Skip to content

Commit

Permalink
fix(lambda-api): track api key usage (#943)
Browse files Browse the repository at this point in the history
  • Loading branch information
blacha committed Jul 21, 2020
1 parent 723cad8 commit 7c4689c
Show file tree
Hide file tree
Showing 16 changed files with 79 additions and 107 deletions.
5 changes: 3 additions & 2 deletions packages/_infra/src/edge/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,9 @@ export class EdgeStack extends cdk.Stack {
queryString: true,
queryStringCacheKeys: ['NOT_A_CACHE_KEY'],
},
// TODO track API keys with viewer requests
// lambdaFunctionAssociations: [lambdaViewerRequest],
lambdaFunctionAssociations: [
{ eventType: cf.LambdaEdgeEventType.VIEWER_REQUEST, lambdaFunction: this.lambda.version },
],
},
],
};
Expand Down
14 changes: 12 additions & 2 deletions packages/_infra/src/edge/lambda.edge.api.key.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,16 @@ import iam = require('@aws-cdk/aws-iam');
import lambda = require('@aws-cdk/aws-lambda');
import { RetentionDays } from '@aws-cdk/aws-logs';
import { ApiKeyTableArn } from '../api.key.db';
import { VersionUtil } from '../version';
import { Env } from '@basemaps/shared';

