Skip to content

Commit

Permalink
feat: add support for multiple auth strategies per target from secret…
Browse files Browse the repository at this point in the history
…s file (#5500)
  • Loading branch information
RamanaReddy0M committed Aug 16, 2024
1 parent e0466e1 commit 2609d2d
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 45 deletions.
78 changes: 47 additions & 31 deletions pkg/authprovider/file.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ import (
type FileAuthProvider struct {
Path string
store *authx.Authx
compiled map[*regexp.Regexp]authx.AuthStrategy
domains map[string]authx.AuthStrategy
compiled map[*regexp.Regexp][]authx.AuthStrategy
domains map[string][]authx.AuthStrategy
}

// NewFileAuthProvider creates a new file based auth provider
Expand Down Expand Up @@ -56,58 +56,70 @@ func (f *FileAuthProvider) init() {
if len(secret.DomainsRegex) > 0 {
for _, domain := range secret.DomainsRegex {
if f.compiled == nil {
f.compiled = make(map[*regexp.Regexp]authx.AuthStrategy)
f.compiled = make(map[*regexp.Regexp][]authx.AuthStrategy)
}
compiled, err := regexp.Compile(domain)
if err != nil {
continue
}
f.compiled[compiled] = secret.GetStrategy()

if ss, ok := f.compiled[compiled]; ok {
f.compiled[compiled] = append(ss, secret.GetStrategy())
} else {
f.compiled[compiled] = []authx.AuthStrategy{secret.GetStrategy()}
}
}
}
for _, domain := range secret.Domains {
if f.domains == nil {
f.domains = make(map[string]authx.AuthStrategy)
f.domains = make(map[string][]authx.AuthStrategy)
}
f.domains[strings.TrimSpace(domain)] = secret.GetStrategy()
if strings.HasSuffix(domain, ":80") {
f.domains[strings.TrimSuffix(domain, ":80")] = secret.GetStrategy()
}
if strings.HasSuffix(domain, ":443") {
f.domains[strings.TrimSuffix(domain, ":443")] = secret.GetStrategy()
domain = strings.TrimSpace(domain)
domain = strings.TrimSuffix(domain, ":80")
domain = strings.TrimSuffix(domain, ":443")
if ss, ok := f.domains[domain]; ok {
f.domains[domain] = append(ss, secret.GetStrategy())
} else {
f.domains[domain] = []authx.AuthStrategy{secret.GetStrategy()}
}
}
}
for _, dynamic := range f.store.Dynamic {
if len(dynamic.DomainsRegex) > 0 {
for _, domain := range dynamic.DomainsRegex {
if f.compiled == nil {
f.compiled = make(map[*regexp.Regexp]authx.AuthStrategy)
f.compiled = make(map[*regexp.Regexp][]authx.AuthStrategy)
}
compiled, err := regexp.Compile(domain)
if err != nil {
continue
}
f.compiled[compiled] = &authx.DynamicAuthStrategy{Dynamic: dynamic}
if ss, ok := f.compiled[compiled]; !ok {
f.compiled[compiled] = []authx.AuthStrategy{&authx.DynamicAuthStrategy{Dynamic: dynamic}}
} else {
f.compiled[compiled] = append(ss, &authx.DynamicAuthStrategy{Dynamic: dynamic})
}
}
}
for _, domain := range dynamic.Domains {
if f.domains == nil {
f.domains = make(map[string]authx.AuthStrategy)
f.domains = make(map[string][]authx.AuthStrategy)
}
f.domains[strings.TrimSpace(domain)] = &authx.DynamicAuthStrategy{Dynamic: dynamic}
if strings.HasSuffix(domain, ":80") {
f.domains[strings.TrimSuffix(domain, ":80")] = &authx.DynamicAuthStrategy{Dynamic: dynamic}
}
if strings.HasSuffix(domain, ":443") {
f.domains[strings.TrimSuffix(domain, ":443")] = &authx.DynamicAuthStrategy{Dynamic: dynamic}
domain = strings.TrimSpace(domain)
domain = strings.TrimSuffix(domain, ":80")
domain = strings.TrimSuffix(domain, ":443")

if ss, ok := f.domains[domain]; !ok {
f.domains[domain] = []authx.AuthStrategy{&authx.DynamicAuthStrategy{Dynamic: dynamic}}
} else {
f.domains[domain] = append(ss, &authx.DynamicAuthStrategy{Dynamic: dynamic})
}
}
}
}

// LookupAddr looks up a given domain/address and returns appropriate auth strategy
func (f *FileAuthProvider) LookupAddr(addr string) authx.AuthStrategy {
func (f *FileAuthProvider) LookupAddr(addr string) []authx.AuthStrategy {
if strings.Contains(addr, ":") {
// default normalization for host:port
host, port, err := net.SplitHostPort(addr)
Expand All @@ -129,12 +141,12 @@ func (f *FileAuthProvider) LookupAddr(addr string) authx.AuthStrategy {
}

// LookupURL looks up a given URL and returns appropriate auth strategy
func (f *FileAuthProvider) LookupURL(u *url.URL) authx.AuthStrategy {
func (f *FileAuthProvider) LookupURL(u *url.URL) []authx.AuthStrategy {
return f.LookupAddr(u.Host)
}

// LookupURLX looks up a given URL and returns appropriate auth strategy
func (f *FileAuthProvider) LookupURLX(u *urlutil.URL) authx.AuthStrategy {
func (f *FileAuthProvider) LookupURLX(u *urlutil.URL) []authx.AuthStrategy {
return f.LookupAddr(u.Host)
}

Expand All @@ -151,17 +163,21 @@ func (f *FileAuthProvider) GetTemplatePaths() []string {

// PreFetchSecrets pre-fetches the secrets from the auth provider
func (f *FileAuthProvider) PreFetchSecrets() error {
for _, s := range f.domains {
if val, ok := s.(*authx.DynamicAuthStrategy); ok {
if err := val.Dynamic.Fetch(false); err != nil {
return err
for _, ss := range f.domains {
for _, s := range ss {
if val, ok := s.(*authx.DynamicAuthStrategy); ok {
if err := val.Dynamic.Fetch(false); err != nil {
return err
}
}
}
}
for _, s := range f.compiled {
if val, ok := s.(*authx.DynamicAuthStrategy); ok {
if err := val.Dynamic.Fetch(false); err != nil {
return err
for _, ss := range f.compiled {
for _, s := range ss {
if val, ok := s.(*authx.DynamicAuthStrategy); ok {
if err := val.Dynamic.Fetch(false); err != nil {
return err
}
}
}
}
Expand Down
6 changes: 3 additions & 3 deletions pkg/authprovider/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@ var (
type AuthProvider interface {
// LookupAddr looks up a given domain/address and returns appropriate auth strategy
// for it (accepted inputs are scanme.sh or scanme.sh:443)
LookupAddr(string) authx.AuthStrategy
LookupAddr(string) []authx.AuthStrategy
// LookupURL looks up a given URL and returns appropriate auth strategy
// it accepts a valid url struct and returns the auth strategy
LookupURL(*url.URL) authx.AuthStrategy
LookupURL(*url.URL) []authx.AuthStrategy
// LookupURLX looks up a given URL and returns appropriate auth strategy
// it accepts pd url struct (i.e urlutil.URL) and returns the auth strategy
LookupURLX(*urlutil.URL) authx.AuthStrategy
LookupURLX(*urlutil.URL) []authx.AuthStrategy
// GetTemplatePaths returns the template path for the auth provider
// that will be used for dynamic secret fetching
GetTemplatePaths() []string
Expand Down
6 changes: 3 additions & 3 deletions pkg/authprovider/multi.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ func NewMultiAuthProvider(providers ...AuthProvider) AuthProvider {
return &MultiAuthProvider{Providers: providers}
}

func (m *MultiAuthProvider) LookupAddr(host string) authx.AuthStrategy {
func (m *MultiAuthProvider) LookupAddr(host string) []authx.AuthStrategy {
for _, provider := range m.Providers {
strategy := provider.LookupAddr(host)
if strategy != nil {
Expand All @@ -29,7 +29,7 @@ func (m *MultiAuthProvider) LookupAddr(host string) authx.AuthStrategy {
return nil
}

func (m *MultiAuthProvider) LookupURL(u *url.URL) authx.AuthStrategy {
func (m *MultiAuthProvider) LookupURL(u *url.URL) []authx.AuthStrategy {
for _, provider := range m.Providers {
strategy := provider.LookupURL(u)
if strategy != nil {
Expand All @@ -39,7 +39,7 @@ func (m *MultiAuthProvider) LookupURL(u *url.URL) authx.AuthStrategy {
return nil
}

func (m *MultiAuthProvider) LookupURLX(u *urlutil.URL) authx.AuthStrategy {
func (m *MultiAuthProvider) LookupURLX(u *urlutil.URL) []authx.AuthStrategy {
for _, provider := range m.Providers {
strategy := provider.LookupURLX(u)
if strategy != nil {
Expand Down
16 changes: 8 additions & 8 deletions pkg/protocols/http/build_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,9 @@ func (g *generatedRequest) ApplyAuth(provider authprovider.AuthProvider) {
return
}
if g.request != nil {
auth := provider.LookupURLX(g.request.URL)
if auth != nil {
auth.ApplyOnRR(g.request)
authStrategies := provider.LookupURLX(g.request.URL)
for _, strategy := range authStrategies {
strategy.ApplyOnRR(g.request)
}
}
if g.rawRequest != nil {
Expand All @@ -101,11 +101,11 @@ func (g *generatedRequest) ApplyAuth(provider authprovider.AuthProvider) {
gologger.Warning().Msgf("[authprovider] Could not parse URL %s: %s\n", g.rawRequest.FullURL, err)
return
}
auth := provider.LookupURLX(parsed)
if auth != nil {
// here we need to apply it custom because we don't have a standard/official
// rawhttp request format ( which we probably should have )
g.rawRequest.ApplyAuthStrategy(auth)
authStrategies := provider.LookupURLX(parsed)
// here we need to apply it custom because we don't have a standard/official
// rawhttp request format ( which we probably should have )
for _, strategy := range authStrategies {
g.rawRequest.ApplyAuthStrategy(strategy)
}
}
}
Expand Down

0 comments on commit 2609d2d

Please sign in to comment.