Skip to content

Commit

Permalink
update(Merge*): do not buffer source values by default
Browse files Browse the repository at this point in the history
  • Loading branch information
b97tsk committed Dec 15, 2023
1 parent 905c606 commit d3c4de5
Show file tree
Hide file tree
Showing 2 changed files with 218 additions and 51 deletions.
206 changes: 156 additions & 50 deletions merge.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ func MergeWith[T any](some ...Observable[T]) Operator[T, T] {
func (some observables[T]) Merge(ctx context.Context, sink Observer[T]) {
wg := WaitGroupFromContext(ctx)
ctx, cancel := context.WithCancel(ctx)

sink = sink.OnLastNotification(cancel).Serialized()

var workers atomic.Uint32
Expand All @@ -55,12 +54,7 @@ func (some observables[T]) Merge(ctx context.Context, sink Observer[T]) {
// which concurrently delivers all values that are emitted on the inner
// Observables.
func MergeAll[_ Observable[T], T any]() MergeMapOperator[Observable[T], T] {
return MergeMapOperator[Observable[T], T]{
opts: mergeMapConfig[Observable[T], T]{
Project: identity[Observable[T]],
Concurrency: -1,
},
}
return mergeMap(identity[Observable[T]])
}

// MergeMap converts the source Observable into a higher-order Observable,
Expand All @@ -71,36 +65,47 @@ func MergeMap[T, R any](proj func(v T) Observable[R]) MergeMapOperator[T, R] {
panic("proj == nil")
}

return MergeMapOperator[T, R]{
opts: mergeMapConfig[T, R]{
Project: proj,
Concurrency: -1,
},
}
return mergeMap(proj)
}

// MergeMapTo converts the source Observable into a higher-order Observable,
// by projecting each source value to the same Observable, then flattens it
// into a first-order Observable using MergeAll.
func MergeMapTo[T, R any](inner Observable[R]) MergeMapOperator[T, R] {
return mergeMap(func(T) Observable[R] { return inner })
}

func mergeMap[T, R any](proj func(v T) Observable[R]) MergeMapOperator[T, R] {
return MergeMapOperator[T, R]{
opts: mergeMapConfig[T, R]{
Project: func(T) Observable[R] { return inner },
Concurrency: -1,
Project: proj,
Concurrency: -1,
UseBuffering: false,
},
}
}

type mergeMapConfig[T, R any] struct {
Project func(T) Observable[R]
Concurrency int
Project func(T) Observable[R]
Concurrency int
UseBuffering bool
}

// MergeMapOperator is an [Operator] type for [MergeMap].
type MergeMapOperator[T, R any] struct {
opts mergeMapConfig[T, R]
}

// WithBuffering turns on source buffering.
// By default, this Operator might block the source due to concurrency limit.
// With source buffering on, this Operator buffers every source value, which
// might consume a lot of memory over time if the source has lots of values
// and emits faster than flattening.
func (op MergeMapOperator[T, R]) WithBuffering() MergeMapOperator[T, R] {
op.opts.UseBuffering = true
return op
}

// WithConcurrency sets Concurrency option to a given value.
// It must not be zero. The default value is -1 (unlimited).
func (op MergeMapOperator[T, R]) WithConcurrency(n int) MergeMapOperator[T, R] {
Expand All @@ -124,27 +129,121 @@ type mergeMapObservable[T, R any] struct {
}

func (obs mergeMapObservable[T, R]) Subscribe(ctx context.Context, sink Observer[R]) {
if obs.UseBuffering {
obs.subscribeWithBuffering(ctx, sink)
return
}

obs.subscribe(ctx, sink)
}

func (obs mergeMapObservable[T, R]) subscribe(ctx context.Context, sink Observer[R]) {
wg := WaitGroupFromContext(ctx)
ctx, cancel := context.WithCancel(ctx)

sink = sink.OnLastNotification(cancel).Serialized()

var x struct {
Workers atomic.Uint32
Complete atomic.Bool
Queue struct {
sync.Mutex
queue.Queue[T]
Sealed bool
sync.Mutex
sync.Cond
Workers int
Complete bool
HasError bool
}

x.Cond.L = &x.Mutex

var noop bool

obs.Source.Subscribe(ctx, func(n Notification[T]) {
if noop {
return
}

switch n.Kind {
case KindNext:
x.Lock()

for x.Workers == obs.Concurrency && !x.HasError {
x.Wait()
}

if x.HasError {
x.Unlock()
noop = true
return
}

x.Workers++
x.Unlock()

obs1 := obs.Project(n.Value)

wg.Go(func() {
obs1.Subscribe(ctx, func(n Notification[R]) {
switch n.Kind {
case KindNext:
sink(n)

case KindError:
x.Lock()
x.Workers--
x.HasError = true
x.Unlock()
x.Signal()

sink(n)

case KindComplete:
x.Lock()

x.Workers--

if x.Workers == 0 && x.Complete && !x.HasError {
sink(n)
}

x.Unlock()
x.Signal()
}
})
})

case KindError:
sink.Error(n.Error)

case KindComplete:
x.Lock()

x.Complete = true

if x.Workers == 0 && !x.HasError {
sink.Complete()
}

x.Unlock()
}
})
}

func (obs mergeMapObservable[T, R]) subscribeWithBuffering(ctx context.Context, sink Observer[R]) {
wg := WaitGroupFromContext(ctx)
ctx, cancel := context.WithCancel(ctx)
sink = sink.OnLastNotification(cancel).Serialized()

var x struct {
sync.Mutex
Queue queue.Queue[T]
Workers int
Complete bool
HasError bool
}

var startWorker func()

startWorker = func() {
obs1 := obs.Project(x.Queue.Pop())

x.Queue.Unlock()
x.Unlock()

wg.Go(func() {
obs1.Subscribe(ctx, func(n Notification[R]) {
Expand All @@ -153,67 +252,74 @@ func (obs mergeMapObservable[T, R]) Subscribe(ctx context.Context, sink Observer
sink(n)

case KindError:
x.Queue.Lock()

x.Queue.Sealed = true

x.Lock()
x.Queue.Init()
x.Queue.Unlock()
x.Workers--
x.HasError = true
x.Unlock()

sink(n)

case KindComplete:
x.Queue.Lock()
x.Lock()

if x.Queue.Len() == 0 {
workers := x.Workers.Add(^uint32(0))
if x.Queue.Len() != 0 {
startWorker()
return
}

x.Queue.Unlock()
x.Workers--

if workers == 0 && x.Complete.Load() && x.Workers.Load() == 0 {
sink(n)
}

return
if x.Workers == 0 && x.Complete && !x.HasError {
sink(n)
}

startWorker()
x.Unlock()
}
})
})
}

var noop bool

obs.Source.Subscribe(ctx, func(n Notification[T]) {
if noop {
return
}

switch n.Kind {
case KindNext:
x.Queue.Lock()
x.Lock()

if x.Queue.Sealed {
x.Queue.Unlock()
if x.HasError {
x.Unlock()
noop = true
return
}

x.Queue.Push(n.Value)

if x.Workers.Load() != uint32(obs.Concurrency) {
x.Workers.Add(1)

if x.Workers != obs.Concurrency {
x.Workers++
startWorker()

return
}

x.Queue.Unlock()
x.Unlock()

case KindError:
sink.Error(n.Error)

case KindComplete:
x.Complete.Store(true)
x.Lock()

if x.Workers.Load() == 0 {
x.Complete = true

if x.Workers == 0 && !x.HasError {
sink.Complete()
}

x.Unlock()
}
})
}
63 changes: 62 additions & 1 deletion merge_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func TestMerge(t *testing.T) {
)
}

func TestMerge2(t *testing.T) {
func TestMergeMap(t *testing.T) {
t.Parallel()

NewTestSuite[string](t).Case(
Expand Down Expand Up @@ -94,3 +94,64 @@ func TestMerge2(t *testing.T) {
"A", "B", ErrTest,
)
}

func TestMergeMapWithBuffering(t *testing.T) {
t.Parallel()

NewTestSuite[string](t).Case(
rx.Pipe1(
rx.Just(
rx.Pipe1(rx.Just("A", "B"), AddLatencyToValues[string](3, 5)),
rx.Pipe1(rx.Just("C", "D"), AddLatencyToValues[string](2, 4)),
rx.Pipe1(rx.Just("E", "F"), AddLatencyToValues[string](1, 3)),
),
rx.MergeAll[rx.Observable[string]]().WithBuffering(),
),
"E", "C", "A", "F", "D", "B", ErrComplete,
).Case(
rx.Pipe3(
rx.Range(0, 9),
rx.MergeMap(
func(v int) rx.Observable[int] {
return rx.Pipe1(rx.Just(v), DelaySubscription[int](1))
},
).WithBuffering().WithConcurrency(3),
rx.Reduce(0, func(v1, v2 int) int { return v1 + v2 }),
ToString[int](),
),
"36", ErrComplete,
).Case(
rx.Pipe1(
rx.Timer(Step(1)),
rx.MergeMapTo[time.Time](rx.Just("A")).WithBuffering(),
),
"A", ErrComplete,
).Case(
rx.Pipe1(
rx.Empty[rx.Observable[string]](),
rx.MergeAll[rx.Observable[string]]().WithBuffering(),
),
ErrComplete,
).Case(
rx.Pipe1(
rx.Throw[rx.Observable[string]](ErrTest),
rx.MergeAll[rx.Observable[string]]().WithBuffering(),
),
ErrTest,
).Case(
rx.Pipe1(
func(_ context.Context, sink rx.Observer[rx.Observable[string]]) {
sink.Next(rx.Just("A", "B"))
time.Sleep(Step(1))
sink.Next(rx.Throw[string](ErrTest))
time.Sleep(Step(1))
sink.Next(rx.Just("C", "D"))
time.Sleep(Step(1))
sink.Next(rx.Just("E", "F"))
sink.Complete()
},
rx.MergeAll[rx.Observable[string]]().WithBuffering(),
),
"A", "B", ErrTest,
)
}

0 comments on commit d3c4de5

Please sign in to comment.