From d985e4b89c0ecb86ca1cbdcd74ab215ecc0c77c2 Mon Sep 17 00:00:00 2001 From: Tuomo Tanskanen Date: Mon, 20 May 2024 13:04:02 +0300 Subject: [PATCH] refactor TLS config TLS config code causes security linters to report false positive about TLS versions that can be configured. This is porting over CAPO PR 2037. Signed-off-by: Tuomo Tanskanen --- main.go | 36 +++++++++++++--------------------- main_test.go | 55 ++++++++++++++++++++++++++++++++++++++++++---------- 2 files changed, 58 insertions(+), 33 deletions(-) diff --git a/main.go b/main.go index fea144f7e0..6b30977a41 100644 --- a/main.go +++ b/main.go @@ -366,14 +366,19 @@ func main() { func GetTLSOptionOverrideFuncs(options TLSOptions) ([]func(*tls.Config), error) { var tlsOptions []func(config *tls.Config) - tlsMinVersion, err := GetTLSVersion(options.TLSMinVersion) - if err != nil { - return nil, err - } - - tlsMaxVersion, err := GetTLSVersion(options.TLSMaxVersion) - if err != nil { - return nil, err + // To make a static analyzer happy, this block ensures there is no code + // path that sets a TLS version outside the acceptable values, even in + // case of unexpected user input. + var tlsMinVersion, tlsMaxVersion uint16 + for version, option := range map[*uint16]string{&tlsMinVersion: options.TLSMinVersion, &tlsMaxVersion: options.TLSMaxVersion} { + switch option { + case TLSVersion12: + *version = tls.VersionTLS12 + case TLSVersion13: + *version = tls.VersionTLS13 + default: + return nil, fmt.Errorf("unexpected TLS version %q (must be one of: %s)", option, strings.Join(tlsSupportedVersions, ", ")) + } } if tlsMaxVersion != 0 && tlsMinVersion > tlsMaxVersion { @@ -418,21 +423,6 @@ func GetTLSOptionOverrideFuncs(options TLSOptions) ([]func(*tls.Config), error) return tlsOptions, nil } -// GetTLSVersion returns the corresponding tls.Version or error. -func GetTLSVersion(version string) (uint16, error) { - var v uint16 - - switch version { - case TLSVersion12: - v = tls.VersionTLS12 - case TLSVersion13: - v = tls.VersionTLS13 - default: - return 0, fmt.Errorf("unexpected TLS version %q (must be one of: %s)", version, strings.Join(tlsSupportedVersions, ", ")) - } - return v, nil -} - func getMaxConcurrentReconciles(controllerConcurrency int) (int, error) { if controllerConcurrency > 0 { ctrl.Log.Info(fmt.Sprintf("controller concurrency will be set to %d according to command line flag", controllerConcurrency)) diff --git a/main_test.go b/main_test.go index 6e4c65945e..7a3999f05b 100644 --- a/main_test.go +++ b/main_test.go @@ -18,6 +18,7 @@ package main import ( "bytes" + "crypto/tls" "testing" . "github.com/onsi/gomega" @@ -75,25 +76,59 @@ func Test13CipherSuite(t *testing.T) { klog.SetOutput(bufWriter) klog.LogToStderr(false) // this is important, because klog by default logs to stderr only _, err := GetTLSOptionOverrideFuncs(tlsMockOptions) - g.Expect(bufWriter.String()).Should(ContainSubstring("warning: Cipher suites should not be set for TLS version 1.3. Ignoring ciphers")) g.Expect(err).ShouldNot(HaveOccurred()) + g.Expect(bufWriter.String()).Should(ContainSubstring("warning: Cipher suites should not be set for TLS version 1.3. Ignoring ciphers")) }) } -func TestGetTLSVersion(t *testing.T) { - t.Run("should error out when incorrect tls version passed", func(t *testing.T) { +func TestGetTLSOverrideFuncs(t *testing.T) { + t.Run("should error out when incorrect min tls version passed", func(t *testing.T) { + g := NewWithT(t) + _, err := GetTLSOptionOverrideFuncs(TLSOptions{ + TLSMinVersion: "TLS11", + TLSMaxVersion: "TLS12", + }) + g.Expect(err.Error()).Should(Equal("unexpected TLS version \"TLS11\" (must be one of: TLS12, TLS13)")) + }) + t.Run("should error out when incorrect max tls version passed", func(t *testing.T) { g := NewWithT(t) - tlsVersion := "TLS11" - _, err := GetTLSVersion(tlsVersion) + _, err := GetTLSOptionOverrideFuncs(TLSOptions{ + TLSMinVersion: "TLS12", + TLSMaxVersion: "TLS11", + }) g.Expect(err.Error()).Should(Equal("unexpected TLS version \"TLS11\" (must be one of: TLS12, TLS13)")) }) - t.Run("should pass and output correct tls version", func(t *testing.T) { - const VersionTLS12 uint16 = 771 + t.Run("should apply the requested TLS versions", func(t *testing.T) { + g := NewWithT(t) + tlsOptionOverrides, err := GetTLSOptionOverrideFuncs(TLSOptions{ + TLSMinVersion: "TLS12", + TLSMaxVersion: "TLS13", + }) + + var tlsConfig tls.Config + for _, apply := range tlsOptionOverrides { + apply(&tlsConfig) + } + + g.Expect(err).ShouldNot(HaveOccurred()) + g.Expect(tlsConfig.MinVersion).To(Equal(uint16(tls.VersionTLS12))) + g.Expect(tlsConfig.MaxVersion).To(Equal(uint16(tls.VersionTLS13))) + }) + t.Run("should apply the requested non-default TLS versions", func(t *testing.T) { g := NewWithT(t) - tlsVersion := "TLS12" - version, err := GetTLSVersion(tlsVersion) - g.Expect(version).To(Equal(VersionTLS12)) + tlsOptionOverrides, err := GetTLSOptionOverrideFuncs(TLSOptions{ + TLSMinVersion: "TLS13", + TLSMaxVersion: "TLS13", + }) + + var tlsConfig tls.Config + for _, apply := range tlsOptionOverrides { + apply(&tlsConfig) + } + g.Expect(err).ShouldNot(HaveOccurred()) + g.Expect(tlsConfig.MinVersion).To(Equal(uint16(tls.VersionTLS13))) + g.Expect(tlsConfig.MaxVersion).To(Equal(uint16(tls.VersionTLS13))) }) }