Skip to content

Commit

Permalink
update(ConcatWith,MergeWith,RaceWith): do not subscribe to the source…
Browse files Browse the repository at this point in the history
… in a goroutine
  • Loading branch information
b97tsk committed Mar 5, 2024
1 parent 9d32cca commit ed39508
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 54 deletions.
37 changes: 24 additions & 13 deletions concat.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,41 +15,52 @@ func Concat[T any](some ...Observable[T]) Observable[T] {
return Empty[T]()
}

return observables[T](some).Concat
return concatWithObservable[T]{Others: some}.Subscribe
}

// ConcatWith applies [Concat] to the source Observable along with some other
// Observables to create a first-order Observable.
// ConcatWith concatenates the source Observable and some other Observables
// together to create an Observable that sequentially emits their values,
// one Observable after the other.
func ConcatWith[T any](some ...Observable[T]) Operator[T, T] {
return NewOperator(
func(source Observable[T]) Observable[T] {
return observables[T](append([]Observable[T]{source}, some...)).Concat
return concatWithObservable[T]{source, some}.Subscribe
},
)
}

func (some observables[T]) Concat(c Context, sink Observer[T]) {
type concatWithObservable[T any] struct {
Source Observable[T]
Others []Observable[T]
}

func (obs concatWithObservable[T]) Subscribe(c Context, sink Observer[T]) {
var observer Observer[T]

done := c.Done()

subscribeToNext := resistReentrance(func() {
next := resistReentrance(func() {
if source := obs.Source; source != nil {
obs.Source = nil
source.Subscribe(c, observer)
return
}

select {
default:
case <-done:
sink.Error(c.Err())
return
}

if len(some) == 0 {
if len(obs.Others) == 0 {
sink.Complete()
return
}

obs := some[0]
some = some[1:]

obs.Subscribe(c, observer)
obs1 := obs.Others[0]
obs.Others = obs.Others[1:]
obs1.Subscribe(c, observer)
})

observer = func(n Notification[T]) {
Expand All @@ -58,10 +69,10 @@ func (some observables[T]) Concat(c Context, sink Observer[T]) {
return
}

subscribeToNext()
next()
}

subscribeToNext()
next()
}

// ConcatAll flattens a higher-order Observable into a first-order Observable
Expand Down
38 changes: 29 additions & 9 deletions merge.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,36 +14,56 @@ func Merge[T any](some ...Observable[T]) Observable[T] {
return Empty[T]()
}

return observables[T](some).Merge
return mergeWithObservable[T]{Others: some}.Subscribe
}

// MergeWith applies [Merge] to the source Observable along with some other
// Observables to create a first-order Observable.
// MergeWith merges the source Observable and some other Observables together
// to create an Observable that concurrently emits all values from the source
// and every given input Observable.
func MergeWith[T any](some ...Observable[T]) Operator[T, T] {
return NewOperator(
func(source Observable[T]) Observable[T] {
return observables[T](append([]Observable[T]{source}, some...)).Merge
return mergeWithObservable[T]{source, some}.Subscribe
},
)
}

func (some observables[T]) Merge(c Context, sink Observer[T]) {
type mergeWithObservable[T any] struct {
Source Observable[T]
Others []Observable[T]
}

func (obs mergeWithObservable[T]) Subscribe(c Context, sink Observer[T]) {
c, cancel := c.WithCancel()
sink = sink.OnLastNotification(cancel).Serialized()

var workers atomic.Uint32
var num atomic.Uint32

workers.Store(uint32(len(some)))
num.Store(uint32(obs.numObservables()))

worker := func(n Notification[T]) {
if n.Kind != KindComplete || workers.Add(^uint32(0)) == 0 {
if n.Kind != KindComplete || num.Add(^uint32(0)) == 0 {
sink(n)
}
}

for _, obs := range some {
for _, obs := range obs.Others {
c.Go(func() { obs.Subscribe(c, worker) })
}

if obs.Source != nil {
obs.Source.Subscribe(c, worker)
}
}

func (obs mergeWithObservable[T]) numObservables() int {
n := len(obs.Others)

if obs.Source != nil {
n++
}

return n
}

// MergeAll flattens a higher-order Observable into a first-order Observable
Expand Down
81 changes: 51 additions & 30 deletions race.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,59 +9,80 @@ func Race[T any](some ...Observable[T]) Observable[T] {
return Empty[T]()
}

return observables[T](some).Race
return raceWithObservable[T]{Others: some}.Subscribe
}

// RaceWith applies [Race] to the source Observable along with some other
// Observables to create a first-order Observable.
// RaceWith mirrors the first Observable to emit a value, from the source
// and given input Observables.
func RaceWith[T any](some ...Observable[T]) Operator[T, T] {
return NewOperator(
func(source Observable[T]) Observable[T] {
return observables[T](append([]Observable[T]{source}, some...)).Race
return raceWithObservable[T]{source, some}.Subscribe
},
)
}

func (some observables[T]) Race(c Context, sink Observer[T]) {
subs := make([]Pair[Context, CancelFunc], len(some))
type raceWithObservable[T any] struct {
Source Observable[T]
Others []Observable[T]
}

func (obs raceWithObservable[T]) Subscribe(c Context, sink Observer[T]) {
subs := make([]Pair[Context, CancelFunc], obs.numObservables())

for i := range subs {
subs[i] = NewPair(c.WithCancel())
}

var race atomic.Uint32

for index, obs := range some {
c.Go(func() {
var won, lost bool

obs.Subscribe(subs[index].Left(), func(n Notification[T]) {
switch {
case won:
sink(n)
return
case lost:
return
}
subscribe := func(i int, obs Observable[T]) {
var won, lost bool

obs.Subscribe(subs[i].Left(), func(n Notification[T]) {
switch {
case won:
sink(n)
return
case lost:
return
}

if race.CompareAndSwap(0, 1) {
for i := range subs {
if i != index {
subs[i].Right()()
}
if race.CompareAndSwap(0, 1) {
for j := range subs {
if j != i {
subs[j].Right()()
}
}

won = true
won = true

sink(n)
sink(n)

return
}
return
}

lost = true
lost = true

subs[index].Right()()
})
subs[i].Right()()
})
}

for i, obs := range obs.Others {
c.Go(func() { subscribe(i, obs) })
}

if obs.Source != nil {
subscribe(len(subs)-1, obs.Source)
}
}

func (obs raceWithObservable[T]) numObservables() int {
n := len(obs.Others)

if obs.Source != nil {
n++
}

return n
}
2 changes: 0 additions & 2 deletions util.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@ var sentinel = func() context.Context {
return ctx
}()

type observables[T any] []Observable[T]

func identity[T any](v T) T { return v }

// resistReentrance returns a function that calls f in a non-recursive way
Expand Down

0 comments on commit ed39508

Please sign in to comment.