diff --git a/internal-packages/next-dev/src/ai.ts b/internal-packages/next-dev/src/ai.ts new file mode 100644 index 000000000..db588ce42 --- /dev/null +++ b/internal-packages/next-dev/src/ai.ts @@ -0,0 +1,44 @@ +import { Response } from 'miniflare'; +import type { Request } from 'miniflare'; + +import { cloneHeaders } from "./wrangler"; + +type Fetcher = (req: Request) => Promise; + +export function getAIFetcher({ accountId, apiToken }: {accountId: string, apiToken: string}): Fetcher { + // Tweaked version of the wrangler ai fetcher + // (source: https://github.com/cloudflare/workers-sdk/blob/912bfe/packages/wrangler/src/ai/fetcher.ts) + return async function AIFetcher(request: Request) { + request.headers.delete("Host"); + request.headers.delete("Content-Length"); + + const res = await performApiFetch(`/accounts/${accountId}/ai/run/proxy`, { + method: "POST", + headers: Object.fromEntries(request.headers.entries()), + body: request.body as BodyInit, + // eslint-disable-next-line @typescript-eslint/ban-ts-comment + // @ts-ignore + duplex: "half", + }, apiToken); + + return new Response(res.body, { status: res.status }); + }; +} + +// (Heavily) Simplified version of performApiFetch from wrangler +// (source: https://github.com/cloudflare/workers-sdk/blob/912bfe/packages/wrangler/src/cfetch/internal.ts#L18) +export async function performApiFetch( + resource: string, + init: RequestInit = {}, + apiToken: string, +) { + const method = init.method ?? "GET"; + const headers = cloneHeaders(init.headers); + headers["Authorization"] = `Bearer ${apiToken}`; + + return fetch(`https://api.cloudflare.com/client/v4${resource}`, { + method, + ...init, + headers, + }); +} diff --git a/internal-packages/next-dev/src/index.ts b/internal-packages/next-dev/src/index.ts index 31be061b5..5249023c9 100644 --- a/internal-packages/next-dev/src/index.ts +++ b/internal-packages/next-dev/src/index.ts @@ -2,6 +2,7 @@ import type { WorkerOptions } from 'miniflare'; import { Miniflare, Request, Response, Headers } from 'miniflare'; import { getDOBindingInfo } from './durableObjects'; import { getServiceBindings } from './services'; +import { getAIFetcher } from './ai'; /** * Sets up the bindings that need to be available during development time (using @@ -54,6 +55,11 @@ export type DevBindingsOptions = { className: string; } >; + ai?: { + bindingName: string; + accountId: string; + apiToken: string; + }; /** * Record mapping binding names to R2 bucket names to inject as R2Bucket. * If a `string[]` of binding names is specified, the binding name and bucket name are assumed to be the same. @@ -84,6 +90,10 @@ async function instantiateMiniflare( const { workerOptions, durableObjects } = (await getDOBindingInfo(options.durableObjects)) ?? {}; + const aiBindingObj = options.ai ? { + [options.ai.bindingName]: getAIFetcher(options.ai), + } : {}; + const { kvNamespaces, r2Buckets, d1Databases, services, textBindings } = options; const bindings = { @@ -102,7 +112,10 @@ async function instantiateMiniflare( ...bindings, modules: true, script: '', - serviceBindings, + serviceBindings: { + ...serviceBindings, + ...aiBindingObj + }, }, ...(workerOptions ? [workerOptions] : []), ]; diff --git a/internal-packages/next-dev/src/wrangler.ts b/internal-packages/next-dev/src/wrangler.ts index 6cc7f3e61..e749943fa 100644 --- a/internal-packages/next-dev/src/wrangler.ts +++ b/internal-packages/next-dev/src/wrangler.ts @@ -97,3 +97,14 @@ const IDENTIFIER_UNSAFE_REGEXP = /[^a-zA-Z0-9_$]/g; export function getIdentifier(name: string) { return name.replace(IDENTIFIER_UNSAFE_REGEXP, '_'); } + +// https://github.com/cloudflare/workers-sdk/blob/912bfe/packages/wrangler/src/cfetch/internal.ts#L119 +export function cloneHeaders( + headers: HeadersInit | undefined +): Record { + return headers instanceof Headers + ? Object.fromEntries(headers.entries()) + : Array.isArray(headers) + ? Object.fromEntries(headers) + : { ...headers }; +}