Skip to content

Commit

Permalink
fix: validate arn
Browse files Browse the repository at this point in the history
  • Loading branch information
mazyu36 committed Jul 4, 2024
1 parent 17506d6 commit c55c1df
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 6 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { Token } from '../../../core';
import { Arn, ArnFormat, Token } from '../../../core';

/**
* Guradrail settings for BedrockInvokeModel
Expand Down Expand Up @@ -34,10 +34,21 @@ export class Guardrail {
*/
private constructor(public readonly guardrailIdentifier: string, public readonly guardrailVersion: string) {
if (!Token.isUnresolved(guardrailIdentifier)) {
const guardrailPattern = /^(([a-z0-9]+)|(arn:aws(-[^:]+)?:bedrock:[a-z0-9-]{1,20}:[0-9]{12}:guardrail\/[a-z0-9]+))$/;
let gurdrailId = undefined;

if (!guardrailPattern.test(guardrailIdentifier)) {
throw new Error(`You must set guardrailIdentifier to the id or the arn of Guardrail, got ${guardrailIdentifier}`);
if (guardrailIdentifier.startsWith('arn:')) {
const arn = Arn.split(guardrailIdentifier, ArnFormat.SLASH_RESOURCE_NAME);
if (!arn.resourceName) {
throw new Error(`Invalid ARN format. The ARN of Guradrail should have the format: \`arn:aws:bedrock:<region>:<account-id>:guardrail/<guardrail-name>\`, got ${guardrailIdentifier}.`);
}
gurdrailId = arn.resourceName;
} else {
gurdrailId = guardrailIdentifier;
}

const guardrailPattern = /^[a-z0-9]+$/;
if (!guardrailPattern.test(gurdrailId)) {
throw new Error(`The id of Guardrail must contain only lowercase letters and numbers, got ${gurdrailId}.`);
}

if (guardrailIdentifier.length > 2048) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,7 @@ describe('Invoke Model', () => {
});
});

test('guardrail fails when invalid guardrailIdentifier is set', () => {
test('guardrail fails when guardrailIdentifier is set to invalid id', () => {
// GIVEN
const stack = new cdk.Stack();
const model = bedrock.ProvisionedModel.fromProvisionedModelArn(stack, 'Imported', 'arn:aws:bedrock:us-turbo-2:123456789012:provisioned-model/abc-123');
Expand All @@ -526,7 +526,27 @@ describe('Invoke Model', () => {
guardrail: Guardrail.enableDraft('invalid-id'),
});
// THEN
}).toThrow('You must set guardrailIdentifier to the id or the arn of Guardrail, got invalid-id');
}).toThrow('The id of Guardrail must contain only lowercase letters and numbers, got invalid-id');
});

test('guardrail fails when guardrailIdentifier is set to invalid ARN', () => {
// GIVEN
const stack = new cdk.Stack();
const model = bedrock.ProvisionedModel.fromProvisionedModelArn(stack, 'Imported', 'arn:aws:bedrock:us-turbo-2:123456789012:provisioned-model/abc-123');

expect(() => {
// WHEN
new BedrockInvokeModel(stack, 'Invoke', {
model,
body: sfn.TaskInput.fromObject(
{
prompt: 'Hello world',
},
),
guardrail: Guardrail.enableDraft('arn:aws:bedrock:us-turbo-2:123456789012:guardrail'),
});
// THEN
}).toThrow('Invalid ARN format. The ARN of Guradrail should have the format: `arn:aws:bedrock:<region>:<account-id>:guardrail/<guardrail-name>`, got arn:aws:bedrock:us-turbo-2:123456789012:guardrail.');
});

test('guardrail fails when guardrailIdentifier length is invalid', () => {
Expand Down

0 comments on commit c55c1df

Please sign in to comment.