Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(middleware-flexible-checksums): use RequestChecksumCalculation and ResponseChecksumValidation without interceptor middleware #6484

Closed
wants to merge 9 commits into from
13 changes: 13 additions & 0 deletions packages/middleware-flexible-checksums/src/configuration.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@ import {
Encoder,
GetAwsChunkedEncodingStream,
HashConstructor,
Provider,
StreamCollector,
StreamHasher,
} from "@smithy/types";

import { RequestChecksumCalculation, ResponseChecksumValidation } from "./constants";

export interface PreviouslyResolved {
/**
* The function that will be used to convert binary data to a base64-encoded string.
Expand All @@ -31,6 +34,16 @@ export interface PreviouslyResolved {
*/
md5: ChecksumConstructor | HashConstructor;

/**
* Determines when a checksum will be calculated for request payloads
*/
requestChecksumCalculation: Provider<RequestChecksumCalculation>;

/**
* Determines when a checksum will be calculated for response payloads
*/
responseChecksumValidation: Provider<ResponseChecksumValidation>;

/**
* A constructor for a class implementing the {@link Hash} interface that computes SHA1 hashes.
* @internal
Expand Down
5 changes: 4 additions & 1 deletion packages/middleware-flexible-checksums/src/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ export const DEFAULT_RESPONSE_CHECKSUM_VALIDATION = RequestChecksumCalculation.W
* Checksum Algorithms supported by the SDK.
*/
export enum ChecksumAlgorithm {
/**
* @deprecated Use {@link ChecksumAlgorithm.CRC32} instead.
*/
MD5 = "MD5",
CRC32 = "CRC32",
CRC32C = "CRC32C",
Expand All @@ -70,7 +73,7 @@ export enum ChecksumLocation {
/**
* @internal
*/
export const DEFAULT_CHECKSUM_ALGORITHM = ChecksumAlgorithm.MD5;
export const DEFAULT_CHECKSUM_ALGORITHM = ChecksumAlgorithm.CRC32;

/**
* @internal
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { HttpRequest } from "@smithy/protocol-http";
import { BuildHandlerArguments } from "@smithy/types";

import { PreviouslyResolved } from "./configuration";
import { ChecksumAlgorithm } from "./constants";
import { ChecksumAlgorithm, RequestChecksumCalculation } from "./constants";
import { flexibleChecksumsMiddleware } from "./flexibleChecksumsMiddleware";
import { getChecksumAlgorithmForRequest } from "./getChecksumAlgorithmForRequest";
import { getChecksumLocationName } from "./getChecksumLocationName";
Expand All @@ -27,7 +27,9 @@ describe(flexibleChecksumsMiddleware.name, () => {
const mockChecksumLocationName = "mock-checksum-location-name";

const mockInput = {};
const mockConfig = {} as PreviouslyResolved;
const mockConfig = {
requestChecksumCalculation: () => Promise.resolve(RequestChecksumCalculation.WHEN_REQUIRED),
} as PreviouslyResolved;
const mockMiddlewareConfig = { input: mockInput, requestChecksumRequired: false };

const mockBody = { body: "mockRequestBody" };
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,15 @@ export const flexibleChecksumsMiddleware =
const { request } = args;
const { body: requestBody, headers } = request;
const { base64Encoder, streamHasher } = config;
const requestChecksumCalculation = await config.requestChecksumCalculation();
const { input, requestChecksumRequired, requestAlgorithmMember } = middlewareConfig;

const checksumAlgorithm = getChecksumAlgorithmForRequest(
input,
{
requestChecksumRequired,
requestAlgorithmMember,
requestChecksumCalculation,
},
!!context.isS3ExpressBucket
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { HttpRequest } from "@smithy/protocol-http";
import { DeserializeHandlerArguments } from "@smithy/types";

import { PreviouslyResolved } from "./configuration";
import { ChecksumAlgorithm } from "./constants";
import { ChecksumAlgorithm, ResponseChecksumValidation } from "./constants";
import { flexibleChecksumsResponseMiddleware } from "./flexibleChecksumsResponseMiddleware";
import { getChecksumLocationName } from "./getChecksumLocationName";
import { FlexibleChecksumsMiddlewareConfig } from "./getFlexibleChecksumsPlugin";
Expand All @@ -23,7 +23,9 @@ describe(flexibleChecksumsResponseMiddleware.name, () => {
commandName: "mockCommandName",
};

const mockConfig = {} as PreviouslyResolved;
const mockConfig = {
responseChecksumValidation: () => Promise.resolve(ResponseChecksumValidation.WHEN_REQUIRED),
} as PreviouslyResolved;
const mockRequestValidationModeMember = "ChecksumEnabled";
const mockResponseAlgorithms = [ChecksumAlgorithm.CRC32, ChecksumAlgorithm.CRC32C];
const mockMiddlewareConfig = {
Expand Down Expand Up @@ -59,52 +61,66 @@ describe(flexibleChecksumsResponseMiddleware.name, () => {
});

describe("skips", () => {
it("if not an instance of HttpRequest", async () => {
const { isInstance } = HttpRequest;
(isInstance as unknown as jest.Mock).mockReturnValue(false);
const handler = flexibleChecksumsResponseMiddleware(mockConfig, mockMiddlewareConfig)(mockNext, mockContext);
it("if requestValidationModeMember is not defined", async () => {
const mockMwConfig = Object.assign({}, mockMiddlewareConfig) as FlexibleChecksumsMiddlewareConfig;
delete mockMwConfig.requestValidationModeMember;
const handler = flexibleChecksumsResponseMiddleware(mockConfig, mockMwConfig)(mockNext, mockContext);
await handler(mockArgs);
expect(validateChecksumFromResponse).not.toHaveBeenCalled();
expect(mockNext).toHaveBeenCalledWith(mockArgs);
});

describe("response checksum", () => {
it("if requestValidationModeMember is not defined", async () => {
const mockMwConfig = Object.assign({}, mockMiddlewareConfig) as FlexibleChecksumsMiddlewareConfig;
delete mockMwConfig.requestValidationModeMember;
const handler = flexibleChecksumsResponseMiddleware(mockConfig, mockMwConfig)(mockNext, mockContext);
await handler(mockArgs);
expect(validateChecksumFromResponse).not.toHaveBeenCalled();
});
it("if requestValidationModeMember is not enabled in input", async () => {
const handler = flexibleChecksumsResponseMiddleware(mockConfig, mockMiddlewareConfig)(mockNext, mockContext);

it("if requestValidationModeMember is not enabled in input", async () => {
const handler = flexibleChecksumsResponseMiddleware(mockConfig, mockMiddlewareConfig)(mockNext, mockContext);
await handler({ ...mockArgs, input: {} });
expect(validateChecksumFromResponse).not.toHaveBeenCalled();
});
const mockArgsWithoutEnabled = { ...mockArgs, input: {} };
await handler(mockArgsWithoutEnabled);
expect(validateChecksumFromResponse).not.toHaveBeenCalled();
expect(mockNext).toHaveBeenCalledWith(mockArgsWithoutEnabled);
});

it("if checksum is for S3 whole-object multipart GET", async () => {
(isChecksumWithPartNumber as jest.Mock).mockReturnValue(true);
const handler = flexibleChecksumsResponseMiddleware(mockConfig, mockMiddlewareConfig)(mockNext, {
clientName: "S3Client",
commandName: "GetObjectCommand",
});
await handler(mockArgs);
expect(isChecksumWithPartNumber).toHaveBeenCalledTimes(1);
expect(isChecksumWithPartNumber).toHaveBeenCalledWith(mockChecksum);
expect(validateChecksumFromResponse).not.toHaveBeenCalled();
it("if checksum is for S3 whole-object multipart GET", async () => {
(isChecksumWithPartNumber as jest.Mock).mockReturnValue(true);
const handler = flexibleChecksumsResponseMiddleware(mockConfig, mockMiddlewareConfig)(mockNext, {
clientName: "S3Client",
commandName: "GetObjectCommand",
});
await handler(mockArgs);
expect(isChecksumWithPartNumber).toHaveBeenCalledTimes(1);
expect(isChecksumWithPartNumber).toHaveBeenCalledWith(mockChecksum);
expect(validateChecksumFromResponse).not.toHaveBeenCalled();
expect(mockNext).toHaveBeenCalledWith(mockArgs);
});
});

describe("validates checksum from response header", () => {
it("generic case", async () => {
it("if requestValidationModeMember is enabled in input", async () => {
const handler = flexibleChecksumsResponseMiddleware(mockConfig, mockMiddlewareConfig)(mockNext, mockContext);

await handler(mockArgs);
expect(validateChecksumFromResponse).toHaveBeenCalledWith(mockResult.response, {
config: mockConfig,
responseAlgorithms: mockResponseAlgorithms,
});
expect(mockNext).toHaveBeenCalledWith(mockArgs);
});

it(`if requestValidationModeMember is not enabled in input, but responseChecksumValidation returns ${ResponseChecksumValidation.WHEN_SUPPORTED}`, async () => {
const mockConfigWithResponseChecksumValidationSupported = {
...mockConfig,
responseChecksumValidation: () => Promise.resolve(ResponseChecksumValidation.WHEN_SUPPORTED),
};
const handler = flexibleChecksumsResponseMiddleware(
mockConfigWithResponseChecksumValidationSupported,
mockMiddlewareConfig
)(mockNext, mockContext);

await handler({ ...mockArgs, input: {} });
expect(validateChecksumFromResponse).toHaveBeenCalledWith(mockResult.response, {
config: mockConfigWithResponseChecksumValidationSupported,
responseAlgorithms: mockResponseAlgorithms,
});
expect(mockNext).toHaveBeenCalledWith(mockArgs);
});

it("if checksum is for S3 GET without part number", async () => {
Expand All @@ -120,6 +136,7 @@ describe(flexibleChecksumsResponseMiddleware.name, () => {
config: mockConfig,
responseAlgorithms: mockResponseAlgorithms,
});
expect(mockNext).toHaveBeenCalledWith(mockArgs);
});
});
});
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
import { HttpRequest, HttpResponse } from "@smithy/protocol-http";
import { HttpResponse } from "@smithy/protocol-http";
import {
DeserializeHandler,
DeserializeHandlerArguments,
DeserializeHandlerOutput,
DeserializeMiddleware,
HandlerExecutionContext,
MetadataBearer,
RelativeMiddlewareOptions,
SerializeHandler,
SerializeHandlerArguments,
SerializeHandlerOutput,
SerializeMiddleware,
} from "@smithy/types";

import { PreviouslyResolved } from "./configuration";
import { ChecksumAlgorithm } from "./constants";
import { ChecksumAlgorithm, ResponseChecksumValidation } from "./constants";
import { getChecksumAlgorithmListForResponse } from "./getChecksumAlgorithmListForResponse";
import { getChecksumLocationName } from "./getChecksumLocationName";
import { isChecksumWithPartNumber } from "./isChecksumWithPartNumber";
Expand All @@ -37,8 +37,8 @@ export interface FlexibleChecksumsResponseMiddlewareConfig {
*/
export const flexibleChecksumsResponseMiddlewareOptions: RelativeMiddlewareOptions = {
name: "flexibleChecksumsResponseMiddleware",
toMiddleware: "deserializerMiddleware",
relation: "after",
toMiddleware: "serializerMiddleware",
relation: "before",
tags: ["BODY_CHECKSUM"],
override: true,
};
Expand All @@ -52,32 +52,38 @@ export const flexibleChecksumsResponseMiddleware =
(
config: PreviouslyResolved,
middlewareConfig: FlexibleChecksumsResponseMiddlewareConfig
): DeserializeMiddleware<any, any> =>
): SerializeMiddleware<any, any> =>
<Output extends MetadataBearer>(
next: DeserializeHandler<any, Output>,
next: SerializeHandler<any, Output>,
context: HandlerExecutionContext
): DeserializeHandler<any, Output> =>
async (args: DeserializeHandlerArguments<any>): Promise<DeserializeHandlerOutput<Output>> => {
if (!HttpRequest.isInstance(args.request)) {
return next(args);
): SerializeHandler<any, Output> =>
async (args: SerializeHandlerArguments<any>): Promise<SerializeHandlerOutput<Output>> => {
const input = args.input;
const { requestValidationModeMember, responseAlgorithms } = middlewareConfig;
const responseChecksumValidation = await config.responseChecksumValidation();

const isResponseChecksumValidationNeeded =
requestValidationModeMember &&
(input[requestValidationModeMember] === "ENABLED" ||
responseChecksumValidation === ResponseChecksumValidation.WHEN_SUPPORTED);

if (isResponseChecksumValidationNeeded) {
input[requestValidationModeMember] = "ENABLED";
}

const input = args.input;
const result = await next(args);

const response = result.response as HttpResponse;
let collectedStream: Uint8Array | undefined = undefined;

const { requestValidationModeMember, responseAlgorithms } = middlewareConfig;
// @ts-ignore Element implicitly has an 'any' type for input[requestValidationModeMember]
if (requestValidationModeMember && input[requestValidationModeMember] === "ENABLED") {
if (isResponseChecksumValidationNeeded) {
const { clientName, commandName } = context;
const isS3WholeObjectMultipartGetResponseChecksum =
clientName === "S3Client" &&
commandName === "GetObjectCommand" &&
getChecksumAlgorithmListForResponse(responseAlgorithms).every((algorithm: ChecksumAlgorithm) => {
const responseHeader = getChecksumLocationName(algorithm);
const checksumFromResponse = response.headers[responseHeader];
const checksumFromResponse = response.headers?.[responseHeader];
return !checksumFromResponse || isChecksumWithPartNumber(checksumFromResponse);
});
if (isS3WholeObjectMultipartGetResponseChecksum) {
Expand Down
Loading
Loading