From 705a319c409415b48f22380ceca4488b63f176a8 Mon Sep 17 00:00:00 2001 From: Sebastiaan Koppe Date: Wed, 18 Sep 2024 13:42:47 +0100 Subject: [PATCH] Add sample operator --- source/concurrency/sequence.d | 112 ++++++++++++++++++++++++++++++++ tests/ut/concurrency/sequence.d | 24 +++++++ 2 files changed, 136 insertions(+) diff --git a/source/concurrency/sequence.d b/source/concurrency/sequence.d index cc04595..3ab9ba6 100644 --- a/source/concurrency/sequence.d +++ b/source/concurrency/sequence.d @@ -1085,6 +1085,118 @@ private auto getSome(T)(ref T t) { return t.get(); } +auto proxyNext(Sequence, Receiver)(Sequence sequence, Receiver receiver) { + return ProxyNextSequence!(Sequence, Receiver)(sequence, receiver); +} + +struct ProxyNextSequence(Sequence, NextReceiver) { + alias Value = Sequence.Value; + + Sequence sequence; + NextReceiver nextReceiver; + + auto connect(Receiver)(return Receiver receiver) @safe return scope { + // ensure NRVO + auto op = sequence.connect(ProxyNextReceiver!(NextReceiver, Receiver)(nextReceiver, receiver)); + return op; + } +} + +struct ProxyNextReceiver(NextReceiver, Receiver) { + NextReceiver nextReceiver; + Receiver receiver; + + auto setNext(Sender)(Sender sender) { + return nextReceiver.setNext(sender); + } + auto setValue() { + receiver.setValue(); + } + auto setDone() nothrow @safe { + receiver.setDone(); + } + auto setError(Throwable t) nothrow @safe { + receiver.setError(t); + } + import concurrency.receiver : ForwardExtensionPoints; + mixin ForwardExtensionPoints!receiver; +} + +auto sample(BaseSequence, TriggerSequence)(BaseSequence base, TriggerSequence trigger) { + return SampleSequence!(BaseSequence, TriggerSequence)(base, trigger); +} + +struct SampleSequence(BaseSequence, TriggerSequence) { + alias Value = void; + alias Element = BaseSequence.Element; + + BaseSequence base; + TriggerSequence trigger; + + auto connect(Receiver)(return Receiver receiver) @safe return scope { + // ensure NRVO + auto op = SampleSequenceOp!(BaseSequence, TriggerSequence, Receiver)(base, trigger, receiver); + return op; + } +} + +struct SampleSequenceOp(BaseSequence, TriggerSequence, Receiver) { + import concurrency.bitfield : SharedBitField; + import concurrency.sender : OpType; + import concurrency.operations : RaceSender; + + import std.typecons : Nullable; + enum Flags : size_t { + locked = 0x1, + valid = 0x2 + } + shared SharedBitField!Flags state; + alias Element = BaseSequence.Element; + Element item; + alias RaceAllSender = RaceSender!( + SequenceCollect!(BaseSequence, void delegate(Element) shared @safe nothrow @nogc), + ProxyNextSequence!(FilterMapSequence!(TriggerSequence, Nullable!Element delegate() shared @safe nothrow @nogc), Receiver) + ); + alias Op = OpType!(RaceAllSender, Receiver); + + Op op; + + @disable this(ref return scope typeof(this) rhs); + @disable this(this); + + @disable void opAssign(typeof(this) rhs) nothrow @safe @nogc; + @disable void opAssign(ref typeof(this) rhs) nothrow @safe @nogc; + + this(BaseSequence base, TriggerSequence trigger, return Receiver receiver) @trusted return scope { + import concurrency.operations : raceAll; + op = raceAll( + base.collect(&(cast(shared)this).produced), + trigger.filterMap(&(cast(shared)this).triggered).proxyNext(receiver) + ).connect(receiver); + } + + void start() { + op.start(); + } + + private void produced(Element item) shared @safe nothrow @nogc { + with (state.lock(Flags.valid)) { + this.item = item; + } + } + + private Nullable!Element triggered() shared @safe nothrow @nogc{ + with (state.lock()) { + if (was(Flags.valid)) { + auto localElement = item; + release(Flags.valid); + return Nullable!Element(localElement); + } + return Nullable!Element.init; + } + } +} + // cron - create a sequence like interval but using cron spec diff --git a/tests/ut/concurrency/sequence.d b/tests/ut/concurrency/sequence.d index b980ea1..3c7312f 100644 --- a/tests/ut/concurrency/sequence.d +++ b/tests/ut/concurrency/sequence.d @@ -173,3 +173,27 @@ import unit_threaded; return Nullable!int(1); }).take(4).toList().syncWait.value.should == [1,1,1,1]; } + +@("sample") +@safe unittest { + import core.time : msecs; + import concurrency.scheduler : ManualTimeWorker; + import concurrency.operations : then, whenAll, withScheduler; + + auto worker = new shared ManualTimeWorker(); + auto driver = just(worker).then((shared ManualTimeWorker worker) shared { + for(;;) { + auto span = worker.timeUntilNextEvent(); + if (span.isNull) + break; + worker.advance(span.get()); + } + }); + whenAll( + sample( + interval(1.msecs, false).scan((int acc) => acc + 1, 0), + interval(2.msecs, false) + ).take(4).toList().withScheduler(worker.getScheduler()), + driver, + ).syncWait.value.should == [1,3,5,7]; +}