Skip to content

Commit

Permalink
enable ai binding in next-dev module
Browse files Browse the repository at this point in the history
  • Loading branch information
dario-piotrowicz committed Dec 17, 2023
1 parent a4efc7b commit 4be7860
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 1 deletion.
44 changes: 44 additions & 0 deletions internal-packages/next-dev/src/ai.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import { Response } from 'miniflare';
import type { Request } from 'miniflare';

import { cloneHeaders } from "./wrangler";

type Fetcher = (req: Request) => Promise<Response>;

export function getAIFetcher({ accountId, apiToken }: {accountId: string, apiToken: string}): Fetcher {
// Tweaked version of the wrangler ai fetcher
// (source: https://github.com/cloudflare/workers-sdk/blob/912bfe/packages/wrangler/src/ai/fetcher.ts)
return async function AIFetcher(request: Request) {
request.headers.delete("Host");
request.headers.delete("Content-Length");

const res = await performApiFetch(`/accounts/${accountId}/ai/run/proxy`, {
method: "POST",
headers: Object.fromEntries(request.headers.entries()),
body: request.body as BodyInit,
// eslint-disable-next-line @typescript-eslint/ban-ts-comment
// @ts-ignore
duplex: "half",
}, apiToken);

return new Response(res.body, { status: res.status });
};
}

// (Heavily) Simplified version of performApiFetch from wrangler
// (source: https://github.com/cloudflare/workers-sdk/blob/912bfe/packages/wrangler/src/cfetch/internal.ts#L18)
export async function performApiFetch(
resource: string,
init: RequestInit = {},
apiToken: string,
) {
const method = init.method ?? "GET";
const headers = cloneHeaders(init.headers);
headers["Authorization"] = `Bearer ${apiToken}`;

return fetch(`https://api.cloudflare.com/client/v4${resource}`, {
method,
...init,
headers,
});
}
15 changes: 14 additions & 1 deletion internal-packages/next-dev/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import type { WorkerOptions } from 'miniflare';
import { Miniflare, Request, Response, Headers } from 'miniflare';
import { getDOBindingInfo } from './durableObjects';
import { getServiceBindings } from './services';
import { getAIFetcher } from './ai';

/**
* Sets up the bindings that need to be available during development time (using
Expand Down Expand Up @@ -54,6 +55,11 @@ export type DevBindingsOptions = {
className: string;
}
>;
ai?: {
bindingName: string;
accountId: string;
apiToken: string;
};
/**
* Record mapping binding names to R2 bucket names to inject as R2Bucket.
* If a `string[]` of binding names is specified, the binding name and bucket name are assumed to be the same.
Expand Down Expand Up @@ -84,6 +90,10 @@ async function instantiateMiniflare(
const { workerOptions, durableObjects } =
(await getDOBindingInfo(options.durableObjects)) ?? {};

const aiBindingObj = options.ai ? {
[options.ai.bindingName]: getAIFetcher(options.ai),
} : {};

const { kvNamespaces, r2Buckets, d1Databases, services, textBindings } =
options;
const bindings = {
Expand All @@ -102,7 +112,10 @@ async function instantiateMiniflare(
...bindings,
modules: true,
script: '',
serviceBindings,
serviceBindings: {
...serviceBindings,
...aiBindingObj
},
},
...(workerOptions ? [workerOptions] : []),
];
Expand Down
11 changes: 11 additions & 0 deletions internal-packages/next-dev/src/wrangler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,14 @@ const IDENTIFIER_UNSAFE_REGEXP = /[^a-zA-Z0-9_$]/g;
export function getIdentifier(name: string) {
return name.replace(IDENTIFIER_UNSAFE_REGEXP, '_');
}

// https://github.com/cloudflare/workers-sdk/blob/912bfe/packages/wrangler/src/cfetch/internal.ts#L119
export function cloneHeaders(
headers: HeadersInit | undefined
): Record<string, string> {
return headers instanceof Headers
? Object.fromEntries(headers.entries())
: Array.isArray(headers)
? Object.fromEntries(headers)
: { ...headers };
}

0 comments on commit 4be7860

Please sign in to comment.