diff --git a/crypto.go b/crypto.go index 119d843..2c31746 100644 --- a/crypto.go +++ b/crypto.go @@ -37,14 +37,35 @@ type Identity struct { config tls.Config } +type IdentityConfig struct { + CertTemplate *x509.Certificate +} + +type IdentityOption func(r *IdentityConfig) + +// WithCertTemplate specifies the template to use when generating a new certificate. +func WithCertTemplate(template *x509.Certificate) IdentityOption { + return func(c *IdentityConfig) { + c.CertTemplate = template + } +} + // NewIdentity creates a new identity -func NewIdentity(privKey ic.PrivKey) (*Identity, error) { - certTmpl, err := defaultCertTemplate() - if err != nil { - return nil, err +func NewIdentity(privKey ic.PrivKey, opts ...IdentityOption) (*Identity, error) { + config := IdentityConfig{} + for _, opt := range opts { + opt(&config) + } + + var err error + if config.CertTemplate == nil { + config.CertTemplate, err = DefaultCertTemplate() + if err != nil { + return nil, err + } } - cert, err := KeyToCertificate(privKey, certTmpl) + cert, err := keyToCertificate(privKey, config.CertTemplate) if err != nil { return nil, err } @@ -163,10 +184,10 @@ func PubKeyFromCertChain(chain []*x509.Certificate) (ic.PubKey, error) { return pubKey, nil } -// GenerateSignedExtension uses the provided private key to sign the public key, and returns the +// generateSignedExtension uses the provided private key to sign the public key, and returns the // signature within a pkix.Extension. // This extension is included in a certificate to cryptographically tie it to the libp2p private key. -func GenerateSignedExtension(sk ic.PrivKey, pubKey crypto.PublicKey) (*pkix.Extension, error) { +func generateSignedExtension(sk ic.PrivKey, pubKey crypto.PublicKey) (*pkix.Extension, error) { keyBytes, err := ic.MarshalPublicKey(sk.GetPublic()) if err != nil { return nil, err @@ -190,17 +211,17 @@ func GenerateSignedExtension(sk ic.PrivKey, pubKey crypto.PublicKey) (*pkix.Exte return &pkix.Extension{Id: extensionID, Critical: extensionCritical, Value: value}, nil } -// KeyToCertificate generates a new ECDSA private key and corresponding x509 certificate. +// keyToCertificate generates a new ECDSA private key and corresponding x509 certificate. // The certificate includes an extension that cryptographically ties it to the provided libp2p // private key to authenticate TLS connections. -func KeyToCertificate(sk ic.PrivKey, certTmpl *x509.Certificate) (*tls.Certificate, error) { +func keyToCertificate(sk ic.PrivKey, certTmpl *x509.Certificate) (*tls.Certificate, error) { certKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) if err != nil { return nil, err } // after calling CreateCertificate, these will end up in Certificate.Extensions - extension, err := GenerateSignedExtension(sk, certKey.Public()) + extension, err := generateSignedExtension(sk, certKey.Public()) if err != nil { return nil, err } @@ -216,7 +237,7 @@ func KeyToCertificate(sk ic.PrivKey, certTmpl *x509.Certificate) (*tls.Certifica }, nil } -func defaultCertTemplate() (*x509.Certificate, error) { +func DefaultCertTemplate() (*x509.Certificate, error) { bigNum := big.NewInt(1 << 62) sn, err := rand.Int(rand.Reader, bigNum) if err != nil { diff --git a/crypto_test.go b/crypto_test.go new file mode 100644 index 0000000..ce20ab2 --- /dev/null +++ b/crypto_test.go @@ -0,0 +1,48 @@ +package libp2ptls + +import ( + "crypto/x509" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNewIdentityCertificates(t *testing.T) { + _, key := createPeer(t) + cn := "a.test.name" + email := "unittest@example.com" + + t.Run("NewIdentity with default template", func(t *testing.T) { + // Generate an identity using the default template + id, err := NewIdentity(key) + assert.NoError(t, err) + + // Extract the x509 certificate + x509Cert, err := x509.ParseCertificate(id.config.Certificates[0].Certificate[0]) + assert.NoError(t, err) + + // verify the common name and email are not set + assert.Empty(t, x509Cert.Subject.CommonName) + assert.Empty(t, x509Cert.EmailAddresses) + }) + + t.Run("NewIdentity with custom template", func(t *testing.T) { + tmpl, err := DefaultCertTemplate() + assert.NoError(t, err) + + tmpl.Subject.CommonName = cn + tmpl.EmailAddresses = []string{email} + + // Generate an identity using the custom template + id, err := NewIdentity(key, WithCertTemplate(tmpl)) + assert.NoError(t, err) + + // Extract the x509 certificate + x509Cert, err := x509.ParseCertificate(id.config.Certificates[0].Certificate[0]) + assert.NoError(t, err) + + // verify the common name and email are set + assert.Equal(t, cn, x509Cert.Subject.CommonName) + assert.Equal(t, email, x509Cert.EmailAddresses[0]) + }) +}