Skip to content

Commit

Permalink
Support multiple SSH keys for the same host
Browse files Browse the repository at this point in the history
  • Loading branch information
imjasonh committed Oct 17, 2019
1 parent 5160e9f commit b853cc6
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 65 deletions.
2 changes: 1 addition & 1 deletion pkg/credentials/gitcreds/creds.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func flags(fs *flag.FlagSet) {
basicConfig = basicGitConfig{entries: make(map[string]basicEntry)}
fs.Var(&basicConfig, basicAuthFlag, "List of secret=url pairs.")

sshConfig = sshGitConfig{entries: make(map[string]sshEntry)}
sshConfig = sshGitConfig{entries: make(map[string][]sshEntry)}
fs.Var(&sshConfig, sshFlag, "List of secret=url pairs.")
}

Expand Down
54 changes: 16 additions & 38 deletions pkg/credentials/gitcreds/creds_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -220,11 +220,11 @@ func TestSSHFlagHandling(t *testing.T) {

expectedSSHConfig := fmt.Sprintf(`Host github.com
HostName github.com
IdentityFile %s
Port 22
`, filepath.Join(os.Getenv("HOME"), ".ssh", "id_foo"))
if string(b) != expectedSSHConfig {
t.Errorf("got: %v, wanted: %v", string(b), expectedSSHConfig)
IdentityFile %s/.ssh/id_foo
`, credentials.VolumePath)
if d := cmp.Diff(expectedSSHConfig, string(b)); d != "" {
t.Errorf("ssh_config diff: %s", d)
}

b, err = ioutil.ReadFile(filepath.Join(credentials.VolumePath, ".ssh", "known_hosts"))
Expand Down Expand Up @@ -283,8 +283,10 @@ func TestSSHFlagHandlingThrice(t *testing.T) {
fs := flag.NewFlagSet("test", flag.ContinueOnError)
flags(fs)
err := fs.Parse([]string{
// Two secrets target github.com, and both will end up in the
// ssh config.
"-ssh-git=foo=github.com",
"-ssh-git=bar=gitlab.com",
"-ssh-git=bar=github.com",
"-ssh-git=baz=gitlab.example.com:2222",
})
if err != nil {
Expand All @@ -303,21 +305,16 @@ func TestSSHFlagHandlingThrice(t *testing.T) {

expectedSSHConfig := fmt.Sprintf(`Host github.com
HostName github.com
IdentityFile %s
Port 22
Host gitlab.com
HostName gitlab.com
IdentityFile %s
Port 22
IdentityFile %s/.ssh/id_foo
IdentityFile %s/.ssh/id_bar
Host gitlab.example.com
HostName gitlab.example.com
IdentityFile %s
Port 2222
`, filepath.Join(os.Getenv("HOME"), ".ssh", "id_foo"),
filepath.Join(os.Getenv("HOME"), ".ssh", "id_bar"),
filepath.Join(os.Getenv("HOME"), ".ssh", "id_baz"))
if string(b) != expectedSSHConfig {
t.Errorf("got: %v, wanted: %v", string(b), expectedSSHConfig)
IdentityFile %s/.ssh/id_baz
`, credentials.VolumePath, credentials.VolumePath, credentials.VolumePath)
if d := cmp.Diff(expectedSSHConfig, string(b)); d != "" {
t.Errorf("ssh_config diff: %s", d)
}

b, err = ioutil.ReadFile(filepath.Join(credentials.VolumePath, ".ssh", "known_hosts"))
Expand All @@ -327,8 +324,8 @@ Host gitlab.example.com
expectedSSHKnownHosts := `ssh-rsa aaaa
ssh-rsa bbbb
ssh-rsa cccc`
if string(b) != expectedSSHKnownHosts {
t.Errorf("got: %v, wanted: %v", string(b), expectedSSHKnownHosts)
if d := cmp.Diff(expectedSSHKnownHosts, string(b)); d != "" {
t.Errorf("known_hosts diff: %s", d)
}

b, err = ioutil.ReadFile(filepath.Join(credentials.VolumePath, ".ssh", "id_foo"))
Expand Down Expand Up @@ -370,31 +367,12 @@ func TestSSHFlagHandlingMissingFiles(t *testing.T) {
}
// No ssh-privatekey files yields an error.

cfg := sshGitConfig{entries: make(map[string]sshEntry)}
cfg := sshGitConfig{entries: make(map[string][]sshEntry)}
if err := cfg.Set("not-found=github.com"); err == nil {
t.Error("Set(); got success, wanted error.")
}
}

func TestSSHFlagHandlingURLCollision(t *testing.T) {
credentials.VolumePath, _ = ioutil.TempDir("", "")
dir := credentials.VolumeName("foo")
if err := os.MkdirAll(dir, os.ModePerm); err != nil {
t.Fatalf("os.MkdirAll(%s) = %v", dir, err)
}
if err := ioutil.WriteFile(filepath.Join(dir, corev1.SSHAuthPrivateKey), []byte("bar"), 0777); err != nil {
t.Fatalf("ioutil.WriteFile(ssh-privatekey) = %v", err)
}

cfg := sshGitConfig{entries: make(map[string]sshEntry)}
if err := cfg.Set("foo=github.com"); err != nil {
t.Fatalf("First Set() = %v", err)
}
if err := cfg.Set("bar=github.com"); err == nil {
t.Error("Second Set(); got success, wanted error.")
}
}

func TestBasicMalformedValues(t *testing.T) {
tests := []string{
"bar=baz=blah",
Expand Down
53 changes: 27 additions & 26 deletions pkg/credentials/gitcreds/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ const sshKnownHosts = "known_hosts"
// As the flag is read, this status is populated.
// sshGitConfig implements flag.Value
type sshGitConfig struct {
entries map[string]sshEntry
entries map[string][]sshEntry
// The order we see things, for iterating over the above.
order []string
}
Expand All @@ -48,8 +48,9 @@ func (dc *sshGitConfig) String() string {
}
var urls []string
for _, k := range dc.order {
v := dc.entries[k]
urls = append(urls, fmt.Sprintf("%s=%s", v.secret, k))
for _, e := range dc.entries[k] {
urls = append(urls, fmt.Sprintf("%s=%s", e.secretName, k))
}
}
return strings.Join(urls, ",")
}
Expand All @@ -59,19 +60,17 @@ func (dc *sshGitConfig) Set(value string) error {
if len(parts) != 2 {
return xerrors.Errorf("Expect entries of the form secret=url, got: %v", value)
}
secret := parts[0]
secretName := parts[0]
url := parts[1]

if _, ok := dc.entries[url]; ok {
return xerrors.Errorf("Multiple entries for url: %v", url)
}

e, err := newSshEntry(url, secret)
e, err := newSshEntry(url, secretName)
if err != nil {
return err
}
dc.entries[url] = *e
dc.order = append(dc.order, url)
if _, exists := dc.entries[url]; !exists {
dc.order = append(dc.order, url)
}
dc.entries[url] = append(dc.entries[url], *e)
return nil
}

Expand All @@ -82,7 +81,7 @@ func (dc *sshGitConfig) Write() error {
}

// Walk each of the entries and for each do three things:
// 1. Write out: ~/.ssh/id_{secret} with the secret key
// 1. Write out: ~/.ssh/id_{secretName} with the secret key
// 2. Compute its part of "~/.ssh/config"
// 3. Compute its part of "~/.ssh/known_hosts"
var configEntries []string
Expand All @@ -95,17 +94,19 @@ func (dc *sshGitConfig) Write() error {
host = k
port = defaultPort
}
v := dc.entries[k]
if err := v.Write(sshDir); err != nil {
return err
}
configEntries = append(configEntries, fmt.Sprintf(`Host %s
configEntry := fmt.Sprintf(`Host %s
HostName %s
IdentityFile %s
Port %s
`, host, host, v.path(sshDir), port))

knownHosts = append(knownHosts, v.knownHosts)
`, host, host, port)
for _, e := range dc.entries[k] {
if err := e.Write(sshDir); err != nil {
return err
}
configEntry += fmt.Sprintf(` IdentityFile %s
`, e.path(sshDir))
knownHosts = append(knownHosts, e.knownHosts)
}
configEntries = append(configEntries, configEntry)
}
configPath := filepath.Join(sshDir, "config")
configContent := strings.Join(configEntries, "")
Expand All @@ -118,13 +119,13 @@ func (dc *sshGitConfig) Write() error {
}

type sshEntry struct {
secret string
secretName string
privateKey string
knownHosts string
}

func (be *sshEntry) path(sshDir string) string {
return filepath.Join(sshDir, "id_"+be.secret)
return filepath.Join(sshDir, "id_"+be.secretName)
}

func sshKeyScan(domain string) ([]byte, error) {
Expand All @@ -142,8 +143,8 @@ func (be *sshEntry) Write(sshDir string) error {
return ioutil.WriteFile(be.path(sshDir), []byte(be.privateKey), 0600)
}

func newSshEntry(u, secret string) (*sshEntry, error) {
secretPath := credentials.VolumeName(secret)
func newSshEntry(u, secretName string) (*sshEntry, error) {
secretPath := credentials.VolumeName(secretName)

pk, err := ioutil.ReadFile(filepath.Join(secretPath, corev1.SSHAuthPrivateKey))
if err != nil {
Expand All @@ -161,7 +162,7 @@ func newSshEntry(u, secret string) (*sshEntry, error) {
knownHosts := string(kh)

return &sshEntry{
secret: secret,
secretName: secretName,
privateKey: privateKey,
knownHosts: knownHosts,
}, nil
Expand Down

0 comments on commit b853cc6

Please sign in to comment.