Skip to content

Commit

Permalink
RAI-13007 Returning err in case GetAccessToken returns not 2xx http c…
Browse files Browse the repository at this point in the history
…ode (#85)

* returning err in case GetAccessToken returns not 2xx http code

* reading local config vs env for running workflows

* refactored getConfig function to be shared by TestMain and NewClientTest

---------

Co-authored-by: Anatoli Kurtsevich <anatolikurtsevich@relational.ai>
  • Loading branch information
antikus and Anatoli Kurtsevich committed Jun 5, 2023
1 parent 673aa44 commit 9bb1816
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 19 deletions.
5 changes: 2 additions & 3 deletions rai/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,13 +171,12 @@ func (c *Client) GetAccessToken(creds *ClientCredentials) (*AccessToken, error)
if err != nil {
return nil, err
}
req = req.WithContext(c.ctx)
c.ensureHeaders(req, nil)
rsp, err := c.HttpClient.Do(req)
rsp, err := c.Do(req)
if err != nil {
return nil, err
}
defer rsp.Body.Close()

token := &AccessToken{}
if err = token.Load(rsp.Body); err != nil {
return nil, err
Expand Down
32 changes: 32 additions & 0 deletions rai/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package rai

import (
"context"
"fmt"
"strings"
"testing"
Expand Down Expand Up @@ -60,6 +61,37 @@ func findModel(models []Model, name string) *Model {
return nil
}

func TestNewClient(t *testing.T) {
var testClient *Client
var cfg Config

err := getConfig(&cfg)
assert.Nil(t, err)

opts := ClientOptions{Config: cfg}
testClient = NewClient(context.Background(), &opts)

creds := &ClientCredentials{
ClientID: cfg.Credentials.ClientID,
ClientSecret: cfg.Credentials.ClientSecret,
ClientCredentialsUrl: cfg.Credentials.ClientCredentialsUrl,
Audience: cfg.Credentials.Audience,
}
token, err := testClient.GetAccessToken(creds)
assert.Nil(t, err)
assert.NotNil(t, token)

missingCreds := &ClientCredentials{
ClientID: cfg.Credentials.ClientID,
ClientSecret: cfg.Credentials.ClientSecret,
ClientCredentialsUrl: cfg.Credentials.ClientCredentialsUrl,
}

token, err = testClient.GetAccessToken(missingCreds)
assert.Nil(t, token)
assert.NotNil(t, err)
}

// Test database management APIs.
func TestDatabase(t *testing.T) {
client := test.client
Expand Down
33 changes: 17 additions & 16 deletions rai/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,20 +98,11 @@ func (h headerRoundTrip) RoundTrip(r *http.Request) (*http.Response, error) {
return h.defaultRoundTrip.RoundTrip(r)
}

// todo: fix client init logic, load from config only if env vars are not
// available.
func newTestClient() (*Client, error) {
func getConfig(cfg *Config) error {
configPath, _ := expandUser(DefaultConfigFile)
var testClient *Client
if _, err := os.Stat(configPath); err == nil {
testClient, err = NewDefaultClient()
if err != nil {
panic(err)
}

return LoadConfig(cfg)
} else {
var cfg Config

clientId := os.Getenv("CLIENT_ID")
clientSecret := os.Getenv("CLIENT_SECRET")
clientCredentialsUrl := os.Getenv("CLIENT_CREDENTIALS_URL")
Expand All @@ -131,12 +122,22 @@ func newTestClient() (*Client, error) {
client_credentials_url=%s
`
configSrc := fmt.Sprintf(configFormat, raiHost, clientId, clientSecret, clientCredentialsUrl)
if err := LoadConfigString(configSrc, "default", &cfg); err != nil {
return nil, err
}
opts := ClientOptions{Config: cfg}
testClient = NewClient(context.Background(), &opts)
return LoadConfigString(configSrc, "default", cfg)
}
}

// todo: fix client init logic, load from config only if env vars are not
// available.
func newTestClient() (*Client, error) {
var testClient *Client
var cfg Config

if err := getConfig(&cfg); err != nil {
return nil, err
}

opts := ClientOptions{Config: cfg}
testClient = NewClient(context.Background(), &opts)

// get custom headers
var customHeaders map[string]string
Expand Down

0 comments on commit 9bb1816

Please sign in to comment.