Skip to content

Commit

Permalink
Cache SSH connections
Browse files Browse the repository at this point in the history
The underlying SSH connections are kept open and are reused
across several SSH sessions. This is due to upstream issues in
which concurrent/parallel SSH connections may lead to instability.

golang/go#51926
golang/go#27140
Signed-off-by: Paulo Gomes <paulo.gomes@weave.works>
  • Loading branch information
Paulo Gomes committed Mar 25, 2022
1 parent 58d6828 commit 08c3490
Show file tree
Hide file tree
Showing 2 changed files with 220 additions and 39 deletions.
161 changes: 122 additions & 39 deletions pkg/git/libgit2/managed/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ import (
"net/url"
"runtime"
"strings"
"sync"
"time"

"golang.org/x/crypto/ssh"

Expand All @@ -62,6 +64,17 @@ import (
// registerManagedSSH registers a Go-native implementation of
// SSH transport that doesn't rely on any lower-level libraries
// such as libssh2.
//
// The underlying SSH connections are kept open and are reused
// across several SSH sessions. This is due to upstream issues in
// which concurrent/parallel SSH connections may lead to instability.
//
// Connections are created on first attempt to use a given remote. The
// connection is removed from the cache on the first failed session related
// operation.
//
// https://github.com/golang/go/issues/51926
// https://github.com/golang/go/issues/27140
func registerManagedSSH() error {
for _, protocol := range []string{"ssh", "ssh+git", "git+ssh"} {
_, err := git2go.NewRegisteredSmartTransport(protocol, false, sshSmartSubtransportFactory)
Expand Down Expand Up @@ -89,6 +102,9 @@ type sshSmartSubtransport struct {
currentStream *sshSmartSubtransportStream
}

var aMux sync.RWMutex
var sshClients map[string]*ssh.Client = make(map[string]*ssh.Client)

func (t *sshSmartSubtransport) Action(urlString string, action git2go.SmartServiceAction) (git2go.SmartSubtransportStream, error) {
runtime.LockOSThread()
defer runtime.UnlockOSThread()
Expand Down Expand Up @@ -135,7 +151,14 @@ func (t *sshSmartSubtransport) Action(urlString string, action git2go.SmartServi
}
defer cred.Free()

sshConfig, err := getSSHConfigFromCredential(cred)
var addr string
if u.Port() != "" {
addr = fmt.Sprintf("%s:%s", u.Hostname(), u.Port())
} else {
addr = fmt.Sprintf("%s:22", u.Hostname())
}

ckey, sshConfig, err := cacheKeyAndConfig(addr, cred)
if err != nil {
return nil, err
}
Expand All @@ -156,52 +179,66 @@ func (t *sshSmartSubtransport) Action(urlString string, action git2go.SmartServi
return t.transport.SmartCertificateCheck(cert, true, hostname)
}

var addr string
if u.Port() != "" {
addr = fmt.Sprintf("%s:%s", u.Hostname(), u.Port())
} else {
addr = fmt.Sprintf("%s:22", u.Hostname())
aMux.RLock()
if c, ok := sshClients[ckey]; ok {
traceLog.Info("[ssh]: cache hit", "remoteAddress", addr)
t.client = c
}
aMux.RUnlock()

if t.client == nil {
traceLog.Info("[ssh]: cache miss", "remoteAddress", addr)

aMux.Lock()
defer aMux.Unlock()

// In some scenarios the ssh handshake can hang indefinitely at
// golang.org/x/crypto/ssh.(*handshakeTransport).kexLoop.
//
// xref: https://github.com/golang/go/issues/51926
done := make(chan error, 1)
go func() {
t.client, err = ssh.Dial("tcp", addr, sshConfig)
done <- err
}()

dialTimeout := sshConfig.Timeout + (30 * time.Second)

select {
case doneErr := <-done:
if doneErr != nil {
err = fmt.Errorf("ssh.Dial: %w", doneErr)
}
case <-time.After(dialTimeout):
err = fmt.Errorf("timed out waiting for ssh.Dial after %s", dialTimeout)
}

// In some scenarios the ssh handshake can hang indefinitely at
// golang.org/x/crypto/ssh.(*handshakeTransport).kexLoop.
//
// xref: https://github.com/golang/go/issues/51926
done := make(chan error, 1)
go func() {
t.client, err = ssh.Dial("tcp", addr, sshConfig)
done <- err
}()

select {
case doneErr := <-done:
if doneErr != nil {
err = fmt.Errorf("ssh.Dial: %w", doneErr)
if err != nil {
return nil, err
}
case <-time.After(sshConfig.Timeout + (5 * time.Second)):
err = fmt.Errorf("timed out waiting for ssh.Dial")
}

if err != nil {
return nil, err
sshClients[ckey] = t.client
}

traceLog.Info("[ssh]: creating new ssh session")
if t.session, err = t.client.NewSession(); err != nil {
discardCachedSshClient(ckey)
return nil, err
}

t.stdin, err = t.session.StdinPipe()
if err != nil {
if t.stdin, err = t.session.StdinPipe(); err != nil {
discardCachedSshClient(ckey)
return nil, err
}

t.stdout, err = t.session.StdoutPipe()
if err != nil {
if t.stdout, err = t.session.StdoutPipe(); err != nil {
discardCachedSshClient(ckey)
return nil, err
}

traceLog.Info("[ssh]: run on remote", "cmd", cmd)
if err := t.session.Start(cmd); err != nil {
discardCachedSshClient(ckey)
return nil, err
}

Expand All @@ -214,15 +251,25 @@ func (t *sshSmartSubtransport) Action(urlString string, action git2go.SmartServi
}

func (t *sshSmartSubtransport) Close() error {
var returnErr error

traceLog.Info("[ssh]: sshSmartSubtransport.Close()")
t.currentStream = nil
if t.client != nil {
t.stdin.Close()
t.session.Wait()
t.session.Close()
if err := t.stdin.Close(); err != nil {
returnErr = fmt.Errorf("cannot close stdin: %w", err)
}
t.client = nil
}
return nil
if t.session != nil {
traceLog.Info("[ssh]: skipping session.wait")
traceLog.Info("[ssh]: session.Close()")
if err := t.session.Close(); err != nil {
returnErr = fmt.Errorf("cannot close session: %w", err)
}
}

return returnErr
}

func (t *sshSmartSubtransport) Free() {
Expand All @@ -245,19 +292,23 @@ func (stream *sshSmartSubtransportStream) Free() {
traceLog.Info("[ssh]: sshSmartSubtransportStream.Free()")
}

func getSSHConfigFromCredential(cred *git2go.Credential) (*ssh.ClientConfig, error) {
func cacheKeyAndConfig(remoteAddress string, cred *git2go.Credential) (string, *ssh.ClientConfig, error) {
username, _, privatekey, passphrase, err := cred.GetSSHKey()
if err != nil {
return nil, err
return "", nil, err
}

var pemBytes []byte
if cred.Type() == git2go.CredentialTypeSSHMemory {
pemBytes = []byte(privatekey)
} else {
return nil, fmt.Errorf("file based SSH credential is not supported")
return "", nil, fmt.Errorf("file based SSH credential is not supported")
}

// must include the passphrase, otherwise a caller that knows the private key, but
// not its passphrase would be able to bypass auth.
ck := cacheKey(remoteAddress, username, passphrase, pemBytes)

var key ssh.Signer
if passphrase != "" {
key, err = ssh.ParsePrivateKeyWithPassphrase(pemBytes, []byte(passphrase))
Expand All @@ -266,12 +317,44 @@ func getSSHConfigFromCredential(cred *git2go.Credential) (*ssh.ClientConfig, err
}

if err != nil {
return nil, err
return "", nil, err
}

return &ssh.ClientConfig{
cfg := &ssh.ClientConfig{
User: username,
Auth: []ssh.AuthMethod{ssh.PublicKeys(key)},
Timeout: sshConnectionTimeOut,
}, nil
}

return ck, cfg, nil
}

// cacheKey generates a cache key that is multi-tenancy safe.
//
// Stablishing multiple and concurrent ssh connections leads to stability
// issues documented above. However, the caching/sharing of already stablished
// connections could represent a vector for users to bypass the ssh authentication
// mechanism.
//
// cacheKey tries to ensure that connections are only shared by users that
// have the exact same remoteAddress and credentials.
func cacheKey(remoteAddress, userName, passphrase string, pubKey []byte) string {
h := sha256.New()

v := fmt.Sprintf("%s-%s-%s-%v", remoteAddress, userName, passphrase, pubKey)

h.Write([]byte(v))
return fmt.Sprintf("%x", h.Sum(nil))
}

// discardCachedSshClient discards the cached ssh client, forcing the next git operation
// to create a new one via ssh.Dial.
func discardCachedSshClient(key string) {
aMux.Lock()
defer aMux.Unlock()

if _, found := sshClients[key]; found {
traceLog.Info("[ssh]: discard cached ssh client")
delete(sshClients, key)
}
}
98 changes: 98 additions & 0 deletions pkg/git/libgit2/managed/ssh_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
/*
Copyright 2022 The Flux authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package managed

import (
"testing"
)

func TestCacheKey(t *testing.T) {
tests := []struct {
description string
remoteAddress1 string
user1 string
passphrase1 string
pubKey1 []byte
remoteAddress2 string
user2 string
passphrase2 string
pubKey2 []byte
expectMatch bool
}{
{
description: "same remote addresses with no config",
remoteAddress1: "1.1.1.1",
remoteAddress2: "1.1.1.1",
expectMatch: true,
},
{
description: "same remote addresses with different config",
remoteAddress1: "1.1.1.1",
user1: "joe",
remoteAddress2: "1.1.1.1",
user2: "another-joe",
expectMatch: false,
},
{
description: "different remote addresses with no config",
remoteAddress1: "8.8.8.8",
remoteAddress2: "1.1.1.1",
expectMatch: false,
},
{
description: "different remote addresses with same config",
remoteAddress1: "8.8.8.8",
user1: "legit",
remoteAddress2: "1.1.1.1",
user2: "legit",
expectMatch: false,
},
{
description: "same remote addresses with same pubkey signers",
remoteAddress1: "1.1.1.1",
user1: "same-jane",
pubKey1: []byte{255, 123, 0},
remoteAddress2: "1.1.1.1",
user2: "same-jane",
pubKey2: []byte{255, 123, 0},
expectMatch: true,
},
{
description: "same remote addresses with different pubkey signers",
remoteAddress1: "1.1.1.1",
user1: "same-jane",
pubKey1: []byte{255, 123, 0},
remoteAddress2: "1.1.1.1",
user2: "same-jane",
pubKey2: []byte{0, 123, 0},
expectMatch: false,
},
}

for _, tt := range tests {
cacheKey1 := cacheKey(tt.remoteAddress1, tt.user1, tt.passphrase1, tt.pubKey1)
cacheKey2 := cacheKey(tt.remoteAddress2, tt.user2, tt.passphrase2, tt.pubKey2)

if tt.expectMatch && cacheKey1 != cacheKey2 {
t.Errorf("%s: cache keys '%s' and '%s' should match", tt.description, cacheKey1, cacheKey2)
}

if !tt.expectMatch && cacheKey1 == cacheKey2 {
t.Errorf("%s: cache keys '%s' and '%s' should not match", tt.description, cacheKey1, cacheKey2)
}
}
}

0 comments on commit 08c3490

Please sign in to comment.