Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: validate incoming host header #163

Merged
merged 5 commits into from
Apr 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/index.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
export { serve, createAdaptorServer } from './server'
export { getRequestListener } from './listener'
export { RequestError } from './request'
export type { HttpBindings, Http2Bindings } from './types'
22 changes: 18 additions & 4 deletions src/listener.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import type { IncomingMessage, ServerResponse, OutgoingHttpHeaders } from 'node:http'
import type { Http2ServerRequest, Http2ServerResponse } from 'node:http2'
import { getAbortController, newRequest, Request as LightweightRequest } from './request'
import {
getAbortController,
newRequest,
Request as LightweightRequest,
toRequestError,
} from './request'
import { cacheKey, getInternalBody, Response as LightweightResponse } from './response'
import type { CustomErrorHandler, FetchCallback, HttpBindings } from './types'
import { writeFromReadableStream, buildOutgoingHttpHeaders } from './utils'
Expand All @@ -10,6 +15,11 @@ import './globals'
const regBuffer = /^no$/i
const regContentType = /^(application\/json\b|text\/(?!event-stream\b))/i

const handleRequestError = (): Response =>
new Response(null, {
status: 400,
})

const handleFetchError = (e: unknown): Response =>
new Response(null, {
status:
Expand Down Expand Up @@ -140,6 +150,7 @@ const responseViaResponseObject = async (
export const getRequestListener = (
fetchCallback: FetchCallback,
options: {
hostname?: string
errorHandler?: CustomErrorHandler
overrideGlobalObjects?: boolean
} = {}
Expand All @@ -157,12 +168,13 @@ export const getRequestListener = (
incoming: IncomingMessage | Http2ServerRequest,
outgoing: ServerResponse | Http2ServerResponse
) => {
let res
// eslint-disable-next-line @typescript-eslint/no-explicit-any
let res, req: any

try {
// `fetchCallback()` requests a Request object, but global.Request is expensive to generate,
// so generate a pseudo Request object with only the minimum required information.
const req = newRequest(incoming)
req = newRequest(incoming, options.hostname)

// Detect if request was aborted.
outgoing.on('close', () => {
Expand All @@ -181,10 +193,12 @@ export const getRequestListener = (
} catch (e: unknown) {
if (!res) {
if (options.errorHandler) {
res = await options.errorHandler(e)
res = await options.errorHandler(req ? e : toRequestError(e))
if (!res) {
return
}
} else if (!req) {
res = handleRequestError()
} else {
res = handleFetchError(e)
}
Expand Down
48 changes: 42 additions & 6 deletions src/request.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,25 @@ import { Http2ServerRequest } from 'node:http2'
import { Readable } from 'node:stream'
import type { TLSSocket } from 'node:tls'

export class RequestError extends Error {
static name = 'RequestError'
constructor(
message: string,
options?: {
cause?: unknown
}
) {
super(message, options)
}
}

export const toRequestError = (e: unknown): RequestError => {
if (e instanceof RequestError) {
return e
}
return new RequestError((e as Error).message, { cause: e })
}

export const GlobalRequest = global.Request
export class Request extends GlobalRequest {
constructor(input: string | Request, options?: RequestInit) {
Expand Down Expand Up @@ -111,18 +130,35 @@ const requestPrototype: Record<string | symbol, any> = {
})
Object.setPrototypeOf(requestPrototype, Request.prototype)

export const newRequest = (incoming: IncomingMessage | Http2ServerRequest) => {
export const newRequest = (
incoming: IncomingMessage | Http2ServerRequest,
defaultHostname?: string
) => {
const req = Object.create(requestPrototype)
req[incomingKey] = incoming
req[urlKey] = new URL(

const host =
(incoming instanceof Http2ServerRequest ? incoming.authority : incoming.headers.host) ||
defaultHostname
if (!host) {
throw new RequestError('Missing host header')
}
const url = new URL(
`${
incoming instanceof Http2ServerRequest ||
(incoming.socket && (incoming.socket as TLSSocket).encrypted)
? 'https'
: 'http'
}://${incoming instanceof Http2ServerRequest ? incoming.authority : incoming.headers.host}${
incoming.url
}`
).href
}://${host}${incoming.url}`
)

// check by length for performance.
// if suspicious, check by host. host header sometimes contains port.
if (url.hostname.length !== host.length && url.hostname !== host.replace(/:\d+$/, '')) {
throw new RequestError('Invalid host header')
}

req[urlKey] = url.href

return req
}
1 change: 1 addition & 0 deletions src/server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import type { Options, ServerType } from './types'
export const createAdaptorServer = (options: Options): ServerType => {
const fetchCallback = options.fetch
const requestListener = getRequestListener(fetchCallback, {
hostname: options.hostname,
overrideGlobalObjects: options.overrideGlobalObjects,
})
// ts will complain about createServerHTTP and createServerHTTP2 not being callable, which works just fine
Expand Down
68 changes: 54 additions & 14 deletions test/listener.test.ts
Original file line number Diff line number Diff line change
@@ -1,28 +1,68 @@
import { createServer } from 'node:http'
import request from 'supertest'
import { getRequestListener } from '../src/listener'
import { GlobalRequest, Request as LightweightRequest } from '../src/request'
import { GlobalRequest, Request as LightweightRequest, RequestError } from '../src/request'
import { GlobalResponse, Response as LightweightResponse } from '../src/response'

describe('Invalid request', () => {
const requestListener = getRequestListener(jest.fn())
const server = createServer(async (req, res) => {
await requestListener(req, res)
describe('default error handler', () => {
const requestListener = getRequestListener(jest.fn())
const server = createServer(requestListener)

if (!res.writableEnded) {
res.writeHead(500, { 'Content-Type': 'text/plain' })
res.end('error handler did not return a response')
}
it('Should return server error for a request w/o host header', async () => {
const res = await request(server).get('/').set('Host', '').send()
expect(res.status).toBe(400)
})

it('Should return server error for a request invalid host header', async () => {
const res = await request(server).get('/').set('Host', 'a b').send()
expect(res.status).toBe(400)
})
})

it('Should return server error for a request w/o host header', async () => {
const res = await request(server).get('/').set('Host', '').send()
expect(res.status).toBe(500)
describe('custom error handler', () => {
const requestListener = getRequestListener(jest.fn(), {
errorHandler: (e) => {
if (e instanceof RequestError) {
return new Response(e.message, { status: 400 })
} else {
return new Response('unknown error', { status: 500 })
}
},
})
const server = createServer(requestListener)

it('Should return server error for a request w/o host header', async () => {
const res = await request(server).get('/').set('Host', '').send()
expect(res.status).toBe(400)
})

it('Should return server error for a request invalid host header', async () => {
const res = await request(server).get('/').set('Host', 'a b').send()
expect(res.status).toBe(400)
})

it('Should return server error for host header with path', async () => {
const res = await request(server).get('/').set('Host', 'a/b').send()
expect(res.status).toBe(400)
})
})

it('Should return server error for a request invalid host header', async () => {
const res = await request(server).get('/').set('Host', 'a b').send()
expect(res.status).toBe(500)
describe('default hostname', () => {
const requestListener = getRequestListener(() => new Response('ok'), {
hostname: 'example.com',
})
const server = createServer(requestListener)

it('Should return 200 for a request w/o host header', async () => {
const res = await request(server).get('/').set('Host', '').send()
expect(res.status).toBe(200)
})

it('Should return server error for a request invalid host header', async () => {
const res = await request(server).get('/').set('Host', 'a b').send()
expect(res.status).toBe(400)
})
})
})

Expand Down
50 changes: 42 additions & 8 deletions test/request.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import {
Request as LightweightRequest,
GlobalRequest,
getAbortController,
RequestError,
} from '../src/request'

Object.defineProperty(global, 'Request', {
Expand Down Expand Up @@ -40,30 +41,30 @@ describe('Request', () => {
expect(req.url).toBe('http://localhost/foo.txt')
})

it('Should resolve double dots in host header', async () => {
it('Should accept hostname and port in host header', async () => {
const req = newRequest({
headers: {
host: 'localhost/..',
host: 'localhost:8080',
},
url: '/foo.txt',
url: '/static/../foo.txt',
} as IncomingMessage)
expect(req).toBeInstanceOf(global.Request)
expect(req.url).toBe('http://localhost/foo.txt')
expect(req.url).toBe('http://localhost:8080/foo.txt')
})

it('should generate only one `AbortController` per `Request` object created', async () => {
const req = newRequest({
headers: {
host: 'localhost/..',
host: 'localhost',
},
rawHeaders: ['host', 'localhost/..'],
rawHeaders: ['host', 'localhost'],
url: '/foo.txt',
} as IncomingMessage)
const req2 = newRequest({
headers: {
host: 'localhost/..',
host: 'localhost',
},
rawHeaders: ['host', 'localhost/..'],
rawHeaders: ['host', 'localhost'],
url: '/foo.txt',
} as IncomingMessage)

Expand All @@ -78,6 +79,39 @@ describe('Request', () => {
expect(z).not.toBe(x)
expect(z).not.toBe(y)
})

it('Should throw error if host header contains path', async () => {
expect(() => {
newRequest({
headers: {
host: 'localhost/..',
},
url: '/foo.txt',
} as IncomingMessage)
}).toThrow(RequestError)
})

it('Should throw error if host header is empty', async () => {
expect(() => {
newRequest({
headers: {
host: '',
},
url: '/foo.txt',
} as IncomingMessage)
}).toThrow(RequestError)
})

it('Should throw error if host header contains query parameter', async () => {
expect(() => {
newRequest({
headers: {
host: 'localhost?foo=bar',
},
url: '/foo.txt',
} as IncomingMessage)
}).toThrow(RequestError)
})
})

describe('GlobalRequest', () => {
Expand Down
Loading