Skip to content

Commit

Permalink
Add KMS Signer (#229)
Browse files Browse the repository at this point in the history
* kms signer

* add kms signer

* update signatures

* remove GATEWAY_LISTEN

* comment reason why v is 0/1
  • Loading branch information
ian-shim committed May 6, 2024
1 parent 1fcd296 commit 069dbf2
Show file tree
Hide file tree
Showing 12 changed files with 548 additions and 38 deletions.
36 changes: 36 additions & 0 deletions aws/config.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package aws

import (
"context"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"
)

func GetAWSConfig(accessKey, secretAccessKey, region, endpointURL string) (*aws.Config, error) {
createClient := func(service, region string, options ...interface{}) (aws.Endpoint, error) {
if endpointURL != "" {
return aws.Endpoint{
PartitionID: "aws",
URL: endpointURL,
SigningRegion: region,
}, nil
}

// returning EndpointNotFoundError will allow the service to fallback to its default resolution
return aws.Endpoint{}, &aws.EndpointNotFoundError{}
}
customResolver := aws.EndpointResolverWithOptionsFunc(createClient)

cfg, errCfg := config.LoadDefaultConfig(context.Background(),
config.WithRegion(region),
config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(accessKey, secretAccessKey, "")),
config.WithEndpointResolverWithOptions(customResolver),
config.WithRetryMode(aws.RetryModeStandard),
)
if errCfg != nil {
return nil, errCfg
}
return &cfg, nil
}
19 changes: 19 additions & 0 deletions aws/kms/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package kms

import (
"context"
"fmt"

"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/kms"
)

func NewKMSClient(ctx context.Context, region string) (*kms.Client, error) {
config, err := config.LoadDefaultConfig(ctx, config.WithRegion(region))
if err != nil {
return nil, fmt.Errorf("failed to load AWS config: %w", err)
}

c := kms.NewFromConfig(config)
return c, nil
}
46 changes: 46 additions & 0 deletions aws/kms/get_public_key.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package kms

import (
"context"
"crypto/ecdsa"
"encoding/asn1"
"fmt"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/kms"
"github.com/ethereum/go-ethereum/crypto"
)

type asn1EcPublicKey struct {
EcPublicKeyInfo asn1EcPublicKeyInfo
PublicKey asn1.BitString
}

type asn1EcPublicKeyInfo struct {
Algorithm asn1.ObjectIdentifier
Parameters asn1.ObjectIdentifier
}

// GetECDSAPublicKey retrieves the ECDSA public key for a KMS key
// It assumes the key is set up with `ECC_SECG_P256K1` key spec and `SIGN_VERIFY` key usage
func GetECDSAPublicKey(ctx context.Context, svc *kms.Client, keyId string) (*ecdsa.PublicKey, error) {
getPubKeyOutput, err := svc.GetPublicKey(ctx, &kms.GetPublicKeyInput{
KeyId: aws.String(keyId),
})
if err != nil {
return nil, fmt.Errorf("failed to get public key for KeyId=%s: %w", keyId, err)
}

var asn1pubk asn1EcPublicKey
_, err = asn1.Unmarshal(getPubKeyOutput.PublicKey, &asn1pubk)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal public key for KeyId=%s: %w", keyId, err)
}

pubkey, err := crypto.UnmarshalPubkey(asn1pubk.PublicKey.Bytes)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal public key for KeyId=%s: %w", keyId, err)
}

return pubkey, nil
}
68 changes: 68 additions & 0 deletions aws/kms/get_public_key_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package kms_test

import (
"context"
"fmt"
"os"
"testing"

eigenkms "github.com/Layr-Labs/eigensdk-go/aws/kms"
"github.com/Layr-Labs/eigensdk-go/testutils"
"github.com/aws/aws-sdk-go-v2/service/kms/types"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto"
"github.com/stretchr/testify/assert"
"github.com/testcontainers/testcontainers-go"
)

var (
mappedLocalstackPort string
keyMetadata *types.KeyMetadata
container testcontainers.Container
)

func TestMain(m *testing.M) {
err := setup()
if err != nil {
fmt.Println("Error setting up test environment:", err)
teardown()
os.Exit(1)
}
exitCode := m.Run()
teardown()
os.Exit(exitCode)
}

