From 78f92e8d9c4743aabe3069be6c10b62f8e7d44c1 Mon Sep 17 00:00:00 2001 From: Dario Piotrowicz Date: Sun, 17 Dec 2023 16:49:41 +0000 Subject: [PATCH] enable ai binding in next-dev module --- internal-packages/next-dev/src/ai.ts | 54 ++++++++++++++++++++++ internal-packages/next-dev/src/index.ts | 17 ++++++- internal-packages/next-dev/src/wrangler.ts | 11 +++++ 3 files changed, 81 insertions(+), 1 deletion(-) create mode 100644 internal-packages/next-dev/src/ai.ts diff --git a/internal-packages/next-dev/src/ai.ts b/internal-packages/next-dev/src/ai.ts new file mode 100644 index 000000000..34294819e --- /dev/null +++ b/internal-packages/next-dev/src/ai.ts @@ -0,0 +1,54 @@ +import { Response } from 'miniflare'; +import type { Request } from 'miniflare'; + +import { cloneHeaders } from './wrangler'; + +type Fetcher = (req: Request) => Promise; + +export function getAIFetcher({ + accountId, + apiToken, +}: { + accountId: string; + apiToken: string; +}): Fetcher { + // Tweaked version of the wrangler ai fetcher + // (source: https://github.com/cloudflare/workers-sdk/blob/912bfe/packages/wrangler/src/ai/fetcher.ts) + return async function AIFetcher(request: Request) { + request.headers.delete('Host'); + request.headers.delete('Content-Length'); + + const res = await performApiFetch( + `/accounts/${accountId}/ai/run/proxy`, + { + method: 'POST', + headers: Object.fromEntries(request.headers.entries()), + body: request.body as BodyInit, + // eslint-disable-next-line @typescript-eslint/ban-ts-comment + // @ts-ignore + duplex: 'half', + }, + apiToken, + ); + + return new Response(res.body, { status: res.status }); + }; +} + +// (Heavily) Simplified version of performApiFetch from wrangler +// (source: https://github.com/cloudflare/workers-sdk/blob/912bfe/packages/wrangler/src/cfetch/internal.ts#L18) +export async function performApiFetch( + resource: string, + init: RequestInit = {}, + apiToken: string, +) { + const method = init.method ?? 'GET'; + const headers = cloneHeaders(init.headers); + headers['Authorization'] = `Bearer ${apiToken}`; + + return fetch(`https://api.cloudflare.com/client/v4${resource}`, { + method, + ...init, + headers, + }); +} diff --git a/internal-packages/next-dev/src/index.ts b/internal-packages/next-dev/src/index.ts index 31be061b5..7d374277c 100644 --- a/internal-packages/next-dev/src/index.ts +++ b/internal-packages/next-dev/src/index.ts @@ -2,6 +2,7 @@ import type { WorkerOptions } from 'miniflare'; import { Miniflare, Request, Response, Headers } from 'miniflare'; import { getDOBindingInfo } from './durableObjects'; import { getServiceBindings } from './services'; +import { getAIFetcher } from './ai'; /** * Sets up the bindings that need to be available during development time (using @@ -54,6 +55,11 @@ export type DevBindingsOptions = { className: string; } >; + ai?: { + bindingName: string; + accountId: string; + apiToken: string; + }; /** * Record mapping binding names to R2 bucket names to inject as R2Bucket. * If a `string[]` of binding names is specified, the binding name and bucket name are assumed to be the same. @@ -84,6 +90,12 @@ async function instantiateMiniflare( const { workerOptions, durableObjects } = (await getDOBindingInfo(options.durableObjects)) ?? {}; + const aiBindingObj = options.ai + ? { + [options.ai.bindingName]: getAIFetcher(options.ai), + } + : {}; + const { kvNamespaces, r2Buckets, d1Databases, services, textBindings } = options; const bindings = { @@ -102,7 +114,10 @@ async function instantiateMiniflare( ...bindings, modules: true, script: '', - serviceBindings, + serviceBindings: { + ...serviceBindings, + ...aiBindingObj, + }, }, ...(workerOptions ? [workerOptions] : []), ]; diff --git a/internal-packages/next-dev/src/wrangler.ts b/internal-packages/next-dev/src/wrangler.ts index 6cc7f3e61..33f56c2e9 100644 --- a/internal-packages/next-dev/src/wrangler.ts +++ b/internal-packages/next-dev/src/wrangler.ts @@ -97,3 +97,14 @@ const IDENTIFIER_UNSAFE_REGEXP = /[^a-zA-Z0-9_$]/g; export function getIdentifier(name: string) { return name.replace(IDENTIFIER_UNSAFE_REGEXP, '_'); } + +// https://github.com/cloudflare/workers-sdk/blob/912bfe/packages/wrangler/src/cfetch/internal.ts#L119 +export function cloneHeaders( + headers: HeadersInit | undefined, +): Record { + return headers instanceof Headers + ? Object.fromEntries(headers.entries()) + : Array.isArray(headers) + ? Object.fromEntries(headers) + : { ...headers }; +}