Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable ai binding in next-dev module #597

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions internal-packages/next-dev/src/ai.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import { Response } from 'miniflare';
import type { Request } from 'miniflare';

import { cloneHeaders } from './wrangler';

type Fetcher = (req: Request) => Promise<Response>;

export function getAIFetcher({
accountId,
apiToken,
}: {
accountId: string;
apiToken: string;
}): Fetcher {
// Tweaked version of the wrangler ai fetcher
// (source: https://github.com/cloudflare/workers-sdk/blob/912bfe/packages/wrangler/src/ai/fetcher.ts)
return async function AIFetcher(request: Request) {
request.headers.delete('Host');
request.headers.delete('Content-Length');

const res = await performApiFetch(
`/accounts/${accountId}/ai/run/proxy`,
{
method: 'POST',
headers: Object.fromEntries(request.headers.entries()),
body: request.body as BodyInit,
// eslint-disable-next-line @typescript-eslint/ban-ts-comment
// @ts-ignore
duplex: 'half',
},
apiToken,
);

return new Response(res.body, { status: res.status });
};
}

// (Heavily) Simplified version of performApiFetch from wrangler
// (source: https://github.com/cloudflare/workers-sdk/blob/912bfe/packages/wrangler/src/cfetch/internal.ts#L18)
export async function performApiFetch(
resource: string,
init: RequestInit = {},
apiToken: string,
) {
const method = init.method ?? 'GET';
const headers = cloneHeaders(init.headers);
headers['Authorization'] = `Bearer ${apiToken}`;

return fetch(`https://api.cloudflare.com/client/v4${resource}`, {
method,
...init,
headers,
});
}
17 changes: 16 additions & 1 deletion internal-packages/next-dev/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import type { WorkerOptions } from 'miniflare';
import { Miniflare, Request, Response, Headers } from 'miniflare';
import { getDOBindingInfo } from './durableObjects';
import { getServiceBindings } from './services';
import { getAIFetcher } from './ai';

/**
* Sets up the bindings that need to be available during development time (using
Expand Down Expand Up @@ -54,6 +55,11 @@ export type DevBindingsOptions = {
className: string;
}
>;
ai?: {
bindingName: string;
accountId: string;
apiToken: string;
};
/**
* Record mapping binding names to R2 bucket names to inject as R2Bucket.
* If a `string[]` of binding names is specified, the binding name and bucket name are assumed to be the same.
Expand Down Expand Up @@ -84,6 +90,12 @@ async function instantiateMiniflare(
const { workerOptions, durableObjects } =
(await getDOBindingInfo(options.durableObjects)) ?? {};

const aiBindingObj = options.ai
? {
[options.ai.bindingName]: getAIFetcher(options.ai),
}
: {};

const { kvNamespaces, r2Buckets, d1Databases, services, textBindings } =
options;
const bindings = {
Expand All @@ -102,7 +114,10 @@ async function instantiateMiniflare(
...bindings,
modules: true,
script: '',
serviceBindings,
serviceBindings: {
...serviceBindings,
...aiBindingObj,
},
},
...(workerOptions ? [workerOptions] : []),
];
Expand Down
11 changes: 11 additions & 0 deletions internal-packages/next-dev/src/wrangler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,14 @@ const IDENTIFIER_UNSAFE_REGEXP = /[^a-zA-Z0-9_$]/g;
export function getIdentifier(name: string) {
return name.replace(IDENTIFIER_UNSAFE_REGEXP, '_');
}

// https://github.com/cloudflare/workers-sdk/blob/912bfe/packages/wrangler/src/cfetch/internal.ts#L119
export function cloneHeaders(
headers: HeadersInit | undefined,
): Record<string, string> {
return headers instanceof Headers
? Object.fromEntries(headers.entries())
: Array.isArray(headers)
? Object.fromEntries(headers)
: { ...headers };
}
Loading