Skip to content

Commit

Permalink
Allow explicitly specified /state and /state_ids requests to comp…
Browse files Browse the repository at this point in the history
…lete
  • Loading branch information
Sean Quah committed Jul 21, 2022
1 parent d085f11 commit 2cbeaae
Showing 1 changed file with 24 additions and 6 deletions.
30 changes: 24 additions & 6 deletions tests/federation_room_join_partial_state_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,8 @@ type partialStateJoinResult struct {
ServerRoom *federation.ServerRoom
fedStateIdsRequestReceivedWaiter *Waiter
fedStateIdsSendResponseWaiter *Waiter
// the set of events for which we will not block `/state` or `/state_ids` requests.
fedStateIdsAllowedEvents map[string]bool
}

// beginPartialStateJoin spins up a room on a complement server,
Expand Down Expand Up @@ -627,6 +629,7 @@ func beginPartialStateJoin(t *testing.T, deployment *docker.Deployment, joiningU
// some things for orchestration
result.fedStateIdsRequestReceivedWaiter = NewWaiter()
result.fedStateIdsSendResponseWaiter = NewWaiter()
result.fedStateIdsAllowedEvents = make(map[string]bool)

// create the room on the complement server, with charlie and derek as members
roomVer := joiningUser.GetDefaultRoomVersion(t)
Expand All @@ -642,10 +645,17 @@ func beginPartialStateJoin(t *testing.T, deployment *docker.Deployment, joiningU

// register a handler for /state_ids requests, which finishes fedStateIdsRequestReceivedWaiter, then
// waits for fedStateIdsSendResponseWaiter and sends a reply
handleStateIdsRequests(t, result.Server, result.ServerRoom, result.fedStateIdsRequestReceivedWaiter, result.fedStateIdsSendResponseWaiter)
handleStateIdsRequests(
t,
result.Server,
result.ServerRoom,
result.fedStateIdsRequestReceivedWaiter,
result.fedStateIdsSendResponseWaiter,
result.fedStateIdsAllowedEvents,
)

// a handler for /state requests, which sends a sensible response
handleStateRequests(t, result.Server, result.ServerRoom, nil, nil)
handleStateRequests(t, result.Server, result.ServerRoom, nil, nil, nil)

// have joiningUser join the room by room ID.
joiningUser.JoinRoom(t, result.ServerRoom.RoomID, []string{result.Server.ServerName()})
Expand Down Expand Up @@ -693,6 +703,12 @@ func (psj *partialStateJoinResult) CreateMessageEvent(t *testing.T, senderLocalp
return event
}

// allow a /state_ids request for a given event to complete before FinishStateRequest has been called.
// only applies to new incoming requests, and not any currently blocked ones.
func (psj *partialStateJoinResult) AllowStateRequestForEvent(eventID string) {
psj.fedStateIdsAllowedEvents[eventID] = true
}

// wait for a /state_ids request for the test room to arrive
func (psj *partialStateJoinResult) AwaitStateIdsRequest(t *testing.T) {
psj.fedStateIdsRequestReceivedWaiter.Waitf(t, 5*time.Second, "Waiting for /state_ids request")
Expand All @@ -709,7 +725,7 @@ func (psj *partialStateJoinResult) FinishStateRequest() {
// if sendResponseWaiter is not nil, we will Wait() for it to finish before sending the response.
func handleStateIdsRequests(
t *testing.T, srv *federation.Server, serverRoom *federation.ServerRoom,
requestReceivedWaiter *Waiter, sendResponseWaiter *Waiter,
requestReceivedWaiter *Waiter, sendResponseWaiter *Waiter, allowedEvents map[string]bool,
) {
srv.Mux().Handle(
fmt.Sprintf("/_matrix/federation/v1/state_ids/%s", serverRoom.RoomID),
Expand All @@ -719,7 +735,8 @@ func handleStateIdsRequests(
if requestReceivedWaiter != nil {
requestReceivedWaiter.Finish()
}
if sendResponseWaiter != nil {
if !allowedEvents[queryParams["event_id"][0]] &&
sendResponseWaiter != nil {
sendResponseWaiter.Waitf(t, 60*time.Second, "Waiting for /state_ids request")
}
t.Logf("Replying to /state_ids request")
Expand All @@ -744,7 +761,7 @@ func handleStateIdsRequests(
// if sendResponseWaiter is not nil, we will Wait() for it to finish before sending the response.
func handleStateRequests(
t *testing.T, srv *federation.Server, serverRoom *federation.ServerRoom,
requestReceivedWaiter *Waiter, sendResponseWaiter *Waiter,
requestReceivedWaiter *Waiter, sendResponseWaiter *Waiter, allowedEvents map[string]bool,
) {
srv.Mux().Handle(
fmt.Sprintf("/_matrix/federation/v1/state/%s", serverRoom.RoomID),
Expand All @@ -754,7 +771,8 @@ func handleStateRequests(
if requestReceivedWaiter != nil {
requestReceivedWaiter.Finish()
}
if sendResponseWaiter != nil {
if !allowedEvents[queryParams["event_id"][0]] &&
sendResponseWaiter != nil {
sendResponseWaiter.Waitf(t, 60*time.Second, "Waiting for /state request")
}
res := gomatrixserverlib.RespState{
Expand Down

0 comments on commit 2cbeaae

Please sign in to comment.