Skip to content

Commit

Permalink
Await connection and flush before closing socket.
Browse files Browse the repository at this point in the history
  • Loading branch information
dom96 committed Jan 24, 2024
1 parent 20b9eb5 commit 9b89c8a
Show file tree
Hide file tree
Showing 8 changed files with 134 additions and 6 deletions.
61 changes: 58 additions & 3 deletions src/workerd/api/sockets.c++
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "system-streams.h"
#include <workerd/io/worker-interface.h>
#include "url-standard.h"
#include <workerd/util/autogate.h>


namespace workerd::api {
Expand Down Expand Up @@ -262,7 +263,11 @@ jsg::Ref<Socket> connectImpl(
return connectImplNoOutputLock(js, kj::mv(fetcher), kj::mv(address), kj::mv(options));
}

jsg::Promise<void> Socket::close(jsg::Lock& js) {
// Closes the underlying socket connection. This is an old implementation and will be removed soon.
// See closeImplNew below for the new implementation.
//
// TODO(later): remove once safe
jsg::Promise<void> Socket::closeImplOld(jsg::Lock& js) {
// Forcibly close the readable/writable streams.
auto cancelPromise = readable->getController().cancel(js, kj::none);
auto abortPromise = writable->getController().abort(js, kj::none);
Expand All @@ -271,8 +276,58 @@ jsg::Promise<void> Socket::close(jsg::Lock& js) {
return abortPromise.then(js, [this](jsg::Lock& js) {
resolveFulfiller(js, kj::none);
return js.resolvedPromise();
}, [this](jsg::Lock& js, jsg::Value err) { return errorHandler(js, kj::mv(err)); });
}, [this](jsg::Lock& js, jsg::Value err) { return errorHandler(js, kj::mv(err)); });
}, [this](jsg::Lock& js, jsg::Value err) {
errorHandler(js, kj::mv(err));
return js.resolvedPromise();
});
}, [this](jsg::Lock& js, jsg::Value err) {
errorHandler(js, kj::mv(err));
return js.resolvedPromise();
});
}

// Closes the underlying socket connection, but only after the socket connection is properly
// established through any configured proxy. This method also flushes the writable stream prior to
// closing.
jsg::Promise<void> Socket::closeImplNew(jsg::Lock& js) {
if (isClosing) {
return closedPromiseCopy.whenResolved(js);
}

isClosing = true;
writable->getController().setPendingClosure();
readable->getController().setPendingClosure();

// Wait until the socket connects (successfully or otherwise)
return openedPromiseCopy.whenResolved(js).then(js, [this](jsg::Lock& js) {
if (!writable->getController().isClosedOrClosing()) {
return writable->getController().flush(js);
} else {
return js.resolvedPromise();
}
}).then(js, [this](jsg::Lock& js) {
// Forcibly abort the readable/writable streams.
auto cancelPromise = readable->getController().cancel(js, kj::none);
auto abortPromise = writable->getController().abort(js, kj::none);
// The below is effectively `Promise.all(cancelPromise, abortPromise)`
return cancelPromise.then(js,
[abortPromise = kj::mv(abortPromise)](jsg::Lock& js) mutable {
return kj::mv(abortPromise);
});
}).then(js, [this](jsg::Lock& js) {
resolveFulfiller(js, kj::none);
return js.resolvedPromise();
}).catch_(js, [this](jsg::Lock& js, jsg::Value err) {
errorHandler(js, kj::mv(err));
});
}

jsg::Promise<void> Socket::close(jsg::Lock& js) {
if (util::Autogate::isEnabled(util::AutogateKey::SOCKETS_AWAIT_PROXY_BEFORE_CLOSE)) {
return closeImplNew(js);
} else {
return closeImplOld(js);
}
}

