Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

API: Add context to each raw request call #4987

Merged
merged 1 commit into from
Jul 24, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 44 additions & 13 deletions api/auth_token.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package api

import "context"

// TokenAuth is used to perform token backend operations on Vault
type TokenAuth struct {
c *Client
Expand All @@ -16,7 +18,9 @@ func (c *TokenAuth) Create(opts *TokenCreateRequest) (*Secret, error) {
return nil, err
}

resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return nil, err
}
Expand All @@ -31,7 +35,9 @@ func (c *TokenAuth) CreateOrphan(opts *TokenCreateRequest) (*Secret, error) {
return nil, err
}

resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return nil, err
}
Expand All @@ -46,7 +52,9 @@ func (c *TokenAuth) CreateWithRole(opts *TokenCreateRequest, roleName string) (*
return nil, err
}

resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return nil, err
}
Expand All @@ -63,7 +71,9 @@ func (c *TokenAuth) Lookup(token string) (*Secret, error) {
return nil, err
}

resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return nil, err
}
Expand All @@ -79,7 +89,10 @@ func (c *TokenAuth) LookupAccessor(accessor string) (*Secret, error) {
}); err != nil {
return nil, err
}
resp, err := c.c.RawRequest(r)

ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return nil, err
}
Expand All @@ -91,7 +104,9 @@ func (c *TokenAuth) LookupAccessor(accessor string) (*Secret, error) {
func (c *TokenAuth) LookupSelf() (*Secret, error) {
r := c.c.NewRequest("GET", "/v1/auth/token/lookup-self")

resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return nil, err
}
Expand All @@ -109,7 +124,9 @@ func (c *TokenAuth) Renew(token string, increment int) (*Secret, error) {
return nil, err
}

resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return nil, err
}
Expand All @@ -126,7 +143,9 @@ func (c *TokenAuth) RenewSelf(increment int) (*Secret, error) {
return nil, err
}

resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return nil, err
}
Expand All @@ -146,7 +165,9 @@ func (c *TokenAuth) RenewTokenAsSelf(token string, increment int) (*Secret, erro
return nil, err
}

resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return nil, err
}
Expand All @@ -164,7 +185,10 @@ func (c *TokenAuth) RevokeAccessor(accessor string) error {
}); err != nil {
return err
}
resp, err := c.c.RawRequest(r)

ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return err
}
Expand All @@ -183,7 +207,9 @@ func (c *TokenAuth) RevokeOrphan(token string) error {
return err
}

resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return err
}
Expand All @@ -197,7 +223,10 @@ func (c *TokenAuth) RevokeOrphan(token string) error {
// an effect.
func (c *TokenAuth) RevokeSelf(token string) error {
r := c.c.NewRequest("PUT", "/v1/auth/token/revoke-self")
resp, err := c.c.RawRequest(r)

ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return err
}
Expand All @@ -217,7 +246,9 @@ func (c *TokenAuth) RevokeTree(token string) error {
return err
}

resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return err
}
Expand Down
19 changes: 10 additions & 9 deletions api/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,13 @@ func (c *Client) NewRequest(method, requestPath string) *Request {
// a Vault server not configured with this client. This is an advanced operation
// that generally won't need to be called externally.
func (c *Client) RawRequest(r *Request) (*Response, error) {
return c.RawRequestWithContext(context.Background(), r)
}

// RawRequestWithContext performs the raw request given. This request may be against
// a Vault server not configured with this client. This is an advanced operation
// that generally won't need to be called externally.
func (c *Client) RawRequestWithContext(ctx context.Context, r *Request) (*Response, error) {
c.modifyLock.RLock()
token := c.token

Expand All @@ -622,7 +629,7 @@ func (c *Client) RawRequest(r *Request) (*Response, error) {
c.modifyLock.RUnlock()

if limiter != nil {
limiter.Wait(context.Background())
limiter.Wait(ctx)
}

// Sanity check the token before potentially erroring from the API
Expand All @@ -643,13 +650,10 @@ START:
return nil, fmt.Errorf("nil request created")
}

// Set the timeout, if any
var cancelFunc context.CancelFunc
if timeout != 0 {
var ctx context.Context
ctx, cancelFunc = context.WithTimeout(context.Background(), timeout)
req.Request = req.Request.WithContext(ctx)
ctx, _ = context.WithTimeout(ctx, timeout)
}
req.Request = req.Request.WithContext(ctx)

if backoff == nil {
backoff = retryablehttp.LinearJitterBackoff
Expand All @@ -667,9 +671,6 @@ START:

var result *Response
resp, err := client.Do(req)
if cancelFunc != nil {
cancelFunc()
}
if resp != nil {
result = &Response{Response: resp}
}
Expand Down
6 changes: 5 additions & 1 deletion api/help.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
package api

import (
"context"
"fmt"
)

// Help reads the help information for the given path.
func (c *Client) Help(path string) (*Help, error) {
r := c.NewRequest("GET", fmt.Sprintf("/v1/%s", path))
r.Params.Add("help", "1")
resp, err := c.RawRequest(r)

ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.RawRequestWithContext(ctx, r)
if err != nil {
return nil, err
}
Expand Down
24 changes: 19 additions & 5 deletions api/logical.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package api

import (
"bytes"
"context"
"fmt"
"io"
"os"
Expand Down Expand Up @@ -46,7 +47,10 @@ func (c *Client) Logical() *Logical {

func (c *Logical) Read(path string) (*Secret, error) {
r := c.c.NewRequest("GET", "/v1/"+path)
resp, err := c.c.RawRequest(r)

ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if resp != nil {
defer resp.Body.Close()
}
Expand Down Expand Up @@ -77,7 +81,10 @@ func (c *Logical) List(path string) (*Secret, error) {
// handle the wrapping lookup function
r.Method = "GET"
r.Params.Set("list", "true")
resp, err := c.c.RawRequest(r)

ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if resp != nil {
defer resp.Body.Close()
}
Expand Down Expand Up @@ -108,7 +115,9 @@ func (c *Logical) Write(path string, data map[string]interface{}) (*Secret, erro
return nil, err
}

resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if resp != nil {
defer resp.Body.Close()
}
Expand All @@ -134,7 +143,10 @@ func (c *Logical) Write(path string, data map[string]interface{}) (*Secret, erro

func (c *Logical) Delete(path string) (*Secret, error) {
r := c.c.NewRequest("DELETE", "/v1/"+path)
resp, err := c.c.RawRequest(r)

ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if resp != nil {
defer resp.Body.Close()
}
Expand Down Expand Up @@ -175,7 +187,9 @@ func (c *Logical) Unwrap(wrappingToken string) (*Secret, error) {
return nil, err
}

resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if resp != nil {
defer resp.Body.Close()
}
Expand Down
13 changes: 10 additions & 3 deletions api/ssh.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package api

import "fmt"
import (
"context"
"fmt"
)

// SSH is used to return a client to invoke operations on SSH backend.
type SSH struct {
Expand Down Expand Up @@ -28,7 +31,9 @@ func (c *SSH) Credential(role string, data map[string]interface{}) (*Secret, err
return nil, err
}

resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return nil, err
}
Expand All @@ -45,7 +50,9 @@ func (c *SSH) SignKey(role string, data map[string]interface{}) (*Secret, error)
return nil, err
}

resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return nil, err
}
Expand Down
5 changes: 4 additions & 1 deletion api/ssh_agent.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package api

import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
Expand Down Expand Up @@ -207,7 +208,9 @@ func (c *SSHHelper) Verify(otp string) (*SSHVerifyResponse, error) {
return nil, err
}

resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return nil, err
}
Expand Down
22 changes: 18 additions & 4 deletions api/sys_audit.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package api

import (
"context"
"fmt"

"github.com/mitchellh/mapstructure"
Expand All @@ -16,7 +17,9 @@ func (c *Sys) AuditHash(path string, input string) (string, error) {
return "", err
}

resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return "", err
}
Expand All @@ -37,7 +40,11 @@ func (c *Sys) AuditHash(path string, input string) (string, error) {

func (c *Sys) ListAudit() (map[string]*Audit, error) {
r := c.c.NewRequest("GET", "/v1/sys/audit")
resp, err := c.c.RawRequest(r)

ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)

if err != nil {
return nil, err
}
Expand Down Expand Up @@ -87,7 +94,10 @@ func (c *Sys) EnableAuditWithOptions(path string, options *EnableAuditOptions) e
return err
}

resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)

if err != nil {
return err
}
Expand All @@ -98,7 +108,11 @@ func (c *Sys) EnableAuditWithOptions(path string, options *EnableAuditOptions) e

func (c *Sys) DisableAudit(path string) error {
r := c.c.NewRequest("DELETE", fmt.Sprintf("/v1/sys/audit/%s", path))
resp, err := c.c.RawRequest(r)

ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)

if err == nil {
defer resp.Body.Close()
}
Expand Down
Loading