Skip to content

Commit

Permalink
addressed review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
santhoshivan23 committed Jun 21, 2023
1 parent 2fcc16f commit 5346ce7
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 30 deletions.
32 changes: 12 additions & 20 deletions clientapi/routing/admin.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"github.com/matrix-org/util"
"github.com/nats-io/nats.go"
"github.com/sirupsen/logrus"
"golang.org/x/exp/constraints"

clientapi "github.com/matrix-org/dendrite/clientapi/api"
"github.com/matrix-org/dendrite/internal/httputil"
Expand All @@ -39,8 +40,8 @@ func AdminCreateNewRegistrationToken(req *http.Request, cfg *config.ClientAPI, u
}
request := struct {
Token string `json:"token"`
UsesAllowed int32 `json:"uses_allowed"`
ExpiryTime int64 `json:"expiry_time"`
UsesAllowed *int32 `json:"uses_allowed,omitempty"`
ExpiryTime *int64 `json:"expiry_time,omitempty"`
Length int32 `json:"length"`
}{}

Expand Down Expand Up @@ -87,15 +88,13 @@ func AdminCreateNewRegistrationToken(req *http.Request, cfg *config.ClientAPI, u
}
}
// At this point, we have a valid token, either through request body or through random generation.

if usesAllowed < 0 {
if usesAllowed != nil && *usesAllowed < 0 {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.BadJSON("uses_allowed must be a non-negative integer or null"),
}
}

if expiryTime != 0 && expiryTime < time.Now().UnixNano()/int64(time.Millisecond) {
if expiryTime != nil && spec.Timestamp(*expiryTime).Time().Before(time.Now()) {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.BadJSON("expiry_time must not be in the past"),
Expand All @@ -106,10 +105,10 @@ func AdminCreateNewRegistrationToken(req *http.Request, cfg *config.ClientAPI, u
// If usesAllowed or expiryTime is 0, it means they are not present in the request. NULL (indicating unlimited uses / no expiration will be persisted in DB)
registrationToken := &clientapi.RegistrationToken{
Token: &token,
UsesAllowed: &usesAllowed,
UsesAllowed: usesAllowed,
Pending: &pending,
Completed: &completed,
ExpiryTime: &expiryTime,
ExpiryTime: expiryTime,
}
created, err := userAPI.PerformAdminCreateRegistrationToken(req.Context(), registrationToken)
if !created {
Expand All @@ -130,19 +129,19 @@ func AdminCreateNewRegistrationToken(req *http.Request, cfg *config.ClientAPI, u
Code: 200,
JSON: map[string]interface{}{
"token": token,
"uses_allowed": getReturnValueForUsesAllowed(usesAllowed),
"uses_allowed": getReturnValue(usesAllowed),
"pending": pending,
"completed": completed,
"expiry_time": getReturnValueExpiryTime(expiryTime),
"expiry_time": getReturnValue(expiryTime),
},
}
}

func getReturnValueForUsesAllowed(usesAllowed int32) interface{} {
if usesAllowed == 0 {
func getReturnValue[t constraints.Integer](in *t) any {
if in == nil {
return nil
}
return usesAllowed
return *in
}

func AdminListRegistrationTokens(req *http.Request, cfg *config.ClientAPI, userAPI userapi.ClientUserAPI) util.JSONResponse {
Expand Down Expand Up @@ -176,13 +175,6 @@ func AdminListRegistrationTokens(req *http.Request, cfg *config.ClientAPI, userA
}
}

func getReturnValueExpiryTime(expiryTime int64) interface{} {
if expiryTime == 0 {
return nil
}
return expiryTime
}

func AdminGetRegistrationToken(req *http.Request, cfg *config.ClientAPI, userAPI userapi.ClientUserAPI) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
Expand Down
10 changes: 5 additions & 5 deletions userapi/storage/postgres/registration_tokens_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,8 @@ func (s *registrationTokenStatements) InsertRegistrationToken(ctx context.Contex
_, err := stmt.ExecContext(
ctx,
*registrationToken.Token,
nullIfZero(*registrationToken.UsesAllowed),
nullIfZero(*registrationToken.ExpiryTime),
getInsertValue(registrationToken.UsesAllowed),
getInsertValue(registrationToken.ExpiryTime),
*registrationToken.Pending,
*registrationToken.Completed)
if err != nil {
Expand All @@ -116,11 +116,11 @@ func (s *registrationTokenStatements) InsertRegistrationToken(ctx context.Contex
return true, nil
}

func nullIfZero[t constraints.Integer](in t) any {
if in == 0 {
func getInsertValue[t constraints.Integer](in *t) any {
if in == nil {
return nil
}
return in
return *in
}

func (s *registrationTokenStatements) ListRegistrationTokens(ctx context.Context, tx *sql.Tx, returnAll bool, valid bool) ([]api.RegistrationToken, error) {
Expand Down
10 changes: 5 additions & 5 deletions userapi/storage/sqlite3/registration_tokens_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,8 @@ func (s *registrationTokenStatements) InsertRegistrationToken(ctx context.Contex
_, err := stmt.ExecContext(
ctx,
*registrationToken.Token,
nullIfZero(*registrationToken.UsesAllowed),
nullIfZero(*registrationToken.ExpiryTime),
getInsertValue(registrationToken.UsesAllowed),
getInsertValue(registrationToken.ExpiryTime),
*registrationToken.Pending,
*registrationToken.Completed)
if err != nil {
Expand All @@ -116,11 +116,11 @@ func (s *registrationTokenStatements) InsertRegistrationToken(ctx context.Contex
return true, nil
}

func nullIfZero[t constraints.Integer](in t) any {
if in == 0 {
func getInsertValue[t constraints.Integer](in *t) any {
if in == nil {
return nil
}
return in
return *in
}

func (s *registrationTokenStatements) ListRegistrationTokens(ctx context.Context, tx *sql.Tx, returnAll bool, valid bool) ([]api.RegistrationToken, error) {
Expand Down

0 comments on commit 5346ce7

Please sign in to comment.