diff --git a/api/routes.go b/api/routes.go index d7393a52..2fbc652d 100644 --- a/api/routes.go +++ b/api/routes.go @@ -55,6 +55,7 @@ func buildRoutes() http.Handler { register([]string{"GET"}, PrefixClient, "media/download/:server/:mediaId", msc3916, router, authedDownloadRoute) register([]string{"GET"}, PrefixClient, "media/thumbnail/:server/:mediaId", msc3916, router, makeRoute(_routers.RequireAccessToken(r0.ThumbnailMedia), "thumbnail", counter)) register([]string{"GET"}, PrefixFederation, "media/download/:server/:mediaId", msc3916, router, makeRoute(_routers.RequireServerAuth(unstable.FederationDownloadMedia), "download", counter)) + register([]string{"GET"}, PrefixFederation, "media/download/:mediaId", msc3916, router, makeRoute(_routers.RequireServerAuth(unstable.FederationDownloadMedia), "download", counter)) register([]string{"GET"}, PrefixFederation, "media/thumbnail/:server/:mediaId", msc3916, router, makeRoute(_routers.RequireServerAuth(unstable.FederationThumbnailMedia), "thumbnail", counter)) // Custom features @@ -143,7 +144,7 @@ var ( //mxAllSpec matrixVersions = []string{"r0", "v1", "v3", "unstable", "unstable/io.t2bot.media" /* and MSC routes */} mxUnstable matrixVersions = []string{"unstable", "unstable/io.t2bot.media"} msc4034 matrixVersions = []string{"unstable/org.matrix.msc4034"} - msc3916 matrixVersions = []string{"unstable/org.matrix.msc3916"} + msc3916 matrixVersions = []string{"unstable/org.matrix.msc3916", "unstable/org.matrix.msc3916.v2"} mxSpecV3Transition matrixVersions = []string{"r0", "v1", "v3"} mxSpecV3TransitionCS matrixVersions = []string{"r0", "v3"} mxR0 matrixVersions = []string{"r0"} diff --git a/api/unstable/msc3916_download.go b/api/unstable/msc3916_download.go index 6089976a..c357cf56 100644 --- a/api/unstable/msc3916_download.go +++ b/api/unstable/msc3916_download.go @@ -6,6 +6,7 @@ import ( "github.com/t2bot/matrix-media-repo/api/_apimeta" "github.com/t2bot/matrix-media-repo/api/_responses" + "github.com/t2bot/matrix-media-repo/api/_routers" "github.com/t2bot/matrix-media-repo/api/r0" "github.com/t2bot/matrix-media-repo/common/rcontext" "github.com/t2bot/matrix-media-repo/util/readers" @@ -18,6 +19,7 @@ func ClientDownloadMedia(r *http.Request, rctx rcontext.RequestContext, user _ap func FederationDownloadMedia(r *http.Request, rctx rcontext.RequestContext, server _apimeta.ServerInfo) interface{} { r.URL.Query().Set("allow_remote", "false") + r = _routers.ForceSetParam("server", r.Host, r) res := r0.DownloadMedia(r, rctx, _apimeta.UserInfo{}) if dl, ok := res.(*_responses.DownloadResponse); ok { diff --git a/pipelines/_steps/download/try_download.go b/pipelines/_steps/download/try_download.go index c087dce6..d1fbcbec 100644 --- a/pipelines/_steps/download/try_download.go +++ b/pipelines/_steps/download/try_download.go @@ -63,7 +63,7 @@ func TryDownload(ctx rcontext.RequestContext, origin string, mediaId string) (*d var downloadUrl string usesMultipartFormat := false if ctx.Config.SigningKeyPath != "" { - downloadUrl = fmt.Sprintf("%s/_matrix/federation/unstable/org.matrix.msc3916/media/download/%s/%s?a=b", baseUrl, url.PathEscape(origin), url.PathEscape(mediaId)) + downloadUrl = fmt.Sprintf("%s/_matrix/federation/unstable/org.matrix.msc3916.v2/media/download/%s", baseUrl, url.PathEscape(mediaId)) resp, err = matrix.FederatedGet(ctx, downloadUrl, realHost, ctx.Config.SigningKeyPath) metrics.MediaDownloaded.With(prometheus.Labels{"origin": origin}).Inc() if err != nil { diff --git a/test/msc3916_downloads_suite_test.go b/test/msc3916_downloads_suite_test.go index 043379a4..0a2cdbf5 100644 --- a/test/msc3916_downloads_suite_test.go +++ b/test/msc3916_downloads_suite_test.go @@ -155,7 +155,7 @@ func (s *MSC3916DownloadsSuite) TestFederationDownloads() { assert.NotEmpty(t, mediaId) // Verify the federation download *fails* when lacking auth - uri := fmt.Sprintf("/_matrix/federation/unstable/org.matrix.msc3916/media/download/%s/%s", origin, mediaId) + uri := fmt.Sprintf("/_matrix/federation/unstable/org.matrix.msc3916.v2/media/download/%s", mediaId) raw, err := remoteClient.DoRaw("GET", uri, nil, "", nil) assert.NoError(t, err) assert.Equal(t, http.StatusUnauthorized, raw.StatusCode) @@ -182,7 +182,7 @@ func (s *MSC3916DownloadsSuite) TestFederationMakesAuthedDownloads() { origin := "" mediaId := "abc123" testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - assert.Equal(t, fmt.Sprintf("/_matrix/federation/unstable/org.matrix.msc3916/media/download/%s/%s", origin, mediaId), r.URL.Path) + assert.Equal(t, fmt.Sprintf("/_matrix/federation/unstable/org.matrix.msc3916.v2/media/download/%s", mediaId), r.URL.Path) origin, err := matrix.ValidateXMatrixAuth(r, true) assert.NoError(t, err) assert.Equal(t, client1.ServerName, origin)