diff --git a/src/index.ts b/src/index.ts index 6de5948..8ccbbe3 100644 --- a/src/index.ts +++ b/src/index.ts @@ -1,3 +1,4 @@ export { serve, createAdaptorServer } from './server' export { getRequestListener } from './listener' +export { RequestError } from './request' export type { HttpBindings, Http2Bindings } from './types' diff --git a/src/listener.ts b/src/listener.ts index 67c3c26..6022240 100644 --- a/src/listener.ts +++ b/src/listener.ts @@ -1,6 +1,11 @@ import type { IncomingMessage, ServerResponse, OutgoingHttpHeaders } from 'node:http' import type { Http2ServerRequest, Http2ServerResponse } from 'node:http2' -import { getAbortController, newRequest, Request as LightweightRequest } from './request' +import { + getAbortController, + newRequest, + Request as LightweightRequest, + toRequestError, +} from './request' import { cacheKey, getInternalBody, Response as LightweightResponse } from './response' import type { CustomErrorHandler, FetchCallback, HttpBindings } from './types' import { writeFromReadableStream, buildOutgoingHttpHeaders } from './utils' @@ -10,6 +15,11 @@ import './globals' const regBuffer = /^no$/i const regContentType = /^(application\/json\b|text\/(?!event-stream\b))/i +const handleRequestError = (): Response => + new Response(null, { + status: 400, + }) + const handleFetchError = (e: unknown): Response => new Response(null, { status: @@ -140,6 +150,7 @@ const responseViaResponseObject = async ( export const getRequestListener = ( fetchCallback: FetchCallback, options: { + hostname?: string errorHandler?: CustomErrorHandler overrideGlobalObjects?: boolean } = {} @@ -157,12 +168,13 @@ export const getRequestListener = ( incoming: IncomingMessage | Http2ServerRequest, outgoing: ServerResponse | Http2ServerResponse ) => { - let res + // eslint-disable-next-line @typescript-eslint/no-explicit-any + let res, req: any try { // `fetchCallback()` requests a Request object, but global.Request is expensive to generate, // so generate a pseudo Request object with only the minimum required information. - const req = newRequest(incoming) + req = newRequest(incoming, options.hostname) // Detect if request was aborted. outgoing.on('close', () => { @@ -181,10 +193,12 @@ export const getRequestListener = ( } catch (e: unknown) { if (!res) { if (options.errorHandler) { - res = await options.errorHandler(e) + res = await options.errorHandler(req ? e : toRequestError(e)) if (!res) { return } + } else if (!req) { + res = handleRequestError() } else { res = handleFetchError(e) } diff --git a/src/request.ts b/src/request.ts index 49613d8..3ec8116 100644 --- a/src/request.ts +++ b/src/request.ts @@ -6,6 +6,25 @@ import { Http2ServerRequest } from 'node:http2' import { Readable } from 'node:stream' import type { TLSSocket } from 'node:tls' +export class RequestError extends Error { + static name = 'RequestError' + constructor( + message: string, + options?: { + cause?: unknown + } + ) { + super(message, options) + } +} + +export const toRequestError = (e: unknown): RequestError => { + if (e instanceof RequestError) { + return e + } + return new RequestError((e as Error).message, { cause: e }) +} + export const GlobalRequest = global.Request export class Request extends GlobalRequest { constructor(input: string | Request, options?: RequestInit) { @@ -111,18 +130,35 @@ const requestPrototype: Record = { }) Object.setPrototypeOf(requestPrototype, Request.prototype) -export const newRequest = (incoming: IncomingMessage | Http2ServerRequest) => { +export const newRequest = ( + incoming: IncomingMessage | Http2ServerRequest, + defaultHostname?: string +) => { const req = Object.create(requestPrototype) req[incomingKey] = incoming - req[urlKey] = new URL( + + const host = + (incoming instanceof Http2ServerRequest ? incoming.authority : incoming.headers.host) || + defaultHostname + if (!host) { + throw new RequestError('Missing host header') + } + const url = new URL( `${ incoming instanceof Http2ServerRequest || (incoming.socket && (incoming.socket as TLSSocket).encrypted) ? 'https' : 'http' - }://${incoming instanceof Http2ServerRequest ? incoming.authority : incoming.headers.host}${ - incoming.url - }` - ).href + }://${host}${incoming.url}` + ) + + // check by length for performance. + // if suspicious, check by host. host header sometimes contains port. + if (url.hostname.length !== host.length && url.hostname !== host.replace(/:\d+$/, '')) { + throw new RequestError('Invalid host header') + } + + req[urlKey] = url.href + return req } diff --git a/src/server.ts b/src/server.ts index d1b3146..c02fcc8 100644 --- a/src/server.ts +++ b/src/server.ts @@ -6,6 +6,7 @@ import type { Options, ServerType } from './types' export const createAdaptorServer = (options: Options): ServerType => { const fetchCallback = options.fetch const requestListener = getRequestListener(fetchCallback, { + hostname: options.hostname, overrideGlobalObjects: options.overrideGlobalObjects, }) // ts will complain about createServerHTTP and createServerHTTP2 not being callable, which works just fine diff --git a/test/listener.test.ts b/test/listener.test.ts index 65ac1c0..9df6f03 100644 --- a/test/listener.test.ts +++ b/test/listener.test.ts @@ -1,28 +1,68 @@ import { createServer } from 'node:http' import request from 'supertest' import { getRequestListener } from '../src/listener' -import { GlobalRequest, Request as LightweightRequest } from '../src/request' +import { GlobalRequest, Request as LightweightRequest, RequestError } from '../src/request' import { GlobalResponse, Response as LightweightResponse } from '../src/response' describe('Invalid request', () => { - const requestListener = getRequestListener(jest.fn()) - const server = createServer(async (req, res) => { - await requestListener(req, res) + describe('default error handler', () => { + const requestListener = getRequestListener(jest.fn()) + const server = createServer(requestListener) - if (!res.writableEnded) { - res.writeHead(500, { 'Content-Type': 'text/plain' }) - res.end('error handler did not return a response') - } + it('Should return server error for a request w/o host header', async () => { + const res = await request(server).get('/').set('Host', '').send() + expect(res.status).toBe(400) + }) + + it('Should return server error for a request invalid host header', async () => { + const res = await request(server).get('/').set('Host', 'a b').send() + expect(res.status).toBe(400) + }) }) - it('Should return server error for a request w/o host header', async () => { - const res = await request(server).get('/').set('Host', '').send() - expect(res.status).toBe(500) + describe('custom error handler', () => { + const requestListener = getRequestListener(jest.fn(), { + errorHandler: (e) => { + if (e instanceof RequestError) { + return new Response(e.message, { status: 400 }) + } else { + return new Response('unknown error', { status: 500 }) + } + }, + }) + const server = createServer(requestListener) + + it('Should return server error for a request w/o host header', async () => { + const res = await request(server).get('/').set('Host', '').send() + expect(res.status).toBe(400) + }) + + it('Should return server error for a request invalid host header', async () => { + const res = await request(server).get('/').set('Host', 'a b').send() + expect(res.status).toBe(400) + }) + + it('Should return server error for host header with path', async () => { + const res = await request(server).get('/').set('Host', 'a/b').send() + expect(res.status).toBe(400) + }) }) - it('Should return server error for a request invalid host header', async () => { - const res = await request(server).get('/').set('Host', 'a b').send() - expect(res.status).toBe(500) + describe('default hostname', () => { + const requestListener = getRequestListener(() => new Response('ok'), { + hostname: 'example.com', + }) + const server = createServer(requestListener) + + it('Should return 200 for a request w/o host header', async () => { + const res = await request(server).get('/').set('Host', '').send() + expect(res.status).toBe(200) + }) + + it('Should return server error for a request invalid host header', async () => { + const res = await request(server).get('/').set('Host', 'a b').send() + expect(res.status).toBe(400) + }) }) }) diff --git a/test/request.test.ts b/test/request.test.ts index f86bc65..999a371 100644 --- a/test/request.test.ts +++ b/test/request.test.ts @@ -4,6 +4,7 @@ import { Request as LightweightRequest, GlobalRequest, getAbortController, + RequestError, } from '../src/request' Object.defineProperty(global, 'Request', { @@ -40,30 +41,30 @@ describe('Request', () => { expect(req.url).toBe('http://localhost/foo.txt') }) - it('Should resolve double dots in host header', async () => { + it('Should accept hostname and port in host header', async () => { const req = newRequest({ headers: { - host: 'localhost/..', + host: 'localhost:8080', }, - url: '/foo.txt', + url: '/static/../foo.txt', } as IncomingMessage) expect(req).toBeInstanceOf(global.Request) - expect(req.url).toBe('http://localhost/foo.txt') + expect(req.url).toBe('http://localhost:8080/foo.txt') }) it('should generate only one `AbortController` per `Request` object created', async () => { const req = newRequest({ headers: { - host: 'localhost/..', + host: 'localhost', }, - rawHeaders: ['host', 'localhost/..'], + rawHeaders: ['host', 'localhost'], url: '/foo.txt', } as IncomingMessage) const req2 = newRequest({ headers: { - host: 'localhost/..', + host: 'localhost', }, - rawHeaders: ['host', 'localhost/..'], + rawHeaders: ['host', 'localhost'], url: '/foo.txt', } as IncomingMessage) @@ -78,6 +79,39 @@ describe('Request', () => { expect(z).not.toBe(x) expect(z).not.toBe(y) }) + + it('Should throw error if host header contains path', async () => { + expect(() => { + newRequest({ + headers: { + host: 'localhost/..', + }, + url: '/foo.txt', + } as IncomingMessage) + }).toThrow(RequestError) + }) + + it('Should throw error if host header is empty', async () => { + expect(() => { + newRequest({ + headers: { + host: '', + }, + url: '/foo.txt', + } as IncomingMessage) + }).toThrow(RequestError) + }) + + it('Should throw error if host header contains query parameter', async () => { + expect(() => { + newRequest({ + headers: { + host: 'localhost?foo=bar', + }, + url: '/foo.txt', + } as IncomingMessage) + }).toThrow(RequestError) + }) }) describe('GlobalRequest', () => {