diff --git a/README.md b/README.md index 15289a39..a7c5c0b5 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,11 @@ Run a sliding sync proxy. An implementation of [MSC3575](https://github.com/matrix-org/matrix-doc/blob/kegan/sync-v3/proposals/3575-sync.md). -Proxy version to MSC API specification: +## Proxy version to MSC API specification + +This describes which proxy versions implement which version of the API drafted +in MSC3575. See https://github.com/matrix-org/sliding-sync/releases for the +changes in the proxy itself. - Version 0.1.x: [2022/04/01](https://github.com/matrix-org/matrix-spec-proposals/blob/615e8f5a7bfe4da813bc2db661ed0bd00bccac20/proposals/3575-sync.md) - First release @@ -21,10 +25,12 @@ Proxy version to MSC API specification: - Support for `errcode` when sessions expire. - Version 0.99.1 [2023/01/20](https://github.com/matrix-org/matrix-spec-proposals/blob/b4b4e7ff306920d2c862c6ff4d245110f6fa5bc7/proposals/3575-sync.md) - Preparing for major v1.x release: lists-as-keys support. -- Version 0.99.2 [2024/07/27](https://github.com/matrix-org/matrix-spec-proposals/blob/eab643cb3ca63b03537a260fa343e1fb2d1ee284/proposals/3575-sync.md) +- Version 0.99.2 [2023/03/31](https://github.com/matrix-org/matrix-spec-proposals/blob/eab643cb3ca63b03537a260fa343e1fb2d1ee284/proposals/3575-sync.md) - Experimental support for `bump_event_types` when ordering rooms by recency. - Support for opting in to extensions on a per-list and per-room basis. - - Sentry support. +- Version 0.99.3 [2023/05/23](https://github.com/matrix-org/matrix-spec-proposals/blob/4103ee768a4a3e1decee80c2987f50f4c6b3d539/proposals/3575-sync.md) + - Support for per-list `bump_event_types`. + - Support for [`conn_id`](https://github.com/matrix-org/matrix-spec-proposals/blob/4103ee768a4a3e1decee80c2987f50f4c6b3d539/proposals/3575-sync.md#concurrent-connections) for distinguishing multiple concurrent connections. ## Usage diff --git a/cmd/syncv3/main.go b/cmd/syncv3/main.go index 0331c9fd..9938ecc3 100644 --- a/cmd/syncv3/main.go +++ b/cmd/syncv3/main.go @@ -2,6 +2,14 @@ package main import ( "fmt" + "net/http" + _ "net/http/pprof" + "os" + "os/signal" + "strings" + "syscall" + "time" + "github.com/getsentry/sentry-go" sentryhttp "github.com/getsentry/sentry-go/http" syncv3 "github.com/matrix-org/sliding-sync" @@ -10,18 +18,11 @@ import ( "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/rs/zerolog" "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" - "net/http" - _ "net/http/pprof" - "os" - "os/signal" - "strings" - "syscall" - "time" ) var GitCommit string -const version = "0.99.2" +const version = "0.99.3" const ( // Required fields @@ -163,6 +164,8 @@ func main() { h2, h3 := syncv3.Setup(args[EnvServer], args[EnvDB], args[EnvSecret], syncv3.Opts{ AddPrometheusMetrics: args[EnvPrometheus] != "", + DBMaxConns: 100, + DBConnMaxIdleTime: time.Hour, }) go h2.StartV2Pollers() diff --git a/state/device_data_table.go b/state/device_data_table.go index 61853f75..dc8f2488 100644 --- a/state/device_data_table.go +++ b/state/device_data_table.go @@ -44,7 +44,7 @@ func NewDeviceDataTable(db *sqlx.DB) *DeviceDataTable { func (t *DeviceDataTable) Select(userID, deviceID string, swap bool) (result *internal.DeviceData, err error) { err = sqlutil.WithTransaction(t.db, func(txn *sqlx.Tx) error { var row DeviceDataRow - err = t.db.Get(&row, `SELECT data FROM syncv3_device_data WHERE user_id=$1 AND device_id=$2`, userID, deviceID) + err = txn.Get(&row, `SELECT data FROM syncv3_device_data WHERE user_id=$1 AND device_id=$2`, userID, deviceID) if err != nil { if err == sql.ErrNoRows { // if there is no device data for this user, it's not an error. @@ -78,7 +78,7 @@ func (t *DeviceDataTable) Select(userID, deviceID string, swap bool) (result *in // the device_data table. return nil } - _, err = t.db.Exec(`UPDATE syncv3_device_data SET data=$1 WHERE user_id=$2 AND device_id=$3`, data, userID, deviceID) + _, err = txn.Exec(`UPDATE syncv3_device_data SET data=$1 WHERE user_id=$2 AND device_id=$3`, data, userID, deviceID) return err }) return @@ -94,7 +94,7 @@ func (t *DeviceDataTable) Upsert(dd *internal.DeviceData) (pos int64, err error) err = sqlutil.WithTransaction(t.db, func(txn *sqlx.Tx) error { // select what already exists var row DeviceDataRow - err = t.db.Get(&row, `SELECT data FROM syncv3_device_data WHERE user_id=$1 AND device_id=$2`, dd.UserID, dd.DeviceID) + err = txn.Get(&row, `SELECT data FROM syncv3_device_data WHERE user_id=$1 AND device_id=$2`, dd.UserID, dd.DeviceID) if err != nil && err != sql.ErrNoRows { return err } @@ -119,7 +119,7 @@ func (t *DeviceDataTable) Upsert(dd *internal.DeviceData) (pos int64, err error) if err != nil { return err } - err = t.db.QueryRow( + err = txn.QueryRow( `INSERT INTO syncv3_device_data(user_id, device_id, data) VALUES($1,$2,$3) ON CONFLICT (user_id, device_id) DO UPDATE SET data=$3, id=nextval('syncv3_device_data_seq') RETURNING id`, dd.UserID, dd.DeviceID, data, diff --git a/state/event_table.go b/state/event_table.go index 1f152969..925844c3 100644 --- a/state/event_table.go +++ b/state/event_table.go @@ -317,6 +317,25 @@ func (t *EventTable) LatestEventInRooms(txn *sqlx.Tx, roomIDs []string, highestN return } +func (t *EventTable) LatestEventNIDInRooms(roomIDs []string, highestNID int64) (roomToNID map[string]int64, err error) { + // the position (event nid) may be for a random different room, so we need to find the highest nid <= this position for this room + var events []Event + err = t.db.Select( + &events, + `SELECT event_nid, room_id FROM syncv3_events + WHERE event_nid IN (SELECT max(event_nid) FROM syncv3_events WHERE event_nid <= $1 AND room_id = ANY($2) GROUP BY room_id)`, + highestNID, pq.StringArray(roomIDs), + ) + if err == sql.ErrNoRows { + err = nil + } + roomToNID = make(map[string]int64) + for _, ev := range events { + roomToNID[ev.RoomID] = ev.NID + } + return +} + func (t *EventTable) SelectEventsBetween(txn *sqlx.Tx, roomID string, lowerExclusive, upperInclusive int64, limit int) ([]Event, error) { var events []Event err := txn.Select(&events, `SELECT event_nid, event FROM syncv3_events WHERE event_nid > $1 AND event_nid <= $2 AND room_id = $3 ORDER BY event_nid ASC LIMIT $4`, @@ -419,8 +438,8 @@ func (t *EventTable) SelectClosestPrevBatchByID(roomID string, eventID string) ( // Select the closest prev batch token for the provided event NID. Returns the empty string if there // is no closest. -func (t *EventTable) SelectClosestPrevBatch(roomID string, eventNID int64) (prevBatch string, err error) { - err = t.db.QueryRow( +func (t *EventTable) SelectClosestPrevBatch(txn *sqlx.Tx, roomID string, eventNID int64) (prevBatch string, err error) { + err = txn.QueryRow( `SELECT prev_batch FROM syncv3_events WHERE prev_batch IS NOT NULL AND room_id=$1 AND event_nid >= $2 LIMIT 1`, roomID, eventNID, ).Scan(&prevBatch) if err == sql.ErrNoRows { diff --git a/state/event_table_test.go b/state/event_table_test.go index 246c2b76..76104a65 100644 --- a/state/event_table_test.go +++ b/state/event_table_test.go @@ -4,8 +4,10 @@ import ( "bytes" "database/sql" "fmt" + "reflect" "testing" + "github.com/jmoiron/sqlx" "github.com/tidwall/gjson" "github.com/matrix-org/sliding-sync/sqlutil" @@ -776,10 +778,14 @@ func TestEventTablePrevBatch(t *testing.T) { } assertPrevBatch := func(roomID string, index int, wantPrevBatch string) { - gotPrevBatch, err := table.SelectClosestPrevBatch(roomID, int64(idToNID[events[index].ID])) - if err != nil { - t.Fatalf("failed to SelectClosestPrevBatch: %s", err) - } + var gotPrevBatch string + _ = sqlutil.WithTransaction(db, func(txn *sqlx.Tx) error { + gotPrevBatch, err = table.SelectClosestPrevBatch(txn, roomID, int64(idToNID[events[index].ID])) + if err != nil { + t.Fatalf("failed to SelectClosestPrevBatch: %s", err) + } + return nil + }) if wantPrevBatch != "" { if gotPrevBatch == "" || gotPrevBatch != wantPrevBatch { t.Fatalf("SelectClosestPrevBatch: got %v want %v", gotPrevBatch, wantPrevBatch) @@ -871,6 +877,93 @@ func TestRemoveUnsignedTXNID(t *testing.T) { } } +func TestLatestEventNIDInRooms(t *testing.T) { + db, close := connectToDB(t) + defer close() + table := NewEventTable(db) + + var result map[string]int64 + var err error + // Insert the following: + // - Room FIRST: [N] + // - Room SECOND: [N+1, N+2, N+3] (replace) + // - Room THIRD: [N+4] (max) + first := "!FIRST" + second := "!SECOND" + third := "!THIRD" + err = sqlutil.WithTransaction(db, func(txn *sqlx.Tx) error { + result, err = table.Insert(txn, []Event{ + { + ID: "$N", + Type: "message", + RoomID: first, + JSON: []byte(`{}`), + }, + { + ID: "$N+1", + Type: "message", + RoomID: second, + JSON: []byte(`{}`), + }, + { + ID: "$N+2", + Type: "message", + RoomID: second, + JSON: []byte(`{}`), + }, + { + ID: "$N+3", + Type: "message", + RoomID: second, + JSON: []byte(`{}`), + }, + { + ID: "$N+4", + Type: "message", + RoomID: third, + JSON: []byte(`{}`), + }, + }, false) + return err + }) + assertNoError(t, err) + + testCases := []struct { + roomIDs []string + highestNID int64 + wantMap map[string]string + }{ + // We should see FIRST=N, SECOND=N+3, THIRD=N+4 when querying LatestEventNIDInRooms with N+4 + { + roomIDs: []string{first, second, third}, + highestNID: result["$N+4"], + wantMap: map[string]string{ + first: "$N", second: "$N+3", third: "$N+4", + }, + }, + // We should see FIRST=N, SECOND=N+2 when querying LatestEventNIDInRooms with N+2 + { + roomIDs: []string{first, second, third}, + highestNID: result["$N+2"], + wantMap: map[string]string{ + first: "$N", second: "$N+2", + }, + }, + } + for _, tc := range testCases { + gotRoomToNID, err := table.LatestEventNIDInRooms(tc.roomIDs, int64(tc.highestNID)) + assertNoError(t, err) + want := make(map[string]int64) // map event IDs to nids + for roomID, eventID := range tc.wantMap { + want[roomID] = int64(result[eventID]) + } + if !reflect.DeepEqual(gotRoomToNID, want) { + t.Errorf("%+v: got %v want %v", tc, gotRoomToNID, want) + } + } + +} + func TestEventTableSelectUnknownEventIDs(t *testing.T) { db, close := connectToDB(t) defer close() diff --git a/state/storage.go b/state/storage.go index 2c0481fa..ff70d981 100644 --- a/state/storage.go +++ b/state/storage.go @@ -31,6 +31,12 @@ type StartupSnapshot struct { AllJoinedMembers map[string][]string // room_id -> [user_id] } +type LatestEvents struct { + Timeline []json.RawMessage + PrevBatch string + LatestNID int64 +} + type Storage struct { Accumulator *Accumulator EventsTable *EventTable @@ -535,7 +541,7 @@ func (s *Storage) RoomStateAfterEventPosition(ctx context.Context, roomIDs []str if err != nil { return fmt.Errorf("failed to form sql query: %s", err) } - rows, err := s.Accumulator.db.Query(s.Accumulator.db.Rebind(query), args...) + rows, err := txn.Query(txn.Rebind(query), args...) if err != nil { return fmt.Errorf("failed to execute query: %s", err) } @@ -580,16 +586,16 @@ func (s *Storage) RoomStateAfterEventPosition(ctx context.Context, roomIDs []str return } -func (s *Storage) LatestEventsInRooms(userID string, roomIDs []string, to int64, limit int) (map[string][]json.RawMessage, map[string]string, error) { +func (s *Storage) LatestEventsInRooms(userID string, roomIDs []string, to int64, limit int) (map[string]*LatestEvents, error) { roomIDToRanges, err := s.visibleEventNIDsBetweenForRooms(userID, roomIDs, 0, to) if err != nil { - return nil, nil, err + return nil, err } - result := make(map[string][]json.RawMessage, len(roomIDs)) - prevBatches := make(map[string]string, len(roomIDs)) + result := make(map[string]*LatestEvents, len(roomIDs)) err = sqlutil.WithTransaction(s.Accumulator.db, func(txn *sqlx.Tx) error { for roomID, ranges := range roomIDToRanges { var earliestEventNID int64 + var latestEventNID int64 var roomEvents []json.RawMessage // start at the most recent range as we want to return the most recent `limit` events for i := len(ranges) - 1; i >= 0; i-- { @@ -604,6 +610,9 @@ func (s *Storage) LatestEventsInRooms(userID string, roomIDs []string, to int64, } // keep pushing to the front so we end up with A,B,C for _, ev := range events { + if latestEventNID == 0 { // set first time and never again + latestEventNID = ev.NID + } roomEvents = append([]json.RawMessage{ev.JSON}, roomEvents...) earliestEventNID = ev.NID if len(roomEvents) >= limit { @@ -611,19 +620,23 @@ func (s *Storage) LatestEventsInRooms(userID string, roomIDs []string, to int64, } } } + latestEvents := LatestEvents{ + LatestNID: latestEventNID, + Timeline: roomEvents, + } if earliestEventNID != 0 { // the oldest event needs a prev batch token, so find one now - prevBatch, err := s.EventsTable.SelectClosestPrevBatch(roomID, earliestEventNID) + prevBatch, err := s.EventsTable.SelectClosestPrevBatch(txn, roomID, earliestEventNID) if err != nil { return fmt.Errorf("failed to select prev_batch for room %s : %s", roomID, err) } - prevBatches[roomID] = prevBatch + latestEvents.PrevBatch = prevBatch } - result[roomID] = roomEvents + result[roomID] = &latestEvents } return nil }) - return result, prevBatches, err + return result, err } func (s *Storage) visibleEventNIDsBetweenForRooms(userID string, roomIDs []string, from, to int64) (map[string][][2]int64, error) { @@ -637,7 +650,7 @@ func (s *Storage) visibleEventNIDsBetweenForRooms(userID string, roomIDs []strin return nil, fmt.Errorf("VisibleEventNIDsBetweenForRooms.SelectEventsWithTypeStateKeyInRooms: %s", err) } } - joinTimingsByRoomID, err := s.determineJoinedRoomsFromMemberships(membershipEvents) + joinTimingsAtFromByRoomID, err := s.determineJoinedRoomsFromMemberships(membershipEvents) if err != nil { return nil, fmt.Errorf("failed to work out joined rooms for %s at pos %d: %s", userID, from, err) } @@ -648,7 +661,7 @@ func (s *Storage) visibleEventNIDsBetweenForRooms(userID string, roomIDs []strin return nil, fmt.Errorf("failed to load membership events: %s", err) } - return s.visibleEventNIDsWithData(joinTimingsByRoomID, membershipEvents, userID, from, to) + return s.visibleEventNIDsWithData(joinTimingsAtFromByRoomID, membershipEvents, userID, from, to) } // Work out the NID ranges to pull events from for this user. Given a from and to event nid stream position, @@ -678,7 +691,7 @@ func (s *Storage) visibleEventNIDsBetweenForRooms(userID string, roomIDs []strin // - For Room E: from=1, to=15 returns { RoomE: [ [3,3], [13,15] ] } (tests invites) func (s *Storage) VisibleEventNIDsBetween(userID string, from, to int64) (map[string][][2]int64, error) { // load *ALL* joined rooms for this user at from (inclusive) - joinTimingsByRoomID, err := s.JoinedRoomsAfterPosition(userID, from) + joinTimingsAtFromByRoomID, err := s.JoinedRoomsAfterPosition(userID, from) if err != nil { return nil, fmt.Errorf("failed to work out joined rooms for %s at pos %d: %s", userID, from, err) } @@ -689,10 +702,10 @@ func (s *Storage) VisibleEventNIDsBetween(userID string, from, to int64) (map[st return nil, fmt.Errorf("failed to load membership events: %s", err) } - return s.visibleEventNIDsWithData(joinTimingsByRoomID, membershipEvents, userID, from, to) + return s.visibleEventNIDsWithData(joinTimingsAtFromByRoomID, membershipEvents, userID, from, to) } -func (s *Storage) visibleEventNIDsWithData(joinTimingsByRoomID map[string]internal.EventMetadata, membershipEvents []Event, userID string, from, to int64) (map[string][][2]int64, error) { +func (s *Storage) visibleEventNIDsWithData(joinTimingsAtFromByRoomID map[string]internal.EventMetadata, membershipEvents []Event, userID string, from, to int64) (map[string][][2]int64, error) { // load membership events in order and bucket based on room ID roomIDToLogs := make(map[string][]membershipEvent) for _, ev := range membershipEvents { @@ -754,7 +767,7 @@ func (s *Storage) visibleEventNIDsWithData(joinTimingsByRoomID map[string]intern // For each joined room, perform the algorithm and delete the logs afterwards result := make(map[string][][2]int64) - for joinedRoomID, _ := range joinTimingsByRoomID { + for joinedRoomID, _ := range joinTimingsAtFromByRoomID { roomResult := calculateVisibleEventNIDs(true, from, to, roomIDToLogs[joinedRoomID]) result[joinedRoomID] = roomResult delete(roomIDToLogs, joinedRoomID) diff --git a/state/storage_test.go b/state/storage_test.go index a5067418..e4b053b5 100644 --- a/state/storage_test.go +++ b/state/storage_test.go @@ -566,10 +566,15 @@ func TestStorageLatestEventsInRoomsPrevBatch(t *testing.T) { wantPrevBatch := wantPrevBatches[i] eventNID := idsToNIDs[eventIDs[i]] // closest batch to the last event in the chunk (latest nid) is always the next prev batch token - pb, err := store.EventsTable.SelectClosestPrevBatch(roomID, eventNID) - if err != nil { - t.Fatalf("failed to SelectClosestPrevBatch: %s", err) - } + var pb string + _ = sqlutil.WithTransaction(store.DB, func(txn *sqlx.Tx) (err error) { + pb, err = store.EventsTable.SelectClosestPrevBatch(txn, roomID, eventNID) + if err != nil { + t.Fatalf("failed to SelectClosestPrevBatch: %s", err) + } + return nil + }) + if pb != wantPrevBatch { t.Fatalf("SelectClosestPrevBatch: got %v want %v", pb, wantPrevBatch) } diff --git a/sync2/devices_table.go b/sync2/devices_table.go index ab6ddb9f..30e68df6 100644 --- a/sync2/devices_table.go +++ b/sync2/devices_table.go @@ -32,8 +32,8 @@ func NewDevicesTable(db *sqlx.DB) *DevicesTable { // InsertDevice creates a new devices row with a blank since token if no such row // exists. Otherwise, it does nothing. -func (t *DevicesTable) InsertDevice(userID, deviceID string) error { - _, err := t.db.Exec( +func (t *DevicesTable) InsertDevice(txn *sqlx.Tx, userID, deviceID string) error { + _, err := txn.Exec( ` INSERT INTO syncv3_sync2_devices(user_id, device_id, since) VALUES($1,$2,$3) ON CONFLICT (user_id, device_id) DO NOTHING`, userID, deviceID, "", diff --git a/sync2/devices_table_test.go b/sync2/devices_table_test.go index 5f70846a..1db3564d 100644 --- a/sync2/devices_table_test.go +++ b/sync2/devices_table_test.go @@ -2,6 +2,7 @@ package sync2 import ( "github.com/jmoiron/sqlx" + "github.com/matrix-org/sliding-sync/sqlutil" "os" "sort" "testing" @@ -41,18 +42,25 @@ func TestDevicesTableSinceColumn(t *testing.T) { aliceSecret1 := "mysecret1" aliceSecret2 := "mysecret2" - t.Log("Insert two tokens for Alice.") - aliceToken, err := tokens.Insert(aliceSecret1, alice, aliceDevice, time.Now()) - if err != nil { - t.Fatalf("Failed to Insert token: %s", err) - } - aliceToken2, err := tokens.Insert(aliceSecret2, alice, aliceDevice, time.Now()) - if err != nil { - t.Fatalf("Failed to Insert token: %s", err) - } + var aliceToken, aliceToken2 *Token + _ = sqlutil.WithTransaction(db, func(txn *sqlx.Tx) (err error) { + t.Log("Insert two tokens for Alice.") + aliceToken, err = tokens.Insert(txn, aliceSecret1, alice, aliceDevice, time.Now()) + if err != nil { + t.Fatalf("Failed to Insert token: %s", err) + } + aliceToken2, err = tokens.Insert(txn, aliceSecret2, alice, aliceDevice, time.Now()) + if err != nil { + t.Fatalf("Failed to Insert token: %s", err) + } - t.Log("Add a devices row for Alice") - err = devices.InsertDevice(alice, aliceDevice) + t.Log("Add a devices row for Alice") + err = devices.InsertDevice(txn, alice, aliceDevice) + if err != nil { + t.Fatalf("Failed to Insert device: %s", err) + } + return nil + }) t.Log("Pretend we're about to start a poller. Fetch Alice's token along with the since value tracked by the devices table.") accessToken, since, err := tokens.GetTokenAndSince(alice, aliceDevice, aliceToken.AccessTokenHash) @@ -104,40 +112,50 @@ func TestTokenForEachDevice(t *testing.T) { chris := "chris" chrisDevice := "chris_desktop" - t.Log("Add a device for Alice, Bob and Chris.") - err := devices.InsertDevice(alice, aliceDevice) - if err != nil { - t.Fatalf("InsertDevice returned error: %s", err) - } - err = devices.InsertDevice(bob, bobDevice) - if err != nil { - t.Fatalf("InsertDevice returned error: %s", err) - } - err = devices.InsertDevice(chris, chrisDevice) - if err != nil { - t.Fatalf("InsertDevice returned error: %s", err) - } + _ = sqlutil.WithTransaction(db, func(txn *sqlx.Tx) error { + t.Log("Add a device for Alice, Bob and Chris.") + err := devices.InsertDevice(txn, alice, aliceDevice) + if err != nil { + t.Fatalf("InsertDevice returned error: %s", err) + } + err = devices.InsertDevice(txn, bob, bobDevice) + if err != nil { + t.Fatalf("InsertDevice returned error: %s", err) + } + err = devices.InsertDevice(txn, chris, chrisDevice) + if err != nil { + t.Fatalf("InsertDevice returned error: %s", err) + } + return nil + }) t.Log("Mark Alice's device with a since token.") sinceValue := "s-1-2-3-4" - devices.UpdateDeviceSince(alice, aliceDevice, sinceValue) - - t.Log("Insert 2 tokens for Alice, one for Bob and none for Chris.") - aliceLastSeen1 := time.Now() - _, err = tokens.Insert("alice_secret", alice, aliceDevice, aliceLastSeen1) - if err != nil { - t.Fatalf("Failed to Insert token: %s", err) - } - aliceLastSeen2 := aliceLastSeen1.Add(1 * time.Minute) - aliceToken2, err := tokens.Insert("alice_secret2", alice, aliceDevice, aliceLastSeen2) + err := devices.UpdateDeviceSince(alice, aliceDevice, sinceValue) if err != nil { - t.Fatalf("Failed to Insert token: %s", err) - } - bobToken, err := tokens.Insert("bob_secret", bob, bobDevice, time.Time{}) - if err != nil { - t.Fatalf("Failed to Insert token: %s", err) + t.Fatalf("UpdateDeviceSince returned error: %s", err) } + var aliceToken2, bobToken *Token + _ = sqlutil.WithTransaction(db, func(txn *sqlx.Tx) error { + t.Log("Insert 2 tokens for Alice, one for Bob and none for Chris.") + aliceLastSeen1 := time.Now() + _, err = tokens.Insert(txn, "alice_secret", alice, aliceDevice, aliceLastSeen1) + if err != nil { + t.Fatalf("Failed to Insert token: %s", err) + } + aliceLastSeen2 := aliceLastSeen1.Add(1 * time.Minute) + aliceToken2, err = tokens.Insert(txn, "alice_secret2", alice, aliceDevice, aliceLastSeen2) + if err != nil { + t.Fatalf("Failed to Insert token: %s", err) + } + bobToken, err = tokens.Insert(txn, "bob_secret", bob, bobDevice, time.Time{}) + if err != nil { + t.Fatalf("Failed to Insert token: %s", err) + } + return nil + }) + t.Log("Fetch a token for every device") gotTokens, err := tokens.TokenForEachDevice(nil) if err != nil { diff --git a/sync2/handler2/handler.go b/sync2/handler2/handler.go index 37b068ce..506a4813 100644 --- a/sync2/handler2/handler.go +++ b/sync2/handler2/handler.go @@ -12,6 +12,9 @@ import ( "github.com/jmoiron/sqlx" "github.com/matrix-org/sliding-sync/sqlutil" + "github.com/jmoiron/sqlx" + "github.com/matrix-org/sliding-sync/sqlutil" + "github.com/getsentry/sentry-go" "github.com/matrix-org/sliding-sync/internal" @@ -32,12 +35,11 @@ var logger = zerolog.New(os.Stdout).With().Timestamp().Logger().Output(zerolog.C // processing v2 data (as a sync2.V2DataReceiver) and publishing updates (pubsub.Payload to V2Listeners); // and receiving and processing EnsurePolling events. type Handler struct { - pMap *sync2.PollerMap + pMap sync2.IPollerMap v2Store *sync2.Storage Store *state.Storage v2Pub pubsub.Notifier v3Sub *pubsub.V3Sub - client sync2.Client unreadMap map[string]struct { Highlight int Notif int @@ -53,13 +55,12 @@ type Handler struct { } func NewHandler( - connStr string, pMap *sync2.PollerMap, v2Store *sync2.Storage, store *state.Storage, client sync2.Client, + pMap *sync2.PollerMap, v2Store *sync2.Storage, store *state.Storage, pub pubsub.Notifier, sub pubsub.Listener, enablePrometheus bool, deviceDataUpdateDuration time.Duration, ) (*Handler, error) { h := &Handler{ pMap: pMap, v2Store: v2Store, - client: client, Store: store, subSystem: "poller", unreadMap: make(map[string]struct { diff --git a/sync2/handler2/handler_test.go b/sync2/handler2/handler_test.go new file mode 100644 index 00000000..20f064ab --- /dev/null +++ b/sync2/handler2/handler_test.go @@ -0,0 +1,170 @@ +package handler2_test + +import ( + "github.com/jmoiron/sqlx" + "github.com/matrix-org/sliding-sync/sqlutil" + "os" + "reflect" + "sync" + "testing" + "time" + + "github.com/matrix-org/sliding-sync/pubsub" + "github.com/matrix-org/sliding-sync/state" + "github.com/matrix-org/sliding-sync/sync2" + "github.com/matrix-org/sliding-sync/sync2/handler2" + "github.com/matrix-org/sliding-sync/testutils" + "github.com/rs/zerolog" +) + +var postgresURI string + +func TestMain(m *testing.M) { + postgresURI = testutils.PrepareDBConnectionString() + exitCode := m.Run() + os.Exit(exitCode) +} + +type pollInfo struct { + pid sync2.PollerID + accessToken string + v2since string + isStartup bool +} + +type mockPollerMap struct { + calls []pollInfo +} + +func (p *mockPollerMap) NumPollers() int { + return 0 +} +func (p *mockPollerMap) Terminate() {} + +func (p *mockPollerMap) EnsurePolling(pid sync2.PollerID, accessToken, v2since string, isStartup bool, logger zerolog.Logger) { + p.calls = append(p.calls, pollInfo{ + pid: pid, + accessToken: accessToken, + v2since: v2since, + isStartup: isStartup, + }) +} +func (p *mockPollerMap) assertCallExists(t *testing.T, pi pollInfo) { + for _, c := range p.calls { + if reflect.DeepEqual(pi, c) { + return + } + } + t.Fatalf("assertCallExists: did not find %+v", pi) +} + +type mockPub struct { + calls []pubsub.Payload + mu *sync.Mutex + waiters map[string][]chan struct{} +} + +func newMockPub() *mockPub { + return &mockPub{ + mu: &sync.Mutex{}, + waiters: make(map[string][]chan struct{}), + } +} + +// Notify chanName that there is a new payload p. Return an error if we failed to send the notification. +func (p *mockPub) Notify(chanName string, payload pubsub.Payload) error { + p.calls = append(p.calls, payload) + p.mu.Lock() + for _, ch := range p.waiters[payload.Type()] { + close(ch) + } + p.waiters[payload.Type()] = nil // don't re-notify for 2nd+ payload + p.mu.Unlock() + return nil +} + +func (p *mockPub) WaitForPayloadType(t string) chan struct{} { + ch := make(chan struct{}) + p.mu.Lock() + p.waiters[t] = append(p.waiters[t], ch) + p.mu.Unlock() + return ch +} + +func (p *mockPub) DoWait(t *testing.T, errMsg string, ch chan struct{}) { + select { + case <-ch: + return + case <-time.After(time.Second): + t.Fatalf("DoWait: timed out waiting: %s", errMsg) + } +} + +// Close is called when we should stop listening. +func (p *mockPub) Close() error { return nil } + +type mockSub struct{} + +// Begin listening on this channel with this callback starting from this position. Blocks until Close() is called. +func (s *mockSub) Listen(chanName string, fn func(p pubsub.Payload)) error { return nil } + +// Close the listener. No more callbacks should fire. +func (s *mockSub) Close() error { return nil } + +func assertNoError(t *testing.T, err error) { + t.Helper() + if err == nil { + return + } + t.Fatalf("assertNoError: %v", err) +} + +// Test that if you call EnsurePolling you get back V2InitialSyncComplete down pubsub and the poller +// map is called correctly +func TestHandlerFreshEnsurePolling(t *testing.T) { + store := state.NewStorage(postgresURI) + v2Store := sync2.NewStore(postgresURI, "secret") + pMap := &mockPollerMap{} + pub := newMockPub() + sub := &mockSub{} + h, err := handler2.NewHandler(pMap, v2Store, store, pub, sub, false) + assertNoError(t, err) + alice := "@alice:localhost" + deviceID := "ALICE" + token := "aliceToken" + + var tok *sync2.Token + sqlutil.WithTransaction(v2Store.DB, func(txn *sqlx.Tx) error { + // the device and token needs to already exist prior to EnsurePolling + err = v2Store.DevicesTable.InsertDevice(txn, alice, deviceID) + assertNoError(t, err) + tok, err = v2Store.TokensTable.Insert(txn, token, alice, deviceID, time.Now()) + assertNoError(t, err) + return nil + }) + + payloadInitialSyncComplete := pubsub.V2InitialSyncComplete{ + UserID: alice, + DeviceID: deviceID, + } + ch := pub.WaitForPayloadType(payloadInitialSyncComplete.Type()) + // ask the handler to start polling + h.EnsurePolling(&pubsub.V3EnsurePolling{ + UserID: alice, + DeviceID: deviceID, + AccessTokenHash: tok.AccessTokenHash, + }) + pub.DoWait(t, "didn't see V2InitialSyncComplete", ch) + + // make sure we polled with the token i.e it did a db hit + pMap.assertCallExists(t, pollInfo{ + pid: sync2.PollerID{ + UserID: alice, + DeviceID: deviceID, + }, + accessToken: token, + v2since: "", + isStartup: false, + }) + +} diff --git a/sync2/poller.go b/sync2/poller.go index 9da833d3..f6c99a18 100644 --- a/sync2/poller.go +++ b/sync2/poller.go @@ -59,6 +59,12 @@ type V2DataReceiver interface { OnExpiredToken(ctx context.Context, accessTokenHash, userID, deviceID string) } +type IPollerMap interface { + EnsurePolling(pid PollerID, accessToken, v2since string, isStartup bool, logger zerolog.Logger) + NumPollers() int + Terminate() +} + // PollerMap is a map of device ID to Poller type PollerMap struct { v2Client Client @@ -508,7 +514,8 @@ func (p *poller) poll(ctx context.Context, s *pollLoopState) error { } if err != nil { // check if temporary - if statusCode != 401 { + isFatal := statusCode == 401 || statusCode == 403 + if !isFatal { p.logger.Warn().Int("code", statusCode).Err(err).Msg("Poller: sync v2 poll returned temporary error") s.failCount += 1 return nil diff --git a/sync2/storage.go b/sync2/storage.go index 753c2f82..5f484179 100644 --- a/sync2/storage.go +++ b/sync2/storage.go @@ -1,10 +1,11 @@ package sync2 import ( + "os" + "github.com/getsentry/sentry-go" "github.com/jmoiron/sqlx" "github.com/rs/zerolog" - "os" ) var logger = zerolog.New(os.Stdout).With().Timestamp().Logger().Output(zerolog.ConsoleWriter{ diff --git a/sync2/tokens_table.go b/sync2/tokens_table.go index 066c6508..961e7bdd 100644 --- a/sync2/tokens_table.go +++ b/sync2/tokens_table.go @@ -171,10 +171,10 @@ func (t *TokensTable) TokenForEachDevice(txn *sqlx.Tx) (tokens []TokenForPoller, } // Insert a new token into the table. -func (t *TokensTable) Insert(plaintextToken, userID, deviceID string, lastSeen time.Time) (*Token, error) { +func (t *TokensTable) Insert(txn *sqlx.Tx, plaintextToken, userID, deviceID string, lastSeen time.Time) (*Token, error) { hashedToken := hashToken(plaintextToken) encToken := t.encrypt(plaintextToken) - _, err := t.db.Exec( + _, err := txn.Exec( `INSERT INTO syncv3_sync2_tokens(token_hash, token_encrypted, user_id, device_id, last_seen) VALUES ($1, $2, $3, $4, $5) ON CONFLICT (token_hash) DO NOTHING;`, diff --git a/sync2/tokens_table_test.go b/sync2/tokens_table_test.go index 9249077e..c787b2a0 100644 --- a/sync2/tokens_table_test.go +++ b/sync2/tokens_table_test.go @@ -1,6 +1,8 @@ package sync2 import ( + "github.com/jmoiron/sqlx" + "github.com/matrix-org/sliding-sync/sqlutil" "testing" "time" ) @@ -26,27 +28,31 @@ func TestTokensTable(t *testing.T) { aliceSecret1 := "mysecret1" aliceToken1FirstSeen := time.Now() - // Test a single token - t.Log("Insert a new token from Alice.") - aliceToken, err := tokens.Insert(aliceSecret1, alice, aliceDevice, aliceToken1FirstSeen) - if err != nil { - t.Fatalf("Failed to Insert token: %s", err) - } - - t.Log("The returned Token struct should have been populated correctly.") - assertEqualTokens(t, tokens, aliceToken, aliceSecret1, alice, aliceDevice, aliceToken1FirstSeen) - - t.Log("Reinsert the same token.") - reinsertedToken, err := tokens.Insert(aliceSecret1, alice, aliceDevice, aliceToken1FirstSeen) - if err != nil { - t.Fatalf("Failed to Insert token: %s", err) - } + var aliceToken, reinsertedToken *Token + _ = sqlutil.WithTransaction(db, func(txn *sqlx.Tx) (err error) { + // Test a single token + t.Log("Insert a new token from Alice.") + aliceToken, err = tokens.Insert(txn, aliceSecret1, alice, aliceDevice, aliceToken1FirstSeen) + if err != nil { + t.Fatalf("Failed to Insert token: %s", err) + } + + t.Log("The returned Token struct should have been populated correctly.") + assertEqualTokens(t, tokens, aliceToken, aliceSecret1, alice, aliceDevice, aliceToken1FirstSeen) + + t.Log("Reinsert the same token.") + reinsertedToken, err = tokens.Insert(txn, aliceSecret1, alice, aliceDevice, aliceToken1FirstSeen) + if err != nil { + t.Fatalf("Failed to Insert token: %s", err) + } + return nil + }) t.Log("This should yield an equal Token struct.") assertEqualTokens(t, tokens, reinsertedToken, aliceSecret1, alice, aliceDevice, aliceToken1FirstSeen) t.Log("Try to mark Alice's token as being used after an hour.") - err = tokens.MaybeUpdateLastSeen(aliceToken, aliceToken1FirstSeen.Add(time.Hour)) + err := tokens.MaybeUpdateLastSeen(aliceToken, aliceToken1FirstSeen.Add(time.Hour)) if err != nil { t.Fatalf("Failed to update last seen: %s", err) } @@ -74,17 +80,20 @@ func TestTokensTable(t *testing.T) { } assertEqualTokens(t, tokens, fetchedToken, aliceSecret1, alice, aliceDevice, aliceToken1LastSeen) - // Test a second token for Alice - t.Log("Insert a second token for Alice.") - aliceSecret2 := "mysecret2" - aliceToken2FirstSeen := aliceToken1LastSeen.Add(time.Minute) - aliceToken2, err := tokens.Insert(aliceSecret2, alice, aliceDevice, aliceToken2FirstSeen) - if err != nil { - t.Fatalf("Failed to Insert token: %s", err) - } - - t.Log("The returned Token struct should have been populated correctly.") - assertEqualTokens(t, tokens, aliceToken2, aliceSecret2, alice, aliceDevice, aliceToken2FirstSeen) + _ = sqlutil.WithTransaction(db, func(txn *sqlx.Tx) error { + // Test a second token for Alice + t.Log("Insert a second token for Alice.") + aliceSecret2 := "mysecret2" + aliceToken2FirstSeen := aliceToken1LastSeen.Add(time.Minute) + aliceToken2, err := tokens.Insert(txn, aliceSecret2, alice, aliceDevice, aliceToken2FirstSeen) + if err != nil { + t.Fatalf("Failed to Insert token: %s", err) + } + + t.Log("The returned Token struct should have been populated correctly.") + assertEqualTokens(t, tokens, aliceToken2, aliceSecret2, alice, aliceDevice, aliceToken2FirstSeen) + return nil + }) } func TestDeletingTokens(t *testing.T) { @@ -94,11 +103,15 @@ func TestDeletingTokens(t *testing.T) { t.Log("Insert a new token from Alice.") accessToken := "mytoken" - token, err := tokens.Insert(accessToken, "@bob:builders.com", "device", time.Time{}) - if err != nil { - t.Fatalf("Failed to Insert token: %s", err) - } + var token *Token + err := sqlutil.WithTransaction(db, func(txn *sqlx.Tx) (err error) { + token, err = tokens.Insert(txn, accessToken, "@bob:builders.com", "device", time.Time{}) + if err != nil { + t.Fatalf("Failed to Insert token: %s", err) + } + return nil + }) t.Log("We should be able to fetch this token without error.") _, err = tokens.Token(accessToken) if err != nil { diff --git a/sync3/caches/global.go b/sync3/caches/global.go index 28e9fd72..36c2867a 100644 --- a/sync3/caches/global.go +++ b/sync3/caches/global.go @@ -58,7 +58,7 @@ var logger = zerolog.New(os.Stdout).With().Timestamp().Logger().Output(zerolog.C // Dispatcher for new events. type GlobalCache struct { // LoadJoinedRoomsOverride allows tests to mock out the behaviour of LoadJoinedRooms. - LoadJoinedRoomsOverride func(userID string) (pos int64, joinedRooms map[string]*internal.RoomMetadata, joinTimings map[string]internal.EventMetadata, err error) + LoadJoinedRoomsOverride func(userID string) (pos int64, joinedRooms map[string]*internal.RoomMetadata, joinTimings map[string]internal.EventMetadata, latestNIDs map[string]int64, err error) // inserts are done by v2 poll loops, selects are done by v3 request threads // there are lots of overlapping keys as many users (threads) can be joined to the same room (key) @@ -135,23 +135,37 @@ func (c *GlobalCache) copyRoom(roomID string) *internal.RoomMetadata { // The two maps returned by this function have exactly the same set of keys. Each is nil // iff a non-nil error is returned. // TODO: remove with LoadRoomState? +// FIXME: return args are a mess func (c *GlobalCache) LoadJoinedRooms(ctx context.Context, userID string) ( - pos int64, joinedRooms map[string]*internal.RoomMetadata, joinTimingByRoomID map[string]internal.EventMetadata, err error, + pos int64, joinedRooms map[string]*internal.RoomMetadata, joinTimingByRoomID map[string]internal.EventMetadata, + latestNIDs map[string]int64, err error, ) { if c.LoadJoinedRoomsOverride != nil { return c.LoadJoinedRoomsOverride(userID) } initialLoadPosition, err := c.store.LatestEventNID() if err != nil { - return 0, nil, nil, err + return 0, nil, nil, nil, err } joinTimingByRoomID, err = c.store.JoinedRoomsAfterPosition(userID, initialLoadPosition) if err != nil { - return 0, nil, nil, err + return 0, nil, nil, nil, err } + roomIDs := make([]string, len(joinTimingByRoomID)) + i := 0 + for roomID := range joinTimingByRoomID { + roomIDs[i] = roomID + i++ + } + + latestNIDs, err = c.store.EventsTable.LatestEventNIDInRooms(roomIDs, initialLoadPosition) + if err != nil { + return 0, nil, nil, nil, err + } + // TODO: no guarantee that this state is the same as latest unless called in a dispatcher loop rooms := c.LoadRoomsFromMap(ctx, joinTimingByRoomID) - return initialLoadPosition, rooms, joinTimingByRoomID, nil + return initialLoadPosition, rooms, joinTimingByRoomID, latestNIDs, nil } func (c *GlobalCache) LoadStateEvent(ctx context.Context, roomID string, loadPosition int64, evType, stateKey string) json.RawMessage { diff --git a/sync3/caches/user.go b/sync3/caches/user.go index 8b96ddd4..ef160f8b 100644 --- a/sync3/caches/user.go +++ b/sync3/caches/user.go @@ -42,9 +42,9 @@ type UserRoomData struct { HighlightCount int Invite *InviteData - // these fields are set by LazyLoadTimelines and are per-function call, and are not persisted in-memory. - RequestedPrevBatch string - RequestedTimeline []json.RawMessage + // this field is set by LazyLoadTimelines and is per-function call, and is not persisted in-memory. + // The zero value of this safe to use (0 latest nid, no prev batch, no timeline). + RequestedLatestEvents state.LatestEvents // TODO: should Canonicalised really be in RoomConMetadata? It's only set in SetRoom AFAICS CanonicalisedName string // stripped leading symbols like #, all in lower case @@ -218,7 +218,7 @@ func (c *UserCache) Unsubscribe(id int) { func (c *UserCache) OnRegistered(ctx context.Context) error { // select all spaces the user is a part of to seed the cache correctly. This has to be done in // the OnRegistered callback which has locking guarantees. This is why... - _, joinedRooms, joinTimings, err := c.globalCache.LoadJoinedRooms(ctx, c.UserID) + _, joinedRooms, joinTimings, _, err := c.globalCache.LoadJoinedRooms(ctx, c.UserID) if err != nil { return fmt.Errorf("failed to load joined rooms: %s", err) } @@ -295,7 +295,7 @@ func (c *UserCache) LazyLoadTimelines(ctx context.Context, loadPos int64, roomID return c.LazyRoomDataOverride(loadPos, roomIDs, maxTimelineEvents) } result := make(map[string]UserRoomData) - roomIDToEvents, roomIDToPrevBatch, err := c.store.LatestEventsInRooms(c.UserID, roomIDs, loadPos, maxTimelineEvents) + roomIDToLatestEvents, err := c.store.LatestEventsInRooms(c.UserID, roomIDs, loadPos, maxTimelineEvents) if err != nil { logger.Err(err).Strs("rooms", roomIDs).Msg("failed to get LatestEventsInRooms") internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(err) @@ -303,16 +303,14 @@ func (c *UserCache) LazyLoadTimelines(ctx context.Context, loadPos int64, roomID } c.roomToDataMu.Lock() for _, requestedRoomID := range roomIDs { - events := roomIDToEvents[requestedRoomID] + latestEvents := roomIDToLatestEvents[requestedRoomID] urd, ok := c.roomToData[requestedRoomID] if !ok { urd = NewUserRoomData() } - urd.RequestedTimeline = events - if len(events) > 0 { - urd.RequestedPrevBatch = roomIDToPrevBatch[requestedRoomID] + if latestEvents != nil { + urd.RequestedLatestEvents = *latestEvents } - result[requestedRoomID] = urd } c.roomToDataMu.Unlock() diff --git a/sync3/handler/connstate.go b/sync3/handler/connstate.go index 04d5e237..a0aae60b 100644 --- a/sync3/handler/connstate.go +++ b/sync3/handler/connstate.go @@ -30,8 +30,15 @@ type ConnState struct { // "is the user joined to this room?" whereas subscriptions in muxedReq are untrusted. roomSubscriptions map[string]sync3.RoomSubscription // room_id -> subscription - // TODO: remove this as it is unreliable when you have concurrent updates - loadPosition int64 + // This is some event NID which is used to anchor any requests for room data from the database + // to their per-room latest NIDs. It does this by selecting the latest NID for each requested room + // where the NID is <= this anchor value. Note that there are no ordering guarantees here: it's + // possible for the anchor to be higher than room X's latest NID and for this connection to have + // not yet seen room X's latest NID (it'll be sitting in the live buffer). This is why it's important + // that ConnState DOES NOT ignore events based on this value - it must ignore events based on the real + // load position for the room. + // If this value is negative or 0, it means that this connection has not been loaded yet. + anchorLoadPosition int64 // roomID -> latest load pos loadPositions map[string]int64 @@ -59,7 +66,7 @@ func NewConnState( userCache: userCache, userID: userID, deviceID: deviceID, - loadPosition: -1, + anchorLoadPosition: -1, loadPositions: make(map[string]int64), roomSubscriptions: make(map[string]sync3.RoomSubscription), lists: sync3.NewInternalRequestLists(), @@ -73,6 +80,8 @@ func NewConnState( ConnState: cs, updates: make(chan caches.Update, maxPendingEventUpdates), } + // subscribe for updates before loading. We risk seeing dupes but that's fine as load positions + // will stop us double-processing. cs.userCacheID = cs.userCache.Subsribe(cs) return cs } @@ -89,10 +98,13 @@ func NewConnState( // - load() bases its current state based on the latest position, which includes processing of these N events. // - post load() we read N events, processing them a 2nd time. func (s *ConnState) load(ctx context.Context, req *sync3.Request) error { - initialLoadPosition, joinedRooms, joinTimings, err := s.globalCache.LoadJoinedRooms(ctx, s.userID) + initialLoadPosition, joinedRooms, joinTimings, loadPositions, err := s.globalCache.LoadJoinedRooms(ctx, s.userID) if err != nil { return err } + for roomID, pos := range loadPositions { + s.loadPositions[roomID] = pos + } rooms := make([]sync3.RoomConnMetadata, len(joinedRooms)) i := 0 for _, metadata := range joinedRooms { @@ -145,16 +157,21 @@ func (s *ConnState) load(ctx context.Context, req *sync3.Request) error { for _, r := range rooms { s.lists.SetRoom(r) } - s.loadPosition = initialLoadPosition + s.anchorLoadPosition = initialLoadPosition return nil } // OnIncomingRequest is guaranteed to be called sequentially (it's protected by a mutex in conn.go) func (s *ConnState) OnIncomingRequest(ctx context.Context, cid sync3.ConnID, req *sync3.Request, isInitial bool, start time.Time) (*sync3.Response, error) { - if s.loadPosition == -1 { + if s.anchorLoadPosition <= 0 { // load() needs no ctx so drop it _, region := internal.StartSpan(ctx, "load") - s.load(ctx, req) + err := s.load(ctx, req) + if err != nil { + // in practice this means DB hit failures. If we try again later maybe it'll work, and we will because + // anchorLoadPosition is unset. + logger.Err(err).Str("conn", cid.String()).Msg("failed to load initial data") + } region.End() } setupTime := time.Since(start) @@ -165,19 +182,19 @@ func (s *ConnState) OnIncomingRequest(ctx context.Context, cid sync3.ConnID, req // onIncomingRequest is a callback which fires when the client makes a request to the server. Whilst each request may // be on their own goroutine, the requests are linearised for us by Conn so it is safe to modify ConnState without // additional locking mechanisms. -func (s *ConnState) onIncomingRequest(ctx context.Context, req *sync3.Request, isInitial bool) (*sync3.Response, error) { +func (s *ConnState) onIncomingRequest(reqCtx context.Context, req *sync3.Request, isInitial bool) (*sync3.Response, error) { start := time.Now() // ApplyDelta works fine if s.muxedReq is nil var delta *sync3.RequestDelta s.muxedReq, delta = s.muxedReq.ApplyDelta(req) - internal.Logf(ctx, "connstate", "new subs=%v unsubs=%v num_lists=%v", len(delta.Subs), len(delta.Unsubs), len(delta.Lists)) + internal.Logf(reqCtx, "connstate", "new subs=%v unsubs=%v num_lists=%v", len(delta.Subs), len(delta.Unsubs), len(delta.Lists)) for key, l := range delta.Lists { listData := "" if l.Curr != nil { listDataBytes, _ := json.Marshal(l.Curr) listData = string(listDataBytes) } - internal.Logf(ctx, "connstate", "list[%v] prev_empty=%v curr=%v", key, l.Prev == nil, listData) + internal.Logf(reqCtx, "connstate", "list[%v] prev_empty=%v curr=%v", key, l.Prev == nil, listData) } for roomID, sub := range s.muxedReq.RoomSubscriptions { internal.Logf(ctx, "connstate", "room sub[%v] %v", roomID, sub) @@ -187,20 +204,20 @@ func (s *ConnState) onIncomingRequest(ctx context.Context, req *sync3.Request, i // for it to mix together builder := NewRoomsBuilder() // works out which rooms are subscribed to but doesn't pull room data - s.buildRoomSubscriptions(ctx, builder, delta.Subs, delta.Unsubs) + s.buildRoomSubscriptions(reqCtx, builder, delta.Subs, delta.Unsubs) // works out how rooms get moved about but doesn't pull room data - respLists := s.buildListSubscriptions(ctx, builder, delta.Lists) + respLists := s.buildListSubscriptions(reqCtx, builder, delta.Lists) // pull room data and set changes on the response response := &sync3.Response{ - Rooms: s.buildRooms(ctx, builder.BuildSubscriptions()), // pull room data + Rooms: s.buildRooms(reqCtx, builder.BuildSubscriptions()), // pull room data Lists: respLists, } // Handle extensions AFTER processing lists as extensions may need to know which rooms the client // is being notified about (e.g. for room account data) - ctx, region := internal.StartSpan(ctx, "extensions") - response.Extensions = s.extensionsHandler.Handle(ctx, s.muxedReq.Extensions, extensions.Context{ + extCtx, region := internal.StartSpan(reqCtx, "extensions") + response.Extensions = s.extensionsHandler.Handle(extCtx, s.muxedReq.Extensions, extensions.Context{ UserID: s.userID, DeviceID: s.deviceID, RoomIDToTimeline: response.RoomIDsToTimelineEventIDs(), @@ -218,8 +235,8 @@ func (s *ConnState) onIncomingRequest(ctx context.Context, req *sync3.Request, i } // do live tracking if we have nothing to tell the client yet - ctx, region = internal.StartSpan(ctx, "liveUpdate") - s.live.liveUpdate(ctx, req, s.muxedReq.Extensions, isInitial, response) + updateCtx, region := internal.StartSpan(reqCtx, "liveUpdate") + s.live.liveUpdate(updateCtx, req, s.muxedReq.Extensions, isInitial, response) region.End() // counts are AFTER events are applied, hence after liveUpdate @@ -232,7 +249,7 @@ func (s *ConnState) onIncomingRequest(ctx context.Context, req *sync3.Request, i // Add membership events for users sending typing notifications. We do this after live update // and initial room loading code so we LL room members in all cases. if response.Extensions.Typing != nil && response.Extensions.Typing.HasData(isInitial) { - s.lazyLoadTypingMembers(ctx, response) + s.lazyLoadTypingMembers(reqCtx, response) } return response, nil } @@ -495,7 +512,7 @@ func (s *ConnState) lazyLoadTypingMembers(ctx context.Context, response *sync3.R continue } // load the state event - memberEvent := s.globalCache.LoadStateEvent(ctx, roomID, s.loadPosition, "m.room.member", typingUserID.Str) + memberEvent := s.globalCache.LoadStateEvent(ctx, roomID, s.loadPositions[roomID], "m.room.member", typingUserID.Str) if memberEvent != nil { room.RequiredState = append(room.RequiredState, memberEvent) s.lazyCache.AddUser(roomID, typingUserID.Str) @@ -512,15 +529,20 @@ func (s *ConnState) getInitialRoomData(ctx context.Context, roomSub sync3.RoomSu ctx, span := internal.StartSpan(ctx, "getInitialRoomData") defer span.End() rooms := make(map[string]sync3.Room, len(roomIDs)) - // We want to grab the user room data and the room metadata for each room ID. - roomIDToUserRoomData := s.userCache.LazyLoadTimelines(ctx, s.loadPosition, roomIDs, int(roomSub.TimelineLimit)) + // We want to grab the user room data and the room metadata for each room ID. We use the globally + // highest NID we've seen to act as an anchor for the request. This anchor does not guarantee that + // events returned here have already been seen - the position is not globally ordered - so because + // room A has a position of 6 and B has 7 (so the highest is 7) does not mean that this connection + // has seen 6, as concurrent room updates cause A and B to race. This is why we then go through the + // response to this call to assign new load positions for each room. + roomIDToUserRoomData := s.userCache.LazyLoadTimelines(ctx, s.anchorLoadPosition, roomIDs, int(roomSub.TimelineLimit)) roomMetadatas := s.globalCache.LoadRooms(ctx, roomIDs...) // prepare lazy loading data structures, txn IDs roomToUsersInTimeline := make(map[string][]string, len(roomIDToUserRoomData)) roomToTimeline := make(map[string][]json.RawMessage) for roomID, urd := range roomIDToUserRoomData { set := make(map[string]struct{}) - for _, ev := range urd.RequestedTimeline { + for _, ev := range urd.RequestedLatestEvents.Timeline { set[gjson.GetBytes(ev, "sender").Str] = struct{}{} } userIDs := make([]string, len(set)) @@ -530,11 +552,22 @@ func (s *ConnState) getInitialRoomData(ctx context.Context, roomSub sync3.RoomSu i++ } roomToUsersInTimeline[roomID] = userIDs - roomToTimeline[roomID] = urd.RequestedTimeline + roomToTimeline[roomID] = urd.RequestedLatestEvents.Timeline + // remember what we just loaded so if we see these events down the live stream we know to ignore them. + // This means that requesting a direct room subscription causes the connection to jump ahead to whatever + // is in the database at the time of the call, rather than gradually converging by consuming live data. + // This is fine, so long as we jump ahead on a per-room basis. We need to make sure (ideally) that the + // room state is also pinned to the load position here, else you could see weird things in individual + // responses such as an updated room.name without the associated m.room.name event (though this will + // come through on the next request -> it converges to the right state so it isn't critical). + s.loadPositions[roomID] = urd.RequestedLatestEvents.LatestNID } roomToTimeline = s.userCache.AnnotateWithTransactionIDs(ctx, s.userID, s.deviceID, roomToTimeline) rsm := roomSub.RequiredStateMap(s.userID) - roomIDToState := s.globalCache.LoadRoomState(ctx, roomIDs, s.loadPosition, rsm, roomToUsersInTimeline) + // by reusing the same global load position anchor here, we can be sure that the state returned here + // matches the timeline we loaded earlier - the race conditions happen around pubsub updates and not + // the events table itself, so whatever position is picked based on this anchor is immutable. + roomIDToState := s.globalCache.LoadRoomState(ctx, roomIDs, s.anchorLoadPosition, rsm, roomToUsersInTimeline) if roomIDToState == nil { // e.g no required_state roomIDToState = make(map[string][]json.RawMessage) } @@ -572,7 +605,7 @@ func (s *ConnState) getInitialRoomData(ctx context.Context, roomSub sync3.RoomSu IsDM: userRoomData.IsDM, JoinedCount: metadata.JoinCount, InvitedCount: &metadata.InviteCount, - PrevBatch: userRoomData.RequestedPrevBatch, + PrevBatch: userRoomData.RequestedLatestEvents.PrevBatch, } } diff --git a/sync3/handler/connstate_live.go b/sync3/handler/connstate_live.go index 8a2e2cdc..a01a7ef6 100644 --- a/sync3/handler/connstate_live.go +++ b/sync3/handler/connstate_live.go @@ -128,9 +128,13 @@ func (s *connStateLive) processLiveUpdate(ctx context.Context, up caches.Update, roomEventUpdate, _ := up.(*caches.RoomEventUpdate) // if this is a room event update we may not want to process this if the event nid is < loadPos, // as that means we have already taken it into account - if roomEventUpdate != nil && !roomEventUpdate.EventData.AlwaysProcess && roomEventUpdate.EventData.NID < s.loadPosition { - internal.Logf(ctx, "liveUpdate", "not process update %v < %v", roomEventUpdate.EventData.NID, s.loadPosition) - return false + if roomEventUpdate != nil && !roomEventUpdate.EventData.AlwaysProcess { + // check if we should skip this update. Do we know of this room (lp > 0) and if so, is this event + // behind what we've processed before? + lp := s.loadPositions[roomEventUpdate.RoomID()] + if lp > 0 && roomEventUpdate.EventData.NID < lp { + return false + } } // for initial rooms e.g a room comes into the window or a subscription now exists @@ -161,9 +165,6 @@ func (s *connStateLive) processLiveUpdate(ctx context.Context, up caches.Update, rooms := s.buildRooms(ctx, builder.BuildSubscriptions()) for roomID, room := range rooms { response.Rooms[roomID] = room - // remember what point we snapshotted this room, incase we see live events which we have - // already snapshotted here. - s.loadPositions[roomID] = s.loadPosition } // TODO: find a better way to determine if the triggering event should be included e.g ask the lists? @@ -195,7 +196,7 @@ func (s *connStateLive) processLiveUpdate(ctx context.Context, up caches.Update, sender := roomEventUpdate.EventData.Sender if s.lazyCache.IsLazyLoading(roomID) && !s.lazyCache.IsSet(roomID, sender) { // load the state event - memberEvent := s.globalCache.LoadStateEvent(context.Background(), roomID, s.loadPosition, "m.room.member", sender) + memberEvent := s.globalCache.LoadStateEvent(context.Background(), roomID, s.loadPositions[roomID], "m.room.member", sender) if memberEvent != nil { r.RequiredState = append(r.RequiredState, memberEvent) s.lazyCache.AddUser(roomID, sender) @@ -296,12 +297,9 @@ func (s *connStateLive) processGlobalUpdates(ctx context.Context, builder *Rooms }) } - if isRoomEventUpdate { - // TODO: we should do this check before lists.SetRoom - if roomEventUpdate.EventData.NID <= s.loadPosition { - return // if this update is in the past then ignore it - } - s.loadPosition = roomEventUpdate.EventData.NID + // update the anchor for this new event + if isRoomEventUpdate && roomEventUpdate.EventData.NID > s.anchorLoadPosition { + s.anchorLoadPosition = roomEventUpdate.EventData.NID } return } diff --git a/sync3/handler/connstate_test.go b/sync3/handler/connstate_test.go index 2b06d0fe..17700ea8 100644 --- a/sync3/handler/connstate_test.go +++ b/sync3/handler/connstate_test.go @@ -48,7 +48,7 @@ func mockLazyRoomOverride(loadPos int64, roomIDs []string, maxTimelineEvents int result := make(map[string]caches.UserRoomData) for _, roomID := range roomIDs { u := caches.NewUserRoomData() - u.RequestedTimeline = []json.RawMessage{[]byte(`{}`)} + u.RequestedLatestEvents.Timeline = []json.RawMessage{[]byte(`{}`)} result[roomID] = u } return result @@ -84,7 +84,7 @@ func TestConnStateInitial(t *testing.T) { roomB.RoomID: {userID}, roomC.RoomID: {userID}, }) - globalCache.LoadJoinedRoomsOverride = func(userID string) (pos int64, joinedRooms map[string]*internal.RoomMetadata, joinTimings map[string]internal.EventMetadata, err error) { + globalCache.LoadJoinedRoomsOverride = func(userID string) (pos int64, joinedRooms map[string]*internal.RoomMetadata, joinTimings map[string]internal.EventMetadata, loadPositions map[string]int64, err error) { return 1, map[string]*internal.RoomMetadata{ roomA.RoomID: &roomA, roomB.RoomID: &roomB, @@ -93,7 +93,7 @@ func TestConnStateInitial(t *testing.T) { roomA.RoomID: {NID: 123, Timestamp: 123}, roomB.RoomID: {NID: 456, Timestamp: 456}, roomC.RoomID: {NID: 780, Timestamp: 789}, - }, nil + }, nil, nil } userCache := caches.NewUserCache(userID, globalCache, nil, &NopTransactionFetcher{}) dispatcher.Register(context.Background(), userCache.UserID, userCache) @@ -102,7 +102,7 @@ func TestConnStateInitial(t *testing.T) { result := make(map[string]caches.UserRoomData) for _, roomID := range roomIDs { u := caches.NewUserRoomData() - u.RequestedTimeline = []json.RawMessage{timeline[roomID]} + u.RequestedLatestEvents.Timeline = []json.RawMessage{timeline[roomID]} result[roomID] = u } return result @@ -256,7 +256,7 @@ func TestConnStateMultipleRanges(t *testing.T) { roomID: {userID}, }) } - globalCache.LoadJoinedRoomsOverride = func(userID string) (pos int64, joinedRooms map[string]*internal.RoomMetadata, joinTimings map[string]internal.EventMetadata, err error) { + globalCache.LoadJoinedRoomsOverride = func(userID string) (pos int64, joinedRooms map[string]*internal.RoomMetadata, joinTimings map[string]internal.EventMetadata, loadPositions map[string]int64, err error) { roomMetadata := make(map[string]*internal.RoomMetadata) joinTimings = make(map[string]internal.EventMetadata) for i, r := range rooms { @@ -266,7 +266,7 @@ func TestConnStateMultipleRanges(t *testing.T) { Timestamp: 123456, } } - return 1, roomMetadata, joinTimings, nil + return 1, roomMetadata, joinTimings, nil, nil } userCache := caches.NewUserCache(userID, globalCache, nil, &NopTransactionFetcher{}) userCache.LazyRoomDataOverride = mockLazyRoomOverride @@ -433,7 +433,7 @@ func TestBumpToOutsideRange(t *testing.T) { roomC.RoomID: {userID}, roomD.RoomID: {userID}, }) - globalCache.LoadJoinedRoomsOverride = func(userID string) (pos int64, joinedRooms map[string]*internal.RoomMetadata, joinTimings map[string]internal.EventMetadata, err error) { + globalCache.LoadJoinedRoomsOverride = func(userID string) (pos int64, joinedRooms map[string]*internal.RoomMetadata, joinTimings map[string]internal.EventMetadata, loadPositions map[string]int64, err error) { return 1, map[string]*internal.RoomMetadata{ roomA.RoomID: &roomA, roomB.RoomID: &roomB, @@ -444,7 +444,7 @@ func TestBumpToOutsideRange(t *testing.T) { roomB.RoomID: {NID: 2, Timestamp: 2}, roomC.RoomID: {NID: 3, Timestamp: 3}, roomD.RoomID: {NID: 4, Timestamp: 4}, - }, nil + }, nil, nil } userCache := caches.NewUserCache(userID, globalCache, nil, &NopTransactionFetcher{}) @@ -537,7 +537,7 @@ func TestConnStateRoomSubscriptions(t *testing.T) { roomC.RoomID: testutils.NewEvent(t, "m.room.message", userID, map[string]interface{}{"body": "c"}), roomD.RoomID: testutils.NewEvent(t, "m.room.message", userID, map[string]interface{}{"body": "d"}), } - globalCache.LoadJoinedRoomsOverride = func(userID string) (pos int64, joinedRooms map[string]*internal.RoomMetadata, joinTimings map[string]internal.EventMetadata, err error) { + globalCache.LoadJoinedRoomsOverride = func(userID string) (pos int64, joinedRooms map[string]*internal.RoomMetadata, joinTimings map[string]internal.EventMetadata, loadPositions map[string]int64, err error) { return 1, map[string]*internal.RoomMetadata{ roomA.RoomID: &roomA, roomB.RoomID: &roomB, @@ -548,14 +548,14 @@ func TestConnStateRoomSubscriptions(t *testing.T) { roomB.RoomID: {NID: 2, Timestamp: 2}, roomC.RoomID: {NID: 3, Timestamp: 3}, roomD.RoomID: {NID: 4, Timestamp: 4}, - }, nil + }, nil, nil } userCache := caches.NewUserCache(userID, globalCache, nil, &NopTransactionFetcher{}) userCache.LazyRoomDataOverride = func(loadPos int64, roomIDs []string, maxTimelineEvents int) map[string]caches.UserRoomData { result := make(map[string]caches.UserRoomData) for _, roomID := range roomIDs { u := caches.NewUserRoomData() - u.RequestedTimeline = []json.RawMessage{timeline[roomID]} + u.RequestedLatestEvents.Timeline = []json.RawMessage{timeline[roomID]} result[roomID] = u } return result diff --git a/sync3/handler/handler.go b/sync3/handler/handler.go index b9227cf3..84c461c3 100644 --- a/sync3/handler/handler.go +++ b/sync3/handler/handler.go @@ -1,6 +1,5 @@ package handler -import "C" import ( "context" "database/sql" @@ -67,7 +66,7 @@ type SyncLiveHandler struct { } func NewSync3Handler( - store *state.Storage, storev2 *sync2.Storage, v2Client sync2.Client, postgresDBURI, secret string, + store *state.Storage, storev2 *sync2.Storage, v2Client sync2.Client, secret string, pub pubsub.Notifier, sub pubsub.Listener, enablePrometheus bool, maxPendingEventUpdates int, ) (*SyncLiveHandler, error) { logger.Info().Msg("creating handler") @@ -225,8 +224,7 @@ func (h *SyncLiveHandler) serve(w http.ResponseWriter, req *http.Request) error if req.ContentLength != 0 { defer req.Body.Close() if err := json.NewDecoder(req.Body).Decode(&requestBody); err != nil { - log.Err(err).Msg("failed to read/decode request body") - internal.GetSentryHubFromContextOrDefault(req.Context()).CaptureException(err) + log.Warn().Err(err).Msg("failed to read/decode request body") return &internal.HandlerError{ StatusCode: 400, Err: err, @@ -339,6 +337,8 @@ func (h *SyncLiveHandler) serve(w http.ResponseWriter, req *http.Request) error // When this function returns, the connection is alive and active. func (h *SyncLiveHandler) setupConnection(req *http.Request, syncReq *sync3.Request, containsPos bool) (*sync3.Conn, *internal.HandlerError) { + taskCtx, task := internal.StartTask(req.Context(), "setupConnection") + defer task.End() var conn *sync3.Conn // Extract an access token accessToken, err := internal.ExtractAccessToken(req) @@ -371,6 +371,7 @@ func (h *SyncLiveHandler) setupConnection(req *http.Request, syncReq *sync3.Requ } log := hlog.FromRequest(req).With().Str("user", token.UserID).Str("device", token.DeviceID).Logger() internal.SetRequestContextUserID(req.Context(), token.UserID, token.DeviceID) + internal.Logf(taskCtx, "setupConnection", "identified access token as user=%s device=%s", token.UserID, token.DeviceID) // Record the fact that we've recieved a request from this token err = h.V2Store.TokensTable.MaybeUpdateLastSeen(token, time.Now()) @@ -396,8 +397,8 @@ func (h *SyncLiveHandler) setupConnection(req *http.Request, syncReq *sync3.Requ return nil, internal.ExpiredSessionError() } - log.Trace().Msg("checking poller exists and is running") pid := sync2.PollerID{UserID: token.UserID, DeviceID: token.DeviceID} + log.Trace().Any("pid", pid).Msg("checking poller exists and is running") h.EnsurePoller.EnsurePolling(req.Context(), pid, token.AccessTokenHash) log.Trace().Msg("poller exists and is running") // this may take a while so if the client has given up (e.g timed out) by this point, just stop. @@ -458,14 +459,14 @@ func (h *SyncLiveHandler) identifyUnknownAccessToken(accessToken string, logger var token *sync2.Token err = sqlutil.WithTransaction(h.V2Store.DB, func(txn *sqlx.Tx) error { // Create a brand-new row for this token. - token, err = h.V2Store.TokensTable.Insert(accessToken, userID, deviceID, time.Now()) + token, err = h.V2Store.TokensTable.Insert(txn, accessToken, userID, deviceID, time.Now()) if err != nil { logger.Warn().Err(err).Str("user", userID).Str("device", deviceID).Msg("failed to insert v2 token") return err } // Ensure we have a device row for this token. - err = h.V2Store.DevicesTable.InsertDevice(userID, deviceID) + err = h.V2Store.DevicesTable.InsertDevice(txn, userID, deviceID) if err != nil { log.Warn().Err(err).Str("user", userID).Str("device", deviceID).Msg("failed to insert v2 device") return err diff --git a/tests-integration/db_test.go b/tests-integration/db_test.go new file mode 100644 index 00000000..3897009e --- /dev/null +++ b/tests-integration/db_test.go @@ -0,0 +1,103 @@ +package syncv3 + +import ( + "encoding/json" + "fmt" + "sync" + "testing" + "time" + + syncv3 "github.com/matrix-org/sliding-sync" + "github.com/matrix-org/sliding-sync/sync2" + "github.com/matrix-org/sliding-sync/sync3" + "github.com/matrix-org/sliding-sync/testutils" + "github.com/matrix-org/sliding-sync/testutils/m" +) + +// Test that the proxy works fine with low max conns. Low max conns can be a problem +// if a request A needs 2 conns to respond and that blocks forward progress on the server, +// and the request can only obtain 1 conn. +func TestMaxDBConns(t *testing.T) { + pqString := testutils.PrepareDBConnectionString() + // setup code + v2 := runTestV2Server(t) + opts := syncv3.Opts{ + DBMaxConns: 1, + } + v3 := runTestServer(t, v2, pqString, opts) + defer v2.close() + defer v3.close() + + testMaxDBConns := func() { + // make N users and drip feed some events, make sure they are all seen + numUsers := 5 + var wg sync.WaitGroup + wg.Add(numUsers) + for i := 0; i < numUsers; i++ { + go func(n int) { + defer wg.Done() + userID := fmt.Sprintf("@maxconns_%d:localhost", n) + token := fmt.Sprintf("maxconns_%d", n) + roomID := fmt.Sprintf("!maxconns_%d", n) + v2.addAccount(t, userID, token) + v2.queueResponse(userID, sync2.SyncResponse{ + Rooms: sync2.SyncRoomsResponse{ + Join: v2JoinTimeline(roomEvents{ + roomID: roomID, + state: createRoomState(t, userID, time.Now()), + }), + }, + }) + // initial sync + res := v3.mustDoV3Request(t, token, sync3.Request{ + Lists: map[string]sync3.RequestList{"a": { + Ranges: sync3.SliceRanges{ + [2]int64{0, 1}, + }, + RoomSubscription: sync3.RoomSubscription{ + TimelineLimit: 1, + }, + }}, + }) + t.Logf("user %s has done an initial /sync OK", userID) + m.MatchResponse(t, res, m.MatchList("a", m.MatchV3Count(1), m.MatchV3Ops( + m.MatchV3SyncOp(0, 0, []string{roomID}), + )), m.MatchRoomSubscriptionsStrict(map[string][]m.RoomMatcher{ + roomID: { + m.MatchJoinCount(1), + }, + })) + // drip feed and get update + dripMsg := testutils.NewEvent(t, "m.room.message", userID, map[string]interface{}{ + "msgtype": "m.text", + "body": "drip drip", + }) + v2.queueResponse(userID, sync2.SyncResponse{ + Rooms: sync2.SyncRoomsResponse{ + Join: v2JoinTimeline(roomEvents{ + roomID: roomID, + events: []json.RawMessage{ + dripMsg, + }, + }), + }, + }) + t.Logf("user %s has queued the drip", userID) + v2.waitUntilEmpty(t, userID) + t.Logf("user %s poller has received the drip", userID) + res = v3.mustDoV3RequestWithPos(t, token, res.Pos, sync3.Request{}) + m.MatchResponse(t, res, m.MatchRoomSubscriptionsStrict(map[string][]m.RoomMatcher{ + roomID: { + m.MatchRoomTimelineMostRecent(1, []json.RawMessage{dripMsg}), + }, + })) + t.Logf("user %s has received the drip", userID) + }(i) + } + wg.Wait() + } + + testMaxDBConns() + v3.restart(t, v2, pqString, opts) + testMaxDBConns() +} diff --git a/tests-integration/v3_test.go b/tests-integration/v3_test.go index cd8bd2c1..de72d954 100644 --- a/tests-integration/v3_test.go +++ b/tests-integration/v3_test.go @@ -291,11 +291,11 @@ func (s *testV3Server) close() { s.h2.Teardown() } -func (s *testV3Server) restart(t *testing.T, v2 *testV2Server, pq string) { +func (s *testV3Server) restart(t *testing.T, v2 *testV2Server, pq string, opts ...syncv3.Opts) { t.Helper() log.Printf("restarting server") s.close() - ss := runTestServer(t, v2, pq) + ss := runTestServer(t, v2, pq, opts...) // replace all the fields which will be close()d to ensure we don't leak s.srv = ss.srv s.h2 = ss.h2 @@ -366,20 +366,22 @@ func runTestServer(t testutils.TestBenchInterface, v2Server *testV2Server, postg //tests often repeat requests. To ensure tests remain fast, reduce the spam protection limits. sync3.SpamProtectionInterval = time.Millisecond - metricsEnabled := false - maxPendingEventUpdates := 200 + combinedOpts := syncv3.Opts{ + TestingSynchronousPubsub: true, // critical to avoid flakey tests + AddPrometheusMetrics: false, + MaxPendingEventUpdates: 200, + } if len(opts) > 0 { - metricsEnabled = opts[0].AddPrometheusMetrics - if opts[0].MaxPendingEventUpdates > 0 { - maxPendingEventUpdates = opts[0].MaxPendingEventUpdates + opt := opts[0] + combinedOpts.AddPrometheusMetrics = opt.AddPrometheusMetrics + combinedOpts.DBConnMaxIdleTime = opt.DBConnMaxIdleTime + combinedOpts.DBMaxConns = opt.DBMaxConns + if opt.MaxPendingEventUpdates > 0 { + combinedOpts.MaxPendingEventUpdates = opt.MaxPendingEventUpdates handler.BufferWaitTime = 5 * time.Millisecond } } - h2, h3 := syncv3.Setup(v2Server.url(), postgresConnectionString, os.Getenv("SYNCV3_SECRET"), syncv3.Opts{ - TestingSynchronousPubsub: true, // critical to avoid flakey tests - MaxPendingEventUpdates: maxPendingEventUpdates, - AddPrometheusMetrics: metricsEnabled, - }) + h2, h3 := syncv3.Setup(v2Server.url(), postgresConnectionString, os.Getenv("SYNCV3_SECRET"), combinedOpts) // for ease of use we don't start v2 pollers at startup in tests r := mux.NewRouter() r.Use(hlog.NewHandler(logger)) diff --git a/v3.go b/v3.go index e189e97a..cc72d093 100644 --- a/v3.go +++ b/v3.go @@ -9,6 +9,7 @@ import ( "time" "github.com/getsentry/sentry-go" + "github.com/jmoiron/sqlx" "github.com/gorilla/mux" "github.com/matrix-org/sliding-sync/internal" @@ -36,6 +37,9 @@ type Opts struct { // if true, publishing messages will block until the consumer has consumed it. // Assumes a single producer and a single consumer. TestingSynchronousPubsub bool + + DBMaxConns int + DBConnMaxIdleTime time.Duration } type server struct { @@ -75,6 +79,18 @@ func Setup(destHomeserver, postgresURI, secret string, opts Opts) (*handler2.Han } store := state.NewStorage(postgresURI) storev2 := sync2.NewStore(postgresURI, secret) + for _, db := range []*sqlx.DB{store.DB, storev2.DB} { + if opts.DBMaxConns > 0 { + // https://github.com/go-sql-driver/mysql#important-settings + // "db.SetMaxIdleConns() is recommended to be set same to db.SetMaxOpenConns(). When it is smaller + // than SetMaxOpenConns(), connections can be opened and closed much more frequently than you expect." + db.SetMaxOpenConns(opts.DBMaxConns) + db.SetMaxIdleConns(opts.DBMaxConns) + } + if opts.DBConnMaxIdleTime > 0 { + db.SetConnMaxIdleTime(opts.DBConnMaxIdleTime) + } + } bufferSize := 50 deviceDataUpdateFrequency := time.Second if opts.TestingSynchronousPubsub { @@ -88,14 +104,14 @@ func Setup(destHomeserver, postgresURI, secret string, opts Opts) (*handler2.Han pMap := sync2.NewPollerMap(v2Client, opts.AddPrometheusMetrics) // create v2 handler - h2, err := handler2.NewHandler(postgresURI, pMap, storev2, store, v2Client, pubSub, pubSub, opts.AddPrometheusMetrics, deviceDataUpdateFrequency) + h2, err := handler2.NewHandler(pMap, storev2, store, pubSub, pubSub, opts.AddPrometheusMetrics, deviceDataUpdateFrequency) if err != nil { panic(err) } pMap.SetCallbacks(h2) // create v3 handler - h3, err := handler.NewSync3Handler(store, storev2, v2Client, postgresURI, secret, pubSub, pubSub, opts.AddPrometheusMetrics, opts.MaxPendingEventUpdates) + h3, err := handler.NewSync3Handler(store, storev2, v2Client, secret, pubSub, pubSub, opts.AddPrometheusMetrics, opts.MaxPendingEventUpdates) if err != nil { panic(err) }