From e57790b8c5d5a5adac000ebcccb7f1f3ad8c0dd0 Mon Sep 17 00:00:00 2001 From: LockedThread Date: Sun, 19 Jun 2022 21:59:35 -0400 Subject: [PATCH] Revert "Feat jwt auth" --- auth/auth.go | 98 ++++++++-------------------------------------- auth/directives.go | 69 -------------------------------- auth/middleware.go | 36 ----------------- models/auth.go | 94 -------------------------------------------- utils/gin.go | 30 -------------- 5 files changed, 16 insertions(+), 311 deletions(-) delete mode 100644 auth/directives.go delete mode 100644 auth/middleware.go delete mode 100644 models/auth.go delete mode 100644 utils/gin.go diff --git a/auth/auth.go b/auth/auth.go index 36b11aa..30c4bc0 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -7,40 +7,27 @@ import ( "crypto/rand" "errors" "fmt" - "github.com/KnightHacks/knighthacks_shared/models" "github.com/golang-jwt/jwt" "github.com/google/go-github/v45/github" "golang.org/x/oauth2" "io" "strconv" - "time" ) -type TokenType string +type Provider int const ( - RefreshTokenType TokenType = "REFRESH" - AccessTokenType TokenType = "ACCESS" -) - -var ( - TokenNotValid = errors.New("jwt token not valid") + GitHubAuthProvider Provider = iota + GmailAuthProvider Provider = iota ) type Auth struct { - ConfigMap map[models.Provider]oauth2.Config - signingKey []byte + ConfigMap map[Provider]oauth2.Config + signingKey string gcm cipher.AEAD } -type UserClaims struct { - UserID string `json:"user_id"` - Role models.Role `json:"role"` - Type TokenType `json:"type"` - jwt.StandardClaims -} - -func NewAuth(signingKey string, cipher32Bit string, configMap map[models.Provider]oauth2.Config) (*Auth, error) { +func NewAuth(signingKey string, cipher32Bit string, configMap map[Provider]oauth2.Config) (*Auth, error) { newCipher, err := aes.NewCipher([]byte(cipher32Bit)) if err != nil { return nil, err @@ -49,24 +36,25 @@ func NewAuth(signingKey string, cipher32Bit string, configMap map[models.Provide if err != nil { return nil, err } - return &Auth{ConfigMap: configMap, signingKey: []byte(signingKey), gcm: gcm}, nil + return &Auth{ConfigMap: configMap, signingKey: signingKey, gcm: gcm}, nil } -func (a *Auth) GetAuthCodeURL(provider models.Provider, state string) string { +func (a *Auth) GetAuthCodeURL(provider Provider) string { config := a.ConfigMap[provider] - return config.AuthCodeURL(state, oauth2.AccessTypeOffline) + // TODO: Implement oauth2 'state' on url to prevent CSRF https://datatracker.ietf.org/doc/html/rfc6749#section-10.12 + return config.AuthCodeURL("state", oauth2.AccessTypeOffline) } -func (a *Auth) ExchangeCode(ctx context.Context, provider models.Provider, code string) (*oauth2.Token, error) { +func (a *Auth) ExchangeCode(ctx context.Context, provider Provider, code string) (*oauth2.Token, error) { config := a.ConfigMap[provider] return config.Exchange(ctx, code) } -func (a *Auth) GetUID(ctx context.Context, provider models.Provider, token string) (string, error) { +func (a *Auth) GetUID(ctx context.Context, provider Provider, token string) (string, error) { config := a.ConfigMap[provider] oauthClient := oauth2.NewClient(ctx, config.TokenSource(ctx, &oauth2.Token{AccessToken: token})) - if provider == models.ProviderGithub { + if provider == GitHubAuthProvider { githubClient := github.NewClient(oauthClient) user, _, err := githubClient.Users.Get(ctx, "") @@ -109,62 +97,8 @@ func (a *Auth) DecryptAccessToken(token string) ([]byte, error) { return decryptedBytes, nil } -func (a *Auth) NewTokens(userId string, role models.Role) (refreshToken string, accessToken string, err error) { - refreshToken, err = a.NewRefreshToken(userId, role) - if err != nil { - return "", "", err - } - accessToken, err = a.NewAccessToken(userId, role) - if err != nil { - return "", "", err - } - return refreshToken, accessToken, nil -} - -func (a *Auth) NewRefreshToken(userId string, role models.Role) (string, error) { - return a.newJWT(userId, role, RefreshTokenType, time.Hour*24) -} - -func (a *Auth) NewAccessToken(userId string, role models.Role) (string, error) { - return a.newJWT(userId, role, AccessTokenType, time.Minute*30) -} - -func (a *Auth) newJWT(userId string, role models.Role, tokenType TokenType, expiration time.Duration) (string, error) { - now := time.Now().UTC() - claims := UserClaims{ - userId, - role, - tokenType, - jwt.StandardClaims{ - ExpiresAt: now.Add(expiration).Unix(), - IssuedAt: now.Unix(), - Issuer: "knighthacks", - }, - } - token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) +func (a *Auth) NewJWT(mapClaims jwt.MapClaims) (string, error) { + token := jwt.New(jwt.SigningMethodRS256) + token.Claims = mapClaims return token.SignedString(a.signingKey) } - -func (a *Auth) ParseJWT(tokenString string, tokenType TokenType) (*UserClaims, error) { - token, err := jwt.ParseWithClaims(tokenString, &UserClaims{}, func(token *jwt.Token) (interface{}, error) { - if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { - return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) - } - return a.signingKey, nil - }) - if err != nil { - return nil, err - } - if !token.Valid { - return nil, TokenNotValid - } - - if claims, ok := token.Claims.(*UserClaims); ok { - if claims.Type != tokenType { - return nil, fmt.Errorf("you are sending a %s token while we need a %s token", claims.Type, tokenType) - } - return claims, nil - } else { - return nil, errors.New("unable to cast jwt claims to UserClaims") - } -} diff --git a/auth/directives.go b/auth/directives.go deleted file mode 100644 index 21db4d3..0000000 --- a/auth/directives.go +++ /dev/null @@ -1,69 +0,0 @@ -package auth - -import ( - "context" - "errors" - "github.com/99designs/gqlgen/graphql" - "github.com/KnightHacks/knighthacks_shared/models" - "github.com/KnightHacks/knighthacks_shared/utils" - "log" -) - -type HasRoleDirective struct { - GetUserId func(ctx context.Context, obj interface{}) (string, error) -} - -func (receiver HasRoleDirective) Direct(ctx context.Context, obj interface{}, next graphql.Resolver, role models.Role) (interface{}, error) { - ginContext, err := utils.GinContextFromContext(ctx) - if err != nil { - return nil, err - } - - var userClaims *UserClaims - - value, ok := ctx.Value("AuthorizationUserClaims").(*UserClaims) - if ok { - userClaims = value - } else { - auth, err := AuthFromContext(ctx) - if err != nil { - return nil, err - } - - authHeader := ginContext.GetHeader("authorization") - - userClaims, err = auth.ParseJWT(authHeader, AccessTokenType) - if err != nil { - return nil, err - } - } - - if userClaims.Role == models.RoleOwns { - return nil, errors.New("don't try to be sneaky") - } - - switch role { - case models.RoleAdmin: - if userClaims.Role != models.RoleAdmin { - return nil, errors.New("you must be an admin to use this resolver") - } - break - case models.RoleNormal: - break - case models.RoleOwns: - if userClaims.Role == models.RoleAdmin { - break - } - id, err := receiver.GetUserId(ctx, obj) - if err != nil { - return nil, err - } - log.Printf("Checking id:%s against userClaims=%v\n", id, *userClaims) - if id != userClaims.UserID { - return nil, errors.New("you must be own this data to use this resolver") - } - break - } - - return next(context.WithValue(ctx, "AuthorizationUserClaims", userClaims)) -} diff --git a/auth/middleware.go b/auth/middleware.go deleted file mode 100644 index a77d227..0000000 --- a/auth/middleware.go +++ /dev/null @@ -1,36 +0,0 @@ -package auth - -import ( - "context" - "errors" - "fmt" - "github.com/gin-gonic/gin" -) - -func UserClaimsFromContext(ctx context.Context) (*UserClaims, error) { - if userClaims, ok := ctx.Value("AuthorizationUserClaims").(*UserClaims); ok { - return userClaims, nil - } - return nil, errors.New("unable to retrieve user claims from context") -} - -func AuthContextMiddleware(auth *Auth) gin.HandlerFunc { - return func(c *gin.Context) { - ctx := context.WithValue(c.Request.Context(), "Auth", auth) - c.Request = c.Request.WithContext(ctx) - c.Next() - } -} - -func AuthFromContext(ctx context.Context) (*Auth, error) { - auth := ctx.Value("Auth") - if auth == nil { - err := fmt.Errorf("could not retrieve auth.Auth") - return nil, err - } - - if gc, ok := auth.(*Auth); ok { - return gc, nil - } - return nil, errors.New("auth.Auth has wrong type") -} diff --git a/models/auth.go b/models/auth.go deleted file mode 100644 index 3fcba54..0000000 --- a/models/auth.go +++ /dev/null @@ -1,94 +0,0 @@ -package models - -import ( - "fmt" - "io" - "strconv" -) - -type Role string - -const ( - RoleAdmin Role = "ADMIN" - // for now keep this the same - RoleSponsor Role = "SPONSOR" - RoleNormal Role = "NORMAL" - RoleOwns Role = "OWNS" -) - -var AllRole = []Role{ - RoleAdmin, - RoleSponsor, - RoleNormal, - RoleOwns, -} - -func (e Role) IsValid() bool { - switch e { - case RoleAdmin, RoleSponsor, RoleNormal, RoleOwns: - return true - } - return false -} - -func (e Role) String() string { - return string(e) -} - -func (e *Role) UnmarshalGQL(v interface{}) error { - str, ok := v.(string) - if !ok { - return fmt.Errorf("enums must be strings") - } - - *e = Role(str) - if !e.IsValid() { - return fmt.Errorf("%s is not a valid Role", str) - } - return nil -} - -func (e Role) MarshalGQL(w io.Writer) { - fmt.Fprint(w, strconv.Quote(e.String())) -} - -type Provider string - -const ( - ProviderGithub Provider = "GITHUB" - ProviderGmail Provider = "GMAIL" -) - -var AllProvider = []Provider{ - ProviderGithub, - ProviderGmail, -} - -func (e Provider) IsValid() bool { - switch e { - case ProviderGithub, ProviderGmail: - return true - } - return false -} - -func (e Provider) String() string { - return string(e) -} - -func (e *Provider) UnmarshalGQL(v interface{}) error { - str, ok := v.(string) - if !ok { - return fmt.Errorf("enums must be strings") - } - - *e = Provider(str) - if !e.IsValid() { - return fmt.Errorf("%s is not a valid Provider", str) - } - return nil -} - -func (e Provider) MarshalGQL(w io.Writer) { - fmt.Fprint(w, strconv.Quote(e.String())) -} diff --git a/utils/gin.go b/utils/gin.go deleted file mode 100644 index 24521ad..0000000 --- a/utils/gin.go +++ /dev/null @@ -1,30 +0,0 @@ -package utils - -import ( - "context" - "fmt" - "github.com/gin-gonic/gin" -) - -func GinContextFromContext(ctx context.Context) (*gin.Context, error) { - ginContext := ctx.Value("GinContextKey") - if ginContext == nil { - err := fmt.Errorf("could not retrieve gin.Context") - return nil, err - } - - gc, ok := ginContext.(*gin.Context) - if !ok { - err := fmt.Errorf("gin.Context has wrong type") - return nil, err - } - return gc, nil -} - -func GinContextMiddleware() gin.HandlerFunc { - return func(c *gin.Context) { - ctx := context.WithValue(c.Request.Context(), "GinContextKey", c) - c.Request = c.Request.WithContext(ctx) - c.Next() - } -}