Skip to content

Commit

Permalink
fixup!
Browse files Browse the repository at this point in the history
- fix code
- add tests
- minor cleanups in code + tests
  • Loading branch information
vicb committed Oct 30, 2024
1 parent 8a8f9cd commit 2e0455f
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 27 deletions.
11 changes: 7 additions & 4 deletions packages/open-next/src/core/routing/middleware.ts
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,18 @@ export async function handleMiddleware(
// We bypass the middleware if the request is internal
if (internalEvent.headers["x-isr"]) return internalEvent;

const protocol = new URL(internalEvent.url).protocol;
// Retrieve the url protocol:
// - In lambda, the url only contains the rawPath and the query - default to https.
// - In cloudflare, the protocol is usually http in dev and https in production.
const protocol = internalEvent.url.startsWith("http://") ? "http:" : "https:";

const host = internalEvent.headers.host
? `${protocol}//${internalEvent.headers.host}`
: "http://localhost:3000";

const initialUrl = new URL(normalizedPath, host);
initialUrl.search = convertToQueryString(query);
const url = initialUrl.toString();
// console.log("url", url, normalizedPath);

const middleware = await middlewareLoader();

Expand Down Expand Up @@ -127,7 +130,7 @@ export async function handleMiddleware(
.get("location")
?.replace(
"http://localhost:3000",
`https://${internalEvent.headers.host}`,
`${protocol}//${internalEvent.headers.host}`,
) ?? resHeaders.location;
// res.setHeader("Location", location);
return {
Expand Down Expand Up @@ -192,7 +195,7 @@ export async function handleMiddleware(
responseHeaders: resHeaders,
url: newUrl,
rawPath: rewritten
? (newUrl ?? internalEvent.rawPath)
? newUrl ?? internalEvent.rawPath

Check failure on line 198 in packages/open-next/src/core/routing/middleware.ts

View workflow job for this annotation

GitHub Actions / validate

Replace `newUrl·??·internalEvent.rawPath` with `(newUrl·??·internalEvent.rawPath)`
: internalEvent.rawPath,
type: internalEvent.type,
headers: { ...internalEvent.headers, ...reqHeaders },
Expand Down
94 changes: 76 additions & 18 deletions packages/tests-unit/tests/core/routing/middleware.test.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import { handleMiddleware } from "@opennextjs/aws/core/routing/middleware.js";
import { convertFromQueryString } from "@opennextjs/aws/core/routing/util.js";
import {
convertFromQueryString,
isExternal,
} from "@opennextjs/aws/core/routing/util.js";
import type { InternalEvent } from "@opennextjs/aws/types/open-next.js";
import { toReadableStream } from "@opennextjs/aws/utils/stream.js";
import { vi } from "vitest";
Expand Down Expand Up @@ -48,15 +51,25 @@ type PartialEvent = Partial<
> & { body?: string };

function createEvent(event: PartialEvent): InternalEvent {
const [rawPath, qs] = (event.url ?? "/").split("?", 2);
let rawPath: string;
let qs: string;
if (isExternal(event.url)) {
const url = new URL(event.url!);
rawPath = url.pathname;
qs = url.search;
} else {
const parts = (event.url ?? "/").split("?", 2);
rawPath = parts[0];
qs = parts[1] ?? "";
}
return {
type: "core",
method: event.method ?? "GET",
rawPath,
url: event.url ?? "/",
body: Buffer.from(event.body ?? ""),
headers: event.headers ?? {},
query: convertFromQueryString(qs ?? ""),
query: convertFromQueryString(qs),
cookies: event.cookies ?? {},
remoteAddress: event.remoteAddress ?? "::1",
};
Expand All @@ -70,19 +83,19 @@ beforeEach(() => {
* Ideally these tests would be broken up and tests smaller parts of the middleware rather than the entire function.
*/
describe("handleMiddleware", () => {
it("should bypass middlware for internal requests", async () => {
it("should bypass middleware for internal requests", async () => {
const event = createEvent({
headers: {
"x-isr": "1",
},
});
const result = await handleMiddleware(event, middlewareLoader);

expect(middlewareLoader).not.toBeCalled();
expect(middlewareLoader).not.toHaveBeenCalled();
expect(result).toEqual(event);
});

it("should invoke middlware with redirect", async () => {
it("should invoke middleware with redirect", async () => {
const event = createEvent({});
middleware.mockResolvedValue({
status: 302,
Expand All @@ -92,12 +105,12 @@ describe("handleMiddleware", () => {
});
const result = await handleMiddleware(event, middlewareLoader);

expect(middlewareLoader).toBeCalled();
expect(middlewareLoader).toHaveBeenCalled();
expect(result.statusCode).toEqual(302);
expect(result.headers.location).toEqual("/redirect");
});

it("should invoke middlware with external redirect", async () => {
it("should invoke middleware with external redirect", async () => {
const event = createEvent({});
middleware.mockResolvedValue({
status: 302,
Expand All @@ -107,12 +120,12 @@ describe("handleMiddleware", () => {
});
const result = await handleMiddleware(event, middlewareLoader);

expect(middlewareLoader).toBeCalled();
expect(middlewareLoader).toHaveBeenCalled();
expect(result.statusCode).toEqual(302);
expect(result.headers.location).toEqual("http://external/redirect");
});

it("should invoke middlware with rewrite", async () => {
it("should invoke middleware with rewrite", async () => {
const event = createEvent({
headers: {
host: "localhost",
Expand All @@ -125,7 +138,7 @@ describe("handleMiddleware", () => {
});
const result = await handleMiddleware(event, middlewareLoader);

expect(middlewareLoader).toBeCalled();
expect(middlewareLoader).toHaveBeenCalled();
expect(result).toEqual({
...event,
rawPath: "/rewrite",
Expand All @@ -137,7 +150,7 @@ describe("handleMiddleware", () => {
});
});

it("should invoke middlware with rewrite with __nextDataReq", async () => {
it("should invoke middleware with rewrite with __nextDataReq", async () => {
const event = createEvent({
url: "/rewrite?__nextDataReq=1&key=value",
headers: {
Expand All @@ -151,7 +164,7 @@ describe("handleMiddleware", () => {
});
const result = await handleMiddleware(event, middlewareLoader);

expect(middlewareLoader).toBeCalled();
expect(middlewareLoader).toHaveBeenCalled();
expect(result).toEqual({
...event,
rawPath: "/rewrite",
Expand All @@ -167,7 +180,7 @@ describe("handleMiddleware", () => {
});
});

it("should invoke middlware with external rewrite", async () => {
it("should invoke middleware with external rewrite", async () => {
const event = createEvent({
headers: {
host: "localhost",
Expand All @@ -180,7 +193,7 @@ describe("handleMiddleware", () => {
});
const result = await handleMiddleware(event, middlewareLoader);

expect(middlewareLoader).toBeCalled();
expect(middlewareLoader).toHaveBeenCalled();
expect(result).toEqual({
...event,
rawPath: "http://external/rewrite",
Expand All @@ -201,7 +214,7 @@ describe("handleMiddleware", () => {
});
const result = await handleMiddleware(event, middlewareLoader);

expect(middlewareLoader).toBeCalled();
expect(middlewareLoader).toHaveBeenCalled();
expect(result).toEqual({
...event,
headers: {
Expand All @@ -223,7 +236,7 @@ describe("handleMiddleware", () => {
});
const result = await handleMiddleware(event, middlewareLoader);

expect(middlewareLoader).toBeCalled();
expect(middlewareLoader).toHaveBeenCalled();
expect(result).toEqual({
type: "core",
statusCode: 200,
Expand All @@ -246,7 +259,7 @@ describe("handleMiddleware", () => {
});
const result = await handleMiddleware(event, middlewareLoader);

expect(middlewareLoader).toBeCalled();
expect(middlewareLoader).toHaveBeenCalled();
expect(result).toEqual({
type: "core",
statusCode: 200,
Expand All @@ -257,4 +270,49 @@ describe("handleMiddleware", () => {
isBase64Encoded: false,
});
});

it("should use the http event protocol when specified", async () => {
const event = createEvent({
url: "http://test.me/path",
headers: {
host: "test.me",
},
});
await handleMiddleware(event, middlewareLoader);
expect(middleware).toHaveBeenCalledWith(
expect.objectContaining({
url: "http://test.me/path",
}),
);
});

it("should use the https event protocol when specified", async () => {
const event = createEvent({
url: "https://test.me/path",
headers: {
host: "test.me/path",
},
});
await handleMiddleware(event, middlewareLoader);
expect(middleware).toHaveBeenCalledWith(
expect.objectContaining({
url: "https://test.me/path",
}),
);
});

it("should default to https protocol", async () => {
const event = createEvent({
url: "/path",
headers: {
host: "test.me",
},
});
await handleMiddleware(event, middlewareLoader);
expect(middleware).toHaveBeenCalledWith(
expect.objectContaining({
url: "https://test.me/path",
}),
);
});
});
10 changes: 5 additions & 5 deletions packages/tests-unit/tests/core/routing/util.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -146,11 +146,11 @@ describe("getUrlParts", () => {

describe("external", () => {
it("throws for empty url", () => {
expect(() => getUrlParts("", true)).toThrowError();
expect(() => getUrlParts("", true)).toThrow();
});

it("throws for invalid url", () => {
expect(() => getUrlParts("/relative", true)).toThrowError();
expect(() => getUrlParts("/relative", true)).toThrow();
});

it("returns url parts for /", () => {
Expand Down Expand Up @@ -581,7 +581,7 @@ describe("revalidateIfRequired", () => {
const headers: Record<string, string> = {};
await revalidateIfRequired("localhost", "/path", headers);

expect(sendMock).not.toBeCalled();
expect(sendMock).not.toHaveBeenCalled();
});

it("should send to queue when x-nextjs-cache is STALE", async () => {
Expand All @@ -590,7 +590,7 @@ describe("revalidateIfRequired", () => {
};
await revalidateIfRequired("localhost", "/path", headers);

expect(sendMock).toBeCalledWith({
expect(sendMock).toHaveBeenCalledWith({
MessageBody: { host: "localhost", url: "/path" },
MessageDeduplicationId: expect.any(String),
MessageGroupId: expect.any(String),
Expand All @@ -604,7 +604,7 @@ describe("revalidateIfRequired", () => {
sendMock.mockRejectedValueOnce(new Error("Failed to send"));
await revalidateIfRequired("localhost", "/path", headers);

expect(sendMock).toBeCalledWith({
expect(sendMock).toHaveBeenCalledWith({
MessageBody: { host: "localhost", url: "/path" },
MessageDeduplicationId: expect.any(String),
MessageGroupId: expect.any(String),
Expand Down

0 comments on commit 2e0455f

Please sign in to comment.