diff --git a/roomserver/state/state.go b/roomserver/state/state.go index 0da32d29b7..7f2c0bd68a 100644 --- a/roomserver/state/state.go +++ b/roomserver/state/state.go @@ -822,14 +822,10 @@ func (v *StateResolution) resolveConflictsV2( key := conflictedEvent.EventID() // Store the newly found auth events in the auth set for this event. - var authEventMap map[string]types.StateEntry - authSets[key], authEventMap, err = v.loadAuthEvents(ctx, conflictedEvent) + authSets[key], err = v.loadAuthEvents(ctx, conflictedEvent) if err != nil { return nil, err } - for k, v := range authEventMap { - eventIDMap[k] = v - } // Only add auth events into the authEvents slice once, otherwise the // check for the auth difference can become expensive and produce @@ -975,14 +971,13 @@ func (v *StateResolution) loadStateEvents( return result, eventIDMap, nil } -// loadAuthEvents loads all of the auth events for a given event recursively, -// along with a map that contains state entries for all of the auth events. +// loadAuthEvents loads all of the auth events for a given event recursively. func (v *StateResolution) loadAuthEvents( ctx context.Context, event *gomatrixserverlib.Event, -) ([]*gomatrixserverlib.Event, map[string]types.StateEntry, error) { +) ([]*gomatrixserverlib.Event, error) { eventMap := map[string]struct{}{} - var getEvents func(eventIDs []string) ([]types.Event, error) - getEvents = func(eventIDs []string) ([]types.Event, error) { + var getEvents func(eventIDs []string) ([]*gomatrixserverlib.Event, error) + getEvents = func(eventIDs []string) ([]*gomatrixserverlib.Event, error) { lookup := make([]string, 0, len(event.AuthEventIDs())) for _, eventID := range eventIDs { if _, ok := eventMap[eventID]; ok { @@ -997,54 +992,19 @@ func (v *StateResolution) loadAuthEvents( if err != nil { return nil, fmt.Errorf("v.db.EventsFromIDs: %w", err) } - eventMap[event.EventID()] = struct{}{} - next, err := getEvents(event.AuthEventIDs()) - if err != nil { - return nil, err - } - return append(events, next...), nil - } - authEvents, err := getEvents(event.AuthEventIDs()) - if err != nil { - return nil, nil, fmt.Errorf("getEvents: %w", err) - } - authEventTypes := map[string]struct{}{} - authEventStateKeys := map[string]struct{}{} - for _, authEvent := range authEvents { - authEventTypes[authEvent.Type()] = struct{}{} - authEventStateKeys[*authEvent.StateKey()] = struct{}{} - } - lookupAuthEventTypes := make([]string, 0, len(authEventTypes)) - lookupAuthEventStateKeys := make([]string, 0, len(authEventStateKeys)) - for eventType := range authEventTypes { - lookupAuthEventTypes = append(lookupAuthEventTypes, eventType) - } - for eventStateKey := range authEventStateKeys { - lookupAuthEventStateKeys = append(lookupAuthEventStateKeys, eventStateKey) - } - eventTypes, err := v.db.EventTypeNIDs(ctx, lookupAuthEventTypes) - if err != nil { - return nil, nil, fmt.Errorf("v.db.EventTypeNIDs: %w", err) - } - eventStateKeys, err := v.db.EventStateKeyNIDs(ctx, lookupAuthEventStateKeys) - if err != nil { - return nil, nil, fmt.Errorf("v.db.EventStateKeyNIDs: %w", err) - } - stateEntryMap := map[string]types.StateEntry{} - for _, authEvent := range authEvents { - stateEntryMap[authEvent.EventID()] = types.StateEntry{ - EventNID: authEvent.EventNID, - StateKeyTuple: types.StateKeyTuple{ - EventTypeNID: eventTypes[authEvent.Type()], - EventStateKeyNID: eventStateKeys[*authEvent.StateKey()], - }, + result := make([]*gomatrixserverlib.Event, 0, len(events)) + for _, event := range events { + result = append(result, event.Event) + eventMap[event.EventID()] = struct{}{} + next, err := getEvents(event.AuthEventIDs()) + if err != nil { + return nil, err + } + result = append(result, next...) } + return result, nil } - nakedEvents := make([]*gomatrixserverlib.Event, 0, len(authEvents)) - for _, authEvent := range authEvents { - nakedEvents = append(nakedEvents, authEvent.Event) - } - return nakedEvents, stateEntryMap, nil + return getEvents(event.AuthEventIDs()) } // findDuplicateStateKeys finds the state entries where the state key tuple appears more than once in a sorted list.