Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

advancedTLS: Rename get root certs related pieces #7207

Merged
merged 4 commits into from
May 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 72 additions & 39 deletions security/advancedtls/advancedtls.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,31 +87,52 @@ type PostHandshakeVerificationFunc func(params *HandshakeVerificationInfo) (*Pos
// Deprecated: use PostHandshakeVerificationFunc instead.
type CustomVerificationFunc = PostHandshakeVerificationFunc

// GetRootCAsParams contains the parameters available to users when
// implementing GetRootCAs.
type GetRootCAsParams struct {
RawConn net.Conn
// ConnectionInfo contains the parameters available to users when
// implementing GetRootCertificates.
type ConnectionInfo struct {
// RawConn is the raw net.Conn representing a connection.
RawConn net.Conn
// RawCerts is the byte representation of the presented peer cert chain.
RawCerts [][]byte
}

// GetRootCAsResults contains the results of GetRootCAs.
// GetRootCAsParams contains the parameters available to users when
// implementing GetRootCAs.
//
// Deprecated: use ConnectionInfo instead.
type GetRootCAsParams = ConnectionInfo

// RootCertificates is the result of GetRootCertificates.
// If users want to reload the root trust certificate, it is required to return
// the proper TrustCerts in GetRootCAs.
type GetRootCAsResults struct {
type RootCertificates struct {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it's just a wrapper and used as an output of a single function - how about RootCertificatesResults?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think Results is just noise - this struct fundamentally is root certificates, and it's a given that it's a Result when it's the output of a function. To me this would be sort of like naming the function GetRootCAsFunction

// TrustCerts is the pool of trusted certificates.
TrustCerts *x509.CertPool
}

// GetRootCAsResults contains the results of GetRootCAs.
// If users want to reload the root trust certificate, it is required to return
// the proper TrustCerts in GetRootCAs.
//
// Deprecated: use RootCertificates instead.
type GetRootCAsResults = RootCertificates

// RootCertificateOptions contains options to obtain root trust certificates
// for both the client and the server.
// At most one option could be set. If none of them are set, we
// use the system default trust certificates.
type RootCertificateOptions struct {
// If RootCertificates is set, it will be used every time when verifying
// the peer certificates, without performing root certificate reloading.
RootCertificates *x509.CertPool
// If RootCACerts is set, it will be used every time when verifying
// the peer certificates, without performing root certificate reloading.
//
// Deprecated: use RootCertificates instead.
RootCACerts *x509.CertPool
// If GetRootCertificates is set, it will be invoked to obtain root certs for
// every new connection.
GetRootCertificates func(params *GetRootCAsParams) (*GetRootCAsResults, error)
GetRootCertificates func(params *ConnectionInfo) (*RootCertificates, error)
// If RootProvider is set, we will use the root certs from the Provider's
// KeyMaterial() call in the new connections. The Provider must have initial
// credentials if specified. Otherwise, KeyMaterial() will block forever.
Expand Down Expand Up @@ -277,6 +298,12 @@ func (o *Options) clientConfig() (*tls.Config, error) {
if o.MaxTLSVersion == 0 {
o.MaxTLSVersion = o.MaxVersion
}
// TODO(gtcooke94) RootCACerts is deprecated, eventually remove this block.
// This will ensure that users still explicitly setting RootCACerts will get
// the setting int the right place.
if o.RootOptions.RootCACerts != nil {
o.RootOptions.RootCertificates = o.RootOptions.RootCACerts
}
if o.VerificationType == SkipVerification && o.AdditionalPeerVerification == nil {
return nil, fmt.Errorf("client needs to provide custom verification mechanism if choose to skip default verification")
}
Expand Down Expand Up @@ -312,19 +339,19 @@ func (o *Options) clientConfig() (*tls.Config, error) {
}
// Propagate root-certificate-related fields in tls.Config.
switch {
case o.RootOptions.RootCACerts != nil:
config.RootCAs = o.RootOptions.RootCACerts
case o.RootOptions.RootCertificates != nil:
config.RootCAs = o.RootOptions.RootCertificates
case o.RootOptions.GetRootCertificates != nil:
// In cases when users provide GetRootCertificates callback, since this
// callback is not contained in tls.Config, we have nothing to set here.
// We will invoke the callback in ClientHandshake.
case o.RootOptions.RootProvider != nil:
o.RootOptions.GetRootCertificates = func(*GetRootCAsParams) (*GetRootCAsResults, error) {
o.RootOptions.GetRootCertificates = func(*ConnectionInfo) (*RootCertificates, error) {
km, err := o.RootOptions.RootProvider.KeyMaterial(context.Background())
if err != nil {
return nil, err
}
return &GetRootCAsResults{TrustCerts: km.Roots}, nil
return &RootCertificates{TrustCerts: km.Roots}, nil
}
default:
// No root certificate options specified by user. Use the certificates
Expand Down Expand Up @@ -381,6 +408,12 @@ func (o *Options) serverConfig() (*tls.Config, error) {
if o.MaxTLSVersion == 0 {
o.MaxTLSVersion = o.MaxVersion
}
// TODO(gtcooke94) RootCACerts is deprecated, eventually remove this block.
// This will ensure that users still explicitly setting RootCACerts will get
// the setting int the right place.
if o.RootOptions.RootCACerts != nil {
o.RootOptions.RootCertificates = o.RootOptions.RootCACerts
}
if o.RequireClientCert && o.VerificationType == SkipVerification && o.AdditionalPeerVerification == nil {
return nil, fmt.Errorf("server needs to provide custom verification mechanism if choose to skip default verification, but require client certificate(s)")
}
Expand Down Expand Up @@ -420,19 +453,19 @@ func (o *Options) serverConfig() (*tls.Config, error) {
}
// Propagate root-certificate-related fields in tls.Config.
switch {
case o.RootOptions.RootCACerts != nil:
config.ClientCAs = o.RootOptions.RootCACerts
case o.RootOptions.RootCertificates != nil:
config.ClientCAs = o.RootOptions.RootCertificates
case o.RootOptions.GetRootCertificates != nil:
// In cases when users provide GetRootCertificates callback, since this
// callback is not contained in tls.Config, we have nothing to set here.
// We will invoke the callback in ServerHandshake.
case o.RootOptions.RootProvider != nil:
o.RootOptions.GetRootCertificates = func(*GetRootCAsParams) (*GetRootCAsResults, error) {
o.RootOptions.GetRootCertificates = func(*ConnectionInfo) (*RootCertificates, error) {
km, err := o.RootOptions.RootProvider.KeyMaterial(context.Background())
if err != nil {
return nil, err
}
return &GetRootCAsResults{TrustCerts: km.Roots}, nil
return &RootCertificates{TrustCerts: km.Roots}, nil
}
default:
// No root certificate options specified by user. Use the certificates
Expand Down Expand Up @@ -477,12 +510,12 @@ func (o *Options) serverConfig() (*tls.Config, error) {
// advancedTLSCreds is the credentials required for authenticating a connection
// using TLS.
type advancedTLSCreds struct {
config *tls.Config
verifyFunc PostHandshakeVerificationFunc
getRootCAs func(params *GetRootCAsParams) (*GetRootCAsResults, error)
isClient bool
revocationOptions *RevocationOptions
verificationType VerificationType
config *tls.Config
verifyFunc PostHandshakeVerificationFunc
getRootCertificates func(params *ConnectionInfo) (*RootCertificates, error)
isClient bool
revocationOptions *RevocationOptions
verificationType VerificationType
}

func (c advancedTLSCreds) Info() credentials.ProtocolInfo {
Expand Down Expand Up @@ -548,10 +581,10 @@ func (c *advancedTLSCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credenti

func (c *advancedTLSCreds) Clone() credentials.TransportCredentials {
return &advancedTLSCreds{
config: credinternal.CloneTLSConfig(c.config),
verifyFunc: c.verifyFunc,
getRootCAs: c.getRootCAs,
isClient: c.isClient,
config: credinternal.CloneTLSConfig(c.config),
verifyFunc: c.verifyFunc,
getRootCertificates: c.getRootCertificates,
isClient: c.isClient,
}
}

Expand Down Expand Up @@ -588,8 +621,8 @@ func buildVerifyFunc(c *advancedTLSCreds,
rootCAs = c.config.ClientCAs
}
// Reload root CA certs.
if rootCAs == nil && c.getRootCAs != nil {
results, err := c.getRootCAs(&GetRootCAsParams{
if rootCAs == nil && c.getRootCertificates != nil {
results, err := c.getRootCertificates(&ConnectionInfo{
RawConn: rawConn,
RawCerts: rawCerts,
})
Expand Down Expand Up @@ -661,12 +694,12 @@ func NewClientCreds(o *Options) (credentials.TransportCredentials, error) {
return nil, err
}
tc := &advancedTLSCreds{
config: conf,
isClient: true,
getRootCAs: o.RootOptions.GetRootCertificates,
verifyFunc: o.AdditionalPeerVerification,
revocationOptions: o.RevocationOptions,
verificationType: o.VerificationType,
config: conf,
isClient: true,
getRootCertificates: o.RootOptions.GetRootCertificates,
verifyFunc: o.AdditionalPeerVerification,
revocationOptions: o.RevocationOptions,
verificationType: o.VerificationType,
}
tc.config.NextProtos = credinternal.AppendH2ToNextProtos(tc.config.NextProtos)
return tc, nil
Expand All @@ -680,12 +713,12 @@ func NewServerCreds(o *Options) (credentials.TransportCredentials, error) {
return nil, err
}
tc := &advancedTLSCreds{
config: conf,
isClient: false,
getRootCAs: o.RootOptions.GetRootCertificates,
verifyFunc: o.AdditionalPeerVerification,
revocationOptions: o.RevocationOptions,
verificationType: o.VerificationType,
config: conf,
isClient: false,
getRootCertificates: o.RootOptions.GetRootCertificates,
verifyFunc: o.AdditionalPeerVerification,
revocationOptions: o.RevocationOptions,
verificationType: o.VerificationType,
}
tc.config.NextProtos = credinternal.AppendH2ToNextProtos(tc.config.NextProtos)
return tc, nil
Expand Down
30 changes: 15 additions & 15 deletions security/advancedtls/advancedtls_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,13 +142,13 @@ func (s) TestEnd2End(t *testing.T) {
clientCert []tls.Certificate
clientGetCert func(*tls.CertificateRequestInfo) (*tls.Certificate, error)
clientRoot *x509.CertPool
clientGetRoot func(params *GetRootCAsParams) (*GetRootCAsResults, error)
clientGetRoot func(params *ConnectionInfo) (*RootCertificates, error)
clientVerifyFunc PostHandshakeVerificationFunc
clientVerificationType VerificationType
serverCert []tls.Certificate
serverGetCert func(*tls.ClientHelloInfo) ([]*tls.Certificate, error)
serverRoot *x509.CertPool
serverGetRoot func(params *GetRootCAsParams) (*GetRootCAsResults, error)
serverGetRoot func(params *ConnectionInfo) (*RootCertificates, error)
serverVerifyFunc PostHandshakeVerificationFunc
serverVerificationType VerificationType
}{
Expand Down Expand Up @@ -180,12 +180,12 @@ func (s) TestEnd2End(t *testing.T) {
},
clientVerificationType: CertVerification,
serverCert: []tls.Certificate{cs.ServerCert1},
serverGetRoot: func(params *GetRootCAsParams) (*GetRootCAsResults, error) {
serverGetRoot: func(params *ConnectionInfo) (*RootCertificates, error) {
switch stage.read() {
case 0, 1:
return &GetRootCAsResults{TrustCerts: cs.ServerTrust1}, nil
return &RootCertificates{TrustCerts: cs.ServerTrust1}, nil
default:
return &GetRootCAsResults{TrustCerts: cs.ServerTrust2}, nil
return &RootCertificates{TrustCerts: cs.ServerTrust2}, nil
}
},
serverVerifyFunc: func(params *HandshakeVerificationInfo) (*PostHandshakeVerificationResults, error) {
Expand All @@ -208,12 +208,12 @@ func (s) TestEnd2End(t *testing.T) {
{
desc: "test the reloading feature for server identity callback and client trust callback",
clientCert: []tls.Certificate{cs.ClientCert1},
clientGetRoot: func(params *GetRootCAsParams) (*GetRootCAsResults, error) {
clientGetRoot: func(params *ConnectionInfo) (*RootCertificates, error) {
switch stage.read() {
case 0, 1:
return &GetRootCAsResults{TrustCerts: cs.ClientTrust1}, nil
return &RootCertificates{TrustCerts: cs.ClientTrust1}, nil
default:
return &GetRootCAsResults{TrustCerts: cs.ClientTrust2}, nil
return &RootCertificates{TrustCerts: cs.ClientTrust2}, nil
}
},
clientVerifyFunc: func(params *HandshakeVerificationInfo) (*PostHandshakeVerificationResults, error) {
Expand Down Expand Up @@ -250,12 +250,12 @@ func (s) TestEnd2End(t *testing.T) {
{
desc: "test client custom verification",
clientCert: []tls.Certificate{cs.ClientCert1},
clientGetRoot: func(params *GetRootCAsParams) (*GetRootCAsResults, error) {
clientGetRoot: func(params *ConnectionInfo) (*RootCertificates, error) {
switch stage.read() {
case 0:
return &GetRootCAsResults{TrustCerts: cs.ClientTrust1}, nil
return &RootCertificates{TrustCerts: cs.ClientTrust1}, nil
default:
return &GetRootCAsResults{TrustCerts: cs.ClientTrust2}, nil
return &RootCertificates{TrustCerts: cs.ClientTrust2}, nil
}
},
clientVerifyFunc: func(params *HandshakeVerificationInfo) (*PostHandshakeVerificationResults, error) {
Expand Down Expand Up @@ -342,7 +342,7 @@ func (s) TestEnd2End(t *testing.T) {
GetIdentityCertificatesForServer: test.serverGetCert,
},
RootOptions: RootCertificateOptions{
RootCACerts: test.serverRoot,
RootCertificates: test.serverRoot,
GetRootCertificates: test.serverGetRoot,
},
RequireClientCert: true,
Expand Down Expand Up @@ -370,7 +370,7 @@ func (s) TestEnd2End(t *testing.T) {
},
AdditionalPeerVerification: test.clientVerifyFunc,
RootOptions: RootCertificateOptions{
RootCACerts: test.clientRoot,
RootCertificates: test.clientRoot,
GetRootCertificates: test.clientGetRoot,
},
VerificationType: test.clientVerificationType,
Expand Down Expand Up @@ -787,7 +787,7 @@ func (s) TestDefaultHostNameCheck(t *testing.T) {
go s.Serve(lis)
clientOptions := &Options{
RootOptions: RootCertificateOptions{
RootCACerts: test.clientRoot,
RootCertificates: test.clientRoot,
},
VerificationType: test.clientVerificationType,
}
Expand Down Expand Up @@ -927,7 +927,7 @@ func (s) TestTLSVersions(t *testing.T) {
go s.Serve(lis)
clientOptions := &Options{
RootOptions: RootCertificateOptions{
RootCACerts: cs.ClientTrust1,
RootCertificates: cs.ClientTrust1,
},
VerificationType: CertAndHostVerification,
MinTLSVersion: test.clientMinVersion,
Expand Down
Loading
Loading