diff --git a/state/accumulator.go b/state/accumulator.go index b0de435b..0105786f 100644 --- a/state/accumulator.go +++ b/state/accumulator.go @@ -4,6 +4,7 @@ import ( "database/sql" "encoding/json" "fmt" + "github.com/getsentry/sentry-go" "github.com/jmoiron/sqlx" @@ -291,131 +292,125 @@ func (a *Accumulator) Initialise(roomID string, state []json.RawMessage) (Initia // to exist in the database, and the sync stream is already linearised for us. // - Else it creates a new room state snapshot if the timeline contains state events (as this now represents the current state) // - It adds entries to the membership log for membership events. -func (a *Accumulator) Accumulate(roomID string, prevBatch string, timeline []json.RawMessage) (numNew int, timelineNIDs []int64, err error) { - if len(timeline) == 0 { - return 0, nil, nil - } - err = sqlutil.WithTransaction(a.db, func(txn *sqlx.Tx) error { - // Insert the events. Check for duplicates which can happen in the real world when joining - // Matrix HQ on Synapse. - dedupedEvents := make([]Event, 0, len(timeline)) - seenEvents := make(map[string]struct{}) - for i := range timeline { - e := Event{ - JSON: timeline[i], - RoomID: roomID, - } - if err := e.ensureFieldsSetOnEvent(); err != nil { - return fmt.Errorf("event malformed: %s", err) - } - if _, ok := seenEvents[e.ID]; ok { - logger.Warn().Str("event_id", e.ID).Str("room_id", roomID).Msg( - "Accumulator.Accumulate: seen the same event ID twice, ignoring", - ) - continue - } - if i == 0 && prevBatch != "" { - // tag the first timeline event with the prev batch token - e.PrevBatch = sql.NullString{ - String: prevBatch, - Valid: true, - } - } - dedupedEvents = append(dedupedEvents, e) - seenEvents[e.ID] = struct{}{} +func (a *Accumulator) Accumulate(txn *sqlx.Tx, roomID string, prevBatch string, timeline []json.RawMessage) (numNew int, timelineNIDs []int64, err error) { + // Insert the events. Check for duplicates which can happen in the real world when joining + // Matrix HQ on Synapse. + dedupedEvents := make([]Event, 0, len(timeline)) + seenEvents := make(map[string]struct{}) + for i := range timeline { + e := Event{ + JSON: timeline[i], + RoomID: roomID, } - eventIDToNID, err := a.eventsTable.Insert(txn, dedupedEvents, false) - if err != nil { - return err + if err := e.ensureFieldsSetOnEvent(); err != nil { + return 0, nil, fmt.Errorf("event malformed: %s", err) } - if len(eventIDToNID) == 0 { - // nothing to do, we already know about these events - return nil + if _, ok := seenEvents[e.ID]; ok { + logger.Warn().Str("event_id", e.ID).Str("room_id", roomID).Msg( + "Accumulator.Accumulate: seen the same event ID twice, ignoring", + ) + continue } - numNew = len(eventIDToNID) - - var latestNID int64 - newEvents := make([]Event, 0, len(eventIDToNID)) - for _, ev := range dedupedEvents { - nid, ok := eventIDToNID[ev.ID] - if ok { - ev.NID = int64(nid) - if gjson.GetBytes(ev.JSON, "state_key").Exists() { - // XXX: reusing this to mean "it's a state event" as well as "it's part of the state v2 response" - // its important that we don't insert 'ev' at this point as this should be False in the DB. - ev.IsState = true - } - // assign the highest nid value to the latest nid. - // we'll return this to the caller so they can stay in-sync - if ev.NID > latestNID { - latestNID = ev.NID - } - newEvents = append(newEvents, ev) - timelineNIDs = append(timelineNIDs, ev.NID) + if i == 0 && prevBatch != "" { + // tag the first timeline event with the prev batch token + e.PrevBatch = sql.NullString{ + String: prevBatch, + Valid: true, } } + dedupedEvents = append(dedupedEvents, e) + seenEvents[e.ID] = struct{}{} + } + eventIDToNID, err := a.eventsTable.Insert(txn, dedupedEvents, false) + if err != nil { + return 0, nil, err + } + if len(eventIDToNID) == 0 { + // nothing to do, we already know about these events + return 0, nil, nil + } + numNew = len(eventIDToNID) - // Given a timeline of [E1, E2, S3, E4, S5, S6, E7] (E=message event, S=state event) - // And a prior state snapshot of SNAP0 then the BEFORE snapshot IDs are grouped as: - // E1,E2,S3 => SNAP0 - // E4, S5 => (SNAP0 + S3) - // S6 => (SNAP0 + S3 + S5) - // E7 => (SNAP0 + S3 + S5 + S6) - // We can track this by loading the current snapshot ID (after snapshot) then rolling forward - // the timeline until we hit a state event, at which point we make a new snapshot but critically - // do NOT assign the new state event in the snapshot so as to represent the state before the event. - snapID, err := a.roomsTable.CurrentAfterSnapshotID(txn, roomID) - if err != nil { - return err + var latestNID int64 + newEvents := make([]Event, 0, len(eventIDToNID)) + for _, ev := range dedupedEvents { + nid, ok := eventIDToNID[ev.ID] + if ok { + ev.NID = int64(nid) + if gjson.GetBytes(ev.JSON, "state_key").Exists() { + // XXX: reusing this to mean "it's a state event" as well as "it's part of the state v2 response" + // its important that we don't insert 'ev' at this point as this should be False in the DB. + ev.IsState = true + } + // assign the highest nid value to the latest nid. + // we'll return this to the caller so they can stay in-sync + if ev.NID > latestNID { + latestNID = ev.NID + } + newEvents = append(newEvents, ev) + timelineNIDs = append(timelineNIDs, ev.NID) } - for _, ev := range newEvents { - var replacesNID int64 - // the snapshot ID we assign to this event is unaffected by whether /this/ event is state or not, - // as this is the before snapshot ID. - beforeSnapID := snapID + } - if ev.IsState { - // make a new snapshot and update the snapshot ID - var oldStripped StrippedEvents - if snapID != 0 { - oldStripped, err = a.strippedEventsForSnapshot(txn, snapID) - if err != nil { - return fmt.Errorf("failed to load stripped state events for snapshot %d: %s", snapID, err) - } - } - newStripped, replacedNID, err := a.calculateNewSnapshot(oldStripped, ev) + // Given a timeline of [E1, E2, S3, E4, S5, S6, E7] (E=message event, S=state event) + // And a prior state snapshot of SNAP0 then the BEFORE snapshot IDs are grouped as: + // E1,E2,S3 => SNAP0 + // E4, S5 => (SNAP0 + S3) + // S6 => (SNAP0 + S3 + S5) + // E7 => (SNAP0 + S3 + S5 + S6) + // We can track this by loading the current snapshot ID (after snapshot) then rolling forward + // the timeline until we hit a state event, at which point we make a new snapshot but critically + // do NOT assign the new state event in the snapshot so as to represent the state before the event. + snapID, err := a.roomsTable.CurrentAfterSnapshotID(txn, roomID) + if err != nil { + return 0, nil, err + } + for _, ev := range newEvents { + var replacesNID int64 + // the snapshot ID we assign to this event is unaffected by whether /this/ event is state or not, + // as this is the before snapshot ID. + beforeSnapID := snapID + + if ev.IsState { + // make a new snapshot and update the snapshot ID + var oldStripped StrippedEvents + if snapID != 0 { + oldStripped, err = a.strippedEventsForSnapshot(txn, snapID) if err != nil { - return fmt.Errorf("failed to calculateNewSnapshot: %s", err) - } - replacesNID = replacedNID - memNIDs, otherNIDs := newStripped.NIDs() - newSnapshot := &SnapshotRow{ - RoomID: roomID, - MembershipEvents: memNIDs, - OtherEvents: otherNIDs, + return 0, nil, fmt.Errorf("failed to load stripped state events for snapshot %d: %s", snapID, err) } - if err = a.snapshotTable.Insert(txn, newSnapshot); err != nil { - return fmt.Errorf("failed to insert new snapshot: %w", err) - } - snapID = newSnapshot.SnapshotID } - if err := a.eventsTable.UpdateBeforeSnapshotID(txn, ev.NID, beforeSnapID, replacesNID); err != nil { - return err + newStripped, replacedNID, err := a.calculateNewSnapshot(oldStripped, ev) + if err != nil { + return 0, nil, fmt.Errorf("failed to calculateNewSnapshot: %s", err) } + replacesNID = replacedNID + memNIDs, otherNIDs := newStripped.NIDs() + newSnapshot := &SnapshotRow{ + RoomID: roomID, + MembershipEvents: memNIDs, + OtherEvents: otherNIDs, + } + if err = a.snapshotTable.Insert(txn, newSnapshot); err != nil { + return 0, nil, fmt.Errorf("failed to insert new snapshot: %w", err) + } + snapID = newSnapshot.SnapshotID } - - if err = a.spacesTable.HandleSpaceUpdates(txn, newEvents); err != nil { - return fmt.Errorf("HandleSpaceUpdates: %s", err) + if err := a.eventsTable.UpdateBeforeSnapshotID(txn, ev.NID, beforeSnapID, replacesNID); err != nil { + return 0, nil, err } + } - // the last fetched snapshot ID is the current one - info := a.roomInfoDelta(roomID, newEvents) - if err = a.roomsTable.Upsert(txn, info, snapID, latestNID); err != nil { - return fmt.Errorf("failed to UpdateCurrentSnapshotID to %d: %w", snapID, err) - } - return nil - }) - return numNew, timelineNIDs, err + if err = a.spacesTable.HandleSpaceUpdates(txn, newEvents); err != nil { + return 0, nil, fmt.Errorf("HandleSpaceUpdates: %s", err) + } + + // the last fetched snapshot ID is the current one + info := a.roomInfoDelta(roomID, newEvents) + if err = a.roomsTable.Upsert(txn, info, snapID, latestNID); err != nil { + return 0, nil, fmt.Errorf("failed to UpdateCurrentSnapshotID to %d: %w", snapID, err) + } + return numNew, timelineNIDs, nil } // Delta returns a list of events of at most `limit` for the room not including `lastEventNID`. diff --git a/state/accumulator_test.go b/state/accumulator_test.go index 2be71d93..927679ee 100644 --- a/state/accumulator_test.go +++ b/state/accumulator_test.go @@ -7,6 +7,8 @@ import ( "sort" "testing" + "github.com/jmoiron/sqlx" + "github.com/matrix-org/sliding-sync/sqlutil" "github.com/matrix-org/sliding-sync/sync2" "github.com/matrix-org/sliding-sync/testutils" "github.com/tidwall/gjson" @@ -115,7 +117,11 @@ func TestAccumulatorAccumulate(t *testing.T) { } var numNew int var latestNIDs []int64 - if numNew, latestNIDs, err = accumulator.Accumulate(roomID, "", newEvents); err != nil { + err = sqlutil.WithTransaction(accumulator.db, func(txn *sqlx.Tx) error { + numNew, latestNIDs, err = accumulator.Accumulate(txn, roomID, "", newEvents) + return err + }) + if err != nil { t.Fatalf("failed to Accumulate: %s", err) } if numNew != len(newEvents) { @@ -185,7 +191,11 @@ func TestAccumulatorAccumulate(t *testing.T) { } // subsequent calls do nothing and are not an error - if _, _, err = accumulator.Accumulate(roomID, "", newEvents); err != nil { + err = sqlutil.WithTransaction(accumulator.db, func(txn *sqlx.Tx) error { + _, _, err = accumulator.Accumulate(txn, roomID, "", newEvents) + return err + }) + if err != nil { t.Fatalf("failed to Accumulate: %s", err) } } @@ -207,7 +217,11 @@ func TestAccumulatorDelta(t *testing.T) { []byte(`{"event_id":"aH", "type":"m.room.join_rules", "state_key":"", "content":{"join_rule":"public"}}`), []byte(`{"event_id":"aI", "type":"m.room.history_visibility", "state_key":"", "content":{"visibility":"public"}}`), } - if _, _, err = accumulator.Accumulate(roomID, "", roomEvents); err != nil { + err = sqlutil.WithTransaction(accumulator.db, func(txn *sqlx.Tx) error { + _, _, err = accumulator.Accumulate(txn, roomID, "", roomEvents) + return err + }) + if err != nil { t.Fatalf("failed to Accumulate: %s", err) } @@ -266,7 +280,11 @@ func TestAccumulatorMembershipLogs(t *testing.T) { // @me leaves the room []byte(`{"event_id":"` + roomEventIDs[7] + `", "type":"m.room.member", "state_key":"@me:localhost","unsigned":{"prev_content":{"membership":"join", "displayname":"Me"}}, "content":{"membership":"leave"}}`), } - if _, _, err = accumulator.Accumulate(roomID, "", roomEvents); err != nil { + err = sqlutil.WithTransaction(accumulator.db, func(txn *sqlx.Tx) error { + _, _, err = accumulator.Accumulate(txn, roomID, "", roomEvents) + return err + }) + if err != nil { t.Fatalf("failed to Accumulate: %s", err) } txn, err := accumulator.db.Beginx() @@ -389,7 +407,10 @@ func TestAccumulatorDupeEvents(t *testing.T) { t.Fatalf("failed to Initialise accumulator: %s", err) } - _, _, err = accumulator.Accumulate(roomID, "", joinRoom.Timeline.Events) + err = sqlutil.WithTransaction(accumulator.db, func(txn *sqlx.Tx) error { + _, _, err = accumulator.Accumulate(txn, roomID, "", joinRoom.Timeline.Events) + return err + }) if err != nil { t.Fatalf("failed to Accumulate: %s", err) } @@ -434,7 +455,10 @@ func TestAccumulatorMisorderedGraceful(t *testing.T) { } // Accumulate events D, A, B(msg). - _, _, err = accumulator.Accumulate(roomID, "", []json.RawMessage{eventD, eventA, eventBMsg}) + err = sqlutil.WithTransaction(accumulator.db, func(txn *sqlx.Tx) error { + _, _, err = accumulator.Accumulate(txn, roomID, "", []json.RawMessage{eventD, eventA, eventBMsg}) + return err + }) if err != nil { t.Fatalf("failed to Accumulate: %s", err) } @@ -630,6 +654,12 @@ func TestCalculateNewSnapshotDupe(t *testing.T) { } } +// Test that you can accumulate the same room with the same partial sequence of timeline events and +// state is updated correctly. This relies on postgres blocking subsequent transactions sensibly. +func TestAccumulatorConcurrency(t *testing.T) { + +} + func currentSnapshotNIDs(t *testing.T, snapshotTable *SnapshotTable, roomID string) []int64 { txn := snapshotTable.db.MustBeginTx(context.Background(), nil) defer txn.Commit() diff --git a/state/storage.go b/state/storage.go index 89396ab3..c849d25a 100644 --- a/state/storage.go +++ b/state/storage.go @@ -339,7 +339,14 @@ func (s *Storage) currentNotMembershipStateEventsInAllRooms(txn *sqlx.Tx, eventT } func (s *Storage) Accumulate(roomID, prevBatch string, timeline []json.RawMessage) (numNew int, timelineNIDs []int64, err error) { - return s.accumulator.Accumulate(roomID, prevBatch, timeline) + if len(timeline) == 0 { + return 0, nil, nil + } + err = sqlutil.WithTransaction(s.accumulator.db, func(txn *sqlx.Tx) error { + numNew, timelineNIDs, err = s.accumulator.Accumulate(txn, roomID, prevBatch, timeline) + return err + }) + return } func (s *Storage) Initialise(roomID string, state []json.RawMessage) (InitialiseResult, error) {