Skip to content

Commit

Permalink
Add testing about getCredentialsPath
Browse files Browse the repository at this point in the history
  • Loading branch information
kenzo0107 committed Dec 19, 2019
1 parent e90db16 commit 345f4b5
Show file tree
Hide file tree
Showing 2 changed files with 193 additions and 142 deletions.
289 changes: 147 additions & 142 deletions cmd/omssh/main.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package main

import (
"errors"
"flag"
"fmt"
"log"
Expand All @@ -26,53 +27,19 @@ import (
latest "github.com/tcnksm/go-latest"
)

const version = "0.0.3"

const (
defUser = "ubuntu"
name = "omssh"
version = "0.0.3"
defaultUser = "ubuntu"
)

var (
showVersion bool
buildDate string
credentialsPath string
defUsers = []string{"ubuntu", "ec2-user"}
)

func init() {
credentialsPath = os.Getenv("AWS_SHARED_CREDENTIALS_FILE")
if credentialsPath == "" {
var configDir string
home := os.Getenv("HOME")
if home == "" && runtime.GOOS == "windows" {
configDir = os.Getenv("APPDATA")
} else {
configDir = home
}
credentialsPath = filepath.Join(configDir, ".aws", "credentials")
}
}

func main() {
var (
showVersion bool
)

flag.BoolVar(&showVersion, "v", false, "show version")
flag.BoolVar(&showVersion, "version", false, "show version")

if showVersion {
fmt.Println("version:", version)
fmt.Println("build:", buildDate)
checkLatest(version)
return
}

app := cli.NewApp()

app.Name = "Oreno mssh"
app.Version = version

app.Flags = []cli.Flag{
flags = []cli.Flag{
cli.StringFlag{
Name: "region, r",
Value: "ap-northeast-1",
Expand All @@ -88,142 +55,180 @@ func main() {
Usage: "select ssh user",
},
}
)

// app.Action = omssh.Pre

app.Action = func(c *cli.Context) error {
region := c.String("region")
func init() {
credentialsPath = getCredentialsPath(runtime.GOOS)
}

profiles, err := utility.GetProfiles(credentialsPath)
if err != nil {
return err
}
func main() {
flag.BoolVar(&showVersion, "v", false, "show version")
flag.BoolVar(&showVersion, "version", false, "show version")

profileWithAssumeRole, err := utility.FinderProfile(profiles)
if err != nil {
return err
if showVersion {
fmt.Println("version:", version)
fmt.Println("build:", buildDate)
if err := checkLatest(version); err != nil {
log.Println(err)
}
return
}

_p := strings.Split(profileWithAssumeRole, "|")
app := cli.NewApp()
app.Name = name
app.Version = version
app.Flags = flags
app.Action = action

var sess *session.Session
if len(_p) > 1 {
profile, roleArn, mfaSerial, sourceProfile := awsapi.GetProfileWithAssumeRole(profileWithAssumeRole)
err := app.Run(os.Args)
if err != nil {
log.Fatal(err)
}
}

sourceSess := awsapi.NewSession(sourceProfile, region)
func getCredentialsPath(runtimeGOOS string) string {
c := os.Getenv("AWS_SHARED_CREDENTIALS_FILE")
if c != "" {
return c
}

f := func(o *stscreds.AssumeRoleProvider) {
o.Duration = time.Hour
o.RoleSessionName = sourceProfile
o.SerialNumber = aws.String(mfaSerial)
o.TokenProvider = stscreds.StdinTokenProvider
}
var configDir string
home := os.Getenv("HOME")
if home == "" && runtimeGOOS == "windows" {
configDir = os.Getenv("APPDATA")
} else {
configDir = home
}
return filepath.Join(configDir, ".aws", "credentials")
}

creds := stscreds.NewCredentials(sourceSess, roleArn, f)
func checkLatest(version string) error {
version = fixVersionStr(version)
githubTag := &latest.GithubTag{
Owner: "kenzo0107",
Repository: "omssh",
FixVersionStrFunc: fixVersionStr,
}
res, err := latest.Check(githubTag, version)
if err != nil {
return err
}
if res.Outdated {
return errors.New("not latest, you should upgrade")
}
return nil
}

config := aws.Config{
Region: aws.String(region),
Credentials: creds,
}
func fixVersionStr(v string) string {
v = strings.TrimPrefix(v, "v")
vs := strings.Split(v, "-")
return vs[0]
}

sess = session.Must(session.NewSessionWithOptions(session.Options{
Config: config,
Profile: profile,
}))
} else {
profile := _p[0]
sess = awsapi.NewSession(profile, region)
}
func action(c *cli.Context) error {
region := c.String("region")

// get list of ec2 instances
ec2Client := awsapi.NewEC2Client(ec2.New(sess))
ec2Instances, err := ec2Client.DescribeRunningEC2s()
if err != nil {
return err
}
profiles, err := utility.GetProfiles(credentialsPath)
if err != nil {
return err
}

// select an ec2
ec2, err := awsapi.FinderEC2(ec2Instances)
if err != nil {
return err
}
profileWithAssumeRole, err := utility.FinderProfile(profiles)
if err != nil {
return err
}

user := defUser
if c.Bool("user") {
u, e := awsapi.FinderUsername(defUsers)
if e != nil {
return e
}
user = u
}
_p := strings.Split(profileWithAssumeRole, "|")

cache := cache.New(480*time.Minute, 1440*time.Minute)
publicKey, privateKey := utility.SSHKeyGen(cache)
var sess *session.Session
if len(_p) > 1 {
profile, roleArn, mfaSerial, sourceProfile := awsapi.GetProfileWithAssumeRole(profileWithAssumeRole)

// use ec2 instance connect to send public key
ec2instanceconnectSvc := ec2instanceconnect.New(sess)
sourceSess := awsapi.NewSession(sourceProfile, region)

input := ec2instanceconnect.SendSSHPublicKeyInput{
AvailabilityZone: aws.String(ec2.AvailabilityZone),
InstanceId: aws.String(ec2.InstanceID),
InstanceOSUser: aws.String(user),
SSHPublicKey: aws.String(publicKey),
f := func(o *stscreds.AssumeRoleProvider) {
o.Duration = time.Hour
o.RoleSessionName = sourceProfile
o.SerialNumber = aws.String(mfaSerial)
o.TokenProvider = stscreds.StdinTokenProvider
}

ec2InstanceConnectClient := awsapi.NewEC2InstanceConnectClient(ec2instanceconnectSvc)
r, err := ec2InstanceConnectClient.SendSSHPubKey(input)
creds := stscreds.NewCredentials(sourceSess, roleArn, f)

if err != nil || !r {
return err
config := aws.Config{
Region: aws.String(region),
Credentials: creds,
}

// ssh -i <temporary ssh private key> <user>@<public ip address>
log.Printf("ssh %s@%s -p %s [%s]\n", user, ec2.PublicIPAddress, c.String("port"), ec2.InstanceID)

signer, err := ssh.ParsePrivateKey(privateKey)
if err != nil {
return err
}
sess = session.Must(session.NewSessionWithOptions(session.Options{
Config: config,
Profile: profile,
}))
} else {
profile := _p[0]
sess = awsapi.NewSession(profile, region)
}

sshClientConfig := omssh.ConfigureSSHClient(user, signer)
// get list of ec2 instances
ec2Client := awsapi.NewEC2Client(ec2.New(sess))
ec2Instances, err := ec2Client.DescribeRunningEC2s()
if err != nil {
return err
}

device := omssh.NewDevice(ec2.PublicIPAddress, c.String("port"))
if err := device.SSHConnect(sshClientConfig); err != nil {
return err
}
device.SetupIO()
// select an ec2
ec2, err := awsapi.FinderEC2(ec2Instances)
if err != nil {
return err
}

if err := device.StartShell(); err != nil {
return err
user := defaultUser
if c.Bool("user") {
u, e := awsapi.FinderUsername(defUsers)
if e != nil {
return e
}
return nil
user = u
}

err := app.Run(os.Args)
if err != nil {
log.Fatal(err)
cache := cache.New(480*time.Minute, 1440*time.Minute)
publicKey, privateKey := utility.SSHKeyGen(cache)

// use ec2 instance connect to send public key
ec2instanceconnectSvc := ec2instanceconnect.New(sess)

input := ec2instanceconnect.SendSSHPublicKeyInput{
AvailabilityZone: aws.String(ec2.AvailabilityZone),
InstanceId: aws.String(ec2.InstanceID),
InstanceOSUser: aws.String(user),
SSHPublicKey: aws.String(publicKey),
}
}

func checkLatest(version string) {
version = fixVersionStr(version)
githubTag := &latest.GithubTag{
Owner: "kenzo0107",
Repository: "omssh",
FixVersionStrFunc: fixVersionStr,
ec2InstanceConnectClient := awsapi.NewEC2InstanceConnectClient(ec2instanceconnectSvc)
r, err := ec2InstanceConnectClient.SendSSHPubKey(input)

if err != nil || !r {
return err
}
res, err := latest.Check(githubTag, version)

// ssh -i <temporary ssh private key> <user>@<public ip address>
log.Printf("ssh %s@%s -p %s [%s]\n", user, ec2.PublicIPAddress, c.String("port"), ec2.InstanceID)

signer, err := ssh.ParsePrivateKey(privateKey)
if err != nil {
log.Println(err)
return
return err
}
if res.Outdated {
log.Printf("%s is not latest, you should upgrade to %s\n", version, res.Current)

sshClientConfig := omssh.ConfigureSSHClient(user, signer)

device := omssh.NewDevice(ec2.PublicIPAddress, c.String("port"))
if err := device.SSHConnect(sshClientConfig); err != nil {
return err
}
}
device.SetupIO()

func fixVersionStr(v string) string {
v = strings.TrimPrefix(v, "v")
vs := strings.Split(v, "-")
return vs[0]
if err := device.StartShell(); err != nil {
return err
}
return nil
}
Loading

0 comments on commit 345f4b5

Please sign in to comment.