const CODE_PATH = '../lambda-api-tracker/dist';
/**
* Create a API Key validation edge lambda
*/
export class LambdaApiKeyValidator extends cdk.Construct {
public lambda: lambda.Function;
public version: lambda.Version;

public constructor(scope: cdk.Stack, id: string) {
super(scope, id);
Expand All @@ -23,16 +26,23 @@ export class LambdaApiKeyValidator extends cdk.Construct {
managedPolicies: [{ managedPolicyArn: 'arn:aws:iam::aws:policy/service-role/AWSLambdaBasicExecutionRole' }],
});

const version = VersionUtil.version();

this.lambda = new lambda.Function(this, 'ApiValidatorFunction', {
runtime: lambda.Runtime.NODEJS_10_X,
handler: 'index.handler',
code: lambda.Code.asset(CODE_PATH),
role: lambdaRole,
logRetention: RetentionDays.ONE_MONTH,
// Lambda@Edge only allows 128mb of ram
memorySize: 128,
environment: {
[Env.NodeEnv]: Env.get(Env.NodeEnv, 'dev'),
[Env.Hash]: version.hash,
[Env.Version]: version.version,
},
});

this.version = this.lambda.addVersion(version.hash);

// Allow access to all dynamoDb tables with the same name
const dynamoPolicy = new iam.PolicyStatement();
dynamoPolicy.addActions('dynamoDB:getItem', 'dynamoDB:putItem', 'dynamodb:UpdateItem');
Expand Down
28 changes: 15 additions & 13 deletions packages/lambda-api-tracker/src/__test__/validate.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import o from 'ospec';
import { ValidateRequest } from '../validate';

o.spec('ApiValidate', (): void => {
const dummyApiKey = 'dummy1';
const faultyApiKey = 'fault1';
const mockApiKey = 'mock1';

Expand All @@ -16,8 +15,10 @@ o.spec('ApiValidate', (): void => {
Aws.apiKey.get = oldGet;
});

function makeContext(): LambdaContext {
return new LambdaContext({} as any, LogConfig.get());
function makeContext(apiKey: string): LambdaContext {
const ctx = new LambdaContext({} as any, LogConfig.get());
ctx.apiKey = apiKey;
return ctx;
}

o('validate should fail on faulty apikey', async () => {
Expand All @@ -33,7 +34,7 @@ o.spec('ApiValidate', (): void => {
minuteCount: 100,
} as ApiKeyTableRecord;
};
const result = await ValidateRequest.validate(faultyApiKey, makeContext());
const result = await ValidateRequest.validate(makeContext(faultyApiKey));
o(result).notEquals(null);
if (result == null) throw new Error('Validate returns null result');

Expand All @@ -42,15 +43,16 @@ o.spec('ApiValidate', (): void => {
o(result.statusDescription).equals('API key disabled');
});

o('validate should fail on null record', async () => {
Aws.apiKey.get = async (): Promise<null> => null;
const result = await ValidateRequest.validate(dummyApiKey, makeContext());
o(result).notEquals(null);
if (result == null) throw new Error('Validate returns null result');
// TODO this should be re-enabled at some stage
// o('validate should fail on null record', async () => {
// Aws.apiKey.get = async (): Promise<null> => null;
// const result = await ValidateRequest.validate(makeContext(dummyApiKey));
// o(result).notEquals(null);
// if (result == null) throw new Error('Validate returns null result');

o(result.status).equals(403);
o(result.statusDescription).equals('Invalid API Key');
});
// o(result.status).equals(403);
// o(result.statusDescription).equals('Invalid API Key');
// });

o('validate should fail with rate limit error', async () => {
const mockMinuteCount = 1e4;
Expand All @@ -65,7 +67,7 @@ o.spec('ApiValidate', (): void => {
minuteCount: mockMinuteCount,
} as ApiKeyTableRecord;
};
const result = await ValidateRequest.validate(mockApiKey, makeContext());
const result = await ValidateRequest.validate(makeContext(mockApiKey));
o(result).notEquals(null);
if (result == null) throw new Error('Validate returns null result');

Expand Down
31 changes: 0 additions & 31 deletions packages/lambda-api-tracker/src/__test__/xyz.request.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -73,37 +73,6 @@ o.spec('xyz-request', () => {
value: corrId,
},
],
'x-linz-api-key': [{ key: 'x-linz-api-key', value: '12345' }],
'x-linz-request-id': [{ key: 'x-linz-request-id', value: String(res.header(HttpHeader.RequestId)) }],
});
});

o('should not cache WMTSCapabilities', async () => {
ValidateRequest.validate = async (): Promise<LambdaHttpResponse | null> => null;

const request = req('/v1/tiles/aerial/3857/WMTSCapabilities.xml');
const res = await handleRequest(request);

o(res.status).equals(100);
const response = LambdaContext.toResponse(request, res) as CloudFrontRequestResult;

const corrId = String(res.header(HttpHeader.CorrelationId));
o(response?.headers).deepEquals({
referer: [{ key: 'Referer', value: 'from/url' }],
'user-agent': [{ key: 'User-Agent', value: 'test browser' }],
'cache-control': [
{
key: 'cache-control',
value: 'max-age=0',
},
],
'x-linz-correlation-id': [
{
key: 'x-linz-correlation-id',
value: corrId,
},
],
'x-linz-api-key': [{ key: 'x-linz-api-key', value: '12345' }],
'x-linz-request-id': [{ key: 'x-linz-request-id', value: String(res.header(HttpHeader.RequestId)) }],
});
});
Expand Down
47 changes: 16 additions & 31 deletions packages/lambda-api-tracker/src/index.ts
Original file line number Diff line number Diff line change
@@ -1,60 +1,45 @@
import { HttpHeader, LambdaContext, LambdaFunction, LambdaHttpResponse } from '@basemaps/lambda';
import { Const, LogConfig, tileFromPath, TileType, ProjectionTileMatrixSet } from '@basemaps/shared';
import { LambdaContext, LambdaFunction, LambdaHttpResponse } from '@basemaps/lambda';
import { LogConfig, ProjectionTileMatrixSet, tileFromPath, TileType } from '@basemaps/shared';
import { ValidateRequest } from './validate';

function setTileInfo(ctx: LambdaContext): boolean {
function setTileInfo(ctx: LambdaContext): void {
const xyzData = tileFromPath(ctx.action.rest);
if (xyzData?.type === TileType.WMTS) {
return true;
}
if (xyzData == null) return;

if (xyzData?.type === TileType.Image) {
const { x, y, z } = xyzData;
if (xyzData.type === TileType.Image) {
const { x, y, z, ext } = xyzData;
ctx.set('xyz', { x, y, z });
ctx.set('projection', xyzData.projection.code);
ctx.set('extension', ext);

const latLon = ProjectionTileMatrixSet.get(xyzData.projection.code).tileCenterToLatLon(xyzData);
ctx.set('location', latLon);
}
return false;
}

/**
* Validate a CloudFront request has a valid API key and is not abusing the system
*/
export async function handleRequest(req: LambdaContext): Promise<LambdaHttpResponse> {
req.set('name', 'LambdaApiTracker');
req.set('method', req.method);
req.set('path', req.path);

// Extract request information
// ctx.set('clientIp', ctx.evt.clientIp); FIXME
if (LambdaContext.isCloudFrontEvent(req.evt)) {
req.set('clientIp', req.evt.Records[0].cf.request.clientIp);
}

req.set('referer', req.header('referer'));
req.set('userAgent', req.header('user-agent'));

const doNotCache = req.action.name === 'tiles' && setTileInfo(req);

const apiKey = req.query[Const.ApiKey.QueryString];
if (apiKey == null || Array.isArray(apiKey)) {
return new LambdaHttpResponse(400, 'Invalid API Key');
}

// Include the APIKey in the final log entry
req.set(Const.ApiKey.QueryString, apiKey);
if (req.action.name === 'tiles') setTileInfo(req);

// Validate the request throwing an error if anything goes wrong
req.timer.start('validate');
const res = await ValidateRequest.validate(apiKey, req);
const res = await ValidateRequest.validate(req);
req.timer.end('validate');

if (res != null) {
return res;
}
if (res != null) return res;

const response = new LambdaHttpResponse(100, 'Continue');
// Api key will be trimmed from the forwarded request so pass it via a well known header
response.header(HttpHeader.ApiKey, apiKey);
if (doNotCache) response.header(HttpHeader.CacheControl, 'max-age=0');
return response;
return new LambdaHttpResponse(100, 'Continue');
}

export const handler = LambdaFunction.wrap(handleRequest, LogConfig.get());
12 changes: 7 additions & 5 deletions packages/lambda-api-tracker/src/validate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,18 @@ export const ValidateRequest = {
* Validate that a API Key is valid
* @param apiKey API key to validate
*/
async validate(apiKey: string, ctx: LambdaContext): Promise<LambdaHttpResponse | null> {
async validate(ctx: LambdaContext): Promise<LambdaHttpResponse | null> {
const timer = ctx.timer;

if (ctx.apiKey == null) return new LambdaHttpResponse(400, 'Invalid API Key');

// TODO increment the api counter
timer.start('validate:db');
const record = await Aws.apiKey.get(apiKey);
const record = await Aws.apiKey.get(ctx.apiKey);
timer.end('validate:db');

if (record == null) {
return new LambdaHttpResponse(403, 'Invalid API Key');
}
if (record == null) return null; // Allow invalid keys for now

ctx.log.info({ record }, 'Record');

if (record.lockToIp != null) {
Expand Down
4 changes: 0 additions & 4 deletions packages/lambda-xyz/src/__test__/xyz.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,7 @@ o.spec('LambdaXyz', () => {
o(z).equals(0);

// Validate the session information has been set correctly
o(request.logContext['path']).equals(`/v1/tiles/${tileSetName}/global-mercator/0/0/0.png`);
o(request.logContext['tileSet']).equals(tileSetName);
o(request.logContext['method']).equals('get');
o(request.logContext['xyz']).deepEquals({ x: 0, y: 0, z: 0 });
o(round(request.logContext['location'])).deepEquals({ lat: 0, lon: 0 });
});
Expand All @@ -109,8 +107,6 @@ o.spec('LambdaXyz', () => {
o(z).equals(0);

// Validate the session information has been set correctly
o(request.logContext['path']).equals('/v1/tiles/aerial/3857/0/0/0.webp');
o(request.logContext['method']).equals('get');
o(request.logContext['xyz']).deepEquals({ x: 0, y: 0, z: 0 });
o(round(request.logContext['location'])).deepEquals({ lat: 0, lon: 0 });
});
Expand Down
10 changes: 2 additions & 8 deletions packages/lambda-xyz/src/index.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { LambdaContext, LambdaFunction, LambdaHttpResponse, Router } from '@basemaps/lambda';
import { LogConfig } from '@basemaps/shared';
import { Health, Ping, Version } from './routes/api';
import { TileOrWmts } from './routes/tile';
import { LogConfig, Const } from '@basemaps/shared';

const app = new Router();

Expand All @@ -11,13 +11,7 @@ app.get('version', Version);
app.get('tiles', TileOrWmts);

export async function handleRequest(req: LambdaContext): Promise<LambdaHttpResponse> {
req.set('name', 'LambdaXyzTiler');
req.set('method', req.method);
req.set('path', req.path);

const apiKey = req.query[Const.ApiKey.QueryString];
if (apiKey != null && !Array.isArray(apiKey)) req.set(Const.ApiKey.QueryString, apiKey);

req.set('name', 'LambdaTiler');
return await app.handle(req);
}

Expand Down
2 changes: 2 additions & 0 deletions packages/lambda-xyz/src/routes/tile.ts
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ export async function tile(req: LambdaContext, xyzData: TileDataXyz): Promise<La

const { x, y, z, ext } = xyzData;
req.set('xyz', { x, y, z });
req.set('projection', xyzData.projection.code);
req.set('extension', ext);
if (z > tiler.tms.maxZoom) return new LambdaHttpResponse(404, `Zoom not found : ${z}`);

const latLon = ProjectionTileMatrixSet.get(xyzData.projection.code).tileCenterToLatLon(xyzData);
Expand Down
3 changes: 3 additions & 0 deletions packages/lambda/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,14 @@
"test": "ospec --globs 'build/**/*.test.js' --preload ../../scripts/test.before.js"
},
"dependencies": {
"@basemaps/shared": "^4.2.0",
"@basemaps/metrics": "^4.0.0",
"source-map-support": "^0.5.19",
"ulid": "^2.3.0"
},
"devDependencies": {
"@basemaps/geo": "^4.1.0",
"@basemaps/tiler": "^4.2.0",
"@types/aws-lambda": "^8.10.43"
},
"publishConfig": {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import { Epsg } from '@basemaps/geo';
import o from 'ospec';
import { tileFromPath, TileType } from '../api.path';
import { LambdaContext } from '@basemaps/lambda';
import { LogConfig, tileFromPath, TileType } from '@basemaps/shared';
import { ImageFormat } from '@basemaps/tiler';
import { LogConfig } from '../log';
import o from 'ospec';
import { LambdaContext } from '../lambda.context';

o.spec('api.path', () => {
function makeContext(path: string): LambdaContext {
Expand Down
6 changes: 3 additions & 3 deletions packages/lambda/src/__test__/lambda.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ o.spec('LambdaFunction', () => {
const testFunc = LambdaFunction.wrap(asyncThrow, FakeLogger());

const spy = o.spy();
await testFunc({} as any, null as any, spy);
await testFunc({ httpMethod: 'GET' } as any, null as any, spy);
o(spy.calls.length).equals(1);
const err = spy.args[0];
const res = spy.args[1] as ALBResult;
Expand All @@ -39,7 +39,7 @@ o.spec('LambdaFunction', () => {
const testFunc = LambdaFunction.wrap(asyncThrow, FakeLogger());

const spy = o.spy();
await testFunc({ Records: [{ cf: { request: { headers: {} } } }] } as any, null as any, spy);
await testFunc({ Records: [{ cf: { request: { method: 'GET', headers: {} } } }] } as any, null as any, spy);
o(spy.calls.length).equals(1);
const err = spy.args[0];
const res = spy.args[1] as CloudFrontResultResponse;
Expand Down Expand Up @@ -68,7 +68,7 @@ o.spec('LambdaFunction', () => {
}, fakeLogger);

const spy = o.spy();
await testFunc({} as any, null as any, spy);
await testFunc({ httpMethod: 'GET' } as any, null as any, spy);
o(spy.calls.length).equals(1);
o(spy.args[1]).deepEquals(LambdaContext.toAlbResponse(albOk));

Expand Down
8 changes: 7 additions & 1 deletion packages/lambda/src/lambda.context.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import { toAlbHeaders, toCloudFrontHeaders } from './lambda.aws';
import * as qs from 'querystring';
import { HttpHeader } from './header';
import { LambdaHttpResponse } from './lambda.response';
import { Const } from '@basemaps/shared';

export interface ActionData {
version: string;
Expand Down Expand Up @@ -50,7 +51,11 @@ export class LambdaContext {
this.evt = evt;
this.id = ulid.ulid();
this.loadHeaders();
this.apiKey = this.header(HttpHeader.ApiKey);
const apiKey = this.query[Const.ApiKey.QueryString] ?? this.header(HttpHeader.ApiKey);
if (apiKey != null && !Array.isArray(apiKey)) {
this.apiKey = apiKey;
this.set(Const.ApiKey.QueryString, this.apiKey);
}
this.correlationId = this.header(HttpHeader.CorrelationId) ?? ulid.ulid();
this.set('correlationId', this.correlationId);
this.log = logger.child({ id: this.id });
Expand Down Expand Up @@ -88,6 +93,7 @@ export class LambdaContext {
return this.evt.queryStringParameters ?? {};
}
const query = this.evt.Records[0].cf.request.querystring;
if (query == null || query[0] == null) return {};
return qs.decode(query[0] == '?' ? query.substr(1) : query);
}

Expand Down
Loading

0 comments on commit 7c4689c

Please sign in to comment.