func setup() error {
var err error
container, err = testutils.StartLocalstackContainer("get_public_key_test")
if err != nil {
return err
}
mappedPort, err := container.MappedPort(context.Background(), testutils.LocalStackPort)
if err != nil {
return err
}
mappedLocalstackPort = string(mappedPort)
keyMetadata, err = testutils.CreateKMSKey(mappedLocalstackPort)
if err != nil {
return err
}
return nil
}

func teardown() {
_ = container.Terminate(context.Background())
}

func TestGetPublicKey(t *testing.T) {
c, err := testutils.NewKMSClient(mappedLocalstackPort)
assert.Nil(t, err)
assert.NotNil(t, keyMetadata.KeyId)
pk, err := eigenkms.GetECDSAPublicKey(context.Background(), c, *keyMetadata.KeyId)
assert.Nil(t, err)
assert.NotNil(t, pk)
keyAddr := crypto.PubkeyToAddress(*pk)
t.Logf("Public key address: %s", keyAddr.String())
assert.NotEqual(t, keyAddr, common.Address{0})
}
40 changes: 40 additions & 0 deletions aws/kms/get_signature.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package kms

import (
"context"
"encoding/asn1"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/kms"
"github.com/aws/aws-sdk-go-v2/service/kms/types"
)

type asn1EcSig struct {
R asn1.RawValue
S asn1.RawValue
}

// GetECDSASignature retrieves the ECDSA signature for a message using a KMS key
func GetECDSASignature(
ctx context.Context, svc *kms.Client, keyId string, msg []byte,
) (r []byte, s []byte, err error) {
signInput := &kms.SignInput{
KeyId: aws.String(keyId),
SigningAlgorithm: types.SigningAlgorithmSpecEcdsaSha256,
MessageType: types.MessageTypeDigest,
Message: msg,
}

signOutput, err := svc.Sign(ctx, signInput)
if err != nil {
return nil, nil, err
}

var sigAsn1 asn1EcSig
_, err = asn1.Unmarshal(signOutput.Signature, &sigAsn1)
if err != nil {
return nil, nil, err
}

return sigAsn1.R.Bytes, sigAsn1.S.Bytes, nil
}
46 changes: 9 additions & 37 deletions cmd/egnaddrs/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,9 @@ package main

import (
"context"
"os"
"path/filepath"
"testing"

"github.com/testcontainers/testcontainers-go"
"github.com/testcontainers/testcontainers-go/wait"
"github.com/Layr-Labs/eigensdk-go/testutils"
)

const (
Expand All @@ -20,7 +17,10 @@ const (

func TestEgnAddrsWithServiceManagerFlag(t *testing.T) {

anvilC := startAnvilTestContainer()
anvilC, err := testutils.StartAnvilContainer(anvilStateFileName)
if err != nil {
t.Fatal(err)
}
anvilEndpoint, err := anvilC.Endpoint(context.Background(), "")
if err != nil {
t.Error(err)
Expand All @@ -35,7 +35,10 @@ func TestEgnAddrsWithServiceManagerFlag(t *testing.T) {

func TestEgnAddrsWithRegistryCoordinatorFlag(t *testing.T) {

anvilC := startAnvilTestContainer()
anvilC, err := testutils.StartAnvilContainer(anvilStateFileName)
if err != nil {
t.Fatal(err)
}
anvilEndpoint, err := anvilC.Endpoint(context.Background(), "")
if err != nil {
t.Error(err)
Expand All @@ -47,34 +50,3 @@ func TestEgnAddrsWithRegistryCoordinatorFlag(t *testing.T) {
// we just make sure it doesn't crash
run(args)
}

func startAnvilTestContainer() testcontainers.Container {
integrationDir, err := os.Getwd()
if err != nil {
panic(err)
}

ctx := context.Background()
req := testcontainers.ContainerRequest{
Image: "ghcr.io/foundry-rs/foundry:latest",
Files: []testcontainers.ContainerFile{
{
HostFilePath: filepath.Join(integrationDir, "test_data", anvilStateFileName),
ContainerFilePath: "/root/.anvil/state.json",
FileMode: 0644, // Adjust the FileMode according to your requirements
},
},
Entrypoint: []string{"anvil"},
Cmd: []string{"--host", "0.0.0.0", "--load-state", "/root/.anvil/state.json"},
ExposedPorts: []string{"8545/tcp"},
WaitingFor: wait.ForLog("Listening on"),
}
anvilC, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{
ContainerRequest: req,
Started: true,
})
if err != nil {
panic(err)
}
return anvilC
}
3 changes: 2 additions & 1 deletion go.mod