jsg::Ref<Socket> Socket::startTls(jsg::Lock& js, jsg::Optional<TlsOptions> tlsOptions) {
Expand Down
17 changes: 14 additions & 3 deletions src/workerd/api/sockets.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class Socket: public jsg::Object {
: connectionStream(context.addObject(kj::mv(connectionStream))),
readable(kj::mv(readableParam)), writable(kj::mv(writable)),
closedResolver(kj::mv(closedPrPair.resolver)),
closedPromiseCopy(closedPrPair.promise.whenResolved(js)),
closedPromise(kj::mv(closedPrPair.promise)),
watchForDisconnectTask(context.addObject(kj::heap(kj::mv(watchForDisconnectTask)))),
options(kj::mv(options)),
Expand All @@ -66,7 +67,9 @@ class Socket: public jsg::Object {
domain(kj::mv(domain)),
isDefaultFetchPort(isDefaultFetchPort),
openedResolver(kj::mv(openedPrPair.resolver)),
openedPromise(kj::mv(openedPrPair.promise)) { };
openedPromiseCopy(openedPrPair.promise.whenResolved(js)),
openedPromise(kj::mv(openedPrPair.promise)),
isClosing(false) { };

jsg::Ref<ReadableStream> getReadable() { return readable.addRef(); }
jsg::Ref<WritableStream> getWritable() { return writable.addRef(); }
Expand Down Expand Up @@ -119,6 +122,9 @@ class Socket: public jsg::Object {
jsg::Ref<WritableStream> writable;
// This fulfiller is used to resolve the `closedPromise` below.
jsg::Promise<void>::Resolver closedResolver;
// Copy kept so that it can be returned from `close`.
jsg::Promise<void> closedPromiseCopy;
// Memoized copy that is returned by the `closed` attribute.
jsg::MemoizedIdentity<jsg::Promise<void>> closedPromise;
IoOwn<kj::Promise<void>> watchForDisconnectTask;
jsg::Optional<SocketOptions> options;
Expand All @@ -133,10 +139,16 @@ class Socket: public jsg::Object {
bool isDefaultFetchPort;
// This fulfiller is used to resolve the `openedPromise` below.
jsg::Promise<SocketInfo>::Resolver openedResolver;
// Copy kept so that it can be used in `close`.
jsg::Promise<void> openedPromiseCopy;
jsg::MemoizedIdentity<jsg::Promise<SocketInfo>> openedPromise;
// Used to keep track of a pending `close` operation on the socket.
bool isClosing;

kj::Promise<kj::Own<kj::AsyncIoStream>> processConnection();
jsg::Promise<void> maybeCloseWriteSide(jsg::Lock& js);
jsg::Promise<void> closeImplOld(jsg::Lock& js);
jsg::Promise<void> closeImplNew(jsg::Lock& js);

// Helper method for handleProxyStatus implementations.
void handleProxyError(jsg::Lock& js, kj::Exception e);
Expand All @@ -149,10 +161,9 @@ class Socket: public jsg::Object {
}
};

jsg::Promise<void> errorHandler(jsg::Lock& js, jsg::Value err) {
void errorHandler(jsg::Lock& js, jsg::Value err) {
auto jsException = err.getHandle(js);
resolveFulfiller(js, jsg::createTunneledException(js.v8Isolate, jsException));
return js.resolvedPromise();
};

void visitForGc(jsg::GcVisitor& visitor) {
Expand Down
8 changes: 8 additions & 0 deletions src/workerd/api/streams/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,10 @@ class ReadableStreamController {
jsg::Lock& js, kj::Own<WritableStreamSink> sink, bool end) = 0;

virtual kj::Own<ReadableStreamController> detach(jsg::Lock& js, bool ignoreDisturbed) = 0;

// Used by sockets to signal that the ReadableStream shouldn't allow reads due to pending
// closure.
virtual void setPendingClosure() = 0;
};

kj::Own<ReadableStreamController> newReadableStreamJsController();
Expand Down Expand Up @@ -679,6 +683,10 @@ class WritableStreamController {

// True is this controller requires ArrayBuffer(Views) to be written to it.
virtual bool isByteOriented() const = 0;

// Used by sockets to signal that the WritableStream shouldn't allow writes due to pending
// closure.
virtual void setPendingClosure() = 0;
};

kj::Own<WritableStreamController> newWritableStreamJsController();
Expand Down
25 changes: 25 additions & 0 deletions src/workerd/api/streams/internal.c++
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,12 @@ jsg::Ref<ReadableStream> ReadableStreamInternalController::addRef() {
kj::Maybe<jsg::Promise<ReadResult>> ReadableStreamInternalController::read(
jsg::Lock& js,
kj::Maybe<ByobOptions> maybeByobOptions) {

if (isPendingClosure) {
return js.rejectedPromise<ReadResult>(
js.v8TypeError("This ReadableStream belongs to an object that is closing."_kj));
}

std::shared_ptr<v8::BackingStore> store;
size_t byteLength = 0;
size_t byteOffset = 0;
Expand Down Expand Up @@ -596,6 +602,11 @@ jsg::Promise<void> ReadableStreamInternalController::pipeTo(
KJ_DASSERT(!isLockedToReader());
KJ_DASSERT(!destination.isLockedToWriter());

if (isPendingClosure) {
return js.rejectedPromise<void>(
js.v8TypeError("This ReadableStream belongs to an object that is closing."_kj));
}

disturbed = true;
KJ_IF_SOME(promise, destination.tryPipeFrom(js,
KJ_ASSERT_NONNULL(owner).addRef(),
Expand Down Expand Up @@ -655,6 +666,8 @@ void ReadableStreamInternalController::doError(jsg::Lock& js, v8::Local<v8::Valu
ReadableStreamController::Tee ReadableStreamInternalController::tee(jsg::Lock& js) {
JSG_REQUIRE(!isLockedToReader(), TypeError,
"This ReadableStream is currently locked to a reader.");
JSG_REQUIRE(!isPendingClosure, TypeError,
"This ReadableStream belongs to an object that is closing.");
readState.init<Locked>();
disturbed = true;
KJ_SWITCH_ONEOF(state) {
Expand Down Expand Up @@ -815,6 +828,10 @@ jsg::Ref<WritableStream> WritableStreamInternalController::addRef() {
jsg::Promise<void> WritableStreamInternalController::write(
jsg::Lock& js,
jsg::Optional<v8::Local<v8::Value>> value) {
if (isPendingClosure) {
return js.rejectedPromise<void>(
js.v8TypeError("This WritableStream belongs to an object that is closing."_kj));
}
if (isClosedOrClosing()) {
return js.rejectedPromise<void>(
js.v8TypeError("This WritableStream has been closed."_kj));
Expand Down Expand Up @@ -1916,6 +1933,10 @@ jsg::Promise<kj::Array<byte>> ReadableStreamInternalController::readAllBytes(
return js.rejectedPromise<kj::Array<byte>>(KJ_EXCEPTION(FAILED,
"jsg.TypeError: This ReadableStream is currently locked to a reader."));
}
if (isPendingClosure) {
return js.rejectedPromise<kj::Array<byte>>(
js.v8TypeError("This ReadableStream belongs to an object that is closing."_kj));
}
KJ_SWITCH_ONEOF(state) {
KJ_CASE_ONEOF(closed, StreamStates::Closed) {
return js.resolvedPromise(kj::Array<byte>());
Expand All @@ -1939,6 +1960,10 @@ jsg::Promise<kj::String> ReadableStreamInternalController::readAllText(
return js.rejectedPromise<kj::String>(KJ_EXCEPTION(FAILED,
"jsg.TypeError: This ReadableStream is currently locked to a reader."));
}
if (isPendingClosure) {
return js.rejectedPromise<kj::String>(
js.v8TypeError("This ReadableStream belongs to an object that is closing."_kj));
}
KJ_SWITCH_ONEOF(state) {
KJ_CASE_ONEOF(closed, StreamStates::Closed) {
return js.resolvedPromise(kj::String());
Expand Down
16 changes: 16 additions & 0 deletions src/workerd/api/streams/internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,10 @@ class ReadableStreamInternalController: public ReadableStreamController {

kj::Own<ReadableStreamController> detach(jsg::Lock& js, bool ignoreDisturbed) override;

void setPendingClosure() override {
isPendingClosure = true;
}

private:
void doCancel(jsg::Lock& js, jsg::Optional<v8::Local<v8::Value>> reason);
void doClose(jsg::Lock& js);
Expand Down Expand Up @@ -143,6 +147,10 @@ class ReadableStreamInternalController: public ReadableStreamController {
bool disturbed = false;
bool readPending = false;

// Used by Sockets code to signal to the ReadableStream that it should error when read from
// because the socket is currently being closed.
bool isPendingClosure = false;

friend class ReadableStream;
friend class WritableStreamInternalController;
friend class PipeLocked;
Expand Down Expand Up @@ -212,6 +220,10 @@ class WritableStreamInternalController: public WritableStreamController {
bool isErrored() override;

inline bool isByteOriented() const override { return true; }

void setPendingClosure() override {
isPendingClosure = true;
}
private:

struct AbortOptions {
Expand Down Expand Up @@ -259,6 +271,10 @@ class WritableStreamInternalController: public WritableStreamController {
kj::Maybe<jsg::Promise<void>> maybeClosureWaitable;
bool waitingOnClosureWritableAlready = false;

// Used by Sockets code to signal to the WritableStream that it should error when written to
// because the socket is currently being closed.
bool isPendingClosure = false;

void increaseCurrentWriteBufferSize(jsg::Lock& js, uint64_t amount);
void decreaseCurrentWriteBufferSize(jsg::Lock& js, uint64_t amount);
void updateBackpressure(jsg::Lock& js, bool backpressure);
Expand Down
9 changes: 9 additions & 0 deletions src/workerd/api/streams/standard.c++
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// https://opensource.org/licenses/Apache-2.0

#include "standard.h"
#include <kj/debug.h>
#include "readable.h"
#include "writable.h"
#include <workerd/jsg/buffersource.h>
Expand Down Expand Up @@ -684,6 +685,10 @@ public:

kj::Own<ReadableStreamController> detach(jsg::Lock& js, bool ignoreDisturbed) override;

void setPendingClosure() override {
KJ_UNIMPLEMENTED("only implemented for WritableStreamInternalController");
}

private:
bool hasPendingReadRequests();

Expand Down Expand Up @@ -816,6 +821,10 @@ public:

inline bool isByteOriented() const override { return false; }

void setPendingClosure() override {
KJ_UNIMPLEMENTED("only implemented for WritableStreamInternalController");
}

private:
jsg::Promise<void> pipeLoop(jsg::Lock& js);

Expand Down
2 changes: 2 additions & 0 deletions src/workerd/util/autogate.c++
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ kj::StringPtr KJ_STRINGIFY(AutogateKey key) {
return "test-workerd"_kj;
case AutogateKey::BUILTIN_WASM_MODULES:
return "builtin-wasm-modules"_kj;
case AutogateKey::SOCKETS_AWAIT_PROXY_BEFORE_CLOSE:
return "sockets-await-proxy-before-close"_kj;
case AutogateKey::NumOfKeys:
KJ_FAIL_ASSERT("NumOfKeys should not be used in getName");
}
Expand Down
2 changes: 2 additions & 0 deletions src/workerd/util/autogate.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ enum class AutogateKey {
// Allow builtin modules to be wasm modules. Used for Python project.
// Gates code in jsg/modules.h
BUILTIN_WASM_MODULES,
// Enable new behaviour of Socket::close (specifically waiting for proxy result before closing).
SOCKETS_AWAIT_PROXY_BEFORE_CLOSE,
NumOfKeys // Reserved for iteration.
};

Expand Down

0 comments on commit 9b89c8a

Please sign in to comment.