Skip to content

Commit

Permalink
fix: fix policy statement when guardrailIdentifier is set to arn
Browse files Browse the repository at this point in the history
  • Loading branch information
mazyu36 committed Jun 22, 2024
1 parent 121bb62 commit d5e5a58
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -221,15 +221,18 @@ export class BedrockInvokeModel extends sfn.TaskStateBase {
}

if (this.props.guardrail) {
const isArn = this.props.guardrail.guardrailIdentifier.startsWith('arn:');
policyStatements.push(
new iam.PolicyStatement({
actions: ['bedrock:ApplyGuardrail'],
resources: [
Stack.of(this).formatArn({
service: 'bedrock',
resource: 'guardrail',
resourceName: this.props.guardrail.guardrailIdentifier,
}),
isArn
? this.props.guardrail.guardrailIdentifier
: Stack.of(this).formatArn({
service: 'bedrock',
resource: 'guardrail',
resourceName: this.props.guardrail.guardrailIdentifier,
}),
],
}),
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ describe('Invoke Model', () => {
}).toThrow(/Output S3 object version is not supported./);
});

test('guardrail', () => {
test('guardrail when gurdarilIdentifier is set to arn', () => {
// GIVEN
const stack = new cdk.Stack();
const model = bedrock.ProvisionedModel.fromProvisionedModelArn(stack, 'Imported', 'arn:aws:bedrock:us-turbo-2:123456789012:provisioned-model/abc-123');
Expand All @@ -381,6 +381,10 @@ describe('Invoke Model', () => {
},
});

new sfn.StateMachine(stack, 'StateMachine', {
definitionBody: sfn.DefinitionBody.fromChainable(task),
});

// THEN
expect(stack.resolve(task.toStateJson())).toEqual({
Type: 'Task',
Expand All @@ -407,6 +411,111 @@ describe('Invoke Model', () => {
GuardrailVersion: 'DRAFT',
},
});

Template.fromStack(stack).hasResourceProperties('AWS::IAM::Policy', {
PolicyDocument: {
Statement: [
{
Action: 'bedrock:InvokeModel',
Effect: 'Allow',
Resource: 'arn:aws:bedrock:us-turbo-2:123456789012:provisioned-model/abc-123',
},
{
Action: 'bedrock:ApplyGuardrail',
Effect: 'Allow',
Resource: 'arn:aws:bedrock:us-turbo-2:123456789012:guardrail/testid',
},
],
},
});
});

test('guardrail when gurdarilIdentifier is set to id', () => {
// GIVEN
const stack = new cdk.Stack();
const model = bedrock.ProvisionedModel.fromProvisionedModelArn(stack, 'Imported', 'arn:aws:bedrock:us-turbo-2:123456789012:provisioned-model/abc-123');

// WHEN
const task = new BedrockInvokeModel(stack, 'Invoke', {
model,
contentType: 'application/json',
body: sfn.TaskInput.fromObject(
{
prompt: 'Hello world',
},
),
guardrail: {
guardrailIdentifier: 'testid',
guardrailVersion: 'DRAFT',
},
});

new sfn.StateMachine(stack, 'StateMachine', {
definitionBody: sfn.DefinitionBody.fromChainable(task),
});

// THEN
expect(stack.resolve(task.toStateJson())).toEqual({
Type: 'Task',
Resource: {
'Fn::Join': [
'',
[
'arn:',
{
Ref: 'AWS::Partition',
},
':states:::bedrock:invokeModel',
],
],
},
End: true,
Parameters: {
ModelId: 'arn:aws:bedrock:us-turbo-2:123456789012:provisioned-model/abc-123',
Body: {
prompt: 'Hello world',
},
ContentType: 'application/json',
GuardrailIdentifier: 'testid',
GuardrailVersion: 'DRAFT',
},
});

Template.fromStack(stack).hasResourceProperties('AWS::IAM::Policy', {
PolicyDocument: {
Statement: [
{
Action: 'bedrock:InvokeModel',
Effect: 'Allow',
Resource: 'arn:aws:bedrock:us-turbo-2:123456789012:provisioned-model/abc-123',
},
{
Action: 'bedrock:ApplyGuardrail',
Effect: 'Allow',
Resource: {
'Fn::Join': [
'',
[
'arn:',
{
Ref: 'AWS::Partition',
},
':bedrock:',
{
Ref: 'AWS::Region',
},
':',
{
Ref: 'AWS::AccountId',
},
':guardrail/testid',
],
],
},
},
],
},
});
});

test('guardrail fails when invalid guardrailIdentifier is set', () => {
Expand Down

0 comments on commit d5e5a58

Please sign in to comment.