From 80ceb18c7860cce8d4f19ae44fa60a69bfaf5209 Mon Sep 17 00:00:00 2001 From: santhoshivan23 Date: Sat, 3 Jun 2023 21:47:24 +0530 Subject: [PATCH 01/21] CreateNewToken API: Initial Changes --- clientapi/auth/authtypes/logintypes.go | 1 + clientapi/routing/admin.go | 81 ++++++++++++++++++++++++++ clientapi/routing/routing.go | 5 ++ setup/config/config_clientapi.go | 9 +++ 4 files changed, 96 insertions(+) diff --git a/clientapi/auth/authtypes/logintypes.go b/clientapi/auth/authtypes/logintypes.go index f01e48f806..6e08d97357 100644 --- a/clientapi/auth/authtypes/logintypes.go +++ b/clientapi/auth/authtypes/logintypes.go @@ -11,4 +11,5 @@ const ( LoginTypeRecaptcha = "m.login.recaptcha" LoginTypeApplicationService = "m.login.application_service" LoginTypeToken = "m.login.token" + LoginTypeRegistrationToken = "m.login.registration_token" ) diff --git a/clientapi/routing/admin.go b/clientapi/routing/admin.go index 3d64454c4c..c25a76d85a 100644 --- a/clientapi/routing/admin.go +++ b/clientapi/routing/admin.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "net/http" + "regexp" "time" "github.com/gorilla/mux" @@ -24,6 +25,86 @@ import ( "github.com/matrix-org/dendrite/userapi/api" ) +func AdminCreateNewToken(req *http.Request) util.JSONResponse { + request := struct { + Token string `json:"token"` + UsesAllowed int32 `json:"uses_allowed"` + ExpiryTime int64 `json:"expiry_time"` + Length int32 `json:"length"` + }{} + + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.Unknown("Failed to decode request body: " + err.Error()), + } + } + token := request.Token + if len(token) > 0 { + if len(token) > 64 { + return util.MatrixErrorResponse( + http.StatusBadRequest, + string(spec.ErrorInvalidParam), + "token must not be empty and must not be longer than 64") + } + is_token_valid, _ := regexp.MatchString("^[[:ascii:][:digit:]_]*$", token) + if !is_token_valid { + return util.MatrixErrorResponse( + http.StatusBadRequest, + string(spec.ErrorInvalidParam), + "token must consist only of characters matched by the regex [A-Za-z0-9-_]") + } + } else { + length := request.Length + if length > 0 && length <= 64 { + return util.MatrixErrorResponse( + http.StatusBadRequest, + string(spec.ErrorInvalidParam), + "length must be greater than zero and not greater than 64") + } + // TODO: Generate Random Token + // token = GenerateRandomToken(length) + } + uses_allowed := request.UsesAllowed + if uses_allowed < 0 { + return util.MatrixErrorResponse( + http.StatusBadRequest, + string(spec.ErrorInvalidParam), + "uses_allowed must be a non-negative integer or null") + } + + expiry_time := request.ExpiryTime + if expiry_time != 0 && expiry_time < time.Now().UnixNano()/int64(time.Millisecond) { + return util.MatrixErrorResponse( + http.StatusBadRequest, + string(spec.ErrorInvalidParam), + "expiry_time must not be in the past") + } + created := CreateToken(token, uses_allowed, expiry_time) + if !created { + return util.MatrixErrorResponse( + http.StatusBadRequest, + string(spec.ErrorInvalidParam), + fmt.Sprintf("Token alreaady exists: %s", token)) + } + return util.JSONResponse{ + Code: 200, + JSON: map[string]interface{}{ + "token": token, + "uses_allowed": uses_allowed, + "pending": 0, + "completed": 0, + "expiry_time": expiry_time, + }, + } +} + +func CreateToken(token string, uses_allowed int32, expiryTime int64) bool { + // TODO: Implement Create Token -> Inserts token into table registration_tokens. + // Returns true if token created, false if token already exists. + return true +} + func AdminEvacuateRoom(req *http.Request, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index d3f19cae12..cef558f09f 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -162,6 +162,11 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) } + dendriteAdminRouter.Handle("/admin/registrationTokens/new", + httputil.MakeAdminAPI("admin_registration_tokens_new", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + return AdminCreateNewToken(req) + }), + ).Methods(http.MethodPost, http.MethodOptions) dendriteAdminRouter.Handle("/admin/evacuateRoom/{roomID}", httputil.MakeAdminAPI("admin_evacuate_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { diff --git a/setup/config/config_clientapi.go b/setup/config/config_clientapi.go index b6c74a75f0..b04d617e0a 100644 --- a/setup/config/config_clientapi.go +++ b/setup/config/config_clientapi.go @@ -13,6 +13,10 @@ type ClientAPI struct { // secrets) RegistrationDisabled bool `yaml:"registration_disabled"` + // If set, requires users to submit a token during registration. + // Tokens can be managed using admin API. + RegistrationRequiresToken bool `yaml:"registration_requires_token"` + // Enable registration without captcha verification or shared secret. // This option is populated by the -really-enable-open-registration // command line parameter as it is not recommended. @@ -56,6 +60,7 @@ type ClientAPI struct { func (c *ClientAPI) Defaults(opts DefaultOpts) { c.RegistrationSharedSecret = "" + c.RegistrationRequiresToken = false c.RecaptchaPublicKey = "" c.RecaptchaPrivateKey = "" c.RecaptchaEnabled = false @@ -100,6 +105,10 @@ func (c *ClientAPI) Verify(configErrs *ConfigErrors) { ) } } + + if c.RegistrationDisabled && c.RegistrationRequiresToken { + configErrs.Add("registration_requires_token cannot be set to true when registration_disabled is true") + } } type TURN struct { From 6205ffb8c0a030ecf83dba5f2e2b6ed3e2fe34b2 Mon Sep 17 00:00:00 2001 From: santhoshivan23 Date: Sun, 4 Jun 2023 22:08:21 +0530 Subject: [PATCH 02/21] refactoring --- clientapi/routing/admin.go | 90 +++++++++++--------- clientapi/routing/routing.go | 2 +- roomserver/api/api.go | 1 + roomserver/internal/perform/perform_admin.go | 9 ++ setup/config/config_clientapi.go | 4 - 5 files changed, 59 insertions(+), 47 deletions(-) diff --git a/clientapi/routing/admin.go b/clientapi/routing/admin.go index c25a76d85a..24fbb79a87 100644 --- a/clientapi/routing/admin.go +++ b/clientapi/routing/admin.go @@ -25,7 +25,14 @@ import ( "github.com/matrix-org/dendrite/userapi/api" ) -func AdminCreateNewToken(req *http.Request) util.JSONResponse { +func AdminCreateNewRegistrationToken(req *http.Request, cfg *config.ClientAPI, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse { + if !cfg.RegistrationRequiresToken { + return util.MatrixErrorResponse( + http.StatusForbidden, + string(spec.ErrorForbidden), + "Registration via tokens is not enabled on this homeserver", + ) + } request := struct { Token string `json:"token"` UsesAllowed int32 `json:"uses_allowed"` @@ -34,53 +41,58 @@ func AdminCreateNewToken(req *http.Request) util.JSONResponse { }{} if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: spec.Unknown("Failed to decode request body: " + err.Error()), - } + return util.MatrixErrorResponse( + http.StatusBadRequest, + string(spec.ErrorBadJSON), + "Failed to decode request body:", + ) } token := request.Token - if len(token) > 0 { - if len(token) > 64 { - return util.MatrixErrorResponse( - http.StatusBadRequest, - string(spec.ErrorInvalidParam), - "token must not be empty and must not be longer than 64") - } - is_token_valid, _ := regexp.MatchString("^[[:ascii:][:digit:]_]*$", token) - if !is_token_valid { - return util.MatrixErrorResponse( - http.StatusBadRequest, - string(spec.ErrorInvalidParam), - "token must consist only of characters matched by the regex [A-Za-z0-9-_]") - } - } else { - length := request.Length - if length > 0 && length <= 64 { - return util.MatrixErrorResponse( - http.StatusBadRequest, - string(spec.ErrorInvalidParam), - "length must be greater than zero and not greater than 64") - } - // TODO: Generate Random Token - // token = GenerateRandomToken(length) + if len(token) == 0 || len(token) > 64 { + return util.MatrixErrorResponse( + http.StatusBadRequest, + string(spec.ErrorInvalidParam), + "token must not be empty and must not be longer than 64") } - uses_allowed := request.UsesAllowed - if uses_allowed < 0 { + isTokenValid, _ := regexp.MatchString("^[[:ascii:][:digit:]_]*$", token) + if !isTokenValid { + return util.MatrixErrorResponse( + http.StatusBadRequest, + string(spec.ErrorInvalidParam), + "token must consist only of characters matched by the regex [A-Za-z0-9-_]") + } + length := request.Length + if !(length > 0 && length <= 64) { + return util.MatrixErrorResponse( + http.StatusBadRequest, + string(spec.ErrorInvalidParam), + "length must be greater than zero and not greater than 64") + } + // TODO: Generate Random Token + // token = GenerateRandomToken(length) + usesAllowed := request.UsesAllowed + if usesAllowed < 0 { return util.MatrixErrorResponse( http.StatusBadRequest, string(spec.ErrorInvalidParam), "uses_allowed must be a non-negative integer or null") } - expiry_time := request.ExpiryTime - if expiry_time != 0 && expiry_time < time.Now().UnixNano()/int64(time.Millisecond) { + expiryTime := request.ExpiryTime + if expiryTime != 0 && expiryTime < time.Now().UnixNano()/int64(time.Millisecond) { return util.MatrixErrorResponse( http.StatusBadRequest, string(spec.ErrorInvalidParam), "expiry_time must not be in the past") } - created := CreateToken(token, uses_allowed, expiry_time) + created, err := rsAPI.PerformCreateToken(req.Context(), token, usesAllowed, expiryTime) + if err != nil { + return util.MatrixErrorResponse( + http.StatusInternalServerError, + string(spec.ErrorUnknown), + err.Error(), + ) + } if !created { return util.MatrixErrorResponse( http.StatusBadRequest, @@ -91,20 +103,14 @@ func AdminCreateNewToken(req *http.Request) util.JSONResponse { Code: 200, JSON: map[string]interface{}{ "token": token, - "uses_allowed": uses_allowed, + "uses_allowed": usesAllowed, "pending": 0, "completed": 0, - "expiry_time": expiry_time, + "expiry_time": expiryTime, }, } } -func CreateToken(token string, uses_allowed int32, expiryTime int64) bool { - // TODO: Implement Create Token -> Inserts token into table registration_tokens. - // Returns true if token created, false if token already exists. - return true -} - func AdminEvacuateRoom(req *http.Request, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index cef558f09f..efa3f45e82 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -164,7 +164,7 @@ func Setup( } dendriteAdminRouter.Handle("/admin/registrationTokens/new", httputil.MakeAdminAPI("admin_registration_tokens_new", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - return AdminCreateNewToken(req) + return AdminCreateNewRegistrationToken(req, cfg, rsAPI) }), ).Methods(http.MethodPost, http.MethodOptions) diff --git a/roomserver/api/api.go b/roomserver/api/api.go index 7cb3379e03..54762b6ffc 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -173,6 +173,7 @@ type ClientRoomserverAPI interface { PerformCreateRoom(ctx context.Context, userID spec.UserID, roomID spec.RoomID, createRequest *PerformCreateRoomRequest) (string, *util.JSONResponse) // PerformRoomUpgrade upgrades a room to a newer version PerformRoomUpgrade(ctx context.Context, roomID, userID string, roomVersion gomatrixserverlib.RoomVersion) (newRoomID string, err error) + PerformAdminCreateRegistrationToken(ctx context.Context, token string, usesAllowed, pending, completed int32, expiryTime int64) (bool, error) PerformAdminEvacuateRoom(ctx context.Context, roomID string) (affected []string, err error) PerformAdminEvacuateUser(ctx context.Context, userID string) (affected []string, err error) PerformAdminPurgeRoom(ctx context.Context, roomID string) error diff --git a/roomserver/internal/perform/perform_admin.go b/roomserver/internal/perform/perform_admin.go index 575525e21b..f78886035c 100644 --- a/roomserver/internal/perform/perform_admin.go +++ b/roomserver/internal/perform/perform_admin.go @@ -42,6 +42,15 @@ type Admin struct { Leaver *Leaver } +func (r *Admin) PerformAdminCreateRegistrationToken( + ctx context.Context, token string, + usesAllowed, pending, completed int32, + expiryTime int64) (bool, error) { + //TODO: Implement logic to save token in DB. + //Return false, if token already exists, else true. + return true, nil +} + // PerformAdminEvacuateRoom will remove all local users from the given room. func (r *Admin) PerformAdminEvacuateRoom( ctx context.Context, diff --git a/setup/config/config_clientapi.go b/setup/config/config_clientapi.go index b04d617e0a..44136e2a08 100644 --- a/setup/config/config_clientapi.go +++ b/setup/config/config_clientapi.go @@ -105,10 +105,6 @@ func (c *ClientAPI) Verify(configErrs *ConfigErrors) { ) } } - - if c.RegistrationDisabled && c.RegistrationRequiresToken { - configErrs.Add("registration_requires_token cannot be set to true when registration_disabled is true") - } } type TURN struct { From f5039be461519d5eecf6e19637c568ca3463b4e1 Mon Sep 17 00:00:00 2001 From: santhoshivan23 Date: Sun, 4 Jun 2023 22:12:44 +0530 Subject: [PATCH 03/21] refactoring --- clientapi/routing/admin.go | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/clientapi/routing/admin.go b/clientapi/routing/admin.go index 24fbb79a87..60d6ec7dd2 100644 --- a/clientapi/routing/admin.go +++ b/clientapi/routing/admin.go @@ -85,7 +85,9 @@ func AdminCreateNewRegistrationToken(req *http.Request, cfg *config.ClientAPI, r string(spec.ErrorInvalidParam), "expiry_time must not be in the past") } - created, err := rsAPI.PerformCreateToken(req.Context(), token, usesAllowed, expiryTime) + pending := 0 + completed := 0 + created, err := rsAPI.PerformAdminCreateRegistrationToken(req.Context(), token, usesAllowed, int32(pending), int32(completed), expiryTime) if err != nil { return util.MatrixErrorResponse( http.StatusInternalServerError, @@ -104,8 +106,8 @@ func AdminCreateNewRegistrationToken(req *http.Request, cfg *config.ClientAPI, r JSON: map[string]interface{}{ "token": token, "uses_allowed": usesAllowed, - "pending": 0, - "completed": 0, + "pending": pending, + "completed": completed, "expiry_time": expiryTime, }, } From 2c339a6bfd90c148ab9524c66ee7e7301479dc72 Mon Sep 17 00:00:00 2001 From: santhoshivan23 Date: Mon, 5 Jun 2023 23:10:27 +0530 Subject: [PATCH 04/21] refactoring, implement db layer --- clientapi/routing/admin.go | 48 ++++++++++++++----- roomserver/internal/perform/perform_admin.go | 10 +++- roomserver/storage/interface.go | 1 + .../postgres/registration_tokens_table.go | 27 +++++++++++ roomserver/storage/postgres/storage.go | 3 ++ roomserver/storage/shared/storage.go | 23 +++++---- roomserver/storage/tables/interface.go | 4 ++ 7 files changed, 92 insertions(+), 24 deletions(-) create mode 100644 roomserver/storage/postgres/registration_tokens_table.go diff --git a/clientapi/routing/admin.go b/clientapi/routing/admin.go index 60d6ec7dd2..558e011e30 100644 --- a/clientapi/routing/admin.go +++ b/clientapi/routing/admin.go @@ -5,8 +5,10 @@ import ( "encoding/json" "errors" "fmt" + "math/rand" "net/http" "regexp" + "strings" "time" "github.com/gorilla/mux" @@ -25,6 +27,17 @@ import ( "github.com/matrix-org/dendrite/userapi/api" ) +func generateRandomToken(length int) string { + allowedChars := "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_" + rand.Seed(time.Now().UnixNano()) + var sb strings.Builder + for i := 0; i < length; i++ { + randomIndex := rand.Intn(len(allowedChars)) + sb.WriteByte(allowedChars[randomIndex]) + } + return sb.String() +} + func AdminCreateNewRegistrationToken(req *http.Request, cfg *config.ClientAPI, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse { if !cfg.RegistrationRequiresToken { return util.MatrixErrorResponse( @@ -47,13 +60,31 @@ func AdminCreateNewRegistrationToken(req *http.Request, cfg *config.ClientAPI, r "Failed to decode request body:", ) } + token := request.Token - if len(token) == 0 || len(token) > 64 { + usesAllowed := request.UsesAllowed + expiryTime := request.ExpiryTime + length := request.Length + + if len(token) == 0 { + // Token not present in request body. Hence, generate a random token. + if !(length > 0 && length <= 64) { + return util.MatrixErrorResponse( + http.StatusBadRequest, + string(spec.ErrorInvalidParam), + "length must be greater than zero and not greater than 64") + } + token = generateRandomToken(int(length)) + } + + if len(token) > 64 { + //Token present in request body, but is too long. return util.MatrixErrorResponse( http.StatusBadRequest, string(spec.ErrorInvalidParam), - "token must not be empty and must not be longer than 64") + "token must not be longer than 64") } + isTokenValid, _ := regexp.MatchString("^[[:ascii:][:digit:]_]*$", token) if !isTokenValid { return util.MatrixErrorResponse( @@ -61,16 +92,8 @@ func AdminCreateNewRegistrationToken(req *http.Request, cfg *config.ClientAPI, r string(spec.ErrorInvalidParam), "token must consist only of characters matched by the regex [A-Za-z0-9-_]") } - length := request.Length - if !(length > 0 && length <= 64) { - return util.MatrixErrorResponse( - http.StatusBadRequest, - string(spec.ErrorInvalidParam), - "length must be greater than zero and not greater than 64") - } - // TODO: Generate Random Token - // token = GenerateRandomToken(length) - usesAllowed := request.UsesAllowed + // At this point, we have a valid token, either through request body or through random generation. + if usesAllowed < 0 { return util.MatrixErrorResponse( http.StatusBadRequest, @@ -78,7 +101,6 @@ func AdminCreateNewRegistrationToken(req *http.Request, cfg *config.ClientAPI, r "uses_allowed must be a non-negative integer or null") } - expiryTime := request.ExpiryTime if expiryTime != 0 && expiryTime < time.Now().UnixNano()/int64(time.Millisecond) { return util.MatrixErrorResponse( http.StatusBadRequest, diff --git a/roomserver/internal/perform/perform_admin.go b/roomserver/internal/perform/perform_admin.go index f78886035c..292d91f236 100644 --- a/roomserver/internal/perform/perform_admin.go +++ b/roomserver/internal/perform/perform_admin.go @@ -46,8 +46,14 @@ func (r *Admin) PerformAdminCreateRegistrationToken( ctx context.Context, token string, usesAllowed, pending, completed int32, expiryTime int64) (bool, error) { - //TODO: Implement logic to save token in DB. - //Return false, if token already exists, else true. + exists, err := r.DB.RegistrationTokenExists(ctx, token) + if err != nil { + return false, err + } + if exists { + fmt.Println(fmt.Sprintf("token: %s already exists", token)) + return false, fmt.Errorf("token: %s already exists", token) + } return true, nil } diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index 7d22df0084..4cf8f3b3a5 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -27,6 +27,7 @@ import ( ) type Database interface { + RegistrationTokenExists(ctx context.Context, token string) (bool, error) // Do we support processing input events for more than one room at a time? SupportsConcurrentRoomInputs() bool // RoomInfo returns room information for the given room ID, or nil if there is no room. diff --git a/roomserver/storage/postgres/registration_tokens_table.go b/roomserver/storage/postgres/registration_tokens_table.go new file mode 100644 index 0000000000..1f69f42d8f --- /dev/null +++ b/roomserver/storage/postgres/registration_tokens_table.go @@ -0,0 +1,27 @@ +package postgres + +import ( + "context" + "database/sql" + "fmt" +) + +const registrationTokensSchema = ` +CREATE TABLE IF NOT EXISTS roomserver_registration_tokens ( + token TEXT PRIMARY KEY, + pending BIGINT, + completed BIGINT, + uses_allowed BIGINT, + expiry_time BIGINT +); +` + +func CreateRegistrationTokensTable(db *sql.DB) error { + _, err := db.Exec(registrationTokensSchema) + return err +} + +func RegistrationTokenExists(ctx context.Context, tx *sql.Tx, token string) (bool, error) { + fmt.Println("here!!") + return true, nil +} diff --git a/roomserver/storage/postgres/storage.go b/roomserver/storage/postgres/storage.go index 19cde54105..5836ab1532 100644 --- a/roomserver/storage/postgres/storage.go +++ b/roomserver/storage/postgres/storage.go @@ -92,6 +92,9 @@ func executeMigration(ctx context.Context, db *sql.DB) error { } func (d *Database) create(db *sql.DB) error { + if err := CreateRegistrationTokensTable(db); err != nil { + return err + } if err := CreateEventStateKeysTable(db); err != nil { return err } diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index cefa58a3d0..3e316b8827 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -46,15 +46,20 @@ type Database struct { // EventDatabase contains all tables needed to work with events type EventDatabase struct { - DB *sql.DB - Cache caching.RoomServerCaches - Writer sqlutil.Writer - EventsTable tables.Events - EventJSONTable tables.EventJSON - EventTypesTable tables.EventTypes - EventStateKeysTable tables.EventStateKeys - PrevEventsTable tables.PreviousEvents - RedactionsTable tables.Redactions + DB *sql.DB + Cache caching.RoomServerCaches + Writer sqlutil.Writer + EventsTable tables.Events + EventJSONTable tables.EventJSON + EventTypesTable tables.EventTypes + EventStateKeysTable tables.EventStateKeys + PrevEventsTable tables.PreviousEvents + RedactionsTable tables.Redactions + RegistrationTokensTable tables.RegistrationTokens +} + +func (d *Database) RegistrationTokenExists(ctx context.Context, token string) (bool, error) { + return d.RegistrationTokensTable.RegistrationTokenExists(ctx, nil, token) } func (d *Database) SupportsConcurrentRoomInputs() bool { diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index 333483b324..471b341ebd 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -19,6 +19,10 @@ type EventJSONPair struct { EventJSON []byte } +type RegistrationTokens interface { + RegistrationTokenExists(ctx context.Context, tx *sql.Tx, token string) (bool, error) +} + type EventJSON interface { // Insert the event JSON. On conflict, replace the event JSON with the new value (for redactions). InsertEventJSON(ctx context.Context, tx *sql.Tx, eventNID types.EventNID, eventJSON []byte) error From 6cd6af150b18d374aa8475d7c936bd340f641e9d Mon Sep 17 00:00:00 2001 From: santhoshivan23 Date: Tue, 6 Jun 2023 00:18:05 +0530 Subject: [PATCH 05/21] refactoring --- .../postgres/registration_tokens_table.go | 2 - roomserver/storage/shared/storage.go | 42 +++++++++---------- 2 files changed, 21 insertions(+), 23 deletions(-) diff --git a/roomserver/storage/postgres/registration_tokens_table.go b/roomserver/storage/postgres/registration_tokens_table.go index 1f69f42d8f..8fd0e41f99 100644 --- a/roomserver/storage/postgres/registration_tokens_table.go +++ b/roomserver/storage/postgres/registration_tokens_table.go @@ -3,7 +3,6 @@ package postgres import ( "context" "database/sql" - "fmt" ) const registrationTokensSchema = ` @@ -22,6 +21,5 @@ func CreateRegistrationTokensTable(db *sql.DB) error { } func RegistrationTokenExists(ctx context.Context, tx *sql.Tx, token string) (bool, error) { - fmt.Println("here!!") return true, nil } diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 3e316b8827..0a8a358efe 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -31,31 +31,31 @@ const redactionsArePermanent = true type Database struct { DB *sql.DB EventDatabase - Cache caching.RoomServerCaches - Writer sqlutil.Writer - RoomsTable tables.Rooms - StateSnapshotTable tables.StateSnapshot - StateBlockTable tables.StateBlock - RoomAliasesTable tables.RoomAliases - InvitesTable tables.Invites - MembershipTable tables.Membership - PublishedTable tables.Published - Purge tables.Purge - GetRoomUpdaterFn func(ctx context.Context, roomInfo *types.RoomInfo) (*RoomUpdater, error) + Cache caching.RoomServerCaches + Writer sqlutil.Writer + RoomsTable tables.Rooms + StateSnapshotTable tables.StateSnapshot + StateBlockTable tables.StateBlock + RoomAliasesTable tables.RoomAliases + InvitesTable tables.Invites + MembershipTable tables.Membership + PublishedTable tables.Published + Purge tables.Purge + GetRoomUpdaterFn func(ctx context.Context, roomInfo *types.RoomInfo) (*RoomUpdater, error) + RegistrationTokensTable tables.RegistrationTokens } // EventDatabase contains all tables needed to work with events type EventDatabase struct { - DB *sql.DB - Cache caching.RoomServerCaches - Writer sqlutil.Writer - EventsTable tables.Events - EventJSONTable tables.EventJSON - EventTypesTable tables.EventTypes - EventStateKeysTable tables.EventStateKeys - PrevEventsTable tables.PreviousEvents - RedactionsTable tables.Redactions - RegistrationTokensTable tables.RegistrationTokens + DB *sql.DB + Cache caching.RoomServerCaches + Writer sqlutil.Writer + EventsTable tables.Events + EventJSONTable tables.EventJSON + EventTypesTable tables.EventTypes + EventStateKeysTable tables.EventStateKeys + PrevEventsTable tables.PreviousEvents + RedactionsTable tables.Redactions } func (d *Database) RegistrationTokenExists(ctx context.Context, token string) (bool, error) { From fe2464fd4b3b9384cfc62937d3b7302cb56d3dfc Mon Sep 17 00:00:00 2001 From: santhoshivan23 Date: Tue, 6 Jun 2023 21:52:05 +0530 Subject: [PATCH 06/21] Move DB Layer to UserAPI --- clientapi/routing/admin.go | 13 +++- clientapi/routing/routing.go | 2 +- roomserver/api/api.go | 1 - roomserver/internal/perform/perform_admin.go | 15 ----- roomserver/storage/interface.go | 1 - .../postgres/registration_tokens_table.go | 25 ------- roomserver/storage/postgres/storage.go | 3 - roomserver/storage/shared/storage.go | 27 ++++---- roomserver/storage/tables/interface.go | 4 -- userapi/api/api.go | 1 + userapi/internal/user_api.go | 15 +++++ userapi/storage/interface.go | 6 ++ .../postgres/registration_tokens_table.go | 66 +++++++++++++++++++ userapi/storage/postgres/storage.go | 5 ++ userapi/storage/shared/storage.go | 13 ++++ userapi/storage/tables/interface.go | 5 ++ 16 files changed, 133 insertions(+), 69 deletions(-) delete mode 100644 roomserver/storage/postgres/registration_tokens_table.go create mode 100644 userapi/storage/postgres/registration_tokens_table.go diff --git a/clientapi/routing/admin.go b/clientapi/routing/admin.go index 558e011e30..a0a3072733 100644 --- a/clientapi/routing/admin.go +++ b/clientapi/routing/admin.go @@ -25,6 +25,7 @@ import ( "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/userapi/api" + userapi "github.com/matrix-org/dendrite/userapi/api" ) func generateRandomToken(length int) string { @@ -38,7 +39,7 @@ func generateRandomToken(length int) string { return sb.String() } -func AdminCreateNewRegistrationToken(req *http.Request, cfg *config.ClientAPI, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse { +func AdminCreateNewRegistrationToken(req *http.Request, cfg *config.ClientAPI, userAPI userapi.ClientUserAPI) util.JSONResponse { if !cfg.RegistrationRequiresToken { return util.MatrixErrorResponse( http.StatusForbidden, @@ -67,7 +68,11 @@ func AdminCreateNewRegistrationToken(req *http.Request, cfg *config.ClientAPI, r length := request.Length if len(token) == 0 { - // Token not present in request body. Hence, generate a random token. + if length == 0 { + // length not provided in request. Assign default value of 16. + length = 16 + } + // token not present in request body. Hence, generate a random token. if !(length > 0 && length <= 64) { return util.MatrixErrorResponse( http.StatusBadRequest, @@ -109,7 +114,9 @@ func AdminCreateNewRegistrationToken(req *http.Request, cfg *config.ClientAPI, r } pending := 0 completed := 0 - created, err := rsAPI.PerformAdminCreateRegistrationToken(req.Context(), token, usesAllowed, int32(pending), int32(completed), expiryTime) + // If usesAllowed or expiryTime is 0, it means they are not present in the request. NULL (indicating + // unlimited uses / no expiration will be persisted in DB) + created, err := userAPI.PerformAdminCreateRegistrationToken(req.Context(), token, usesAllowed, expiryTime) if err != nil { return util.MatrixErrorResponse( http.StatusInternalServerError, diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index efa3f45e82..bbca602278 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -164,7 +164,7 @@ func Setup( } dendriteAdminRouter.Handle("/admin/registrationTokens/new", httputil.MakeAdminAPI("admin_registration_tokens_new", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - return AdminCreateNewRegistrationToken(req, cfg, rsAPI) + return AdminCreateNewRegistrationToken(req, cfg, userAPI) }), ).Methods(http.MethodPost, http.MethodOptions) diff --git a/roomserver/api/api.go b/roomserver/api/api.go index 54762b6ffc..7cb3379e03 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -173,7 +173,6 @@ type ClientRoomserverAPI interface { PerformCreateRoom(ctx context.Context, userID spec.UserID, roomID spec.RoomID, createRequest *PerformCreateRoomRequest) (string, *util.JSONResponse) // PerformRoomUpgrade upgrades a room to a newer version PerformRoomUpgrade(ctx context.Context, roomID, userID string, roomVersion gomatrixserverlib.RoomVersion) (newRoomID string, err error) - PerformAdminCreateRegistrationToken(ctx context.Context, token string, usesAllowed, pending, completed int32, expiryTime int64) (bool, error) PerformAdminEvacuateRoom(ctx context.Context, roomID string) (affected []string, err error) PerformAdminEvacuateUser(ctx context.Context, userID string) (affected []string, err error) PerformAdminPurgeRoom(ctx context.Context, roomID string) error diff --git a/roomserver/internal/perform/perform_admin.go b/roomserver/internal/perform/perform_admin.go index 292d91f236..575525e21b 100644 --- a/roomserver/internal/perform/perform_admin.go +++ b/roomserver/internal/perform/perform_admin.go @@ -42,21 +42,6 @@ type Admin struct { Leaver *Leaver } -func (r *Admin) PerformAdminCreateRegistrationToken( - ctx context.Context, token string, - usesAllowed, pending, completed int32, - expiryTime int64) (bool, error) { - exists, err := r.DB.RegistrationTokenExists(ctx, token) - if err != nil { - return false, err - } - if exists { - fmt.Println(fmt.Sprintf("token: %s already exists", token)) - return false, fmt.Errorf("token: %s already exists", token) - } - return true, nil -} - // PerformAdminEvacuateRoom will remove all local users from the given room. func (r *Admin) PerformAdminEvacuateRoom( ctx context.Context, diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index 4cf8f3b3a5..7d22df0084 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -27,7 +27,6 @@ import ( ) type Database interface { - RegistrationTokenExists(ctx context.Context, token string) (bool, error) // Do we support processing input events for more than one room at a time? SupportsConcurrentRoomInputs() bool // RoomInfo returns room information for the given room ID, or nil if there is no room. diff --git a/roomserver/storage/postgres/registration_tokens_table.go b/roomserver/storage/postgres/registration_tokens_table.go deleted file mode 100644 index 8fd0e41f99..0000000000 --- a/roomserver/storage/postgres/registration_tokens_table.go +++ /dev/null @@ -1,25 +0,0 @@ -package postgres - -import ( - "context" - "database/sql" -) - -const registrationTokensSchema = ` -CREATE TABLE IF NOT EXISTS roomserver_registration_tokens ( - token TEXT PRIMARY KEY, - pending BIGINT, - completed BIGINT, - uses_allowed BIGINT, - expiry_time BIGINT -); -` - -func CreateRegistrationTokensTable(db *sql.DB) error { - _, err := db.Exec(registrationTokensSchema) - return err -} - -func RegistrationTokenExists(ctx context.Context, tx *sql.Tx, token string) (bool, error) { - return true, nil -} diff --git a/roomserver/storage/postgres/storage.go b/roomserver/storage/postgres/storage.go index 5836ab1532..19cde54105 100644 --- a/roomserver/storage/postgres/storage.go +++ b/roomserver/storage/postgres/storage.go @@ -92,9 +92,6 @@ func executeMigration(ctx context.Context, db *sql.DB) error { } func (d *Database) create(db *sql.DB) error { - if err := CreateRegistrationTokensTable(db); err != nil { - return err - } if err := CreateEventStateKeysTable(db); err != nil { return err } diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 0a8a358efe..cefa58a3d0 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -31,18 +31,17 @@ const redactionsArePermanent = true type Database struct { DB *sql.DB EventDatabase - Cache caching.RoomServerCaches - Writer sqlutil.Writer - RoomsTable tables.Rooms - StateSnapshotTable tables.StateSnapshot - StateBlockTable tables.StateBlock - RoomAliasesTable tables.RoomAliases - InvitesTable tables.Invites - MembershipTable tables.Membership - PublishedTable tables.Published - Purge tables.Purge - GetRoomUpdaterFn func(ctx context.Context, roomInfo *types.RoomInfo) (*RoomUpdater, error) - RegistrationTokensTable tables.RegistrationTokens + Cache caching.RoomServerCaches + Writer sqlutil.Writer + RoomsTable tables.Rooms + StateSnapshotTable tables.StateSnapshot + StateBlockTable tables.StateBlock + RoomAliasesTable tables.RoomAliases + InvitesTable tables.Invites + MembershipTable tables.Membership + PublishedTable tables.Published + Purge tables.Purge + GetRoomUpdaterFn func(ctx context.Context, roomInfo *types.RoomInfo) (*RoomUpdater, error) } // EventDatabase contains all tables needed to work with events @@ -58,10 +57,6 @@ type EventDatabase struct { RedactionsTable tables.Redactions } -func (d *Database) RegistrationTokenExists(ctx context.Context, token string) (bool, error) { - return d.RegistrationTokensTable.RegistrationTokenExists(ctx, nil, token) -} - func (d *Database) SupportsConcurrentRoomInputs() bool { return true } diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index 471b341ebd..333483b324 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -19,10 +19,6 @@ type EventJSONPair struct { EventJSON []byte } -type RegistrationTokens interface { - RegistrationTokenExists(ctx context.Context, tx *sql.Tx, token string) (bool, error) -} - type EventJSON interface { // Insert the event JSON. On conflict, replace the event JSON with the new value (for redactions). InsertEventJSON(ctx context.Context, tx *sql.Tx, eventNID types.EventNID, eventJSON []byte) error diff --git a/userapi/api/api.go b/userapi/api/api.go index 0504026451..1dfae8ed1d 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -94,6 +94,7 @@ type ClientUserAPI interface { QueryPushers(ctx context.Context, req *QueryPushersRequest, res *QueryPushersResponse) error QueryPushRules(ctx context.Context, userID string) (*pushrules.AccountRuleSets, error) QueryAccountAvailability(ctx context.Context, req *QueryAccountAvailabilityRequest, res *QueryAccountAvailabilityResponse) error + PerformAdminCreateRegistrationToken(ctx context.Context, token string, usesAllowed int32, expiryTime int64) (bool, error) PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error diff --git a/userapi/internal/user_api.go b/userapi/internal/user_api.go index 32f3d84b5a..8f388ab821 100644 --- a/userapi/internal/user_api.go +++ b/userapi/internal/user_api.go @@ -63,6 +63,21 @@ type UserInternalAPI struct { Updater *DeviceListUpdater } +func (a *UserInternalAPI) PerformAdminCreateRegistrationToken(ctx context.Context, token string, usesAllowed int32, expiryTime int64) (bool, error) { + exists, err := a.DB.RegistrationTokenExists(ctx, token) + if err != nil { + return false, err + } + if exists { + return false, fmt.Errorf("token: %s already exists", token) + } + _, err = a.DB.InsertRegistrationToken(ctx, token, usesAllowed, expiryTime) + if err != nil { + return false, fmt.Errorf("Error creating token: %s"+err.Error(), token) + } + return true, nil +} + func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error { local, domain, err := gomatrixserverlib.SplitID('@', req.UserID) if err != nil { diff --git a/userapi/storage/interface.go b/userapi/storage/interface.go index 4f5e99a8a1..8815df68f0 100644 --- a/userapi/storage/interface.go +++ b/userapi/storage/interface.go @@ -30,6 +30,11 @@ import ( "github.com/matrix-org/dendrite/userapi/types" ) +type RegistrationTokens interface { + RegistrationTokenExists(ctx context.Context, token string) (bool, error) + InsertRegistrationToken(ctx context.Context, token string, usesAllowed int32, expiryTime int64) (bool, error) +} + type Profile interface { GetProfileByLocalpart(ctx context.Context, localpart string, serverName spec.ServerName) (*authtypes.Profile, error) SearchProfiles(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error) @@ -144,6 +149,7 @@ type UserDatabase interface { Pusher Statistics ThreePID + RegistrationTokens } type KeyChangeDatabase interface { diff --git a/userapi/storage/postgres/registration_tokens_table.go b/userapi/storage/postgres/registration_tokens_table.go new file mode 100644 index 0000000000..750e53b26c --- /dev/null +++ b/userapi/storage/postgres/registration_tokens_table.go @@ -0,0 +1,66 @@ +package postgres + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/storage/tables" +) + +const registrationTokensSchema = ` +CREATE TABLE IF NOT EXISTS userapi_registration_tokens ( + token TEXT PRIMARY KEY, + pending BIGINT, + completed BIGINT, + uses_allowed BIGINT, + expiry_time BIGINT +); +` + +const selectTokenSQL = "" + + "SELECT token FROM userapi_registration_tokens WHERE token = $1" + +const insertTokenSQL = "" + + "INSERT INTO userapi_registration_tokens (token, uses_allowed, expiry_time, pending, completed) VALUES ($1, $2, $3, $4, $5)" + +type registrationTokenStatements struct { + selectTokenStatement *sql.Stmt + insertTokenStatment *sql.Stmt +} + +func NewPostgresRegistrationTokensTable(db *sql.DB) (tables.RegistrationTokensTable, error) { + s := ®istrationTokenStatements{} + _, err := db.Exec(registrationTokensSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ + {&s.selectTokenStatement, selectTokenSQL}, + {&s.insertTokenStatment, insertTokenSQL}, + }.Prepare(db) +} + +func (s *registrationTokenStatements) RegistrationTokenExists(ctx context.Context, tx *sql.Tx, token string) (bool, error) { + var existingToken string + stmt := s.selectTokenStatement + err := stmt.QueryRowContext(ctx, token).Scan(&existingToken) + if err != nil { + if err == sql.ErrNoRows { + return false, nil + } + return false, err + } + return true, nil +} + +func (s *registrationTokenStatements) InsertRegistrationToken(ctx context.Context, tx *sql.Tx, token string, usesAllowed int32, expiryTime int64) (bool, error) { + stmt := sqlutil.TxStmt(tx, s.insertTokenStatment) + pending := 0 + completed := 0 + _, err := stmt.ExecContext(ctx, token, nil, expiryTime, pending, completed) + if err != nil { + return false, err + } + return true, nil +} diff --git a/userapi/storage/postgres/storage.go b/userapi/storage/postgres/storage.go index 72e7c9cd90..d01ccc7764 100644 --- a/userapi/storage/postgres/storage.go +++ b/userapi/storage/postgres/storage.go @@ -53,6 +53,10 @@ func NewDatabase(ctx context.Context, conMan sqlutil.Connections, dbProperties * return nil, err } + registationTokensTable, err := NewPostgresRegistrationTokensTable(db) + if err != nil { + return nil, fmt.Errorf("NewPostgresRegistrationsTokenTable: %w", err) + } accountsTable, err := NewPostgresAccountsTable(db, serverName) if err != nil { return nil, fmt.Errorf("NewPostgresAccountsTable: %w", err) @@ -125,6 +129,7 @@ func NewDatabase(ctx context.Context, conMan sqlutil.Connections, dbProperties * ThreePIDs: threePIDTable, Pushers: pusherTable, Notifications: notificationsTable, + RegistrationTokens: registationTokensTable, Stats: statsTable, ServerName: serverName, DB: db, diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go index 537bbbf4ae..9ec210391e 100644 --- a/userapi/storage/shared/storage.go +++ b/userapi/storage/shared/storage.go @@ -43,6 +43,7 @@ import ( type Database struct { DB *sql.DB Writer sqlutil.Writer + RegistrationTokens tables.RegistrationTokensTable Accounts tables.AccountsTable Profiles tables.ProfileTable AccountDatas tables.AccountDataTable @@ -78,6 +79,18 @@ const ( loginTokenByteLength = 32 ) +func (d *Database) RegistrationTokenExists(ctx context.Context, token string) (bool, error) { + return d.RegistrationTokens.RegistrationTokenExists(ctx, nil, token) +} + +func (d *Database) InsertRegistrationToken(ctx context.Context, token string, usesAllowed int32, expiryTime int64) (created bool, err error) { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + created, err = d.RegistrationTokens.InsertRegistrationToken(ctx, txn, token, usesAllowed, expiryTime) + return err + }) + return +} + // GetAccountByPassword returns the account associated with the given localpart and password. // Returns sql.ErrNoRows if no account exists which matches the given localpart. func (d *Database) GetAccountByPassword( diff --git a/userapi/storage/tables/interface.go b/userapi/storage/tables/interface.go index 3c6214e7c6..41c99baed6 100644 --- a/userapi/storage/tables/interface.go +++ b/userapi/storage/tables/interface.go @@ -29,6 +29,11 @@ import ( "github.com/matrix-org/dendrite/userapi/types" ) +type RegistrationTokensTable interface { + RegistrationTokenExists(ctx context.Context, txn *sql.Tx, token string) (bool, error) + InsertRegistrationToken(ctx context.Context, txn *sql.Tx, token string, usesAllowed int32, expiryTime int64) (bool, error) +} + type AccountDataTable interface { InsertAccountData(ctx context.Context, txn *sql.Tx, localpart string, serverName spec.ServerName, roomID, dataType string, content json.RawMessage) error SelectAccountData(ctx context.Context, localpart string, serverName spec.ServerName) (map[string]json.RawMessage, map[string]map[string]json.RawMessage, error) From 8972c4b20d560b0a7ebe486f0f3a2700d95a376a Mon Sep 17 00:00:00 2001 From: santhoshivan23 Date: Tue, 6 Jun 2023 21:57:34 +0530 Subject: [PATCH 07/21] format admin.go --- clientapi/routing/admin.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/clientapi/routing/admin.go b/clientapi/routing/admin.go index a0a3072733..8accef5359 100644 --- a/clientapi/routing/admin.go +++ b/clientapi/routing/admin.go @@ -114,8 +114,7 @@ func AdminCreateNewRegistrationToken(req *http.Request, cfg *config.ClientAPI, u } pending := 0 completed := 0 - // If usesAllowed or expiryTime is 0, it means they are not present in the request. NULL (indicating - // unlimited uses / no expiration will be persisted in DB) + // If usesAllowed or expiryTime is 0, it means they are not present in the request. NULL (indicating unlimited uses / no expiration will be persisted in DB) created, err := userAPI.PerformAdminCreateRegistrationToken(req.Context(), token, usesAllowed, expiryTime) if err != nil { return util.MatrixErrorResponse( From 99fa964b62f4fdeee81dfdf24c5cf6e4b9f67d46 Mon Sep 17 00:00:00 2001 From: santhoshivan23 Date: Tue, 6 Jun 2023 22:21:21 +0530 Subject: [PATCH 08/21] handle cases when request field is not present --- clientapi/routing/admin.go | 40 +++++++++++++------ .../postgres/registration_tokens_table.go | 16 +++++++- 2 files changed, 42 insertions(+), 14 deletions(-) diff --git a/clientapi/routing/admin.go b/clientapi/routing/admin.go index 8accef5359..f2a391c9ca 100644 --- a/clientapi/routing/admin.go +++ b/clientapi/routing/admin.go @@ -28,17 +28,6 @@ import ( userapi "github.com/matrix-org/dendrite/userapi/api" ) -func generateRandomToken(length int) string { - allowedChars := "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_" - rand.Seed(time.Now().UnixNano()) - var sb strings.Builder - for i := 0; i < length; i++ { - randomIndex := rand.Intn(len(allowedChars)) - sb.WriteByte(allowedChars[randomIndex]) - } - return sb.String() -} - func AdminCreateNewRegistrationToken(req *http.Request, cfg *config.ClientAPI, userAPI userapi.ClientUserAPI) util.JSONResponse { if !cfg.RegistrationRequiresToken { return util.MatrixErrorResponse( @@ -133,14 +122,39 @@ func AdminCreateNewRegistrationToken(req *http.Request, cfg *config.ClientAPI, u Code: 200, JSON: map[string]interface{}{ "token": token, - "uses_allowed": usesAllowed, + "uses_allowed": getReturnValueForUsesAllowed(usesAllowed), "pending": pending, "completed": completed, - "expiry_time": expiryTime, + "expiry_time": getReturnValueExpiryTime(expiryTime), }, } } +func generateRandomToken(length int) string { + allowedChars := "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_" + rand.Seed(time.Now().UnixNano()) + var sb strings.Builder + for i := 0; i < length; i++ { + randomIndex := rand.Intn(len(allowedChars)) + sb.WriteByte(allowedChars[randomIndex]) + } + return sb.String() +} + +func getReturnValueForUsesAllowed(usesAllowed int32) interface{} { + if usesAllowed == 0 { + return nil + } + return usesAllowed +} + +func getReturnValueExpiryTime(expiryTime int64) interface{} { + if expiryTime == 0 { + return nil + } + return expiryTime +} + func AdminEvacuateRoom(req *http.Request, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { diff --git a/userapi/storage/postgres/registration_tokens_table.go b/userapi/storage/postgres/registration_tokens_table.go index 750e53b26c..6c55444c00 100644 --- a/userapi/storage/postgres/registration_tokens_table.go +++ b/userapi/storage/postgres/registration_tokens_table.go @@ -58,9 +58,23 @@ func (s *registrationTokenStatements) InsertRegistrationToken(ctx context.Contex stmt := sqlutil.TxStmt(tx, s.insertTokenStatment) pending := 0 completed := 0 - _, err := stmt.ExecContext(ctx, token, nil, expiryTime, pending, completed) + _, err := stmt.ExecContext(ctx, token, nullIfZeroInt32(usesAllowed), nullIfZero(expiryTime), pending, completed) if err != nil { return false, err } return true, nil } + +func nullIfZero(value int64) interface{} { + if value == 0 { + return nil + } + return value +} + +func nullIfZeroInt32(value int32) interface{} { + if value == 0 { + return nil + } + return value +} From 4b73df5335295bc5e50d864c7be2a561c6b912bc Mon Sep 17 00:00:00 2001 From: santhoshivan23 Date: Thu, 8 Jun 2023 00:39:58 +0530 Subject: [PATCH 09/21] Implement ListTokens --- clientapi/api/api.go | 8 ++ clientapi/routing/admin.go | 49 +++++++++++- clientapi/routing/routing.go | 8 +- userapi/api/api.go | 4 +- userapi/internal/user_api.go | 19 +++-- userapi/storage/interface.go | 4 +- .../postgres/registration_tokens_table.go | 74 +++++++++++++++++-- userapi/storage/shared/storage.go | 9 ++- userapi/storage/tables/interface.go | 4 +- 9 files changed, 157 insertions(+), 22 deletions(-) diff --git a/clientapi/api/api.go b/clientapi/api/api.go index 23974c8658..28ff593fcc 100644 --- a/clientapi/api/api.go +++ b/clientapi/api/api.go @@ -21,3 +21,11 @@ type ExtraPublicRoomsProvider interface { // Rooms returns the extra rooms. This is called on-demand by clients, so cache appropriately. Rooms() []fclient.PublicRoom } + +type RegistrationToken struct { + Token *string `json:"token"` + UsesAllowed *int32 `json:"uses_allowed"` + Pending *int32 `json:"pending"` + Completed *int32 `json:"completed"` + ExpiryTime *int64 `json:"expiry_time"` +} diff --git a/clientapi/routing/admin.go b/clientapi/routing/admin.go index f2a391c9ca..d0608f7aa6 100644 --- a/clientapi/routing/admin.go +++ b/clientapi/routing/admin.go @@ -8,6 +8,7 @@ import ( "math/rand" "net/http" "regexp" + "strconv" "strings" "time" @@ -20,6 +21,7 @@ import ( "github.com/nats-io/nats.go" "github.com/sirupsen/logrus" + clientapi "github.com/matrix-org/dendrite/clientapi/api" "github.com/matrix-org/dendrite/internal/httputil" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" @@ -101,13 +103,20 @@ func AdminCreateNewRegistrationToken(req *http.Request, cfg *config.ClientAPI, u string(spec.ErrorInvalidParam), "expiry_time must not be in the past") } - pending := 0 - completed := 0 + pending := int32(0) + completed := int32(0) // If usesAllowed or expiryTime is 0, it means they are not present in the request. NULL (indicating unlimited uses / no expiration will be persisted in DB) - created, err := userAPI.PerformAdminCreateRegistrationToken(req.Context(), token, usesAllowed, expiryTime) + registrationToken := &clientapi.RegistrationToken{ + Token: &token, + UsesAllowed: &usesAllowed, + Pending: &pending, + Completed: &completed, + ExpiryTime: &expiryTime, + } + created, err := userAPI.PerformAdminCreateRegistrationToken(req.Context(), registrationToken) if err != nil { return util.MatrixErrorResponse( - http.StatusInternalServerError, + http.StatusBadRequest, string(spec.ErrorUnknown), err.Error(), ) @@ -148,6 +157,38 @@ func getReturnValueForUsesAllowed(usesAllowed int32) interface{} { return usesAllowed } +func AdminListRegistrationTokens(req *http.Request, cfg *config.ClientAPI, userAPI userapi.ClientUserAPI) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.MatrixErrorResponse( + http.StatusInternalServerError, + string(spec.ErrorInvalidParam), + "unable to parse query params", + ) + } + returnAll := true + validQuery, ok := vars["valid"] + if ok { + returnAll = false + } + valid, err := strconv.ParseBool(validQuery) + tokens, err := userAPI.PerformAdminListRegistrationTokens(req.Context(), returnAll, valid) + if err != nil { + return util.MatrixErrorResponse( + http.StatusInternalServerError, + string(spec.ErrorUnknown), + "error fetching registration tokens", + ) + } + + return util.JSONResponse{ + Code: 200, + JSON: map[string]interface{}{ + "registration_tokens": tokens, + }, + } +} + func getReturnValueExpiryTime(expiryTime int64) interface{} { if expiryTime == 0 { return nil diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index bbca602278..2d96e05ccd 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -168,11 +168,17 @@ func Setup( }), ).Methods(http.MethodPost, http.MethodOptions) + dendriteAdminRouter.Handle("/admin/registrationTokens", + httputil.MakeAdminAPI("admin_registration_tokens", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + return AdminListRegistrationTokens(req, cfg, userAPI) + }), + ).Methods(http.MethodGet, http.MethodOptions) + dendriteAdminRouter.Handle("/admin/evacuateRoom/{roomID}", httputil.MakeAdminAPI("admin_evacuate_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return AdminEvacuateRoom(req, rsAPI) }), - ).Methods(http.MethodPost, http.MethodOptions) + ).Methods(http.MethodGet, http.MethodOptions) dendriteAdminRouter.Handle("/admin/evacuateUser/{userID}", httputil.MakeAdminAPI("admin_evacuate_user", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { diff --git a/userapi/api/api.go b/userapi/api/api.go index 1dfae8ed1d..9f014d3f03 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -27,6 +27,7 @@ import ( "github.com/matrix-org/gomatrixserverlib/fclient" "github.com/matrix-org/gomatrixserverlib/spec" + clientapi "github.com/matrix-org/dendrite/clientapi/api" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/internal/pushrules" ) @@ -94,7 +95,8 @@ type ClientUserAPI interface { QueryPushers(ctx context.Context, req *QueryPushersRequest, res *QueryPushersResponse) error QueryPushRules(ctx context.Context, userID string) (*pushrules.AccountRuleSets, error) QueryAccountAvailability(ctx context.Context, req *QueryAccountAvailabilityRequest, res *QueryAccountAvailabilityResponse) error - PerformAdminCreateRegistrationToken(ctx context.Context, token string, usesAllowed int32, expiryTime int64) (bool, error) + PerformAdminCreateRegistrationToken(ctx context.Context, registrationToken *clientapi.RegistrationToken) (bool, error) + PerformAdminListRegistrationTokens(ctx context.Context, returnAll bool, valid bool) ([]clientapi.RegistrationToken, error) PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error diff --git a/userapi/internal/user_api.go b/userapi/internal/user_api.go index 8f388ab821..65ea6a8689 100644 --- a/userapi/internal/user_api.go +++ b/userapi/internal/user_api.go @@ -33,6 +33,7 @@ import ( "github.com/sirupsen/logrus" "golang.org/x/crypto/bcrypt" + clientapi "github.com/matrix-org/dendrite/clientapi/api" "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/pushgateway" @@ -63,21 +64,29 @@ type UserInternalAPI struct { Updater *DeviceListUpdater } -func (a *UserInternalAPI) PerformAdminCreateRegistrationToken(ctx context.Context, token string, usesAllowed int32, expiryTime int64) (bool, error) { - exists, err := a.DB.RegistrationTokenExists(ctx, token) +func (a *UserInternalAPI) PerformAdminCreateRegistrationToken(ctx context.Context, registrationToken *clientapi.RegistrationToken) (bool, error) { + exists, err := a.DB.RegistrationTokenExists(ctx, *registrationToken.Token) if err != nil { return false, err } if exists { - return false, fmt.Errorf("token: %s already exists", token) + return false, fmt.Errorf("token: %s already exists", *registrationToken.Token) } - _, err = a.DB.InsertRegistrationToken(ctx, token, usesAllowed, expiryTime) + _, err = a.DB.InsertRegistrationToken(ctx, registrationToken) if err != nil { - return false, fmt.Errorf("Error creating token: %s"+err.Error(), token) + return false, fmt.Errorf("Error creating token: %s"+err.Error(), *registrationToken.Token) } return true, nil } +func (a *UserInternalAPI) PerformAdminListRegistrationTokens(ctx context.Context, returnAll bool, valid bool) ([]clientapi.RegistrationToken, error) { + tokens, err := a.DB.ListRegistrationTokens(ctx, returnAll, valid) + if err != nil { + return nil, err + } + return tokens, nil +} + func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error { local, domain, err := gomatrixserverlib.SplitID('@', req.UserID) if err != nil { diff --git a/userapi/storage/interface.go b/userapi/storage/interface.go index 8815df68f0..986da99b5e 100644 --- a/userapi/storage/interface.go +++ b/userapi/storage/interface.go @@ -23,6 +23,7 @@ import ( "github.com/matrix-org/gomatrixserverlib/fclient" "github.com/matrix-org/gomatrixserverlib/spec" + clientapi "github.com/matrix-org/dendrite/clientapi/api" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/internal/pushrules" "github.com/matrix-org/dendrite/userapi/api" @@ -32,7 +33,8 @@ import ( type RegistrationTokens interface { RegistrationTokenExists(ctx context.Context, token string) (bool, error) - InsertRegistrationToken(ctx context.Context, token string, usesAllowed int32, expiryTime int64) (bool, error) + InsertRegistrationToken(ctx context.Context, registrationToken *clientapi.RegistrationToken) (bool, error) + ListRegistrationTokens(ctx context.Context, returnAll bool, valid bool) ([]clientapi.RegistrationToken, error) } type Profile interface { diff --git a/userapi/storage/postgres/registration_tokens_table.go b/userapi/storage/postgres/registration_tokens_table.go index 6c55444c00..666fb3c3e3 100644 --- a/userapi/storage/postgres/registration_tokens_table.go +++ b/userapi/storage/postgres/registration_tokens_table.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" + "github.com/matrix-org/dendrite/clientapi/api" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/storage/tables" ) @@ -24,9 +25,13 @@ const selectTokenSQL = "" + const insertTokenSQL = "" + "INSERT INTO userapi_registration_tokens (token, uses_allowed, expiry_time, pending, completed) VALUES ($1, $2, $3, $4, $5)" +const listTokensSQL = "" + + "SELECT * FROM userapi_registration_tokens" + type registrationTokenStatements struct { selectTokenStatement *sql.Stmt - insertTokenStatment *sql.Stmt + insertTokenStatement *sql.Stmt + listTokensStatement *sql.Stmt } func NewPostgresRegistrationTokensTable(db *sql.DB) (tables.RegistrationTokensTable, error) { @@ -37,7 +42,8 @@ func NewPostgresRegistrationTokensTable(db *sql.DB) (tables.RegistrationTokensTa } return s, sqlutil.StatementList{ {&s.selectTokenStatement, selectTokenSQL}, - {&s.insertTokenStatment, insertTokenSQL}, + {&s.insertTokenStatement, insertTokenSQL}, + {&s.listTokensStatement, listTokensSQL}, }.Prepare(db) } @@ -54,11 +60,15 @@ func (s *registrationTokenStatements) RegistrationTokenExists(ctx context.Contex return true, nil } -func (s *registrationTokenStatements) InsertRegistrationToken(ctx context.Context, tx *sql.Tx, token string, usesAllowed int32, expiryTime int64) (bool, error) { - stmt := sqlutil.TxStmt(tx, s.insertTokenStatment) - pending := 0 - completed := 0 - _, err := stmt.ExecContext(ctx, token, nullIfZeroInt32(usesAllowed), nullIfZero(expiryTime), pending, completed) +func (s *registrationTokenStatements) InsertRegistrationToken(ctx context.Context, tx *sql.Tx, registrationToken *api.RegistrationToken) (bool, error) { + stmt := sqlutil.TxStmt(tx, s.insertTokenStatement) + _, err := stmt.ExecContext( + ctx, + *registrationToken.Token, + nullIfZeroInt32(*registrationToken.UsesAllowed), + nullIfZero(*registrationToken.ExpiryTime), + *registrationToken.Pending, + *registrationToken.Completed) if err != nil { return false, err } @@ -78,3 +88,53 @@ func nullIfZeroInt32(value int32) interface{} { } return value } + +func (s *registrationTokenStatements) ListRegistrationTokens(ctx context.Context, tx *sql.Tx, returnAll bool, valid bool) ([]api.RegistrationToken, error) { + var stmt *sql.Stmt + var tokens []api.RegistrationToken + var tokenString sql.NullString + var pending, completed, usesAllowed sql.NullInt32 + var expiryTime sql.NullInt64 + if returnAll { + stmt = s.listTokensStatement + } else if valid { + // TODO: Statement to Get All Valid Tokens + } else { + // TODO: Statement to Get All Invalid Tokens + } + rows, err := stmt.QueryContext(ctx) + if err != nil { + return tokens, err + } + for rows.Next() { + err = rows.Scan(&tokenString, &pending, &completed, &usesAllowed, &expiryTime) + if err != nil { + return tokens, err + } + tokenMap := api.RegistrationToken{ + Token: &tokenString.String, + Pending: &pending.Int32, + Completed: &pending.Int32, + UsesAllowed: getReturnValueForInt32(usesAllowed), + ExpiryTime: getReturnValueForInt64(expiryTime), + } + tokens = append(tokens, tokenMap) + } + return tokens, nil +} + +func getReturnValueForInt32(value sql.NullInt32) *int32 { + if value.Valid { + returnValue := value.Int32 + return &returnValue + } + return nil +} + +func getReturnValueForInt64(value sql.NullInt64) *int64 { + if value.Valid { + returnValue := value.Int64 + return &returnValue + } + return nil +} diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go index 9ec210391e..58ebe30ecf 100644 --- a/userapi/storage/shared/storage.go +++ b/userapi/storage/shared/storage.go @@ -31,6 +31,7 @@ import ( "github.com/matrix-org/gomatrixserverlib/spec" "golang.org/x/crypto/bcrypt" + clientapi "github.com/matrix-org/dendrite/clientapi/api" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/internal/pushrules" "github.com/matrix-org/dendrite/internal/sqlutil" @@ -83,14 +84,18 @@ func (d *Database) RegistrationTokenExists(ctx context.Context, token string) (b return d.RegistrationTokens.RegistrationTokenExists(ctx, nil, token) } -func (d *Database) InsertRegistrationToken(ctx context.Context, token string, usesAllowed int32, expiryTime int64) (created bool, err error) { +func (d *Database) InsertRegistrationToken(ctx context.Context, registrationToken *clientapi.RegistrationToken) (created bool, err error) { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - created, err = d.RegistrationTokens.InsertRegistrationToken(ctx, txn, token, usesAllowed, expiryTime) + created, err = d.RegistrationTokens.InsertRegistrationToken(ctx, txn, registrationToken) return err }) return } +func (d *Database) ListRegistrationTokens(ctx context.Context, returnAll bool, valid bool) ([]clientapi.RegistrationToken, error) { + return d.RegistrationTokens.ListRegistrationTokens(ctx, nil, returnAll, valid) +} + // GetAccountByPassword returns the account associated with the given localpart and password. // Returns sql.ErrNoRows if no account exists which matches the given localpart. func (d *Database) GetAccountByPassword( diff --git a/userapi/storage/tables/interface.go b/userapi/storage/tables/interface.go index 41c99baed6..dfd52235ac 100644 --- a/userapi/storage/tables/interface.go +++ b/userapi/storage/tables/interface.go @@ -25,13 +25,15 @@ import ( "github.com/matrix-org/gomatrixserverlib/fclient" "github.com/matrix-org/gomatrixserverlib/spec" + clientapi "github.com/matrix-org/dendrite/clientapi/api" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/userapi/types" ) type RegistrationTokensTable interface { RegistrationTokenExists(ctx context.Context, txn *sql.Tx, token string) (bool, error) - InsertRegistrationToken(ctx context.Context, txn *sql.Tx, token string, usesAllowed int32, expiryTime int64) (bool, error) + InsertRegistrationToken(ctx context.Context, txn *sql.Tx, registrationToken *clientapi.RegistrationToken) (bool, error) + ListRegistrationTokens(ctx context.Context, txn *sql.Tx, returnAll bool, valid bool) ([]clientapi.RegistrationToken, error) } type AccountDataTable interface { From 86d2aa41c152f099f3e4fc4a77960cd664636b05 Mon Sep 17 00:00:00 2001 From: santhoshivan23 Date: Thu, 8 Jun 2023 10:03:34 +0530 Subject: [PATCH 10/21] implement filter by valid query param --- clientapi/routing/admin.go | 22 ++++++------ .../postgres/registration_tokens_table.go | 36 ++++++++++++++----- 2 files changed, 39 insertions(+), 19 deletions(-) diff --git a/clientapi/routing/admin.go b/clientapi/routing/admin.go index d0608f7aa6..96d557ae48 100644 --- a/clientapi/routing/admin.go +++ b/clientapi/routing/admin.go @@ -158,20 +158,22 @@ func getReturnValueForUsesAllowed(usesAllowed int32) interface{} { } func AdminListRegistrationTokens(req *http.Request, cfg *config.ClientAPI, userAPI userapi.ClientUserAPI) util.JSONResponse { - vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) - if err != nil { - return util.MatrixErrorResponse( - http.StatusInternalServerError, - string(spec.ErrorInvalidParam), - "unable to parse query params", - ) - } + queryParams := req.URL.Query() returnAll := true - validQuery, ok := vars["valid"] + valid := true + validQuery, ok := queryParams["valid"] if ok { returnAll = false + validValue, err := strconv.ParseBool(validQuery[0]) + if err != nil { + return util.MatrixErrorResponse( + http.StatusBadRequest, + string(spec.ErrorInvalidParam), + "invalid 'valid' query parameter", + ) + } + valid = validValue } - valid, err := strconv.ParseBool(validQuery) tokens, err := userAPI.PerformAdminListRegistrationTokens(req.Context(), returnAll, valid) if err != nil { return util.MatrixErrorResponse( diff --git a/userapi/storage/postgres/registration_tokens_table.go b/userapi/storage/postgres/registration_tokens_table.go index 666fb3c3e3..3bb8b19e83 100644 --- a/userapi/storage/postgres/registration_tokens_table.go +++ b/userapi/storage/postgres/registration_tokens_table.go @@ -3,6 +3,7 @@ package postgres import ( "context" "database/sql" + "time" "github.com/matrix-org/dendrite/clientapi/api" "github.com/matrix-org/dendrite/internal/sqlutil" @@ -25,13 +26,24 @@ const selectTokenSQL = "" + const insertTokenSQL = "" + "INSERT INTO userapi_registration_tokens (token, uses_allowed, expiry_time, pending, completed) VALUES ($1, $2, $3, $4, $5)" -const listTokensSQL = "" + +const listAllTokensSQL = "" + "SELECT * FROM userapi_registration_tokens" +const listValidTokensSQL = "" + + "SELECT * FROM userapi_registration_tokens WHERE" + + "(uses_allowed > pending + completed OR uses_allowed IS NULL) AND" + + "(expiry_time > $1 OR expiry_time IS NULL)" + +const listInvalidTokensSQL = "" + + "SELECT * FROM userapi_registration_tokens WHERE" + + "(uses_allowed <= pending + completed OR expiry_time <= $1)" + type registrationTokenStatements struct { - selectTokenStatement *sql.Stmt - insertTokenStatement *sql.Stmt - listTokensStatement *sql.Stmt + selectTokenStatement *sql.Stmt + insertTokenStatement *sql.Stmt + listAllTokensStatement *sql.Stmt + listValidTokensStatement *sql.Stmt + listInvalidTokenStatement *sql.Stmt } func NewPostgresRegistrationTokensTable(db *sql.DB) (tables.RegistrationTokensTable, error) { @@ -43,7 +55,9 @@ func NewPostgresRegistrationTokensTable(db *sql.DB) (tables.RegistrationTokensTa return s, sqlutil.StatementList{ {&s.selectTokenStatement, selectTokenSQL}, {&s.insertTokenStatement, insertTokenSQL}, - {&s.listTokensStatement, listTokensSQL}, + {&s.listAllTokensStatement, listAllTokensSQL}, + {&s.listValidTokensStatement, listValidTokensSQL}, + {&s.listInvalidTokenStatement, listInvalidTokensSQL}, }.Prepare(db) } @@ -95,14 +109,18 @@ func (s *registrationTokenStatements) ListRegistrationTokens(ctx context.Context var tokenString sql.NullString var pending, completed, usesAllowed sql.NullInt32 var expiryTime sql.NullInt64 + var rows *sql.Rows + var err error if returnAll { - stmt = s.listTokensStatement + stmt = s.listAllTokensStatement + rows, err = stmt.QueryContext(ctx) } else if valid { - // TODO: Statement to Get All Valid Tokens + stmt = s.listValidTokensStatement + rows, err = stmt.QueryContext(ctx, time.Now().UnixNano()/int64(time.Millisecond)) } else { - // TODO: Statement to Get All Invalid Tokens + stmt = s.listInvalidTokenStatement + rows, err = stmt.QueryContext(ctx, time.Now().UnixNano()/int64(time.Millisecond)) } - rows, err := stmt.QueryContext(ctx) if err != nil { return tokens, err } From 31f3125c260885005d3599d6d41345aacad4daea Mon Sep 17 00:00:00 2001 From: santhoshivan23 Date: Thu, 8 Jun 2023 19:14:35 +0530 Subject: [PATCH 11/21] Get and Delete APIs --- clientapi/routing/admin.go | 40 ++++++++++++ clientapi/routing/routing.go | 20 +++++- userapi/api/api.go | 2 + userapi/internal/user_api.go | 12 ++++ userapi/storage/interface.go | 2 + .../postgres/registration_tokens_table.go | 61 +++++++++++++++++-- userapi/storage/shared/storage.go | 8 +++ userapi/storage/tables/interface.go | 2 + 8 files changed, 141 insertions(+), 6 deletions(-) diff --git a/clientapi/routing/admin.go b/clientapi/routing/admin.go index 96d557ae48..1644a30091 100644 --- a/clientapi/routing/admin.go +++ b/clientapi/routing/admin.go @@ -198,6 +198,46 @@ func getReturnValueExpiryTime(expiryTime int64) interface{} { return expiryTime } +func AdminGetRegistrationToken(req *http.Request, cfg *config.ClientAPI, userAPI userapi.ClientUserAPI) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + tokenText := vars["token"] + token, err := userAPI.PerformAdminGetRegistrationToken(req.Context(), tokenText) + if err != nil { + return util.MatrixErrorResponse( + http.StatusNotFound, + string(spec.ErrorUnknown), + fmt.Sprintf("token: %s not found", tokenText), + ) + } + return util.JSONResponse{ + Code: 200, + JSON: token, + } +} + +func AdminDeleteRegistrationToken(req *http.Request, cfg *config.ClientAPI, userAPI userapi.ClientUserAPI) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + tokenText := vars["token"] + err = userAPI.PerformAdminDeleteRegistrationToken(req.Context(), tokenText) + if err != nil { + return util.MatrixErrorResponse( + http.StatusNotFound, + string(spec.ErrorUnknown), + fmt.Sprintf("token: %s not found", tokenText), + ) + } + return util.JSONResponse{ + Code: 200, + JSON: map[string]interface{}{}, + } +} + func AdminEvacuateRoom(req *http.Request, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 2d96e05ccd..e1620c53c1 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -169,11 +169,29 @@ func Setup( ).Methods(http.MethodPost, http.MethodOptions) dendriteAdminRouter.Handle("/admin/registrationTokens", - httputil.MakeAdminAPI("admin_registration_tokens", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + httputil.MakeAdminAPI("admin_list_registration_tokens", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return AdminListRegistrationTokens(req, cfg, userAPI) }), ).Methods(http.MethodGet, http.MethodOptions) + dendriteAdminRouter.Handle("/admin/registrationTokens/{token}", + httputil.MakeAdminAPI("admin_get_registration_token", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + if req.Method == http.MethodGet { + return AdminGetRegistrationToken(req, cfg, userAPI) + } else if req.Method == http.MethodPut { + + } else if req.Method == http.MethodDelete { + return AdminDeleteRegistrationToken(req, cfg, userAPI) + } + return util.MatrixErrorResponse( + 404, + string(spec.ErrorNotFound), + "unknown method", + ) + + }), + ).Methods(http.MethodGet, http.MethodPut, http.MethodDelete, http.MethodOptions) + dendriteAdminRouter.Handle("/admin/evacuateRoom/{roomID}", httputil.MakeAdminAPI("admin_evacuate_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return AdminEvacuateRoom(req, rsAPI) diff --git a/userapi/api/api.go b/userapi/api/api.go index 9f014d3f03..532422d841 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -97,6 +97,8 @@ type ClientUserAPI interface { QueryAccountAvailability(ctx context.Context, req *QueryAccountAvailabilityRequest, res *QueryAccountAvailabilityResponse) error PerformAdminCreateRegistrationToken(ctx context.Context, registrationToken *clientapi.RegistrationToken) (bool, error) PerformAdminListRegistrationTokens(ctx context.Context, returnAll bool, valid bool) ([]clientapi.RegistrationToken, error) + PerformAdminGetRegistrationToken(ctx context.Context, tokenString string) (*clientapi.RegistrationToken, error) + PerformAdminDeleteRegistrationToken(ctx context.Context, tokenString string) error PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error diff --git a/userapi/internal/user_api.go b/userapi/internal/user_api.go index 65ea6a8689..07f4250970 100644 --- a/userapi/internal/user_api.go +++ b/userapi/internal/user_api.go @@ -87,6 +87,18 @@ func (a *UserInternalAPI) PerformAdminListRegistrationTokens(ctx context.Context return tokens, nil } +func (a *UserInternalAPI) PerformAdminGetRegistrationToken(ctx context.Context, tokenString string) (*clientapi.RegistrationToken, error) { + token, err := a.DB.GetRegistrationToken(ctx, tokenString) + if err != nil { + return nil, err + } + return token, nil +} + +func (a *UserInternalAPI) PerformAdminDeleteRegistrationToken(ctx context.Context, tokenString string) error { + return a.DB.DeleteRegistrationToken(ctx, tokenString) +} + func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error { local, domain, err := gomatrixserverlib.SplitID('@', req.UserID) if err != nil { diff --git a/userapi/storage/interface.go b/userapi/storage/interface.go index 986da99b5e..144cd1f73a 100644 --- a/userapi/storage/interface.go +++ b/userapi/storage/interface.go @@ -35,6 +35,8 @@ type RegistrationTokens interface { RegistrationTokenExists(ctx context.Context, token string) (bool, error) InsertRegistrationToken(ctx context.Context, registrationToken *clientapi.RegistrationToken) (bool, error) ListRegistrationTokens(ctx context.Context, returnAll bool, valid bool) ([]clientapi.RegistrationToken, error) + GetRegistrationToken(ctx context.Context, tokenString string) (*clientapi.RegistrationToken, error) + DeleteRegistrationToken(ctx context.Context, tokenString string) error } type Profile interface { diff --git a/userapi/storage/postgres/registration_tokens_table.go b/userapi/storage/postgres/registration_tokens_table.go index 3bb8b19e83..0163a29242 100644 --- a/userapi/storage/postgres/registration_tokens_table.go +++ b/userapi/storage/postgres/registration_tokens_table.go @@ -3,6 +3,7 @@ package postgres import ( "context" "database/sql" + "fmt" "time" "github.com/matrix-org/dendrite/clientapi/api" @@ -38,12 +39,20 @@ const listInvalidTokensSQL = "" + "SELECT * FROM userapi_registration_tokens WHERE" + "(uses_allowed <= pending + completed OR expiry_time <= $1)" +const getTokenSQL = "" + + "SELECT pending, completed, uses_allowed, expiry_time FROM userapi_registration_tokens WHERE token = $1" + +const deleteTokenSQL = "" + + "DELETE FROM userapi_registration_tokens WHERE token = $1" + type registrationTokenStatements struct { selectTokenStatement *sql.Stmt insertTokenStatement *sql.Stmt listAllTokensStatement *sql.Stmt listValidTokensStatement *sql.Stmt listInvalidTokenStatement *sql.Stmt + getTokenStatement *sql.Stmt + deleteTokenStatement *sql.Stmt } func NewPostgresRegistrationTokensTable(db *sql.DB) (tables.RegistrationTokensTable, error) { @@ -58,6 +67,8 @@ func NewPostgresRegistrationTokensTable(db *sql.DB) (tables.RegistrationTokensTa {&s.listAllTokensStatement, listAllTokensSQL}, {&s.listValidTokensStatement, listValidTokensSQL}, {&s.listInvalidTokenStatement, listInvalidTokensSQL}, + {&s.getTokenStatement, getTokenSQL}, + {&s.deleteTokenStatement, deleteTokenSQL}, }.Prepare(db) } @@ -129,12 +140,18 @@ func (s *registrationTokenStatements) ListRegistrationTokens(ctx context.Context if err != nil { return tokens, err } + tokenString := tokenString.String + pending := pending.Int32 + completed := completed.Int32 + usesAllowed := getReturnValueForInt32(usesAllowed) + expiryTime := getReturnValueForInt64(expiryTime) + tokenMap := api.RegistrationToken{ - Token: &tokenString.String, - Pending: &pending.Int32, - Completed: &pending.Int32, - UsesAllowed: getReturnValueForInt32(usesAllowed), - ExpiryTime: getReturnValueForInt64(expiryTime), + Token: &tokenString, + Pending: &pending, + Completed: &completed, + UsesAllowed: usesAllowed, + ExpiryTime: expiryTime, } tokens = append(tokens, tokenMap) } @@ -156,3 +173,37 @@ func getReturnValueForInt64(value sql.NullInt64) *int64 { } return nil } + +func (s *registrationTokenStatements) GetRegistrationToken(ctx context.Context, tx *sql.Tx, tokenString string) (*api.RegistrationToken, error) { + stmt := s.getTokenStatement + var pending, completed, usesAllowed sql.NullInt32 + var expiryTime sql.NullInt64 + err := stmt.QueryRowContext(ctx, tokenString).Scan(&pending, &completed, &usesAllowed, &expiryTime) + if err != nil { + return nil, err + } + token := api.RegistrationToken{ + Token: &tokenString, + Pending: &pending.Int32, + Completed: &completed.Int32, + UsesAllowed: getReturnValueForInt32(usesAllowed), + ExpiryTime: getReturnValueForInt64(expiryTime), + } + return &token, nil +} + +func (s *registrationTokenStatements) DeleteRegistrationToken(ctx context.Context, tx *sql.Tx, tokenString string) error { + stmt := s.deleteTokenStatement + res, err := stmt.ExecContext(ctx, tokenString) + if err != nil { + return err + } + count, err := res.RowsAffected() + if err != nil { + return err + } + if count == 0 { + return fmt.Errorf("token: %s does not exists", tokenString) + } + return nil +} diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go index 58ebe30ecf..86f1fd9e62 100644 --- a/userapi/storage/shared/storage.go +++ b/userapi/storage/shared/storage.go @@ -96,6 +96,14 @@ func (d *Database) ListRegistrationTokens(ctx context.Context, returnAll bool, v return d.RegistrationTokens.ListRegistrationTokens(ctx, nil, returnAll, valid) } +func (d *Database) GetRegistrationToken(ctx context.Context, tokenString string) (*clientapi.RegistrationToken, error) { + return d.RegistrationTokens.GetRegistrationToken(ctx, nil, tokenString) +} + +func (d *Database) DeleteRegistrationToken(ctx context.Context, tokenString string) error { + return d.RegistrationTokens.DeleteRegistrationToken(ctx, nil, tokenString) +} + // GetAccountByPassword returns the account associated with the given localpart and password. // Returns sql.ErrNoRows if no account exists which matches the given localpart. func (d *Database) GetAccountByPassword( diff --git a/userapi/storage/tables/interface.go b/userapi/storage/tables/interface.go index dfd52235ac..fe902481a5 100644 --- a/userapi/storage/tables/interface.go +++ b/userapi/storage/tables/interface.go @@ -34,6 +34,8 @@ type RegistrationTokensTable interface { RegistrationTokenExists(ctx context.Context, txn *sql.Tx, token string) (bool, error) InsertRegistrationToken(ctx context.Context, txn *sql.Tx, registrationToken *clientapi.RegistrationToken) (bool, error) ListRegistrationTokens(ctx context.Context, txn *sql.Tx, returnAll bool, valid bool) ([]clientapi.RegistrationToken, error) + GetRegistrationToken(ctx context.Context, txn *sql.Tx, tokenString string) (*clientapi.RegistrationToken, error) + DeleteRegistrationToken(ctx context.Context, txn *sql.Tx, tokenString string) error } type AccountDataTable interface { From 5e0da6ac0ed63ea1746e71b9d785e88ed47f2196 Mon Sep 17 00:00:00 2001 From: santhoshivan23 Date: Sun, 11 Jun 2023 00:51:51 +0530 Subject: [PATCH 12/21] implement update api --- clientapi/routing/admin.go | 62 ++++++++++++++++++- clientapi/routing/routing.go | 2 +- userapi/api/api.go | 1 + userapi/internal/user_api.go | 8 +++ userapi/storage/interface.go | 1 + .../postgres/registration_tokens_table.go | 57 ++++++++++++++--- userapi/storage/shared/storage.go | 8 +++ userapi/storage/tables/interface.go | 1 + 8 files changed, 130 insertions(+), 10 deletions(-) diff --git a/clientapi/routing/admin.go b/clientapi/routing/admin.go index 1644a30091..53b1be3cbf 100644 --- a/clientapi/routing/admin.go +++ b/clientapi/routing/admin.go @@ -182,7 +182,6 @@ func AdminListRegistrationTokens(req *http.Request, cfg *config.ClientAPI, userA "error fetching registration tokens", ) } - return util.JSONResponse{ Code: 200, JSON: map[string]interface{}{ @@ -238,6 +237,67 @@ func AdminDeleteRegistrationToken(req *http.Request, cfg *config.ClientAPI, user } } +func AdminUpdateRegistrationToken(req *http.Request, cfg *config.ClientAPI, userAPI userapi.ClientUserAPI) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + tokenText := vars["token"] + request := make(map[string]interface{}) + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MatrixErrorResponse( + http.StatusBadRequest, + string(spec.ErrorBadJSON), + "Failed to decode request body:", + ) + } + newAttributes := make(map[string]interface{}) + usesAllowed, ok := request["uses_allowed"] + if ok { + // Only add usesAllowed to newAtrributes if it is present and valid + // Non numeric values in payload will cause panic during type conversion. But this is the best way to mimic + // Synapse's behaviour of updating the field if and only if it is present in request body. + if !(usesAllowed == nil || int32(usesAllowed.(float64)) >= 0) { + return util.MatrixErrorResponse( + http.StatusBadRequest, + string(spec.ErrorInvalidParam), + "uses_allowed must be a non-negative integer or null", + ) + } + newAttributes["usesAllowed"] = usesAllowed + } + expiryTime, ok := request["expiry_time"] + if ok { + // Only add expiryTime to newAtrributes if it is present and valid + // Non numeric values in payload will cause panic during type conversion. But this is the best way to mimic + // Synapse's behaviour of updating the field if and only if it is present in request body. + if !(expiryTime == nil || int64(expiryTime.(float64)) > time.Now().UnixNano()/int64(time.Millisecond)) { + return util.MatrixErrorResponse( + http.StatusBadRequest, + string(spec.ErrorInvalidParam), + "expiry_time must be in the future", + ) + } + newAttributes["expiryTime"] = expiryTime + } + if len(newAttributes) == 0 { + // No attributes to update. Return existing token + return AdminGetRegistrationToken(req, cfg, userAPI) + } + updatedToken, err := userAPI.PerformAdminUpdateRegistrationToken(req.Context(), tokenText, newAttributes) + if err != nil { + return util.MatrixErrorResponse( + http.StatusNotFound, + string(spec.ErrorUnknown), + fmt.Sprintf("token: %s not found", tokenText), + ) + } + return util.JSONResponse{ + Code: 200, + JSON: *updatedToken, + } +} + func AdminEvacuateRoom(req *http.Request, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index e1620c53c1..79e628f3f3 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -179,7 +179,7 @@ func Setup( if req.Method == http.MethodGet { return AdminGetRegistrationToken(req, cfg, userAPI) } else if req.Method == http.MethodPut { - + return AdminUpdateRegistrationToken(req, cfg, userAPI) } else if req.Method == http.MethodDelete { return AdminDeleteRegistrationToken(req, cfg, userAPI) } diff --git a/userapi/api/api.go b/userapi/api/api.go index 532422d841..a0dce97589 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -99,6 +99,7 @@ type ClientUserAPI interface { PerformAdminListRegistrationTokens(ctx context.Context, returnAll bool, valid bool) ([]clientapi.RegistrationToken, error) PerformAdminGetRegistrationToken(ctx context.Context, tokenString string) (*clientapi.RegistrationToken, error) PerformAdminDeleteRegistrationToken(ctx context.Context, tokenString string) error + PerformAdminUpdateRegistrationToken(ctx context.Context, tokenString string, newAttributes map[string]interface{}) (*clientapi.RegistrationToken, error) PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error diff --git a/userapi/internal/user_api.go b/userapi/internal/user_api.go index 07f4250970..2cfd649a85 100644 --- a/userapi/internal/user_api.go +++ b/userapi/internal/user_api.go @@ -99,6 +99,14 @@ func (a *UserInternalAPI) PerformAdminDeleteRegistrationToken(ctx context.Contex return a.DB.DeleteRegistrationToken(ctx, tokenString) } +func (a *UserInternalAPI) PerformAdminUpdateRegistrationToken(ctx context.Context, tokenString string, newAttributes map[string]interface{}) (*clientapi.RegistrationToken, error) { + token, err := a.DB.UpdateRegistrationToken(ctx, tokenString, newAttributes) + if err != nil { + return nil, err + } + return token, nil +} + func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error { local, domain, err := gomatrixserverlib.SplitID('@', req.UserID) if err != nil { diff --git a/userapi/storage/interface.go b/userapi/storage/interface.go index 144cd1f73a..125b315853 100644 --- a/userapi/storage/interface.go +++ b/userapi/storage/interface.go @@ -37,6 +37,7 @@ type RegistrationTokens interface { ListRegistrationTokens(ctx context.Context, returnAll bool, valid bool) ([]clientapi.RegistrationToken, error) GetRegistrationToken(ctx context.Context, tokenString string) (*clientapi.RegistrationToken, error) DeleteRegistrationToken(ctx context.Context, tokenString string) error + UpdateRegistrationToken(ctx context.Context, tokenString string, newAttributes map[string]interface{}) (*clientapi.RegistrationToken, error) } type Profile interface { diff --git a/userapi/storage/postgres/registration_tokens_table.go b/userapi/storage/postgres/registration_tokens_table.go index 0163a29242..3f85f20935 100644 --- a/userapi/storage/postgres/registration_tokens_table.go +++ b/userapi/storage/postgres/registration_tokens_table.go @@ -45,14 +45,26 @@ const getTokenSQL = "" + const deleteTokenSQL = "" + "DELETE FROM userapi_registration_tokens WHERE token = $1" +const updateTokenUsesAllowedAndExpiryTimeSQL = "" + + "UPDATE userapi_registration_tokens SET uses_allowed = $2, expiry_time = $3 WHERE token = $1" + +const updateTokenUsesAllowedSQL = "" + + "UPDATE userapi_registration_tokens SET uses_allowed = $2 WHERE token = $1" + +const updateTokenExpiryTimeSQL = "" + + "UPDATE userapi_registration_tokens SET expiry_time = $2 WHERE token = $1" + type registrationTokenStatements struct { - selectTokenStatement *sql.Stmt - insertTokenStatement *sql.Stmt - listAllTokensStatement *sql.Stmt - listValidTokensStatement *sql.Stmt - listInvalidTokenStatement *sql.Stmt - getTokenStatement *sql.Stmt - deleteTokenStatement *sql.Stmt + selectTokenStatement *sql.Stmt + insertTokenStatement *sql.Stmt + listAllTokensStatement *sql.Stmt + listValidTokensStatement *sql.Stmt + listInvalidTokenStatement *sql.Stmt + getTokenStatement *sql.Stmt + deleteTokenStatement *sql.Stmt + updateTokenUsesAllowedAndExpiryTimeStatement *sql.Stmt + updateTokenUsesAllowedStatement *sql.Stmt + updateTokenExpiryTimeStatement *sql.Stmt } func NewPostgresRegistrationTokensTable(db *sql.DB) (tables.RegistrationTokensTable, error) { @@ -69,6 +81,9 @@ func NewPostgresRegistrationTokensTable(db *sql.DB) (tables.RegistrationTokensTa {&s.listInvalidTokenStatement, listInvalidTokensSQL}, {&s.getTokenStatement, getTokenSQL}, {&s.deleteTokenStatement, deleteTokenSQL}, + {&s.updateTokenUsesAllowedAndExpiryTimeStatement, updateTokenUsesAllowedAndExpiryTimeSQL}, + {&s.updateTokenUsesAllowedStatement, updateTokenUsesAllowedSQL}, + {&s.updateTokenExpiryTimeStatement, updateTokenExpiryTimeSQL}, }.Prepare(db) } @@ -175,7 +190,7 @@ func getReturnValueForInt64(value sql.NullInt64) *int64 { } func (s *registrationTokenStatements) GetRegistrationToken(ctx context.Context, tx *sql.Tx, tokenString string) (*api.RegistrationToken, error) { - stmt := s.getTokenStatement + stmt := sqlutil.TxStmt(tx, s.getTokenStatement) var pending, completed, usesAllowed sql.NullInt32 var expiryTime sql.NullInt64 err := stmt.QueryRowContext(ctx, tokenString).Scan(&pending, &completed, &usesAllowed, &expiryTime) @@ -207,3 +222,29 @@ func (s *registrationTokenStatements) DeleteRegistrationToken(ctx context.Contex } return nil } + +func (s *registrationTokenStatements) UpdateRegistrationToken(ctx context.Context, tx *sql.Tx, tokenString string, newAttributes map[string]interface{}) (*api.RegistrationToken, error) { + var stmt *sql.Stmt + usesAllowed, usesAllowedPresent := newAttributes["usesAllowed"] + expiryTime, expiryTimePresent := newAttributes["expiryTime"] + if usesAllowedPresent && expiryTimePresent { + stmt = sqlutil.TxStmt(tx, s.updateTokenUsesAllowedAndExpiryTimeStatement) + _, err := stmt.ExecContext(ctx, tokenString, usesAllowed, expiryTime) + if err != nil { + return nil, err + } + } else if usesAllowedPresent { + stmt = sqlutil.TxStmt(tx, s.updateTokenUsesAllowedStatement) + _, err := stmt.ExecContext(ctx, tokenString, usesAllowed) + if err != nil { + return nil, err + } + } else if expiryTimePresent { + stmt = sqlutil.TxStmt(tx, s.updateTokenExpiryTimeStatement) + _, err := stmt.ExecContext(ctx, tokenString, expiryTime) + if err != nil { + return nil, err + } + } + return s.GetRegistrationToken(ctx, tx, tokenString) +} diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go index 86f1fd9e62..481256db1b 100644 --- a/userapi/storage/shared/storage.go +++ b/userapi/storage/shared/storage.go @@ -104,6 +104,14 @@ func (d *Database) DeleteRegistrationToken(ctx context.Context, tokenString stri return d.RegistrationTokens.DeleteRegistrationToken(ctx, nil, tokenString) } +func (d *Database) UpdateRegistrationToken(ctx context.Context, tokenString string, newAttributes map[string]interface{}) (updatedToken *clientapi.RegistrationToken, err error) { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + updatedToken, err = d.RegistrationTokens.UpdateRegistrationToken(ctx, txn, tokenString, newAttributes) + return err + }) + return +} + // GetAccountByPassword returns the account associated with the given localpart and password. // Returns sql.ErrNoRows if no account exists which matches the given localpart. func (d *Database) GetAccountByPassword( diff --git a/userapi/storage/tables/interface.go b/userapi/storage/tables/interface.go index fe902481a5..3a0be73e4a 100644 --- a/userapi/storage/tables/interface.go +++ b/userapi/storage/tables/interface.go @@ -36,6 +36,7 @@ type RegistrationTokensTable interface { ListRegistrationTokens(ctx context.Context, txn *sql.Tx, returnAll bool, valid bool) ([]clientapi.RegistrationToken, error) GetRegistrationToken(ctx context.Context, txn *sql.Tx, tokenString string) (*clientapi.RegistrationToken, error) DeleteRegistrationToken(ctx context.Context, txn *sql.Tx, tokenString string) error + UpdateRegistrationToken(ctx context.Context, txn *sql.Tx, tokenString string, newAttributes map[string]interface{}) (*clientapi.RegistrationToken, error) } type AccountDataTable interface { From f6dbc84f4d000be2727550c69932ef07c15b96f9 Mon Sep 17 00:00:00 2001 From: santhoshivan23 Date: Mon, 12 Jun 2023 18:20:14 +0530 Subject: [PATCH 13/21] added sqlite support --- .../sqlite3/registration_tokens_table.go | 250 ++++++++++++++++++ userapi/storage/sqlite3/storage.go | 6 +- 2 files changed, 255 insertions(+), 1 deletion(-) create mode 100644 userapi/storage/sqlite3/registration_tokens_table.go diff --git a/userapi/storage/sqlite3/registration_tokens_table.go b/userapi/storage/sqlite3/registration_tokens_table.go new file mode 100644 index 0000000000..47b70d2e16 --- /dev/null +++ b/userapi/storage/sqlite3/registration_tokens_table.go @@ -0,0 +1,250 @@ +package sqlite3 + +import ( + "context" + "database/sql" + "fmt" + "time" + + "github.com/matrix-org/dendrite/clientapi/api" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/storage/tables" +) + +const registrationTokensSchema = ` +CREATE TABLE IF NOT EXISTS userapi_registration_tokens ( + token TEXT PRIMARY KEY, + pending BIGINT, + completed BIGINT, + uses_allowed BIGINT, + expiry_time BIGINT +); +` + +const selectTokenSQL = "" + + "SELECT token FROM userapi_registration_tokens WHERE token = $1" + +const insertTokenSQL = "" + + "INSERT INTO userapi_registration_tokens (token, uses_allowed, expiry_time, pending, completed) VALUES ($1, $2, $3, $4, $5)" + +const listAllTokensSQL = "" + + "SELECT * FROM userapi_registration_tokens" + +const listValidTokensSQL = "" + + "SELECT * FROM userapi_registration_tokens WHERE" + + "(uses_allowed > pending + completed OR uses_allowed IS NULL) AND" + + "(expiry_time > $1 OR expiry_time IS NULL)" + +const listInvalidTokensSQL = "" + + "SELECT * FROM userapi_registration_tokens WHERE" + + "(uses_allowed <= pending + completed OR expiry_time <= $1)" + +const getTokenSQL = "" + + "SELECT pending, completed, uses_allowed, expiry_time FROM userapi_registration_tokens WHERE token = $1" + +const deleteTokenSQL = "" + + "DELETE FROM userapi_registration_tokens WHERE token = $1" + +const updateTokenUsesAllowedAndExpiryTimeSQL = "" + + "UPDATE userapi_registration_tokens SET uses_allowed = $2, expiry_time = $3 WHERE token = $1" + +const updateTokenUsesAllowedSQL = "" + + "UPDATE userapi_registration_tokens SET uses_allowed = $2 WHERE token = $1" + +const updateTokenExpiryTimeSQL = "" + + "UPDATE userapi_registration_tokens SET expiry_time = $2 WHERE token = $1" + +type registrationTokenStatements struct { + selectTokenStatement *sql.Stmt + insertTokenStatement *sql.Stmt + listAllTokensStatement *sql.Stmt + listValidTokensStatement *sql.Stmt + listInvalidTokenStatement *sql.Stmt + getTokenStatement *sql.Stmt + deleteTokenStatement *sql.Stmt + updateTokenUsesAllowedAndExpiryTimeStatement *sql.Stmt + updateTokenUsesAllowedStatement *sql.Stmt + updateTokenExpiryTimeStatement *sql.Stmt +} + +func NewSQLiteRegistrationTokensTable(db *sql.DB) (tables.RegistrationTokensTable, error) { + s := ®istrationTokenStatements{} + _, err := db.Exec(registrationTokensSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ + {&s.selectTokenStatement, selectTokenSQL}, + {&s.insertTokenStatement, insertTokenSQL}, + {&s.listAllTokensStatement, listAllTokensSQL}, + {&s.listValidTokensStatement, listValidTokensSQL}, + {&s.listInvalidTokenStatement, listInvalidTokensSQL}, + {&s.getTokenStatement, getTokenSQL}, + {&s.deleteTokenStatement, deleteTokenSQL}, + {&s.updateTokenUsesAllowedAndExpiryTimeStatement, updateTokenUsesAllowedAndExpiryTimeSQL}, + {&s.updateTokenUsesAllowedStatement, updateTokenUsesAllowedSQL}, + {&s.updateTokenExpiryTimeStatement, updateTokenExpiryTimeSQL}, + }.Prepare(db) +} + +func (s *registrationTokenStatements) RegistrationTokenExists(ctx context.Context, tx *sql.Tx, token string) (bool, error) { + var existingToken string + stmt := s.selectTokenStatement + err := stmt.QueryRowContext(ctx, token).Scan(&existingToken) + if err != nil { + if err == sql.ErrNoRows { + return false, nil + } + return false, err + } + return true, nil +} + +func (s *registrationTokenStatements) InsertRegistrationToken(ctx context.Context, tx *sql.Tx, registrationToken *api.RegistrationToken) (bool, error) { + stmt := sqlutil.TxStmt(tx, s.insertTokenStatement) + _, err := stmt.ExecContext( + ctx, + *registrationToken.Token, + nullIfZeroInt32(*registrationToken.UsesAllowed), + nullIfZero(*registrationToken.ExpiryTime), + *registrationToken.Pending, + *registrationToken.Completed) + if err != nil { + return false, err + } + return true, nil +} + +func nullIfZero(value int64) interface{} { + if value == 0 { + return nil + } + return value +} + +func nullIfZeroInt32(value int32) interface{} { + if value == 0 { + return nil + } + return value +} + +func (s *registrationTokenStatements) ListRegistrationTokens(ctx context.Context, tx *sql.Tx, returnAll bool, valid bool) ([]api.RegistrationToken, error) { + var stmt *sql.Stmt + var tokens []api.RegistrationToken + var tokenString sql.NullString + var pending, completed, usesAllowed sql.NullInt32 + var expiryTime sql.NullInt64 + var rows *sql.Rows + var err error + if returnAll { + stmt = s.listAllTokensStatement + rows, err = stmt.QueryContext(ctx) + } else if valid { + stmt = s.listValidTokensStatement + rows, err = stmt.QueryContext(ctx, time.Now().UnixNano()/int64(time.Millisecond)) + } else { + stmt = s.listInvalidTokenStatement + rows, err = stmt.QueryContext(ctx, time.Now().UnixNano()/int64(time.Millisecond)) + } + if err != nil { + return tokens, err + } + for rows.Next() { + err = rows.Scan(&tokenString, &pending, &completed, &usesAllowed, &expiryTime) + if err != nil { + return tokens, err + } + tokenString := tokenString.String + pending := pending.Int32 + completed := completed.Int32 + usesAllowed := getReturnValueForInt32(usesAllowed) + expiryTime := getReturnValueForInt64(expiryTime) + + tokenMap := api.RegistrationToken{ + Token: &tokenString, + Pending: &pending, + Completed: &completed, + UsesAllowed: usesAllowed, + ExpiryTime: expiryTime, + } + tokens = append(tokens, tokenMap) + } + return tokens, nil +} + +func getReturnValueForInt32(value sql.NullInt32) *int32 { + if value.Valid { + returnValue := value.Int32 + return &returnValue + } + return nil +} + +func getReturnValueForInt64(value sql.NullInt64) *int64 { + if value.Valid { + returnValue := value.Int64 + return &returnValue + } + return nil +} + +func (s *registrationTokenStatements) GetRegistrationToken(ctx context.Context, tx *sql.Tx, tokenString string) (*api.RegistrationToken, error) { + stmt := sqlutil.TxStmt(tx, s.getTokenStatement) + var pending, completed, usesAllowed sql.NullInt32 + var expiryTime sql.NullInt64 + err := stmt.QueryRowContext(ctx, tokenString).Scan(&pending, &completed, &usesAllowed, &expiryTime) + if err != nil { + return nil, err + } + token := api.RegistrationToken{ + Token: &tokenString, + Pending: &pending.Int32, + Completed: &completed.Int32, + UsesAllowed: getReturnValueForInt32(usesAllowed), + ExpiryTime: getReturnValueForInt64(expiryTime), + } + return &token, nil +} + +func (s *registrationTokenStatements) DeleteRegistrationToken(ctx context.Context, tx *sql.Tx, tokenString string) error { + stmt := s.deleteTokenStatement + res, err := stmt.ExecContext(ctx, tokenString) + if err != nil { + return err + } + count, err := res.RowsAffected() + if err != nil { + return err + } + if count == 0 { + return fmt.Errorf("token: %s does not exists", tokenString) + } + return nil +} + +func (s *registrationTokenStatements) UpdateRegistrationToken(ctx context.Context, tx *sql.Tx, tokenString string, newAttributes map[string]interface{}) (*api.RegistrationToken, error) { + var stmt *sql.Stmt + usesAllowed, usesAllowedPresent := newAttributes["usesAllowed"] + expiryTime, expiryTimePresent := newAttributes["expiryTime"] + if usesAllowedPresent && expiryTimePresent { + stmt = sqlutil.TxStmt(tx, s.updateTokenUsesAllowedAndExpiryTimeStatement) + _, err := stmt.ExecContext(ctx, tokenString, usesAllowed, expiryTime) + if err != nil { + return nil, err + } + } else if usesAllowedPresent { + stmt = sqlutil.TxStmt(tx, s.updateTokenUsesAllowedStatement) + _, err := stmt.ExecContext(ctx, tokenString, usesAllowed) + if err != nil { + return nil, err + } + } else if expiryTimePresent { + stmt = sqlutil.TxStmt(tx, s.updateTokenExpiryTimeStatement) + _, err := stmt.ExecContext(ctx, tokenString, expiryTime) + if err != nil { + return nil, err + } + } + return s.GetRegistrationToken(ctx, tx, tokenString) +} diff --git a/userapi/storage/sqlite3/storage.go b/userapi/storage/sqlite3/storage.go index acd9678f21..48f5c842bf 100644 --- a/userapi/storage/sqlite3/storage.go +++ b/userapi/storage/sqlite3/storage.go @@ -50,7 +50,10 @@ func NewUserDatabase(ctx context.Context, conMan sqlutil.Connections, dbProperti if err = m.Up(ctx); err != nil { return nil, err } - + registationTokensTable, err := NewSQLiteRegistrationTokensTable(db) + if err != nil { + return nil, fmt.Errorf("NewSQLiteRegistrationsTokenTable: %w", err) + } accountsTable, err := NewSQLiteAccountsTable(db, serverName) if err != nil { return nil, fmt.Errorf("NewSQLiteAccountsTable: %w", err) @@ -130,6 +133,7 @@ func NewUserDatabase(ctx context.Context, conMan sqlutil.Connections, dbProperti LoginTokenLifetime: loginTokenLifetime, BcryptCost: bcryptCost, OpenIDTokenLifetimeMS: openIDTokenLifetimeMS, + RegistrationTokens: registationTokensTable, }, nil } From 44beddc28701e362d9324ca25db695573bbd92b1 Mon Sep 17 00:00:00 2001 From: santhoshivan23 Date: Mon, 12 Jun 2023 18:31:42 +0530 Subject: [PATCH 14/21] remove unused code --- clientapi/auth/authtypes/logintypes.go | 1 - 1 file changed, 1 deletion(-) diff --git a/clientapi/auth/authtypes/logintypes.go b/clientapi/auth/authtypes/logintypes.go index 6e08d97357..f01e48f806 100644 --- a/clientapi/auth/authtypes/logintypes.go +++ b/clientapi/auth/authtypes/logintypes.go @@ -11,5 +11,4 @@ const ( LoginTypeRecaptcha = "m.login.recaptcha" LoginTypeApplicationService = "m.login.application_service" LoginTypeToken = "m.login.token" - LoginTypeRegistrationToken = "m.login.registration_token" ) From 6ea96a0909e17ccce3ed225806377407e73c6ffa Mon Sep 17 00:00:00 2001 From: santhoshivan23 Date: Wed, 14 Jun 2023 22:51:56 +0530 Subject: [PATCH 15/21] addressed review comments --- clientapi/routing/admin.go | 153 +++++++------ clientapi/routing/routing.go | 21 +- cmd/dendrite/main.go | 203 ------------------ cmd/dendrite/main_test.go | 50 ----- userapi/internal/user_api.go | 18 +- .../postgres/registration_tokens_table.go | 88 +++----- userapi/storage/shared/storage.go | 8 +- .../sqlite3/registration_tokens_table.go | 88 +++----- 8 files changed, 152 insertions(+), 477 deletions(-) delete mode 100644 cmd/dendrite/main.go delete mode 100644 cmd/dendrite/main_test.go diff --git a/clientapi/routing/admin.go b/clientapi/routing/admin.go index 53b1be3cbf..4a7afc58e4 100644 --- a/clientapi/routing/admin.go +++ b/clientapi/routing/admin.go @@ -30,13 +30,14 @@ import ( userapi "github.com/matrix-org/dendrite/userapi/api" ) +var validRegistrationTokenRegex = regexp.MustCompile("^[[:ascii:][:digit:]_]*$") + func AdminCreateNewRegistrationToken(req *http.Request, cfg *config.ClientAPI, userAPI userapi.ClientUserAPI) util.JSONResponse { if !cfg.RegistrationRequiresToken { - return util.MatrixErrorResponse( - http.StatusForbidden, - string(spec.ErrorForbidden), - "Registration via tokens is not enabled on this homeserver", - ) + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: spec.Forbidden("Registration via tokens is not enabled on this homeserver"), + } } request := struct { Token string `json:"token"` @@ -46,11 +47,10 @@ func AdminCreateNewRegistrationToken(req *http.Request, cfg *config.ClientAPI, u }{} if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MatrixErrorResponse( - http.StatusBadRequest, - string(spec.ErrorBadJSON), - "Failed to decode request body:", - ) + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON(fmt.Sprintf("Failed to decode request body: %s", err)), + } } token := request.Token @@ -65,43 +65,43 @@ func AdminCreateNewRegistrationToken(req *http.Request, cfg *config.ClientAPI, u } // token not present in request body. Hence, generate a random token. if !(length > 0 && length <= 64) { - return util.MatrixErrorResponse( - http.StatusBadRequest, - string(spec.ErrorInvalidParam), - "length must be greater than zero and not greater than 64") + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("length must be greater than zero and not greater than 64"), + } } token = generateRandomToken(int(length)) } if len(token) > 64 { //Token present in request body, but is too long. - return util.MatrixErrorResponse( - http.StatusBadRequest, - string(spec.ErrorInvalidParam), - "token must not be longer than 64") + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("token must not be longer than 64"), + } } - isTokenValid, _ := regexp.MatchString("^[[:ascii:][:digit:]_]*$", token) + isTokenValid := validRegistrationTokenRegex.Match([]byte(token)) if !isTokenValid { - return util.MatrixErrorResponse( - http.StatusBadRequest, - string(spec.ErrorInvalidParam), - "token must consist only of characters matched by the regex [A-Za-z0-9-_]") + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("token must consist only of characters matched by the regex [A-Za-z0-9-_]"), + } } // At this point, we have a valid token, either through request body or through random generation. if usesAllowed < 0 { - return util.MatrixErrorResponse( - http.StatusBadRequest, - string(spec.ErrorInvalidParam), - "uses_allowed must be a non-negative integer or null") + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("uses_allowed must be a non-negative integer or null"), + } } if expiryTime != 0 && expiryTime < time.Now().UnixNano()/int64(time.Millisecond) { - return util.MatrixErrorResponse( - http.StatusBadRequest, - string(spec.ErrorInvalidParam), - "expiry_time must not be in the past") + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("expiry_time must not be in the past"), + } } pending := int32(0) completed := int32(0) @@ -115,17 +115,16 @@ func AdminCreateNewRegistrationToken(req *http.Request, cfg *config.ClientAPI, u } created, err := userAPI.PerformAdminCreateRegistrationToken(req.Context(), registrationToken) if err != nil { - return util.MatrixErrorResponse( - http.StatusBadRequest, - string(spec.ErrorUnknown), - err.Error(), - ) + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: err, + } } if !created { - return util.MatrixErrorResponse( - http.StatusBadRequest, - string(spec.ErrorInvalidParam), - fmt.Sprintf("Token alreaady exists: %s", token)) + return util.JSONResponse{ + Code: http.StatusConflict, + JSON: fmt.Sprintf("Token already exists: %s", token), + } } return util.JSONResponse{ Code: 200, @@ -166,21 +165,19 @@ func AdminListRegistrationTokens(req *http.Request, cfg *config.ClientAPI, userA returnAll = false validValue, err := strconv.ParseBool(validQuery[0]) if err != nil { - return util.MatrixErrorResponse( - http.StatusBadRequest, - string(spec.ErrorInvalidParam), - "invalid 'valid' query parameter", - ) + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("invalid 'valid' query parameter"), + } } valid = validValue } tokens, err := userAPI.PerformAdminListRegistrationTokens(req.Context(), returnAll, valid) if err != nil { - return util.MatrixErrorResponse( - http.StatusInternalServerError, - string(spec.ErrorUnknown), - "error fetching registration tokens", - ) + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.ErrorUnknown, + } } return util.JSONResponse{ Code: 200, @@ -205,11 +202,10 @@ func AdminGetRegistrationToken(req *http.Request, cfg *config.ClientAPI, userAPI tokenText := vars["token"] token, err := userAPI.PerformAdminGetRegistrationToken(req.Context(), tokenText) if err != nil { - return util.MatrixErrorResponse( - http.StatusNotFound, - string(spec.ErrorUnknown), - fmt.Sprintf("token: %s not found", tokenText), - ) + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: spec.NotFound(fmt.Sprintf("token: %s not found", tokenText)), + } } return util.JSONResponse{ Code: 200, @@ -225,11 +221,10 @@ func AdminDeleteRegistrationToken(req *http.Request, cfg *config.ClientAPI, user tokenText := vars["token"] err = userAPI.PerformAdminDeleteRegistrationToken(req.Context(), tokenText) if err != nil { - return util.MatrixErrorResponse( - http.StatusNotFound, - string(spec.ErrorUnknown), - fmt.Sprintf("token: %s not found", tokenText), - ) + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: spec.NotFound(fmt.Sprintf("token: %s not found", tokenText)), + } } return util.JSONResponse{ Code: 200, @@ -244,12 +239,11 @@ func AdminUpdateRegistrationToken(req *http.Request, cfg *config.ClientAPI, user } tokenText := vars["token"] request := make(map[string]interface{}) - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MatrixErrorResponse( - http.StatusBadRequest, - string(spec.ErrorBadJSON), - "Failed to decode request body:", - ) + if err = json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON(fmt.Sprintf("Failed to decode request body: %s", err)), + } } newAttributes := make(map[string]interface{}) usesAllowed, ok := request["uses_allowed"] @@ -258,11 +252,10 @@ func AdminUpdateRegistrationToken(req *http.Request, cfg *config.ClientAPI, user // Non numeric values in payload will cause panic during type conversion. But this is the best way to mimic // Synapse's behaviour of updating the field if and only if it is present in request body. if !(usesAllowed == nil || int32(usesAllowed.(float64)) >= 0) { - return util.MatrixErrorResponse( - http.StatusBadRequest, - string(spec.ErrorInvalidParam), - "uses_allowed must be a non-negative integer or null", - ) + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("uses_allowed must be a non-negative integer or null"), + } } newAttributes["usesAllowed"] = usesAllowed } @@ -272,11 +265,10 @@ func AdminUpdateRegistrationToken(req *http.Request, cfg *config.ClientAPI, user // Non numeric values in payload will cause panic during type conversion. But this is the best way to mimic // Synapse's behaviour of updating the field if and only if it is present in request body. if !(expiryTime == nil || int64(expiryTime.(float64)) > time.Now().UnixNano()/int64(time.Millisecond)) { - return util.MatrixErrorResponse( - http.StatusBadRequest, - string(spec.ErrorInvalidParam), - "expiry_time must be in the future", - ) + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("expiry_time must not be in the past"), + } } newAttributes["expiryTime"] = expiryTime } @@ -286,11 +278,10 @@ func AdminUpdateRegistrationToken(req *http.Request, cfg *config.ClientAPI, user } updatedToken, err := userAPI.PerformAdminUpdateRegistrationToken(req.Context(), tokenText, newAttributes) if err != nil { - return util.MatrixErrorResponse( - http.StatusNotFound, - string(spec.ErrorUnknown), - fmt.Sprintf("token: %s not found", tokenText), - ) + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: spec.NotFound(fmt.Sprintf("token: %s not found", tokenText)), + } } return util.JSONResponse{ Code: 200, diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 79e628f3f3..ab4aefddd3 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -176,19 +176,20 @@ func Setup( dendriteAdminRouter.Handle("/admin/registrationTokens/{token}", httputil.MakeAdminAPI("admin_get_registration_token", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - if req.Method == http.MethodGet { + switch req.Method { + case http.MethodGet: return AdminGetRegistrationToken(req, cfg, userAPI) - } else if req.Method == http.MethodPut { + case http.MethodPut: return AdminUpdateRegistrationToken(req, cfg, userAPI) - } else if req.Method == http.MethodDelete { + case http.MethodDelete: return AdminDeleteRegistrationToken(req, cfg, userAPI) + default: + return util.MatrixErrorResponse( + 404, + string(spec.ErrorNotFound), + "unknown method", + ) } - return util.MatrixErrorResponse( - 404, - string(spec.ErrorNotFound), - "unknown method", - ) - }), ).Methods(http.MethodGet, http.MethodPut, http.MethodDelete, http.MethodOptions) @@ -196,7 +197,7 @@ func Setup( httputil.MakeAdminAPI("admin_evacuate_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return AdminEvacuateRoom(req, rsAPI) }), - ).Methods(http.MethodGet, http.MethodOptions) + ).Methods(http.MethodPost, http.MethodOptions) dendriteAdminRouter.Handle("/admin/evacuateUser/{userID}", httputil.MakeAdminAPI("admin_evacuate_user", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { diff --git a/cmd/dendrite/main.go b/cmd/dendrite/main.go deleted file mode 100644 index 66eb88f875..0000000000 --- a/cmd/dendrite/main.go +++ /dev/null @@ -1,203 +0,0 @@ -// Copyright 2017 Vector Creations Ltd -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package main - -import ( - "flag" - "time" - - "github.com/getsentry/sentry-go" - "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/caching" - "github.com/matrix-org/dendrite/internal/httputil" - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/setup/jetstream" - "github.com/matrix-org/dendrite/setup/process" - "github.com/matrix-org/gomatrixserverlib/fclient" - "github.com/sirupsen/logrus" - - "github.com/matrix-org/dendrite/appservice" - "github.com/matrix-org/dendrite/federationapi" - "github.com/matrix-org/dendrite/roomserver" - "github.com/matrix-org/dendrite/setup" - basepkg "github.com/matrix-org/dendrite/setup/base" - "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/setup/mscs" - "github.com/matrix-org/dendrite/userapi" -) - -var ( - unixSocket = flag.String("unix-socket", "", - "EXPERIMENTAL(unstable): The HTTP listening unix socket for the server (disables http[s]-bind-address feature)", - ) - unixSocketPermission = flag.String("unix-socket-permission", "755", - "EXPERIMENTAL(unstable): The HTTP listening unix socket permission for the server (in chmod format like 755)", - ) - httpBindAddr = flag.String("http-bind-address", ":8008", "The HTTP listening port for the server") - httpsBindAddr = flag.String("https-bind-address", ":8448", "The HTTPS listening port for the server") - certFile = flag.String("tls-cert", "", "The PEM formatted X509 certificate to use for TLS") - keyFile = flag.String("tls-key", "", "The PEM private key to use for TLS") -) - -func main() { - cfg := setup.ParseFlags(true) - httpAddr := config.ServerAddress{} - httpsAddr := config.ServerAddress{} - if *unixSocket == "" { - http, err := config.HTTPAddress("http://" + *httpBindAddr) - if err != nil { - logrus.WithError(err).Fatalf("Failed to parse http address") - } - httpAddr = http - https, err := config.HTTPAddress("https://" + *httpsBindAddr) - if err != nil { - logrus.WithError(err).Fatalf("Failed to parse https address") - } - httpsAddr = https - } else { - socket, err := config.UnixSocketAddress(*unixSocket, *unixSocketPermission) - if err != nil { - logrus.WithError(err).Fatalf("Failed to parse unix socket") - } - httpAddr = socket - } - - configErrors := &config.ConfigErrors{} - cfg.Verify(configErrors) - if len(*configErrors) > 0 { - for _, err := range *configErrors { - logrus.Errorf("Configuration error: %s", err) - } - logrus.Fatalf("Failed to start due to configuration errors") - } - processCtx := process.NewProcessContext() - - internal.SetupStdLogging() - internal.SetupHookLogging(cfg.Logging) - internal.SetupPprof() - - basepkg.PlatformSanityChecks() - - logrus.Infof("Dendrite version %s", internal.VersionString()) - if !cfg.ClientAPI.RegistrationDisabled && cfg.ClientAPI.OpenRegistrationWithoutVerificationEnabled { - logrus.Warn("Open registration is enabled") - } - - // create DNS cache - var dnsCache *fclient.DNSCache - if cfg.Global.DNSCache.Enabled { - dnsCache = fclient.NewDNSCache( - cfg.Global.DNSCache.CacheSize, - cfg.Global.DNSCache.CacheLifetime, - ) - logrus.Infof( - "DNS cache enabled (size %d, lifetime %s)", - cfg.Global.DNSCache.CacheSize, - cfg.Global.DNSCache.CacheLifetime, - ) - } - - // setup tracing - closer, err := cfg.SetupTracing() - if err != nil { - logrus.WithError(err).Panicf("failed to start opentracing") - } - defer closer.Close() // nolint: errcheck - - // setup sentry - if cfg.Global.Sentry.Enabled { - logrus.Info("Setting up Sentry for debugging...") - err = sentry.Init(sentry.ClientOptions{ - Dsn: cfg.Global.Sentry.DSN, - Environment: cfg.Global.Sentry.Environment, - Debug: true, - ServerName: string(cfg.Global.ServerName), - Release: "dendrite@" + internal.VersionString(), - AttachStacktrace: true, - }) - if err != nil { - logrus.WithError(err).Panic("failed to start Sentry") - } - go func() { - processCtx.ComponentStarted() - <-processCtx.WaitForShutdown() - if !sentry.Flush(time.Second * 5) { - logrus.Warnf("failed to flush all Sentry events!") - } - processCtx.ComponentFinished() - }() - } - - federationClient := basepkg.CreateFederationClient(cfg, dnsCache) - httpClient := basepkg.CreateClient(cfg, dnsCache) - - // prepare required dependencies - cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) - routers := httputil.NewRouters() - - caches := caching.NewRistrettoCache(cfg.Global.Cache.EstimatedMaxSize, cfg.Global.Cache.MaxAge, caching.EnableMetrics) - natsInstance := jetstream.NATSInstance{} - rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.EnableMetrics) - fsAPI := federationapi.NewInternalAPI( - processCtx, cfg, cm, &natsInstance, federationClient, rsAPI, caches, nil, false, - ) - - keyRing := fsAPI.KeyRing() - - userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, federationClient) - asAPI := appservice.NewInternalAPI(processCtx, cfg, &natsInstance, userAPI, rsAPI) - - // The underlying roomserver implementation needs to be able to call the fedsender. - // This is different to rsAPI which can be the http client which doesn't need this - // dependency. Other components also need updating after their dependencies are up. - rsAPI.SetFederationAPI(fsAPI, keyRing) - rsAPI.SetAppserviceAPI(asAPI) - rsAPI.SetUserAPI(userAPI) - - monolith := setup.Monolith{ - Config: cfg, - Client: httpClient, - FedClient: federationClient, - KeyRing: keyRing, - - AppserviceAPI: asAPI, - // always use the concrete impl here even in -http mode because adding public routes - // must be done on the concrete impl not an HTTP client else fedapi will call itself - FederationAPI: fsAPI, - RoomserverAPI: rsAPI, - UserAPI: userAPI, - } - monolith.AddAllPublicRoutes(processCtx, cfg, routers, cm, &natsInstance, caches, caching.EnableMetrics) - - if len(cfg.MSCs.MSCs) > 0 { - if err := mscs.Enable(cfg, cm, routers, &monolith, caches); err != nil { - logrus.WithError(err).Fatalf("Failed to enable MSCs") - } - } - - // Expose the matrix APIs directly rather than putting them under a /api path. - go func() { - basepkg.SetupAndServeHTTP(processCtx, cfg, routers, httpAddr, nil, nil) - }() - // Handle HTTPS if certificate and key are provided - if *unixSocket == "" && *certFile != "" && *keyFile != "" { - go func() { - basepkg.SetupAndServeHTTP(processCtx, cfg, routers, httpsAddr, certFile, keyFile) - }() - } - - // We want to block forever to let the HTTP and HTTPS handler serve the APIs - basepkg.WaitForShutdown(processCtx) -} diff --git a/cmd/dendrite/main_test.go b/cmd/dendrite/main_test.go deleted file mode 100644 index d51bc74340..0000000000 --- a/cmd/dendrite/main_test.go +++ /dev/null @@ -1,50 +0,0 @@ -package main - -import ( - "os" - "os/signal" - "strings" - "syscall" - "testing" -) - -// This is an instrumented main, used when running integration tests (sytest) with code coverage. -// Compile: go test -c -race -cover -covermode=atomic -o monolith.debug -coverpkg "github.com/matrix-org/..." ./cmd/dendrite -// Run the monolith: ./monolith.debug -test.coverprofile=/somewhere/to/dump/integrationcover.out DEVEL --config dendrite.yaml -// Generate HTML with coverage: go tool cover -html=/somewhere/where/there/is/integrationcover.out -o cover.html -// Source: https://dzone.com/articles/measuring-integration-test-coverage-rate-in-pouchc -func TestMain(_ *testing.T) { - var ( - args []string - ) - - for _, arg := range os.Args { - switch { - case strings.HasPrefix(arg, "DEVEL"): - case strings.HasPrefix(arg, "-test"): - default: - args = append(args, arg) - } - } - // only run the tests if there are args to be passed - if len(args) <= 1 { - return - } - - waitCh := make(chan int, 1) - os.Args = args - go func() { - main() - close(waitCh) - }() - - signalCh := make(chan os.Signal, 1) - signal.Notify(signalCh, syscall.SIGINT, syscall.SIGQUIT, syscall.SIGTERM, syscall.SIGHUP) - - select { - case <-signalCh: - return - case <-waitCh: - return - } -} diff --git a/userapi/internal/user_api.go b/userapi/internal/user_api.go index 2cfd649a85..4305c13a9b 100644 --- a/userapi/internal/user_api.go +++ b/userapi/internal/user_api.go @@ -80,19 +80,11 @@ func (a *UserInternalAPI) PerformAdminCreateRegistrationToken(ctx context.Contex } func (a *UserInternalAPI) PerformAdminListRegistrationTokens(ctx context.Context, returnAll bool, valid bool) ([]clientapi.RegistrationToken, error) { - tokens, err := a.DB.ListRegistrationTokens(ctx, returnAll, valid) - if err != nil { - return nil, err - } - return tokens, nil + return a.DB.ListRegistrationTokens(ctx, returnAll, valid) } func (a *UserInternalAPI) PerformAdminGetRegistrationToken(ctx context.Context, tokenString string) (*clientapi.RegistrationToken, error) { - token, err := a.DB.GetRegistrationToken(ctx, tokenString) - if err != nil { - return nil, err - } - return token, nil + return a.DB.GetRegistrationToken(ctx, tokenString) } func (a *UserInternalAPI) PerformAdminDeleteRegistrationToken(ctx context.Context, tokenString string) error { @@ -100,11 +92,7 @@ func (a *UserInternalAPI) PerformAdminDeleteRegistrationToken(ctx context.Contex } func (a *UserInternalAPI) PerformAdminUpdateRegistrationToken(ctx context.Context, tokenString string, newAttributes map[string]interface{}) (*clientapi.RegistrationToken, error) { - token, err := a.DB.UpdateRegistrationToken(ctx, tokenString, newAttributes) - if err != nil { - return nil, err - } - return token, nil + return a.DB.UpdateRegistrationToken(ctx, tokenString, newAttributes) } func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error { diff --git a/userapi/storage/postgres/registration_tokens_table.go b/userapi/storage/postgres/registration_tokens_table.go index 3f85f20935..45b39c8922 100644 --- a/userapi/storage/postgres/registration_tokens_table.go +++ b/userapi/storage/postgres/registration_tokens_table.go @@ -3,12 +3,13 @@ package postgres import ( "context" "database/sql" - "fmt" "time" "github.com/matrix-org/dendrite/clientapi/api" + internal "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/storage/tables" + "golang.org/x/exp/constraints" ) const registrationTokensSchema = ` @@ -89,7 +90,7 @@ func NewPostgresRegistrationTokensTable(db *sql.DB) (tables.RegistrationTokensTa func (s *registrationTokenStatements) RegistrationTokenExists(ctx context.Context, tx *sql.Tx, token string) (bool, error) { var existingToken string - stmt := s.selectTokenStatement + stmt := sqlutil.TxStmt(tx, s.selectTokenStatement) err := stmt.QueryRowContext(ctx, token).Scan(&existingToken) if err != nil { if err == sql.ErrNoRows { @@ -105,7 +106,7 @@ func (s *registrationTokenStatements) InsertRegistrationToken(ctx context.Contex _, err := stmt.ExecContext( ctx, *registrationToken.Token, - nullIfZeroInt32(*registrationToken.UsesAllowed), + nullIfZero(*registrationToken.UsesAllowed), nullIfZero(*registrationToken.ExpiryTime), *registrationToken.Pending, *registrationToken.Completed) @@ -115,111 +116,82 @@ func (s *registrationTokenStatements) InsertRegistrationToken(ctx context.Contex return true, nil } -func nullIfZero(value int64) interface{} { - if value == 0 { +func nullIfZero[t constraints.Integer](in t) any { + if in == 0 { return nil } - return value -} - -func nullIfZeroInt32(value int32) interface{} { - if value == 0 { - return nil - } - return value + return in } func (s *registrationTokenStatements) ListRegistrationTokens(ctx context.Context, tx *sql.Tx, returnAll bool, valid bool) ([]api.RegistrationToken, error) { var stmt *sql.Stmt var tokens []api.RegistrationToken - var tokenString sql.NullString - var pending, completed, usesAllowed sql.NullInt32 - var expiryTime sql.NullInt64 + var tokenString string + var pending, completed, usesAllowed *int32 + var expiryTime *int64 var rows *sql.Rows var err error if returnAll { - stmt = s.listAllTokensStatement + stmt = sqlutil.TxStmt(tx, s.listAllTokensStatement) rows, err = stmt.QueryContext(ctx) } else if valid { - stmt = s.listValidTokensStatement + stmt = sqlutil.TxStmt(tx, s.listValidTokensStatement) rows, err = stmt.QueryContext(ctx, time.Now().UnixNano()/int64(time.Millisecond)) } else { - stmt = s.listInvalidTokenStatement + stmt = sqlutil.TxStmt(tx, s.listInvalidTokenStatement) rows, err = stmt.QueryContext(ctx, time.Now().UnixNano()/int64(time.Millisecond)) } if err != nil { return tokens, err } + defer internal.CloseAndLogIfError(ctx, rows, "ListRegistrationTokens: rows.close() failed") for rows.Next() { err = rows.Scan(&tokenString, &pending, &completed, &usesAllowed, &expiryTime) if err != nil { return tokens, err } - tokenString := tokenString.String - pending := pending.Int32 - completed := completed.Int32 - usesAllowed := getReturnValueForInt32(usesAllowed) - expiryTime := getReturnValueForInt64(expiryTime) + tokenString := tokenString + pending := pending + completed := completed + usesAllowed := usesAllowed + expiryTime := expiryTime tokenMap := api.RegistrationToken{ Token: &tokenString, - Pending: &pending, - Completed: &completed, + Pending: pending, + Completed: completed, UsesAllowed: usesAllowed, ExpiryTime: expiryTime, } tokens = append(tokens, tokenMap) } - return tokens, nil -} - -func getReturnValueForInt32(value sql.NullInt32) *int32 { - if value.Valid { - returnValue := value.Int32 - return &returnValue - } - return nil -} - -func getReturnValueForInt64(value sql.NullInt64) *int64 { - if value.Valid { - returnValue := value.Int64 - return &returnValue - } - return nil + return tokens, rows.Err() } func (s *registrationTokenStatements) GetRegistrationToken(ctx context.Context, tx *sql.Tx, tokenString string) (*api.RegistrationToken, error) { stmt := sqlutil.TxStmt(tx, s.getTokenStatement) - var pending, completed, usesAllowed sql.NullInt32 - var expiryTime sql.NullInt64 + var pending, completed, usesAllowed *int32 + var expiryTime *int64 err := stmt.QueryRowContext(ctx, tokenString).Scan(&pending, &completed, &usesAllowed, &expiryTime) if err != nil { return nil, err } token := api.RegistrationToken{ Token: &tokenString, - Pending: &pending.Int32, - Completed: &completed.Int32, - UsesAllowed: getReturnValueForInt32(usesAllowed), - ExpiryTime: getReturnValueForInt64(expiryTime), + Pending: pending, + Completed: completed, + UsesAllowed: usesAllowed, + ExpiryTime: expiryTime, } return &token, nil } func (s *registrationTokenStatements) DeleteRegistrationToken(ctx context.Context, tx *sql.Tx, tokenString string) error { - stmt := s.deleteTokenStatement - res, err := stmt.ExecContext(ctx, tokenString) + stmt := sqlutil.TxStmt(tx, s.deleteTokenStatement) + _, err := stmt.ExecContext(ctx, tokenString) if err != nil { return err } - count, err := res.RowsAffected() - if err != nil { - return err - } - if count == 0 { - return fmt.Errorf("token: %s does not exists", tokenString) - } return nil } diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go index 481256db1b..b7acb2035e 100644 --- a/userapi/storage/shared/storage.go +++ b/userapi/storage/shared/storage.go @@ -100,8 +100,12 @@ func (d *Database) GetRegistrationToken(ctx context.Context, tokenString string) return d.RegistrationTokens.GetRegistrationToken(ctx, nil, tokenString) } -func (d *Database) DeleteRegistrationToken(ctx context.Context, tokenString string) error { - return d.RegistrationTokens.DeleteRegistrationToken(ctx, nil, tokenString) +func (d *Database) DeleteRegistrationToken(ctx context.Context, tokenString string) (err error) { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + err = d.RegistrationTokens.DeleteRegistrationToken(ctx, nil, tokenString) + return err + }) + return } func (d *Database) UpdateRegistrationToken(ctx context.Context, tokenString string, newAttributes map[string]interface{}) (updatedToken *clientapi.RegistrationToken, err error) { diff --git a/userapi/storage/sqlite3/registration_tokens_table.go b/userapi/storage/sqlite3/registration_tokens_table.go index 47b70d2e16..99c18c557e 100644 --- a/userapi/storage/sqlite3/registration_tokens_table.go +++ b/userapi/storage/sqlite3/registration_tokens_table.go @@ -3,12 +3,13 @@ package sqlite3 import ( "context" "database/sql" - "fmt" "time" "github.com/matrix-org/dendrite/clientapi/api" + internal "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/storage/tables" + "golang.org/x/exp/constraints" ) const registrationTokensSchema = ` @@ -89,7 +90,7 @@ func NewSQLiteRegistrationTokensTable(db *sql.DB) (tables.RegistrationTokensTabl func (s *registrationTokenStatements) RegistrationTokenExists(ctx context.Context, tx *sql.Tx, token string) (bool, error) { var existingToken string - stmt := s.selectTokenStatement + stmt := sqlutil.TxStmt(tx, s.selectTokenStatement) err := stmt.QueryRowContext(ctx, token).Scan(&existingToken) if err != nil { if err == sql.ErrNoRows { @@ -105,7 +106,7 @@ func (s *registrationTokenStatements) InsertRegistrationToken(ctx context.Contex _, err := stmt.ExecContext( ctx, *registrationToken.Token, - nullIfZeroInt32(*registrationToken.UsesAllowed), + nullIfZero(*registrationToken.UsesAllowed), nullIfZero(*registrationToken.ExpiryTime), *registrationToken.Pending, *registrationToken.Completed) @@ -115,111 +116,82 @@ func (s *registrationTokenStatements) InsertRegistrationToken(ctx context.Contex return true, nil } -func nullIfZero(value int64) interface{} { - if value == 0 { +func nullIfZero[t constraints.Integer](in t) any { + if in == 0 { return nil } - return value -} - -func nullIfZeroInt32(value int32) interface{} { - if value == 0 { - return nil - } - return value + return in } func (s *registrationTokenStatements) ListRegistrationTokens(ctx context.Context, tx *sql.Tx, returnAll bool, valid bool) ([]api.RegistrationToken, error) { var stmt *sql.Stmt var tokens []api.RegistrationToken - var tokenString sql.NullString - var pending, completed, usesAllowed sql.NullInt32 - var expiryTime sql.NullInt64 + var tokenString string + var pending, completed, usesAllowed *int32 + var expiryTime *int64 var rows *sql.Rows var err error if returnAll { - stmt = s.listAllTokensStatement + stmt = sqlutil.TxStmt(tx, s.listAllTokensStatement) rows, err = stmt.QueryContext(ctx) } else if valid { - stmt = s.listValidTokensStatement + stmt = sqlutil.TxStmt(tx, s.listValidTokensStatement) rows, err = stmt.QueryContext(ctx, time.Now().UnixNano()/int64(time.Millisecond)) } else { - stmt = s.listInvalidTokenStatement + stmt = sqlutil.TxStmt(tx, s.listInvalidTokenStatement) rows, err = stmt.QueryContext(ctx, time.Now().UnixNano()/int64(time.Millisecond)) } if err != nil { return tokens, err } + defer internal.CloseAndLogIfError(ctx, rows, "ListRegistrationTokens: rows.close() failed") for rows.Next() { err = rows.Scan(&tokenString, &pending, &completed, &usesAllowed, &expiryTime) if err != nil { return tokens, err } - tokenString := tokenString.String - pending := pending.Int32 - completed := completed.Int32 - usesAllowed := getReturnValueForInt32(usesAllowed) - expiryTime := getReturnValueForInt64(expiryTime) + tokenString := tokenString + pending := pending + completed := completed + usesAllowed := usesAllowed + expiryTime := expiryTime tokenMap := api.RegistrationToken{ Token: &tokenString, - Pending: &pending, - Completed: &completed, + Pending: pending, + Completed: completed, UsesAllowed: usesAllowed, ExpiryTime: expiryTime, } tokens = append(tokens, tokenMap) } - return tokens, nil -} - -func getReturnValueForInt32(value sql.NullInt32) *int32 { - if value.Valid { - returnValue := value.Int32 - return &returnValue - } - return nil -} - -func getReturnValueForInt64(value sql.NullInt64) *int64 { - if value.Valid { - returnValue := value.Int64 - return &returnValue - } - return nil + return tokens, rows.Err() } func (s *registrationTokenStatements) GetRegistrationToken(ctx context.Context, tx *sql.Tx, tokenString string) (*api.RegistrationToken, error) { stmt := sqlutil.TxStmt(tx, s.getTokenStatement) - var pending, completed, usesAllowed sql.NullInt32 - var expiryTime sql.NullInt64 + var pending, completed, usesAllowed *int32 + var expiryTime *int64 err := stmt.QueryRowContext(ctx, tokenString).Scan(&pending, &completed, &usesAllowed, &expiryTime) if err != nil { return nil, err } token := api.RegistrationToken{ Token: &tokenString, - Pending: &pending.Int32, - Completed: &completed.Int32, - UsesAllowed: getReturnValueForInt32(usesAllowed), - ExpiryTime: getReturnValueForInt64(expiryTime), + Pending: pending, + Completed: completed, + UsesAllowed: usesAllowed, + ExpiryTime: expiryTime, } return &token, nil } func (s *registrationTokenStatements) DeleteRegistrationToken(ctx context.Context, tx *sql.Tx, tokenString string) error { - stmt := s.deleteTokenStatement - res, err := stmt.ExecContext(ctx, tokenString) + stmt := sqlutil.TxStmt(tx, s.deleteTokenStatement) + _, err := stmt.ExecContext(ctx, tokenString) if err != nil { return err } - count, err := res.RowsAffected() - if err != nil { - return err - } - if count == 0 { - return fmt.Errorf("token: %s does not exists", tokenString) - } return nil } From 9b83df2b2b8dd6185aef82c604375c7509193f04 Mon Sep 17 00:00:00 2001 From: santhoshivan23 Date: Thu, 15 Jun 2023 08:53:42 +0530 Subject: [PATCH 16/21] add back dendrite --- cmd/dendrite/main.go | 203 ++++++++++++++++++++++++++++++++++++++ cmd/dendrite/main_test.go | 50 ++++++++++ 2 files changed, 253 insertions(+) create mode 100644 cmd/dendrite/main.go create mode 100644 cmd/dendrite/main_test.go diff --git a/cmd/dendrite/main.go b/cmd/dendrite/main.go new file mode 100644 index 0000000000..66eb88f875 --- /dev/null +++ b/cmd/dendrite/main.go @@ -0,0 +1,203 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "flag" + "time" + + "github.com/getsentry/sentry-go" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/setup/jetstream" + "github.com/matrix-org/dendrite/setup/process" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/sirupsen/logrus" + + "github.com/matrix-org/dendrite/appservice" + "github.com/matrix-org/dendrite/federationapi" + "github.com/matrix-org/dendrite/roomserver" + "github.com/matrix-org/dendrite/setup" + basepkg "github.com/matrix-org/dendrite/setup/base" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/mscs" + "github.com/matrix-org/dendrite/userapi" +) + +var ( + unixSocket = flag.String("unix-socket", "", + "EXPERIMENTAL(unstable): The HTTP listening unix socket for the server (disables http[s]-bind-address feature)", + ) + unixSocketPermission = flag.String("unix-socket-permission", "755", + "EXPERIMENTAL(unstable): The HTTP listening unix socket permission for the server (in chmod format like 755)", + ) + httpBindAddr = flag.String("http-bind-address", ":8008", "The HTTP listening port for the server") + httpsBindAddr = flag.String("https-bind-address", ":8448", "The HTTPS listening port for the server") + certFile = flag.String("tls-cert", "", "The PEM formatted X509 certificate to use for TLS") + keyFile = flag.String("tls-key", "", "The PEM private key to use for TLS") +) + +func main() { + cfg := setup.ParseFlags(true) + httpAddr := config.ServerAddress{} + httpsAddr := config.ServerAddress{} + if *unixSocket == "" { + http, err := config.HTTPAddress("http://" + *httpBindAddr) + if err != nil { + logrus.WithError(err).Fatalf("Failed to parse http address") + } + httpAddr = http + https, err := config.HTTPAddress("https://" + *httpsBindAddr) + if err != nil { + logrus.WithError(err).Fatalf("Failed to parse https address") + } + httpsAddr = https + } else { + socket, err := config.UnixSocketAddress(*unixSocket, *unixSocketPermission) + if err != nil { + logrus.WithError(err).Fatalf("Failed to parse unix socket") + } + httpAddr = socket + } + + configErrors := &config.ConfigErrors{} + cfg.Verify(configErrors) + if len(*configErrors) > 0 { + for _, err := range *configErrors { + logrus.Errorf("Configuration error: %s", err) + } + logrus.Fatalf("Failed to start due to configuration errors") + } + processCtx := process.NewProcessContext() + + internal.SetupStdLogging() + internal.SetupHookLogging(cfg.Logging) + internal.SetupPprof() + + basepkg.PlatformSanityChecks() + + logrus.Infof("Dendrite version %s", internal.VersionString()) + if !cfg.ClientAPI.RegistrationDisabled && cfg.ClientAPI.OpenRegistrationWithoutVerificationEnabled { + logrus.Warn("Open registration is enabled") + } + + // create DNS cache + var dnsCache *fclient.DNSCache + if cfg.Global.DNSCache.Enabled { + dnsCache = fclient.NewDNSCache( + cfg.Global.DNSCache.CacheSize, + cfg.Global.DNSCache.CacheLifetime, + ) + logrus.Infof( + "DNS cache enabled (size %d, lifetime %s)", + cfg.Global.DNSCache.CacheSize, + cfg.Global.DNSCache.CacheLifetime, + ) + } + + // setup tracing + closer, err := cfg.SetupTracing() + if err != nil { + logrus.WithError(err).Panicf("failed to start opentracing") + } + defer closer.Close() // nolint: errcheck + + // setup sentry + if cfg.Global.Sentry.Enabled { + logrus.Info("Setting up Sentry for debugging...") + err = sentry.Init(sentry.ClientOptions{ + Dsn: cfg.Global.Sentry.DSN, + Environment: cfg.Global.Sentry.Environment, + Debug: true, + ServerName: string(cfg.Global.ServerName), + Release: "dendrite@" + internal.VersionString(), + AttachStacktrace: true, + }) + if err != nil { + logrus.WithError(err).Panic("failed to start Sentry") + } + go func() { + processCtx.ComponentStarted() + <-processCtx.WaitForShutdown() + if !sentry.Flush(time.Second * 5) { + logrus.Warnf("failed to flush all Sentry events!") + } + processCtx.ComponentFinished() + }() + } + + federationClient := basepkg.CreateFederationClient(cfg, dnsCache) + httpClient := basepkg.CreateClient(cfg, dnsCache) + + // prepare required dependencies + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + routers := httputil.NewRouters() + + caches := caching.NewRistrettoCache(cfg.Global.Cache.EstimatedMaxSize, cfg.Global.Cache.MaxAge, caching.EnableMetrics) + natsInstance := jetstream.NATSInstance{} + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.EnableMetrics) + fsAPI := federationapi.NewInternalAPI( + processCtx, cfg, cm, &natsInstance, federationClient, rsAPI, caches, nil, false, + ) + + keyRing := fsAPI.KeyRing() + + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, federationClient) + asAPI := appservice.NewInternalAPI(processCtx, cfg, &natsInstance, userAPI, rsAPI) + + // The underlying roomserver implementation needs to be able to call the fedsender. + // This is different to rsAPI which can be the http client which doesn't need this + // dependency. Other components also need updating after their dependencies are up. + rsAPI.SetFederationAPI(fsAPI, keyRing) + rsAPI.SetAppserviceAPI(asAPI) + rsAPI.SetUserAPI(userAPI) + + monolith := setup.Monolith{ + Config: cfg, + Client: httpClient, + FedClient: federationClient, + KeyRing: keyRing, + + AppserviceAPI: asAPI, + // always use the concrete impl here even in -http mode because adding public routes + // must be done on the concrete impl not an HTTP client else fedapi will call itself + FederationAPI: fsAPI, + RoomserverAPI: rsAPI, + UserAPI: userAPI, + } + monolith.AddAllPublicRoutes(processCtx, cfg, routers, cm, &natsInstance, caches, caching.EnableMetrics) + + if len(cfg.MSCs.MSCs) > 0 { + if err := mscs.Enable(cfg, cm, routers, &monolith, caches); err != nil { + logrus.WithError(err).Fatalf("Failed to enable MSCs") + } + } + + // Expose the matrix APIs directly rather than putting them under a /api path. + go func() { + basepkg.SetupAndServeHTTP(processCtx, cfg, routers, httpAddr, nil, nil) + }() + // Handle HTTPS if certificate and key are provided + if *unixSocket == "" && *certFile != "" && *keyFile != "" { + go func() { + basepkg.SetupAndServeHTTP(processCtx, cfg, routers, httpsAddr, certFile, keyFile) + }() + } + + // We want to block forever to let the HTTP and HTTPS handler serve the APIs + basepkg.WaitForShutdown(processCtx) +} diff --git a/cmd/dendrite/main_test.go b/cmd/dendrite/main_test.go new file mode 100644 index 0000000000..d51bc74340 --- /dev/null +++ b/cmd/dendrite/main_test.go @@ -0,0 +1,50 @@ +package main + +import ( + "os" + "os/signal" + "strings" + "syscall" + "testing" +) + +// This is an instrumented main, used when running integration tests (sytest) with code coverage. +// Compile: go test -c -race -cover -covermode=atomic -o monolith.debug -coverpkg "github.com/matrix-org/..." ./cmd/dendrite +// Run the monolith: ./monolith.debug -test.coverprofile=/somewhere/to/dump/integrationcover.out DEVEL --config dendrite.yaml +// Generate HTML with coverage: go tool cover -html=/somewhere/where/there/is/integrationcover.out -o cover.html +// Source: https://dzone.com/articles/measuring-integration-test-coverage-rate-in-pouchc +func TestMain(_ *testing.T) { + var ( + args []string + ) + + for _, arg := range os.Args { + switch { + case strings.HasPrefix(arg, "DEVEL"): + case strings.HasPrefix(arg, "-test"): + default: + args = append(args, arg) + } + } + // only run the tests if there are args to be passed + if len(args) <= 1 { + return + } + + waitCh := make(chan int, 1) + os.Args = args + go func() { + main() + close(waitCh) + }() + + signalCh := make(chan os.Signal, 1) + signal.Notify(signalCh, syscall.SIGINT, syscall.SIGQUIT, syscall.SIGTERM, syscall.SIGHUP) + + select { + case <-signalCh: + return + case <-waitCh: + return + } +} From 09904290cbe5863a152dd1a4d3881599b3ab3957 Mon Sep 17 00:00:00 2001 From: santhoshivan23 Date: Sat, 17 Jun 2023 11:51:57 +0530 Subject: [PATCH 17/21] added test cases --- clientapi/admin_test.go | 638 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 638 insertions(+) diff --git a/clientapi/admin_test.go b/clientapi/admin_test.go index 1145cb12d1..e64a9bc7d2 100644 --- a/clientapi/admin_test.go +++ b/clientapi/admin_test.go @@ -2,6 +2,7 @@ package clientapi import ( "context" + "fmt" "net/http" "net/http/httptest" "reflect" @@ -23,12 +24,649 @@ import ( "github.com/matrix-org/util" "github.com/tidwall/gjson" + capi "github.com/matrix-org/dendrite/clientapi/api" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test/testrig" "github.com/matrix-org/dendrite/userapi" uapi "github.com/matrix-org/dendrite/userapi/api" ) +func TestAdminCreateToken(t *testing.T) { + aliceAdmin := test.NewUser(t, test.WithAccountType(uapi.AccountTypeAdmin)) + bob := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser)) + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + cfg, processCtx, close := testrig.CreateConfig(t, dbType) + cfg.ClientAPI.RegistrationRequiresToken = true + defer close() + natsInstance := jetstream.NATSInstance{} + routers := httputil.NewRouters() + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil) + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + accessTokens := map[*test.User]userDevice{ + aliceAdmin: {}, + bob: {}, + } + createAccessTokens(t, accessTokens, userAPI, ctx, routers) + testCases := []struct { + name string + requestingUser *test.User + requestOpt test.HTTPRequestOpt + wantOK bool + withHeader bool + }{ + { + name: "Missing auth", + requestingUser: bob, + wantOK: false, + requestOpt: test.WithJSONBody(t, map[string]interface{}{ + "token": "token1", + }, + ), + }, + { + name: "Bob is denied access", + requestingUser: bob, + wantOK: false, + withHeader: true, + requestOpt: test.WithJSONBody(t, map[string]interface{}{ + "token": "token2", + }, + ), + }, + { + name: "Alice can create a token without specifyiing any information", + requestingUser: aliceAdmin, + wantOK: true, + withHeader: true, + requestOpt: test.WithJSONBody(t, map[string]interface{}{}), + }, + { + name: "Alice can to create a token specifying a name", + requestingUser: aliceAdmin, + wantOK: true, + withHeader: true, + requestOpt: test.WithJSONBody(t, map[string]interface{}{ + "token": "token3", + }, + ), + }, + { + name: "Alice cannot to create a token that already exists", + requestingUser: aliceAdmin, + wantOK: false, + withHeader: true, + requestOpt: test.WithJSONBody(t, map[string]interface{}{ + "token": "token3", + }, + ), + }, + { + name: "Alice can create a token specifying valid params", + requestingUser: aliceAdmin, + wantOK: true, + withHeader: true, + requestOpt: test.WithJSONBody(t, map[string]interface{}{ + "token": "token4", + "uses_allowed": 5, + "expiry_time": time.Now().Add(5*24*time.Hour).UnixNano() / int64(time.Millisecond), + }, + ), + }, + { + name: "Alice cannot create a token specifying invalid name", + requestingUser: aliceAdmin, + wantOK: false, + withHeader: true, + requestOpt: test.WithJSONBody(t, map[string]interface{}{ + "token": "token@", + }, + ), + }, + { + name: "Alice cannot create a token specifying invalid uses_allowed", + requestingUser: aliceAdmin, + wantOK: false, + withHeader: true, + requestOpt: test.WithJSONBody(t, map[string]interface{}{ + "token": "token5", + "uses_allowed": -1, + }, + ), + }, + { + name: "Alice cannot create a token specifying invalid expiry_time", + requestingUser: aliceAdmin, + wantOK: false, + withHeader: true, + requestOpt: test.WithJSONBody(t, map[string]interface{}{ + "token": "token6", + "expiry_time": time.Now().Add(-1*5*24*time.Hour).UnixNano() / int64(time.Millisecond), + }, + ), + }, + { + name: "Alice cannot to create a token specifying invalid length", + requestingUser: aliceAdmin, + wantOK: false, + withHeader: true, + requestOpt: test.WithJSONBody(t, map[string]interface{}{ + "length": 80, + }, + ), + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + req := test.NewRequest(t, http.MethodPost, "/_dendrite/admin/registrationTokens/new") + if tc.requestOpt != nil { + req = test.NewRequest(t, http.MethodPost, "/_dendrite/admin/registrationTokens/new", tc.requestOpt) + } + if tc.withHeader { + req.Header.Set("Authorization", "Bearer "+accessTokens[tc.requestingUser].accessToken) + } + rec := httptest.NewRecorder() + routers.DendriteAdmin.ServeHTTP(rec, req) + t.Logf("%s", rec.Body.String()) + if tc.wantOK && rec.Code != http.StatusOK { + t.Fatalf("expected http status %d, got %d: %s", http.StatusOK, rec.Code, rec.Body.String()) + } + }) + } + }) +} + +func TestAdminListRegistrationTokens(t *testing.T) { + aliceAdmin := test.NewUser(t, test.WithAccountType(uapi.AccountTypeAdmin)) + bob := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser)) + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + cfg, processCtx, close := testrig.CreateConfig(t, dbType) + cfg.ClientAPI.RegistrationRequiresToken = true + defer close() + natsInstance := jetstream.NATSInstance{} + routers := httputil.NewRouters() + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil) + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + accessTokens := map[*test.User]userDevice{ + aliceAdmin: {}, + bob: {}, + } + tokens := []capi.RegistrationToken{ + { + Token: getPointer("valid"), + UsesAllowed: getPointer(int32(10)), + ExpiryTime: getPointer(time.Now().Add(5*24*time.Hour).UnixNano() / int64(time.Millisecond)), + Pending: getPointer(int32(0)), + Completed: getPointer(int32(0)), + }, + { + Token: getPointer("invalid"), + UsesAllowed: getPointer(int32(10)), + ExpiryTime: getPointer(time.Now().Add(-1*5*24*time.Hour).UnixNano() / int64(time.Millisecond)), + Pending: getPointer(int32(0)), + Completed: getPointer(int32(0)), + }, + } + for _, tkn := range tokens { + tkn := tkn + userAPI.PerformAdminCreateRegistrationToken(ctx, &tkn) + } + createAccessTokens(t, accessTokens, userAPI, ctx, routers) + testCases := []struct { + name string + requestingUser *test.User + valid string + isValidSpecified bool + wantOK bool + withHeader bool + }{ + { + name: "Missing auth", + requestingUser: bob, + wantOK: false, + isValidSpecified: false, + }, + { + name: "Bob is denied access", + requestingUser: bob, + wantOK: false, + withHeader: true, + isValidSpecified: false, + }, + { + name: "Alice can list all tokens", + requestingUser: aliceAdmin, + wantOK: true, + withHeader: true, + }, + { + name: "Alice can list all valid tokens", + requestingUser: aliceAdmin, + wantOK: true, + withHeader: true, + valid: "true", + isValidSpecified: true, + }, + { + name: "Alice can list all invalid tokens", + requestingUser: aliceAdmin, + wantOK: true, + withHeader: true, + valid: "false", + isValidSpecified: true, + }, + { + name: "No response when valid has a bad value", + requestingUser: aliceAdmin, + wantOK: false, + withHeader: true, + valid: "trueee", + isValidSpecified: true, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + var path string + if tc.isValidSpecified { + path = fmt.Sprintf("/_dendrite/admin/registrationTokens?valid=%v", tc.valid) + } else { + path = fmt.Sprintf("/_dendrite/admin/registrationTokens") + } + req := test.NewRequest(t, http.MethodGet, path) + if tc.withHeader { + req.Header.Set("Authorization", "Bearer "+accessTokens[tc.requestingUser].accessToken) + } + rec := httptest.NewRecorder() + routers.DendriteAdmin.ServeHTTP(rec, req) + t.Logf("%s", rec.Body.String()) + if tc.wantOK && rec.Code != http.StatusOK { + t.Fatalf("expected http status %d, got %d: %s", http.StatusOK, rec.Code, rec.Body.String()) + } + }) + } + }) +} + +func TestAdminGetRegistrationToken(t *testing.T) { + aliceAdmin := test.NewUser(t, test.WithAccountType(uapi.AccountTypeAdmin)) + bob := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser)) + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + cfg, processCtx, close := testrig.CreateConfig(t, dbType) + cfg.ClientAPI.RegistrationRequiresToken = true + defer close() + natsInstance := jetstream.NATSInstance{} + routers := httputil.NewRouters() + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil) + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + accessTokens := map[*test.User]userDevice{ + aliceAdmin: {}, + bob: {}, + } + tokens := []capi.RegistrationToken{ + { + Token: getPointer("alice_token1"), + UsesAllowed: getPointer(int32(10)), + ExpiryTime: getPointer(time.Now().Add(5*24*time.Hour).UnixNano() / int64(time.Millisecond)), + Pending: getPointer(int32(0)), + Completed: getPointer(int32(0)), + }, + { + Token: getPointer("alice_token2"), + UsesAllowed: getPointer(int32(10)), + ExpiryTime: getPointer(time.Now().Add(-1*5*24*time.Hour).UnixNano() / int64(time.Millisecond)), + Pending: getPointer(int32(0)), + Completed: getPointer(int32(0)), + }, + } + for _, tkn := range tokens { + tkn := tkn + userAPI.PerformAdminCreateRegistrationToken(ctx, &tkn) + } + createAccessTokens(t, accessTokens, userAPI, ctx, routers) + testCases := []struct { + name string + requestingUser *test.User + token string + wantOK bool + withHeader bool + }{ + { + name: "Missing auth", + requestingUser: bob, + wantOK: false, + }, + { + name: "Bob is denied access", + requestingUser: bob, + wantOK: false, + withHeader: true, + }, + { + name: "Alice can GET alice_token1", + token: "alice_token1", + requestingUser: aliceAdmin, + wantOK: true, + withHeader: true, + }, + { + name: "Alice can GET alice_token2", + requestingUser: aliceAdmin, + wantOK: true, + withHeader: true, + token: "alice_token2", + }, + { + name: "Alice cannot GET a token that does not exists", + requestingUser: aliceAdmin, + wantOK: false, + withHeader: true, + token: "alice_token3", + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + path := fmt.Sprintf("/_dendrite/admin/registrationTokens/%s", tc.token) + req := test.NewRequest(t, http.MethodGet, path) + if tc.withHeader { + req.Header.Set("Authorization", "Bearer "+accessTokens[tc.requestingUser].accessToken) + } + rec := httptest.NewRecorder() + routers.DendriteAdmin.ServeHTTP(rec, req) + t.Logf("%s", rec.Body.String()) + if tc.wantOK && rec.Code != http.StatusOK { + t.Fatalf("expected http status %d, got %d: %s", http.StatusOK, rec.Code, rec.Body.String()) + } + }) + } + }) +} + +func TestAdminDeleteRegistrationToken(t *testing.T) { + aliceAdmin := test.NewUser(t, test.WithAccountType(uapi.AccountTypeAdmin)) + bob := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser)) + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + cfg, processCtx, close := testrig.CreateConfig(t, dbType) + cfg.ClientAPI.RegistrationRequiresToken = true + defer close() + natsInstance := jetstream.NATSInstance{} + routers := httputil.NewRouters() + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil) + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + accessTokens := map[*test.User]userDevice{ + aliceAdmin: {}, + bob: {}, + } + tokens := []capi.RegistrationToken{ + { + Token: getPointer("alice_token1"), + UsesAllowed: getPointer(int32(10)), + ExpiryTime: getPointer(time.Now().Add(5*24*time.Hour).UnixNano() / int64(time.Millisecond)), + Pending: getPointer(int32(0)), + Completed: getPointer(int32(0)), + }, + { + Token: getPointer("alice_token2"), + UsesAllowed: getPointer(int32(10)), + ExpiryTime: getPointer(time.Now().Add(-1*5*24*time.Hour).UnixNano() / int64(time.Millisecond)), + Pending: getPointer(int32(0)), + Completed: getPointer(int32(0)), + }, + } + for _, tkn := range tokens { + tkn := tkn + userAPI.PerformAdminCreateRegistrationToken(ctx, &tkn) + } + createAccessTokens(t, accessTokens, userAPI, ctx, routers) + testCases := []struct { + name string + requestingUser *test.User + token string + wantOK bool + withHeader bool + }{ + { + name: "Missing auth", + requestingUser: bob, + wantOK: false, + }, + { + name: "Bob is denied access", + requestingUser: bob, + wantOK: false, + withHeader: true, + }, + { + name: "Alice can DELETE alice_token1", + token: "alice_token1", + requestingUser: aliceAdmin, + wantOK: true, + withHeader: true, + }, + { + name: "Alice can DELETE alice_token2", + requestingUser: aliceAdmin, + wantOK: true, + withHeader: true, + token: "alice_token2", + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + path := fmt.Sprintf("/_dendrite/admin/registrationTokens/%s", tc.token) + req := test.NewRequest(t, http.MethodDelete, path) + if tc.withHeader { + req.Header.Set("Authorization", "Bearer "+accessTokens[tc.requestingUser].accessToken) + } + rec := httptest.NewRecorder() + routers.DendriteAdmin.ServeHTTP(rec, req) + t.Logf("%s", rec.Body.String()) + if tc.wantOK && rec.Code != http.StatusOK { + t.Fatalf("expected http status %d, got %d: %s", http.StatusOK, rec.Code, rec.Body.String()) + } + }) + } + }) +} + +func TestAdminUpdateRegistrationToken(t *testing.T) { + aliceAdmin := test.NewUser(t, test.WithAccountType(uapi.AccountTypeAdmin)) + bob := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser)) + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + cfg, processCtx, close := testrig.CreateConfig(t, dbType) + cfg.ClientAPI.RegistrationRequiresToken = true + defer close() + natsInstance := jetstream.NATSInstance{} + routers := httputil.NewRouters() + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil) + AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) + accessTokens := map[*test.User]userDevice{ + aliceAdmin: {}, + bob: {}, + } + createAccessTokens(t, accessTokens, userAPI, ctx, routers) + tokens := []capi.RegistrationToken{ + { + Token: getPointer("alice_token1"), + UsesAllowed: getPointer(int32(10)), + ExpiryTime: getPointer(time.Now().Add(5*24*time.Hour).UnixNano() / int64(time.Millisecond)), + Pending: getPointer(int32(0)), + Completed: getPointer(int32(0)), + }, + { + Token: getPointer("alice_token2"), + UsesAllowed: getPointer(int32(10)), + ExpiryTime: getPointer(time.Now().Add(-1*5*24*time.Hour).UnixNano() / int64(time.Millisecond)), + Pending: getPointer(int32(0)), + Completed: getPointer(int32(0)), + }, + } + for _, tkn := range tokens { + tkn := tkn + userAPI.PerformAdminCreateRegistrationToken(ctx, &tkn) + } + testCases := []struct { + name string + requestingUser *test.User + method string + token string + requestOpt test.HTTPRequestOpt + wantOK bool + withHeader bool + }{ + { + name: "Missing auth", + requestingUser: bob, + wantOK: false, + token: "alice_token1", + requestOpt: test.WithJSONBody(t, map[string]interface{}{ + "uses_allowed": 10, + }, + ), + }, + { + name: "Bob is denied access", + requestingUser: bob, + wantOK: false, + withHeader: true, + token: "alice_token1", + requestOpt: test.WithJSONBody(t, map[string]interface{}{ + "uses_allowed": 10, + }, + ), + }, + { + name: "Alice can UPDATE a token's uses_allowed property", + requestingUser: aliceAdmin, + wantOK: true, + withHeader: true, + token: "alice_token1", + requestOpt: test.WithJSONBody(t, map[string]interface{}{ + "uses_allowed": 10, + }), + }, + { + name: "Alice can UPDATE a token's expiry_time property", + requestingUser: aliceAdmin, + wantOK: true, + withHeader: true, + token: "alice_token2", + requestOpt: test.WithJSONBody(t, map[string]interface{}{ + "expiry_time": time.Now().Add(5*24*time.Hour).UnixNano() / int64(time.Millisecond), + }, + ), + }, + { + name: "Alice can UPDATE a token's uses_allowed and expiry_time property", + requestingUser: aliceAdmin, + wantOK: false, + withHeader: true, + token: "alice_token1", + requestOpt: test.WithJSONBody(t, map[string]interface{}{ + "uses_allowed": 20, + "expiry_time": time.Now().Add(10*24*time.Hour).UnixNano() / int64(time.Millisecond), + }, + ), + }, + { + name: "Alice CANNOT update a token with invalid properties", + requestingUser: aliceAdmin, + wantOK: false, + withHeader: true, + token: "alice_token2", + requestOpt: test.WithJSONBody(t, map[string]interface{}{ + "uses_allowed": -5, + "expiry_time": time.Now().Add(-1*5*24*time.Hour).UnixNano() / int64(time.Millisecond), + }, + ), + }, + { + name: "Alice CANNOT UPDATE a token that does not exist", + requestingUser: aliceAdmin, + wantOK: false, + withHeader: true, + token: "alice_token9", + requestOpt: test.WithJSONBody(t, map[string]interface{}{ + "uses_allowed": 100, + }, + ), + }, + { + name: "Alice can UPDATE token specifying uses_allowed as null - Valid for infinite uses", + requestingUser: aliceAdmin, + wantOK: false, + withHeader: true, + token: "alice_token1", + requestOpt: test.WithJSONBody(t, map[string]interface{}{ + "uses_allowed": nil, + }, + ), + }, + { + name: "Alice can UPDATE token specifying expiry_time AS null - Valid for infinite time", + requestingUser: aliceAdmin, + wantOK: false, + withHeader: true, + token: "alice_token1", + requestOpt: test.WithJSONBody(t, map[string]interface{}{ + "expiry_time": nil, + }, + ), + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + path := fmt.Sprintf("/_dendrite/admin/registrationTokens/%s", tc.token) + req := test.NewRequest(t, http.MethodPut, path) + if tc.requestOpt != nil { + req = test.NewRequest(t, http.MethodPut, path, tc.requestOpt) + } + if tc.withHeader { + req.Header.Set("Authorization", "Bearer "+accessTokens[tc.requestingUser].accessToken) + } + rec := httptest.NewRecorder() + routers.DendriteAdmin.ServeHTTP(rec, req) + t.Logf("%s", rec.Body.String()) + if tc.wantOK && rec.Code != http.StatusOK { + t.Fatalf("expected http status %d, got %d: %s", http.StatusOK, rec.Code, rec.Body.String()) + } + }) + } + }) +} + +func getPointer[T any](s T) *T { + return &s +} + func TestAdminResetPassword(t *testing.T) { aliceAdmin := test.NewUser(t, test.WithAccountType(uapi.AccountTypeAdmin)) bob := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser)) From c5bb8f6578f5d34dc0790a44d693d854951ea971 Mon Sep 17 00:00:00 2001 From: santhoshivan23 Date: Wed, 21 Jun 2023 21:21:05 +0530 Subject: [PATCH 18/21] addressed review comments --- clientapi/routing/admin.go | 35 ++++++++++++----------------------- 1 file changed, 12 insertions(+), 23 deletions(-) diff --git a/clientapi/routing/admin.go b/clientapi/routing/admin.go index 4a7afc58e4..f22daa8de8 100644 --- a/clientapi/routing/admin.go +++ b/clientapi/routing/admin.go @@ -5,11 +5,9 @@ import ( "encoding/json" "errors" "fmt" - "math/rand" "net/http" "regexp" "strconv" - "strings" "time" "github.com/gorilla/mux" @@ -64,13 +62,13 @@ func AdminCreateNewRegistrationToken(req *http.Request, cfg *config.ClientAPI, u length = 16 } // token not present in request body. Hence, generate a random token. - if !(length > 0 && length <= 64) { + if length <= 0 || length > 64 { return util.JSONResponse{ Code: http.StatusBadRequest, JSON: spec.BadJSON("length must be greater than zero and not greater than 64"), } } - token = generateRandomToken(int(length)) + token = util.RandomString(int(length)) } if len(token) > 64 { @@ -114,16 +112,18 @@ func AdminCreateNewRegistrationToken(req *http.Request, cfg *config.ClientAPI, u ExpiryTime: &expiryTime, } created, err := userAPI.PerformAdminCreateRegistrationToken(req.Context(), registrationToken) - if err != nil { + if !created { return util.JSONResponse{ - Code: http.StatusInternalServerError, - JSON: err, + Code: http.StatusConflict, + JSON: map[string]string{ + "error": fmt.Sprintf("token: %s already exists", token), + }, } } - if !created { + if err != nil { return util.JSONResponse{ - Code: http.StatusConflict, - JSON: fmt.Sprintf("Token already exists: %s", token), + Code: http.StatusInternalServerError, + JSON: err, } } return util.JSONResponse{ @@ -138,17 +138,6 @@ func AdminCreateNewRegistrationToken(req *http.Request, cfg *config.ClientAPI, u } } -func generateRandomToken(length int) string { - allowedChars := "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_" - rand.Seed(time.Now().UnixNano()) - var sb strings.Builder - for i := 0; i < length; i++ { - randomIndex := rand.Intn(len(allowedChars)) - sb.WriteByte(allowedChars[randomIndex]) - } - return sb.String() -} - func getReturnValueForUsesAllowed(usesAllowed int32) interface{} { if usesAllowed == 0 { return nil @@ -222,8 +211,8 @@ func AdminDeleteRegistrationToken(req *http.Request, cfg *config.ClientAPI, user err = userAPI.PerformAdminDeleteRegistrationToken(req.Context(), tokenText) if err != nil { return util.JSONResponse{ - Code: http.StatusNotFound, - JSON: spec.NotFound(fmt.Sprintf("token: %s not found", tokenText)), + Code: http.StatusInternalServerError, + JSON: err, } } return util.JSONResponse{ From 2fcc16fbb75442eb11d39f6d7c96521a0e54f95c Mon Sep 17 00:00:00 2001 From: santhoshivan23 Date: Wed, 21 Jun 2023 23:04:17 +0530 Subject: [PATCH 19/21] type safety --- clientapi/routing/admin.go | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/clientapi/routing/admin.go b/clientapi/routing/admin.go index f22daa8de8..964236f7f1 100644 --- a/clientapi/routing/admin.go +++ b/clientapi/routing/admin.go @@ -227,7 +227,7 @@ func AdminUpdateRegistrationToken(req *http.Request, cfg *config.ClientAPI, user return util.ErrorResponse(err) } tokenText := vars["token"] - request := make(map[string]interface{}) + request := make(map[string]*int64) if err = json.NewDecoder(req.Body).Decode(&request); err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, @@ -238,9 +238,7 @@ func AdminUpdateRegistrationToken(req *http.Request, cfg *config.ClientAPI, user usesAllowed, ok := request["uses_allowed"] if ok { // Only add usesAllowed to newAtrributes if it is present and valid - // Non numeric values in payload will cause panic during type conversion. But this is the best way to mimic - // Synapse's behaviour of updating the field if and only if it is present in request body. - if !(usesAllowed == nil || int32(usesAllowed.(float64)) >= 0) { + if !(usesAllowed == nil || *usesAllowed >= 0) { return util.JSONResponse{ Code: http.StatusBadRequest, JSON: spec.BadJSON("uses_allowed must be a non-negative integer or null"), @@ -251,9 +249,7 @@ func AdminUpdateRegistrationToken(req *http.Request, cfg *config.ClientAPI, user expiryTime, ok := request["expiry_time"] if ok { // Only add expiryTime to newAtrributes if it is present and valid - // Non numeric values in payload will cause panic during type conversion. But this is the best way to mimic - // Synapse's behaviour of updating the field if and only if it is present in request body. - if !(expiryTime == nil || int64(expiryTime.(float64)) > time.Now().UnixNano()/int64(time.Millisecond)) { + if !(expiryTime == nil || *expiryTime > time.Now().UnixNano()/int64(time.Millisecond)) { return util.JSONResponse{ Code: http.StatusBadRequest, JSON: spec.BadJSON("expiry_time must not be in the past"), From 5346ce735abfeefdaff438e690fffa8f16980da2 Mon Sep 17 00:00:00 2001 From: santhoshivan23 Date: Thu, 22 Jun 2023 00:02:09 +0530 Subject: [PATCH 20/21] addressed review comments --- clientapi/routing/admin.go | 32 +++++++------------ .../postgres/registration_tokens_table.go | 10 +++--- .../sqlite3/registration_tokens_table.go | 10 +++--- 3 files changed, 22 insertions(+), 30 deletions(-) diff --git a/clientapi/routing/admin.go b/clientapi/routing/admin.go index 964236f7f1..cc9370be29 100644 --- a/clientapi/routing/admin.go +++ b/clientapi/routing/admin.go @@ -18,6 +18,7 @@ import ( "github.com/matrix-org/util" "github.com/nats-io/nats.go" "github.com/sirupsen/logrus" + "golang.org/x/exp/constraints" clientapi "github.com/matrix-org/dendrite/clientapi/api" "github.com/matrix-org/dendrite/internal/httputil" @@ -39,8 +40,8 @@ func AdminCreateNewRegistrationToken(req *http.Request, cfg *config.ClientAPI, u } request := struct { Token string `json:"token"` - UsesAllowed int32 `json:"uses_allowed"` - ExpiryTime int64 `json:"expiry_time"` + UsesAllowed *int32 `json:"uses_allowed,omitempty"` + ExpiryTime *int64 `json:"expiry_time,omitempty"` Length int32 `json:"length"` }{} @@ -87,15 +88,13 @@ func AdminCreateNewRegistrationToken(req *http.Request, cfg *config.ClientAPI, u } } // At this point, we have a valid token, either through request body or through random generation. - - if usesAllowed < 0 { + if usesAllowed != nil && *usesAllowed < 0 { return util.JSONResponse{ Code: http.StatusBadRequest, JSON: spec.BadJSON("uses_allowed must be a non-negative integer or null"), } } - - if expiryTime != 0 && expiryTime < time.Now().UnixNano()/int64(time.Millisecond) { + if expiryTime != nil && spec.Timestamp(*expiryTime).Time().Before(time.Now()) { return util.JSONResponse{ Code: http.StatusBadRequest, JSON: spec.BadJSON("expiry_time must not be in the past"), @@ -106,10 +105,10 @@ func AdminCreateNewRegistrationToken(req *http.Request, cfg *config.ClientAPI, u // If usesAllowed or expiryTime is 0, it means they are not present in the request. NULL (indicating unlimited uses / no expiration will be persisted in DB) registrationToken := &clientapi.RegistrationToken{ Token: &token, - UsesAllowed: &usesAllowed, + UsesAllowed: usesAllowed, Pending: &pending, Completed: &completed, - ExpiryTime: &expiryTime, + ExpiryTime: expiryTime, } created, err := userAPI.PerformAdminCreateRegistrationToken(req.Context(), registrationToken) if !created { @@ -130,19 +129,19 @@ func AdminCreateNewRegistrationToken(req *http.Request, cfg *config.ClientAPI, u Code: 200, JSON: map[string]interface{}{ "token": token, - "uses_allowed": getReturnValueForUsesAllowed(usesAllowed), + "uses_allowed": getReturnValue(usesAllowed), "pending": pending, "completed": completed, - "expiry_time": getReturnValueExpiryTime(expiryTime), + "expiry_time": getReturnValue(expiryTime), }, } } -func getReturnValueForUsesAllowed(usesAllowed int32) interface{} { - if usesAllowed == 0 { +func getReturnValue[t constraints.Integer](in *t) any { + if in == nil { return nil } - return usesAllowed + return *in } func AdminListRegistrationTokens(req *http.Request, cfg *config.ClientAPI, userAPI userapi.ClientUserAPI) util.JSONResponse { @@ -176,13 +175,6 @@ func AdminListRegistrationTokens(req *http.Request, cfg *config.ClientAPI, userA } } -func getReturnValueExpiryTime(expiryTime int64) interface{} { - if expiryTime == 0 { - return nil - } - return expiryTime -} - func AdminGetRegistrationToken(req *http.Request, cfg *config.ClientAPI, userAPI userapi.ClientUserAPI) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { diff --git a/userapi/storage/postgres/registration_tokens_table.go b/userapi/storage/postgres/registration_tokens_table.go index 45b39c8922..3c3e3fdd93 100644 --- a/userapi/storage/postgres/registration_tokens_table.go +++ b/userapi/storage/postgres/registration_tokens_table.go @@ -106,8 +106,8 @@ func (s *registrationTokenStatements) InsertRegistrationToken(ctx context.Contex _, err := stmt.ExecContext( ctx, *registrationToken.Token, - nullIfZero(*registrationToken.UsesAllowed), - nullIfZero(*registrationToken.ExpiryTime), + getInsertValue(registrationToken.UsesAllowed), + getInsertValue(registrationToken.ExpiryTime), *registrationToken.Pending, *registrationToken.Completed) if err != nil { @@ -116,11 +116,11 @@ func (s *registrationTokenStatements) InsertRegistrationToken(ctx context.Contex return true, nil } -func nullIfZero[t constraints.Integer](in t) any { - if in == 0 { +func getInsertValue[t constraints.Integer](in *t) any { + if in == nil { return nil } - return in + return *in } func (s *registrationTokenStatements) ListRegistrationTokens(ctx context.Context, tx *sql.Tx, returnAll bool, valid bool) ([]api.RegistrationToken, error) { diff --git a/userapi/storage/sqlite3/registration_tokens_table.go b/userapi/storage/sqlite3/registration_tokens_table.go index 99c18c557e..8979547317 100644 --- a/userapi/storage/sqlite3/registration_tokens_table.go +++ b/userapi/storage/sqlite3/registration_tokens_table.go @@ -106,8 +106,8 @@ func (s *registrationTokenStatements) InsertRegistrationToken(ctx context.Contex _, err := stmt.ExecContext( ctx, *registrationToken.Token, - nullIfZero(*registrationToken.UsesAllowed), - nullIfZero(*registrationToken.ExpiryTime), + getInsertValue(registrationToken.UsesAllowed), + getInsertValue(registrationToken.ExpiryTime), *registrationToken.Pending, *registrationToken.Completed) if err != nil { @@ -116,11 +116,11 @@ func (s *registrationTokenStatements) InsertRegistrationToken(ctx context.Contex return true, nil } -func nullIfZero[t constraints.Integer](in t) any { - if in == 0 { +func getInsertValue[t constraints.Integer](in *t) any { + if in == nil { return nil } - return in + return *in } func (s *registrationTokenStatements) ListRegistrationTokens(ctx context.Context, tx *sql.Tx, returnAll bool, valid bool) ([]api.RegistrationToken, error) { From 04ca62aba1e303f6d01919d66bebff4f13c1e4a6 Mon Sep 17 00:00:00 2001 From: santhoshivan23 Date: Thu, 22 Jun 2023 21:40:11 +0530 Subject: [PATCH 21/21] lint and refactor --- clientapi/admin_test.go | 2 +- clientapi/routing/admin.go | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/clientapi/admin_test.go b/clientapi/admin_test.go index e64a9bc7d2..9d2acd68ed 100644 --- a/clientapi/admin_test.go +++ b/clientapi/admin_test.go @@ -281,7 +281,7 @@ func TestAdminListRegistrationTokens(t *testing.T) { if tc.isValidSpecified { path = fmt.Sprintf("/_dendrite/admin/registrationTokens?valid=%v", tc.valid) } else { - path = fmt.Sprintf("/_dendrite/admin/registrationTokens") + path = "/_dendrite/admin/registrationTokens" } req := test.NewRequest(t, http.MethodGet, path) if tc.withHeader { diff --git a/clientapi/routing/admin.go b/clientapi/routing/admin.go index cc9370be29..519666076e 100644 --- a/clientapi/routing/admin.go +++ b/clientapi/routing/admin.go @@ -230,7 +230,7 @@ func AdminUpdateRegistrationToken(req *http.Request, cfg *config.ClientAPI, user usesAllowed, ok := request["uses_allowed"] if ok { // Only add usesAllowed to newAtrributes if it is present and valid - if !(usesAllowed == nil || *usesAllowed >= 0) { + if usesAllowed != nil && *usesAllowed < 0 { return util.JSONResponse{ Code: http.StatusBadRequest, JSON: spec.BadJSON("uses_allowed must be a non-negative integer or null"), @@ -241,7 +241,7 @@ func AdminUpdateRegistrationToken(req *http.Request, cfg *config.ClientAPI, user expiryTime, ok := request["expiry_time"] if ok { // Only add expiryTime to newAtrributes if it is present and valid - if !(expiryTime == nil || *expiryTime > time.Now().UnixNano()/int64(time.Millisecond)) { + if expiryTime != nil && spec.Timestamp(*expiryTime).Time().Before(time.Now()) { return util.JSONResponse{ Code: http.StatusBadRequest, JSON: spec.BadJSON("expiry_time must not be in the past"),