diff --git a/tests/federation_room_join_partial_state_test.go b/tests/federation_room_join_partial_state_test.go index 5fa56c48..9e694599 100644 --- a/tests/federation_room_join_partial_state_test.go +++ b/tests/federation_room_join_partial_state_test.go @@ -14,17 +14,16 @@ import ( "net" "net/http" "net/url" + "reflect" "strconv" "strings" "testing" "time" - "github.com/gorilla/mux" "github.com/tidwall/gjson" "github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/util" "github.com/matrix-org/complement/internal/b" "github.com/matrix-org/complement/internal/client" @@ -182,7 +181,8 @@ func TestPartialStateJoin(t *testing.T) { federation.HandleEventAuthRequests()(psjResult.Server) // the HS will make a /get_missing_events request for the missing prev events of event B - handleGetMissingEventsRequests(t, psjResult.Server, psjResult.ServerRoom, []*gomatrixserverlib.Event{eventA}) + handleGetMissingEventsRequests(t, psjResult.Server, psjResult.ServerRoom, + []string{eventB.EventID()}, []*gomatrixserverlib.Event{eventA}) // send event B to hs1 testReceiveEventDuringPartialStateJoin(t, deployment, alice, psjResult, eventB) @@ -218,7 +218,8 @@ func TestPartialStateJoin(t *testing.T) { federation.HandleEventAuthRequests()(psjResult.Server) // the HS will make a /get_missing_events request for the missing prev event of event B - handleGetMissingEventsRequests(t, psjResult.Server, psjResult.ServerRoom, []*gomatrixserverlib.Event{eventA}) + handleGetMissingEventsRequests(t, psjResult.Server, psjResult.ServerRoom, + []string{eventB.EventID()}, []*gomatrixserverlib.Event{eventA}) // send event B to hs1 testReceiveEventDuringPartialStateJoin(t, deployment, alice, psjResult, eventB) @@ -255,7 +256,8 @@ func TestPartialStateJoin(t *testing.T) { // the HS will make a /get_missing_events request for the missing prev event of event C, // to which we respond with event B only. - handleGetMissingEventsRequests(t, psjResult.Server, psjResult.ServerRoom, []*gomatrixserverlib.Event{eventB}) + handleGetMissingEventsRequests(t, psjResult.Server, psjResult.ServerRoom, + []string{eventC.EventID()}, []*gomatrixserverlib.Event{eventB}) // dedicated state_ids and state handlers for event A handleStateIdsRequests(t, psjResult.Server, psjResult.ServerRoom, eventA.EventID(), psjResult.ServerRoom.AllCurrentState(), nil, nil) @@ -891,45 +893,33 @@ func handleStateRequests( } // register a handler for `/get_missing_events` requests +// +// This can (currently) only handle a single `/get_missing_events` request, and the "latest_events" in the request +// must match those listed in "expectedLatestEvents" (otherwise the test is failed). func handleGetMissingEventsRequests( t *testing.T, srv *federation.Server, serverRoom *federation.ServerRoom, - eventsToReturn []*gomatrixserverlib.Event, + expectedLatestEvents []string, eventsToReturn []*gomatrixserverlib.Event, ) { - srv.Mux().HandleFunc("/_matrix/federation/v1/get_missing_events/{roomID}", func(w http.ResponseWriter, req *http.Request) { - roomID := mux.Vars(req)["roomID"] - if roomID != serverRoom.RoomID { - t.Fatalf("Received unexpected /get_missing_events request for room: %s", roomID) - } - + srv.Mux().HandleFunc(fmt.Sprintf("/_matrix/federation/v1/get_missing_events/%s", serverRoom.RoomID), func(w http.ResponseWriter, req *http.Request) { body, err := ioutil.ReadAll(req.Body) if err != nil { - t.Fatalf("unable to read request body: %v", err) - } - var getMissingEventsRequest struct { - EarliestEvents []string `json:"earliest_events"` - LatestEvents []string `json:"latest_events"` - Limit int `json:"int"` - MinDepth int `json:"min_depth"` + t.Fatalf("unable to read /get_missing_events request body: %s", err) } + var getMissingEventsRequest gomatrixserverlib.MissingEvents err = json.Unmarshal(body, &getMissingEventsRequest) if err != nil { - errResp := util.MessageResponse(400, err.Error()) - w.WriteHeader(errResp.Code) - b, _ := json.Marshal(errResp.JSON) - w.Write(b) - return + t.Fatalf("unable to unmarshall /get_missing_events request body: %s", err) } - t.Logf("Incoming get_missing_events request for prev events of %s in room %s", getMissingEventsRequest.LatestEvents, roomID) + t.Logf("Incoming get_missing_events request for prev events of %s in room %s", getMissingEventsRequest.LatestEvents, serverRoom.RoomID) + if !reflect.DeepEqual(expectedLatestEvents, getMissingEventsRequest.LatestEvents) { + t.Fatalf("getMissingEventsRequest.LatestEvents: got %v, wanted %v", getMissingEventsRequest, expectedLatestEvents) + } - // TODO: return events based on those requested + responseBytes, _ := json.Marshal(gomatrixserverlib.RespMissingEvents{ + Events: gomatrixserverlib.NewEventJSONsFromEvents(eventsToReturn), + }) w.WriteHeader(200) - res := struct { - Events []*gomatrixserverlib.Event `json:"events"` - }{ - Events: eventsToReturn, - } - responseBytes, _ := json.Marshal(&res) w.Write(responseBytes) }).Methods("POST") }