From 5346ce735abfeefdaff438e690fffa8f16980da2 Mon Sep 17 00:00:00 2001 From: santhoshivan23 Date: Thu, 22 Jun 2023 00:02:09 +0530 Subject: [PATCH] addressed review comments --- clientapi/routing/admin.go | 32 +++++++------------ .../postgres/registration_tokens_table.go | 10 +++--- .../sqlite3/registration_tokens_table.go | 10 +++--- 3 files changed, 22 insertions(+), 30 deletions(-) diff --git a/clientapi/routing/admin.go b/clientapi/routing/admin.go index 964236f7f1..cc9370be29 100644 --- a/clientapi/routing/admin.go +++ b/clientapi/routing/admin.go @@ -18,6 +18,7 @@ import ( "github.com/matrix-org/util" "github.com/nats-io/nats.go" "github.com/sirupsen/logrus" + "golang.org/x/exp/constraints" clientapi "github.com/matrix-org/dendrite/clientapi/api" "github.com/matrix-org/dendrite/internal/httputil" @@ -39,8 +40,8 @@ func AdminCreateNewRegistrationToken(req *http.Request, cfg *config.ClientAPI, u } request := struct { Token string `json:"token"` - UsesAllowed int32 `json:"uses_allowed"` - ExpiryTime int64 `json:"expiry_time"` + UsesAllowed *int32 `json:"uses_allowed,omitempty"` + ExpiryTime *int64 `json:"expiry_time,omitempty"` Length int32 `json:"length"` }{} @@ -87,15 +88,13 @@ func AdminCreateNewRegistrationToken(req *http.Request, cfg *config.ClientAPI, u } } // At this point, we have a valid token, either through request body or through random generation. - - if usesAllowed < 0 { + if usesAllowed != nil && *usesAllowed < 0 { return util.JSONResponse{ Code: http.StatusBadRequest, JSON: spec.BadJSON("uses_allowed must be a non-negative integer or null"), } } - - if expiryTime != 0 && expiryTime < time.Now().UnixNano()/int64(time.Millisecond) { + if expiryTime != nil && spec.Timestamp(*expiryTime).Time().Before(time.Now()) { return util.JSONResponse{ Code: http.StatusBadRequest, JSON: spec.BadJSON("expiry_time must not be in the past"), @@ -106,10 +105,10 @@ func AdminCreateNewRegistrationToken(req *http.Request, cfg *config.ClientAPI, u // If usesAllowed or expiryTime is 0, it means they are not present in the request. NULL (indicating unlimited uses / no expiration will be persisted in DB) registrationToken := &clientapi.RegistrationToken{ Token: &token, - UsesAllowed: &usesAllowed, + UsesAllowed: usesAllowed, Pending: &pending, Completed: &completed, - ExpiryTime: &expiryTime, + ExpiryTime: expiryTime, } created, err := userAPI.PerformAdminCreateRegistrationToken(req.Context(), registrationToken) if !created { @@ -130,19 +129,19 @@ func AdminCreateNewRegistrationToken(req *http.Request, cfg *config.ClientAPI, u Code: 200, JSON: map[string]interface{}{ "token": token, - "uses_allowed": getReturnValueForUsesAllowed(usesAllowed), + "uses_allowed": getReturnValue(usesAllowed), "pending": pending, "completed": completed, - "expiry_time": getReturnValueExpiryTime(expiryTime), + "expiry_time": getReturnValue(expiryTime), }, } } -func getReturnValueForUsesAllowed(usesAllowed int32) interface{} { - if usesAllowed == 0 { +func getReturnValue[t constraints.Integer](in *t) any { + if in == nil { return nil } - return usesAllowed + return *in } func AdminListRegistrationTokens(req *http.Request, cfg *config.ClientAPI, userAPI userapi.ClientUserAPI) util.JSONResponse { @@ -176,13 +175,6 @@ func AdminListRegistrationTokens(req *http.Request, cfg *config.ClientAPI, userA } } -func getReturnValueExpiryTime(expiryTime int64) interface{} { - if expiryTime == 0 { - return nil - } - return expiryTime -} - func AdminGetRegistrationToken(req *http.Request, cfg *config.ClientAPI, userAPI userapi.ClientUserAPI) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { diff --git a/userapi/storage/postgres/registration_tokens_table.go b/userapi/storage/postgres/registration_tokens_table.go index 45b39c8922..3c3e3fdd93 100644 --- a/userapi/storage/postgres/registration_tokens_table.go +++ b/userapi/storage/postgres/registration_tokens_table.go @@ -106,8 +106,8 @@ func (s *registrationTokenStatements) InsertRegistrationToken(ctx context.Contex _, err := stmt.ExecContext( ctx, *registrationToken.Token, - nullIfZero(*registrationToken.UsesAllowed), - nullIfZero(*registrationToken.ExpiryTime), + getInsertValue(registrationToken.UsesAllowed), + getInsertValue(registrationToken.ExpiryTime), *registrationToken.Pending, *registrationToken.Completed) if err != nil { @@ -116,11 +116,11 @@ func (s *registrationTokenStatements) InsertRegistrationToken(ctx context.Contex return true, nil } -func nullIfZero[t constraints.Integer](in t) any { - if in == 0 { +func getInsertValue[t constraints.Integer](in *t) any { + if in == nil { return nil } - return in + return *in } func (s *registrationTokenStatements) ListRegistrationTokens(ctx context.Context, tx *sql.Tx, returnAll bool, valid bool) ([]api.RegistrationToken, error) { diff --git a/userapi/storage/sqlite3/registration_tokens_table.go b/userapi/storage/sqlite3/registration_tokens_table.go index 99c18c557e..8979547317 100644 --- a/userapi/storage/sqlite3/registration_tokens_table.go +++ b/userapi/storage/sqlite3/registration_tokens_table.go @@ -106,8 +106,8 @@ func (s *registrationTokenStatements) InsertRegistrationToken(ctx context.Contex _, err := stmt.ExecContext( ctx, *registrationToken.Token, - nullIfZero(*registrationToken.UsesAllowed), - nullIfZero(*registrationToken.ExpiryTime), + getInsertValue(registrationToken.UsesAllowed), + getInsertValue(registrationToken.ExpiryTime), *registrationToken.Pending, *registrationToken.Completed) if err != nil { @@ -116,11 +116,11 @@ func (s *registrationTokenStatements) InsertRegistrationToken(ctx context.Contex return true, nil } -func nullIfZero[t constraints.Integer](in t) any { - if in == 0 { +func getInsertValue[t constraints.Integer](in *t) any { + if in == nil { return nil } - return in + return *in } func (s *registrationTokenStatements) ListRegistrationTokens(ctx context.Context, tx *sql.Tx, returnAll bool, valid bool) ([]api.RegistrationToken, error) {