Skip to content

Commit

Permalink
Merge pull request #1737 from Nordix/tuomo/refactor-tls-config
Browse files Browse the repository at this point in the history
🌱 refactor TLS config
  • Loading branch information
metal3-io-bot committed May 22, 2024
2 parents 2066e6b + d985e4b commit c8cf1c4
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 33 deletions.
36 changes: 13 additions & 23 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -366,14 +366,19 @@ func main() {
func GetTLSOptionOverrideFuncs(options TLSOptions) ([]func(*tls.Config), error) {
var tlsOptions []func(config *tls.Config)

tlsMinVersion, err := GetTLSVersion(options.TLSMinVersion)
if err != nil {
return nil, err
}

tlsMaxVersion, err := GetTLSVersion(options.TLSMaxVersion)
if err != nil {
return nil, err
// To make a static analyzer happy, this block ensures there is no code
// path that sets a TLS version outside the acceptable values, even in
// case of unexpected user input.
var tlsMinVersion, tlsMaxVersion uint16
for version, option := range map[*uint16]string{&tlsMinVersion: options.TLSMinVersion, &tlsMaxVersion: options.TLSMaxVersion} {
switch option {
case TLSVersion12:
*version = tls.VersionTLS12
case TLSVersion13:
*version = tls.VersionTLS13
default:
return nil, fmt.Errorf("unexpected TLS version %q (must be one of: %s)", option, strings.Join(tlsSupportedVersions, ", "))
}
}

if tlsMaxVersion != 0 && tlsMinVersion > tlsMaxVersion {
Expand Down Expand Up @@ -418,21 +423,6 @@ func GetTLSOptionOverrideFuncs(options TLSOptions) ([]func(*tls.Config), error)
return tlsOptions, nil
}

// GetTLSVersion returns the corresponding tls.Version or error.
func GetTLSVersion(version string) (uint16, error) {
var v uint16

switch version {
case TLSVersion12:
v = tls.VersionTLS12
case TLSVersion13:
v = tls.VersionTLS13
default:
return 0, fmt.Errorf("unexpected TLS version %q (must be one of: %s)", version, strings.Join(tlsSupportedVersions, ", "))
}
return v, nil
}

func getMaxConcurrentReconciles(controllerConcurrency int) (int, error) {
if controllerConcurrency > 0 {
ctrl.Log.Info(fmt.Sprintf("controller concurrency will be set to %d according to command line flag", controllerConcurrency))
Expand Down
55 changes: 45 additions & 10 deletions main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package main

import (
"bytes"
"crypto/tls"
"testing"

. "github.com/onsi/gomega"
Expand Down Expand Up @@ -75,25 +76,59 @@ func Test13CipherSuite(t *testing.T) {
klog.SetOutput(bufWriter)
klog.LogToStderr(false) // this is important, because klog by default logs to stderr only
_, err := GetTLSOptionOverrideFuncs(tlsMockOptions)
g.Expect(bufWriter.String()).Should(ContainSubstring("warning: Cipher suites should not be set for TLS version 1.3. Ignoring ciphers"))
g.Expect(err).ShouldNot(HaveOccurred())
g.Expect(bufWriter.String()).Should(ContainSubstring("warning: Cipher suites should not be set for TLS version 1.3. Ignoring ciphers"))
})
}

func TestGetTLSVersion(t *testing.T) {
t.Run("should error out when incorrect tls version passed", func(t *testing.T) {
func TestGetTLSOverrideFuncs(t *testing.T) {
t.Run("should error out when incorrect min tls version passed", func(t *testing.T) {
g := NewWithT(t)
_, err := GetTLSOptionOverrideFuncs(TLSOptions{
TLSMinVersion: "TLS11",
TLSMaxVersion: "TLS12",
})
g.Expect(err.Error()).Should(Equal("unexpected TLS version \"TLS11\" (must be one of: TLS12, TLS13)"))
})
t.Run("should error out when incorrect max tls version passed", func(t *testing.T) {
g := NewWithT(t)
tlsVersion := "TLS11"
_, err := GetTLSVersion(tlsVersion)
_, err := GetTLSOptionOverrideFuncs(TLSOptions{
TLSMinVersion: "TLS12",
TLSMaxVersion: "TLS11",
})
g.Expect(err.Error()).Should(Equal("unexpected TLS version \"TLS11\" (must be one of: TLS12, TLS13)"))
})
t.Run("should pass and output correct tls version", func(t *testing.T) {
const VersionTLS12 uint16 = 771
t.Run("should apply the requested TLS versions", func(t *testing.T) {
g := NewWithT(t)
tlsOptionOverrides, err := GetTLSOptionOverrideFuncs(TLSOptions{
TLSMinVersion: "TLS12",
TLSMaxVersion: "TLS13",
})

var tlsConfig tls.Config
for _, apply := range tlsOptionOverrides {
apply(&tlsConfig)
}

g.Expect(err).ShouldNot(HaveOccurred())
g.Expect(tlsConfig.MinVersion).To(Equal(uint16(tls.VersionTLS12)))
g.Expect(tlsConfig.MaxVersion).To(Equal(uint16(tls.VersionTLS13)))
})
t.Run("should apply the requested non-default TLS versions", func(t *testing.T) {
g := NewWithT(t)
tlsVersion := "TLS12"
version, err := GetTLSVersion(tlsVersion)
g.Expect(version).To(Equal(VersionTLS12))
tlsOptionOverrides, err := GetTLSOptionOverrideFuncs(TLSOptions{
TLSMinVersion: "TLS13",
TLSMaxVersion: "TLS13",
})

var tlsConfig tls.Config
for _, apply := range tlsOptionOverrides {
apply(&tlsConfig)
}

g.Expect(err).ShouldNot(HaveOccurred())
g.Expect(tlsConfig.MinVersion).To(Equal(uint16(tls.VersionTLS13)))
g.Expect(tlsConfig.MaxVersion).To(Equal(uint16(tls.VersionTLS13)))
})
}

Expand Down

0 comments on commit c8cf1c4

Please sign in to comment.