Skip to content

Commit

Permalink
Merge pull request #39 from jingweno/host_no_ignore_pk
Browse files Browse the repository at this point in the history
Check server public key when host sharing a session
  • Loading branch information
owenthereal committed May 25, 2020
2 parents 1677033 + aa85f2f commit b9dd4dd
Show file tree
Hide file tree
Showing 7 changed files with 767 additions and 18 deletions.
44 changes: 30 additions & 14 deletions cmd/upterm/command/host.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,12 @@ import (
)

var (
flagServer string
flagForceCommand string
flagPrivateKeys []string
flagAuthorizedKeys string
flagReadOnly bool
flagServer string
flagForceCommand string
flagPrivateKeys []string
flagKnownHostsFilename string
flagAuthorizedKeys string
flagReadOnly bool
)

func hostCmd() *cobra.Command {
Expand Down Expand Up @@ -55,9 +56,15 @@ func hostCmd() *cobra.Command {
RunE: shareRunE,
}

homeDir, err := os.UserHomeDir()
if err != nil {
log.Fatal(err)
}

cmd.PersistentFlags().StringVarP(&flagServer, "server", "", "ssh://uptermd.upterm.dev:22", "upterm server address (required), supported protocols are shh, ws, or wss.")
cmd.PersistentFlags().StringVarP(&flagForceCommand, "force-command", "f", "", "force execution of a command and attach its input/output to client's.")
cmd.PersistentFlags().StringSliceVarP(&flagPrivateKeys, "private-key", "i", nil, "private key for public key authentication against the upterm server (required).")
cmd.PersistentFlags().StringSliceVarP(&flagPrivateKeys, "private-key", "i", defaultPrivateKeys(homeDir), "private key for public key authentication against the upterm server (required).")
cmd.PersistentFlags().StringVarP(&flagKnownHostsFilename, "known-hosts", "", defaultKnownHost(homeDir), "a file contains the known keys for remote hosts (required).")
cmd.PersistentFlags().StringVarP(&flagAuthorizedKeys, "authorized-key", "a", "", "an authorized_keys file that lists public keys that are permitted to connect.")
cmd.PersistentFlags().BoolVarP(&flagReadOnly, "read-only", "r", false, "host a read-only session. Clients won't be able to interact.")

Expand Down Expand Up @@ -88,15 +95,11 @@ func validateShareRequiredFlags(c *cobra.Command, args []string) error {
}

if len(flagPrivateKeys) == 0 {
homeDir, err := os.UserHomeDir()
if err != nil {
return err
}
result = multierror.Append(result, fmt.Errorf("missing flag --private-key"))
}

flagPrivateKeys = defaultPrivateKeys(homeDir)
if len(flagPrivateKeys) == 0 {
result = multierror.Append(result, fmt.Errorf("missing flag --private-key"))
}
if flagKnownHostsFilename == "" {
result = multierror.Append(result, fmt.Errorf("missing flag --known-hosts"))
}

return result
Expand Down Expand Up @@ -138,16 +141,25 @@ func shareRunE(c *cobra.Command, args []string) error {
if cleanup != nil {
defer cleanup()
}

hkcb, err := host.NewPromptingHostKeyCallback(os.Stdin, os.Stdout, flagKnownHostsFilename)
if err != nil {
return err
}

h := &host.Host{
Host: flagServer,
Command: args,
ForceCommand: forceCommand,
Signers: signers,
HostKeyCallback: hkcb,
AuthorizedKeys: authorizedKeys,
KeepAliveDuration: 50 * time.Second, // nlb is 350 sec & heroku router is 55 sec
SessionCreatedCallback: displaySessionCallback,
ClientJoinedCallback: clientJoinedCallback,
ClientLeftCallback: clientLeftCallback,
Stdin: os.Stdin,
Stdout: os.Stdout,
Logger: log.New(),
ReadOnly: flagReadOnly,
}
Expand Down Expand Up @@ -208,3 +220,7 @@ func defaultPrivateKeys(homeDir string) []string {

return pks
}

func defaultKnownHost(homeDir string) string {
return filepath.Join(homeDir, ".ssh", "known_hosts")
}
1 change: 1 addition & 0 deletions ftests/ftests_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,7 @@ func (c *Host) Share(url string) error {
ClientLeftCallback: c.ClientLeftCallback,
KeepAliveDuration: 10 * time.Second,
Logger: logger,
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
Stdin: stdinr,
Stdout: stdoutw,
ReadOnly: c.ReadOnly,
Expand Down
113 changes: 113 additions & 0 deletions host/host.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
package host

import (
"bufio"
"context"
"fmt"
"io"
"net"
"net/url"
"os"
"path/filepath"
"strings"
"time"

"github.com/jingweno/upterm/host/api"
Expand All @@ -17,14 +21,118 @@ import (
"github.com/olebedev/emitter"
log "github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/knownhosts"
)

func NewPromptingHostKeyCallback(stdin io.Reader, stdout io.Writer, knownHostsFilename string) (ssh.HostKeyCallback, error) {
cb, err := knownhosts.New(knownHostsFilename)
if err != nil {
return nil, err
}

hkcb := hostKeyCallback{
stdin: stdin,
stdout: stdout,
file: knownHostsFilename,
HostKeyCallback: cb,
}

return hkcb.checkHostKey, nil
}

const (
errKeyMismatch = `
@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
@ WARNING: REMOTE HOST IDENTIFICATION HAS CHANGED! @
@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
IT IS POSSIBLE THAT SOMEONE IS DOING SOMETHING NASTY!
Someone could be eavesdropping on you right now (man-in-the-middle attack)!
It is also possible that a host key has just been changed.
The fingerprint for the %s key sent by the remote host is
%s.
Please contact your system administrator.
Add correct host key in %s to get rid of this message.
Offending %s key in %s:%d`
)

type hostKeyCallback struct {
stdin io.Reader
stdout io.Writer
file string
ssh.HostKeyCallback
}

func (cb hostKeyCallback) checkHostKey(hostname string, remote net.Addr, key ssh.PublicKey) error {
if err := cb.HostKeyCallback(hostname, remote, key); err != nil {
kerr, ok := err.(*knownhosts.KeyError)
if !ok {
return err
}

// If keer.Want is non-empty, there was a mismatch, which can signify a MITM attack
if len(kerr.Want) != 0 {
kk := kerr.Want[0] // TODO: take care of multiple key mismatches
fp := utils.FingerprintSHA256(kk.Key)
kt := keyType(kk.Key.Type())
return fmt.Errorf(errKeyMismatch, kt, fp, kk.Filename, kt, kk.Filename, kk.Line)
}

return cb.promptForConfirmation(hostname, remote, key)

}

return nil
}

func (cb hostKeyCallback) promptForConfirmation(hostname string, remote net.Addr, key ssh.PublicKey) error {
fp := utils.FingerprintSHA256(key)
fmt.Fprintf(cb.stdout, "The authenticity of host '%s (%s)' can't be established.\n", knownhosts.Normalize(hostname), knownhosts.Normalize(remote.String()))
fmt.Fprintf(cb.stdout, "%s key fingerprint is %s.\n", keyType(key.Type()), fp)
fmt.Fprintf(cb.stdout, "Are you sure you want to continue connecting (yes/no/[fingerprint])? ")

reader := bufio.NewReader(cb.stdin)
for {
confirm, err := reader.ReadString('\n')
if err != nil {
return err
}

confirm = strings.TrimSpace(confirm)

if confirm == "yes" || confirm == fp {
return cb.appendHostLine(hostname, key)
}

if confirm == "no" {
return fmt.Errorf("Host key verification failed.")
}

fmt.Fprintf(cb.stdout, "Please type 'yes', 'no' or the fingerprint: ")
}
}

func (cb hostKeyCallback) appendHostLine(hostname string, key ssh.PublicKey) error {
f, err := os.OpenFile(cb.file, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
return err
}
defer f.Close()

line := knownhosts.Line([]string{hostname}, key)
if _, err := f.WriteString(line + "\n"); err != nil {
return err
}

return nil
}

type Host struct {
Host string
KeepAliveDuration time.Duration
Command []string
ForceCommand []string
Signers []ssh.Signer
HostKeyCallback ssh.HostKeyCallback
AuthorizedKeys []ssh.PublicKey
AdminSocketFile string
SessionCreatedCallback func(*models.APIGetSessionResponse) error
Expand Down Expand Up @@ -66,6 +174,7 @@ func (c *Host) Run(ctx context.Context) error {
rt := internal.ReverseTunnel{
Host: u,
Signers: c.Signers,
HostKeyCallback: c.HostKeyCallback,
AuthorizedKeys: c.AuthorizedKeys,
KeepAliveDuration: c.KeepAliveDuration,
Logger: log.WithField("com", "reverse-tunnel"),
Expand Down Expand Up @@ -186,3 +295,7 @@ func (c *Host) Run(ctx context.Context) error {

return g.Run()
}

func keyType(t string) string {
return strings.ToUpper(strings.TrimPrefix(t, "ssh-"))
}
69 changes: 69 additions & 0 deletions host/host_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package host

import (
"bytes"
"io/ioutil"
"net"
"os"
"strings"
"testing"

"github.com/jingweno/upterm/utils"
"golang.org/x/crypto/ssh"
)

const (
testPublicKey = `ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIN0EWrjdcHcuMfI8bGAyHPcGsAc/vd/gl5673pRkRBGY`
)

func Test_hostKeyCallback(t *testing.T) {
tmpfile, err := ioutil.TempFile("", "known_hosts")
if err != nil {
t.Fatal(err)
}
defer os.Remove(tmpfile.Name())

if _, err := tmpfile.Write([]byte("[127.0.0.1]:23 ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIKpVcpc3t5GZHQFlbSLyj6sQY4wWLjNZsLTkfo9Cdjit\n")); err != nil {
t.Fatal(err)
}
tmpfile.Close()

stdin := bytes.NewBufferString("yes\n") // Simulate typing "yes" in stdin
stdout := bytes.NewBuffer(nil)

pk, _, _, _, err := ssh.ParseAuthorizedKey([]byte(testPublicKey))
if err != nil {
t.Fatal(err)
}
fp := utils.FingerprintSHA256(pk)

cb, err := NewPromptingHostKeyCallback(stdin, stdout, tmpfile.Name())
if err != nil {
t.Fatal(err)
}

// 127.0.0.1:22 is not in known_hosts
addr := &net.TCPAddr{
IP: net.IPv4(127, 0, 0, 1),
Port: 22,
}
if err := cb("127.0.0.1:22", addr, pk); err != nil {
t.Fatal(err)
}
if !strings.Contains(stdout.String(), "ED25519 key fingerprint is "+fp) {
t.Fatalf("stdout should contain fingerprint %s: %s", fp, stdout)
}

// 127.0.0.1:23 is in known_hosts
addr = &net.TCPAddr{
IP: net.IPv4(127, 0, 0, 1),
Port: 23,
}
err = cb("127.0.0.1:23", addr, pk)
if err == nil {
t.Fatalf("key mismatched error is expected")
}
if !strings.Contains(err.Error(), "Offending ED25519 key in "+tmpfile.Name()) {
t.Fatalf("unexpected error message: %s", err.Error())
}
}
17 changes: 13 additions & 4 deletions host/internal/reversetunnel.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ type ReverseTunnel struct {
Signers []ssh.Signer
AuthorizedKeys []ssh.PublicKey
KeepAliveDuration time.Duration
HostKeyCallback ssh.HostKeyCallback
Logger log.FieldLogger

ln net.Listener
Expand Down Expand Up @@ -72,10 +73,18 @@ func (c *ReverseTunnel) Establish(ctx context.Context) (*server.CreateSessionRes
}

config := &ssh.ClientConfig{
User: encodedID,
Auth: auths,
ClientVersion: upterm.HostSSHClientVersion,
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
User: encodedID,
Auth: auths,
ClientVersion: upterm.HostSSHClientVersion,
// Enforce a restricted set of algorithms for security
// TODO: make this configurable if necessary
HostKeyAlgorithms: []string{
ssh.CertAlgoRSAv01,
ssh.CertAlgoED25519v01,
ssh.KeyAlgoED25519,
ssh.KeyAlgoRSA,
},
HostKeyCallback: c.HostKeyCallback,
}

if isWSScheme(c.Host.Scheme) {
Expand Down
Loading

0 comments on commit b9dd4dd

Please sign in to comment.