From 56448c1c744c43601f02b7af7d7d6c34987d70b8 Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Thu, 1 Jun 2023 17:52:52 -0600 Subject: [PATCH 01/28] Split senderID & userID fields on PDU --- appservice/consumers/roomserver.go | 9 ++++++++- clientapi/routing/redaction.go | 11 +++++++++-- federationapi/internal/perform.go | 4 ++-- federationapi/routing/leave.go | 10 ++++++---- go.mod | 2 +- go.sum | 4 ++-- internal/pushrules/evaluate.go | 4 ++-- roomserver/internal/alias.go | 4 ++-- roomserver/internal/input/input_events.go | 14 +++++++------- roomserver/internal/perform/perform_backfill.go | 6 ++++-- roomserver/internal/perform/perform_invite.go | 2 +- roomserver/producers/roomevent.go | 2 +- roomserver/storage/shared/membership_updater.go | 2 +- roomserver/storage/shared/storage.go | 14 ++++++++++---- setup/mscs/msc2946/msc2946.go | 2 +- syncapi/consumers/roomserver.go | 2 +- syncapi/routing/memberships.go | 2 +- syncapi/routing/search.go | 10 +++++----- .../storage/postgres/current_room_state_table.go | 2 +- .../storage/postgres/output_room_events_table.go | 2 +- syncapi/storage/shared/storage_consumer.go | 9 ++++++++- .../storage/sqlite3/current_room_state_table.go | 2 +- .../storage/sqlite3/output_room_events_table.go | 2 +- syncapi/streams/stream_invite.go | 6 +++++- syncapi/streams/stream_pdu.go | 6 +++--- syncapi/synctypes/clientevent.go | 6 +++++- syncapi/synctypes/clientevent_test.go | 4 ++-- userapi/consumers/roomserver.go | 12 ++++++++---- userapi/consumers/roomserver_test.go | 8 ++++---- 29 files changed, 103 insertions(+), 60 deletions(-) diff --git a/appservice/consumers/roomserver.go b/appservice/consumers/roomserver.go index c02d904041..634031fe73 100644 --- a/appservice/consumers/roomserver.go +++ b/appservice/consumers/roomserver.go @@ -233,10 +233,17 @@ func (s *appserviceState) backoffAndPause(err error) error { // // TODO: This should be cached, see https://github.com/matrix-org/dendrite/issues/1682 func (s *OutputRoomEventConsumer) appserviceIsInterestedInEvent(ctx context.Context, event *types.HeaderedEvent, appservice *config.ApplicationService) bool { + userID, err := event.UserID() + if err != nil { + log.WithFields(log.Fields{ + "appservice": appservice.ID, + "room_id": event.RoomID(), + }).WithError(err).Errorf("invalid userID") + } switch { case appservice.URL == "": return false - case appservice.IsInterestedInUserID(event.Sender()): + case appservice.IsInterestedInUserID(userID.String()): return true case appservice.IsInterestedInRoomID(event.RoomID()): return true diff --git a/clientapi/routing/redaction.go b/clientapi/routing/redaction.go index 883126423d..76377a5bd1 100644 --- a/clientapi/routing/redaction.go +++ b/clientapi/routing/redaction.go @@ -76,7 +76,14 @@ func SendRedaction( // "Users may redact their own events, and any user with a power level greater than or equal // to the redact power level of the room may redact events there" // https://matrix.org/docs/spec/client_server/r0.6.1#put-matrix-client-r0-rooms-roomid-redact-eventid-txnid - allowedToRedact := ev.Sender() == device.UserID + userID, err := ev.UserID() + if err != nil { + return util.JSONResponse{ + Code: 400, + JSON: spec.BadJSON("invalid userID"), + } + } + allowedToRedact := userID.String() == device.UserID if !allowedToRedact { plEvent := roomserverAPI.GetStateEvent(req.Context(), rsAPI, roomID, gomatrixserverlib.StateKeyTuple{ EventType: spec.MRoomPowerLevels, @@ -119,7 +126,7 @@ func SendRedaction( Type: spec.MRoomRedaction, Redacts: eventID, } - err := proto.SetContent(r) + err = proto.SetContent(r) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("proto.SetContent failed") return util.JSONResponse{ diff --git a/federationapi/internal/perform.go b/federationapi/internal/perform.go index ed800d03a5..70fa2b841b 100644 --- a/federationapi/internal/perform.go +++ b/federationapi/internal/perform.go @@ -509,7 +509,7 @@ func (r *FederationInternalAPI) SendInvite( event gomatrixserverlib.PDU, strippedState []gomatrixserverlib.InviteStrippedState, ) (gomatrixserverlib.PDU, error) { - _, origin, err := r.cfg.Matrix.SplitLocalID('@', event.Sender()) + originUser, err := event.UserID() if err != nil { return nil, err } @@ -542,7 +542,7 @@ func (r *FederationInternalAPI) SendInvite( return nil, fmt.Errorf("gomatrixserverlib.NewInviteV2Request: %w", err) } - inviteRes, err := r.federation.SendInviteV2(ctx, origin, destination, inviteReq) + inviteRes, err := r.federation.SendInviteV2(ctx, originUser.Domain(), destination, inviteReq) if err != nil { return nil, fmt.Errorf("r.federation.SendInviteV2: failed to send invite: %w", err) } diff --git a/federationapi/routing/leave.go b/federationapi/routing/leave.go index a767168d83..91b05e3825 100644 --- a/federationapi/routing/leave.go +++ b/federationapi/routing/leave.go @@ -213,7 +213,7 @@ func SendLeave( JSON: spec.BadJSON("No state key was provided in the leave event."), } } - if !event.StateKeyEquals(event.Sender()) { + if !event.StateKeyEquals(event.SenderID()) { return util.JSONResponse{ Code: http.StatusBadRequest, JSON: spec.BadJSON("Event state key must match the event sender."), @@ -224,12 +224,14 @@ func SendLeave( // the request. By this point we've already asserted that the sender // and the state key are equal so we don't need to check both. var serverName spec.ServerName - if _, serverName, err = gomatrixserverlib.SplitID('@', event.Sender()); err != nil { + userID, err := event.UserID() + if err != nil { return util.JSONResponse{ - Code: http.StatusForbidden, + Code: http.StatusBadRequest, JSON: spec.Forbidden("The sender of the join is invalid"), } - } else if serverName != request.Origin() { + } + if userID.Domain() != request.Origin() { return util.JSONResponse{ Code: http.StatusForbidden, JSON: spec.Forbidden("The sender does not match the server that originated the request"), diff --git a/go.mod b/go.mod index a20757bbcc..88aaf7f4a0 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,7 @@ require ( github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 - github.com/matrix-org/gomatrixserverlib v0.0.0-20230531155817-0e3adf17bee6 + github.com/matrix-org/gomatrixserverlib v0.0.0-20230601235030-c4138953b254 github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 github.com/mattn/go-sqlite3 v1.14.16 diff --git a/go.sum b/go.sum index a1946adaaf..fc1293a296 100644 --- a/go.sum +++ b/go.sum @@ -323,8 +323,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230531155817-0e3adf17bee6 h1:Kh1TNvJDhWN5CdgtICNUC4G0wV2km51LGr46Dvl153A= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230531155817-0e3adf17bee6/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230601235030-c4138953b254 h1:javi7tINJfP86mdYeVg7IzDmeRLeaLA0kRzIbEZlK/I= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230601235030-c4138953b254/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU= github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a h1:awrPDf9LEFySxTLKYBMCiObelNx/cBuv/wzllvCCH3A= github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a/go.mod h1:HchJX9oKMXaT2xYFs0Ha/6Zs06mxLU8k6F1ODnrGkeQ= github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y= diff --git a/internal/pushrules/evaluate.go b/internal/pushrules/evaluate.go index 7c98efd30f..fc65b975f9 100644 --- a/internal/pushrules/evaluate.go +++ b/internal/pushrules/evaluate.go @@ -113,7 +113,7 @@ func ruleMatches(rule *Rule, kind Kind, event gomatrixserverlib.PDU, ec Evaluati return rule.RuleID == event.RoomID(), nil case SenderKind: - return rule.RuleID == event.Sender(), nil + return rule.RuleID == event.SenderID(), nil default: return false, nil @@ -143,7 +143,7 @@ func conditionMatches(cond *Condition, event gomatrixserverlib.PDU, ec Evaluatio return cmp(n), nil case SenderNotificationPermissionCondition: - return ec.HasPowerLevel(event.Sender(), cond.Key) + return ec.HasPowerLevel(event.SenderID(), cond.Key) default: return false, nil diff --git a/roomserver/internal/alias.go b/roomserver/internal/alias.go index 52b90cf4e1..407f4014fe 100644 --- a/roomserver/internal/alias.go +++ b/roomserver/internal/alias.go @@ -173,8 +173,8 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias( } sender := request.UserID - if request.UserID != ev.Sender() { - sender = ev.Sender() + if sender != ev.SenderID() { + sender = ev.SenderID() } _, senderDomain, err := r.Cfg.Global.SplitLocalID('@', sender) diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index 386083f6e5..52051f7a4d 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -128,9 +128,9 @@ func (r *Inputer) processRoomEvent( if roomInfo == nil && !isCreateEvent { return fmt.Errorf("room %s does not exist for event %s", event.RoomID(), event.EventID()) } - _, senderDomain, err := gomatrixserverlib.SplitID('@', event.Sender()) + sender, err := event.UserID() if err != nil { - return fmt.Errorf("event has invalid sender %q", input.Event.Sender()) + return fmt.Errorf("event has invalid sender %q", event.SenderID()) } // If we already know about this outlier and it hasn't been rejected @@ -193,9 +193,9 @@ func (r *Inputer) processRoomEvent( serverRes.ServerNames = append(serverRes.ServerNames, input.Origin) delete(servers, input.Origin) } - if senderDomain != input.Origin && senderDomain != r.Cfg.Matrix.ServerName { - serverRes.ServerNames = append(serverRes.ServerNames, senderDomain) - delete(servers, senderDomain) + if sender.Domain() != input.Origin && sender.Domain() != r.Cfg.Matrix.ServerName { + serverRes.ServerNames = append(serverRes.ServerNames, sender.Domain()) + delete(servers, sender.Domain()) } for server := range servers { serverRes.ServerNames = append(serverRes.ServerNames, server) @@ -451,7 +451,7 @@ func (r *Inputer) processRoomEvent( } // Handle remote room upgrades, e.g. remove published room - if event.Type() == "m.room.tombstone" && event.StateKeyEquals("") && !r.Cfg.Matrix.IsLocalServerName(senderDomain) { + if event.Type() == "m.room.tombstone" && event.StateKeyEquals("") && !r.Cfg.Matrix.IsLocalServerName(sender.Domain()) { if err = r.handleRemoteRoomUpgrade(ctx, event); err != nil { return fmt.Errorf("failed to handle remote room upgrade: %w", err) } @@ -493,7 +493,7 @@ func (r *Inputer) processRoomEvent( func (r *Inputer) handleRemoteRoomUpgrade(ctx context.Context, event gomatrixserverlib.PDU) error { oldRoomID := event.RoomID() newRoomID := gjson.GetBytes(event.Content(), "replacement_room").Str - return r.DB.UpgradeRoom(ctx, oldRoomID, newRoomID, event.Sender()) + return r.DB.UpgradeRoom(ctx, oldRoomID, newRoomID, event.SenderID()) } // processStateBefore works out what the state is before the event and diff --git a/roomserver/internal/perform/perform_backfill.go b/roomserver/internal/perform/perform_backfill.go index fb579f03a0..e9318235f0 100644 --- a/roomserver/internal/perform/perform_backfill.go +++ b/roomserver/internal/perform/perform_backfill.go @@ -484,9 +484,11 @@ FindSuccessor: // Store the server names in a temporary map to avoid duplicates. serverSet := make(map[spec.ServerName]bool) for _, event := range memberEvents { - if _, senderDomain, err := gomatrixserverlib.SplitID('@', event.Sender()); err == nil { - serverSet[senderDomain] = true + sender, err := event.UserID() + if err != nil { + continue } + serverSet[sender.Domain()] = true } var servers []spec.ServerName for server := range serverSet { diff --git a/roomserver/internal/perform/perform_invite.go b/roomserver/internal/perform/perform_invite.go index 1930b5acec..34984eea6f 100644 --- a/roomserver/internal/perform/perform_invite.go +++ b/roomserver/internal/perform/perform_invite.go @@ -125,7 +125,7 @@ func (r *Inviter) PerformInvite( ) error { event := req.Event - sender, err := spec.NewUserID(event.Sender(), true) + sender, err := event.UserID() if err != nil { return spec.InvalidParam("The user ID is invalid") } diff --git a/roomserver/producers/roomevent.go b/roomserver/producers/roomevent.go index febe8ddf41..165304d49c 100644 --- a/roomserver/producers/roomevent.go +++ b/roomserver/producers/roomevent.go @@ -60,7 +60,7 @@ func (r *RoomEventProducer) ProduceRoomEvents(roomID string, updates []api.Outpu "adds_state": len(update.NewRoomEvent.AddsStateEventIDs), "removes_state": len(update.NewRoomEvent.RemovesStateEventIDs), "send_as_server": update.NewRoomEvent.SendAsServer, - "sender": update.NewRoomEvent.Event.Sender(), + "sender": update.NewRoomEvent.Event.SenderID(), }) if update.NewRoomEvent.Event.StateKey() != nil { logger = logger.WithField("state_key", *update.NewRoomEvent.Event.StateKey()) diff --git a/roomserver/storage/shared/membership_updater.go b/roomserver/storage/shared/membership_updater.go index f9c889cb15..105e61df62 100644 --- a/roomserver/storage/shared/membership_updater.go +++ b/roomserver/storage/shared/membership_updater.go @@ -101,7 +101,7 @@ func (u *MembershipUpdater) Update(newMembership tables.MembershipState, event * var inserted bool // Did the query result in a membership change? var retired []string // Did we retire any updates in the process? return inserted, retired, u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error { - senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, event.Sender()) + senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, event.SenderID()) if err != nil { return fmt.Errorf("u.d.AssignStateKeyNID: %w", err) } diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index cefa58a3d0..78cc15c916 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -988,8 +988,14 @@ func (d *EventDatabase) MaybeRedactEvent( return nil } - _, sender1, _ := gomatrixserverlib.SplitID('@', redactedEvent.Sender()) - _, sender2, _ := gomatrixserverlib.SplitID('@', redactionEvent.Sender()) + sender1, err := redactedEvent.UserID() + if err != nil { + return err + } + sender2, err := redactionEvent.UserID() + if err != nil { + return err + } var powerlevels *gomatrixserverlib.PowerLevelContent powerlevels, err = plResolver.Resolve(ctx, redactionEvent.EventID()) if err != nil { @@ -997,9 +1003,9 @@ func (d *EventDatabase) MaybeRedactEvent( } switch { - case powerlevels.UserLevel(redactionEvent.Sender()) >= powerlevels.Redact: + case powerlevels.UserLevel(redactionEvent.SenderID()) >= powerlevels.Redact: // 1. The power level of the redaction event’s sender is greater than or equal to the redact level. - case sender1 == sender2: + case sender1.Domain() == sender2.Domain(): // 2. The domain of the redaction event’s sender matches that of the original event’s sender. default: ignoreRedaction = true diff --git a/setup/mscs/msc2946/msc2946.go b/setup/mscs/msc2946/msc2946.go index 291e0f3b20..f380d3d4f5 100644 --- a/setup/mscs/msc2946/msc2946.go +++ b/setup/mscs/msc2946/msc2946.go @@ -730,7 +730,7 @@ func stripped(ev gomatrixserverlib.PDU) *fclient.MSC2946StrippedEvent { Type: ev.Type(), StateKey: *ev.StateKey(), Content: ev.Content(), - Sender: ev.Sender(), + Sender: ev.SenderID(), OriginServerTS: ev.OriginServerTS(), } } diff --git a/syncapi/consumers/roomserver.go b/syncapi/consumers/roomserver.go index 56285dbf4a..c083646589 100644 --- a/syncapi/consumers/roomserver.go +++ b/syncapi/consumers/roomserver.go @@ -523,7 +523,7 @@ func (s *OutputRoomEventConsumer) updateStateEvent(event *rstypes.HeaderedEvent) prev := types.PrevEventRef{ PrevContent: prevEvent.Content(), ReplacesState: prevEvent.EventID(), - PrevSender: prevEvent.Sender(), + PrevSender: prevEvent.SenderID(), } event.PDU, err = event.SetUnsigned(prev) diff --git a/syncapi/routing/memberships.go b/syncapi/routing/memberships.go index 7d2e137d3e..f92dffa2ea 100644 --- a/syncapi/routing/memberships.go +++ b/syncapi/routing/memberships.go @@ -144,7 +144,7 @@ func GetMemberships( JSON: spec.InternalServerError{}, } } - res.Joined[ev.Sender()] = joinedMember(content) + res.Joined[ev.SenderID()] = joinedMember(content) } return util.JSONResponse{ Code: http.StatusOK, diff --git a/syncapi/routing/search.go b/syncapi/routing/search.go index b7191873ec..21dbabf7e1 100644 --- a/syncapi/routing/search.go +++ b/syncapi/routing/search.go @@ -204,11 +204,11 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts profileInfos := make(map[string]ProfileInfoResponse) for _, ev := range append(eventsBefore, eventsAfter...) { - profile, ok := knownUsersProfiles[event.Sender()] + profile, ok := knownUsersProfiles[event.SenderID()] if !ok { - stateEvent, err := snapshot.GetStateEvent(ctx, ev.RoomID(), spec.MRoomMember, ev.Sender()) + stateEvent, err := snapshot.GetStateEvent(ctx, ev.RoomID(), spec.MRoomMember, ev.SenderID()) if err != nil { - logrus.WithError(err).WithField("user_id", event.Sender()).Warn("failed to query userprofile") + logrus.WithError(err).WithField("user_id", event.SenderID()).Warn("failed to query userprofile") continue } if stateEvent == nil { @@ -218,9 +218,9 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts AvatarURL: gjson.GetBytes(stateEvent.Content(), "avatar_url").Str, DisplayName: gjson.GetBytes(stateEvent.Content(), "displayname").Str, } - knownUsersProfiles[event.Sender()] = profile + knownUsersProfiles[event.SenderID()] = profile } - profileInfos[ev.Sender()] = profile + profileInfos[ev.SenderID()] = profile } results = append(results, Result{ diff --git a/syncapi/storage/postgres/current_room_state_table.go b/syncapi/storage/postgres/current_room_state_table.go index 0cc9637318..bfe5e9bdda 100644 --- a/syncapi/storage/postgres/current_room_state_table.go +++ b/syncapi/storage/postgres/current_room_state_table.go @@ -343,7 +343,7 @@ func (s *currentRoomStateStatements) UpsertRoomState( event.RoomID(), event.EventID(), event.Type(), - event.Sender(), + event.SenderID(), containsURL, *event.StateKey(), headeredJSON, diff --git a/syncapi/storage/postgres/output_room_events_table.go b/syncapi/storage/postgres/output_room_events_table.go index 3aadbccf85..e068afab1b 100644 --- a/syncapi/storage/postgres/output_room_events_table.go +++ b/syncapi/storage/postgres/output_room_events_table.go @@ -407,7 +407,7 @@ func (s *outputRoomEventsStatements) InsertEvent( event.EventID(), headeredJSON, event.Type(), - event.Sender(), + event.SenderID(), containsURL, pq.StringArray(addState), pq.StringArray(removeState), diff --git a/syncapi/storage/shared/storage_consumer.go b/syncapi/storage/shared/storage_consumer.go index ecfd418fc7..3a4758337f 100644 --- a/syncapi/storage/shared/storage_consumer.go +++ b/syncapi/storage/shared/storage_consumer.go @@ -195,7 +195,14 @@ func (d *Database) StreamEventsToEvents(device *userapi.Device, in []types.Strea for i := 0; i < len(in); i++ { out[i] = in[i].HeaderedEvent if device != nil && in[i].TransactionID != nil { - if device.UserID == in[i].Sender() && device.SessionID == in[i].TransactionID.SessionID { + userID, err := in[i].UserID() + if err != nil { + logrus.WithFields(logrus.Fields{ + "event_id": out[i].EventID(), + }).WithError(err).Warnf("Event has invalid userID") + continue + } + if device.UserID == userID.String() && device.SessionID == in[i].TransactionID.SessionID { err := out[i].SetUnsignedField( "transaction_id", in[i].TransactionID.TransactionID, ) diff --git a/syncapi/storage/sqlite3/current_room_state_table.go b/syncapi/storage/sqlite3/current_room_state_table.go index 1b8632eb67..e432e483b9 100644 --- a/syncapi/storage/sqlite3/current_room_state_table.go +++ b/syncapi/storage/sqlite3/current_room_state_table.go @@ -342,7 +342,7 @@ func (s *currentRoomStateStatements) UpsertRoomState( event.RoomID(), event.EventID(), event.Type(), - event.Sender(), + event.SenderID(), containsURL, *event.StateKey(), headeredJSON, diff --git a/syncapi/storage/sqlite3/output_room_events_table.go b/syncapi/storage/sqlite3/output_room_events_table.go index d63e760672..5a47aec44e 100644 --- a/syncapi/storage/sqlite3/output_room_events_table.go +++ b/syncapi/storage/sqlite3/output_room_events_table.go @@ -348,7 +348,7 @@ func (s *outputRoomEventsStatements) InsertEvent( event.EventID(), headeredJSON, event.Type(), - event.Sender(), + event.SenderID(), containsURL, string(addStateJSON), string(removeStateJSON), diff --git a/syncapi/streams/stream_invite.go b/syncapi/streams/stream_invite.go index becd863a98..b35c24facf 100644 --- a/syncapi/streams/stream_invite.go +++ b/syncapi/streams/stream_invite.go @@ -63,7 +63,11 @@ func (p *InviteStreamProvider) IncrementalSync( for roomID, inviteEvent := range invites { // skip ignored user events - if _, ok := req.IgnoredUsers.List[inviteEvent.Sender()]; ok { + userID, err := inviteEvent.UserID() + if err != nil { + continue + } + if _, ok := req.IgnoredUsers.List[userID.String()]; ok { continue } ir := types.NewInviteResponse(inviteEvent) diff --git a/syncapi/streams/stream_pdu.go b/syncapi/streams/stream_pdu.go index b5fd5be8e8..46140060f8 100644 --- a/syncapi/streams/stream_pdu.go +++ b/syncapi/streams/stream_pdu.go @@ -425,7 +425,7 @@ func applyHistoryVisibilityFilter( for _, ev := range recentEvents { if ev.StateKey() != nil { stateTypes = append(stateTypes, ev.Type()) - senders = append(senders, ev.Sender()) + senders = append(senders, ev.SenderID()) } } @@ -577,8 +577,8 @@ func (p *PDUStreamProvider) lazyLoadMembers( // Add all users the client doesn't know about yet to a list for _, event := range timelineEvents { // Membership is not yet cached, add it to the list - if _, ok := p.lazyLoadCache.IsLazyLoadedUserCached(device, roomID, event.Sender()); !ok { - timelineUsers[event.Sender()] = struct{}{} + if _, ok := p.lazyLoadCache.IsLazyLoadedUserCached(device, roomID, event.SenderID()); !ok { + timelineUsers[event.SenderID()] = struct{}{} } } // Preallocate with the same amount, even if it will end up with fewer values diff --git a/syncapi/synctypes/clientevent.go b/syncapi/synctypes/clientevent.go index c722fe60a2..2d5daa2828 100644 --- a/syncapi/synctypes/clientevent.go +++ b/syncapi/synctypes/clientevent.go @@ -57,9 +57,13 @@ func ToClientEvents(serverEvs []gomatrixserverlib.PDU, format ClientEventFormat) // ToClientEvent converts a single server event to a client event. func ToClientEvent(se gomatrixserverlib.PDU, format ClientEventFormat) ClientEvent { + sender, err := se.UserID() + if err != nil { + panic("uh oh! userID does not exist...") + } ce := ClientEvent{ Content: spec.RawJSON(se.Content()), - Sender: se.Sender(), + Sender: sender.String(), Type: se.Type(), StateKey: se.StateKey(), Unsigned: spec.RawJSON(se.Unsigned()), diff --git a/syncapi/synctypes/clientevent_test.go b/syncapi/synctypes/clientevent_test.go index b914e64f1d..88a2e7e8d5 100644 --- a/syncapi/synctypes/clientevent_test.go +++ b/syncapi/synctypes/clientevent_test.go @@ -62,8 +62,8 @@ func TestToClientEvent(t *testing.T) { // nolint: gocyclo if !bytes.Equal(ce.Unsigned, ev.Unsigned()) { t.Errorf("ClientEvent.Unsigned: wanted %s, got %s", string(ev.Unsigned()), string(ce.Unsigned)) } - if ce.Sender != ev.Sender() { - t.Errorf("ClientEvent.Sender: wanted %s, got %s", ev.Sender(), ce.Sender) + if ce.Sender != ev.SenderID() { + t.Errorf("ClientEvent.Sender: wanted %s, got %s", ev.SenderID(), ce.Sender) } j, err := json.Marshal(ce) if err != nil { diff --git a/userapi/consumers/roomserver.go b/userapi/consumers/roomserver.go index 3cfdc0ce92..b74f843745 100644 --- a/userapi/consumers/roomserver.go +++ b/userapi/consumers/roomserver.go @@ -108,7 +108,7 @@ func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msgs []*nats.Ms } if s.cfg.Matrix.ReportStats.Enabled { - go s.storeMessageStats(ctx, event.Type(), event.Sender(), event.RoomID()) + go s.storeMessageStats(ctx, event.Type(), event.SenderID(), event.RoomID()) } log.WithFields(log.Fields{ @@ -615,7 +615,11 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *rstype // evaluatePushRules fetches and evaluates the push rules of a local // user. Returns actions (including dont_notify). func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event *rstypes.HeaderedEvent, mem *localMembership, roomSize int) ([]*pushrules.Action, error) { - if event.Sender() == mem.UserID { + userID, err := event.UserID() + if err != nil { + return nil, err + } + if userID.String() == mem.UserID { // SPEC: Homeservers MUST NOT notify the Push Gateway for // events that the user has sent themselves. return nil, nil @@ -632,7 +636,7 @@ func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event * if err != nil { return nil, err } - sender := event.Sender() + sender := event.SenderID() if _, ok := ignored.List[sender]; ok { return nil, fmt.Errorf("user %s is ignored", sender) } @@ -767,7 +771,7 @@ func (s *OutputRoomEventConsumer) notifyHTTP(ctx context.Context, event *rstypes ID: event.EventID(), RoomID: event.RoomID(), RoomName: roomName, - Sender: event.Sender(), + Sender: event.SenderID(), Type: event.Type(), }, } diff --git a/userapi/consumers/roomserver_test.go b/userapi/consumers/roomserver_test.go index 53977206f7..977b7fda86 100644 --- a/userapi/consumers/roomserver_test.go +++ b/userapi/consumers/roomserver_test.go @@ -61,13 +61,13 @@ func Test_evaluatePushRules(t *testing.T) { }{ { name: "m.receipt doesn't notify", - eventContent: `{"type":"m.receipt"}`, + eventContent: `{"type":"m.receipt","sender":"@test:remotehost"}`, wantAction: pushrules.UnknownAction, wantActions: nil, }, { name: "m.reaction doesn't notify", - eventContent: `{"type":"m.reaction"}`, + eventContent: `{"type":"m.reaction","sender":"@test:remotehost"}`, wantAction: pushrules.DontNotifyAction, wantActions: []*pushrules.Action{ { @@ -77,7 +77,7 @@ func Test_evaluatePushRules(t *testing.T) { }, { name: "m.room.message notifies", - eventContent: `{"type":"m.room.message"}`, + eventContent: `{"type":"m.room.message","sender":"@test:remotehost"}`, wantNotify: true, wantAction: pushrules.NotifyAction, wantActions: []*pushrules.Action{ @@ -86,7 +86,7 @@ func Test_evaluatePushRules(t *testing.T) { }, { name: "m.room.message highlights", - eventContent: `{"type":"m.room.message", "content": {"body": "test"} }`, + eventContent: `{"type":"m.room.message", "content": {"body": "test"},"sender":"@test:remotehost" }`, wantNotify: true, wantAction: pushrules.NotifyAction, wantActions: []*pushrules.Action{ From b479b062c5984d80e14f5a72096064fb0c347da8 Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Fri, 2 Jun 2023 10:42:33 -0600 Subject: [PATCH 02/28] Fix sytest failures by using only SenderID everywhere for now --- appservice/consumers/roomserver.go | 9 +-------- clientapi/routing/redaction.go | 11 ++--------- federationapi/internal/perform.go | 4 ++-- federationapi/routing/leave.go | 8 +++----- go.mod | 2 +- go.sum | 4 ++-- roomserver/internal/alias.go | 2 +- roomserver/internal/input/input_events.go | 10 +++++----- roomserver/internal/perform/perform_backfill.go | 6 ++---- roomserver/internal/perform/perform_invite.go | 2 +- roomserver/storage/shared/storage.go | 12 +++--------- syncapi/storage/shared/storage_consumer.go | 9 +-------- syncapi/streams/stream_invite.go | 6 +----- syncapi/synctypes/clientevent.go | 6 +----- userapi/consumers/roomserver.go | 6 +----- 15 files changed, 27 insertions(+), 70 deletions(-) diff --git a/appservice/consumers/roomserver.go b/appservice/consumers/roomserver.go index 634031fe73..97dd90deba 100644 --- a/appservice/consumers/roomserver.go +++ b/appservice/consumers/roomserver.go @@ -233,17 +233,10 @@ func (s *appserviceState) backoffAndPause(err error) error { // // TODO: This should be cached, see https://github.com/matrix-org/dendrite/issues/1682 func (s *OutputRoomEventConsumer) appserviceIsInterestedInEvent(ctx context.Context, event *types.HeaderedEvent, appservice *config.ApplicationService) bool { - userID, err := event.UserID() - if err != nil { - log.WithFields(log.Fields{ - "appservice": appservice.ID, - "room_id": event.RoomID(), - }).WithError(err).Errorf("invalid userID") - } switch { case appservice.URL == "": return false - case appservice.IsInterestedInUserID(userID.String()): + case appservice.IsInterestedInUserID(event.SenderID()): return true case appservice.IsInterestedInRoomID(event.RoomID()): return true diff --git a/clientapi/routing/redaction.go b/clientapi/routing/redaction.go index 76377a5bd1..bd3a79ebc5 100644 --- a/clientapi/routing/redaction.go +++ b/clientapi/routing/redaction.go @@ -76,14 +76,7 @@ func SendRedaction( // "Users may redact their own events, and any user with a power level greater than or equal // to the redact power level of the room may redact events there" // https://matrix.org/docs/spec/client_server/r0.6.1#put-matrix-client-r0-rooms-roomid-redact-eventid-txnid - userID, err := ev.UserID() - if err != nil { - return util.JSONResponse{ - Code: 400, - JSON: spec.BadJSON("invalid userID"), - } - } - allowedToRedact := userID.String() == device.UserID + allowedToRedact := ev.SenderID() == device.UserID if !allowedToRedact { plEvent := roomserverAPI.GetStateEvent(req.Context(), rsAPI, roomID, gomatrixserverlib.StateKeyTuple{ EventType: spec.MRoomPowerLevels, @@ -126,7 +119,7 @@ func SendRedaction( Type: spec.MRoomRedaction, Redacts: eventID, } - err = proto.SetContent(r) + err := proto.SetContent(r) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("proto.SetContent failed") return util.JSONResponse{ diff --git a/federationapi/internal/perform.go b/federationapi/internal/perform.go index 70fa2b841b..960a4461e4 100644 --- a/federationapi/internal/perform.go +++ b/federationapi/internal/perform.go @@ -509,7 +509,7 @@ func (r *FederationInternalAPI) SendInvite( event gomatrixserverlib.PDU, strippedState []gomatrixserverlib.InviteStrippedState, ) (gomatrixserverlib.PDU, error) { - originUser, err := event.UserID() + _, origin, err := r.cfg.Matrix.SplitLocalID('@', event.SenderID()) if err != nil { return nil, err } @@ -542,7 +542,7 @@ func (r *FederationInternalAPI) SendInvite( return nil, fmt.Errorf("gomatrixserverlib.NewInviteV2Request: %w", err) } - inviteRes, err := r.federation.SendInviteV2(ctx, originUser.Domain(), destination, inviteReq) + inviteRes, err := r.federation.SendInviteV2(ctx, origin, destination, inviteReq) if err != nil { return nil, fmt.Errorf("r.federation.SendInviteV2: failed to send invite: %w", err) } diff --git a/federationapi/routing/leave.go b/federationapi/routing/leave.go index 91b05e3825..84976901aa 100644 --- a/federationapi/routing/leave.go +++ b/federationapi/routing/leave.go @@ -224,14 +224,12 @@ func SendLeave( // the request. By this point we've already asserted that the sender // and the state key are equal so we don't need to check both. var serverName spec.ServerName - userID, err := event.UserID() - if err != nil { + if _, serverName, err = gomatrixserverlib.SplitID('@', event.SenderID()); err != nil { return util.JSONResponse{ - Code: http.StatusBadRequest, + Code: http.StatusForbidden, JSON: spec.Forbidden("The sender of the join is invalid"), } - } - if userID.Domain() != request.Origin() { + } else if serverName != request.Origin() { return util.JSONResponse{ Code: http.StatusForbidden, JSON: spec.Forbidden("The sender does not match the server that originated the request"), diff --git a/go.mod b/go.mod index 88aaf7f4a0..71e6ffc7a9 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,7 @@ require ( github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 - github.com/matrix-org/gomatrixserverlib v0.0.0-20230601235030-c4138953b254 + github.com/matrix-org/gomatrixserverlib v0.0.0-20230602163838-69e05dbabdad github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 github.com/mattn/go-sqlite3 v1.14.16 diff --git a/go.sum b/go.sum index fc1293a296..565057b44c 100644 --- a/go.sum +++ b/go.sum @@ -323,8 +323,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230601235030-c4138953b254 h1:javi7tINJfP86mdYeVg7IzDmeRLeaLA0kRzIbEZlK/I= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230601235030-c4138953b254/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230602163838-69e05dbabdad h1:l+mNyh4S3v0z320VOFuTIGZLjCAeTenVjYBiUNTZm/o= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230602163838-69e05dbabdad/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU= github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a h1:awrPDf9LEFySxTLKYBMCiObelNx/cBuv/wzllvCCH3A= github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a/go.mod h1:HchJX9oKMXaT2xYFs0Ha/6Zs06mxLU8k6F1ODnrGkeQ= github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y= diff --git a/roomserver/internal/alias.go b/roomserver/internal/alias.go index 407f4014fe..0951172a09 100644 --- a/roomserver/internal/alias.go +++ b/roomserver/internal/alias.go @@ -173,7 +173,7 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias( } sender := request.UserID - if sender != ev.SenderID() { + if request.UserID != ev.SenderID() { sender = ev.SenderID() } diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index 52051f7a4d..e203ebf3af 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -128,7 +128,7 @@ func (r *Inputer) processRoomEvent( if roomInfo == nil && !isCreateEvent { return fmt.Errorf("room %s does not exist for event %s", event.RoomID(), event.EventID()) } - sender, err := event.UserID() + _, senderDomain, err := gomatrixserverlib.SplitID('@', event.SenderID()) if err != nil { return fmt.Errorf("event has invalid sender %q", event.SenderID()) } @@ -193,9 +193,9 @@ func (r *Inputer) processRoomEvent( serverRes.ServerNames = append(serverRes.ServerNames, input.Origin) delete(servers, input.Origin) } - if sender.Domain() != input.Origin && sender.Domain() != r.Cfg.Matrix.ServerName { - serverRes.ServerNames = append(serverRes.ServerNames, sender.Domain()) - delete(servers, sender.Domain()) + if senderDomain != input.Origin && senderDomain != r.Cfg.Matrix.ServerName { + serverRes.ServerNames = append(serverRes.ServerNames, senderDomain) + delete(servers, senderDomain) } for server := range servers { serverRes.ServerNames = append(serverRes.ServerNames, server) @@ -451,7 +451,7 @@ func (r *Inputer) processRoomEvent( } // Handle remote room upgrades, e.g. remove published room - if event.Type() == "m.room.tombstone" && event.StateKeyEquals("") && !r.Cfg.Matrix.IsLocalServerName(sender.Domain()) { + if event.Type() == "m.room.tombstone" && event.StateKeyEquals("") && !r.Cfg.Matrix.IsLocalServerName(senderDomain) { if err = r.handleRemoteRoomUpgrade(ctx, event); err != nil { return fmt.Errorf("failed to handle remote room upgrade: %w", err) } diff --git a/roomserver/internal/perform/perform_backfill.go b/roomserver/internal/perform/perform_backfill.go index e9318235f0..9b6d7f6788 100644 --- a/roomserver/internal/perform/perform_backfill.go +++ b/roomserver/internal/perform/perform_backfill.go @@ -484,11 +484,9 @@ FindSuccessor: // Store the server names in a temporary map to avoid duplicates. serverSet := make(map[spec.ServerName]bool) for _, event := range memberEvents { - sender, err := event.UserID() - if err != nil { - continue + if _, senderDomain, err := gomatrixserverlib.SplitID('@', event.SenderID()); err == nil { + serverSet[senderDomain] = true } - serverSet[sender.Domain()] = true } var servers []spec.ServerName for server := range serverSet { diff --git a/roomserver/internal/perform/perform_invite.go b/roomserver/internal/perform/perform_invite.go index 34984eea6f..bb41e0dfcf 100644 --- a/roomserver/internal/perform/perform_invite.go +++ b/roomserver/internal/perform/perform_invite.go @@ -125,7 +125,7 @@ func (r *Inviter) PerformInvite( ) error { event := req.Event - sender, err := event.UserID() + sender, err := spec.NewUserID(event.SenderID(), true) if err != nil { return spec.InvalidParam("The user ID is invalid") } diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 78cc15c916..1ee436f99c 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -988,14 +988,8 @@ func (d *EventDatabase) MaybeRedactEvent( return nil } - sender1, err := redactedEvent.UserID() - if err != nil { - return err - } - sender2, err := redactionEvent.UserID() - if err != nil { - return err - } + _, sender1, _ := gomatrixserverlib.SplitID('@', redactedEvent.SenderID()) + _, sender2, _ := gomatrixserverlib.SplitID('@', redactionEvent.SenderID()) var powerlevels *gomatrixserverlib.PowerLevelContent powerlevels, err = plResolver.Resolve(ctx, redactionEvent.EventID()) if err != nil { @@ -1005,7 +999,7 @@ func (d *EventDatabase) MaybeRedactEvent( switch { case powerlevels.UserLevel(redactionEvent.SenderID()) >= powerlevels.Redact: // 1. The power level of the redaction event’s sender is greater than or equal to the redact level. - case sender1.Domain() == sender2.Domain(): + case sender1 == sender2: // 2. The domain of the redaction event’s sender matches that of the original event’s sender. default: ignoreRedaction = true diff --git a/syncapi/storage/shared/storage_consumer.go b/syncapi/storage/shared/storage_consumer.go index 3a4758337f..2211c257c0 100644 --- a/syncapi/storage/shared/storage_consumer.go +++ b/syncapi/storage/shared/storage_consumer.go @@ -195,14 +195,7 @@ func (d *Database) StreamEventsToEvents(device *userapi.Device, in []types.Strea for i := 0; i < len(in); i++ { out[i] = in[i].HeaderedEvent if device != nil && in[i].TransactionID != nil { - userID, err := in[i].UserID() - if err != nil { - logrus.WithFields(logrus.Fields{ - "event_id": out[i].EventID(), - }).WithError(err).Warnf("Event has invalid userID") - continue - } - if device.UserID == userID.String() && device.SessionID == in[i].TransactionID.SessionID { + if device.UserID == in[i].SenderID() && device.SessionID == in[i].TransactionID.SessionID { err := out[i].SetUnsignedField( "transaction_id", in[i].TransactionID.TransactionID, ) diff --git a/syncapi/streams/stream_invite.go b/syncapi/streams/stream_invite.go index b35c24facf..2f05382bbb 100644 --- a/syncapi/streams/stream_invite.go +++ b/syncapi/streams/stream_invite.go @@ -63,11 +63,7 @@ func (p *InviteStreamProvider) IncrementalSync( for roomID, inviteEvent := range invites { // skip ignored user events - userID, err := inviteEvent.UserID() - if err != nil { - continue - } - if _, ok := req.IgnoredUsers.List[userID.String()]; ok { + if _, ok := req.IgnoredUsers.List[inviteEvent.SenderID()]; ok { continue } ir := types.NewInviteResponse(inviteEvent) diff --git a/syncapi/synctypes/clientevent.go b/syncapi/synctypes/clientevent.go index 2d5daa2828..bbd4a75c70 100644 --- a/syncapi/synctypes/clientevent.go +++ b/syncapi/synctypes/clientevent.go @@ -57,13 +57,9 @@ func ToClientEvents(serverEvs []gomatrixserverlib.PDU, format ClientEventFormat) // ToClientEvent converts a single server event to a client event. func ToClientEvent(se gomatrixserverlib.PDU, format ClientEventFormat) ClientEvent { - sender, err := se.UserID() - if err != nil { - panic("uh oh! userID does not exist...") - } ce := ClientEvent{ Content: spec.RawJSON(se.Content()), - Sender: sender.String(), + Sender: se.SenderID(), Type: se.Type(), StateKey: se.StateKey(), Unsigned: spec.RawJSON(se.Unsigned()), diff --git a/userapi/consumers/roomserver.go b/userapi/consumers/roomserver.go index b74f843745..a4cb3defe8 100644 --- a/userapi/consumers/roomserver.go +++ b/userapi/consumers/roomserver.go @@ -615,11 +615,7 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *rstype // evaluatePushRules fetches and evaluates the push rules of a local // user. Returns actions (including dont_notify). func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event *rstypes.HeaderedEvent, mem *localMembership, roomSize int) ([]*pushrules.Action, error) { - userID, err := event.UserID() - if err != nil { - return nil, err - } - if userID.String() == mem.UserID { + if event.SenderID() == mem.UserID { // SPEC: Homeservers MUST NOT notify the Push Gateway for // events that the user has sent themselves. return nil, nil From 222b67bfebba07068771386aca4b80b1c6e739ea Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Fri, 2 Jun 2023 12:57:08 -0600 Subject: [PATCH 03/28] Use UserID in appservice --- appservice/consumers/roomserver.go | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/appservice/consumers/roomserver.go b/appservice/consumers/roomserver.go index 97dd90deba..ce693e6482 100644 --- a/appservice/consumers/roomserver.go +++ b/appservice/consumers/roomserver.go @@ -233,10 +233,16 @@ func (s *appserviceState) backoffAndPause(err error) error { // // TODO: This should be cached, see https://github.com/matrix-org/dendrite/issues/1682 func (s *OutputRoomEventConsumer) appserviceIsInterestedInEvent(ctx context.Context, event *types.HeaderedEvent, appservice *config.ApplicationService) bool { + user := "" + userID, err := event.UserID() + if err == nil { + user = userID.String() + } + switch { case appservice.URL == "": return false - case appservice.IsInterestedInUserID(event.SenderID()): + case appservice.IsInterestedInUserID(user): return true case appservice.IsInterestedInRoomID(event.RoomID()): return true From d81b63f5564766866bc347ddfcc5b3dc29e8b13c Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Fri, 2 Jun 2023 14:35:14 -0600 Subject: [PATCH 04/28] Refine SenderID/UserID usage --- clientapi/routing/redaction.go | 2 +- federationapi/internal/perform.go | 9 +++++++-- federationapi/routing/leave.go | 8 ++++---- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/clientapi/routing/redaction.go b/clientapi/routing/redaction.go index bd3a79ebc5..e94c7748ee 100644 --- a/clientapi/routing/redaction.go +++ b/clientapi/routing/redaction.go @@ -76,7 +76,7 @@ func SendRedaction( // "Users may redact their own events, and any user with a power level greater than or equal // to the redact power level of the room may redact events there" // https://matrix.org/docs/spec/client_server/r0.6.1#put-matrix-client-r0-rooms-roomid-redact-eventid-txnid - allowedToRedact := ev.SenderID() == device.UserID + allowedToRedact := ev.SenderID() == device.UserID // TODO: Should replace device.UserID with device...PerRoomKey if !allowedToRedact { plEvent := roomserverAPI.GetStateEvent(req.Context(), rsAPI, roomID, gomatrixserverlib.StateKeyTuple{ EventType: spec.MRoomPowerLevels, diff --git a/federationapi/internal/perform.go b/federationapi/internal/perform.go index 960a4461e4..c4c7156636 100644 --- a/federationapi/internal/perform.go +++ b/federationapi/internal/perform.go @@ -509,10 +509,15 @@ func (r *FederationInternalAPI) SendInvite( event gomatrixserverlib.PDU, strippedState []gomatrixserverlib.InviteStrippedState, ) (gomatrixserverlib.PDU, error) { - _, origin, err := r.cfg.Matrix.SplitLocalID('@', event.SenderID()) + inviter, err := event.UserID() if err != nil { return nil, err } + // For portable accounts, we need to verify the inviter domain is still associated with this server. + // The userID of the inviter may have changed to another server in which case we cannot send the invite. + if !r.cfg.Matrix.IsLocalServerName(inviter.Domain()) { + return nil, fmt.Errorf("the invite must be from a local user") + } if event.StateKey() == nil { return nil, errors.New("invite must be a state event") @@ -542,7 +547,7 @@ func (r *FederationInternalAPI) SendInvite( return nil, fmt.Errorf("gomatrixserverlib.NewInviteV2Request: %w", err) } - inviteRes, err := r.federation.SendInviteV2(ctx, origin, destination, inviteReq) + inviteRes, err := r.federation.SendInviteV2(ctx, inviter.Domain(), destination, inviteReq) if err != nil { return nil, fmt.Errorf("r.federation.SendInviteV2: failed to send invite: %w", err) } diff --git a/federationapi/routing/leave.go b/federationapi/routing/leave.go index 84976901aa..33cd54fad9 100644 --- a/federationapi/routing/leave.go +++ b/federationapi/routing/leave.go @@ -223,13 +223,13 @@ func SendLeave( // Check that the sender belongs to the server that is sending us // the request. By this point we've already asserted that the sender // and the state key are equal so we don't need to check both. - var serverName spec.ServerName - if _, serverName, err = gomatrixserverlib.SplitID('@', event.SenderID()); err != nil { + sender, err := event.UserID() + if err != nil { return util.JSONResponse{ Code: http.StatusForbidden, JSON: spec.Forbidden("The sender of the join is invalid"), } - } else if serverName != request.Origin() { + } else if sender.Domain() != request.Origin() { return util.JSONResponse{ Code: http.StatusForbidden, JSON: spec.Forbidden("The sender does not match the server that originated the request"), @@ -291,7 +291,7 @@ func SendLeave( } } verifyRequests := []gomatrixserverlib.VerifyJSONRequest{{ - ServerName: serverName, + ServerName: sender.Domain(), Message: redacted, AtTS: event.OriginServerTS(), StrictValidityChecking: true, From 26ef0246e8bbc3319d30af7005700d633464db96 Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Fri, 2 Jun 2023 15:01:58 -0600 Subject: [PATCH 05/28] Refine SenderID/UserID usage for push rules --- internal/pushrules/evaluate.go | 6 +++++- userapi/consumers/roomserver.go | 4 ++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/internal/pushrules/evaluate.go b/internal/pushrules/evaluate.go index fc65b975f9..399041ad3d 100644 --- a/internal/pushrules/evaluate.go +++ b/internal/pushrules/evaluate.go @@ -113,7 +113,11 @@ func ruleMatches(rule *Rule, kind Kind, event gomatrixserverlib.PDU, ec Evaluati return rule.RuleID == event.RoomID(), nil case SenderKind: - return rule.RuleID == event.SenderID(), nil + sender, err := event.UserID() + if err != nil { + return false, nil + } + return rule.RuleID == sender.String(), nil default: return false, nil diff --git a/userapi/consumers/roomserver.go b/userapi/consumers/roomserver.go index a4cb3defe8..a512e62805 100644 --- a/userapi/consumers/roomserver.go +++ b/userapi/consumers/roomserver.go @@ -682,7 +682,7 @@ func (rse *ruleSetEvalContext) UserDisplayName() string { return rse.mem.Display func (rse *ruleSetEvalContext) RoomMemberCount() (int, error) { return rse.roomSize, nil } -func (rse *ruleSetEvalContext) HasPowerLevel(userID, levelKey string) (bool, error) { +func (rse *ruleSetEvalContext) HasPowerLevel(senderID, levelKey string) (bool, error) { req := &rsapi.QueryLatestEventsAndStateRequest{ RoomID: rse.roomID, StateToFetch: []gomatrixserverlib.StateKeyTuple{ @@ -702,7 +702,7 @@ func (rse *ruleSetEvalContext) HasPowerLevel(userID, levelKey string) (bool, err if err != nil { return false, err } - return plc.UserLevel(userID) >= plc.NotificationLevel(levelKey), nil + return plc.UserLevel(senderID) >= plc.NotificationLevel(levelKey), nil } return true, nil } From 0c902f49b681b2407e40b141f5663e7f7fe6a8ed Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Fri, 2 Jun 2023 15:58:42 -0600 Subject: [PATCH 06/28] Refine SenderID/UserID usage --- clientapi/routing/directory.go | 41 ++++++++++++++++--- internal/pushrules/evaluate.go | 7 ++-- internal/pushrules/evaluate_test.go | 8 ++-- roomserver/api/alias.go | 4 +- roomserver/api/api.go | 6 +++ roomserver/internal/alias.go | 12 +++--- .../internal/perform/perform_create_room.go | 6 +-- .../internal/perform/perform_upgrade.go | 4 +- roomserver/internal/query/query.go | 5 +++ roomserver/roomserver_test.go | 2 +- 10 files changed, 69 insertions(+), 26 deletions(-) diff --git a/clientapi/routing/directory.go b/clientapi/routing/directory.go index c786f8cc4e..f1f729cc3e 100644 --- a/clientapi/routing/directory.go +++ b/clientapi/routing/directory.go @@ -181,10 +181,25 @@ func SetLocalAlias( return *resErr } + userID, err := spec.NewUserID(device.UserID, true) + if err != nil { + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{Err: "UserID for device is invalid"}, + } + } + + deviceSenderID, err := rsAPI.QuerySenderIDForRoom(req.Context(), alias, *userID) + if err != nil { + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{Err: "Could not find SenderID for this device"}, + } + } queryReq := roomserverAPI.SetRoomAliasRequest{ - UserID: device.UserID, - RoomID: r.RoomID, - Alias: alias, + SenderID: deviceSenderID, + RoomID: r.RoomID, + Alias: alias, } var queryRes roomserverAPI.SetRoomAliasResponse if err := rsAPI.SetRoomAlias(req.Context(), &queryReq, &queryRes); err != nil { @@ -215,9 +230,25 @@ func RemoveLocalAlias( alias string, rsAPI roomserverAPI.ClientRoomserverAPI, ) util.JSONResponse { + userID, err := spec.NewUserID(device.UserID, true) + if err != nil { + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{Err: "UserID for device is invalid"}, + } + } + + deviceSenderID, err := rsAPI.QuerySenderIDForRoom(req.Context(), alias, *userID) + if err != nil { + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{Err: "Could not find SenderID for this device"}, + } + } + queryReq := roomserverAPI.RemoveRoomAliasRequest{ - Alias: alias, - UserID: device.UserID, + Alias: alias, + SenderID: deviceSenderID, } var queryRes roomserverAPI.RemoveRoomAliasResponse if err := rsAPI.RemoveRoomAlias(req.Context(), &queryReq, &queryRes); err != nil { diff --git a/internal/pushrules/evaluate.go b/internal/pushrules/evaluate.go index 399041ad3d..646cc90148 100644 --- a/internal/pushrules/evaluate.go +++ b/internal/pushrules/evaluate.go @@ -113,11 +113,12 @@ func ruleMatches(rule *Rule, kind Kind, event gomatrixserverlib.PDU, ec Evaluati return rule.RuleID == event.RoomID(), nil case SenderKind: + userID := "" sender, err := event.UserID() - if err != nil { - return false, nil + if err == nil { + userID = sender.String() } - return rule.RuleID == sender.String(), nil + return rule.RuleID == userID, nil default: return false, nil diff --git a/internal/pushrules/evaluate_test.go b/internal/pushrules/evaluate_test.go index 5045a864eb..e3bfe48d0a 100644 --- a/internal/pushrules/evaluate_test.go +++ b/internal/pushrules/evaluate_test.go @@ -82,11 +82,11 @@ func TestRuleMatches(t *testing.T) { {"contentMatch", ContentKind, Rule{Enabled: true, Pattern: pointer("b")}, `{"content":{"body":"abc"}}`, true}, {"contentNoMatch", ContentKind, Rule{Enabled: true, Pattern: pointer("d")}, `{"content":{"body":"abc"}}`, false}, - {"roomMatch", RoomKind, Rule{Enabled: true, RuleID: "!room@example.com"}, `{"room_id":"!room@example.com"}`, true}, - {"roomNoMatch", RoomKind, Rule{Enabled: true, RuleID: "!room@example.com"}, `{"room_id":"!otherroom@example.com"}`, false}, + {"roomMatch", RoomKind, Rule{Enabled: true, RuleID: "!room:example.com"}, `{"room_id":"!room:example.com"}`, true}, + {"roomNoMatch", RoomKind, Rule{Enabled: true, RuleID: "!room:example.com"}, `{"room_id":"!otherroom:example.com"}`, false}, - {"senderMatch", SenderKind, Rule{Enabled: true, RuleID: "@user@example.com"}, `{"sender":"@user@example.com"}`, true}, - {"senderNoMatch", SenderKind, Rule{Enabled: true, RuleID: "@user@example.com"}, `{"sender":"@otheruser@example.com"}`, false}, + {"senderMatch", SenderKind, Rule{Enabled: true, RuleID: "@user:example.com"}, `{"sender":"@user:example.com"}`, true}, + {"senderNoMatch", SenderKind, Rule{Enabled: true, RuleID: "@user:example.com"}, `{"sender":"@otheruser:example.com"}`, false}, } for _, tst := range tsts { t.Run(tst.Name, func(t *testing.T) { diff --git a/roomserver/api/alias.go b/roomserver/api/alias.go index 37892a44a1..0404c1cdd7 100644 --- a/roomserver/api/alias.go +++ b/roomserver/api/alias.go @@ -19,7 +19,7 @@ import "regexp" // SetRoomAliasRequest is a request to SetRoomAlias type SetRoomAliasRequest struct { // ID of the user setting the alias - UserID string `json:"user_id"` + SenderID string `json:"user_id"` // New alias for the room Alias string `json:"alias"` // The room ID the alias is referring to @@ -62,7 +62,7 @@ type GetAliasesForRoomIDResponse struct { // RemoveRoomAliasRequest is a request to RemoveRoomAlias type RemoveRoomAliasRequest struct { // ID of the user removing the alias - UserID string `json:"user_id"` + SenderID string `json:"user_id"` // The room alias to remove Alias string `json:"alias"` } diff --git a/roomserver/api/api.go b/roomserver/api/api.go index 7cb3379e03..be0e9b16e9 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -65,6 +65,11 @@ type InputRoomEventsAPI interface { ) } +type QuerySenderIDAPI interface { + // Accepts either roomID or alias + QuerySenderIDForRoom(ctx context.Context, roomAliasOrID string, userID spec.UserID) (string, error) +} + // Query the latest events and state for a room from the room server. type QueryLatestEventsAndStateAPI interface { QueryLatestEventsAndState(ctx context.Context, req *QueryLatestEventsAndStateRequest, res *QueryLatestEventsAndStateResponse) error @@ -158,6 +163,7 @@ type ClientRoomserverAPI interface { QueryLatestEventsAndStateAPI QueryBulkStateContentAPI QueryEventsAPI + QuerySenderIDAPI QueryMembershipForUser(ctx context.Context, req *QueryMembershipForUserRequest, res *QueryMembershipForUserResponse) error QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error QueryRoomsForUser(ctx context.Context, req *QueryRoomsForUserRequest, res *QueryRoomsForUserResponse) error diff --git a/roomserver/internal/alias.go b/roomserver/internal/alias.go index 0951172a09..da275725db 100644 --- a/roomserver/internal/alias.go +++ b/roomserver/internal/alias.go @@ -51,7 +51,7 @@ func (r *RoomserverInternalAPI) SetRoomAlias( response.AliasExists = false // Save the new alias - if err := r.DB.SetRoomAlias(ctx, request.Alias, request.RoomID, request.UserID); err != nil { + if err := r.DB.SetRoomAlias(ctx, request.Alias, request.RoomID, request.SenderID); err != nil { return err } @@ -119,7 +119,7 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias( request *api.RemoveRoomAliasRequest, response *api.RemoveRoomAliasResponse, ) error { - _, virtualHost, err := r.Cfg.Global.SplitLocalID('@', request.UserID) + _, virtualHost, err := r.Cfg.Global.SplitLocalID('@', request.SenderID) if err != nil { return err } @@ -140,7 +140,7 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias( return fmt.Errorf("r.DB.GetCreatorIDForAlias: %w", err) } - if creatorID != request.UserID { + if creatorID != request.SenderID { var plEvent *types.HeaderedEvent var pls *gomatrixserverlib.PowerLevelContent @@ -154,7 +154,7 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias( return fmt.Errorf("plEvent.PowerLevels: %w", err) } - if pls.UserLevel(request.UserID) < pls.EventLevel(spec.MRoomCanonicalAlias, true) { + if pls.UserLevel(request.SenderID) < pls.EventLevel(spec.MRoomCanonicalAlias, true) { response.Removed = false return nil } @@ -172,8 +172,8 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias( return err } - sender := request.UserID - if request.UserID != ev.SenderID() { + sender := request.SenderID + if request.SenderID != ev.SenderID() { sender = ev.SenderID() } diff --git a/roomserver/internal/perform/perform_create_room.go b/roomserver/internal/perform/perform_create_room.go index 41194832d4..1ae47d9453 100644 --- a/roomserver/internal/perform/perform_create_room.go +++ b/roomserver/internal/perform/perform_create_room.go @@ -350,9 +350,9 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo // been taken. if roomAlias != "" { aliasReq := api.SetRoomAliasRequest{ - Alias: roomAlias, - RoomID: roomID.String(), - UserID: userID.String(), + Alias: roomAlias, + RoomID: roomID.String(), + SenderID: userID.String(), } var aliasResp api.SetRoomAliasResponse diff --git a/roomserver/internal/perform/perform_upgrade.go b/roomserver/internal/perform/perform_upgrade.go index ff4a6a1dc8..7e615f2a3c 100644 --- a/roomserver/internal/perform/perform_upgrade.go +++ b/roomserver/internal/perform/perform_upgrade.go @@ -176,13 +176,13 @@ func moveLocalAliases(ctx context.Context, } for _, alias := range aliasRes.Aliases { - removeAliasReq := api.RemoveRoomAliasRequest{UserID: userID, Alias: alias} + removeAliasReq := api.RemoveRoomAliasRequest{SenderID: userID, Alias: alias} removeAliasRes := api.RemoveRoomAliasResponse{} if err = URSAPI.RemoveRoomAlias(ctx, &removeAliasReq, &removeAliasRes); err != nil { return fmt.Errorf("Failed to remove old room alias: %w", err) } - setAliasReq := api.SetRoomAliasRequest{UserID: userID, Alias: alias, RoomID: newRoomID} + setAliasReq := api.SetRoomAliasRequest{SenderID: userID, Alias: alias, RoomID: newRoomID} setAliasRes := api.SetRoomAliasResponse{} if err = URSAPI.SetRoomAlias(ctx, &setAliasReq, &setAliasRes); err != nil { return fmt.Errorf("Failed to set new room alias: %w", err) diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index effcc90d7e..15d226a74d 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -1034,3 +1034,8 @@ func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, req *api.Query } return nil } + +func (r *Queryer) QuerySenderIDForRoom(ctx context.Context, roomAliasOrID string, userID spec.UserID) (string, error) { + // TODO: implement this properly with pseudoIDs + return userID.String(), nil +} diff --git a/roomserver/roomserver_test.go b/roomserver/roomserver_test.go index d19ebebe43..6ff03025af 100644 --- a/roomserver/roomserver_test.go +++ b/roomserver/roomserver_test.go @@ -267,7 +267,7 @@ func TestPurgeRoom(t *testing.T) { } aliasResp := &api.SetRoomAliasResponse{} - if err = rsAPI.SetRoomAlias(ctx, &api.SetRoomAliasRequest{RoomID: room.ID, Alias: "myalias", UserID: alice.ID}, aliasResp); err != nil { + if err = rsAPI.SetRoomAlias(ctx, &api.SetRoomAliasRequest{RoomID: room.ID, Alias: "myalias", SenderID: alice.ID}, aliasResp); err != nil { t.Fatal(err) } // check the alias is actually there From 00e719f7e7a2cab4a624960dbdbac72847114d59 Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Fri, 2 Jun 2023 16:36:11 -0600 Subject: [PATCH 07/28] Refine SenderID/UserID usage --- clientapi/routing/directory.go | 4 ++-- roomserver/api/api.go | 2 +- roomserver/internal/input/input_events.go | 12 +++++++----- roomserver/internal/perform/perform_backfill.go | 4 ++-- roomserver/internal/perform/perform_invite.go | 4 ++-- 5 files changed, 14 insertions(+), 12 deletions(-) diff --git a/clientapi/routing/directory.go b/clientapi/routing/directory.go index f1f729cc3e..bc8ef6d655 100644 --- a/clientapi/routing/directory.go +++ b/clientapi/routing/directory.go @@ -189,7 +189,7 @@ func SetLocalAlias( } } - deviceSenderID, err := rsAPI.QuerySenderIDForRoom(req.Context(), alias, *userID) + deviceSenderID, err := rsAPI.QuerySenderIDForUser(req.Context(), alias, *userID) if err != nil { return util.JSONResponse{ Code: http.StatusInternalServerError, @@ -238,7 +238,7 @@ func RemoveLocalAlias( } } - deviceSenderID, err := rsAPI.QuerySenderIDForRoom(req.Context(), alias, *userID) + deviceSenderID, err := rsAPI.QuerySenderIDForUser(req.Context(), alias, *userID) if err != nil { return util.JSONResponse{ Code: http.StatusInternalServerError, diff --git a/roomserver/api/api.go b/roomserver/api/api.go index be0e9b16e9..99ec537193 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -67,7 +67,7 @@ type InputRoomEventsAPI interface { type QuerySenderIDAPI interface { // Accepts either roomID or alias - QuerySenderIDForRoom(ctx context.Context, roomAliasOrID string, userID spec.UserID) (string, error) + QuerySenderIDForUser(ctx context.Context, roomAliasOrID string, userID spec.UserID) (string, error) } // Query the latest events and state for a room from the room server. diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index e203ebf3af..7fb5bfc131 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -128,7 +128,7 @@ func (r *Inputer) processRoomEvent( if roomInfo == nil && !isCreateEvent { return fmt.Errorf("room %s does not exist for event %s", event.RoomID(), event.EventID()) } - _, senderDomain, err := gomatrixserverlib.SplitID('@', event.SenderID()) + sender, err := event.UserID() if err != nil { return fmt.Errorf("event has invalid sender %q", event.SenderID()) } @@ -193,9 +193,9 @@ func (r *Inputer) processRoomEvent( serverRes.ServerNames = append(serverRes.ServerNames, input.Origin) delete(servers, input.Origin) } - if senderDomain != input.Origin && senderDomain != r.Cfg.Matrix.ServerName { - serverRes.ServerNames = append(serverRes.ServerNames, senderDomain) - delete(servers, senderDomain) + if sender.Domain() != input.Origin && sender.Domain() != r.Cfg.Matrix.ServerName { + serverRes.ServerNames = append(serverRes.ServerNames, sender.Domain()) + delete(servers, sender.Domain()) } for server := range servers { serverRes.ServerNames = append(serverRes.ServerNames, server) @@ -451,7 +451,7 @@ func (r *Inputer) processRoomEvent( } // Handle remote room upgrades, e.g. remove published room - if event.Type() == "m.room.tombstone" && event.StateKeyEquals("") && !r.Cfg.Matrix.IsLocalServerName(senderDomain) { + if event.Type() == "m.room.tombstone" && event.StateKeyEquals("") && !r.Cfg.Matrix.IsLocalServerName(sender.Domain()) { if err = r.handleRemoteRoomUpgrade(ctx, event); err != nil { return fmt.Errorf("failed to handle remote room upgrade: %w", err) } @@ -828,11 +828,13 @@ func (r *Inputer) kickGuests(ctx context.Context, event gomatrixserverlib.PDU, r continue } + // TODO: pseudoIDs: get userID for room using state key (which is now senderID) localpart, senderDomain, err := gomatrixserverlib.SplitID('@', *memberEvent.StateKey()) if err != nil { continue } + // TODO: pseudoIDs: query account by state key (which is now senderID) accountRes := &userAPI.QueryAccountByLocalpartResponse{} if err = r.UserAPI.QueryAccountByLocalpart(ctx, &userAPI.QueryAccountByLocalpartRequest{ Localpart: localpart, diff --git a/roomserver/internal/perform/perform_backfill.go b/roomserver/internal/perform/perform_backfill.go index 9b6d7f6788..686bcc1647 100644 --- a/roomserver/internal/perform/perform_backfill.go +++ b/roomserver/internal/perform/perform_backfill.go @@ -484,8 +484,8 @@ FindSuccessor: // Store the server names in a temporary map to avoid duplicates. serverSet := make(map[spec.ServerName]bool) for _, event := range memberEvents { - if _, senderDomain, err := gomatrixserverlib.SplitID('@', event.SenderID()); err == nil { - serverSet[senderDomain] = true + if sender, err := event.UserID(); err == nil { + serverSet[sender.Domain()] = true } } var servers []spec.ServerName diff --git a/roomserver/internal/perform/perform_invite.go b/roomserver/internal/perform/perform_invite.go index bb41e0dfcf..7a12bc2cf6 100644 --- a/roomserver/internal/perform/perform_invite.go +++ b/roomserver/internal/perform/perform_invite.go @@ -125,9 +125,9 @@ func (r *Inviter) PerformInvite( ) error { event := req.Event - sender, err := spec.NewUserID(event.SenderID(), true) + sender, err := event.UserID() if err != nil { - return spec.InvalidParam("The user ID is invalid") + return spec.InvalidParam("The sender user ID is invalid") } if !r.Cfg.Matrix.IsLocalServerName(sender.Domain()) { return api.ErrInvalidID{Err: fmt.Errorf("the invite must be from a local user")} From 8a3c01fa19a3e932f67cc26a5c9a68dd9f994189 Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Fri, 2 Jun 2023 18:00:43 -0600 Subject: [PATCH 08/28] Refine SenderID/UserID usage --- roomserver/api/api.go | 1 + roomserver/internal/perform/perform_invite.go | 5 +++-- roomserver/internal/query/query.go | 11 ++++++++++- 3 files changed, 14 insertions(+), 3 deletions(-) diff --git a/roomserver/api/api.go b/roomserver/api/api.go index 99ec537193..ec1cd6abbd 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -68,6 +68,7 @@ type InputRoomEventsAPI interface { type QuerySenderIDAPI interface { // Accepts either roomID or alias QuerySenderIDForUser(ctx context.Context, roomAliasOrID string, userID spec.UserID) (string, error) + QueryUserIDForSender(ctx context.Context, roomAliasOrID string, senderID string) (spec.UserID, error) } // Query the latest events and state for a room from the room server. diff --git a/roomserver/internal/perform/perform_invite.go b/roomserver/internal/perform/perform_invite.go index 7a12bc2cf6..329395e631 100644 --- a/roomserver/internal/perform/perform_invite.go +++ b/roomserver/internal/perform/perform_invite.go @@ -97,11 +97,12 @@ func (r *Inviter) ProcessInviteMembership( ) ([]api.OutputEvent, error) { var outputUpdates []api.OutputEvent var updater *shared.MembershipUpdater - _, domain, err := gomatrixserverlib.SplitID('@', *inviteEvent.StateKey()) + + userID, err := r.RSAPI.QueryUserIDForSender(ctx, inviteEvent.RoomID(), *inviteEvent.StateKey()) if err != nil { return nil, api.ErrInvalidID{Err: fmt.Errorf("the user ID %s is invalid", *inviteEvent.StateKey())} } - isTargetLocal := r.Cfg.Matrix.IsLocalServerName(domain) + isTargetLocal := r.Cfg.Matrix.IsLocalServerName(userID.Domain()) if updater, err = r.DB.MembershipUpdater(ctx, inviteEvent.RoomID(), *inviteEvent.StateKey(), isTargetLocal, inviteEvent.Version()); err != nil { return nil, fmt.Errorf("r.DB.MembershipUpdater: %w", err) } diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index 15d226a74d..d249b28969 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -1035,7 +1035,16 @@ func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, req *api.Query return nil } -func (r *Queryer) QuerySenderIDForRoom(ctx context.Context, roomAliasOrID string, userID spec.UserID) (string, error) { +func (r *Queryer) QuerySenderIDForUser(ctx context.Context, roomAliasOrID string, userID spec.UserID) (string, error) { // TODO: implement this properly with pseudoIDs return userID.String(), nil } + +func (r *Queryer) QueryUserIDForSender(ctx context.Context, roomAliasOrID string, senderID string) (spec.UserID, error) { + // TODO: implement this properly with pseudoIDs + userID, err := spec.NewUserID(senderID, true) + if err != nil { + return spec.UserID{}, err + } + return *userID, err +} From bf5bf8363de17f86d5ca49b1fd4d3480e3170fbb Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Fri, 2 Jun 2023 18:25:46 -0600 Subject: [PATCH 09/28] Refine SenderID/UserID usage --- roomserver/storage/shared/storage.go | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 1ee436f99c..8c37651dd0 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -988,8 +988,16 @@ func (d *EventDatabase) MaybeRedactEvent( return nil } - _, sender1, _ := gomatrixserverlib.SplitID('@', redactedEvent.SenderID()) - _, sender2, _ := gomatrixserverlib.SplitID('@', redactionEvent.SenderID()) + sender1Domain := "" + sender1, err := redactedEvent.UserID() + if err == nil { + sender1Domain = string(sender1.Domain()) + } + sender2Domain := "" + sender2, err := redactionEvent.UserID() + if err == nil { + sender2Domain = string(sender2.Domain()) + } var powerlevels *gomatrixserverlib.PowerLevelContent powerlevels, err = plResolver.Resolve(ctx, redactionEvent.EventID()) if err != nil { @@ -999,7 +1007,7 @@ func (d *EventDatabase) MaybeRedactEvent( switch { case powerlevels.UserLevel(redactionEvent.SenderID()) >= powerlevels.Redact: // 1. The power level of the redaction event’s sender is greater than or equal to the redact level. - case sender1 == sender2: + case sender1Domain == sender2Domain: // 2. The domain of the redaction event’s sender matches that of the original event’s sender. default: ignoreRedaction = true From 752c89fb028cd49c8d928627e5a6f8d985a528ab Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Fri, 2 Jun 2023 19:01:51 -0600 Subject: [PATCH 10/28] Refine SenderID/UserID usage --- syncapi/storage/shared/storage_consumer.go | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/syncapi/storage/shared/storage_consumer.go b/syncapi/storage/shared/storage_consumer.go index 2211c257c0..8fa2f3863c 100644 --- a/syncapi/storage/shared/storage_consumer.go +++ b/syncapi/storage/shared/storage_consumer.go @@ -195,7 +195,21 @@ func (d *Database) StreamEventsToEvents(device *userapi.Device, in []types.Strea for i := 0; i < len(in); i++ { out[i] = in[i].HeaderedEvent if device != nil && in[i].TransactionID != nil { - if device.UserID == in[i].SenderID() && device.SessionID == in[i].TransactionID.SessionID { + userID, err := spec.NewUserID(device.UserID, true) + if err != nil { + logrus.WithFields(logrus.Fields{ + "event_id": out[i].EventID(), + }).WithError(err).Warnf("Failed to add transaction ID to event") + continue + } + deviceSenderID, err := d.getSenderIDForUser(in[i].RoomID(), *userID) + if err != nil { + logrus.WithFields(logrus.Fields{ + "event_id": out[i].EventID(), + }).WithError(err).Warnf("Failed to add transaction ID to event") + continue + } + if deviceSenderID == in[i].SenderID() && device.SessionID == in[i].TransactionID.SessionID { err := out[i].SetUnsignedField( "transaction_id", in[i].TransactionID.TransactionID, ) @@ -210,6 +224,11 @@ func (d *Database) StreamEventsToEvents(device *userapi.Device, in []types.Strea return out } +func (d *Database) getSenderIDForUser(roomID string, userID spec.UserID) (string, error) { + // TODO: Repalce with actual logic for pseudoIDs + return userID.String(), nil +} + // handleBackwardExtremities adds this event as a backwards extremity if and only if we do not have all of // the events listed in the event's 'prev_events'. This function also updates the backwards extremities table // to account for the fact that the given event is no longer a backwards extremity, but may be marked as such. From 7c82de99c91fdbd8ca422bb04a0a7ef78833e39a Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Fri, 2 Jun 2023 19:05:16 -0600 Subject: [PATCH 11/28] Refine SenderID/UserID usage --- syncapi/synctypes/clientevent.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/syncapi/synctypes/clientevent.go b/syncapi/synctypes/clientevent.go index bbd4a75c70..ba8ac77965 100644 --- a/syncapi/synctypes/clientevent.go +++ b/syncapi/synctypes/clientevent.go @@ -57,9 +57,14 @@ func ToClientEvents(serverEvs []gomatrixserverlib.PDU, format ClientEventFormat) // ToClientEvent converts a single server event to a client event. func ToClientEvent(se gomatrixserverlib.PDU, format ClientEventFormat) ClientEvent { + user := "" + userID, err := se.UserID() + if err == nil { + user = userID.String() + } ce := ClientEvent{ Content: spec.RawJSON(se.Content()), - Sender: se.SenderID(), + Sender: user, Type: se.Type(), StateKey: se.StateKey(), Unsigned: spec.RawJSON(se.Unsigned()), From 7cf96a865d39a7aa7758796eae64f32d03bf40e6 Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Fri, 2 Jun 2023 19:07:07 -0600 Subject: [PATCH 12/28] Refine SenderID/UserID usage --- syncapi/synctypes/clientevent_test.go | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/syncapi/synctypes/clientevent_test.go b/syncapi/synctypes/clientevent_test.go index 88a2e7e8d5..a4e7b3ff87 100644 --- a/syncapi/synctypes/clientevent_test.go +++ b/syncapi/synctypes/clientevent_test.go @@ -62,8 +62,13 @@ func TestToClientEvent(t *testing.T) { // nolint: gocyclo if !bytes.Equal(ce.Unsigned, ev.Unsigned()) { t.Errorf("ClientEvent.Unsigned: wanted %s, got %s", string(ev.Unsigned()), string(ce.Unsigned)) } - if ce.Sender != ev.SenderID() { - t.Errorf("ClientEvent.Sender: wanted %s, got %s", ev.SenderID(), ce.Sender) + user := "" + userID, err := ev.UserID() + if err == nil { + user = userID.String() + } + if ce.Sender != user { + t.Errorf("ClientEvent.Sender: wanted %s, got %s", user, ce.Sender) } j, err := json.Marshal(ce) if err != nil { From b324f7eccc27f06f32817977c0bbd4fa8ff1a19c Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Fri, 2 Jun 2023 19:37:08 -0600 Subject: [PATCH 13/28] Refine SenderID/UserID usage --- userapi/consumers/roomserver.go | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/userapi/consumers/roomserver.go b/userapi/consumers/roomserver.go index a512e62805..f51874a13d 100644 --- a/userapi/consumers/roomserver.go +++ b/userapi/consumers/roomserver.go @@ -615,7 +615,12 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *rstype // evaluatePushRules fetches and evaluates the push rules of a local // user. Returns actions (including dont_notify). func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event *rstypes.HeaderedEvent, mem *localMembership, roomSize int) ([]*pushrules.Action, error) { - if event.SenderID() == mem.UserID { + user := "" + userID, err := event.UserID() + if err == nil { + user = userID.String() + } + if user == mem.UserID { // SPEC: Homeservers MUST NOT notify the Push Gateway for // events that the user has sent themselves. return nil, nil @@ -767,7 +772,7 @@ func (s *OutputRoomEventConsumer) notifyHTTP(ctx context.Context, event *rstypes ID: event.EventID(), RoomID: event.RoomID(), RoomName: roomName, - Sender: event.SenderID(), + Sender: event.SenderID(), // TODO: Should push use UserIDs? Type: event.Type(), }, } From 3582ec7de7746c05a049f2d62521ee202b016134 Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Fri, 2 Jun 2023 20:12:00 -0600 Subject: [PATCH 14/28] Refine SenderID/UserID usage --- go.mod | 2 +- go.sum | 4 ++-- roomserver/api/query.go | 4 ++-- roomserver/storage/shared/storage.go | 8 ++++---- syncapi/storage/shared/storage_consumer.go | 2 +- 5 files changed, 10 insertions(+), 10 deletions(-) diff --git a/go.mod b/go.mod index 71e6ffc7a9..9ff9ab2869 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,7 @@ require ( github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 - github.com/matrix-org/gomatrixserverlib v0.0.0-20230602163838-69e05dbabdad + github.com/matrix-org/gomatrixserverlib v0.0.0-20230603021032-d30b8fdc7ced github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 github.com/mattn/go-sqlite3 v1.14.16 diff --git a/go.sum b/go.sum index 565057b44c..24544c90c7 100644 --- a/go.sum +++ b/go.sum @@ -323,8 +323,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230602163838-69e05dbabdad h1:l+mNyh4S3v0z320VOFuTIGZLjCAeTenVjYBiUNTZm/o= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230602163838-69e05dbabdad/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230603021032-d30b8fdc7ced h1:pbCM+nno+r2wW3jwxP65xmkzk6008CdMNZaOWYBwB1c= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230603021032-d30b8fdc7ced/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU= github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a h1:awrPDf9LEFySxTLKYBMCiObelNx/cBuv/wzllvCCH3A= github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a/go.mod h1:HchJX9oKMXaT2xYFs0Ha/6Zs06mxLU8k6F1ODnrGkeQ= github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y= diff --git a/roomserver/api/query.go b/roomserver/api/query.go index b33698c826..0449dd97cf 100644 --- a/roomserver/api/query.go +++ b/roomserver/api/query.go @@ -463,10 +463,10 @@ type MembershipQuerier struct { Roomserver FederationRoomserverAPI } -func (mq *MembershipQuerier) CurrentMembership(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (string, error) { +func (mq *MembershipQuerier) CurrentMembership(ctx context.Context, roomID spec.RoomID, senderID string) (string, error) { req := QueryMembershipForUserRequest{ RoomID: roomID.String(), - UserID: userID.String(), + UserID: senderID, } res := QueryMembershipForUserResponse{} err := mq.Roomserver.QueryMembershipForUser(ctx, &req, &res) diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 8c37651dd0..61183bc420 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -989,13 +989,13 @@ func (d *EventDatabase) MaybeRedactEvent( } sender1Domain := "" - sender1, err := redactedEvent.UserID() - if err == nil { + sender1, err1 := redactedEvent.UserID() + if err1 == nil { sender1Domain = string(sender1.Domain()) } sender2Domain := "" - sender2, err := redactionEvent.UserID() - if err == nil { + sender2, err2 := redactionEvent.UserID() + if err2 == nil { sender2Domain = string(sender2.Domain()) } var powerlevels *gomatrixserverlib.PowerLevelContent diff --git a/syncapi/storage/shared/storage_consumer.go b/syncapi/storage/shared/storage_consumer.go index 8fa2f3863c..17a6a69c30 100644 --- a/syncapi/storage/shared/storage_consumer.go +++ b/syncapi/storage/shared/storage_consumer.go @@ -224,7 +224,7 @@ func (d *Database) StreamEventsToEvents(device *userapi.Device, in []types.Strea return out } -func (d *Database) getSenderIDForUser(roomID string, userID spec.UserID) (string, error) { +func (d *Database) getSenderIDForUser(roomID string, userID spec.UserID) (string, error) { // nolint // TODO: Repalce with actual logic for pseudoIDs return userID.String(), nil } From d7e41177f30f0f4f7fd995c351e134236b2323e6 Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Mon, 5 Jun 2023 20:20:01 -0600 Subject: [PATCH 15/28] Remove pdu.UserID & pass in the required accessor --- appservice/consumers/roomserver.go | 6 ++-- clientapi/routing/sendevent.go | 4 ++- clientapi/routing/state.go | 12 +++++-- cmd/resolve-state/main.go | 5 ++- federationapi/federationapi_test.go | 4 +++ federationapi/internal/perform.go | 33 ++++++++++++------- federationapi/routing/join.go | 21 +++++++----- federationapi/routing/leave.go | 2 +- go.mod | 4 +-- go.sum | 8 ++--- internal/pushrules/evaluate.go | 9 ++--- internal/pushrules/evaluate_test.go | 9 +++-- internal/transactionrequest.go | 4 ++- internal/transactionrequest_test.go | 8 +++++ roomserver/api/api.go | 8 ++++- roomserver/api/query.go | 4 +-- roomserver/internal/helpers/auth.go | 4 ++- roomserver/internal/input/input_events.go | 18 +++++++--- .../internal/input/input_events_test.go | 2 +- roomserver/internal/input/input_missing.go | 24 ++++++++++---- roomserver/internal/perform/perform_admin.go | 8 +++-- .../internal/perform/perform_backfill.go | 10 ++++-- .../internal/perform/perform_create_room.go | 4 ++- roomserver/internal/perform/perform_invite.go | 5 ++- .../internal/perform/perform_upgrade.go | 8 +++-- roomserver/internal/query/query.go | 28 ++++++++-------- roomserver/state/state.go | 9 ++++- roomserver/storage/interface.go | 5 +++ roomserver/storage/shared/room_updater.go | 5 +++ roomserver/storage/shared/storage.go | 16 +++++++-- setup/mscs/msc2836/msc2836.go | 8 +++-- setup/mscs/msc2836/msc2836_test.go | 4 +++ syncapi/routing/context.go | 20 ++++++++--- syncapi/routing/getevent.go | 4 ++- syncapi/routing/memberships.go | 4 ++- syncapi/routing/messages.go | 12 ++++--- syncapi/routing/relations.go | 4 ++- syncapi/routing/routing.go | 2 +- syncapi/routing/search.go | 27 ++++++++++----- syncapi/routing/search_test.go | 10 +++++- syncapi/streams/stream_invite.go | 6 +++- syncapi/streams/stream_pdu.go | 32 +++++++++++++----- syncapi/streams/streams.go | 1 + syncapi/syncapi_test.go | 4 +++ syncapi/synctypes/clientevent.go | 8 ++--- syncapi/synctypes/clientevent_test.go | 11 +++++-- syncapi/types/types.go | 4 +-- syncapi/types/types_test.go | 7 +++- test/room.go | 6 +++- userapi/consumers/roomserver.go | 14 +++++--- userapi/consumers/roomserver_test.go | 9 ++++- userapi/util/notify_test.go | 7 +++- 52 files changed, 355 insertions(+), 136 deletions(-) diff --git a/appservice/consumers/roomserver.go b/appservice/consumers/roomserver.go index ce693e6482..5650a26b7f 100644 --- a/appservice/consumers/roomserver.go +++ b/appservice/consumers/roomserver.go @@ -181,7 +181,9 @@ func (s *OutputRoomEventConsumer) sendEvents( // Create the transaction body. transaction, err := json.Marshal( ApplicationServiceTransaction{ - Events: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatAll), + Events: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatAll, func(roomAliasOrID, senderID string) (*spec.UserID, error) { + return s.rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + }), }, ) if err != nil { @@ -234,7 +236,7 @@ func (s *appserviceState) backoffAndPause(err error) error { // TODO: This should be cached, see https://github.com/matrix-org/dendrite/issues/1682 func (s *OutputRoomEventConsumer) appserviceIsInterestedInEvent(ctx context.Context, event *types.HeaderedEvent, appservice *config.ApplicationService) bool { user := "" - userID, err := event.UserID() + userID, err := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()) if err == nil { user = userID.String() } diff --git a/clientapi/routing/sendevent.go b/clientapi/routing/sendevent.go index 1a2e25c9d0..5fb755f578 100644 --- a/clientapi/routing/sendevent.go +++ b/clientapi/routing/sendevent.go @@ -331,7 +331,9 @@ func generateSendEvent( stateEvents[i] = queryRes.StateEvents[i].PDU } provider := gomatrixserverlib.NewAuthEvents(gomatrixserverlib.ToPDUs(stateEvents)) - if err = gomatrixserverlib.Allowed(e.PDU, &provider); err != nil { + if err = gomatrixserverlib.Allowed(e.PDU, &provider, func(roomAliasOrID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + }); err != nil { return nil, &util.JSONResponse{ Code: http.StatusForbidden, JSON: spec.Forbidden(err.Error()), // TODO: Is this error string comprehensible to the client? diff --git a/clientapi/routing/state.go b/clientapi/routing/state.go index 319f4eba51..5dda21abfc 100644 --- a/clientapi/routing/state.go +++ b/clientapi/routing/state.go @@ -142,7 +142,9 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a for _, ev := range stateRes.StateEvents { stateEvents = append( stateEvents, - synctypes.ToClientEvent(ev, synctypes.FormatAll), + synctypes.ToClientEvent(ev, synctypes.FormatAll, func(roomAliasOrID string, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + }), ) } } else { @@ -164,7 +166,9 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a for _, ev := range stateAfterRes.StateEvents { stateEvents = append( stateEvents, - synctypes.ToClientEvent(ev, synctypes.FormatAll), + synctypes.ToClientEvent(ev, synctypes.FormatAll, func(roomAliasOrID string, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + }), ) } } @@ -335,7 +339,9 @@ func OnIncomingStateTypeRequest( } stateEvent := stateEventInStateResp{ - ClientEvent: synctypes.ToClientEvent(event, synctypes.FormatAll), + ClientEvent: synctypes.ToClientEvent(event, synctypes.FormatAll, func(roomAliasOrID string, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + }), } var res interface{} diff --git a/cmd/resolve-state/main.go b/cmd/resolve-state/main.go index 3a4255bae0..1ebb6cd8dd 100644 --- a/cmd/resolve-state/main.go +++ b/cmd/resolve-state/main.go @@ -18,6 +18,7 @@ import ( "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) // This is a utility for inspecting state snapshots and running state resolution @@ -182,7 +183,9 @@ func main() { fmt.Println("Resolving state") var resolved Events resolved, err = gomatrixserverlib.ResolveConflicts( - gomatrixserverlib.RoomVersion(*roomVersion), events, authEvents, + gomatrixserverlib.RoomVersion(*roomVersion), events, authEvents, func(roomAliasOrID, senderID string) (*spec.UserID, error) { + return roomserverDB.GetUserIDForSender(ctx, roomAliasOrID, senderID) + }, ) if err != nil { panic(err) diff --git a/federationapi/federationapi_test.go b/federationapi/federationapi_test.go index beb648a48f..fc09440dad 100644 --- a/federationapi/federationapi_test.go +++ b/federationapi/federationapi_test.go @@ -36,6 +36,10 @@ type fedRoomserverAPI struct { queryRoomsForUser func(ctx context.Context, req *rsapi.QueryRoomsForUserRequest, res *rsapi.QueryRoomsForUserResponse) error } +func (f *fedRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomAliasOrID string, senderID string) (*spec.UserID, error) { + return spec.NewUserID(senderID, true) +} + // PerformJoin will call this function func (f *fedRoomserverAPI) InputRoomEvents(ctx context.Context, req *rsapi.InputRoomEventsRequest, res *rsapi.InputRoomEventsResponse) { if f.inputRoomEvents == nil { diff --git a/federationapi/internal/perform.go b/federationapi/internal/perform.go index c4c7156636..4f05979838 100644 --- a/federationapi/internal/perform.go +++ b/federationapi/internal/perform.go @@ -156,15 +156,20 @@ func (r *FederationInternalAPI) performJoinUsingServer( } joinInput := gomatrixserverlib.PerformJoinInput{ - UserID: user, - RoomID: room, - ServerName: serverName, - Content: content, - Unsigned: unsigned, - PrivateKey: r.cfg.Matrix.PrivateKey, - KeyID: r.cfg.Matrix.KeyID, - KeyRing: r.keyRing, - EventProvider: federatedEventProvider(ctx, r.federation, r.keyRing, user.Domain(), serverName), + UserID: user, + RoomID: room, + ServerName: serverName, + Content: content, + Unsigned: unsigned, + PrivateKey: r.cfg.Matrix.PrivateKey, + KeyID: r.cfg.Matrix.KeyID, + KeyRing: r.keyRing, + EventProvider: federatedEventProvider(ctx, r.federation, r.keyRing, user.Domain(), serverName, func(roomAliasOrID, senderID string) (*spec.UserID, error) { + return r.rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + }), + UserIDQuerier: func(roomAliasOrID, senderID string) (*spec.UserID, error) { + return r.rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + }, } response, joinErr := gomatrixserverlib.PerformJoin(ctx, r, joinInput) @@ -358,8 +363,11 @@ func (r *FederationInternalAPI) performOutboundPeekUsingServer( // authenticate the state returned (check its auth events etc) // the equivalent of CheckSendJoinResponse() + userIDProvider := func(roomAliasOrID, senderID string) (*spec.UserID, error) { + return r.rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + } authEvents, stateEvents, err := gomatrixserverlib.CheckStateResponse( - ctx, &respPeek, respPeek.RoomVersion, r.keyRing, federatedEventProvider(ctx, r.federation, r.keyRing, r.cfg.Matrix.ServerName, serverName), + ctx, &respPeek, respPeek.RoomVersion, r.keyRing, federatedEventProvider(ctx, r.federation, r.keyRing, r.cfg.Matrix.ServerName, serverName, userIDProvider), userIDProvider, ) if err != nil { return fmt.Errorf("error checking state returned from peeking: %w", err) @@ -509,7 +517,7 @@ func (r *FederationInternalAPI) SendInvite( event gomatrixserverlib.PDU, strippedState []gomatrixserverlib.InviteStrippedState, ) (gomatrixserverlib.PDU, error) { - inviter, err := event.UserID() + inviter, err := r.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()) if err != nil { return nil, err } @@ -640,6 +648,7 @@ func checkEventsContainCreateEvent(events []gomatrixserverlib.PDU) error { func federatedEventProvider( ctx context.Context, federation fclient.FederationClient, keyRing gomatrixserverlib.JSONVerifier, origin, server spec.ServerName, + userIDForSender spec.UserIDForSender, ) gomatrixserverlib.EventProvider { // A list of events that we have retried, if they were not included in // the auth events supplied in the send_join. @@ -689,7 +698,7 @@ func federatedEventProvider( } // Check the signatures of the event. - if err := gomatrixserverlib.VerifyEventSignatures(ctx, ev, keyRing); err != nil { + if err := gomatrixserverlib.VerifyEventSignatures(ctx, ev, keyRing, userIDForSender); err != nil { return nil, fmt.Errorf("missingAuth VerifyEventSignatures: %w", err) } diff --git a/federationapi/routing/join.go b/federationapi/routing/join.go index c6f96375ee..f00830e50a 100644 --- a/federationapi/routing/join.go +++ b/federationapi/routing/join.go @@ -147,15 +147,18 @@ func MakeJoin( } input := gomatrixserverlib.HandleMakeJoinInput{ - Context: httpReq.Context(), - UserID: userID, - RoomID: roomID, - RoomVersion: roomVersion, - RemoteVersions: remoteVersions, - RequestOrigin: request.Origin(), - LocalServerName: cfg.Matrix.ServerName, - LocalServerInRoom: res.RoomExists && res.IsInRoom, - RoomQuerier: &roomQuerier, + Context: httpReq.Context(), + UserID: userID, + RoomID: roomID, + RoomVersion: roomVersion, + RemoteVersions: remoteVersions, + RequestOrigin: request.Origin(), + LocalServerName: cfg.Matrix.ServerName, + LocalServerInRoom: res.RoomExists && res.IsInRoom, + RoomQuerier: &roomQuerier, + UserIDQuerier: func(roomAliasOrID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(httpReq.Context(), roomAliasOrID, senderID) + }, BuildEventTemplate: createJoinTemplate, } response, internalErr := gomatrixserverlib.HandleMakeJoin(input) diff --git a/federationapi/routing/leave.go b/federationapi/routing/leave.go index 33cd54fad9..25365b4ecd 100644 --- a/federationapi/routing/leave.go +++ b/federationapi/routing/leave.go @@ -223,7 +223,7 @@ func SendLeave( // Check that the sender belongs to the server that is sending us // the request. By this point we've already asserted that the sender // and the state key are equal so we don't need to check both. - sender, err := event.UserID() + sender, err := rsAPI.QueryUserIDForSender(httpReq.Context(), event.RoomID(), event.SenderID()) if err != nil { return util.JSONResponse{ Code: http.StatusForbidden, diff --git a/go.mod b/go.mod index 9ff9ab2869..35f28c03f2 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,7 @@ require ( github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 - github.com/matrix-org/gomatrixserverlib v0.0.0-20230603021032-d30b8fdc7ced + github.com/matrix-org/gomatrixserverlib v0.0.0-20230606021710-b68a1b0eef30 github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 github.com/mattn/go-sqlite3 v1.14.16 @@ -34,7 +34,7 @@ require ( github.com/patrickmn/go-cache v2.1.0+incompatible github.com/pkg/errors v0.9.1 github.com/prometheus/client_golang v1.13.0 - github.com/sirupsen/logrus v1.9.2 + github.com/sirupsen/logrus v1.9.3 github.com/stretchr/testify v1.8.2 github.com/tidwall/gjson v1.14.4 github.com/tidwall/sjson v1.2.5 diff --git a/go.sum b/go.sum index 24544c90c7..8bac1e8079 100644 --- a/go.sum +++ b/go.sum @@ -323,8 +323,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230603021032-d30b8fdc7ced h1:pbCM+nno+r2wW3jwxP65xmkzk6008CdMNZaOWYBwB1c= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230603021032-d30b8fdc7ced/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230606021710-b68a1b0eef30 h1:G+Do1UoWazY0Fetq+eAX1h1+fimf19NGGyaS86hWg8s= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230606021710-b68a1b0eef30/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU= github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a h1:awrPDf9LEFySxTLKYBMCiObelNx/cBuv/wzllvCCH3A= github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a/go.mod h1:HchJX9oKMXaT2xYFs0Ha/6Zs06mxLU8k6F1ODnrGkeQ= github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y= @@ -444,8 +444,8 @@ github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPx github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88= github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= -github.com/sirupsen/logrus v1.9.2 h1:oxx1eChJGI6Uks2ZC4W1zpLlVgqB8ner4EuQwV4Ik1Y= -github.com/sirupsen/logrus v1.9.2/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= +github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= +github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= github.com/smartystreets/goconvey v0.0.0-20181108003508-044398e4856c/go.mod h1:XDJAKZRPZ1CvBcN2aX5YOUTYGHki24fSF0Iv48Ibg0s= github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= diff --git a/internal/pushrules/evaluate.go b/internal/pushrules/evaluate.go index 646cc90148..da33d38623 100644 --- a/internal/pushrules/evaluate.go +++ b/internal/pushrules/evaluate.go @@ -6,6 +6,7 @@ import ( "strings" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) // A RuleSetEvaluator encapsulates context to evaluate an event @@ -53,7 +54,7 @@ func NewRuleSetEvaluator(ec EvaluationContext, ruleSet *RuleSet) *RuleSetEvaluat // MatchEvent returns the first matching rule. Returns nil if there // was no match rule. -func (rse *RuleSetEvaluator) MatchEvent(event gomatrixserverlib.PDU) (*Rule, error) { +func (rse *RuleSetEvaluator) MatchEvent(event gomatrixserverlib.PDU, userIDForSender spec.UserIDForSender) (*Rule, error) { // TODO: server-default rules have lower priority than user rules, // but they are stored together with the user rules. It's a bit // unclear what the specification (11.14.1.4 Predefined rules) @@ -68,7 +69,7 @@ func (rse *RuleSetEvaluator) MatchEvent(event gomatrixserverlib.PDU) (*Rule, err if rule.Default != defRules { continue } - ok, err := ruleMatches(rule, rsat.Kind, event, rse.ec) + ok, err := ruleMatches(rule, rsat.Kind, event, rse.ec, userIDForSender) if err != nil { return nil, err } @@ -83,7 +84,7 @@ func (rse *RuleSetEvaluator) MatchEvent(event gomatrixserverlib.PDU) (*Rule, err return nil, nil } -func ruleMatches(rule *Rule, kind Kind, event gomatrixserverlib.PDU, ec EvaluationContext) (bool, error) { +func ruleMatches(rule *Rule, kind Kind, event gomatrixserverlib.PDU, ec EvaluationContext, userIDForSender spec.UserIDForSender) (bool, error) { if !rule.Enabled { return false, nil } @@ -114,7 +115,7 @@ func ruleMatches(rule *Rule, kind Kind, event gomatrixserverlib.PDU, ec Evaluati case SenderKind: userID := "" - sender, err := event.UserID() + sender, err := userIDForSender(event.RoomID(), event.SenderID()) if err == nil { userID = sender.String() } diff --git a/internal/pushrules/evaluate_test.go b/internal/pushrules/evaluate_test.go index e3bfe48d0a..616dac894c 100644 --- a/internal/pushrules/evaluate_test.go +++ b/internal/pushrules/evaluate_test.go @@ -5,8 +5,13 @@ import ( "github.com/google/go-cmp/cmp" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) +func UserIDForSender(roomAliasOrID string, senderID string) (*spec.UserID, error) { + return spec.NewUserID(senderID, true) +} + func TestRuleSetEvaluatorMatchEvent(t *testing.T) { ev := mustEventFromJSON(t, `{}`) defaultEnabled := &Rule{ @@ -45,7 +50,7 @@ func TestRuleSetEvaluatorMatchEvent(t *testing.T) { for _, tst := range tsts { t.Run(tst.Name, func(t *testing.T) { rse := NewRuleSetEvaluator(fakeEvaluationContext{3}, &tst.RuleSet) - got, err := rse.MatchEvent(tst.Event) + got, err := rse.MatchEvent(tst.Event, UserIDForSender) if err != nil { t.Fatalf("MatchEvent failed: %v", err) } @@ -90,7 +95,7 @@ func TestRuleMatches(t *testing.T) { } for _, tst := range tsts { t.Run(tst.Name, func(t *testing.T) { - got, err := ruleMatches(&tst.Rule, tst.Kind, mustEventFromJSON(t, tst.EventJSON), nil) + got, err := ruleMatches(&tst.Rule, tst.Kind, mustEventFromJSON(t, tst.EventJSON), nil, UserIDForSender) if err != nil { t.Fatalf("ruleMatches failed: %v", err) } diff --git a/internal/transactionrequest.go b/internal/transactionrequest.go index c9d321f259..5d9be08c5c 100644 --- a/internal/transactionrequest.go +++ b/internal/transactionrequest.go @@ -167,7 +167,9 @@ func (t *TxnReq) ProcessTransaction(ctx context.Context) (*fclient.RespSend, *ut } continue } - if err = gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys); err != nil { + if err = gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys, func(roomAliasOrID, senderID string) (*spec.UserID, error) { + return t.rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + }); err != nil { util.GetLogger(ctx).WithError(err).Debugf("Transaction: Couldn't validate signature of event %q", event.EventID()) results[event.EventID()] = fclient.PDUResult{ Error: err.Error(), diff --git a/internal/transactionrequest_test.go b/internal/transactionrequest_test.go index fb30d410e5..bcda6520e2 100644 --- a/internal/transactionrequest_test.go +++ b/internal/transactionrequest_test.go @@ -70,6 +70,10 @@ type FakeRsAPI struct { bannedFromRoom bool } +func (r *FakeRsAPI) QueryUserIDForSender(ctx context.Context, roomAliasOrID string, senderID string) (*spec.UserID, error) { + return spec.NewUserID(senderID, true) +} + func (r *FakeRsAPI) QueryRoomVersionForRoom( ctx context.Context, roomID string, @@ -638,6 +642,10 @@ type testRoomserverAPI struct { queryLatestEventsAndState func(*rsAPI.QueryLatestEventsAndStateRequest) rsAPI.QueryLatestEventsAndStateResponse } +func (t *testRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomAliasOrID string, senderID string) (*spec.UserID, error) { + return spec.NewUserID(senderID, true) +} + func (t *testRoomserverAPI) InputRoomEvents( ctx context.Context, request *rsAPI.InputRoomEventsRequest, diff --git a/roomserver/api/api.go b/roomserver/api/api.go index ec1cd6abbd..38b4d80512 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -39,6 +39,7 @@ type RoomserverInternalAPI interface { ClientRoomserverAPI UserRoomserverAPI FederationRoomserverAPI + QuerySenderIDAPI // needed to avoid chicken and egg scenario when setting up the // interdependencies between the roomserver and other input APIs @@ -68,7 +69,7 @@ type InputRoomEventsAPI interface { type QuerySenderIDAPI interface { // Accepts either roomID or alias QuerySenderIDForUser(ctx context.Context, roomAliasOrID string, userID spec.UserID) (string, error) - QueryUserIDForSender(ctx context.Context, roomAliasOrID string, senderID string) (spec.UserID, error) + QueryUserIDForSender(ctx context.Context, roomAliasOrID string, senderID string) (*spec.UserID, error) } // Query the latest events and state for a room from the room server. @@ -98,6 +99,7 @@ type QueryEventsAPI interface { type SyncRoomserverAPI interface { QueryLatestEventsAndStateAPI QueryBulkStateContentAPI + QuerySenderIDAPI // QuerySharedUsers returns a list of users who share at least 1 room in common with the given user. QuerySharedUsers(ctx context.Context, req *QuerySharedUsersRequest, res *QuerySharedUsersResponse) error // QueryEventsByID queries a list of events by event ID for one room. If no room is specified, it will try to determine @@ -138,6 +140,7 @@ type SyncRoomserverAPI interface { } type AppserviceRoomserverAPI interface { + QuerySenderIDAPI // QueryEventsByID queries a list of events by event ID for one room. If no room is specified, it will try to determine // which room to use by querying the first events roomID. QueryEventsByID( @@ -197,6 +200,7 @@ type ClientRoomserverAPI interface { } type UserRoomserverAPI interface { + QuerySenderIDAPI QueryLatestEventsAndStateAPI KeyserverRoomserverAPI QueryCurrentState(ctx context.Context, req *QueryCurrentStateRequest, res *QueryCurrentStateResponse) error @@ -209,6 +213,8 @@ type FederationRoomserverAPI interface { InputRoomEventsAPI QueryLatestEventsAndStateAPI QueryBulkStateContentAPI + QuerySenderIDAPI + // QueryServerBannedFromRoom returns whether a server is banned from a room by server ACLs. QueryServerBannedFromRoom(ctx context.Context, req *QueryServerBannedFromRoomRequest, res *QueryServerBannedFromRoomResponse) error QueryMembershipForUser(ctx context.Context, req *QueryMembershipForUserRequest, res *QueryMembershipForUserResponse) error diff --git a/roomserver/api/query.go b/roomserver/api/query.go index 0449dd97cf..b1c30c2894 100644 --- a/roomserver/api/query.go +++ b/roomserver/api/query.go @@ -463,10 +463,10 @@ type MembershipQuerier struct { Roomserver FederationRoomserverAPI } -func (mq *MembershipQuerier) CurrentMembership(ctx context.Context, roomID spec.RoomID, senderID string) (string, error) { +func (mq *MembershipQuerier) CurrentMembership(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (string, error) { req := QueryMembershipForUserRequest{ RoomID: roomID.String(), - UserID: senderID, + UserID: string(senderID), } res := QueryMembershipForUserResponse{} err := mq.Roomserver.QueryMembershipForUser(ctx, &req, &res) diff --git a/roomserver/internal/helpers/auth.go b/roomserver/internal/helpers/auth.go index 7ec0892e41..cb5bdcc40f 100644 --- a/roomserver/internal/helpers/auth.go +++ b/roomserver/internal/helpers/auth.go @@ -76,7 +76,9 @@ func CheckForSoftFail( } // Check if the event is allowed. - if err = gomatrixserverlib.Allowed(event.PDU, &authEvents); err != nil { + if err = gomatrixserverlib.Allowed(event.PDU, &authEvents, func(roomAliasOrID, senderID string) (*spec.UserID, error) { + return db.GetUserIDForSender(ctx, roomAliasOrID, senderID) + }); err != nil { // return true, nil return true, err } diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index 7fb5bfc131..ffaa20afaf 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -128,7 +128,7 @@ func (r *Inputer) processRoomEvent( if roomInfo == nil && !isCreateEvent { return fmt.Errorf("room %s does not exist for event %s", event.RoomID(), event.EventID()) } - sender, err := event.UserID() + sender, err := r.DB.GetUserIDForSender(ctx, event.RoomID(), event.SenderID()) if err != nil { return fmt.Errorf("event has invalid sender %q", event.SenderID()) } @@ -276,7 +276,9 @@ func (r *Inputer) processRoomEvent( // Check if the event is allowed by its auth events. If it isn't then // we consider the event to be "rejected" — it will still be persisted. - if err = gomatrixserverlib.Allowed(event, &authEvents); err != nil { + if err = gomatrixserverlib.Allowed(event, &authEvents, func(roomAliasOrID, senderID string) (*spec.UserID, error) { + return r.DB.GetUserIDForSender(ctx, roomAliasOrID, senderID) + }); err != nil { isRejected = true rejectionErr = err logger.WithError(rejectionErr).Warnf("Event %s not allowed by auth events", event.EventID()) @@ -579,7 +581,9 @@ func (r *Inputer) processStateBefore( stateBeforeAuth := gomatrixserverlib.NewAuthEvents( gomatrixserverlib.ToPDUs(stateBeforeEvent), ) - if rejectionErr = gomatrixserverlib.Allowed(event, &stateBeforeAuth); rejectionErr != nil { + if rejectionErr = gomatrixserverlib.Allowed(event, &stateBeforeAuth, func(roomAliasOrID, senderID string) (*spec.UserID, error) { + return r.DB.GetUserIDForSender(ctx, roomAliasOrID, senderID) + }); rejectionErr != nil { rejectionErr = fmt.Errorf("Allowed() failed for stateBeforeEvent: %w", rejectionErr) return } @@ -690,7 +694,9 @@ nextAuthEvent: // Check the signatures of the event. If this fails then we'll simply // skip it, because gomatrixserverlib.Allowed() will notice a problem // if a critical event is missing anyway. - if err := gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.FSAPI.KeyRing()); err != nil { + if err := gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.FSAPI.KeyRing(), func(roomAliasOrID, senderID string) (*spec.UserID, error) { + return r.DB.GetUserIDForSender(ctx, roomAliasOrID, senderID) + }); err != nil { continue nextAuthEvent } @@ -706,7 +712,9 @@ nextAuthEvent: } // Check if the auth event should be rejected. - err := gomatrixserverlib.Allowed(authEvent, auth) + err := gomatrixserverlib.Allowed(authEvent, auth, func(roomAliasOrID, senderID string) (*spec.UserID, error) { + return r.DB.GetUserIDForSender(ctx, roomAliasOrID, senderID) + }) if isRejected = err != nil; isRejected { logger.WithError(err).Warnf("Auth event %s rejected", authEvent.EventID()) } diff --git a/roomserver/internal/input/input_events_test.go b/roomserver/internal/input/input_events_test.go index 568038132f..17a5aa1294 100644 --- a/roomserver/internal/input/input_events_test.go +++ b/roomserver/internal/input/input_events_test.go @@ -58,7 +58,7 @@ func Test_EventAuth(t *testing.T) { } // Finally check that the event is NOT allowed - if err := gomatrixserverlib.Allowed(ev.PDU, &allower); err == nil { + if err := gomatrixserverlib.Allowed(ev.PDU, &allower, func(roomAliasOrID, senderID string) (*spec.UserID, error) { return spec.NewUserID(senderID, true) }); err == nil { t.Fatalf("event should not be allowed, but it was") } } diff --git a/roomserver/internal/input/input_missing.go b/roomserver/internal/input/input_missing.go index 10486138df..1a3c2c142d 100644 --- a/roomserver/internal/input/input_missing.go +++ b/roomserver/internal/input/input_missing.go @@ -473,14 +473,18 @@ func (t *missingStateReq) resolveStatesAndCheck(ctx context.Context, roomVersion stateEventList = append(stateEventList, state.StateEvents...) } resolvedStateEvents, err := gomatrixserverlib.ResolveConflicts( - roomVersion, gomatrixserverlib.ToPDUs(stateEventList), gomatrixserverlib.ToPDUs(authEventList), + roomVersion, gomatrixserverlib.ToPDUs(stateEventList), gomatrixserverlib.ToPDUs(authEventList), func(roomAliasOrID, senderID string) (*spec.UserID, error) { + return t.db.GetUserIDForSender(ctx, roomAliasOrID, senderID) + }, ) if err != nil { return nil, err } // apply the current event retryAllowedState: - if err = checkAllowedByState(backwardsExtremity, resolvedStateEvents); err != nil { + if err = checkAllowedByState(backwardsExtremity, resolvedStateEvents, func(roomAliasOrID, senderID string) (*spec.UserID, error) { + return t.db.GetUserIDForSender(ctx, roomAliasOrID, senderID) + }); err != nil { switch missing := err.(type) { case gomatrixserverlib.MissingAuthEventError: h, err2 := t.lookupEvent(ctx, roomVersion, backwardsExtremity.RoomID(), missing.AuthEventID, true) @@ -565,7 +569,9 @@ func (t *missingStateReq) getMissingEvents(ctx context.Context, e gomatrixserver // will be added and duplicates will be removed. missingEvents := make([]gomatrixserverlib.PDU, 0, len(missingResp.Events)) for _, ev := range missingResp.Events.UntrustedEvents(roomVersion) { - if err = gomatrixserverlib.VerifyEventSignatures(ctx, ev, t.keys); err != nil { + if err = gomatrixserverlib.VerifyEventSignatures(ctx, ev, t.keys, func(roomAliasOrID, senderID string) (*spec.UserID, error) { + return t.db.GetUserIDForSender(ctx, roomAliasOrID, senderID) + }); err != nil { continue } missingEvents = append(missingEvents, t.cacheAndReturn(ev)) @@ -654,7 +660,9 @@ func (t *missingStateReq) lookupMissingStateViaState( authEvents, stateEvents, err := gomatrixserverlib.CheckStateResponse(ctx, &fclient.RespState{ StateEvents: state.GetStateEvents(), AuthEvents: state.GetAuthEvents(), - }, roomVersion, t.keys, nil) + }, roomVersion, t.keys, nil, func(roomAliasOrID, senderID string) (*spec.UserID, error) { + return t.db.GetUserIDForSender(ctx, roomAliasOrID, senderID) + }) if err != nil { return nil, err } @@ -889,14 +897,16 @@ func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixs t.log.WithField("missing_event_id", missingEventID).Warnf("Failed to get missing /event for event ID from %d server(s)", len(t.servers)) return nil, fmt.Errorf("wasn't able to find event via %d server(s)", len(t.servers)) } - if err := gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys); err != nil { + if err := gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys, func(roomAliasOrID, senderID string) (*spec.UserID, error) { + return t.db.GetUserIDForSender(ctx, roomAliasOrID, senderID) + }); err != nil { t.log.WithError(err).Warnf("Couldn't validate signature of event %q from /event", event.EventID()) return nil, verifySigError{event.EventID(), err} } return t.cacheAndReturn(event), nil } -func checkAllowedByState(e gomatrixserverlib.PDU, stateEvents []gomatrixserverlib.PDU) error { +func checkAllowedByState(e gomatrixserverlib.PDU, stateEvents []gomatrixserverlib.PDU, userIDForSender spec.UserIDForSender) error { authUsingState := gomatrixserverlib.NewAuthEvents(nil) for i := range stateEvents { err := authUsingState.AddEvent(stateEvents[i]) @@ -904,7 +914,7 @@ func checkAllowedByState(e gomatrixserverlib.PDU, stateEvents []gomatrixserverli return err } } - return gomatrixserverlib.Allowed(e, &authUsingState) + return gomatrixserverlib.Allowed(e, &authUsingState, userIDForSender) } func (t *missingStateReq) hadEvent(eventID string) { diff --git a/roomserver/internal/perform/perform_admin.go b/roomserver/internal/perform/perform_admin.go index 575525e21b..488a106489 100644 --- a/roomserver/internal/perform/perform_admin.go +++ b/roomserver/internal/perform/perform_admin.go @@ -262,13 +262,17 @@ func (r *Admin) PerformAdminDownloadState( return fmt.Errorf("r.Inputer.FSAPI.LookupState (%q): %s", fwdExtremity, err) } for _, authEvent := range state.GetAuthEvents().UntrustedEvents(roomInfo.RoomVersion) { - if err = gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.Inputer.KeyRing); err != nil { + if err = gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.Inputer.KeyRing, func(roomAliasOrID, senderID string) (*spec.UserID, error) { + return r.DB.GetUserIDForSender(ctx, roomAliasOrID, senderID) + }); err != nil { continue } authEventMap[authEvent.EventID()] = authEvent } for _, stateEvent := range state.GetStateEvents().UntrustedEvents(roomInfo.RoomVersion) { - if err = gomatrixserverlib.VerifyEventSignatures(ctx, stateEvent, r.Inputer.KeyRing); err != nil { + if err = gomatrixserverlib.VerifyEventSignatures(ctx, stateEvent, r.Inputer.KeyRing, func(roomAliasOrID, senderID string) (*spec.UserID, error) { + return r.DB.GetUserIDForSender(ctx, roomAliasOrID, senderID) + }); err != nil { continue } stateEventMap[stateEvent.EventID()] = stateEvent diff --git a/roomserver/internal/perform/perform_backfill.go b/roomserver/internal/perform/perform_backfill.go index 686bcc1647..bdb9bf6c94 100644 --- a/roomserver/internal/perform/perform_backfill.go +++ b/roomserver/internal/perform/perform_backfill.go @@ -121,7 +121,9 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform // Specifically the test "Outbound federation can backfill events" events, err := gomatrixserverlib.RequestBackfill( ctx, req.VirtualHost, requester, - r.KeyRing, req.RoomID, info.RoomVersion, req.PrevEventIDs(), 100, + r.KeyRing, req.RoomID, info.RoomVersion, req.PrevEventIDs(), 100, func(roomAliasOrID, senderID string) (*spec.UserID, error) { + return r.DB.GetUserIDForSender(ctx, roomAliasOrID, senderID) + }, ) // Only return an error if we really couldn't get any events. if err != nil && len(events) == 0 { @@ -210,7 +212,9 @@ func (r *Backfiller) fetchAndStoreMissingEvents(ctx context.Context, roomVer gom continue } loader := gomatrixserverlib.NewEventsLoader(roomVer, r.KeyRing, backfillRequester, backfillRequester.ProvideEvents, false) - result, err := loader.LoadAndVerify(ctx, res.PDUs, gomatrixserverlib.TopologicalOrderByPrevEvents) + result, err := loader.LoadAndVerify(ctx, res.PDUs, gomatrixserverlib.TopologicalOrderByPrevEvents, func(roomAliasOrID, senderID string) (*spec.UserID, error) { + return r.DB.GetUserIDForSender(ctx, roomAliasOrID, senderID) + }) if err != nil { logger.WithError(err).Warn("failed to load and verify event") continue @@ -484,7 +488,7 @@ FindSuccessor: // Store the server names in a temporary map to avoid duplicates. serverSet := make(map[spec.ServerName]bool) for _, event := range memberEvents { - if sender, err := event.UserID(); err == nil { + if sender, err := b.db.GetUserIDForSender(ctx, event.RoomID(), event.SenderID()); err == nil { serverSet[sender.Domain()] = true } } diff --git a/roomserver/internal/perform/perform_create_room.go b/roomserver/internal/perform/perform_create_room.go index 1ae47d9453..7448bca4f6 100644 --- a/roomserver/internal/perform/perform_create_room.go +++ b/roomserver/internal/perform/perform_create_room.go @@ -308,7 +308,9 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo } } - if err = gomatrixserverlib.Allowed(ev, &authEvents); err != nil { + if err = gomatrixserverlib.Allowed(ev, &authEvents, func(roomAliasOrID, senderID string) (*spec.UserID, error) { + return c.DB.GetUserIDForSender(ctx, roomAliasOrID, senderID) + }); err != nil { util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.Allowed failed") return "", &util.JSONResponse{ Code: http.StatusInternalServerError, diff --git a/roomserver/internal/perform/perform_invite.go b/roomserver/internal/perform/perform_invite.go index 329395e631..4050c294f6 100644 --- a/roomserver/internal/perform/perform_invite.go +++ b/roomserver/internal/perform/perform_invite.go @@ -126,7 +126,7 @@ func (r *Inviter) PerformInvite( ) error { event := req.Event - sender, err := event.UserID() + sender, err := r.DB.GetUserIDForSender(ctx, event.RoomID(), event.SenderID()) if err != nil { return spec.InvalidParam("The sender user ID is invalid") } @@ -156,6 +156,9 @@ func (r *Inviter) PerformInvite( StrippedState: req.InviteRoomState, MembershipQuerier: &api.MembershipQuerier{Roomserver: r.RSAPI}, StateQuerier: &QueryState{r.DB}, + UserIDQuerier: func(roomAliasOrID, senderID string) (*spec.UserID, error) { + return r.DB.GetUserIDForSender(ctx, roomAliasOrID, senderID) + }, } inviteEvent, err := gomatrixserverlib.PerformInvite(ctx, input, r.FSAPI) if err != nil { diff --git a/roomserver/internal/perform/perform_upgrade.go b/roomserver/internal/perform/perform_upgrade.go index 7e615f2a3c..99ad21ccd5 100644 --- a/roomserver/internal/perform/perform_upgrade.go +++ b/roomserver/internal/perform/perform_upgrade.go @@ -484,7 +484,9 @@ func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, user } - if err = gomatrixserverlib.Allowed(event, &authEvents); err != nil { + if err = gomatrixserverlib.Allowed(event, &authEvents, func(roomAliasOrID, senderID string) (*spec.UserID, error) { + return r.URSAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + }); err != nil { return fmt.Errorf("Failed to auth new %q event: %w", builder.Type, err) } @@ -567,7 +569,9 @@ func (r *Upgrader) makeHeaderedEvent(ctx context.Context, evTime time.Time, user stateEvents[i] = queryRes.StateEvents[i].PDU } provider := gomatrixserverlib.NewAuthEvents(stateEvents) - if err = gomatrixserverlib.Allowed(headeredEvent.PDU, &provider); err != nil { + if err = gomatrixserverlib.Allowed(headeredEvent.PDU, &provider, func(roomAliasOrID, senderID string) (*spec.UserID, error) { + return r.URSAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + }); err != nil { return nil, api.ErrNotAllowed{Err: fmt.Errorf("failed to auth new %q event: %w", proto.Type, err)} // TODO: Is this error string comprehensible to the client? } diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index d249b28969..fdc37c7b23 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -122,7 +122,9 @@ func (r *Queryer) QueryStateAfterEvents( } stateEvents, err = gomatrixserverlib.ResolveConflicts( - info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), + info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomAliasOrID, senderID string) (*spec.UserID, error) { + return r.DB.GetUserIDForSender(ctx, roomAliasOrID, senderID) + }, ) if err != nil { return fmt.Errorf("state.ResolveConflictsAdhoc: %w", err) @@ -349,7 +351,9 @@ func (r *Queryer) QueryMembershipsForRoom( return fmt.Errorf("r.DB.Events: %w", err) } for _, event := range events { - clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll) + clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, func(roomAliasOrID, senderID string) (*spec.UserID, error) { + return r.DB.GetUserIDForSender(ctx, roomAliasOrID, senderID) + }) response.JoinEvents = append(response.JoinEvents, clientEvent) } return nil @@ -398,7 +402,9 @@ func (r *Queryer) QueryMembershipsForRoom( } for _, event := range events { - clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll) + clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, func(roomAliasOrID, senderID string) (*spec.UserID, error) { + return r.DB.GetUserIDForSender(ctx, roomAliasOrID, senderID) + }) response.JoinEvents = append(response.JoinEvents, clientEvent) } @@ -588,7 +594,9 @@ func (r *Queryer) QueryStateAndAuthChain( if request.ResolveState { stateEvents, err = gomatrixserverlib.ResolveConflicts( - info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), + info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomAliasOrID, senderID string) (*spec.UserID, error) { + return r.DB.GetUserIDForSender(ctx, roomAliasOrID, senderID) + }, ) if err != nil { return err @@ -1036,15 +1044,9 @@ func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, req *api.Query } func (r *Queryer) QuerySenderIDForUser(ctx context.Context, roomAliasOrID string, userID spec.UserID) (string, error) { - // TODO: implement this properly with pseudoIDs - return userID.String(), nil + return r.DB.GetSenderIDForUser(ctx, roomAliasOrID, userID) } -func (r *Queryer) QueryUserIDForSender(ctx context.Context, roomAliasOrID string, senderID string) (spec.UserID, error) { - // TODO: implement this properly with pseudoIDs - userID, err := spec.NewUserID(senderID, true) - if err != nil { - return spec.UserID{}, err - } - return *userID, err +func (r *Queryer) QueryUserIDForSender(ctx context.Context, roomAliasOrID string, senderID string) (*spec.UserID, error) { + return r.DB.GetUserIDForSender(ctx, roomAliasOrID, senderID) } diff --git a/roomserver/state/state.go b/roomserver/state/state.go index f38d8f96a3..1607971f74 100644 --- a/roomserver/state/state.go +++ b/roomserver/state/state.go @@ -24,6 +24,7 @@ import ( "time" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/prometheus/client_golang/prometheus" @@ -43,6 +44,7 @@ type StateResolutionStorage interface { AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error) Events(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, eventNIDs []types.EventNID) ([]types.Event, error) EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error) + GetUserIDForSender(ctx context.Context, roomAliasOrID string, senderID string) (*spec.UserID, error) } type StateResolution struct { @@ -945,7 +947,9 @@ func (v *StateResolution) resolveConflictsV1( } // Resolve the conflicts. - resolvedEvents := gomatrixserverlib.ResolveStateConflicts(conflictedEvents, authEvents) + resolvedEvents := gomatrixserverlib.ResolveStateConflicts(conflictedEvents, authEvents, func(roomAliasOrID, senderID string) (*spec.UserID, error) { + return v.db.GetUserIDForSender(ctx, roomAliasOrID, senderID) + }) // Map from the full events back to numeric state entries. for _, resolvedEvent := range resolvedEvents { @@ -1057,6 +1061,9 @@ func (v *StateResolution) resolveConflictsV2( conflictedEvents, nonConflictedEvents, authEvents, + func(roomAliasOrID, senderID string) (*spec.UserID, error) { + return v.db.GetUserIDForSender(ctx, roomAliasOrID, senderID) + }, ) }() diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index 7d22df0084..89e40b49be 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -166,6 +166,10 @@ type Database interface { GetServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName spec.ServerName) (bool, error) // GetKnownUsers searches all users that userID knows about. GetKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error) + // GetKnownUsers tries to obtain the current mxid for a given user. + GetUserIDForSender(ctx context.Context, roomAliasOrID string, senderID string) (*spec.UserID, error) + // GetKnownUsers tries to obtain the current senderID for a given user. + GetSenderIDForUser(ctx context.Context, roomAliasOrID string, userID spec.UserID) (string, error) // GetKnownRooms returns a list of all rooms we know about. GetKnownRooms(ctx context.Context) ([]string, error) // ForgetRoom sets a flag in the membership table, that the user wishes to forget a specific room @@ -211,6 +215,7 @@ type RoomDatabase interface { GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error) GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKey *string) (types.EventStateKeyNID, error) GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*types.HeaderedEvent, error) + GetUserIDForSender(ctx context.Context, roomAliasOrID string, senderID string) (*spec.UserID, error) } type EventDatabase interface { diff --git a/roomserver/storage/shared/room_updater.go b/roomserver/storage/shared/room_updater.go index 70672a33e1..4068caebc9 100644 --- a/roomserver/storage/shared/room_updater.go +++ b/roomserver/storage/shared/room_updater.go @@ -6,6 +6,7 @@ import ( "fmt" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -250,3 +251,7 @@ func (u *RoomUpdater) MarkEventAsSent(eventNID types.EventNID) error { func (u *RoomUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, targetLocal bool) (*MembershipUpdater, error) { return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomInfo.RoomNID, targetUserNID, targetLocal) } + +func (u *RoomUpdater) GetUserIDForSender(ctx context.Context, roomAliasOrID string, senderID string) (*spec.UserID, error) { + return u.d.GetUserIDForSender(ctx, roomAliasOrID, senderID) +} diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 61183bc420..c88f272956 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -988,13 +988,15 @@ func (d *EventDatabase) MaybeRedactEvent( return nil } + // TODO: Don't hack senderID into userID here (pseudoIDs) sender1Domain := "" - sender1, err1 := redactedEvent.UserID() + sender1, err1 := spec.NewUserID(redactedEvent.SenderID(), true) if err1 == nil { sender1Domain = string(sender1.Domain()) } + // TODO: Don't hack senderID into userID here (pseudoIDs) sender2Domain := "" - sender2, err2 := redactionEvent.UserID() + sender2, err2 := spec.NewUserID(redactionEvent.SenderID(), true) if err2 == nil { sender2Domain = string(sender2.Domain()) } @@ -1522,6 +1524,16 @@ func (d *Database) GetKnownUsers(ctx context.Context, userID, searchString strin return d.MembershipTable.SelectKnownUsers(ctx, nil, stateKeyNID, searchString, limit) } +func (d *Database) GetUserIDForSender(ctx context.Context, roomAliasOrID string, senderID string) (*spec.UserID, error) { + // TODO: Use real logic once DB for pseudoIDs is in place + return spec.NewUserID(senderID, true) +} + +func (d *Database) GetSenderIDForUser(ctx context.Context, roomAliasOrID string, userID spec.UserID) (string, error) { + // TODO: Use real logic once DB for pseudoIDs is in place + return userID.String(), nil +} + // GetKnownRooms returns a list of all rooms we know about. func (d *Database) GetKnownRooms(ctx context.Context) ([]string, error) { return d.RoomsTable.SelectRoomIDsWithEvents(ctx, nil) diff --git a/setup/mscs/msc2836/msc2836.go b/setup/mscs/msc2836/msc2836.go index f468b048a6..6e5bc30276 100644 --- a/setup/mscs/msc2836/msc2836.go +++ b/setup/mscs/msc2836/msc2836.go @@ -92,9 +92,11 @@ type MSC2836EventRelationshipsResponse struct { ParsedAuthChain []gomatrixserverlib.PDU } -func toClientResponse(res *MSC2836EventRelationshipsResponse) *EventRelationshipResponse { +func toClientResponse(ctx context.Context, res *MSC2836EventRelationshipsResponse, rsAPI roomserver.RoomserverInternalAPI) *EventRelationshipResponse { out := &EventRelationshipResponse{ - Events: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(res.ParsedEvents), synctypes.FormatAll), + Events: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(res.ParsedEvents), synctypes.FormatAll, func(roomAliasOrID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + }), Limited: res.Limited, NextBatch: res.NextBatch, } @@ -187,7 +189,7 @@ func eventRelationshipHandler(db Database, rsAPI roomserver.RoomserverInternalAP return util.JSONResponse{ Code: 200, - JSON: toClientResponse(res), + JSON: toClientResponse(req.Context(), res, rsAPI), } } } diff --git a/setup/mscs/msc2836/msc2836_test.go b/setup/mscs/msc2836/msc2836_test.go index 2c6f63d452..c8bc8b509f 100644 --- a/setup/mscs/msc2836/msc2836_test.go +++ b/setup/mscs/msc2836/msc2836_test.go @@ -525,6 +525,10 @@ type testRoomserverAPI struct { events map[string]*types.HeaderedEvent } +func (r *testRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomAliasOrID string, senderID string) (*spec.UserID, error) { + return spec.NewUserID(senderID, true) +} + func (r *testRoomserverAPI) QueryEventsByID(ctx context.Context, req *roomserver.QueryEventsByIDRequest, res *roomserver.QueryEventsByIDResponse) error { for _, eventID := range req.EventIDs { ev := r.events[eventID] diff --git a/syncapi/routing/context.go b/syncapi/routing/context.go index ac17d39d2b..5046ba1cde 100644 --- a/syncapi/routing/context.go +++ b/syncapi/routing/context.go @@ -193,14 +193,20 @@ func Context( } } - eventsBeforeClient := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsBeforeFiltered), synctypes.FormatAll) - eventsAfterClient := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsAfterFiltered), synctypes.FormatAll) + eventsBeforeClient := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsBeforeFiltered), synctypes.FormatAll, func(roomAliasOrID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + }) + eventsAfterClient := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsAfterFiltered), synctypes.FormatAll, func(roomAliasOrID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + }) newState := state if filter.LazyLoadMembers { allEvents := append(eventsBeforeFiltered, eventsAfterFiltered...) allEvents = append(allEvents, &requestedEvent) - evs := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(allEvents), synctypes.FormatAll) + evs := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(allEvents), synctypes.FormatAll, func(roomAliasOrID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + }) newState, err = applyLazyLoadMembers(ctx, device, snapshot, roomID, evs, lazyLoadCache) if err != nil { logrus.WithError(err).Error("unable to load membership events") @@ -211,12 +217,16 @@ func Context( } } - ev := synctypes.ToClientEvent(&requestedEvent, synctypes.FormatAll) + ev := synctypes.ToClientEvent(&requestedEvent, synctypes.FormatAll, func(roomAliasOrID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + }) response := ContextRespsonse{ Event: &ev, EventsAfter: eventsAfterClient, EventsBefore: eventsBeforeClient, - State: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(newState), synctypes.FormatAll), + State: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(newState), synctypes.FormatAll, func(roomAliasOrID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + }), } if len(response.State) > filter.Limit { diff --git a/syncapi/routing/getevent.go b/syncapi/routing/getevent.go index 0d3d412f64..878c5d90ab 100644 --- a/syncapi/routing/getevent.go +++ b/syncapi/routing/getevent.go @@ -103,6 +103,8 @@ func GetEvent( return util.JSONResponse{ Code: http.StatusOK, - JSON: synctypes.ToClientEvent(events[0], synctypes.FormatAll), + JSON: synctypes.ToClientEvent(events[0], synctypes.FormatAll, func(roomAliasOrID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(req.Context(), roomAliasOrID, senderID) + }), } } diff --git a/syncapi/routing/memberships.go b/syncapi/routing/memberships.go index f92dffa2ea..699e08015c 100644 --- a/syncapi/routing/memberships.go +++ b/syncapi/routing/memberships.go @@ -153,6 +153,8 @@ func GetMemberships( } return util.JSONResponse{ Code: http.StatusOK, - JSON: getMembershipResponse{synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(result), synctypes.FormatAll)}, + JSON: getMembershipResponse{synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(result), synctypes.FormatAll, func(roomAliasOrID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(req.Context(), roomAliasOrID, senderID) + })}, } } diff --git a/syncapi/routing/messages.go b/syncapi/routing/messages.go index 58f663d0ba..6e6ac8539c 100644 --- a/syncapi/routing/messages.go +++ b/syncapi/routing/messages.go @@ -241,7 +241,7 @@ func OnIncomingMessagesRequest( device: device, } - clientEvents, start, end, err := mReq.retrieveEvents() + clientEvents, start, end, err := mReq.retrieveEvents(req.Context(), rsAPI) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("mreq.retrieveEvents failed") return util.JSONResponse{ @@ -273,7 +273,9 @@ func OnIncomingMessagesRequest( JSON: spec.InternalServerError{}, } } - res.State = append(res.State, synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(membershipEvents), synctypes.FormatAll)...) + res.State = append(res.State, synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(membershipEvents), synctypes.FormatAll, func(roomAliasOrID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(req.Context(), roomAliasOrID, senderID) + })...) } // If we didn't return any events, set the end to an empty string, so it will be omitted @@ -310,7 +312,7 @@ func getMembershipForUser(ctx context.Context, roomID, userID string, rsAPI api. // homeserver in the room for older events. // Returns an error if there was an issue talking to the database or with the // remote homeserver. -func (r *messagesReq) retrieveEvents() ( +func (r *messagesReq) retrieveEvents(ctx context.Context, rsAPI api.SyncRoomserverAPI) ( clientEvents []synctypes.ClientEvent, start, end types.TopologyToken, err error, ) { @@ -382,7 +384,9 @@ func (r *messagesReq) retrieveEvents() ( "events_before": len(events), "events_after": len(filteredEvents), }).Debug("applied history visibility (messages)") - return synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(filteredEvents), synctypes.FormatAll), start, end, err + return synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(filteredEvents), synctypes.FormatAll, func(roomAliasOrID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + }), start, end, err } func (r *messagesReq) getStartEnd(events []*rstypes.HeaderedEvent) (start, end types.TopologyToken, err error) { diff --git a/syncapi/routing/relations.go b/syncapi/routing/relations.go index 8374bf5b04..3913f2d86d 100644 --- a/syncapi/routing/relations.go +++ b/syncapi/routing/relations.go @@ -116,7 +116,9 @@ func Relations( for _, event := range filteredEvents { res.Chunk = append( res.Chunk, - synctypes.ToClientEvent(event.PDU, synctypes.FormatAll), + synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, func(roomAliasOrID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(req.Context(), roomAliasOrID, senderID) + }), ) } diff --git a/syncapi/routing/routing.go b/syncapi/routing/routing.go index 9ad0c04766..8542c0b73e 100644 --- a/syncapi/routing/routing.go +++ b/syncapi/routing/routing.go @@ -171,7 +171,7 @@ func Setup( nb := req.FormValue("next_batch") nextBatch = &nb } - return Search(req, device, syncDB, fts, nextBatch) + return Search(req, device, syncDB, fts, nextBatch, rsAPI) }), ).Methods(http.MethodPost, http.MethodOptions) diff --git a/syncapi/routing/search.go b/syncapi/routing/search.go index 21dbabf7e1..991545315e 100644 --- a/syncapi/routing/search.go +++ b/syncapi/routing/search.go @@ -31,6 +31,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/internal/fulltext" "github.com/matrix-org/dendrite/internal/sqlutil" + roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/synctypes" @@ -38,7 +39,7 @@ import ( ) // nolint:gocyclo -func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts fulltext.Indexer, from *string) util.JSONResponse { +func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts fulltext.Indexer, from *string, rsAPI roomserverAPI.SyncRoomserverAPI) util.JSONResponse { start := time.Now() var ( searchReq SearchRequest @@ -225,14 +226,20 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts results = append(results, Result{ Context: SearchContextResponse{ - Start: startToken.String(), - End: endToken.String(), - EventsAfter: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsAfter), synctypes.FormatSync), - EventsBefore: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsBefore), synctypes.FormatSync), - ProfileInfo: profileInfos, + Start: startToken.String(), + End: endToken.String(), + EventsAfter: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsAfter), synctypes.FormatSync, func(roomAliasOrID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(req.Context(), roomAliasOrID, senderID) + }), + EventsBefore: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsBefore), synctypes.FormatSync, func(roomAliasOrID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(req.Context(), roomAliasOrID, senderID) + }), + ProfileInfo: profileInfos, }, - Rank: eventScore[event.EventID()].Score, - Result: synctypes.ToClientEvent(event, synctypes.FormatAll), + Rank: eventScore[event.EventID()].Score, + Result: synctypes.ToClientEvent(event, synctypes.FormatAll, func(roomAliasOrID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(req.Context(), roomAliasOrID, senderID) + }), }) roomGroup := groups[event.RoomID()] roomGroup.Results = append(roomGroup.Results, event.EventID()) @@ -247,7 +254,9 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts JSON: spec.InternalServerError{}, } } - stateForRooms[event.RoomID()] = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(state), synctypes.FormatSync) + stateForRooms[event.RoomID()] = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(state), synctypes.FormatSync, func(roomAliasOrID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(req.Context(), roomAliasOrID, senderID) + }) } } diff --git a/syncapi/routing/search_test.go b/syncapi/routing/search_test.go index 1cc95a873a..02410820ba 100644 --- a/syncapi/routing/search_test.go +++ b/syncapi/routing/search_test.go @@ -2,6 +2,7 @@ package routing import ( "bytes" + "context" "encoding/json" "net/http" "net/http/httptest" @@ -9,6 +10,7 @@ import ( "github.com/matrix-org/dendrite/internal/fulltext" "github.com/matrix-org/dendrite/internal/sqlutil" + rsapi "github.com/matrix-org/dendrite/roomserver/api" rstypes "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/synctypes" @@ -21,6 +23,12 @@ import ( "github.com/stretchr/testify/assert" ) +type FakeSyncRoomserverAPI struct{ rsapi.SyncRoomserverAPI } + +func (f *FakeSyncRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomAliasOrID string, senderID string) (*spec.UserID, error) { + return spec.NewUserID(senderID, true) +} + func TestSearch(t *testing.T) { alice := test.NewUser(t) aliceDevice := userapi.Device{UserID: alice.ID} @@ -247,7 +255,7 @@ func TestSearch(t *testing.T) { assert.NoError(t, err) req := httptest.NewRequest(http.MethodPost, "/", reqBody) - res := Search(req, tc.device, db, fts, tc.from) + res := Search(req, tc.device, db, fts, tc.from, &FakeSyncRoomserverAPI{}) if !tc.wantOK && !res.Is2xx() { return } diff --git a/syncapi/streams/stream_invite.go b/syncapi/streams/stream_invite.go index 2f05382bbb..fd7b9dbb44 100644 --- a/syncapi/streams/stream_invite.go +++ b/syncapi/streams/stream_invite.go @@ -10,6 +10,7 @@ import ( "github.com/matrix-org/gomatrixserverlib/spec" + "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/syncapi/types" @@ -17,6 +18,7 @@ import ( type InviteStreamProvider struct { DefaultStreamProvider + rsAPI api.SyncRoomserverAPI } func (p *InviteStreamProvider) Setup( @@ -66,7 +68,9 @@ func (p *InviteStreamProvider) IncrementalSync( if _, ok := req.IgnoredUsers.List[inviteEvent.SenderID()]; ok { continue } - ir := types.NewInviteResponse(inviteEvent) + ir := types.NewInviteResponse(inviteEvent, func(roomAliasOrID, senderID string) (*spec.UserID, error) { + return p.rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + }) req.Response.Rooms.Invite[roomID] = ir } diff --git a/syncapi/streams/stream_pdu.go b/syncapi/streams/stream_pdu.go index 46140060f8..4473a0196f 100644 --- a/syncapi/streams/stream_pdu.go +++ b/syncapi/streams/stream_pdu.go @@ -376,20 +376,28 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( } } jr.Timeline.PrevBatch = &prevBatch - jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync) + jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomAliasOrID, senderID string) (*spec.UserID, error) { + return p.rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + }) // If we are limited by the filter AND the history visibility filter // didn't "remove" events, return that the response is limited. jr.Timeline.Limited = (limited && len(events) == len(recentEvents)) || delta.NewlyJoined - jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync) + jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomAliasOrID, senderID string) (*spec.UserID, error) { + return p.rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + }) req.Response.Rooms.Join[delta.RoomID] = jr case spec.Peek: jr := types.NewJoinResponse() jr.Timeline.PrevBatch = &prevBatch // TODO: Apply history visibility on peeked rooms - jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(recentEvents), synctypes.FormatSync) + jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(recentEvents), synctypes.FormatSync, func(roomAliasOrID, senderID string) (*spec.UserID, error) { + return p.rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + }) jr.Timeline.Limited = limited - jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync) + jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomAliasOrID, senderID string) (*spec.UserID, error) { + return p.rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + }) req.Response.Rooms.Peek[delta.RoomID] = jr case spec.Leave: @@ -398,11 +406,15 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( case spec.Ban: lr := types.NewLeaveResponse() lr.Timeline.PrevBatch = &prevBatch - lr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync) + lr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomAliasOrID, senderID string) (*spec.UserID, error) { + return p.rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + }) // If we are limited by the filter AND the history visibility filter // didn't "remove" events, return that the response is limited. lr.Timeline.Limited = limited && len(events) == len(recentEvents) - lr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync) + lr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomAliasOrID, senderID string) (*spec.UserID, error) { + return p.rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + }) req.Response.Rooms.Leave[delta.RoomID] = lr } @@ -552,11 +564,15 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync( } jr.Timeline.PrevBatch = prevBatch - jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync) + jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomAliasOrID, senderID string) (*spec.UserID, error) { + return p.rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + }) // If we are limited by the filter AND the history visibility filter // didn't "remove" events, return that the response is limited. jr.Timeline.Limited = limited && len(events) == len(recentEvents) - jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(stateEvents), synctypes.FormatSync) + jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(stateEvents), synctypes.FormatSync, func(roomAliasOrID, senderID string) (*spec.UserID, error) { + return p.rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + }) return jr, nil } diff --git a/syncapi/streams/streams.go b/syncapi/streams/streams.go index a35491acf4..f25bc978fb 100644 --- a/syncapi/streams/streams.go +++ b/syncapi/streams/streams.go @@ -45,6 +45,7 @@ func NewSyncStreamProviders( }, InviteStreamProvider: &InviteStreamProvider{ DefaultStreamProvider: DefaultStreamProvider{DB: d}, + rsAPI: rsAPI, }, SendToDeviceStreamProvider: &SendToDeviceStreamProvider{ DefaultStreamProvider: DefaultStreamProvider{DB: d}, diff --git a/syncapi/syncapi_test.go b/syncapi/syncapi_test.go index bc766e6633..ac46e10eb1 100644 --- a/syncapi/syncapi_test.go +++ b/syncapi/syncapi_test.go @@ -40,6 +40,10 @@ type syncRoomserverAPI struct { rooms []*test.Room } +func (s *syncRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomAliasOrID string, senderID string) (*spec.UserID, error) { + return spec.NewUserID(senderID, true) +} + func (s *syncRoomserverAPI) QueryLatestEventsAndState(ctx context.Context, req *rsapi.QueryLatestEventsAndStateRequest, res *rsapi.QueryLatestEventsAndStateResponse) error { var room *test.Room for _, r := range s.rooms { diff --git a/syncapi/synctypes/clientevent.go b/syncapi/synctypes/clientevent.go index ba8ac77965..6ab12102d0 100644 --- a/syncapi/synctypes/clientevent.go +++ b/syncapi/synctypes/clientevent.go @@ -44,21 +44,21 @@ type ClientEvent struct { } // ToClientEvents converts server events to client events. -func ToClientEvents(serverEvs []gomatrixserverlib.PDU, format ClientEventFormat) []ClientEvent { +func ToClientEvents(serverEvs []gomatrixserverlib.PDU, format ClientEventFormat, userIDForSender spec.UserIDForSender) []ClientEvent { evs := make([]ClientEvent, 0, len(serverEvs)) for _, se := range serverEvs { if se == nil { continue // TODO: shouldn't happen? } - evs = append(evs, ToClientEvent(se, format)) + evs = append(evs, ToClientEvent(se, format, userIDForSender)) } return evs } // ToClientEvent converts a single server event to a client event. -func ToClientEvent(se gomatrixserverlib.PDU, format ClientEventFormat) ClientEvent { +func ToClientEvent(se gomatrixserverlib.PDU, format ClientEventFormat, userIDForSender spec.UserIDForSender) ClientEvent { user := "" - userID, err := se.UserID() + userID, err := userIDForSender(se.RoomID(), se.SenderID()) if err == nil { user = userID.String() } diff --git a/syncapi/synctypes/clientevent_test.go b/syncapi/synctypes/clientevent_test.go index a4e7b3ff87..5ae353f799 100644 --- a/syncapi/synctypes/clientevent_test.go +++ b/syncapi/synctypes/clientevent_test.go @@ -21,8 +21,13 @@ import ( "testing" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) +func UserIDForSender(roomAliasOrID string, senderID string) (*spec.UserID, error) { + return spec.NewUserID(senderID, true) +} + func TestToClientEvent(t *testing.T) { // nolint: gocyclo ev, err := gomatrixserverlib.MustGetRoomVersion(gomatrixserverlib.RoomVersionV1).NewEventFromTrustedJSON([]byte(`{ "type": "m.room.name", @@ -43,7 +48,7 @@ func TestToClientEvent(t *testing.T) { // nolint: gocyclo if err != nil { t.Fatalf("failed to create Event: %s", err) } - ce := ToClientEvent(ev, FormatAll) + ce := ToClientEvent(ev, FormatAll, UserIDForSender) if ce.EventID != ev.EventID() { t.Errorf("ClientEvent.EventID: wanted %s, got %s", ev.EventID(), ce.EventID) } @@ -63,7 +68,7 @@ func TestToClientEvent(t *testing.T) { // nolint: gocyclo t.Errorf("ClientEvent.Unsigned: wanted %s, got %s", string(ev.Unsigned()), string(ce.Unsigned)) } user := "" - userID, err := ev.UserID() + userID, err := UserIDForSender("", ev.SenderID()) if err == nil { user = userID.String() } @@ -103,7 +108,7 @@ func TestToClientFormatSync(t *testing.T) { if err != nil { t.Fatalf("failed to create Event: %s", err) } - ce := ToClientEvent(ev, FormatSync) + ce := ToClientEvent(ev, FormatSync, UserIDForSender) if ce.RoomID != "" { t.Errorf("ClientEvent.RoomID: wanted '', got %s", ce.RoomID) } diff --git a/syncapi/types/types.go b/syncapi/types/types.go index 22c27fea5d..3b3bd79872 100644 --- a/syncapi/types/types.go +++ b/syncapi/types/types.go @@ -539,7 +539,7 @@ type InviteResponse struct { } // NewInviteResponse creates an empty response with initialised arrays. -func NewInviteResponse(event *types.HeaderedEvent) *InviteResponse { +func NewInviteResponse(event *types.HeaderedEvent, userIDForSender spec.UserIDForSender) *InviteResponse { res := InviteResponse{} res.InviteState.Events = []json.RawMessage{} @@ -552,7 +552,7 @@ func NewInviteResponse(event *types.HeaderedEvent) *InviteResponse { // Then we'll see if we can create a partial of the invite event itself. // This is needed for clients to work out *who* sent the invite. - inviteEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatSync) + inviteEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatSync, userIDForSender) inviteEvent.Unsigned = nil if ev, err := json.Marshal(inviteEvent); err == nil { res.InviteState.Events = append(res.InviteState.Events, ev) diff --git a/syncapi/types/types_test.go b/syncapi/types/types_test.go index 8e0448fe76..2cd8e74723 100644 --- a/syncapi/types/types_test.go +++ b/syncapi/types/types_test.go @@ -8,8 +8,13 @@ import ( "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) +func UserIDForSender(roomAliasOrID string, senderID string) (*spec.UserID, error) { + return spec.NewUserID(senderID, true) +} + func TestSyncTokens(t *testing.T) { shouldPass := map[string]string{ "s4_0_0_0_0_0_0_0_3": StreamingToken{4, 0, 0, 0, 0, 0, 0, 0, 3}.String(), @@ -56,7 +61,7 @@ func TestNewInviteResponse(t *testing.T) { t.Fatal(err) } - res := NewInviteResponse(&types.HeaderedEvent{PDU: ev}) + res := NewInviteResponse(&types.HeaderedEvent{PDU: ev}, UserIDForSender) j, err := json.Marshal(res) if err != nil { t.Fatal(err) diff --git a/test/room.go b/test/room.go index 852e31533b..aefdd87697 100644 --- a/test/room.go +++ b/test/room.go @@ -39,6 +39,10 @@ var ( roomIDCounter = int64(0) ) +func UserIDForSender(roomAliasOrID string, senderID string) (*spec.UserID, error) { + return spec.NewUserID(senderID, true) +} + type Room struct { ID string Version gomatrixserverlib.RoomVersion @@ -195,7 +199,7 @@ func (r *Room) CreateEvent(t *testing.T, creator *User, eventType string, conten if err != nil { t.Fatalf("CreateEvent[%s]: failed to build event: %s", eventType, err) } - if err = gomatrixserverlib.Allowed(ev, &r.authEvents); err != nil { + if err = gomatrixserverlib.Allowed(ev, &r.authEvents, UserIDForSender); err != nil { t.Fatalf("CreateEvent[%s]: failed to verify event was allowed: %s", eventType, err) } headeredEvent := &rstypes.HeaderedEvent{PDU: ev} diff --git a/userapi/consumers/roomserver.go b/userapi/consumers/roomserver.go index f51874a13d..2d7bc4814e 100644 --- a/userapi/consumers/roomserver.go +++ b/userapi/consumers/roomserver.go @@ -301,7 +301,9 @@ func (s *OutputRoomEventConsumer) processMessage(ctx context.Context, event *rst switch { case event.Type() == spec.MRoomMember: - cevent := synctypes.ToClientEvent(event, synctypes.FormatAll) + cevent := synctypes.ToClientEvent(event, synctypes.FormatAll, func(roomAliasOrID, senderID string) (*spec.UserID, error) { + return s.rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + }) var member *localMembership member, err = newLocalMembership(&cevent) if err != nil { @@ -534,7 +536,9 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *rstype // UNSPEC: the spec doesn't say this is a ClientEvent, but the // fields seem to match. room_id should be missing, which // matches the behaviour of FormatSync. - Event: synctypes.ToClientEvent(event, synctypes.FormatSync), + Event: synctypes.ToClientEvent(event, synctypes.FormatSync, func(roomAliasOrID string, senderID string) (*spec.UserID, error) { + return s.rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + }), // TODO: this is per-device, but it's not part of the primary // key. So inserting one notification per profile tag doesn't // make sense. What is this supposed to be? Sytests require it @@ -616,7 +620,7 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *rstype // user. Returns actions (including dont_notify). func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event *rstypes.HeaderedEvent, mem *localMembership, roomSize int) ([]*pushrules.Action, error) { user := "" - userID, err := event.UserID() + userID, err := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()) if err == nil { user = userID.String() } @@ -655,7 +659,9 @@ func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event * roomSize: roomSize, } eval := pushrules.NewRuleSetEvaluator(ec, &ruleSets.Global) - rule, err := eval.MatchEvent(event.PDU) + rule, err := eval.MatchEvent(event.PDU, func(roomAliasOrID, senderID string) (*spec.UserID, error) { + return s.rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + }) if err != nil { return nil, err } diff --git a/userapi/consumers/roomserver_test.go b/userapi/consumers/roomserver_test.go index 977b7fda86..09e57b8955 100644 --- a/userapi/consumers/roomserver_test.go +++ b/userapi/consumers/roomserver_test.go @@ -14,6 +14,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/matrix-org/dendrite/internal/pushrules" + rsapi "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/userapi/storage" @@ -44,13 +45,19 @@ func mustCreateEvent(t *testing.T, content string) *types.HeaderedEvent { return &types.HeaderedEvent{PDU: ev} } +type FakeUserRoomserverAPI struct{ rsapi.UserRoomserverAPI } + +func (f *FakeUserRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomAliasOrID string, senderID string) (*spec.UserID, error) { + return spec.NewUserID(senderID, true) +} + func Test_evaluatePushRules(t *testing.T) { ctx := context.Background() test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { db, close := mustCreateDatabase(t, dbType) defer close() - consumer := OutputRoomEventConsumer{db: db} + consumer := OutputRoomEventConsumer{db: db, rsAPI: &FakeUserRoomserverAPI{}} testCases := []struct { name string diff --git a/userapi/util/notify_test.go b/userapi/util/notify_test.go index e1c88d47f9..778461161b 100644 --- a/userapi/util/notify_test.go +++ b/userapi/util/notify_test.go @@ -11,6 +11,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "golang.org/x/crypto/bcrypt" @@ -22,6 +23,10 @@ import ( userUtil "github.com/matrix-org/dendrite/userapi/util" ) +func UserIDForSender(roomAliasOrID string, senderID string) (*spec.UserID, error) { + return spec.NewUserID(senderID, true) +} + func TestNotifyUserCountsAsync(t *testing.T) { alice := test.NewUser(t) aliceLocalpart, serverName, err := gomatrixserverlib.SplitID('@', alice.ID) @@ -100,7 +105,7 @@ func TestNotifyUserCountsAsync(t *testing.T) { // Insert a dummy event if err := db.InsertNotification(ctx, aliceLocalpart, serverName, dummyEvent.EventID(), 0, nil, &api.Notification{ - Event: synctypes.ToClientEvent(dummyEvent, synctypes.FormatAll), + Event: synctypes.ToClientEvent(dummyEvent, synctypes.FormatAll, UserIDForSender), }); err != nil { t.Error(err) } From 39545d5000132dfb7775d94325e54a4c2bffc6c5 Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Mon, 5 Jun 2023 20:55:53 -0600 Subject: [PATCH 16/28] Fix unpopulated struct fields --- federationapi/routing/invite.go | 3 +++ federationapi/routing/join.go | 3 +++ federationapi/routing/leave.go | 3 +++ 3 files changed, 9 insertions(+) diff --git a/federationapi/routing/invite.go b/federationapi/routing/invite.go index 78a09d9496..81dbf26a75 100644 --- a/federationapi/routing/invite.go +++ b/federationapi/routing/invite.go @@ -95,6 +95,9 @@ func InviteV2( StateQuerier: rsAPI.StateQuerier(), InviteEvent: inviteReq.Event(), StrippedState: inviteReq.InviteRoomState(), + UserIDQuerier: func(roomAliasOrID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(httpReq.Context(), roomAliasOrID, senderID) + }, } event, jsonErr := handleInvite(httpReq.Context(), input, rsAPI) if jsonErr != nil { diff --git a/federationapi/routing/join.go b/federationapi/routing/join.go index f00830e50a..f09c5aff39 100644 --- a/federationapi/routing/join.go +++ b/federationapi/routing/join.go @@ -253,6 +253,9 @@ func SendJoin( PrivateKey: cfg.Matrix.PrivateKey, Verifier: keys, MembershipQuerier: &api.MembershipQuerier{Roomserver: rsAPI}, + UserIDQuerier: func(roomAliasOrID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(httpReq.Context(), roomAliasOrID, senderID) + }, } response, joinErr := gomatrixserverlib.HandleSendJoin(input) switch e := joinErr.(type) { diff --git a/federationapi/routing/leave.go b/federationapi/routing/leave.go index 25365b4ecd..de95330b41 100644 --- a/federationapi/routing/leave.go +++ b/federationapi/routing/leave.go @@ -95,6 +95,9 @@ func MakeLeave( LocalServerName: cfg.Matrix.ServerName, LocalServerInRoom: res.RoomExists && res.IsInRoom, BuildEventTemplate: createLeaveTemplate, + UserIDQuerier: func(roomAliasOrID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(httpReq.Context(), roomAliasOrID, senderID) + }, } response, internalErr := gomatrixserverlib.HandleMakeLeave(input) From 872f8e4af8deb7caf2abdf960787bec591087804 Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Mon, 5 Jun 2023 21:18:18 -0600 Subject: [PATCH 17/28] Fix unpopulated struct fields for invite v1 --- federationapi/routing/invite.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/federationapi/routing/invite.go b/federationapi/routing/invite.go index 81dbf26a75..59f8dd1bea 100644 --- a/federationapi/routing/invite.go +++ b/federationapi/routing/invite.go @@ -188,6 +188,9 @@ func InviteV1( StateQuerier: rsAPI.StateQuerier(), InviteEvent: event, StrippedState: strippedState, + UserIDQuerier: func(roomAliasOrID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(httpReq.Context(), roomAliasOrID, senderID) + }, } event, jsonErr := handleInvite(httpReq.Context(), input, rsAPI) if jsonErr != nil { From e388c9ca425641847731225019cf38c25165bca0 Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Tue, 6 Jun 2023 09:44:16 -0600 Subject: [PATCH 18/28] Make senderID query use only roomID --- appservice/consumers/roomserver.go | 4 +-- clientapi/routing/directory.go | 14 ++++++-- clientapi/routing/sendevent.go | 4 +-- clientapi/routing/state.go | 12 +++---- cmd/resolve-state/main.go | 4 +-- federationapi/federationapi_test.go | 2 +- federationapi/internal/perform.go | 12 +++---- federationapi/routing/invite.go | 8 ++--- federationapi/routing/join.go | 8 ++--- federationapi/routing/leave.go | 4 +-- go.mod | 2 +- go.sum | 2 ++ internal/pushrules/evaluate_test.go | 2 +- internal/transactionrequest.go | 4 +-- internal/transactionrequest_test.go | 4 +-- roomserver/api/api.go | 5 ++- roomserver/internal/helpers/auth.go | 4 +-- roomserver/internal/input/input_events.go | 16 +++++----- .../internal/input/input_events_test.go | 2 +- roomserver/internal/input/input_missing.go | 20 ++++++------ roomserver/internal/perform/perform_admin.go | 8 ++--- .../internal/perform/perform_backfill.go | 8 ++--- .../internal/perform/perform_create_room.go | 4 +-- roomserver/internal/perform/perform_invite.go | 4 +-- .../internal/perform/perform_upgrade.go | 8 ++--- roomserver/internal/query/query.go | 24 +++++++------- roomserver/state/state.go | 10 +++--- roomserver/storage/interface.go | 6 ++-- roomserver/storage/shared/room_updater.go | 4 +-- roomserver/storage/shared/storage.go | 4 +-- setup/mscs/msc2836/msc2836.go | 4 +-- setup/mscs/msc2836/msc2836_test.go | 2 +- syncapi/routing/context.go | 20 ++++++------ syncapi/routing/getevent.go | 4 +-- syncapi/routing/memberships.go | 4 +-- syncapi/routing/messages.go | 8 ++--- syncapi/routing/relations.go | 4 +-- syncapi/routing/search.go | 16 +++++----- syncapi/routing/search_test.go | 2 +- syncapi/streams/stream_invite.go | 4 +-- syncapi/streams/stream_pdu.go | 32 +++++++++---------- syncapi/syncapi_test.go | 2 +- syncapi/synctypes/clientevent_test.go | 2 +- syncapi/types/types_test.go | 2 +- test/room.go | 2 +- userapi/consumers/roomserver.go | 12 +++---- userapi/consumers/roomserver_test.go | 2 +- userapi/util/notify_test.go | 2 +- 48 files changed, 174 insertions(+), 163 deletions(-) diff --git a/appservice/consumers/roomserver.go b/appservice/consumers/roomserver.go index 5650a26b7f..06625ad7e6 100644 --- a/appservice/consumers/roomserver.go +++ b/appservice/consumers/roomserver.go @@ -181,8 +181,8 @@ func (s *OutputRoomEventConsumer) sendEvents( // Create the transaction body. transaction, err := json.Marshal( ApplicationServiceTransaction{ - Events: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatAll, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return s.rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + Events: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { + return s.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }), }, ) diff --git a/clientapi/routing/directory.go b/clientapi/routing/directory.go index bc8ef6d655..c702b136d4 100644 --- a/clientapi/routing/directory.go +++ b/clientapi/routing/directory.go @@ -189,7 +189,7 @@ func SetLocalAlias( } } - deviceSenderID, err := rsAPI.QuerySenderIDForUser(req.Context(), alias, *userID) + deviceSenderID, err := rsAPI.QuerySenderIDForUser(req.Context(), r.RoomID, *userID) if err != nil { return util.JSONResponse{ Code: http.StatusInternalServerError, @@ -238,7 +238,17 @@ func RemoveLocalAlias( } } - deviceSenderID, err := rsAPI.QuerySenderIDForUser(req.Context(), alias, *userID) + roomIDReq := roomserverAPI.GetRoomIDForAliasRequest{Alias: alias} + roomIDRes := roomserverAPI.GetRoomIDForAliasResponse{} + err = rsAPI.GetRoomIDForAlias(req.Context(), &roomIDReq, &roomIDRes) + if err != nil { + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: spec.NotFound("The alias does not exist."), + } + } + + deviceSenderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomIDRes.RoomID, *userID) if err != nil { return util.JSONResponse{ Code: http.StatusInternalServerError, diff --git a/clientapi/routing/sendevent.go b/clientapi/routing/sendevent.go index 5fb755f578..8b09f399a7 100644 --- a/clientapi/routing/sendevent.go +++ b/clientapi/routing/sendevent.go @@ -331,8 +331,8 @@ func generateSendEvent( stateEvents[i] = queryRes.StateEvents[i].PDU } provider := gomatrixserverlib.NewAuthEvents(gomatrixserverlib.ToPDUs(stateEvents)) - if err = gomatrixserverlib.Allowed(e.PDU, &provider, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + if err = gomatrixserverlib.Allowed(e.PDU, &provider, func(roomID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }); err != nil { return nil, &util.JSONResponse{ Code: http.StatusForbidden, diff --git a/clientapi/routing/state.go b/clientapi/routing/state.go index 5dda21abfc..8567e617d1 100644 --- a/clientapi/routing/state.go +++ b/clientapi/routing/state.go @@ -142,8 +142,8 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a for _, ev := range stateRes.StateEvents { stateEvents = append( stateEvents, - synctypes.ToClientEvent(ev, synctypes.FormatAll, func(roomAliasOrID string, senderID string) (*spec.UserID, error) { - return rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + synctypes.ToClientEvent(ev, synctypes.FormatAll, func(roomID string, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }), ) } @@ -166,8 +166,8 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a for _, ev := range stateAfterRes.StateEvents { stateEvents = append( stateEvents, - synctypes.ToClientEvent(ev, synctypes.FormatAll, func(roomAliasOrID string, senderID string) (*spec.UserID, error) { - return rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + synctypes.ToClientEvent(ev, synctypes.FormatAll, func(roomID string, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }), ) } @@ -339,8 +339,8 @@ func OnIncomingStateTypeRequest( } stateEvent := stateEventInStateResp{ - ClientEvent: synctypes.ToClientEvent(event, synctypes.FormatAll, func(roomAliasOrID string, senderID string) (*spec.UserID, error) { - return rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + ClientEvent: synctypes.ToClientEvent(event, synctypes.FormatAll, func(roomID string, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }), } diff --git a/cmd/resolve-state/main.go b/cmd/resolve-state/main.go index 1ebb6cd8dd..360403094f 100644 --- a/cmd/resolve-state/main.go +++ b/cmd/resolve-state/main.go @@ -183,8 +183,8 @@ func main() { fmt.Println("Resolving state") var resolved Events resolved, err = gomatrixserverlib.ResolveConflicts( - gomatrixserverlib.RoomVersion(*roomVersion), events, authEvents, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return roomserverDB.GetUserIDForSender(ctx, roomAliasOrID, senderID) + gomatrixserverlib.RoomVersion(*roomVersion), events, authEvents, func(roomID, senderID string) (*spec.UserID, error) { + return roomserverDB.GetUserIDForSender(ctx, roomID, senderID) }, ) if err != nil { diff --git a/federationapi/federationapi_test.go b/federationapi/federationapi_test.go index fc09440dad..a97bcdeab8 100644 --- a/federationapi/federationapi_test.go +++ b/federationapi/federationapi_test.go @@ -36,7 +36,7 @@ type fedRoomserverAPI struct { queryRoomsForUser func(ctx context.Context, req *rsapi.QueryRoomsForUserRequest, res *rsapi.QueryRoomsForUserResponse) error } -func (f *fedRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomAliasOrID string, senderID string) (*spec.UserID, error) { +func (f *fedRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) { return spec.NewUserID(senderID, true) } diff --git a/federationapi/internal/perform.go b/federationapi/internal/perform.go index 4f05979838..ef1396f704 100644 --- a/federationapi/internal/perform.go +++ b/federationapi/internal/perform.go @@ -164,11 +164,11 @@ func (r *FederationInternalAPI) performJoinUsingServer( PrivateKey: r.cfg.Matrix.PrivateKey, KeyID: r.cfg.Matrix.KeyID, KeyRing: r.keyRing, - EventProvider: federatedEventProvider(ctx, r.federation, r.keyRing, user.Domain(), serverName, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return r.rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + EventProvider: federatedEventProvider(ctx, r.federation, r.keyRing, user.Domain(), serverName, func(roomID, senderID string) (*spec.UserID, error) { + return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }), - UserIDQuerier: func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return r.rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + UserIDQuerier: func(roomID, senderID string) (*spec.UserID, error) { + return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }, } response, joinErr := gomatrixserverlib.PerformJoin(ctx, r, joinInput) @@ -363,8 +363,8 @@ func (r *FederationInternalAPI) performOutboundPeekUsingServer( // authenticate the state returned (check its auth events etc) // the equivalent of CheckSendJoinResponse() - userIDProvider := func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return r.rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + userIDProvider := func(roomID, senderID string) (*spec.UserID, error) { + return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) } authEvents, stateEvents, err := gomatrixserverlib.CheckStateResponse( ctx, &respPeek, respPeek.RoomVersion, r.keyRing, federatedEventProvider(ctx, r.federation, r.keyRing, r.cfg.Matrix.ServerName, serverName, userIDProvider), userIDProvider, diff --git a/federationapi/routing/invite.go b/federationapi/routing/invite.go index 59f8dd1bea..d792335b95 100644 --- a/federationapi/routing/invite.go +++ b/federationapi/routing/invite.go @@ -95,8 +95,8 @@ func InviteV2( StateQuerier: rsAPI.StateQuerier(), InviteEvent: inviteReq.Event(), StrippedState: inviteReq.InviteRoomState(), - UserIDQuerier: func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return rsAPI.QueryUserIDForSender(httpReq.Context(), roomAliasOrID, senderID) + UserIDQuerier: func(roomID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID) }, } event, jsonErr := handleInvite(httpReq.Context(), input, rsAPI) @@ -188,8 +188,8 @@ func InviteV1( StateQuerier: rsAPI.StateQuerier(), InviteEvent: event, StrippedState: strippedState, - UserIDQuerier: func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return rsAPI.QueryUserIDForSender(httpReq.Context(), roomAliasOrID, senderID) + UserIDQuerier: func(roomID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID) }, } event, jsonErr := handleInvite(httpReq.Context(), input, rsAPI) diff --git a/federationapi/routing/join.go b/federationapi/routing/join.go index f09c5aff39..cf6e430328 100644 --- a/federationapi/routing/join.go +++ b/federationapi/routing/join.go @@ -156,8 +156,8 @@ func MakeJoin( LocalServerName: cfg.Matrix.ServerName, LocalServerInRoom: res.RoomExists && res.IsInRoom, RoomQuerier: &roomQuerier, - UserIDQuerier: func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return rsAPI.QueryUserIDForSender(httpReq.Context(), roomAliasOrID, senderID) + UserIDQuerier: func(roomID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID) }, BuildEventTemplate: createJoinTemplate, } @@ -253,8 +253,8 @@ func SendJoin( PrivateKey: cfg.Matrix.PrivateKey, Verifier: keys, MembershipQuerier: &api.MembershipQuerier{Roomserver: rsAPI}, - UserIDQuerier: func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return rsAPI.QueryUserIDForSender(httpReq.Context(), roomAliasOrID, senderID) + UserIDQuerier: func(roomID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID) }, } response, joinErr := gomatrixserverlib.HandleSendJoin(input) diff --git a/federationapi/routing/leave.go b/federationapi/routing/leave.go index de95330b41..ec7628d308 100644 --- a/federationapi/routing/leave.go +++ b/federationapi/routing/leave.go @@ -95,8 +95,8 @@ func MakeLeave( LocalServerName: cfg.Matrix.ServerName, LocalServerInRoom: res.RoomExists && res.IsInRoom, BuildEventTemplate: createLeaveTemplate, - UserIDQuerier: func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return rsAPI.QueryUserIDForSender(httpReq.Context(), roomAliasOrID, senderID) + UserIDQuerier: func(roomID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID) }, } diff --git a/go.mod b/go.mod index 35f28c03f2..1e13d4a10b 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,7 @@ require ( github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 - github.com/matrix-org/gomatrixserverlib v0.0.0-20230606021710-b68a1b0eef30 + github.com/matrix-org/gomatrixserverlib v0.0.0-20230606154326-77b5ce0c692d github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 github.com/mattn/go-sqlite3 v1.14.16 diff --git a/go.sum b/go.sum index 8bac1e8079..35c3d713a0 100644 --- a/go.sum +++ b/go.sum @@ -325,6 +325,8 @@ github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= github.com/matrix-org/gomatrixserverlib v0.0.0-20230606021710-b68a1b0eef30 h1:G+Do1UoWazY0Fetq+eAX1h1+fimf19NGGyaS86hWg8s= github.com/matrix-org/gomatrixserverlib v0.0.0-20230606021710-b68a1b0eef30/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230606154326-77b5ce0c692d h1:AUrJcbgtIPtVYTTfV7DUyarW7hOMgsZQZUuy9r8fMv8= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230606154326-77b5ce0c692d/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU= github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a h1:awrPDf9LEFySxTLKYBMCiObelNx/cBuv/wzllvCCH3A= github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a/go.mod h1:HchJX9oKMXaT2xYFs0Ha/6Zs06mxLU8k6F1ODnrGkeQ= github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y= diff --git a/internal/pushrules/evaluate_test.go b/internal/pushrules/evaluate_test.go index 616dac894c..34c1436f46 100644 --- a/internal/pushrules/evaluate_test.go +++ b/internal/pushrules/evaluate_test.go @@ -8,7 +8,7 @@ import ( "github.com/matrix-org/gomatrixserverlib/spec" ) -func UserIDForSender(roomAliasOrID string, senderID string) (*spec.UserID, error) { +func UserIDForSender(roomID string, senderID string) (*spec.UserID, error) { return spec.NewUserID(senderID, true) } diff --git a/internal/transactionrequest.go b/internal/transactionrequest.go index 5d9be08c5c..0bbe0720cc 100644 --- a/internal/transactionrequest.go +++ b/internal/transactionrequest.go @@ -167,8 +167,8 @@ func (t *TxnReq) ProcessTransaction(ctx context.Context) (*fclient.RespSend, *ut } continue } - if err = gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return t.rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + if err = gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys, func(roomID, senderID string) (*spec.UserID, error) { + return t.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }); err != nil { util.GetLogger(ctx).WithError(err).Debugf("Transaction: Couldn't validate signature of event %q", event.EventID()) results[event.EventID()] = fclient.PDUResult{ diff --git a/internal/transactionrequest_test.go b/internal/transactionrequest_test.go index bcda6520e2..6f3ce0b3b8 100644 --- a/internal/transactionrequest_test.go +++ b/internal/transactionrequest_test.go @@ -70,7 +70,7 @@ type FakeRsAPI struct { bannedFromRoom bool } -func (r *FakeRsAPI) QueryUserIDForSender(ctx context.Context, roomAliasOrID string, senderID string) (*spec.UserID, error) { +func (r *FakeRsAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) { return spec.NewUserID(senderID, true) } @@ -642,7 +642,7 @@ type testRoomserverAPI struct { queryLatestEventsAndState func(*rsAPI.QueryLatestEventsAndStateRequest) rsAPI.QueryLatestEventsAndStateResponse } -func (t *testRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomAliasOrID string, senderID string) (*spec.UserID, error) { +func (t *testRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) { return spec.NewUserID(senderID, true) } diff --git a/roomserver/api/api.go b/roomserver/api/api.go index 38b4d80512..8a115823c6 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -67,9 +67,8 @@ type InputRoomEventsAPI interface { } type QuerySenderIDAPI interface { - // Accepts either roomID or alias - QuerySenderIDForUser(ctx context.Context, roomAliasOrID string, userID spec.UserID) (string, error) - QueryUserIDForSender(ctx context.Context, roomAliasOrID string, senderID string) (*spec.UserID, error) + QuerySenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (string, error) + QueryUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) } // Query the latest events and state for a room from the room server. diff --git a/roomserver/internal/helpers/auth.go b/roomserver/internal/helpers/auth.go index cb5bdcc40f..932ce61551 100644 --- a/roomserver/internal/helpers/auth.go +++ b/roomserver/internal/helpers/auth.go @@ -76,8 +76,8 @@ func CheckForSoftFail( } // Check if the event is allowed. - if err = gomatrixserverlib.Allowed(event.PDU, &authEvents, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return db.GetUserIDForSender(ctx, roomAliasOrID, senderID) + if err = gomatrixserverlib.Allowed(event.PDU, &authEvents, func(roomID, senderID string) (*spec.UserID, error) { + return db.GetUserIDForSender(ctx, roomID, senderID) }); err != nil { // return true, nil return true, err diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index ffaa20afaf..b692a9513c 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -276,8 +276,8 @@ func (r *Inputer) processRoomEvent( // Check if the event is allowed by its auth events. If it isn't then // we consider the event to be "rejected" — it will still be persisted. - if err = gomatrixserverlib.Allowed(event, &authEvents, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return r.DB.GetUserIDForSender(ctx, roomAliasOrID, senderID) + if err = gomatrixserverlib.Allowed(event, &authEvents, func(roomID, senderID string) (*spec.UserID, error) { + return r.DB.GetUserIDForSender(ctx, roomID, senderID) }); err != nil { isRejected = true rejectionErr = err @@ -581,8 +581,8 @@ func (r *Inputer) processStateBefore( stateBeforeAuth := gomatrixserverlib.NewAuthEvents( gomatrixserverlib.ToPDUs(stateBeforeEvent), ) - if rejectionErr = gomatrixserverlib.Allowed(event, &stateBeforeAuth, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return r.DB.GetUserIDForSender(ctx, roomAliasOrID, senderID) + if rejectionErr = gomatrixserverlib.Allowed(event, &stateBeforeAuth, func(roomID, senderID string) (*spec.UserID, error) { + return r.DB.GetUserIDForSender(ctx, roomID, senderID) }); rejectionErr != nil { rejectionErr = fmt.Errorf("Allowed() failed for stateBeforeEvent: %w", rejectionErr) return @@ -694,8 +694,8 @@ nextAuthEvent: // Check the signatures of the event. If this fails then we'll simply // skip it, because gomatrixserverlib.Allowed() will notice a problem // if a critical event is missing anyway. - if err := gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.FSAPI.KeyRing(), func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return r.DB.GetUserIDForSender(ctx, roomAliasOrID, senderID) + if err := gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.FSAPI.KeyRing(), func(roomID, senderID string) (*spec.UserID, error) { + return r.DB.GetUserIDForSender(ctx, roomID, senderID) }); err != nil { continue nextAuthEvent } @@ -712,8 +712,8 @@ nextAuthEvent: } // Check if the auth event should be rejected. - err := gomatrixserverlib.Allowed(authEvent, auth, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return r.DB.GetUserIDForSender(ctx, roomAliasOrID, senderID) + err := gomatrixserverlib.Allowed(authEvent, auth, func(roomID, senderID string) (*spec.UserID, error) { + return r.DB.GetUserIDForSender(ctx, roomID, senderID) }) if isRejected = err != nil; isRejected { logger.WithError(err).Warnf("Auth event %s rejected", authEvent.EventID()) diff --git a/roomserver/internal/input/input_events_test.go b/roomserver/internal/input/input_events_test.go index 17a5aa1294..0ba7d19f59 100644 --- a/roomserver/internal/input/input_events_test.go +++ b/roomserver/internal/input/input_events_test.go @@ -58,7 +58,7 @@ func Test_EventAuth(t *testing.T) { } // Finally check that the event is NOT allowed - if err := gomatrixserverlib.Allowed(ev.PDU, &allower, func(roomAliasOrID, senderID string) (*spec.UserID, error) { return spec.NewUserID(senderID, true) }); err == nil { + if err := gomatrixserverlib.Allowed(ev.PDU, &allower, func(roomID, senderID string) (*spec.UserID, error) { return spec.NewUserID(senderID, true) }); err == nil { t.Fatalf("event should not be allowed, but it was") } } diff --git a/roomserver/internal/input/input_missing.go b/roomserver/internal/input/input_missing.go index 1a3c2c142d..ac0670fc3c 100644 --- a/roomserver/internal/input/input_missing.go +++ b/roomserver/internal/input/input_missing.go @@ -473,8 +473,8 @@ func (t *missingStateReq) resolveStatesAndCheck(ctx context.Context, roomVersion stateEventList = append(stateEventList, state.StateEvents...) } resolvedStateEvents, err := gomatrixserverlib.ResolveConflicts( - roomVersion, gomatrixserverlib.ToPDUs(stateEventList), gomatrixserverlib.ToPDUs(authEventList), func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return t.db.GetUserIDForSender(ctx, roomAliasOrID, senderID) + roomVersion, gomatrixserverlib.ToPDUs(stateEventList), gomatrixserverlib.ToPDUs(authEventList), func(roomID, senderID string) (*spec.UserID, error) { + return t.db.GetUserIDForSender(ctx, roomID, senderID) }, ) if err != nil { @@ -482,8 +482,8 @@ func (t *missingStateReq) resolveStatesAndCheck(ctx context.Context, roomVersion } // apply the current event retryAllowedState: - if err = checkAllowedByState(backwardsExtremity, resolvedStateEvents, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return t.db.GetUserIDForSender(ctx, roomAliasOrID, senderID) + if err = checkAllowedByState(backwardsExtremity, resolvedStateEvents, func(roomID, senderID string) (*spec.UserID, error) { + return t.db.GetUserIDForSender(ctx, roomID, senderID) }); err != nil { switch missing := err.(type) { case gomatrixserverlib.MissingAuthEventError: @@ -569,8 +569,8 @@ func (t *missingStateReq) getMissingEvents(ctx context.Context, e gomatrixserver // will be added and duplicates will be removed. missingEvents := make([]gomatrixserverlib.PDU, 0, len(missingResp.Events)) for _, ev := range missingResp.Events.UntrustedEvents(roomVersion) { - if err = gomatrixserverlib.VerifyEventSignatures(ctx, ev, t.keys, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return t.db.GetUserIDForSender(ctx, roomAliasOrID, senderID) + if err = gomatrixserverlib.VerifyEventSignatures(ctx, ev, t.keys, func(roomID, senderID string) (*spec.UserID, error) { + return t.db.GetUserIDForSender(ctx, roomID, senderID) }); err != nil { continue } @@ -660,8 +660,8 @@ func (t *missingStateReq) lookupMissingStateViaState( authEvents, stateEvents, err := gomatrixserverlib.CheckStateResponse(ctx, &fclient.RespState{ StateEvents: state.GetStateEvents(), AuthEvents: state.GetAuthEvents(), - }, roomVersion, t.keys, nil, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return t.db.GetUserIDForSender(ctx, roomAliasOrID, senderID) + }, roomVersion, t.keys, nil, func(roomID, senderID string) (*spec.UserID, error) { + return t.db.GetUserIDForSender(ctx, roomID, senderID) }) if err != nil { return nil, err @@ -897,8 +897,8 @@ func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixs t.log.WithField("missing_event_id", missingEventID).Warnf("Failed to get missing /event for event ID from %d server(s)", len(t.servers)) return nil, fmt.Errorf("wasn't able to find event via %d server(s)", len(t.servers)) } - if err := gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return t.db.GetUserIDForSender(ctx, roomAliasOrID, senderID) + if err := gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys, func(roomID, senderID string) (*spec.UserID, error) { + return t.db.GetUserIDForSender(ctx, roomID, senderID) }); err != nil { t.log.WithError(err).Warnf("Couldn't validate signature of event %q from /event", event.EventID()) return nil, verifySigError{event.EventID(), err} diff --git a/roomserver/internal/perform/perform_admin.go b/roomserver/internal/perform/perform_admin.go index 488a106489..ca736cb651 100644 --- a/roomserver/internal/perform/perform_admin.go +++ b/roomserver/internal/perform/perform_admin.go @@ -262,16 +262,16 @@ func (r *Admin) PerformAdminDownloadState( return fmt.Errorf("r.Inputer.FSAPI.LookupState (%q): %s", fwdExtremity, err) } for _, authEvent := range state.GetAuthEvents().UntrustedEvents(roomInfo.RoomVersion) { - if err = gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.Inputer.KeyRing, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return r.DB.GetUserIDForSender(ctx, roomAliasOrID, senderID) + if err = gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.Inputer.KeyRing, func(roomID, senderID string) (*spec.UserID, error) { + return r.DB.GetUserIDForSender(ctx, roomID, senderID) }); err != nil { continue } authEventMap[authEvent.EventID()] = authEvent } for _, stateEvent := range state.GetStateEvents().UntrustedEvents(roomInfo.RoomVersion) { - if err = gomatrixserverlib.VerifyEventSignatures(ctx, stateEvent, r.Inputer.KeyRing, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return r.DB.GetUserIDForSender(ctx, roomAliasOrID, senderID) + if err = gomatrixserverlib.VerifyEventSignatures(ctx, stateEvent, r.Inputer.KeyRing, func(roomID, senderID string) (*spec.UserID, error) { + return r.DB.GetUserIDForSender(ctx, roomID, senderID) }); err != nil { continue } diff --git a/roomserver/internal/perform/perform_backfill.go b/roomserver/internal/perform/perform_backfill.go index bdb9bf6c94..0f743f4e48 100644 --- a/roomserver/internal/perform/perform_backfill.go +++ b/roomserver/internal/perform/perform_backfill.go @@ -121,8 +121,8 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform // Specifically the test "Outbound federation can backfill events" events, err := gomatrixserverlib.RequestBackfill( ctx, req.VirtualHost, requester, - r.KeyRing, req.RoomID, info.RoomVersion, req.PrevEventIDs(), 100, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return r.DB.GetUserIDForSender(ctx, roomAliasOrID, senderID) + r.KeyRing, req.RoomID, info.RoomVersion, req.PrevEventIDs(), 100, func(roomID, senderID string) (*spec.UserID, error) { + return r.DB.GetUserIDForSender(ctx, roomID, senderID) }, ) // Only return an error if we really couldn't get any events. @@ -212,8 +212,8 @@ func (r *Backfiller) fetchAndStoreMissingEvents(ctx context.Context, roomVer gom continue } loader := gomatrixserverlib.NewEventsLoader(roomVer, r.KeyRing, backfillRequester, backfillRequester.ProvideEvents, false) - result, err := loader.LoadAndVerify(ctx, res.PDUs, gomatrixserverlib.TopologicalOrderByPrevEvents, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return r.DB.GetUserIDForSender(ctx, roomAliasOrID, senderID) + result, err := loader.LoadAndVerify(ctx, res.PDUs, gomatrixserverlib.TopologicalOrderByPrevEvents, func(roomID, senderID string) (*spec.UserID, error) { + return r.DB.GetUserIDForSender(ctx, roomID, senderID) }) if err != nil { logger.WithError(err).Warn("failed to load and verify event") diff --git a/roomserver/internal/perform/perform_create_room.go b/roomserver/internal/perform/perform_create_room.go index 7448bca4f6..bfa1f631ea 100644 --- a/roomserver/internal/perform/perform_create_room.go +++ b/roomserver/internal/perform/perform_create_room.go @@ -308,8 +308,8 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo } } - if err = gomatrixserverlib.Allowed(ev, &authEvents, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return c.DB.GetUserIDForSender(ctx, roomAliasOrID, senderID) + if err = gomatrixserverlib.Allowed(ev, &authEvents, func(roomID, senderID string) (*spec.UserID, error) { + return c.DB.GetUserIDForSender(ctx, roomID, senderID) }); err != nil { util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.Allowed failed") return "", &util.JSONResponse{ diff --git a/roomserver/internal/perform/perform_invite.go b/roomserver/internal/perform/perform_invite.go index 4050c294f6..e8e20ede29 100644 --- a/roomserver/internal/perform/perform_invite.go +++ b/roomserver/internal/perform/perform_invite.go @@ -156,8 +156,8 @@ func (r *Inviter) PerformInvite( StrippedState: req.InviteRoomState, MembershipQuerier: &api.MembershipQuerier{Roomserver: r.RSAPI}, StateQuerier: &QueryState{r.DB}, - UserIDQuerier: func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return r.DB.GetUserIDForSender(ctx, roomAliasOrID, senderID) + UserIDQuerier: func(roomID, senderID string) (*spec.UserID, error) { + return r.DB.GetUserIDForSender(ctx, roomID, senderID) }, } inviteEvent, err := gomatrixserverlib.PerformInvite(ctx, input, r.FSAPI) diff --git a/roomserver/internal/perform/perform_upgrade.go b/roomserver/internal/perform/perform_upgrade.go index 99ad21ccd5..7513aab46e 100644 --- a/roomserver/internal/perform/perform_upgrade.go +++ b/roomserver/internal/perform/perform_upgrade.go @@ -484,8 +484,8 @@ func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, user } - if err = gomatrixserverlib.Allowed(event, &authEvents, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return r.URSAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + if err = gomatrixserverlib.Allowed(event, &authEvents, func(roomID, senderID string) (*spec.UserID, error) { + return r.URSAPI.QueryUserIDForSender(ctx, roomID, senderID) }); err != nil { return fmt.Errorf("Failed to auth new %q event: %w", builder.Type, err) } @@ -569,8 +569,8 @@ func (r *Upgrader) makeHeaderedEvent(ctx context.Context, evTime time.Time, user stateEvents[i] = queryRes.StateEvents[i].PDU } provider := gomatrixserverlib.NewAuthEvents(stateEvents) - if err = gomatrixserverlib.Allowed(headeredEvent.PDU, &provider, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return r.URSAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + if err = gomatrixserverlib.Allowed(headeredEvent.PDU, &provider, func(roomID, senderID string) (*spec.UserID, error) { + return r.URSAPI.QueryUserIDForSender(ctx, roomID, senderID) }); err != nil { return nil, api.ErrNotAllowed{Err: fmt.Errorf("failed to auth new %q event: %w", proto.Type, err)} // TODO: Is this error string comprehensible to the client? } diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index fdc37c7b23..7828c8b8c7 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -122,8 +122,8 @@ func (r *Queryer) QueryStateAfterEvents( } stateEvents, err = gomatrixserverlib.ResolveConflicts( - info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return r.DB.GetUserIDForSender(ctx, roomAliasOrID, senderID) + info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomID, senderID string) (*spec.UserID, error) { + return r.DB.GetUserIDForSender(ctx, roomID, senderID) }, ) if err != nil { @@ -351,8 +351,8 @@ func (r *Queryer) QueryMembershipsForRoom( return fmt.Errorf("r.DB.Events: %w", err) } for _, event := range events { - clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return r.DB.GetUserIDForSender(ctx, roomAliasOrID, senderID) + clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { + return r.DB.GetUserIDForSender(ctx, roomID, senderID) }) response.JoinEvents = append(response.JoinEvents, clientEvent) } @@ -402,8 +402,8 @@ func (r *Queryer) QueryMembershipsForRoom( } for _, event := range events { - clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return r.DB.GetUserIDForSender(ctx, roomAliasOrID, senderID) + clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { + return r.DB.GetUserIDForSender(ctx, roomID, senderID) }) response.JoinEvents = append(response.JoinEvents, clientEvent) } @@ -594,8 +594,8 @@ func (r *Queryer) QueryStateAndAuthChain( if request.ResolveState { stateEvents, err = gomatrixserverlib.ResolveConflicts( - info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return r.DB.GetUserIDForSender(ctx, roomAliasOrID, senderID) + info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomID, senderID string) (*spec.UserID, error) { + return r.DB.GetUserIDForSender(ctx, roomID, senderID) }, ) if err != nil { @@ -1043,10 +1043,10 @@ func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, req *api.Query return nil } -func (r *Queryer) QuerySenderIDForUser(ctx context.Context, roomAliasOrID string, userID spec.UserID) (string, error) { - return r.DB.GetSenderIDForUser(ctx, roomAliasOrID, userID) +func (r *Queryer) QuerySenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (string, error) { + return r.DB.GetSenderIDForUser(ctx, roomID, userID) } -func (r *Queryer) QueryUserIDForSender(ctx context.Context, roomAliasOrID string, senderID string) (*spec.UserID, error) { - return r.DB.GetUserIDForSender(ctx, roomAliasOrID, senderID) +func (r *Queryer) QueryUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) { + return r.DB.GetUserIDForSender(ctx, roomID, senderID) } diff --git a/roomserver/state/state.go b/roomserver/state/state.go index 1607971f74..3131cbff2f 100644 --- a/roomserver/state/state.go +++ b/roomserver/state/state.go @@ -44,7 +44,7 @@ type StateResolutionStorage interface { AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error) Events(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, eventNIDs []types.EventNID) ([]types.Event, error) EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error) - GetUserIDForSender(ctx context.Context, roomAliasOrID string, senderID string) (*spec.UserID, error) + GetUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) } type StateResolution struct { @@ -947,8 +947,8 @@ func (v *StateResolution) resolveConflictsV1( } // Resolve the conflicts. - resolvedEvents := gomatrixserverlib.ResolveStateConflicts(conflictedEvents, authEvents, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return v.db.GetUserIDForSender(ctx, roomAliasOrID, senderID) + resolvedEvents := gomatrixserverlib.ResolveStateConflicts(conflictedEvents, authEvents, func(roomID, senderID string) (*spec.UserID, error) { + return v.db.GetUserIDForSender(ctx, roomID, senderID) }) // Map from the full events back to numeric state entries. @@ -1061,8 +1061,8 @@ func (v *StateResolution) resolveConflictsV2( conflictedEvents, nonConflictedEvents, authEvents, - func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return v.db.GetUserIDForSender(ctx, roomAliasOrID, senderID) + func(roomID, senderID string) (*spec.UserID, error) { + return v.db.GetUserIDForSender(ctx, roomID, senderID) }, ) }() diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index 89e40b49be..2d007bed5b 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -167,9 +167,9 @@ type Database interface { // GetKnownUsers searches all users that userID knows about. GetKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error) // GetKnownUsers tries to obtain the current mxid for a given user. - GetUserIDForSender(ctx context.Context, roomAliasOrID string, senderID string) (*spec.UserID, error) + GetUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) // GetKnownUsers tries to obtain the current senderID for a given user. - GetSenderIDForUser(ctx context.Context, roomAliasOrID string, userID spec.UserID) (string, error) + GetSenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (string, error) // GetKnownRooms returns a list of all rooms we know about. GetKnownRooms(ctx context.Context) ([]string, error) // ForgetRoom sets a flag in the membership table, that the user wishes to forget a specific room @@ -215,7 +215,7 @@ type RoomDatabase interface { GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error) GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKey *string) (types.EventStateKeyNID, error) GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*types.HeaderedEvent, error) - GetUserIDForSender(ctx context.Context, roomAliasOrID string, senderID string) (*spec.UserID, error) + GetUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) } type EventDatabase interface { diff --git a/roomserver/storage/shared/room_updater.go b/roomserver/storage/shared/room_updater.go index 4068caebc9..7350013836 100644 --- a/roomserver/storage/shared/room_updater.go +++ b/roomserver/storage/shared/room_updater.go @@ -252,6 +252,6 @@ func (u *RoomUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, ta return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomInfo.RoomNID, targetUserNID, targetLocal) } -func (u *RoomUpdater) GetUserIDForSender(ctx context.Context, roomAliasOrID string, senderID string) (*spec.UserID, error) { - return u.d.GetUserIDForSender(ctx, roomAliasOrID, senderID) +func (u *RoomUpdater) GetUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) { + return u.d.GetUserIDForSender(ctx, roomID, senderID) } diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index c88f272956..406d7cf1cb 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -1524,12 +1524,12 @@ func (d *Database) GetKnownUsers(ctx context.Context, userID, searchString strin return d.MembershipTable.SelectKnownUsers(ctx, nil, stateKeyNID, searchString, limit) } -func (d *Database) GetUserIDForSender(ctx context.Context, roomAliasOrID string, senderID string) (*spec.UserID, error) { +func (d *Database) GetUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) { // TODO: Use real logic once DB for pseudoIDs is in place return spec.NewUserID(senderID, true) } -func (d *Database) GetSenderIDForUser(ctx context.Context, roomAliasOrID string, userID spec.UserID) (string, error) { +func (d *Database) GetSenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (string, error) { // TODO: Use real logic once DB for pseudoIDs is in place return userID.String(), nil } diff --git a/setup/mscs/msc2836/msc2836.go b/setup/mscs/msc2836/msc2836.go index 6e5bc30276..5ce3b430b3 100644 --- a/setup/mscs/msc2836/msc2836.go +++ b/setup/mscs/msc2836/msc2836.go @@ -94,8 +94,8 @@ type MSC2836EventRelationshipsResponse struct { func toClientResponse(ctx context.Context, res *MSC2836EventRelationshipsResponse, rsAPI roomserver.RoomserverInternalAPI) *EventRelationshipResponse { out := &EventRelationshipResponse{ - Events: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(res.ParsedEvents), synctypes.FormatAll, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + Events: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(res.ParsedEvents), synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }), Limited: res.Limited, NextBatch: res.NextBatch, diff --git a/setup/mscs/msc2836/msc2836_test.go b/setup/mscs/msc2836/msc2836_test.go index c8bc8b509f..c463fd72bd 100644 --- a/setup/mscs/msc2836/msc2836_test.go +++ b/setup/mscs/msc2836/msc2836_test.go @@ -525,7 +525,7 @@ type testRoomserverAPI struct { events map[string]*types.HeaderedEvent } -func (r *testRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomAliasOrID string, senderID string) (*spec.UserID, error) { +func (r *testRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) { return spec.NewUserID(senderID, true) } diff --git a/syncapi/routing/context.go b/syncapi/routing/context.go index 5046ba1cde..a646622fcb 100644 --- a/syncapi/routing/context.go +++ b/syncapi/routing/context.go @@ -193,19 +193,19 @@ func Context( } } - eventsBeforeClient := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsBeforeFiltered), synctypes.FormatAll, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + eventsBeforeClient := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsBeforeFiltered), synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }) - eventsAfterClient := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsAfterFiltered), synctypes.FormatAll, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + eventsAfterClient := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsAfterFiltered), synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }) newState := state if filter.LazyLoadMembers { allEvents := append(eventsBeforeFiltered, eventsAfterFiltered...) allEvents = append(allEvents, &requestedEvent) - evs := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(allEvents), synctypes.FormatAll, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + evs := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(allEvents), synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }) newState, err = applyLazyLoadMembers(ctx, device, snapshot, roomID, evs, lazyLoadCache) if err != nil { @@ -217,15 +217,15 @@ func Context( } } - ev := synctypes.ToClientEvent(&requestedEvent, synctypes.FormatAll, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + ev := synctypes.ToClientEvent(&requestedEvent, synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }) response := ContextRespsonse{ Event: &ev, EventsAfter: eventsAfterClient, EventsBefore: eventsBeforeClient, - State: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(newState), synctypes.FormatAll, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + State: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(newState), synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }), } diff --git a/syncapi/routing/getevent.go b/syncapi/routing/getevent.go index 878c5d90ab..0d47477aa8 100644 --- a/syncapi/routing/getevent.go +++ b/syncapi/routing/getevent.go @@ -103,8 +103,8 @@ func GetEvent( return util.JSONResponse{ Code: http.StatusOK, - JSON: synctypes.ToClientEvent(events[0], synctypes.FormatAll, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return rsAPI.QueryUserIDForSender(req.Context(), roomAliasOrID, senderID) + JSON: synctypes.ToClientEvent(events[0], synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID) }), } } diff --git a/syncapi/routing/memberships.go b/syncapi/routing/memberships.go index 699e08015c..9c2319dd92 100644 --- a/syncapi/routing/memberships.go +++ b/syncapi/routing/memberships.go @@ -153,8 +153,8 @@ func GetMemberships( } return util.JSONResponse{ Code: http.StatusOK, - JSON: getMembershipResponse{synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(result), synctypes.FormatAll, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return rsAPI.QueryUserIDForSender(req.Context(), roomAliasOrID, senderID) + JSON: getMembershipResponse{synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(result), synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID) })}, } } diff --git a/syncapi/routing/messages.go b/syncapi/routing/messages.go index 9e7c33d5d8..879739d00d 100644 --- a/syncapi/routing/messages.go +++ b/syncapi/routing/messages.go @@ -273,8 +273,8 @@ func OnIncomingMessagesRequest( JSON: spec.InternalServerError{}, } } - res.State = append(res.State, synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(membershipEvents), synctypes.FormatAll, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return rsAPI.QueryUserIDForSender(req.Context(), roomAliasOrID, senderID) + res.State = append(res.State, synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(membershipEvents), synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID) })...) } @@ -385,8 +385,8 @@ func (r *messagesReq) retrieveEvents(ctx context.Context, rsAPI api.SyncRoomserv "events_before": len(events), "events_after": len(filteredEvents), }).Debug("applied history visibility (messages)") - return synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(filteredEvents), synctypes.FormatAll, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + return synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(filteredEvents), synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }), start, end, err } diff --git a/syncapi/routing/relations.go b/syncapi/routing/relations.go index 3913f2d86d..30c02d2934 100644 --- a/syncapi/routing/relations.go +++ b/syncapi/routing/relations.go @@ -116,8 +116,8 @@ func Relations( for _, event := range filteredEvents { res.Chunk = append( res.Chunk, - synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return rsAPI.QueryUserIDForSender(req.Context(), roomAliasOrID, senderID) + synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID) }), ) } diff --git a/syncapi/routing/search.go b/syncapi/routing/search.go index 991545315e..546624cbe2 100644 --- a/syncapi/routing/search.go +++ b/syncapi/routing/search.go @@ -228,17 +228,17 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts Context: SearchContextResponse{ Start: startToken.String(), End: endToken.String(), - EventsAfter: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsAfter), synctypes.FormatSync, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return rsAPI.QueryUserIDForSender(req.Context(), roomAliasOrID, senderID) + EventsAfter: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsAfter), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID) }), - EventsBefore: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsBefore), synctypes.FormatSync, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return rsAPI.QueryUserIDForSender(req.Context(), roomAliasOrID, senderID) + EventsBefore: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsBefore), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID) }), ProfileInfo: profileInfos, }, Rank: eventScore[event.EventID()].Score, - Result: synctypes.ToClientEvent(event, synctypes.FormatAll, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return rsAPI.QueryUserIDForSender(req.Context(), roomAliasOrID, senderID) + Result: synctypes.ToClientEvent(event, synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID) }), }) roomGroup := groups[event.RoomID()] @@ -254,8 +254,8 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts JSON: spec.InternalServerError{}, } } - stateForRooms[event.RoomID()] = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(state), synctypes.FormatSync, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return rsAPI.QueryUserIDForSender(req.Context(), roomAliasOrID, senderID) + stateForRooms[event.RoomID()] = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(state), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID) }) } } diff --git a/syncapi/routing/search_test.go b/syncapi/routing/search_test.go index 02410820ba..b36be82383 100644 --- a/syncapi/routing/search_test.go +++ b/syncapi/routing/search_test.go @@ -25,7 +25,7 @@ import ( type FakeSyncRoomserverAPI struct{ rsapi.SyncRoomserverAPI } -func (f *FakeSyncRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomAliasOrID string, senderID string) (*spec.UserID, error) { +func (f *FakeSyncRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) { return spec.NewUserID(senderID, true) } diff --git a/syncapi/streams/stream_invite.go b/syncapi/streams/stream_invite.go index fd7b9dbb44..52bb380c87 100644 --- a/syncapi/streams/stream_invite.go +++ b/syncapi/streams/stream_invite.go @@ -68,8 +68,8 @@ func (p *InviteStreamProvider) IncrementalSync( if _, ok := req.IgnoredUsers.List[inviteEvent.SenderID()]; ok { continue } - ir := types.NewInviteResponse(inviteEvent, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return p.rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + ir := types.NewInviteResponse(inviteEvent, func(roomID, senderID string) (*spec.UserID, error) { + return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }) req.Response.Rooms.Invite[roomID] = ir } diff --git a/syncapi/streams/stream_pdu.go b/syncapi/streams/stream_pdu.go index 6bc4bfdf43..8f83a08963 100644 --- a/syncapi/streams/stream_pdu.go +++ b/syncapi/streams/stream_pdu.go @@ -376,14 +376,14 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( } } jr.Timeline.PrevBatch = &prevBatch - jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return p.rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) { + return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }) // If we are limited by the filter AND the history visibility filter // didn't "remove" events, return that the response is limited. jr.Timeline.Limited = (limited && len(events) == len(recentEvents)) || delta.NewlyJoined - jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return p.rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) { + return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }) req.Response.Rooms.Join[delta.RoomID] = jr @@ -391,12 +391,12 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( jr := types.NewJoinResponse() jr.Timeline.PrevBatch = &prevBatch // TODO: Apply history visibility on peeked rooms - jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(recentEvents), synctypes.FormatSync, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return p.rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(recentEvents), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) { + return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }) jr.Timeline.Limited = limited - jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return p.rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) { + return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }) req.Response.Rooms.Peek[delta.RoomID] = jr @@ -406,14 +406,14 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( case spec.Ban: lr := types.NewLeaveResponse() lr.Timeline.PrevBatch = &prevBatch - lr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return p.rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + lr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) { + return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }) // If we are limited by the filter AND the history visibility filter // didn't "remove" events, return that the response is limited. lr.Timeline.Limited = limited && len(events) == len(recentEvents) - lr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return p.rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + lr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) { + return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }) req.Response.Rooms.Leave[delta.RoomID] = lr } @@ -564,14 +564,14 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync( } jr.Timeline.PrevBatch = prevBatch - jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return p.rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) { + return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }) // If we are limited by the filter AND the history visibility filter // didn't "remove" events, return that the response is limited. jr.Timeline.Limited = limited && len(events) == len(recentEvents) - jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(stateEvents), synctypes.FormatSync, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return p.rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(stateEvents), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) { + return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }) return jr, nil } diff --git a/syncapi/syncapi_test.go b/syncapi/syncapi_test.go index ac46e10eb1..78c857ab9b 100644 --- a/syncapi/syncapi_test.go +++ b/syncapi/syncapi_test.go @@ -40,7 +40,7 @@ type syncRoomserverAPI struct { rooms []*test.Room } -func (s *syncRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomAliasOrID string, senderID string) (*spec.UserID, error) { +func (s *syncRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) { return spec.NewUserID(senderID, true) } diff --git a/syncapi/synctypes/clientevent_test.go b/syncapi/synctypes/clientevent_test.go index 5ae353f799..b76b38c178 100644 --- a/syncapi/synctypes/clientevent_test.go +++ b/syncapi/synctypes/clientevent_test.go @@ -24,7 +24,7 @@ import ( "github.com/matrix-org/gomatrixserverlib/spec" ) -func UserIDForSender(roomAliasOrID string, senderID string) (*spec.UserID, error) { +func UserIDForSender(roomID string, senderID string) (*spec.UserID, error) { return spec.NewUserID(senderID, true) } diff --git a/syncapi/types/types_test.go b/syncapi/types/types_test.go index 2cd8e74723..a93e9122f7 100644 --- a/syncapi/types/types_test.go +++ b/syncapi/types/types_test.go @@ -11,7 +11,7 @@ import ( "github.com/matrix-org/gomatrixserverlib/spec" ) -func UserIDForSender(roomAliasOrID string, senderID string) (*spec.UserID, error) { +func UserIDForSender(roomID string, senderID string) (*spec.UserID, error) { return spec.NewUserID(senderID, true) } diff --git a/test/room.go b/test/room.go index aefdd87697..4cdb73aa33 100644 --- a/test/room.go +++ b/test/room.go @@ -39,7 +39,7 @@ var ( roomIDCounter = int64(0) ) -func UserIDForSender(roomAliasOrID string, senderID string) (*spec.UserID, error) { +func UserIDForSender(roomID string, senderID string) (*spec.UserID, error) { return spec.NewUserID(senderID, true) } diff --git a/userapi/consumers/roomserver.go b/userapi/consumers/roomserver.go index 2d7bc4814e..edc48e52a9 100644 --- a/userapi/consumers/roomserver.go +++ b/userapi/consumers/roomserver.go @@ -301,8 +301,8 @@ func (s *OutputRoomEventConsumer) processMessage(ctx context.Context, event *rst switch { case event.Type() == spec.MRoomMember: - cevent := synctypes.ToClientEvent(event, synctypes.FormatAll, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return s.rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + cevent := synctypes.ToClientEvent(event, synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { + return s.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }) var member *localMembership member, err = newLocalMembership(&cevent) @@ -536,8 +536,8 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *rstype // UNSPEC: the spec doesn't say this is a ClientEvent, but the // fields seem to match. room_id should be missing, which // matches the behaviour of FormatSync. - Event: synctypes.ToClientEvent(event, synctypes.FormatSync, func(roomAliasOrID string, senderID string) (*spec.UserID, error) { - return s.rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + Event: synctypes.ToClientEvent(event, synctypes.FormatSync, func(roomID string, senderID string) (*spec.UserID, error) { + return s.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }), // TODO: this is per-device, but it's not part of the primary // key. So inserting one notification per profile tag doesn't @@ -659,8 +659,8 @@ func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event * roomSize: roomSize, } eval := pushrules.NewRuleSetEvaluator(ec, &ruleSets.Global) - rule, err := eval.MatchEvent(event.PDU, func(roomAliasOrID, senderID string) (*spec.UserID, error) { - return s.rsAPI.QueryUserIDForSender(ctx, roomAliasOrID, senderID) + rule, err := eval.MatchEvent(event.PDU, func(roomID, senderID string) (*spec.UserID, error) { + return s.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }) if err != nil { return nil, err diff --git a/userapi/consumers/roomserver_test.go b/userapi/consumers/roomserver_test.go index 09e57b8955..d677a7ba9c 100644 --- a/userapi/consumers/roomserver_test.go +++ b/userapi/consumers/roomserver_test.go @@ -47,7 +47,7 @@ func mustCreateEvent(t *testing.T, content string) *types.HeaderedEvent { type FakeUserRoomserverAPI struct{ rsapi.UserRoomserverAPI } -func (f *FakeUserRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomAliasOrID string, senderID string) (*spec.UserID, error) { +func (f *FakeUserRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) { return spec.NewUserID(senderID, true) } diff --git a/userapi/util/notify_test.go b/userapi/util/notify_test.go index 778461161b..86446b2db1 100644 --- a/userapi/util/notify_test.go +++ b/userapi/util/notify_test.go @@ -23,7 +23,7 @@ import ( userUtil "github.com/matrix-org/dendrite/userapi/util" ) -func UserIDForSender(roomAliasOrID string, senderID string) (*spec.UserID, error) { +func UserIDForSender(roomID string, senderID string) (*spec.UserID, error) { return spec.NewUserID(senderID, true) } From c7bdc2e7f474095cca592e283d4da42f421d6ae7 Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Tue, 6 Jun 2023 10:38:44 -0600 Subject: [PATCH 19/28] UserAPI roomserver uses userIDs for push rules --- userapi/consumers/roomserver.go | 16 ++++++++++------ userapi/consumers/roomserver_test.go | 8 ++++---- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/userapi/consumers/roomserver.go b/userapi/consumers/roomserver.go index edc48e52a9..3cb44f0fb0 100644 --- a/userapi/consumers/roomserver.go +++ b/userapi/consumers/roomserver.go @@ -620,9 +620,9 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *rstype // user. Returns actions (including dont_notify). func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event *rstypes.HeaderedEvent, mem *localMembership, roomSize int) ([]*pushrules.Action, error) { user := "" - userID, err := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()) + sender, err := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()) if err == nil { - user = userID.String() + user = sender.String() } if user == mem.UserID { // SPEC: Homeservers MUST NOT notify the Push Gateway for @@ -641,9 +641,8 @@ func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event * if err != nil { return nil, err } - sender := event.SenderID() - if _, ok := ignored.List[sender]; ok { - return nil, fmt.Errorf("user %s is ignored", sender) + if _, ok := ignored.List[sender.String()]; ok { + return nil, fmt.Errorf("user %s is ignored", sender.String()) } } ruleSets, err := s.db.QueryPushRules(ctx, mem.Localpart, mem.Domain) @@ -767,6 +766,11 @@ func (s *OutputRoomEventConsumer) notifyHTTP(ctx context.Context, event *rstypes } default: + sender, err := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()) + if err != nil { + logger.WithError(err).Errorf("Failed to get userID for sender %s", event.SenderID()) + return nil, err + } req = pushgateway.NotifyRequest{ Notification: pushgateway.Notification{ Content: event.Content(), @@ -778,7 +782,7 @@ func (s *OutputRoomEventConsumer) notifyHTTP(ctx context.Context, event *rstypes ID: event.EventID(), RoomID: event.RoomID(), RoomName: roomName, - Sender: event.SenderID(), // TODO: Should push use UserIDs? + Sender: sender.String(), Type: event.Type(), }, } diff --git a/userapi/consumers/roomserver_test.go b/userapi/consumers/roomserver_test.go index d677a7ba9c..899a5aaf09 100644 --- a/userapi/consumers/roomserver_test.go +++ b/userapi/consumers/roomserver_test.go @@ -68,13 +68,13 @@ func Test_evaluatePushRules(t *testing.T) { }{ { name: "m.receipt doesn't notify", - eventContent: `{"type":"m.receipt","sender":"@test:remotehost"}`, + eventContent: `{"type":"m.receipt"}`, wantAction: pushrules.UnknownAction, wantActions: nil, }, { name: "m.reaction doesn't notify", - eventContent: `{"type":"m.reaction","sender":"@test:remotehost"}`, + eventContent: `{"type":"m.reaction"}`, wantAction: pushrules.DontNotifyAction, wantActions: []*pushrules.Action{ { @@ -84,7 +84,7 @@ func Test_evaluatePushRules(t *testing.T) { }, { name: "m.room.message notifies", - eventContent: `{"type":"m.room.message","sender":"@test:remotehost"}`, + eventContent: `{"type":"m.room.message"}`, wantNotify: true, wantAction: pushrules.NotifyAction, wantActions: []*pushrules.Action{ @@ -93,7 +93,7 @@ func Test_evaluatePushRules(t *testing.T) { }, { name: "m.room.message highlights", - eventContent: `{"type":"m.room.message", "content": {"body": "test"},"sender":"@test:remotehost" }`, + eventContent: `{"type":"m.room.message", "content": {"body": "test"}}`, wantNotify: true, wantAction: pushrules.NotifyAction, wantActions: []*pushrules.Action{ From fbbf9b085ad68e7543a2f1caff07f868c7893542 Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Tue, 6 Jun 2023 10:50:59 -0600 Subject: [PATCH 20/28] Use userID for sync stream invites --- syncapi/streams/stream_invite.go | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/syncapi/streams/stream_invite.go b/syncapi/streams/stream_invite.go index 52bb380c87..0e62f70bfb 100644 --- a/syncapi/streams/stream_invite.go +++ b/syncapi/streams/stream_invite.go @@ -64,8 +64,14 @@ func (p *InviteStreamProvider) IncrementalSync( } for roomID, inviteEvent := range invites { + user := "" + sender, err := p.rsAPI.QueryUserIDForSender(ctx, inviteEvent.RoomID(), inviteEvent.SenderID()) + if err == nil { + user = sender.String() + } + // skip ignored user events - if _, ok := req.IgnoredUsers.List[inviteEvent.SenderID()]; ok { + if _, ok := req.IgnoredUsers.List[user]; ok { continue } ir := types.NewInviteResponse(inviteEvent, func(roomID, senderID string) (*spec.UserID, error) { From d30bae35092c5464bc3b28ceb6b4b9f8c057167e Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Tue, 6 Jun 2023 11:00:32 -0600 Subject: [PATCH 21/28] Use userID for sync profile search --- syncapi/routing/search.go | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/syncapi/routing/search.go b/syncapi/routing/search.go index 546624cbe2..e4e173df2d 100644 --- a/syncapi/routing/search.go +++ b/syncapi/routing/search.go @@ -205,11 +205,17 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts profileInfos := make(map[string]ProfileInfoResponse) for _, ev := range append(eventsBefore, eventsAfter...) { - profile, ok := knownUsersProfiles[event.SenderID()] + userID, err := rsAPI.QueryUserIDForSender(req.Context(), ev.RoomID(), ev.SenderID()) + if err != nil { + logrus.WithError(err).WithField("sender_id", event.SenderID()).Warn("failed to query userprofile") + continue + } + + profile, ok := knownUsersProfiles[userID.String()] if !ok { stateEvent, err := snapshot.GetStateEvent(ctx, ev.RoomID(), spec.MRoomMember, ev.SenderID()) if err != nil { - logrus.WithError(err).WithField("user_id", event.SenderID()).Warn("failed to query userprofile") + logrus.WithError(err).WithField("sender_id", event.SenderID()).Warn("failed to query userprofile") continue } if stateEvent == nil { @@ -219,9 +225,9 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts AvatarURL: gjson.GetBytes(stateEvent.Content(), "avatar_url").Str, DisplayName: gjson.GetBytes(stateEvent.Content(), "displayname").Str, } - knownUsersProfiles[event.SenderID()] = profile + knownUsersProfiles[userID.String()] = profile } - profileInfos[ev.SenderID()] = profile + profileInfos[userID.String()] = profile } results = append(results, Result{ From ac61081624e63bb746a7cfaf7d77177f86613e77 Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Tue, 6 Jun 2023 11:08:26 -0600 Subject: [PATCH 22/28] Remove redundant check in fed/SendInvite --- federationapi/internal/perform.go | 5 ----- 1 file changed, 5 deletions(-) diff --git a/federationapi/internal/perform.go b/federationapi/internal/perform.go index ef1396f704..2d59d0f930 100644 --- a/federationapi/internal/perform.go +++ b/federationapi/internal/perform.go @@ -521,11 +521,6 @@ func (r *FederationInternalAPI) SendInvite( if err != nil { return nil, err } - // For portable accounts, we need to verify the inviter domain is still associated with this server. - // The userID of the inviter may have changed to another server in which case we cannot send the invite. - if !r.cfg.Matrix.IsLocalServerName(inviter.Domain()) { - return nil, fmt.Errorf("the invite must be from a local user") - } if event.StateKey() == nil { return nil, errors.New("invite must be a state event") From 1c909ceb011ba98911b083ce427e70d65fc54940 Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Tue, 6 Jun 2023 11:17:16 -0600 Subject: [PATCH 23/28] When handling input events, nil userID is valid --- roomserver/internal/input/input_events.go | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index b692a9513c..3d396d6aac 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -130,7 +130,11 @@ func (r *Inputer) processRoomEvent( } sender, err := r.DB.GetUserIDForSender(ctx, event.RoomID(), event.SenderID()) if err != nil { - return fmt.Errorf("event has invalid sender %q", event.SenderID()) + return fmt.Errorf("failed getting userID for sender %q. %w", event.SenderID(), err) + } + senderDomain := spec.ServerName("") + if sender != nil { + senderDomain = sender.Domain() } // If we already know about this outlier and it hasn't been rejected @@ -193,9 +197,9 @@ func (r *Inputer) processRoomEvent( serverRes.ServerNames = append(serverRes.ServerNames, input.Origin) delete(servers, input.Origin) } - if sender.Domain() != input.Origin && sender.Domain() != r.Cfg.Matrix.ServerName { - serverRes.ServerNames = append(serverRes.ServerNames, sender.Domain()) - delete(servers, sender.Domain()) + if senderDomain != input.Origin && senderDomain != r.Cfg.Matrix.ServerName { + serverRes.ServerNames = append(serverRes.ServerNames, senderDomain) + delete(servers, senderDomain) } for server := range servers { serverRes.ServerNames = append(serverRes.ServerNames, server) @@ -453,7 +457,7 @@ func (r *Inputer) processRoomEvent( } // Handle remote room upgrades, e.g. remove published room - if event.Type() == "m.room.tombstone" && event.StateKeyEquals("") && !r.Cfg.Matrix.IsLocalServerName(sender.Domain()) { + if event.Type() == "m.room.tombstone" && event.StateKeyEquals("") && !r.Cfg.Matrix.IsLocalServerName(senderDomain) { if err = r.handleRemoteRoomUpgrade(ctx, event); err != nil { return fmt.Errorf("failed to handle remote room upgrade: %w", err) } From 7420e731cc0eb0a9bbec65272e1b2190ff52062a Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Tue, 6 Jun 2023 11:26:22 -0600 Subject: [PATCH 24/28] Dont fail processing event when mxid_mapping is unknown --- roomserver/internal/input/input_events.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index 3d396d6aac..764bdfe2cf 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -197,7 +197,9 @@ func (r *Inputer) processRoomEvent( serverRes.ServerNames = append(serverRes.ServerNames, input.Origin) delete(servers, input.Origin) } - if senderDomain != input.Origin && senderDomain != r.Cfg.Matrix.ServerName { + // Only perform this check if the sender mxid_mapping can be resolved. + // Don't fail processing the event if we have no mxid_maping. + if sender != nil && senderDomain != input.Origin && senderDomain != r.Cfg.Matrix.ServerName { serverRes.ServerNames = append(serverRes.ServerNames, senderDomain) delete(servers, senderDomain) } From c0be935e59e9a70f0fe2ec95822245d5dacdf5e4 Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Tue, 6 Jun 2023 11:33:10 -0600 Subject: [PATCH 25/28] Use userID for room alias handling --- clientapi/routing/directory.go | 21 +++---------------- roomserver/api/alias.go | 2 +- roomserver/internal/alias.go | 2 +- .../internal/perform/perform_create_room.go | 6 +++--- .../internal/perform/perform_upgrade.go | 2 +- roomserver/roomserver_test.go | 2 +- 6 files changed, 10 insertions(+), 25 deletions(-) diff --git a/clientapi/routing/directory.go b/clientapi/routing/directory.go index c702b136d4..0c842e6a50 100644 --- a/clientapi/routing/directory.go +++ b/clientapi/routing/directory.go @@ -181,25 +181,10 @@ func SetLocalAlias( return *resErr } - userID, err := spec.NewUserID(device.UserID, true) - if err != nil { - return util.JSONResponse{ - Code: http.StatusInternalServerError, - JSON: spec.InternalServerError{Err: "UserID for device is invalid"}, - } - } - - deviceSenderID, err := rsAPI.QuerySenderIDForUser(req.Context(), r.RoomID, *userID) - if err != nil { - return util.JSONResponse{ - Code: http.StatusInternalServerError, - JSON: spec.InternalServerError{Err: "Could not find SenderID for this device"}, - } - } queryReq := roomserverAPI.SetRoomAliasRequest{ - SenderID: deviceSenderID, - RoomID: r.RoomID, - Alias: alias, + UserID: device.UserID, + RoomID: r.RoomID, + Alias: alias, } var queryRes roomserverAPI.SetRoomAliasResponse if err := rsAPI.SetRoomAlias(req.Context(), &queryReq, &queryRes); err != nil { diff --git a/roomserver/api/alias.go b/roomserver/api/alias.go index 0404c1cdd7..1b94754048 100644 --- a/roomserver/api/alias.go +++ b/roomserver/api/alias.go @@ -19,7 +19,7 @@ import "regexp" // SetRoomAliasRequest is a request to SetRoomAlias type SetRoomAliasRequest struct { // ID of the user setting the alias - SenderID string `json:"user_id"` + UserID string `json:"user_id"` // New alias for the room Alias string `json:"alias"` // The room ID the alias is referring to diff --git a/roomserver/internal/alias.go b/roomserver/internal/alias.go index da275725db..6d1f451864 100644 --- a/roomserver/internal/alias.go +++ b/roomserver/internal/alias.go @@ -51,7 +51,7 @@ func (r *RoomserverInternalAPI) SetRoomAlias( response.AliasExists = false // Save the new alias - if err := r.DB.SetRoomAlias(ctx, request.Alias, request.RoomID, request.SenderID); err != nil { + if err := r.DB.SetRoomAlias(ctx, request.Alias, request.RoomID, request.UserID); err != nil { return err } diff --git a/roomserver/internal/perform/perform_create_room.go b/roomserver/internal/perform/perform_create_room.go index bfa1f631ea..897bd3a0e0 100644 --- a/roomserver/internal/perform/perform_create_room.go +++ b/roomserver/internal/perform/perform_create_room.go @@ -352,9 +352,9 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo // been taken. if roomAlias != "" { aliasReq := api.SetRoomAliasRequest{ - Alias: roomAlias, - RoomID: roomID.String(), - SenderID: userID.String(), + Alias: roomAlias, + RoomID: roomID.String(), + UserID: userID.String(), } var aliasResp api.SetRoomAliasResponse diff --git a/roomserver/internal/perform/perform_upgrade.go b/roomserver/internal/perform/perform_upgrade.go index 7513aab46e..8c0df1c46b 100644 --- a/roomserver/internal/perform/perform_upgrade.go +++ b/roomserver/internal/perform/perform_upgrade.go @@ -182,7 +182,7 @@ func moveLocalAliases(ctx context.Context, return fmt.Errorf("Failed to remove old room alias: %w", err) } - setAliasReq := api.SetRoomAliasRequest{SenderID: userID, Alias: alias, RoomID: newRoomID} + setAliasReq := api.SetRoomAliasRequest{UserID: userID, Alias: alias, RoomID: newRoomID} setAliasRes := api.SetRoomAliasResponse{} if err = URSAPI.SetRoomAlias(ctx, &setAliasReq, &setAliasRes); err != nil { return fmt.Errorf("Failed to set new room alias: %w", err) diff --git a/roomserver/roomserver_test.go b/roomserver/roomserver_test.go index 274b4f6933..11a0f58175 100644 --- a/roomserver/roomserver_test.go +++ b/roomserver/roomserver_test.go @@ -267,7 +267,7 @@ func TestPurgeRoom(t *testing.T) { } aliasResp := &api.SetRoomAliasResponse{} - if err = rsAPI.SetRoomAlias(ctx, &api.SetRoomAliasRequest{RoomID: room.ID, Alias: "myalias", SenderID: alice.ID}, aliasResp); err != nil { + if err = rsAPI.SetRoomAlias(ctx, &api.SetRoomAliasRequest{RoomID: room.ID, Alias: "myalias", UserID: alice.ID}, aliasResp); err != nil { t.Fatal(err) } // check the alias is actually there From c2aac0f19e72a119cbb4a54e05113aa2d597d1ed Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Tue, 6 Jun 2023 11:36:52 -0600 Subject: [PATCH 26/28] Properly lookup userID when removing room alias --- roomserver/internal/alias.go | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/roomserver/internal/alias.go b/roomserver/internal/alias.go index 6d1f451864..dcfb26b8e4 100644 --- a/roomserver/internal/alias.go +++ b/roomserver/internal/alias.go @@ -119,11 +119,6 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias( request *api.RemoveRoomAliasRequest, response *api.RemoveRoomAliasResponse, ) error { - _, virtualHost, err := r.Cfg.Global.SplitLocalID('@', request.SenderID) - if err != nil { - return err - } - roomID, err := r.DB.GetRoomIDForAlias(ctx, request.Alias) if err != nil { return fmt.Errorf("r.DB.GetRoomIDForAlias: %w", err) @@ -134,6 +129,12 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias( return nil } + sender, err := r.QueryUserIDForSender(ctx, roomID, request.SenderID) + if err != nil { + return fmt.Errorf("r.QueryUserIDForSender: %w", err) + } + virtualHost := sender.Domain() + response.Found = true creatorID, err := r.DB.GetCreatorIDForAlias(ctx, request.Alias) if err != nil { From 34edfff85ccfd1c31c689ab44ce66092cae4592f Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Tue, 6 Jun 2023 14:11:21 -0600 Subject: [PATCH 27/28] ToClientEvent now directly uses the provided userID for sender field --- clientapi/routing/state.go | 27 ++++++++++++++++++--------- roomserver/internal/query/query.go | 18 ++++++++++++------ syncapi/routing/context.go | 9 ++++++--- syncapi/routing/getevent.go | 9 ++++++--- syncapi/routing/relations.go | 9 ++++++--- syncapi/routing/search.go | 23 +++++++++++++---------- syncapi/streams/stream_invite.go | 12 +++++------- syncapi/synctypes/clientevent.go | 16 ++++++++-------- syncapi/synctypes/clientevent_test.go | 25 ++++++++++++------------- syncapi/types/types.go | 4 ++-- syncapi/types/types_test.go | 7 ++++++- userapi/consumers/roomserver.go | 18 ++++++++++++------ userapi/util/notify_test.go | 12 ++++++------ 13 files changed, 112 insertions(+), 77 deletions(-) diff --git a/clientapi/routing/state.go b/clientapi/routing/state.go index 8567e617d1..13f308998e 100644 --- a/clientapi/routing/state.go +++ b/clientapi/routing/state.go @@ -140,11 +140,14 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a // use the result of the previous QueryLatestEventsAndState response // to find the state event, if provided. for _, ev := range stateRes.StateEvents { + sender := spec.UserID{} + userID, err := rsAPI.QueryUserIDForSender(ctx, ev.RoomID(), ev.SenderID()) + if err == nil && userID != nil { + sender = *userID + } stateEvents = append( stateEvents, - synctypes.ToClientEvent(ev, synctypes.FormatAll, func(roomID string, senderID string) (*spec.UserID, error) { - return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) - }), + synctypes.ToClientEvent(ev, synctypes.FormatAll, sender), ) } } else { @@ -164,11 +167,14 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a } } for _, ev := range stateAfterRes.StateEvents { + sender := spec.UserID{} + userID, err := rsAPI.QueryUserIDForSender(ctx, ev.RoomID(), ev.SenderID()) + if err == nil && userID != nil { + sender = *userID + } stateEvents = append( stateEvents, - synctypes.ToClientEvent(ev, synctypes.FormatAll, func(roomID string, senderID string) (*spec.UserID, error) { - return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) - }), + synctypes.ToClientEvent(ev, synctypes.FormatAll, sender), ) } } @@ -338,10 +344,13 @@ func OnIncomingStateTypeRequest( } } + sender := spec.UserID{} + userID, err := rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()) + if err == nil && userID != nil { + sender = *userID + } stateEvent := stateEventInStateResp{ - ClientEvent: synctypes.ToClientEvent(event, synctypes.FormatAll, func(roomID string, senderID string) (*spec.UserID, error) { - return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) - }), + ClientEvent: synctypes.ToClientEvent(event, synctypes.FormatAll, sender), } var res interface{} diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index 7329504db2..707e95b2a1 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -388,9 +388,12 @@ func (r *Queryer) QueryMembershipsForRoom( return fmt.Errorf("r.DB.Events: %w", err) } for _, event := range events { - clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { - return r.DB.GetUserIDForSender(ctx, roomID, senderID) - }) + sender := spec.UserID{} + userID, queryErr := r.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()) + if queryErr == nil && userID != nil { + sender = *userID + } + clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, sender) response.JoinEvents = append(response.JoinEvents, clientEvent) } return nil @@ -439,9 +442,12 @@ func (r *Queryer) QueryMembershipsForRoom( } for _, event := range events { - clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { - return r.DB.GetUserIDForSender(ctx, roomID, senderID) - }) + sender := spec.UserID{} + userID, err := r.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()) + if err == nil && userID != nil { + sender = *userID + } + clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, sender) response.JoinEvents = append(response.JoinEvents, clientEvent) } diff --git a/syncapi/routing/context.go b/syncapi/routing/context.go index a646622fcb..27e99a357f 100644 --- a/syncapi/routing/context.go +++ b/syncapi/routing/context.go @@ -217,9 +217,12 @@ func Context( } } - ev := synctypes.ToClientEvent(&requestedEvent, synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { - return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) - }) + sender := spec.UserID{} + userID, err := rsAPI.QueryUserIDForSender(ctx, requestedEvent.RoomID(), requestedEvent.SenderID()) + if err == nil && userID != nil { + sender = *userID + } + ev := synctypes.ToClientEvent(&requestedEvent, synctypes.FormatAll, sender) response := ContextRespsonse{ Event: &ev, EventsAfter: eventsAfterClient, diff --git a/syncapi/routing/getevent.go b/syncapi/routing/getevent.go index 0d47477aa8..63df7e8372 100644 --- a/syncapi/routing/getevent.go +++ b/syncapi/routing/getevent.go @@ -101,10 +101,13 @@ func GetEvent( } } + sender := spec.UserID{} + senderUserID, err := rsAPI.QueryUserIDForSender(req.Context(), roomID, events[0].SenderID()) + if err == nil && senderUserID != nil { + sender = *senderUserID + } return util.JSONResponse{ Code: http.StatusOK, - JSON: synctypes.ToClientEvent(events[0], synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { - return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID) - }), + JSON: synctypes.ToClientEvent(events[0], synctypes.FormatAll, sender), } } diff --git a/syncapi/routing/relations.go b/syncapi/routing/relations.go index 30c02d2934..f21c684c8d 100644 --- a/syncapi/routing/relations.go +++ b/syncapi/routing/relations.go @@ -114,11 +114,14 @@ func Relations( // type if it was specified. res.Chunk = make([]synctypes.ClientEvent, 0, len(filteredEvents)) for _, event := range filteredEvents { + sender := spec.UserID{} + userID, err := rsAPI.QueryUserIDForSender(req.Context(), event.RoomID(), event.SenderID()) + if err == nil && userID != nil { + sender = *userID + } res.Chunk = append( res.Chunk, - synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { - return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID) - }), + synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, sender), ) } diff --git a/syncapi/routing/search.go b/syncapi/routing/search.go index e4e173df2d..9cf3eabe2e 100644 --- a/syncapi/routing/search.go +++ b/syncapi/routing/search.go @@ -205,17 +205,17 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts profileInfos := make(map[string]ProfileInfoResponse) for _, ev := range append(eventsBefore, eventsAfter...) { - userID, err := rsAPI.QueryUserIDForSender(req.Context(), ev.RoomID(), ev.SenderID()) - if err != nil { - logrus.WithError(err).WithField("sender_id", event.SenderID()).Warn("failed to query userprofile") + userID, queryErr := rsAPI.QueryUserIDForSender(req.Context(), ev.RoomID(), ev.SenderID()) + if queryErr != nil { + logrus.WithError(queryErr).WithField("sender_id", event.SenderID()).Warn("failed to query userprofile") continue } profile, ok := knownUsersProfiles[userID.String()] if !ok { - stateEvent, err := snapshot.GetStateEvent(ctx, ev.RoomID(), spec.MRoomMember, ev.SenderID()) - if err != nil { - logrus.WithError(err).WithField("sender_id", event.SenderID()).Warn("failed to query userprofile") + stateEvent, stateErr := snapshot.GetStateEvent(ctx, ev.RoomID(), spec.MRoomMember, ev.SenderID()) + if stateErr != nil { + logrus.WithError(stateErr).WithField("sender_id", event.SenderID()).Warn("failed to query userprofile") continue } if stateEvent == nil { @@ -230,6 +230,11 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts profileInfos[userID.String()] = profile } + sender := spec.UserID{} + userID, err := rsAPI.QueryUserIDForSender(req.Context(), event.RoomID(), event.SenderID()) + if err == nil && userID != nil { + sender = *userID + } results = append(results, Result{ Context: SearchContextResponse{ Start: startToken.String(), @@ -242,10 +247,8 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts }), ProfileInfo: profileInfos, }, - Rank: eventScore[event.EventID()].Score, - Result: synctypes.ToClientEvent(event, synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { - return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID) - }), + Rank: eventScore[event.EventID()].Score, + Result: synctypes.ToClientEvent(event, synctypes.FormatAll, sender), }) roomGroup := groups[event.RoomID()] roomGroup.Results = append(roomGroup.Results, event.EventID()) diff --git a/syncapi/streams/stream_invite.go b/syncapi/streams/stream_invite.go index 0e62f70bfb..a8b0a7b667 100644 --- a/syncapi/streams/stream_invite.go +++ b/syncapi/streams/stream_invite.go @@ -64,19 +64,17 @@ func (p *InviteStreamProvider) IncrementalSync( } for roomID, inviteEvent := range invites { - user := "" + user := spec.UserID{} sender, err := p.rsAPI.QueryUserIDForSender(ctx, inviteEvent.RoomID(), inviteEvent.SenderID()) - if err == nil { - user = sender.String() + if err == nil && sender != nil { + user = *sender } // skip ignored user events - if _, ok := req.IgnoredUsers.List[user]; ok { + if _, ok := req.IgnoredUsers.List[user.String()]; ok { continue } - ir := types.NewInviteResponse(inviteEvent, func(roomID, senderID string) (*spec.UserID, error) { - return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) - }) + ir := types.NewInviteResponse(inviteEvent, user) req.Response.Rooms.Invite[roomID] = ir } diff --git a/syncapi/synctypes/clientevent.go b/syncapi/synctypes/clientevent.go index 6ab12102d0..66fb1d01f4 100644 --- a/syncapi/synctypes/clientevent.go +++ b/syncapi/synctypes/clientevent.go @@ -50,21 +50,21 @@ func ToClientEvents(serverEvs []gomatrixserverlib.PDU, format ClientEventFormat, if se == nil { continue // TODO: shouldn't happen? } - evs = append(evs, ToClientEvent(se, format, userIDForSender)) + sender := spec.UserID{} + userID, err := userIDForSender(se.RoomID(), se.SenderID()) + if err == nil && userID != nil { + sender = *userID + } + evs = append(evs, ToClientEvent(se, format, sender)) } return evs } // ToClientEvent converts a single server event to a client event. -func ToClientEvent(se gomatrixserverlib.PDU, format ClientEventFormat, userIDForSender spec.UserIDForSender) ClientEvent { - user := "" - userID, err := userIDForSender(se.RoomID(), se.SenderID()) - if err == nil { - user = userID.String() - } +func ToClientEvent(se gomatrixserverlib.PDU, format ClientEventFormat, sender spec.UserID) ClientEvent { ce := ClientEvent{ Content: spec.RawJSON(se.Content()), - Sender: user, + Sender: sender.String(), Type: se.Type(), StateKey: se.StateKey(), Unsigned: spec.RawJSON(se.Unsigned()), diff --git a/syncapi/synctypes/clientevent_test.go b/syncapi/synctypes/clientevent_test.go index b76b38c178..341795081c 100644 --- a/syncapi/synctypes/clientevent_test.go +++ b/syncapi/synctypes/clientevent_test.go @@ -24,10 +24,6 @@ import ( "github.com/matrix-org/gomatrixserverlib/spec" ) -func UserIDForSender(roomID string, senderID string) (*spec.UserID, error) { - return spec.NewUserID(senderID, true) -} - func TestToClientEvent(t *testing.T) { // nolint: gocyclo ev, err := gomatrixserverlib.MustGetRoomVersion(gomatrixserverlib.RoomVersionV1).NewEventFromTrustedJSON([]byte(`{ "type": "m.room.name", @@ -48,7 +44,11 @@ func TestToClientEvent(t *testing.T) { // nolint: gocyclo if err != nil { t.Fatalf("failed to create Event: %s", err) } - ce := ToClientEvent(ev, FormatAll, UserIDForSender) + userID, err := spec.NewUserID("@test:localhost", true) + if err != nil { + t.Fatalf("failed to create userID: %s", err) + } + ce := ToClientEvent(ev, FormatAll, *userID) if ce.EventID != ev.EventID() { t.Errorf("ClientEvent.EventID: wanted %s, got %s", ev.EventID(), ce.EventID) } @@ -67,13 +67,8 @@ func TestToClientEvent(t *testing.T) { // nolint: gocyclo if !bytes.Equal(ce.Unsigned, ev.Unsigned()) { t.Errorf("ClientEvent.Unsigned: wanted %s, got %s", string(ev.Unsigned()), string(ce.Unsigned)) } - user := "" - userID, err := UserIDForSender("", ev.SenderID()) - if err == nil { - user = userID.String() - } - if ce.Sender != user { - t.Errorf("ClientEvent.Sender: wanted %s, got %s", user, ce.Sender) + if ce.Sender != userID.String() { + t.Errorf("ClientEvent.Sender: wanted %s, got %s", userID.String(), ce.Sender) } j, err := json.Marshal(ce) if err != nil { @@ -108,7 +103,11 @@ func TestToClientFormatSync(t *testing.T) { if err != nil { t.Fatalf("failed to create Event: %s", err) } - ce := ToClientEvent(ev, FormatSync, UserIDForSender) + userID, err := spec.NewUserID("@test:localhost", true) + if err != nil { + t.Fatalf("failed to create userID: %s", err) + } + ce := ToClientEvent(ev, FormatSync, *userID) if ce.RoomID != "" { t.Errorf("ClientEvent.RoomID: wanted '', got %s", ce.RoomID) } diff --git a/syncapi/types/types.go b/syncapi/types/types.go index 3b3bd79872..526a120d01 100644 --- a/syncapi/types/types.go +++ b/syncapi/types/types.go @@ -539,7 +539,7 @@ type InviteResponse struct { } // NewInviteResponse creates an empty response with initialised arrays. -func NewInviteResponse(event *types.HeaderedEvent, userIDForSender spec.UserIDForSender) *InviteResponse { +func NewInviteResponse(event *types.HeaderedEvent, userID spec.UserID) *InviteResponse { res := InviteResponse{} res.InviteState.Events = []json.RawMessage{} @@ -552,7 +552,7 @@ func NewInviteResponse(event *types.HeaderedEvent, userIDForSender spec.UserIDFo // Then we'll see if we can create a partial of the invite event itself. // This is needed for clients to work out *who* sent the invite. - inviteEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatSync, userIDForSender) + inviteEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatSync, userID) inviteEvent.Unsigned = nil if ev, err := json.Marshal(inviteEvent); err == nil { res.InviteState.Events = append(res.InviteState.Events, ev) diff --git a/syncapi/types/types_test.go b/syncapi/types/types_test.go index a93e9122f7..a79ce54175 100644 --- a/syncapi/types/types_test.go +++ b/syncapi/types/types_test.go @@ -61,7 +61,12 @@ func TestNewInviteResponse(t *testing.T) { t.Fatal(err) } - res := NewInviteResponse(&types.HeaderedEvent{PDU: ev}, UserIDForSender) + sender, err := spec.NewUserID("@neilalexander:matrix.org", true) + if err != nil { + t.Fatal(err) + } + + res := NewInviteResponse(&types.HeaderedEvent{PDU: ev}, *sender) j, err := json.Marshal(res) if err != nil { t.Fatal(err) diff --git a/userapi/consumers/roomserver.go b/userapi/consumers/roomserver.go index 3cb44f0fb0..c025deee06 100644 --- a/userapi/consumers/roomserver.go +++ b/userapi/consumers/roomserver.go @@ -301,9 +301,12 @@ func (s *OutputRoomEventConsumer) processMessage(ctx context.Context, event *rst switch { case event.Type() == spec.MRoomMember: - cevent := synctypes.ToClientEvent(event, synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { - return s.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) - }) + sender := spec.UserID{} + userID, queryErr := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()) + if queryErr == nil && userID != nil { + sender = *userID + } + cevent := synctypes.ToClientEvent(event, synctypes.FormatAll, sender) var member *localMembership member, err = newLocalMembership(&cevent) if err != nil { @@ -531,14 +534,17 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *rstype return fmt.Errorf("s.localPushDevices: %w", err) } + sender := spec.UserID{} + userID, err := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()) + if err == nil && userID != nil { + sender = *userID + } n := &api.Notification{ Actions: actions, // UNSPEC: the spec doesn't say this is a ClientEvent, but the // fields seem to match. room_id should be missing, which // matches the behaviour of FormatSync. - Event: synctypes.ToClientEvent(event, synctypes.FormatSync, func(roomID string, senderID string) (*spec.UserID, error) { - return s.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) - }), + Event: synctypes.ToClientEvent(event, synctypes.FormatSync, sender), // TODO: this is per-device, but it's not part of the primary // key. So inserting one notification per profile tag doesn't // make sense. What is this supposed to be? Sytests require it diff --git a/userapi/util/notify_test.go b/userapi/util/notify_test.go index 86446b2db1..27dd373c24 100644 --- a/userapi/util/notify_test.go +++ b/userapi/util/notify_test.go @@ -23,10 +23,6 @@ import ( userUtil "github.com/matrix-org/dendrite/userapi/util" ) -func UserIDForSender(roomID string, senderID string) (*spec.UserID, error) { - return spec.NewUserID(senderID, true) -} - func TestNotifyUserCountsAsync(t *testing.T) { alice := test.NewUser(t) aliceLocalpart, serverName, err := gomatrixserverlib.SplitID('@', alice.ID) @@ -92,7 +88,7 @@ func TestNotifyUserCountsAsync(t *testing.T) { } // Prepare pusher with our test server URL - if err := db.UpsertPusher(ctx, api.Pusher{ + if err = db.UpsertPusher(ctx, api.Pusher{ Kind: api.HTTPKind, AppID: appID, PushKey: pushKey, @@ -104,8 +100,12 @@ func TestNotifyUserCountsAsync(t *testing.T) { } // Insert a dummy event + sender, err := spec.NewUserID(alice.ID, true) + if err != nil { + t.Error(err) + } if err := db.InsertNotification(ctx, aliceLocalpart, serverName, dummyEvent.EventID(), 0, nil, &api.Notification{ - Event: synctypes.ToClientEvent(dummyEvent, synctypes.FormatAll, UserIDForSender), + Event: synctypes.ToClientEvent(dummyEvent, synctypes.FormatAll, *sender), }); err != nil { t.Error(err) } From 051aa22b15792b95f777e5477d5c698b81bbe3b7 Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Tue, 6 Jun 2023 14:30:20 -0600 Subject: [PATCH 28/28] Update to latest gmsl --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index 65914f51d0..10551f7026 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,7 @@ require ( github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 - github.com/matrix-org/gomatrixserverlib v0.0.0-20230606154401-7f56d079aa92 + github.com/matrix-org/gomatrixserverlib v0.0.0-20230606202811-a644d5d8fb66 github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 github.com/mattn/go-sqlite3 v1.14.16 diff --git a/go.sum b/go.sum index 4cd1c35335..3ec1c115c5 100644 --- a/go.sum +++ b/go.sum @@ -323,8 +323,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230606154401-7f56d079aa92 h1:Dq8ZqshsAW3qD/XPyp89ZsL2xM0hORC0TzwSNlyGUPM= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230606154401-7f56d079aa92/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230606202811-a644d5d8fb66 h1:6SixhMmB5Ir10xUJ6zh3A4NBxSaZCSz2s5U63Wg0eEU= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230606202811-a644d5d8fb66/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU= github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a h1:awrPDf9LEFySxTLKYBMCiObelNx/cBuv/wzllvCCH3A= github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a/go.mod h1:HchJX9oKMXaT2xYFs0Ha/6Zs06mxLU8k6F1ODnrGkeQ= github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y=