diff --git a/README.markdown b/README.markdown index 0711e34b0..698ee1c75 100644 --- a/README.markdown +++ b/README.markdown @@ -412,6 +412,8 @@ Generated code will be placed in the Gradle build directory. - With `--ts_proto_opt=outputServices=false`, or `=none`, ts-proto will output NO service definitions. +- With `--ts_proto_opt=useAbortSignal=true`, the generated services will accept an `AbortSignal` to cancel RPC calls. + - With `--ts_proto_opt=useAsyncIterable=true`, the generated services will use `AsyncIterable` instead of `Observable`. - With `--ts_proto_opt=emitImportedFiles=false`, ts-proto will not emit `google/protobuf/*` files unless you explicit add files to `protoc` like this diff --git a/integration/async-iterable-services-abort-signal/parameters.txt b/integration/async-iterable-services-abort-signal/parameters.txt new file mode 100644 index 000000000..aab111ad9 --- /dev/null +++ b/integration/async-iterable-services-abort-signal/parameters.txt @@ -0,0 +1 @@ +useAsyncIterable=true,useAbortSignal=true diff --git a/integration/async-iterable-services-abort-signal/simple.bin b/integration/async-iterable-services-abort-signal/simple.bin new file mode 100644 index 000000000..51973c2d8 Binary files /dev/null and b/integration/async-iterable-services-abort-signal/simple.bin differ diff --git a/integration/async-iterable-services-abort-signal/simple.proto b/integration/async-iterable-services-abort-signal/simple.proto new file mode 100644 index 000000000..73d6a5180 --- /dev/null +++ b/integration/async-iterable-services-abort-signal/simple.proto @@ -0,0 +1,19 @@ +syntax = "proto3"; +package simple; + +// Echoer service returns the given message. +service Echoer { + // Echo returns the given message. + rpc Echo(EchoMsg) returns (EchoMsg); + // EchoServerStream is an example of a server -> client one-way stream. + rpc EchoServerStream(EchoMsg) returns (stream EchoMsg); + // EchoClientStream is an example of client->server one-way stream. + rpc EchoClientStream(stream EchoMsg) returns (EchoMsg); + // EchoBidiStream is an example of a two-way stream. + rpc EchoBidiStream(stream EchoMsg) returns (stream EchoMsg); +} + +// EchoMsg is the message body for Echo. +message EchoMsg { + string body = 1; +} diff --git a/integration/async-iterable-services-abort-signal/simple.ts b/integration/async-iterable-services-abort-signal/simple.ts new file mode 100644 index 000000000..c6ec3f2f8 --- /dev/null +++ b/integration/async-iterable-services-abort-signal/simple.ts @@ -0,0 +1,178 @@ +/* eslint-disable */ +import * as _m0 from "protobufjs/minimal"; + +export const protobufPackage = "simple"; + +/** EchoMsg is the message body for Echo. */ +export interface EchoMsg { + body: string; +} + +function createBaseEchoMsg(): EchoMsg { + return { body: "" }; +} + +export const EchoMsg = { + encode(message: EchoMsg, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer { + if (message.body !== "") { + writer.uint32(10).string(message.body); + } + return writer; + }, + + decode(input: _m0.Reader | Uint8Array, length?: number): EchoMsg { + const reader = input instanceof _m0.Reader ? input : new _m0.Reader(input); + let end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseEchoMsg(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: + message.body = reader.string(); + break; + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }, + + // encodeTransform encodes a source of message objects. + // Transform + async *encodeTransform( + source: AsyncIterable | Iterable, + ): AsyncIterable { + for await (const pkt of source) { + if (Array.isArray(pkt)) { + for (const p of pkt) { + yield* [EchoMsg.encode(p).finish()]; + } + } else { + yield* [EchoMsg.encode(pkt).finish()]; + } + } + }, + + // decodeTransform decodes a source of encoded messages. + // Transform + async *decodeTransform( + source: AsyncIterable | Iterable, + ): AsyncIterable { + for await (const pkt of source) { + if (Array.isArray(pkt)) { + for (const p of pkt) { + yield* [EchoMsg.decode(p)]; + } + } else { + yield* [EchoMsg.decode(pkt)]; + } + } + }, + + fromJSON(object: any): EchoMsg { + return { body: isSet(object.body) ? String(object.body) : "" }; + }, + + toJSON(message: EchoMsg): unknown { + const obj: any = {}; + message.body !== undefined && (obj.body = message.body); + return obj; + }, + + fromPartial, I>>(object: I): EchoMsg { + const message = createBaseEchoMsg(); + message.body = object.body ?? ""; + return message; + }, +}; + +/** Echoer service returns the given message. */ +export interface Echoer { + /** Echo returns the given message. */ + Echo(request: EchoMsg, abortSignal?: AbortSignal): Promise; + /** EchoServerStream is an example of a server -> client one-way stream. */ + EchoServerStream(request: EchoMsg, abortSignal?: AbortSignal): AsyncIterable; + /** EchoClientStream is an example of client->server one-way stream. */ + EchoClientStream(request: AsyncIterable, abortSignal?: AbortSignal): Promise; + /** EchoBidiStream is an example of a two-way stream. */ + EchoBidiStream(request: AsyncIterable, abortSignal?: AbortSignal): AsyncIterable; +} + +export class EchoerClientImpl implements Echoer { + private readonly rpc: Rpc; + private readonly service: string; + constructor(rpc: Rpc, opts?: { service?: string }) { + this.service = opts?.service || "simple.Echoer"; + this.rpc = rpc; + this.Echo = this.Echo.bind(this); + this.EchoServerStream = this.EchoServerStream.bind(this); + this.EchoClientStream = this.EchoClientStream.bind(this); + this.EchoBidiStream = this.EchoBidiStream.bind(this); + } + Echo(request: EchoMsg, abortSignal?: AbortSignal): Promise { + const data = EchoMsg.encode(request).finish(); + const promise = this.rpc.request(this.service, "Echo", data, abortSignal || undefined); + return promise.then((data) => EchoMsg.decode(new _m0.Reader(data))); + } + + EchoServerStream(request: EchoMsg, abortSignal?: AbortSignal): AsyncIterable { + const data = EchoMsg.encode(request).finish(); + const result = this.rpc.serverStreamingRequest(this.service, "EchoServerStream", data, abortSignal || undefined); + return EchoMsg.decodeTransform(result); + } + + EchoClientStream(request: AsyncIterable, abortSignal?: AbortSignal): Promise { + const data = EchoMsg.encodeTransform(request); + const promise = this.rpc.clientStreamingRequest(this.service, "EchoClientStream", data, abortSignal || undefined); + return promise.then((data) => EchoMsg.decode(new _m0.Reader(data))); + } + + EchoBidiStream(request: AsyncIterable, abortSignal?: AbortSignal): AsyncIterable { + const data = EchoMsg.encodeTransform(request); + const result = this.rpc.bidirectionalStreamingRequest( + this.service, + "EchoBidiStream", + data, + abortSignal || undefined, + ); + return EchoMsg.decodeTransform(result); + } +} + +interface Rpc { + request(service: string, method: string, data: Uint8Array, abortSignal?: AbortSignal): Promise; + clientStreamingRequest( + service: string, + method: string, + data: AsyncIterable, + abortSignal?: AbortSignal, + ): Promise; + serverStreamingRequest( + service: string, + method: string, + data: Uint8Array, + abortSignal?: AbortSignal, + ): AsyncIterable; + bidirectionalStreamingRequest( + service: string, + method: string, + data: AsyncIterable, + abortSignal?: AbortSignal, + ): AsyncIterable; +} + +type Builtin = Date | Function | Uint8Array | string | number | boolean | undefined; + +export type DeepPartial = T extends Builtin ? T + : T extends Array ? Array> : T extends ReadonlyArray ? ReadonlyArray> + : T extends {} ? { [K in keyof T]?: DeepPartial } + : Partial; + +type KeysOfUnion = T extends T ? keyof T : never; +export type Exact = P extends Builtin ? P + : P & { [K in keyof P]: Exact } & { [K in Exclude>]: never }; + +function isSet(value: any): boolean { + return value !== null && value !== undefined; +} diff --git a/src/generate-grpc-web.ts b/src/generate-grpc-web.ts index dda41b3eb..2bfdb1efe 100644 --- a/src/generate-grpc-web.ts +++ b/src/generate-grpc-web.ts @@ -25,7 +25,7 @@ export function generateGrpcClientImpl( // Create the constructor(rpc: Rpc) chunks.push(code` private readonly rpc: Rpc; - + constructor(rpc: Rpc) { `); chunks.push(code`this.rpc = rpc;`); diff --git a/src/generate-services.ts b/src/generate-services.ts index 6c12b722c..ec9a6a9b5 100644 --- a/src/generate-services.ts +++ b/src/generate-services.ts @@ -62,6 +62,9 @@ export function generateService( const partialInput = options.outputClientImpl === "grpc-web"; const inputType = requestType(ctx, methodDesc, partialInput); params.push(code`request: ${inputType}`); + if (options.useAbortSignal) { + params.push(code`abortSignal?: AbortSignal`); + } // Use metadata as last argument for interface only configuration if (options.outputClientImpl === "grpc-web") { @@ -103,22 +106,21 @@ export function generateService( return joinCode(chunks, { on: "\n" }); } -function generateRegularRpcMethod( - ctx: Context, - fileDesc: FileDescriptorProto, - serviceDesc: ServiceDescriptorProto, - methodDesc: MethodDescriptorProto -): Code { +function generateRegularRpcMethod(ctx: Context, methodDesc: MethodDescriptorProto): Code { assertInstanceOf(methodDesc, FormattedMethodDescriptor); - const { options, utils } = ctx; + const { options } = ctx; const Reader = impFile(ctx.options, "Reader@protobufjs/minimal"); const rawInputType = rawRequestType(ctx, methodDesc, { keepValueType: true }); const inputType = requestType(ctx, methodDesc); - const outputType = responseType(ctx, methodDesc); const rawOutputType = responseType(ctx, methodDesc, { keepValueType: true }); - const params = [...(options.context ? [code`ctx: Context`] : []), code`request: ${inputType}`]; + const params = [ + ...(options.context ? [code`ctx: Context`] : []), + code`request: ${inputType}`, + ...(options.useAbortSignal ? [code`abortSignal?: AbortSignal`] : []), + ]; const maybeCtx = options.context ? "ctx," : ""; + const maybeAbortSignal = options.useAbortSignal ? "abortSignal || undefined," : ""; let encode = code`${rawInputType}.encode(request).finish()`; let decode = code`data => ${rawOutputType}.decode(new ${Reader}(data))`; @@ -166,7 +168,8 @@ function generateRegularRpcMethod( ${maybeCtx} this.service, "${methodDesc.name}", - data + data, + ${maybeAbortSignal} ); return ${decode}; } @@ -216,7 +219,7 @@ export function generateServiceClientImpl( if (options.context && methodDesc.name.match(/^Get[A-Z]/)) { chunks.push(generateCachingRpcMethod(ctx, fileDesc, serviceDesc, methodDesc)); } else { - chunks.push(generateRegularRpcMethod(ctx, fileDesc, serviceDesc, methodDesc)); + chunks.push(generateRegularRpcMethod(ctx, methodDesc)); } } @@ -332,6 +335,7 @@ export function generateRpcType(ctx: Context, hasStreamingMethods: boolean): Cod const { options } = ctx; const maybeContext = options.context ? "" : ""; const maybeContextParam = options.context ? "ctx: Context," : ""; + const maybeAbortSignalParam = options.useAbortSignal ? "abortSignal?: AbortSignal," : ""; const methods = [[code`request`, code`Uint8Array`, code`Promise`]]; if (hasStreamingMethods) { const observable = observableType(ctx); @@ -351,7 +355,8 @@ export function generateRpcType(ctx: Context, hasStreamingMethods: boolean): Cod ${maybeContextParam} service: string, method: string, - data: ${method[1]} + data: ${method[1]}, + ${maybeAbortSignalParam} ): ${method[2]};`); }); chunks.push(code` }`); diff --git a/src/main.ts b/src/main.ts index 46dab9c9c..548c33e27 100644 --- a/src/main.ts +++ b/src/main.ts @@ -161,7 +161,7 @@ export function generateFile(ctx: Context, fileDesc: FileDescriptorProto): [stri visit( fileDesc, sourceInfo, - (fullName, message, sInfo, fullProtoTypeName) => { + (fullName, message, _sInfo, fullProtoTypeName) => { const fullTypeName = maybePrefixPackage(fileDesc, fullProtoTypeName); chunks.push(generateBaseInstanceFactory(ctx, fullName, message, fullTypeName)); @@ -268,7 +268,7 @@ export function generateFile(ctx: Context, fileDesc: FileDescriptorProto): [stri } }); } - serviceDesc.method.forEach((methodDesc, index) => { + serviceDesc.method.forEach((methodDesc, _index) => { if (methodDesc.serverStreaming || methodDesc.clientStreaming) { hasStreamingMethods = true; } @@ -335,7 +335,7 @@ export function makeUtils(options: Options): Utils { return { ...bytes, ...makeDeepPartial(options, longs), - ...makeObjectIdMethods(options), + ...makeObjectIdMethods(), ...makeTimestampMethods(options, longs), ...longs, ...makeComparisonUtils(), @@ -538,7 +538,7 @@ function makeDeepPartial(options: Options, longs: ReturnType { "stringEnums": false, "unknownFields": false, "unrecognizedEnum": true, + "useAbortSignal": false, "useAsyncIterable": false, "useDate": "timestamp", "useExactTypes": true,