diff --git a/src/index.ts b/src/index.ts index 1c374b3fc..0758ec377 100644 --- a/src/index.ts +++ b/src/index.ts @@ -158,7 +158,7 @@ const post = async ({ variables?: V headers?: Dom.RequestInit['headers'] operationName?: string - middleware?: (request: Dom.RequestInit) => Dom.RequestInit + middleware?: (request: Dom.RequestInit) => Dom.RequestInit | Promise }) => { const body = createRequestBody(query, variables, operationName, fetchOptions.jsonSerializer) @@ -172,7 +172,7 @@ const post = async ({ ...fetchOptions, } if (middleware) { - options = middleware(options) + options = await Promise.resolve(middleware(options)) } return await fetch(url, options) } @@ -197,7 +197,7 @@ const get = async ({ variables?: V headers?: HeadersInit operationName?: string - middleware?: (request: Dom.RequestInit) => Dom.RequestInit + middleware?: (request: Dom.RequestInit) => Dom.RequestInit | Promise }) => { const queryParams = buildGetQueryParams({ query, @@ -212,7 +212,7 @@ const get = async ({ ...fetchOptions, } if (middleware) { - options = middleware(options) + options = await Promise.resolve(middleware(options)) } return await fetch(`${url}?${queryParams}`, options) } @@ -458,7 +458,7 @@ async function makeRequest({ fetch: any method: string fetchOptions: Dom.RequestInit - middleware?: (request: Dom.RequestInit) => Dom.RequestInit + middleware?: (request: Dom.RequestInit) => Dom.RequestInit | Promise }): Promise> { const fetcher = method.toUpperCase() === 'POST' ? post : get const isBathchingQuery = Array.isArray(query) diff --git a/src/types.ts b/src/types.ts index 886621601..ce8f2c82c 100644 --- a/src/types.ts +++ b/src/types.ts @@ -70,7 +70,7 @@ export interface Response { export type PatchedRequestInit = Omit & { headers?: MaybeFunction - requestMiddleware?: (request: Dom.RequestInit) => Dom.RequestInit + requestMiddleware?: (request: Dom.RequestInit) => Dom.RequestInit | Promise responseMiddleware?: (response: Response | Error) => void } diff --git a/tests/general.test.ts b/tests/general.test.ts index 1545e7c2c..8682054fa 100644 --- a/tests/general.test.ts +++ b/tests/general.test.ts @@ -163,6 +163,41 @@ describe('middleware', () => { }) }) + describe('async request middleware', () => { + beforeEach(() => { + ctx.res({ + body: { + data: { + result: 123, + }, + }, + }) + + requestMiddleware = jest.fn(async (req) => ({ ...req })) + client = new GraphQLClient(ctx.url, { + requestMiddleware, + }) + }) + + it('request', async () => { + const requestPromise = client.request<{ result: number }>(`x`) + expect(requestMiddleware).toBeCalledTimes(1) + await requestPromise + }) + + it('rawRequest', async () => { + const requestPromise = client.rawRequest<{ result: number }>(`x`) + expect(requestMiddleware).toBeCalledTimes(1) + await requestPromise + }) + + it('batchRequests', async () => { + const requestPromise = client.batchRequests<{ result: number }>([{ document: `x` }]) + expect(requestMiddleware).toBeCalledTimes(1) + await requestPromise + }) + }) + describe('failed requests', () => { beforeEach(() => { ctx.res({