Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(middleware-flexible-checksums): delay checksum validation until stream read #6629

Merged
merged 4 commits into from
Nov 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 22 additions & 7 deletions clients/client-s3/test/e2e/S3.e2e.spec.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import "@aws-sdk/signature-v4-crt";

import { S3, SelectObjectContentEventStream } from "@aws-sdk/client-s3";
import { ChecksumAlgorithm, S3, SelectObjectContentEventStream } from "@aws-sdk/client-s3";
import { afterAll, afterEach, beforeAll, describe, expect, test as it } from "vitest";

import { getIntegTestResources } from "../../../../tests/e2e/get-integ-test-resources";
Expand All @@ -24,9 +24,7 @@ describe("@aws-sdk/client-s3", () => {

Key = ``;

client = new S3({
region,
});
client = new S3({ region });
});

describe("PutObject", () => {
Expand Down Expand Up @@ -74,26 +72,43 @@ describe("@aws-sdk/client-s3", () => {
await client.deleteObject({ Bucket, Key });
});

it("should succeed with valid body payload", async () => {
it("should succeed with valid body payload with checksums", async () => {
// prepare the object.
const body = createBuffer("1MB");
let bodyChecksum = "";

const bodyChecksumReader = (next) => async (args) => {
const checksumValue = args.request.headers["x-amz-checksum-crc32"];
if (checksumValue) {
bodyChecksum = checksumValue;
}
return next(args);
};
client.middlewareStack.addRelativeTo(bodyChecksumReader, {
name: "bodyChecksumReader",
relation: "before",
toMiddleware: "deserializerMiddleware",
});

try {
await client.putObject({ Bucket, Key, Body: body });
await client.putObject({ Bucket, Key, Body: body, ChecksumAlgorithm: ChecksumAlgorithm.CRC32 });
} catch (e) {
console.error("failed to put");
throw e;
}

expect(bodyChecksum).not.toEqual("");

try {
// eslint-disable-next-line no-var
var result = await client.getObject({ Bucket, Key });
var result = await client.getObject({ Bucket, Key, ChecksumMode: "ENABLED" });
} catch (e) {
console.error("failed to get");
throw e;
}

expect(result.$metadata.httpStatusCode).toEqual(200);
expect(result.ChecksumCRC32).toEqual(bodyChecksum);
const { Readable } = require("stream");
expect(result.Body).toBeInstanceOf(Readable);
});
Expand Down
1 change: 1 addition & 0 deletions packages/middleware-flexible-checksums/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
"@smithy/types": "^3.6.0",
"@smithy/util-middleware": "^3.0.8",
"@smithy/util-utf8": "^3.0.0",
"@smithy/util-stream": "^3.2.1",
"tslib": "^2.6.2"
},
"devDependencies": {
Expand Down
16 changes: 1 addition & 15 deletions packages/middleware-flexible-checksums/src/getChecksum.spec.ts
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
import { afterEach, beforeEach, describe, expect, test as it, vi } from "vitest";

import { getChecksum } from "./getChecksum";
import { isStreaming } from "./isStreaming";
import { stringHasher } from "./stringHasher";

vi.mock("./isStreaming");
vi.mock("./stringHasher");

describe(getChecksum.name, () => {
const mockOptions = {
streamHasher: vi.fn(),
checksumAlgorithmFn: vi.fn(),
base64Encoder: vi.fn(),
};
Expand All @@ -26,21 +23,10 @@ describe(getChecksum.name, () => {
vi.clearAllMocks();
});

it("gets checksum from streamHasher if body is streaming", async () => {
vi.mocked(isStreaming).mockReturnValue(true);
mockOptions.streamHasher.mockResolvedValue(mockRawOutput);
const checksum = await getChecksum(mockBody, mockOptions);
expect(checksum).toEqual(mockOutput);
expect(stringHasher).not.toHaveBeenCalled();
expect(mockOptions.streamHasher).toHaveBeenCalledWith(mockOptions.checksumAlgorithmFn, mockBody);
});

it("gets checksum from stringHasher if body is not streaming", async () => {
vi.mocked(isStreaming).mockReturnValue(false);
it("gets checksum from stringHasher", async () => {
vi.mocked(stringHasher).mockResolvedValue(mockRawOutput);
const checksum = await getChecksum(mockBody, mockOptions);
expect(checksum).toEqual(mockOutput);
expect(mockOptions.streamHasher).not.toHaveBeenCalled();
expect(stringHasher).toHaveBeenCalledWith(mockOptions.checksumAlgorithmFn, mockBody);
});
});
13 changes: 3 additions & 10 deletions packages/middleware-flexible-checksums/src/getChecksum.ts
Original file line number Diff line number Diff line change
@@ -1,18 +1,11 @@
import { ChecksumConstructor, Encoder, HashConstructor, StreamHasher } from "@smithy/types";
import { ChecksumConstructor, Encoder, HashConstructor } from "@smithy/types";

import { isStreaming } from "./isStreaming";
import { stringHasher } from "./stringHasher";

export interface GetChecksumDigestOptions {
streamHasher: StreamHasher<any>;
checksumAlgorithmFn: ChecksumConstructor | HashConstructor;
base64Encoder: Encoder;
}

export const getChecksum = async (
body: unknown,
{ streamHasher, checksumAlgorithmFn, base64Encoder }: GetChecksumDigestOptions
) => {
const digest = isStreaming(body) ? streamHasher(checksumAlgorithmFn, body) : stringHasher(checksumAlgorithmFn, body);
return base64Encoder(await digest);
};
export const getChecksum = async (body: unknown, { checksumAlgorithmFn, base64Encoder }: GetChecksumDigestOptions) =>
base64Encoder(await stringHasher(checksumAlgorithmFn, body));
Original file line number Diff line number Diff line change
@@ -1,26 +1,30 @@
import { HttpResponse } from "@smithy/protocol-http";
import { createChecksumStream } from "@smithy/util-stream";
import { afterEach, beforeEach, describe, expect, test as it, vi } from "vitest";

import { PreviouslyResolved } from "./configuration";
import { ChecksumAlgorithm } from "./constants";
import { getChecksum } from "./getChecksum";
import { getChecksumAlgorithmListForResponse } from "./getChecksumAlgorithmListForResponse";
import { getChecksumLocationName } from "./getChecksumLocationName";
import { isStreaming } from "./isStreaming";
import { selectChecksumAlgorithmFunction } from "./selectChecksumAlgorithmFunction";
import { validateChecksumFromResponse } from "./validateChecksumFromResponse";

vi.mock("@smithy/util-stream");
vi.mock("./getChecksum");
vi.mock("./getChecksumLocationName");
vi.mock("./getChecksumAlgorithmListForResponse");
vi.mock("./isStreaming");
vi.mock("./selectChecksumAlgorithmFunction");

describe(validateChecksumFromResponse.name, () => {
const mockConfig = {
streamHasher: vi.fn(),
base64Encoder: vi.fn(),
} as unknown as PreviouslyResolved;

const mockBody = {};
const mockBodyStream = { isStream: true };
const mockHeaders = {};
const mockResponse = {
body: mockBody,
Expand Down Expand Up @@ -50,6 +54,7 @@ describe(validateChecksumFromResponse.name, () => {
vi.mocked(getChecksumAlgorithmListForResponse).mockImplementation((responseAlgorithms) => responseAlgorithms);
vi.mocked(selectChecksumAlgorithmFunction).mockReturnValue(mockChecksumAlgorithmFn);
vi.mocked(getChecksum).mockResolvedValue(mockChecksum);
vi.mocked(createChecksumStream).mockReturnValue(mockBodyStream);
});

afterEach(() => {
Expand Down Expand Up @@ -85,31 +90,56 @@ describe(validateChecksumFromResponse.name, () => {
});

describe("successful validation", () => {
afterEach(() => {
const validateCalls = (isStream: boolean, checksumAlgoFn: ChecksumAlgorithm) => {
expect(getChecksumAlgorithmListForResponse).toHaveBeenCalledWith(mockResponseAlgorithms);
expect(selectChecksumAlgorithmFunction).toHaveBeenCalledTimes(1);
expect(getChecksum).toHaveBeenCalledTimes(1);
});

it("when checksum is populated for first algorithm", async () => {
if (isStream) {
expect(getChecksum).not.toHaveBeenCalled();
expect(createChecksumStream).toHaveBeenCalledTimes(1);
expect(createChecksumStream).toHaveBeenCalledWith({
expectedChecksum: mockChecksum,
checksumSourceLocation: checksumAlgoFn,
checksum: new mockChecksumAlgorithmFn(),
source: mockBody,
base64Encoder: mockConfig.base64Encoder,
});
} else {
expect(getChecksum).toHaveBeenCalledTimes(1);
expect(getChecksum).toHaveBeenCalledWith(mockBody, {
checksumAlgorithmFn: mockChecksumAlgorithmFn,
base64Encoder: mockConfig.base64Encoder,
});
expect(createChecksumStream).not.toHaveBeenCalled();
}
};

it.each([false, true])("when checksum is populated for first algorithm when streaming: %s", async (isStream) => {
vi.mocked(isStreaming).mockReturnValue(isStream);
const responseWithChecksum = getMockResponseWithHeader(mockResponseAlgorithms[0], mockChecksum);
await validateChecksumFromResponse(responseWithChecksum, mockOptions);
expect(getChecksumLocationName).toHaveBeenCalledTimes(1);
expect(getChecksumLocationName).toHaveBeenCalledWith(mockResponseAlgorithms[0]);
validateCalls(isStream, mockResponseAlgorithms[0]);
});

it("when checksum is populated for second algorithm", async () => {
it.each([false, true])("when checksum is populated for second algorithm when streaming: %s", async (isStream) => {
vi.mocked(isStreaming).mockReturnValue(isStream);
const responseWithChecksum = getMockResponseWithHeader(mockResponseAlgorithms[1], mockChecksum);
await validateChecksumFromResponse(responseWithChecksum, mockOptions);
expect(getChecksumLocationName).toHaveBeenCalledTimes(2);
expect(getChecksumLocationName).toHaveBeenNthCalledWith(1, mockResponseAlgorithms[0]);
expect(getChecksumLocationName).toHaveBeenNthCalledWith(2, mockResponseAlgorithms[1]);
validateCalls(isStream, mockResponseAlgorithms[1]);
});
});

it("throw error if checksum value is not accurate", async () => {
it("throw error if checksum value is not accurate when not streaming", async () => {
vi.mocked(isStreaming).mockReturnValue(false);

const incorrectChecksum = "incorrectChecksum";
const responseWithChecksum = getMockResponseWithHeader(mockResponseAlgorithms[0], incorrectChecksum);

try {
await validateChecksumFromResponse(responseWithChecksum, mockOptions);
fail("should throw checksum mismatch error");
Expand All @@ -119,9 +149,28 @@ describe(validateChecksumFromResponse.name, () => {
` in response header "${mockResponseAlgorithms[0]}".`
);
}

expect(getChecksumAlgorithmListForResponse).toHaveBeenCalledWith(mockResponseAlgorithms);
expect(selectChecksumAlgorithmFunction).toHaveBeenCalledTimes(1);
expect(getChecksumLocationName).toHaveBeenCalledTimes(1);
expect(getChecksum).toHaveBeenCalledTimes(1);
expect(createChecksumStream).not.toHaveBeenCalled();
});

it("return if checksum value is not accurate when streaming, as error will be thrown when stream is consumed", async () => {
vi.mocked(isStreaming).mockReturnValue(true);

// This override does not matter for the purpose of unit test, but is kept for completeness.
const incorrectChecksum = "incorrectChecksum";
const responseWithChecksum = getMockResponseWithHeader(mockResponseAlgorithms[0], incorrectChecksum);

await validateChecksumFromResponse(responseWithChecksum, mockOptions);

expect(getChecksumAlgorithmListForResponse).toHaveBeenCalledWith(mockResponseAlgorithms);
expect(selectChecksumAlgorithmFunction).toHaveBeenCalledTimes(1);
expect(getChecksumLocationName).toHaveBeenCalledTimes(1);
expect(getChecksum).not.toHaveBeenCalled();
expect(createChecksumStream).toHaveBeenCalledTimes(1);
expect(responseWithChecksum.body).toBe(mockBodyStream);
});
});
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import { HttpResponse } from "@smithy/protocol-http";
import { Checksum } from "@smithy/types";
import { createChecksumStream } from "@smithy/util-stream";

import { PreviouslyResolved } from "./configuration";
import { ChecksumAlgorithm } from "./constants";
import { getChecksum } from "./getChecksum";
import { getChecksumAlgorithmListForResponse } from "./getChecksumAlgorithmListForResponse";
import { getChecksumLocationName } from "./getChecksumLocationName";
import { isStreaming } from "./isStreaming";
import { selectChecksumAlgorithmFunction } from "./selectChecksumAlgorithmFunction";

export interface ValidateChecksumFromResponseOptions {
Expand All @@ -29,9 +32,20 @@ export const validateChecksumFromResponse = async (
const checksumFromResponse = responseHeaders[responseHeader];
if (checksumFromResponse) {
const checksumAlgorithmFn = selectChecksumAlgorithmFunction(algorithm as ChecksumAlgorithm, config);
const { streamHasher, base64Encoder } = config;
const checksum = await getChecksum(responseBody, { streamHasher, checksumAlgorithmFn, base64Encoder });
const { base64Encoder } = config;

if (isStreaming(responseBody)) {
response.body = createChecksumStream({
trivikr marked this conversation as resolved.
Show resolved Hide resolved
expectedChecksum: checksumFromResponse,
checksumSourceLocation: responseHeader,
checksum: new checksumAlgorithmFn() as Checksum,
source: responseBody,
base64Encoder,
});
return;
}

const checksum = await getChecksum(responseBody, { checksumAlgorithmFn, base64Encoder });
if (checksum === checksumFromResponse) {
// The checksum for response payload is valid.
break;
Expand Down
Loading