diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 91212987..cff31f0e 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -44,5 +44,5 @@ jobs: - name: "Run: compile assets" run: "$PWD/bin/compile_assets" - name: "Run: tests" - run: "go test -c -v ./test && ./test.test '-test.v'" # cheat and work around working directory issues + run: "go test -c -v ./test && ./test.test '-test.v' -test.parallel 1" # cheat and work around working directory issues timeout-minutes: 30 diff --git a/CHANGELOG.md b/CHANGELOG.md index b953bb7c..8f86975f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), * S3 datastores can now specify a `prefixLength` to improve S3 performance on some providers. See `config.sample.yaml` for details. * Add `multipartUploads` flag for running MMR against unsupported S3 providers. See `config.sample.yaml` for details. * A new "leaky bucket" rate limit algorithm has been applied to downloads. See `rateLimit.buckets` in the config for details. +* Add support for [MSC3916: Authentication for media](https://github.com/matrix-org/matrix-spec-proposals/pull/3916). + * To enable full support, use `signingKeyPath` in your config. See sample config for details. ### Changed diff --git a/api/_apimeta/auth.go b/api/_apimeta/auth.go index 5bd5dfd6..d16df56f 100644 --- a/api/_apimeta/auth.go +++ b/api/_apimeta/auth.go @@ -16,6 +16,10 @@ type UserInfo struct { IsShared bool } +type ServerInfo struct { + ServerName string +} + func GetRequestUserAdminStatus(r *http.Request, rctx rcontext.RequestContext, user UserInfo) (bool, bool) { isGlobalAdmin := util.IsGlobalAdmin(user.UserId) || user.IsShared isLocalAdmin, err := matrix.IsUserAdmin(rctx, r.Host, user.AccessToken, r.RemoteAddr) diff --git a/api/_auth_cache/auth_cache.go b/api/_auth_cache/auth_cache.go index da1f2d5c..8ac47758 100644 --- a/api/_auth_cache/auth_cache.go +++ b/api/_auth_cache/auth_cache.go @@ -12,7 +12,7 @@ import ( "github.com/t2bot/matrix-media-repo/matrix" ) -var tokenCache = cache.New(0*time.Second, 30*time.Second) +var tokenCache = cache.New(cache.NoExpiration, 30*time.Second) var rwLock = &sync.RWMutex{} var regexCache = make(map[string]*regexp.Regexp) diff --git a/api/_routers/97-require-server-auth.go b/api/_routers/97-require-server-auth.go new file mode 100644 index 00000000..28ad632d --- /dev/null +++ b/api/_routers/97-require-server-auth.go @@ -0,0 +1,45 @@ +package _routers + +import ( + "errors" + "net/http" + + "github.com/t2bot/matrix-media-repo/api/_apimeta" + "github.com/t2bot/matrix-media-repo/api/_responses" + "github.com/t2bot/matrix-media-repo/common" + "github.com/t2bot/matrix-media-repo/common/rcontext" + "github.com/t2bot/matrix-media-repo/matrix" +) + +type GeneratorWithServerFn = func(r *http.Request, ctx rcontext.RequestContext, server _apimeta.ServerInfo) interface{} + +func RequireServerAuth(generator GeneratorWithServerFn) GeneratorFn { + return func(r *http.Request, ctx rcontext.RequestContext) interface{} { + serverName, err := matrix.ValidateXMatrixAuth(r, true) + if err != nil { + ctx.Log.Debug("Error with X-Matrix auth: ", err) + if errors.Is(err, matrix.ErrNoXMatrixAuth) { + return &_responses.ErrorResponse{ + Code: common.ErrCodeUnauthorized, + Message: "no auth provided (required)", + InternalCode: common.ErrCodeMissingToken, + } + } + if errors.Is(err, matrix.ErrWrongDestination) { + return &_responses.ErrorResponse{ + Code: common.ErrCodeUnauthorized, + Message: "no auth provided for this destination (required)", + InternalCode: common.ErrCodeBadRequest, + } + } + return &_responses.ErrorResponse{ + Code: common.ErrCodeForbidden, + Message: "invalid auth provided (required)", + InternalCode: common.ErrCodeBadRequest, + } + } + return generator(r, ctx, _apimeta.ServerInfo{ + ServerName: serverName, + }) + } +} diff --git a/api/_routers/98-use-rcontext.go b/api/_routers/98-use-rcontext.go index 35e36369..420df864 100644 --- a/api/_routers/98-use-rcontext.go +++ b/api/_routers/98-use-rcontext.go @@ -101,20 +101,24 @@ func (c *RContextRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) { beforeParseDownload: log.Infof("Replying with result: %T %+v", res, res) if downloadRes, isDownload := res.(*_responses.DownloadResponse); isDownload { - ranges, err := http_range.ParseRange(r.Header.Get("Range"), downloadRes.SizeBytes, rctx.Config.Downloads.DefaultRangeChunkSizeBytes) - if errors.Is(err, http_range.ErrInvalid) { - proposedStatusCode = http.StatusRequestedRangeNotSatisfiable - res = _responses.BadRequest("invalid range header") - goto beforeParseDownload // reprocess `res` - } else if errors.Is(err, http_range.ErrNoOverlap) { - proposedStatusCode = http.StatusRequestedRangeNotSatisfiable - res = _responses.BadRequest("out of range") - goto beforeParseDownload // reprocess `res` - } - if len(ranges) > 1 { - proposedStatusCode = http.StatusRequestedRangeNotSatisfiable - res = _responses.BadRequest("only 1 range is supported") - goto beforeParseDownload // reprocess `res` + var ranges []http_range.Range + var err error + if downloadRes.SizeBytes > 0 { + ranges, err = http_range.ParseRange(r.Header.Get("Range"), downloadRes.SizeBytes, rctx.Config.Downloads.DefaultRangeChunkSizeBytes) + if errors.Is(err, http_range.ErrInvalid) { + proposedStatusCode = http.StatusRequestedRangeNotSatisfiable + res = _responses.BadRequest("invalid range header") + goto beforeParseDownload // reprocess `res` + } else if errors.Is(err, http_range.ErrNoOverlap) { + proposedStatusCode = http.StatusRequestedRangeNotSatisfiable + res = _responses.BadRequest("out of range") + goto beforeParseDownload // reprocess `res` + } + if len(ranges) > 1 { + proposedStatusCode = http.StatusRequestedRangeNotSatisfiable + res = _responses.BadRequest("only 1 range is supported") + goto beforeParseDownload // reprocess `res` + } } contentType = downloadRes.ContentType diff --git a/api/custom/federation.go b/api/custom/federation.go index 48a0fb92..e92a1393 100644 --- a/api/custom/federation.go +++ b/api/custom/federation.go @@ -33,7 +33,10 @@ func GetFederationInfo(r *http.Request, rctx rcontext.RequestContext, user _apim } versionUrl := url + "/_matrix/federation/v1/version" - versionResponse, err := matrix.FederatedGet(versionUrl, hostname, rctx) + versionResponse, err := matrix.FederatedGet(rctx, versionUrl, hostname, matrix.NoSigningKey) + if versionResponse != nil { + defer versionResponse.Body.Close() + } if err != nil { rctx.Log.Error(err) sentry.CaptureException(err) diff --git a/api/r0/versions.go b/api/r0/versions.go new file mode 100644 index 00000000..47fc5e79 --- /dev/null +++ b/api/r0/versions.go @@ -0,0 +1,36 @@ +package r0 + +import ( + "net/http" + "slices" + + "github.com/getsentry/sentry-go" + "github.com/t2bot/matrix-media-repo/api/_apimeta" + "github.com/t2bot/matrix-media-repo/api/_responses" + "github.com/t2bot/matrix-media-repo/matrix" + + "github.com/t2bot/matrix-media-repo/common/rcontext" +) + +func ClientVersions(r *http.Request, rctx rcontext.RequestContext, user _apimeta.UserInfo) interface{} { + versions, err := matrix.ClientVersions(rctx, r.Host, user.UserId, user.AccessToken, r.RemoteAddr) + if err != nil { + rctx.Log.Error(err) + sentry.CaptureException(err) + return _responses.InternalServerError("unable to get versions") + } + + // This is where we'd add our feature/version support as needed + if versions.Versions == nil { + versions.Versions = make([]string, 1) + } + + // We add v1.11 by force, even though we can't reliably say the rest of the server implements it. This + // is because server admins which point `/versions` at us are effectively opting in to whatever features + // we need to advertise support for. In our case, it's at least Authenticated Media (MSC3916). + if !slices.Contains(versions.Versions, "v1.11") { + versions.Versions = append(versions.Versions, "v1.11") + } + + return versions +} diff --git a/api/routes.go b/api/routes.go index 49f51ebc..5bbdbe63 100644 --- a/api/routes.go +++ b/api/routes.go @@ -18,6 +18,7 @@ import ( const PrefixMedia = "/_matrix/media" const PrefixClient = "/_matrix/client" +const PrefixFederation = "/_matrix/federation" func buildRoutes() http.Handler { counter := &_routers.RequestCounter{} @@ -36,12 +37,23 @@ func buildRoutes() http.Handler { register([]string{"GET", "HEAD"}, PrefixMedia, "download/:server/:mediaId/:filename", mxSpecV3Transition, router, downloadRoute) register([]string{"GET", "HEAD"}, PrefixMedia, "download/:server/:mediaId", mxSpecV3Transition, router, downloadRoute) register([]string{"GET"}, PrefixMedia, "thumbnail/:server/:mediaId", mxSpecV3Transition, router, makeRoute(_routers.OptionalAccessToken(r0.ThumbnailMedia), "thumbnail", counter)) - register([]string{"GET"}, PrefixMedia, "preview_url", mxSpecV3TransitionCS, router, makeRoute(_routers.RequireAccessToken(r0.PreviewUrl), "url_preview", counter)) + previewUrlRoute := makeRoute(_routers.RequireAccessToken(r0.PreviewUrl), "url_preview", counter) + register([]string{"GET"}, PrefixMedia, "preview_url", mxSpecV3TransitionCS, router, previewUrlRoute) register([]string{"GET"}, PrefixMedia, "identicon/*seed", mxR0, router, makeRoute(_routers.OptionalAccessToken(r0.Identicon), "identicon", counter)) - register([]string{"GET"}, PrefixMedia, "config", mxSpecV3TransitionCS, router, makeRoute(_routers.RequireAccessToken(r0.PublicConfig), "config", counter)) + configRoute := makeRoute(_routers.RequireAccessToken(r0.PublicConfig), "config", counter) + register([]string{"GET"}, PrefixMedia, "config", mxSpecV3TransitionCS, router, configRoute) register([]string{"POST"}, PrefixClient, "logout", mxSpecV3TransitionCS, router, makeRoute(_routers.RequireAccessToken(r0.Logout), "logout", counter)) register([]string{"POST"}, PrefixClient, "logout/all", mxSpecV3TransitionCS, router, makeRoute(_routers.RequireAccessToken(r0.LogoutAll), "logout_all", counter)) register([]string{"POST"}, PrefixMedia, "create", mxV1, router, makeRoute(_routers.RequireAccessToken(v1.CreateMedia), "create", counter)) + register([]string{"GET"}, PrefixClient, "versions", mxNoVersion, router, makeRoute(_routers.OptionalAccessToken(r0.ClientVersions), "client_versions", counter)) + register([]string{"GET"}, PrefixClient, "media/preview_url", mxV1, router, previewUrlRoute) + register([]string{"GET"}, PrefixClient, "media/config", mxV1, router, configRoute) + authedDownloadRoute := makeRoute(_routers.RequireAccessToken(v1.ClientDownloadMedia), "download", counter) + register([]string{"GET"}, PrefixClient, "media/download/:server/:mediaId/:filename", mxV1, router, authedDownloadRoute) + register([]string{"GET"}, PrefixClient, "media/download/:server/:mediaId", mxV1, router, authedDownloadRoute) + register([]string{"GET"}, PrefixClient, "media/thumbnail/:server/:mediaId", mxV1, router, makeRoute(_routers.RequireAccessToken(v1.ClientThumbnailMedia), "thumbnail", counter)) + register([]string{"GET"}, PrefixFederation, "media/download/:mediaId", mxV1, router, makeRoute(_routers.RequireServerAuth(v1.FederationDownloadMedia), "download", counter)) + register([]string{"GET"}, PrefixFederation, "media/thumbnail/:mediaId", mxV1, router, makeRoute(_routers.RequireServerAuth(v1.FederationThumbnailMedia), "thumbnail", counter)) // Custom features register([]string{"GET"}, PrefixMedia, "local_copy/:server/:mediaId", mxUnstable, router, makeRoute(_routers.RequireAccessToken(unstable.LocalCopy), "local_copy", counter)) @@ -134,12 +146,16 @@ var ( mxR0 matrixVersions = []string{"r0"} mxV1 matrixVersions = []string{"v1"} mxV3 matrixVersions = []string{"v3"} + mxNoVersion matrixVersions = []string{""} ) func register(methods []string, prefix string, postfix string, versions matrixVersions, router *httprouter.Router, handler http.Handler) { for _, method := range methods { for _, version := range versions { path := fmt.Sprintf("%s/%s/%s", prefix, version, postfix) + if version == "" { + path = fmt.Sprintf("%s/%s", prefix, postfix) + } router.Handler(method, path, http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { defer func() { // hopefully the body was already closed, but maybe it wasn't diff --git a/api/v1/download.go b/api/v1/download.go new file mode 100644 index 00000000..6fd439ec --- /dev/null +++ b/api/v1/download.go @@ -0,0 +1,54 @@ +package v1 + +import ( + "bytes" + "net/http" + + "github.com/t2bot/matrix-media-repo/api/_apimeta" + "github.com/t2bot/matrix-media-repo/api/_responses" + "github.com/t2bot/matrix-media-repo/api/_routers" + "github.com/t2bot/matrix-media-repo/api/r0" + "github.com/t2bot/matrix-media-repo/common/rcontext" + "github.com/t2bot/matrix-media-repo/util/readers" +) + +func ClientDownloadMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta.UserInfo) interface{} { + r.URL.Query().Set("allow_remote", "true") + r.URL.Query().Set("allow_redirect", "true") + return r0.DownloadMedia(r, rctx, user) +} + +func FederationDownloadMedia(r *http.Request, rctx rcontext.RequestContext, server _apimeta.ServerInfo) interface{} { + query := r.URL.Query() + query.Set("allow_remote", "false") + query.Set("allow_redirect", "true") // we override how redirects work in the response + r.URL.RawQuery = query.Encode() + r = _routers.ForceSetParam("server", r.Host, r) + + res := r0.DownloadMedia(r, rctx, _apimeta.UserInfo{}) + if dl, ok := res.(*_responses.DownloadResponse); ok { + return &_responses.DownloadResponse{ + ContentType: "multipart/mixed", + Filename: "", + SizeBytes: 0, + Data: readers.NewMultipartReader( + &readers.MultipartPart{ContentType: "application/json", Reader: readers.MakeCloser(bytes.NewReader([]byte("{}")))}, + &readers.MultipartPart{ContentType: dl.ContentType, FileName: dl.Filename, Reader: dl.Data}, + ), + TargetDisposition: "attachment", + } + } else if rd, ok := res.(*_responses.RedirectResponse); ok { + return &_responses.DownloadResponse{ + ContentType: "multipart/mixed", + Filename: "", + SizeBytes: 0, + Data: readers.NewMultipartReader( + &readers.MultipartPart{ContentType: "application/json", Reader: readers.MakeCloser(bytes.NewReader([]byte("{}")))}, + &readers.MultipartPart{Location: rd.ToUrl}, + ), + TargetDisposition: "attachment", + } + } else { + return res + } +} diff --git a/api/v1/thumbnail.go b/api/v1/thumbnail.go new file mode 100644 index 00000000..438a6fcf --- /dev/null +++ b/api/v1/thumbnail.go @@ -0,0 +1,54 @@ +package v1 + +import ( + "bytes" + "net/http" + + "github.com/t2bot/matrix-media-repo/api/_apimeta" + "github.com/t2bot/matrix-media-repo/api/_responses" + "github.com/t2bot/matrix-media-repo/api/_routers" + "github.com/t2bot/matrix-media-repo/api/r0" + "github.com/t2bot/matrix-media-repo/common/rcontext" + "github.com/t2bot/matrix-media-repo/util/readers" +) + +func ClientThumbnailMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta.UserInfo) interface{} { + r.URL.Query().Set("allow_remote", "true") + r.URL.Query().Set("allow_redirect", "true") + return r0.ThumbnailMedia(r, rctx, user) +} + +func FederationThumbnailMedia(r *http.Request, rctx rcontext.RequestContext, server _apimeta.ServerInfo) interface{} { + query := r.URL.Query() + query.Set("allow_remote", "false") + query.Set("allow_redirect", "true") // we override how redirects work in the response + r.URL.RawQuery = query.Encode() + r = _routers.ForceSetParam("server", r.Host, r) + + res := r0.ThumbnailMedia(r, rctx, _apimeta.UserInfo{}) + if dl, ok := res.(*_responses.DownloadResponse); ok { + return &_responses.DownloadResponse{ + ContentType: "multipart/mixed", + Filename: "", + SizeBytes: 0, + Data: readers.NewMultipartReader( + &readers.MultipartPart{ContentType: "application/json", Reader: readers.MakeCloser(bytes.NewReader([]byte("{}")))}, + &readers.MultipartPart{ContentType: dl.ContentType, FileName: dl.Filename, Reader: dl.Data}, + ), + TargetDisposition: "attachment", + } + } else if rd, ok := res.(*_responses.RedirectResponse); ok { + return &_responses.DownloadResponse{ + ContentType: "multipart/mixed", + Filename: "", + SizeBytes: 0, + Data: readers.NewMultipartReader( + &readers.MultipartPart{ContentType: "application/json", Reader: readers.MakeCloser(bytes.NewReader([]byte("{}")))}, + &readers.MultipartPart{Location: rd.ToUrl}, + ), + TargetDisposition: "attachment", + } + } else { + return res + } +} diff --git a/cmd/utilities/generate_signing_key/main.go b/cmd/utilities/generate_signing_key/main.go index 3367164d..99afc745 100644 --- a/cmd/utilities/generate_signing_key/main.go +++ b/cmd/utilities/generate_signing_key/main.go @@ -1,13 +1,8 @@ package main import ( - "crypto/ed25519" - "crypto/rand" "flag" - "fmt" "os" - "sort" - "strings" "github.com/sirupsen/logrus" "github.com/t2bot/matrix-media-repo/cmd/utilities/_common" @@ -27,16 +22,7 @@ func main() { if *inputFile != "" { key, err = decodeKey(*inputFile) } else { - keyVersion := makeKeyVersion() - - var priv ed25519.PrivateKey - _, priv, err = ed25519.GenerateKey(nil) - priv = priv[len(priv)-32:] - - key = &homeserver_interop.SigningKey{ - PrivateKey: priv, - KeyVersion: keyVersion, - } + key, err = homeserver_interop.GenerateSigningKey() } if err != nil { logrus.Fatal(err) @@ -47,28 +33,6 @@ func main() { _common.EncodeSigningKeys([]*homeserver_interop.SigningKey{key}, *outputFormat, *outputFile) } -func makeKeyVersion() string { - buf := make([]byte, 2) - chars := strings.Split("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789", "") - for i := 0; i < len(chars); i++ { - sort.Slice(chars, func(i int, j int) bool { - c, err := rand.Read(buf) - - // "should never happen" clauses - if err != nil { - panic(err) - } - if c != len(buf) || c != 2 { - panic(fmt.Sprintf("crypto rand read %d bytes, expected %d", c, len(buf))) - } - - return buf[0] < buf[1] - }) - } - - return strings.Join(chars[:6], "") -} - func decodeKey(fileName string) (*homeserver_interop.SigningKey, error) { f, err := os.Open(fileName) if err != nil { diff --git a/cmd/workers/media_repo/reloads.go b/cmd/workers/media_repo/reloads.go index dbf201c5..4dbde321 100644 --- a/cmd/workers/media_repo/reloads.go +++ b/cmd/workers/media_repo/reloads.go @@ -10,6 +10,7 @@ import ( "github.com/t2bot/matrix-media-repo/database" "github.com/t2bot/matrix-media-repo/errcache" "github.com/t2bot/matrix-media-repo/limits" + "github.com/t2bot/matrix-media-repo/matrix" "github.com/t2bot/matrix-media-repo/metrics" "github.com/t2bot/matrix-media-repo/pgo_internal" "github.com/t2bot/matrix-media-repo/plugins" @@ -29,6 +30,7 @@ func setupReloads() { reloadPluginsOnChan(globals.PluginReloadChan) reloadPoolOnChan(globals.PoolReloadChan) reloadErrorCachesOnChan(globals.ErrorCacheReloadChan) + reloadMatrixCachesOnChan(globals.MatrixCachesReloadChan) reloadPGOOnChan(globals.PGOReloadChan) reloadBucketsOnChan(globals.BucketsReloadChan) } @@ -55,6 +57,8 @@ func stopReloads() { globals.PoolReloadChan <- false logrus.Debug("Stopping ErrorCacheReloadChan") globals.ErrorCacheReloadChan <- false + logrus.Debug("Stopping MatrixCachesReloadChan") + globals.MatrixCachesReloadChan <- false logrus.Debug("Stopping PGOReloadChan") globals.PGOReloadChan <- false logrus.Debug("Stopping BucketsReloadChan") @@ -207,6 +211,20 @@ func reloadErrorCachesOnChan(reloadChan chan bool) { }() } +func reloadMatrixCachesOnChan(reloadChan chan bool) { + go func() { + defer close(reloadChan) + for { + shouldReload := <-reloadChan + if shouldReload { + matrix.FlushSigningKeyCache() + } else { + return // received stop + } + } + }() +} + func reloadPGOOnChan(reloadChan chan bool) { go func() { defer close(reloadChan) diff --git a/common/config/access.go b/common/config/access.go index 6b1bab44..73a1e5a3 100644 --- a/common/config/access.go +++ b/common/config/access.go @@ -153,6 +153,7 @@ func reloadConfig() (*MainRepoConfig, map[string]*DomainRepoConfig, error) { dc.ClientServerApi = d.ClientServerApi dc.BackoffAt = d.BackoffAt dc.AdminApiKind = d.AdminApiKind + dc.SigningKeyPath = d.SigningKeyPath m, err := objToMapYaml(dc) if err != nil { @@ -222,6 +223,15 @@ func GetDomain(domain string) *DomainRepoConfig { return domains[domain] } +func AddDomainForTesting(domain string, config *DomainRepoConfig) { + Get() // Ensure the "main" config was loaded first + if config == nil { + c := NewDefaultDomainConfig() + config = &c + } + domains[domain] = config +} + func DomainConfigFrom(c MainRepoConfig) DomainRepoConfig { // HACK: We should be better at this kind of inheritance dc := NewDefaultDomainConfig() @@ -265,7 +275,7 @@ func UniqueDatastores() []DatastoreConfig { func PrintDomainInfo() { logrus.Info("Domains loaded:") for _, d := range domains { - logrus.Info(fmt.Sprintf("\t%s (%s)", d.Name, d.ClientServerApi)) + logrus.Info(fmt.Sprintf("\t%s (%s | Signing Key Path=%s)", d.Name, d.ClientServerApi, d.SigningKeyPath)) } } diff --git a/common/config/conf_domain.go b/common/config/conf_domain.go index f7784f43..5443efce 100644 --- a/common/config/conf_domain.go +++ b/common/config/conf_domain.go @@ -16,6 +16,7 @@ func NewDefaultDomainConfig() DomainRepoConfig { ClientServerApi: "https://UNDEFINED", BackoffAt: 10, AdminApiKind: "matrix", + SigningKeyPath: "", }, Downloads: DownloadsConfig{ MaxSizeBytes: 104857600, // 100mb diff --git a/common/config/models_main.go b/common/config/models_main.go index 4f4a89f8..6b80d272 100644 --- a/common/config/models_main.go +++ b/common/config/models_main.go @@ -16,6 +16,7 @@ type HomeserverConfig struct { ClientServerApi string `yaml:"csApi"` BackoffAt int `yaml:"backoffAt"` AdminApiKind string `yaml:"adminApiKind"` + SigningKeyPath string `yaml:"signingKeyPath"` } type DatabaseConfig struct { diff --git a/common/config/watch.go b/common/config/watch.go index 6ba7f910..4ad281cd 100644 --- a/common/config/watch.go +++ b/common/config/watch.go @@ -61,10 +61,13 @@ func onFileChanged() { PrintDomainInfo() CheckDeprecations() - logrus.Info("Reloading pool & cache configuration") + logrus.Info("Reloading pool & error cache configurations") globals.PoolReloadChan <- true globals.ErrorCacheReloadChan <- true + logrus.Info("Reloading matrix caches") + globals.MatrixCachesReloadChan <- true + bindAddressChange := configNew.General.BindAddress != configNow.General.BindAddress bindPortChange := configNew.General.Port != configNow.General.Port forwardAddressChange := configNew.General.TrustAnyForward != configNow.General.TrustAnyForward diff --git a/common/errorcodes.go b/common/errorcodes.go index 2ba7fca2..a04aeff3 100644 --- a/common/errorcodes.go +++ b/common/errorcodes.go @@ -13,6 +13,7 @@ const ErrCodeBadRequest = "M_BAD_REQUEST" const ErrCodeRateLimitExceeded = "M_LIMIT_EXCEEDED" const ErrCodeUnknown = "M_UNKNOWN" const ErrCodeForbidden = "M_FORBIDDEN" +const ErrCodeUnauthorized = "M_UNAUTHORIZED" const ErrCodeQuotaExceeded = "M_QUOTA_EXCEEDED" const ErrCodeCannotOverwrite = "M_CANNOT_OVERWRITE_MEDIA" const ErrCodeNotYetUploaded = "M_NOT_YET_UPLOADED" diff --git a/common/globals/reload.go b/common/globals/reload.go index c4c942b2..251742ba 100644 --- a/common/globals/reload.go +++ b/common/globals/reload.go @@ -10,5 +10,6 @@ var CacheReplaceChan = make(chan bool) var PluginReloadChan = make(chan bool) var PoolReloadChan = make(chan bool) var ErrorCacheReloadChan = make(chan bool) +var MatrixCachesReloadChan = make(chan bool) var PGOReloadChan = make(chan bool) var BucketsReloadChan = make(chan bool) diff --git a/config.sample.yaml b/config.sample.yaml index 6798fe11..3cf6152a 100644 --- a/config.sample.yaml +++ b/config.sample.yaml @@ -92,6 +92,11 @@ homeservers: # to "matrix", most functionality requiring the admin API will not work. adminApiKind: "synapse" + # The signing key to use for authorizing outbound federation requests. If not specified, + # requests will not be authorized. See https://docs.t2bot.io/matrix-media-repo/v1.3.5/installation/signing-key/ + # for details. + #signingKeyPath: "/data/example.org.key" + # Options for controlling how access tokens work with the media repo. It is recommended that if # you are going to use these options that the `/logout` and `/logout/all` client-server endpoints # be proxied through this process. They will also be called on the homeserver, and the response diff --git a/database/db.go b/database/db.go index 5931a481..ffae2bc4 100644 --- a/database/db.go +++ b/database/db.go @@ -61,6 +61,8 @@ func Reload() { GetInstance() } +// GetAccessorForTests +// Deprecated: For tests only. func GetAccessorForTests() *sql.DB { return GetInstance().conn } diff --git a/dev/homeserver.nginx.conf b/dev/homeserver.nginx.conf index 10f39142..8a975aa1 100644 --- a/dev/homeserver.nginx.conf +++ b/dev/homeserver.nginx.conf @@ -9,6 +9,16 @@ server { proxy_pass http://host.docker.internal:8001; } + location /_matrix/client/versions { + proxy_set_header Host localhost; + proxy_pass http://host.docker.internal:8001; + } + + location /_matrix/client/unstable/org.matrix.msc3916 { + proxy_set_header Host localhost; + proxy_pass http://host.docker.internal:8001; + } + location /_matrix { proxy_pass http://media_repo_synapse:8008; } diff --git a/homeserver_interop/signing_key.go b/homeserver_interop/signing_key.go index 29c0af81..8a1d887c 100644 --- a/homeserver_interop/signing_key.go +++ b/homeserver_interop/signing_key.go @@ -2,9 +2,50 @@ package homeserver_interop import ( "crypto/ed25519" + "crypto/rand" + "fmt" + "sort" + "strings" ) type SigningKey struct { PrivateKey ed25519.PrivateKey KeyVersion string } + +func GenerateSigningKey() (*SigningKey, error) { + keyVersion := makeKeyVersion() + + _, priv, err := ed25519.GenerateKey(nil) + priv = priv[len(priv)-32:] + if err != nil { + return nil, err + } + + return &SigningKey{ + PrivateKey: priv, + KeyVersion: keyVersion, + }, nil +} + +func makeKeyVersion() string { + buf := make([]byte, 2) + chars := strings.Split("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789", "") + for i := 0; i < len(chars); i++ { + sort.Slice(chars, func(i int, j int) bool { + c, err := rand.Read(buf) + + // "should never happen" clauses + if err != nil { + panic(err) + } + if c != len(buf) || c != 2 { + panic(fmt.Errorf("crypto rand read %d bytes, expected %d", c, len(buf))) + } + + return buf[0] < buf[1] + }) + } + + return strings.Join(chars[:6], "") +} diff --git a/matrix/breakers.go b/matrix/breakers.go index 5f6ee6a9..911f92f2 100644 --- a/matrix/breakers.go +++ b/matrix/breakers.go @@ -50,10 +50,6 @@ func getFederationBreaker(hostname string) *circuit.Breaker { } func doBreakerRequest(ctx rcontext.RequestContext, serverName string, accessToken string, appserviceUserId string, ipAddr string, method string, path string, resp interface{}) error { - if accessToken == "" { - return ErrInvalidToken - } - hs, cb := getBreakerAndConfig(serverName) var replyError error diff --git a/matrix/errors.go b/matrix/errors.go index 965b1851..fb6a7458 100644 --- a/matrix/errors.go +++ b/matrix/errors.go @@ -7,12 +7,12 @@ import ( "github.com/t2bot/matrix-media-repo/common" ) -type errorResponse struct { +type ErrorResponse struct { ErrorCode string `json:"errcode"` Message string `json:"error"` } -func (e errorResponse) Error() string { +func (e ErrorResponse) Error() string { return fmt.Sprintf("code=%s message=%s", e.ErrorCode, e.Message) } @@ -22,7 +22,7 @@ func filterError(err error) (error, error) { } // Unknown token errors should be filtered out explicitly to ensure we don't break on bad requests - var httpErr *errorResponse + var httpErr *ErrorResponse if errors.As(err, &httpErr) { // We send back our own version of errors to ensure we can filter them out elsewhere if httpErr.ErrorCode == common.ErrCodeUnknownToken { @@ -34,3 +34,15 @@ func filterError(err error) (error, error) { return err, err } + +type ServerNotAllowedError struct { + error + ServerName string +} + +func MakeServerNotAllowedError(serverName string) ServerNotAllowedError { + return ServerNotAllowedError{ + error: errors.New("server " + serverName + " is not allowed"), + ServerName: serverName, + } +} diff --git a/matrix/requests.go b/matrix/requests.go index d075d28b..740ea432 100644 --- a/matrix/requests.go +++ b/matrix/requests.go @@ -10,12 +10,16 @@ import ( "io" "net" "net/http" + "net/url" "os" "time" "github.com/t2bot/matrix-media-repo/common/rcontext" + "github.com/t2bot/matrix-media-repo/database" ) +const NoSigningKey = "" + // Based in part on https://github.com/matrix-org/gomatrix/blob/072b39f7fa6b40257b4eead8c958d71985c28bdd/client.go#L180-L243 func doRequest(ctx rcontext.RequestContext, method string, urlStr string, body interface{}, result interface{}, accessToken string, ipAddr string) error { ctx.Log.Debugf("Calling %s %s", method, urlStr) @@ -60,7 +64,7 @@ func doRequest(ctx rcontext.RequestContext, method string, urlStr string, body i return err } if res.StatusCode != http.StatusOK { - mtxErr := &errorResponse{} + mtxErr := &ErrorResponse{} err = json.Unmarshal(contents, mtxErr) if err == nil && mtxErr.ErrorCode != "" { return mtxErr @@ -78,14 +82,14 @@ func doRequest(ctx rcontext.RequestContext, method string, urlStr string, body i return nil } -func FederatedGet(url string, realHost string, ctx rcontext.RequestContext) (*http.Response, error) { - ctx.Log.Debug("Doing federated GET to " + url + " with host " + realHost) +func FederatedGet(ctx rcontext.RequestContext, reqUrl string, realHost string, useSigningKeyPath string) (*http.Response, error) { + ctx.Log.Debug("Doing federated GET to " + reqUrl + " with host " + realHost) cb := getFederationBreaker(realHost) var resp *http.Response replyError := cb.CallContext(ctx, func() error { - req, err := http.NewRequest("GET", url, nil) + req, err := http.NewRequest(http.MethodGet, reqUrl, nil) if err != nil { return err } @@ -95,6 +99,23 @@ func FederatedGet(url string, realHost string, ctx rcontext.RequestContext) (*ht req.Header.Set("User-Agent", "matrix-media-repo") req.Host = realHost + if useSigningKeyPath != NoSigningKey { + ctx.Log.Debug("Reading signing key and adding authentication headers") + key, err := getLocalSigningKey(useSigningKeyPath) + if err != nil { + return err + } + parsed, err := url.Parse(reqUrl) + if err != nil { + return err + } + auth, err := CreateXMatrixHeader(ctx.Request.Host, realHost, http.MethodGet, parsed.RequestURI(), &database.AnonymousJson{}, key.Key, key.Version) + if err != nil { + return err + } + req.Header.Set("Authorization", auth) + } + var client *http.Client if os.Getenv("MEDIA_REPO_UNSAFE_FEDERATION") != "true" { // This is how we verify the certificate is valid for the host we expect. diff --git a/matrix/requests_info.go b/matrix/requests_info.go new file mode 100644 index 00000000..c2cc1de0 --- /dev/null +++ b/matrix/requests_info.go @@ -0,0 +1,17 @@ +package matrix + +import "github.com/t2bot/matrix-media-repo/common/rcontext" + +type ClientVersionsResponse struct { + Versions []string `json:"versions"` + UnstableFeatures map[string]bool `json:"unstable_features"` +} + +func ClientVersions(ctx rcontext.RequestContext, serverName string, accessToken string, appserviceUserId string, ipAddr string) (*ClientVersionsResponse, error) { + response := &ClientVersionsResponse{} + err := doBreakerRequest(ctx, serverName, accessToken, appserviceUserId, ipAddr, "GET", "/_matrix/client/versions", response) + if err != nil { + return nil, err + } + return response, nil +} diff --git a/matrix/requests_signing.go b/matrix/requests_signing.go new file mode 100644 index 00000000..54e09b71 --- /dev/null +++ b/matrix/requests_signing.go @@ -0,0 +1,201 @@ +package matrix + +import ( + "crypto/ed25519" + "encoding/json" + "errors" + "fmt" + "net/http" + "sync" + "time" + + "github.com/patrickmn/go-cache" + "github.com/sirupsen/logrus" + "github.com/t2bot/go-typed-singleflight" + "github.com/t2bot/matrix-media-repo/common/rcontext" + "github.com/t2bot/matrix-media-repo/database" + "github.com/t2bot/matrix-media-repo/util" +) + +type signingKey struct { + Key string `json:"key"` +} + +type ServerKeyResult struct { + ServerName string `json:"server_name"` + ValidUntilTs int64 `json:"valid_until_ts"` + VerifyKeys map[string]signingKey `json:"verify_keys"` // unpadded base64 + OldVerifyKeys map[string]signingKey `json:"old_verify_keys"` // unpadded base64 + Signatures map[string]map[string]string `json:"signatures"` // unpadded base64; > +} + +type ServerSigningKeys map[string]ed25519.PublicKey + +var signingKeySf = new(typedsf.Group[ServerSigningKeys]) +var signingKeyCache = cache.New(cache.NoExpiration, 30*time.Second) +var signingKeyRWLock = new(sync.RWMutex) + +// TestsOnlyInjectSigningKey +// Deprecated: For tests only. +func TestsOnlyInjectSigningKey(serverName string, httpFederationUrl string) error { + resp, err := http.Get(httpFederationUrl + "/_matrix/key/v2/server") + if err != nil { + return err + } + defer resp.Body.Close() + + decoder := json.NewDecoder(resp.Body) + raw := database.AnonymousJson{} + if err = decoder.Decode(&raw); err != nil { + return err + } + keyInfo := new(ServerKeyResult) + if err = raw.ApplyTo(keyInfo); err != nil { + return err + } + + // Convert keys to something useful, and check signatures + serverKeys, err := CheckSigningKeySignatures(serverName, keyInfo, raw) + if err != nil { + return err + } + + // Cache & return (unlock is deferred) + signingKeyRWLock.Lock() + defer signingKeyRWLock.Unlock() + cacheUntil := time.Until(time.UnixMilli(keyInfo.ValidUntilTs)) / 2 + signingKeyCache.Set(serverName, &serverKeys, cacheUntil) + + return nil +} + +func querySigningKeyCache(serverName string) ServerSigningKeys { + if val, ok := signingKeyCache.Get(serverName); ok { + ptr := val.(*ServerSigningKeys) + return *ptr + } + return nil +} + +func QuerySigningKeys(serverName string) (ServerSigningKeys, error) { + signingKeyRWLock.RLock() + keys := querySigningKeyCache(serverName) + signingKeyRWLock.RUnlock() + if keys != nil { + return keys, nil + } + + keys, err, _ := signingKeySf.Do(serverName, func() (ServerSigningKeys, error) { + ctx := rcontext.Initial().LogWithFields(logrus.Fields{ + "keysForServer": serverName, + }) + + signingKeyRWLock.Lock() + defer signingKeyRWLock.Unlock() + + // check cache once more, just in case the locks overlapped + cachedKeys := querySigningKeyCache(serverName) + if keys != nil { + return cachedKeys, nil + } + + // now we can try to get the keys from the source + url, hostname, err := GetServerApiUrl(serverName) + if err != nil { + return nil, err + } + + keysUrl := url + "/_matrix/key/v2/server" + keysResponse, err := FederatedGet(ctx, keysUrl, hostname, NoSigningKey) + if keysResponse != nil { + defer keysResponse.Body.Close() + } + if err != nil { + return nil, err + } + + decoder := json.NewDecoder(keysResponse.Body) + raw := database.AnonymousJson{} + if err = decoder.Decode(&raw); err != nil { + return nil, err + } + keyInfo := new(ServerKeyResult) + if err = raw.ApplyTo(keyInfo); err != nil { + return nil, err + } + + // Check validity before we go much further + if keyInfo.ServerName != serverName { + return nil, fmt.Errorf("got keys for '%s' but expected '%s'", keyInfo.ServerName, serverName) + } + if keyInfo.ValidUntilTs <= util.NowMillis() { + return nil, errors.New("returned server keys are expired") + } + cacheUntil := time.Until(time.UnixMilli(keyInfo.ValidUntilTs)) / 2 + if cacheUntil <= (6 * time.Second) { + return nil, errors.New("returned server keys would expire too quickly") + } + + // Convert keys to something useful, and check signatures + serverKeys, err := CheckSigningKeySignatures(serverName, keyInfo, raw) + if err != nil { + return nil, err + } + + // Cache & return (unlock was deferred) + signingKeyCache.Set(serverName, &serverKeys, cacheUntil) + return serverKeys, nil + }) + return keys, err +} + +func CheckSigningKeySignatures(serverName string, keyInfo *ServerKeyResult, raw database.AnonymousJson) (ServerSigningKeys, error) { + serverKeys := make(ServerSigningKeys) + for keyId, keyObj := range keyInfo.VerifyKeys { + b, err := util.DecodeUnpaddedBase64String(keyObj.Key) + if err != nil { + return nil, errors.Join(fmt.Errorf("bad base64 for key ID '%s' for '%s'", keyId, serverName), err) + } + + serverKeys[keyId] = b + } + + if len(keyInfo.Signatures) == 0 || len(keyInfo.Signatures[serverName]) == 0 { + return nil, fmt.Errorf("missing signatures from '%s'", serverName) + } + delete(raw, "signatures") + canonical, err := util.EncodeCanonicalJson(raw) + if err != nil { + return nil, err + } + for domain, sig := range keyInfo.Signatures { + if domain != serverName { + return nil, fmt.Errorf("unexpected signature from '%s' (expected '%s')", domain, serverName) + } + + for keyId, b64 := range sig { + signatureBytes, err := util.DecodeUnpaddedBase64String(b64) + if err != nil { + return nil, errors.Join(fmt.Errorf("bad base64 signature for key ID '%s' for '%s'", keyId, serverName), err) + } + + key, ok := serverKeys[keyId] + if !ok { + return nil, fmt.Errorf("unknown key ID '%s' for signature from '%s'", keyId, serverName) + } + + if !ed25519.Verify(key, canonical, signatureBytes) { + return nil, fmt.Errorf("invalid signature '%s' from key ID '%s' for '%s'", b64, keyId, serverName) + } + } + } + + // Ensure *all* keys have signed the response + for keyId, _ := range serverKeys { + if _, ok := keyInfo.Signatures[serverName][keyId]; !ok { + return nil, fmt.Errorf("missing signature from key '%s'", keyId) + } + } + + return serverKeys, nil +} diff --git a/matrix/server_discovery.go b/matrix/server_discovery.go index 02e5ce37..43ee1996 100644 --- a/matrix/server_discovery.go +++ b/matrix/server_discovery.go @@ -5,6 +5,7 @@ import ( "fmt" "net" "net/http" + "os" "strconv" "strings" "sync" @@ -36,6 +37,12 @@ func GetServerApiUrl(hostname string) (string, string, error) { logrus.Debug("Getting server API URL for " + hostname) + scheme := "https" + if os.Getenv("MEDIA_REPO_HTTP_ONLY_FEDERATION") == "true" { + logrus.Warnf("Making non-https request to hostname %s because MEDIA_REPO_HTTP_ONLY_FEDERATION is set to true", hostname) + scheme = "http" + } + // Check to see if we've cached this hostname at all setupCache() record, found := apiUrlCacheInstance.Get(hostname) @@ -58,7 +65,7 @@ func GetServerApiUrl(hostname string) (string, string, error) { // Step 1 of the discovery process: if the hostname is an IP, use that with explicit or default port logrus.Debug("Testing if " + h + " is an IP address") if is.IP(h) { - url := fmt.Sprintf("https://%s", net.JoinHostPort(h, p)) + url := fmt.Sprintf("%s://%s", scheme, net.JoinHostPort(h, p)) server := cachedServer{url, hostname} apiUrlCacheInstance.Set(hostname, server, cache.DefaultExpiration) logrus.Debug("Server API URL for " + hostname + " is " + url + " (IP address)") @@ -68,7 +75,7 @@ func GetServerApiUrl(hostname string) (string, string, error) { // Step 2: if the hostname is not an IP address, and an explicit port is given, use that logrus.Debug("Testing if a default port was used. Using default = ", defPort) if !defPort { - url := fmt.Sprintf("https://%s", net.JoinHostPort(h, p)) + url := fmt.Sprintf("%s://%s", scheme, net.JoinHostPort(h, p)) server := cachedServer{url, h} apiUrlCacheInstance.Set(hostname, server, cache.DefaultExpiration) logrus.Debug("Server API URL for " + hostname + " is " + url + " (explicit port)") @@ -78,7 +85,7 @@ func GetServerApiUrl(hostname string) (string, string, error) { // Step 3: if the hostname is not an IP address and no explicit port is given, do .well-known // Note that we have sprawling branches here because we need to fall through to step 4 if parsing fails logrus.Debug("Doing .well-known lookup on " + h) - r, err := http.Get(fmt.Sprintf("https://%s/.well-known/matrix/server", h)) + r, err := http.Get(fmt.Sprintf("%s://%s/.well-known/matrix/server", scheme, h)) if r != nil { defer r.Body.Close() } @@ -98,7 +105,7 @@ func GetServerApiUrl(hostname string) (string, string, error) { // Step 3a: if the delegated host is an IP address, use that (regardless of port) logrus.Debug("Checking if WK host is an IP: " + wkHost) if is.IP(wkHost) { - url := fmt.Sprintf("https://%s", net.JoinHostPort(wkHost, wkPort)) + url := fmt.Sprintf("%s://%s", scheme, net.JoinHostPort(wkHost, wkPort)) server := cachedServer{url, wk.ServerAddr} apiUrlCacheInstance.Set(hostname, server, cache.DefaultExpiration) logrus.Debug("Server API URL for " + hostname + " is " + url + " (WK; IP address)") @@ -109,7 +116,7 @@ func GetServerApiUrl(hostname string) (string, string, error) { logrus.Debug("Checking if WK is using default port? ", wkDefPort) if !wkDefPort { wkHost = net.JoinHostPort(wkHost, wkPort) - url := fmt.Sprintf("https://%s", wkHost) + url := fmt.Sprintf("%s://%s", scheme, wkHost) server := cachedServer{url, wkHost} apiUrlCacheInstance.Set(hostname, server, cache.DefaultExpiration) logrus.Debug("Server API URL for " + hostname + " is " + url + " (WK; explicit port)") @@ -126,7 +133,7 @@ func GetServerApiUrl(hostname string) (string, string, error) { if realAddr[len(realAddr)-1:] == "." { realAddr = realAddr[0 : len(realAddr)-1] } - url := fmt.Sprintf("https://%s", net.JoinHostPort(realAddr, strconv.Itoa(int(addrs[0].Port)))) + url := fmt.Sprintf("%s://%s", scheme, net.JoinHostPort(realAddr, strconv.Itoa(int(addrs[0].Port)))) server := cachedServer{url, wkHost} apiUrlCacheInstance.Set(hostname, server, cache.DefaultExpiration) logrus.Debug("Server API URL for " + hostname + " is " + url + " (WK; SRV)") @@ -144,7 +151,7 @@ func GetServerApiUrl(hostname string) (string, string, error) { if realAddr[len(realAddr)-1:] == "." { realAddr = realAddr[0 : len(realAddr)-1] } - url := fmt.Sprintf("https://%s", net.JoinHostPort(realAddr, strconv.Itoa(int(addrs[0].Port)))) + url := fmt.Sprintf("%s://%s", scheme, net.JoinHostPort(realAddr, strconv.Itoa(int(addrs[0].Port)))) server := cachedServer{url, wkHost} apiUrlCacheInstance.Set(hostname, server, cache.DefaultExpiration) logrus.Debug("Server API URL for " + hostname + " is " + url + " (WK; SRV-Deprecated)") @@ -153,7 +160,7 @@ func GetServerApiUrl(hostname string) (string, string, error) { // Step 3d: use the delegated host as-is logrus.Debug("Using .well-known as-is for ", wkHost) - url := fmt.Sprintf("https://%s", net.JoinHostPort(wkHost, wkPort)) + url := fmt.Sprintf("%s://%s", scheme, net.JoinHostPort(wkHost, wkPort)) server := cachedServer{url, wkHost} apiUrlCacheInstance.Set(hostname, server, cache.DefaultExpiration) logrus.Debug("Server API URL for " + hostname + " is " + url + " (WK; fallback)") @@ -176,7 +183,7 @@ func GetServerApiUrl(hostname string) (string, string, error) { if realAddr[len(realAddr)-1:] == "." { realAddr = realAddr[0 : len(realAddr)-1] } - url := fmt.Sprintf("https://%s", net.JoinHostPort(realAddr, strconv.Itoa(int(addrs[0].Port)))) + url := fmt.Sprintf("%s://%s", scheme, net.JoinHostPort(realAddr, strconv.Itoa(int(addrs[0].Port)))) server := cachedServer{url, h} apiUrlCacheInstance.Set(hostname, server, cache.DefaultExpiration) logrus.Debug("Server API URL for " + hostname + " is " + url + " (SRV)") @@ -193,7 +200,7 @@ func GetServerApiUrl(hostname string) (string, string, error) { if realAddr[len(realAddr)-1:] == "." { realAddr = realAddr[0 : len(realAddr)-1] } - url := fmt.Sprintf("https://%s", net.JoinHostPort(realAddr, strconv.Itoa(int(addrs[0].Port)))) + url := fmt.Sprintf("%s://%s", scheme, net.JoinHostPort(realAddr, strconv.Itoa(int(addrs[0].Port)))) server := cachedServer{url, h} apiUrlCacheInstance.Set(hostname, server, cache.DefaultExpiration) logrus.Debug("Server API URL for " + hostname + " is " + url + " (SRV-Deprecated)") @@ -202,7 +209,7 @@ func GetServerApiUrl(hostname string) (string, string, error) { // Step 6: use the target host as-is logrus.Debug("Using host as-is: ", hostname) - url := fmt.Sprintf("https://%s", net.JoinHostPort(h, p)) + url := fmt.Sprintf("%s://%s", scheme, net.JoinHostPort(h, p)) server := cachedServer{url, h} apiUrlCacheInstance.Set(hostname, server, cache.DefaultExpiration) logrus.Debug("Server API URL for " + hostname + " is " + url + " (fallback)") diff --git a/matrix/signing_key_cache.go b/matrix/signing_key_cache.go new file mode 100644 index 00000000..f3b8bc63 --- /dev/null +++ b/matrix/signing_key_cache.go @@ -0,0 +1,43 @@ +package matrix + +import ( + "crypto/ed25519" + "os" + "time" + + "github.com/patrickmn/go-cache" + "github.com/t2bot/matrix-media-repo/homeserver_interop/mmr" +) + +type LocalSigningKey struct { + Key ed25519.PrivateKey + Version string +} + +var localSigningKeyCache = cache.New(5*time.Minute, 10*time.Minute) + +func FlushSigningKeyCache() { + localSigningKeyCache.Flush() +} + +func getLocalSigningKey(fromPath string) (*LocalSigningKey, error) { + if val, ok := localSigningKeyCache.Get(fromPath); ok { + return val.(*LocalSigningKey), nil + } + + f, err := os.Open(fromPath) + defer f.Close() + if err != nil { + return nil, err + } + key, err := mmr.DecodeSigningKey(f) + if err != nil { + return nil, err + } + sk := &LocalSigningKey{ + Key: key.PrivateKey, + Version: key.KeyVersion, + } + localSigningKeyCache.Set(fromPath, sk, cache.DefaultExpiration) + return sk, nil +} diff --git a/matrix/xmatrix.go b/matrix/xmatrix.go new file mode 100644 index 00000000..6b9e890e --- /dev/null +++ b/matrix/xmatrix.go @@ -0,0 +1,104 @@ +package matrix + +import ( + "crypto/ed25519" + "errors" + "fmt" + "net/http" + "strings" + + "github.com/t2bot/matrix-media-repo/database" + "github.com/t2bot/matrix-media-repo/util" +) + +var ErrNoXMatrixAuth = errors.New("no X-Matrix auth headers") +var ErrWrongDestination = errors.New("wrong destination") + +func ValidateXMatrixAuth(request *http.Request, expectNoContent bool) (string, error) { + if !expectNoContent { + panic("development error: X-Matrix auth validation can only be done with an empty body for now") + } + + auths, err := util.GetXMatrixAuth(request.Header.Values("Authorization")) + if err != nil { + return "", err + } + if len(auths) == 0 { + return "", ErrNoXMatrixAuth + } + + keys, err := QuerySigningKeys(auths[0].Origin) + if err != nil { + return "", err + } + + uri := request.RequestURI + if strings.HasSuffix(uri, "?") { + uri = uri[:len(uri)-1] + } + + err = ValidateXMatrixAuthHeader(request.Method, uri, &database.AnonymousJson{}, auths, keys, request.Host) + if err != nil { + return "", err + } + return auths[0].Origin, nil +} + +func ValidateXMatrixAuthHeader(requestMethod string, requestUri string, content any, headers []util.XMatrixAuth, originKeys ServerSigningKeys, destinationHost string) error { + if len(headers) == 0 { + return ErrNoXMatrixAuth + } + + obj := map[string]interface{}{ + "method": requestMethod, + "uri": requestUri, + "origin": headers[0].Origin, + "destination": headers[0].Destination, + "content": content, + } + canonical, err := util.EncodeCanonicalJson(obj) + if err != nil { + return err + } + + for i, h := range headers { + if h.Origin != obj["origin"] { + return errors.New("auth is from multiple servers") + } + if h.Destination != obj["destination"] { + return errors.New("auth is for multiple servers") + } + if h.Destination != "" && (!util.IsServerOurs(h.Destination) || destinationHost != h.Destination) { + return ErrWrongDestination + } + + if key, ok := (originKeys)[h.KeyId]; ok { + if !ed25519.Verify(key, canonical, h.Signature) { + return fmt.Errorf("failed signatures on '%s', header %d", h.KeyId, i) + } + } else { + return fmt.Errorf("unknown key '%s'", h.KeyId) + } + } + + return nil +} + +func CreateXMatrixHeader(origin string, destination string, requestMethod string, requestUri string, content any, key ed25519.PrivateKey, keyVersion string) (string, error) { + obj := map[string]interface{}{ + "method": requestMethod, + "uri": requestUri, + "origin": origin, + "destination": destination, + "content": content, + } + canonical, err := util.EncodeCanonicalJson(obj) + if err != nil { + return "", err + } + + b := ed25519.Sign(key, canonical) + sig := util.EncodeUnpaddedBase64ToString(b) + + return fmt.Sprintf("X-Matrix origin=\"%s\",destination=\"%s\",key=\"ed25519:%s\",sig=\"%s\"", origin, destination, keyVersion, sig), nil +} diff --git a/pipelines/_steps/download/try_download.go b/pipelines/_steps/download/try_download.go index bda7aa62..3c2e959c 100644 --- a/pipelines/_steps/download/try_download.go +++ b/pipelines/_steps/download/try_download.go @@ -1,14 +1,18 @@ package download import ( + "encoding/json" "errors" "fmt" "io" "mime" + "mime/multipart" "net/http" "net/url" "strconv" + "strings" + "github.com/getsentry/sentry-go" "github.com/prometheus/client_golang/prometheus" "github.com/t2bot/matrix-media-repo/common" "github.com/t2bot/matrix-media-repo/common/rcontext" @@ -25,6 +29,7 @@ import ( type downloadResult struct { r io.ReadCloser + metadata *database.AnonymousJson filename string contentType string err error @@ -55,12 +60,52 @@ func TryDownload(ctx rcontext.RequestContext, origin string, mediaId string) (*d return } - downloadUrl := fmt.Sprintf("%s/_matrix/media/v3/download/%s/%s?allow_remote=false&allow_redirect=true", baseUrl, url.PathEscape(origin), url.PathEscape(mediaId)) - resp, err := matrix.FederatedGet(downloadUrl, realHost, ctx) - metrics.MediaDownloaded.With(prometheus.Labels{"origin": origin}).Inc() - if err != nil { - errFn(err) - return + var resp *http.Response + var downloadUrl string + usesMultipartFormat := false + if ctx.Config.SigningKeyPath != "" { + downloadUrl = fmt.Sprintf("%s/_matrix/federation/v1/media/download/%s", baseUrl, url.PathEscape(mediaId)) + resp, err = matrix.FederatedGet(ctx, downloadUrl, realHost, ctx.Config.SigningKeyPath) + metrics.MediaDownloaded.With(prometheus.Labels{"origin": origin}).Inc() + if err != nil { + errFn(err) + return + } + if resp.StatusCode == http.StatusForbidden || resp.StatusCode == http.StatusUnauthorized { + errFn(matrix.MakeServerNotAllowedError(ctx.Request.Host)) + return + } else if resp.StatusCode == http.StatusNotFound { + decoder := json.NewDecoder(resp.Body) + resp2 := resp // copy response in case we clear it out later + defer resp2.Body.Close() + mxerr := &matrix.ErrorResponse{} + if err = decoder.Decode(&mxerr); err != nil { + // we probably got not-json - ignore and move on + ctx.Log.Debugf("Ignoring JSON decoding error on download error %d: %v", resp.StatusCode, err) + resp = nil // indicate we want to use fallback + } else { + if mxerr.ErrorCode == "M_UNRECOGNIZED" { + ctx.Log.Debugf("Destination doesn't support MSC3916") + resp = nil // indicate we want to use fallback + } + } + } else if resp.StatusCode == http.StatusOK { + usesMultipartFormat = true + } + } else { + // Yes, we are deliberately loud about this. People should configure this. + ctx.Log.Warn("No signing key is configured for this domain! See `signingKeyPath` in the sample config for details.") + } + + // Try fallback (unauthenticated) + if resp == nil { + downloadUrl = fmt.Sprintf("%s/_matrix/media/v3/download/%s/%s?allow_remote=false&allow_redirect=true", baseUrl, url.PathEscape(origin), url.PathEscape(mediaId)) + resp, err = matrix.FederatedGet(ctx, downloadUrl, realHost, matrix.NoSigningKey) + metrics.MediaDownloaded.With(prometheus.Labels{"origin": origin}).Inc() + if err != nil { + errFn(err) + return + } } if resp.StatusCode == http.StatusNotFound { @@ -85,24 +130,92 @@ func TryDownload(ctx rcontext.RequestContext, origin string, mediaId string) (*d return } - r := resp.Body if ctx.Config.Downloads.MaxSizeBytes > 0 { - r = readers.LimitReaderWithOverrunError(resp.Body, ctx.Config.Downloads.MaxSizeBytes) + resp.Body = readers.LimitReaderWithOverrunError(resp.Body, ctx.Config.Downloads.MaxSizeBytes) + } + + contentType := resp.Header.Get("Content-Type") // we default Content-Type after we inspect for multiparts + + metadata := &database.AnonymousJson{} + mediaPart := util.MatrixMediaPartFromResponse(resp) + if usesMultipartFormat { + if !strings.HasPrefix(contentType, "multipart/mixed;") { + errFn(fmt.Errorf("expected multipart/mixed, got %s", contentType)) + return + } + + _, params, err := mime.ParseMediaType(contentType) + if err != nil { + errFn(err) + return + } + + partReader := multipart.NewReader(resp.Body, params["boundary"]) + + // The first part should always be the metadata + jsonPart, err := partReader.NextPart() + if err != nil { + errFn(err) + return + } + partType := jsonPart.Header.Get("Content-Type") + if partType == "" || partType == "application/json" { + decoder := json.NewDecoder(jsonPart) + err = decoder.Decode(&metadata) + if err != nil { + errFn(err) + return + } + } else { + errFn(fmt.Errorf("expected application/json as the first part, got %s instead", partType)) + } + + ctx.Log.Debugf("Got metadata: %v", metadata) + + // The second part should always be the media itself + bodyPart, err := partReader.NextPart() + if err != nil { + errFn(err) + return + } + mediaPart = util.MatrixMediaPartFromMimeMultipart(bodyPart) + contentType = mediaPart.Header.Get("Content-Type") // Content-Type should really be the media content type + + locationHeader := mediaPart.Header.Get("Location") + if locationHeader != "" { + // the media part body won't have anything for us - go `GET` the URL. + ctx.Log.Debugf("Redirecting to %s", locationHeader) + + err = mediaPart.Body.Close() + if err != nil { + sentry.CaptureException(errors.Join(errors.New("non-fatal error closing redirected MSC3916 body"), err)) + ctx.Log.Debug("Non-fatal error closing redirected MSC3916 body: ", err) + } + + resp, err = http.DefaultClient.Get(locationHeader) + if err != nil { + errFn(err) + return + } + mediaPart = util.MatrixMediaPartFromResponse(resp) + contentType = mediaPart.Header.Get("Content-Type") + } } - contentType := resp.Header.Get("Content-Type") + // Default the Content-Type if we haven't already if contentType == "" { contentType = "application/octet-stream" // binary } fileName := "download" - _, params, err := mime.ParseMediaType(resp.Header.Get("Content-Disposition")) + _, params, err := mime.ParseMediaType(mediaPart.Header.Get("Content-Disposition")) if err == nil && params["filename"] != "" { fileName = params["filename"] } ch <- downloadResult{ - r: r, + r: mediaPart.Body, + metadata: metadata, filename: fileName, contentType: contentType, err: nil, @@ -117,6 +230,7 @@ func TryDownload(ctx rcontext.RequestContext, origin string, mediaId string) (*d } // At this point, res.r is our http response body. + // TODO: Do something with res.metadata (MSC3911) return datastore_op.PutAndReturnStream(ctx, origin, mediaId, res.r, res.contentType, res.filename, datastores.RemoteMediaKind) } diff --git a/pipelines/pipeline_download/pipeline.go b/pipelines/pipeline_download/pipeline.go index 3a600875..fa45c514 100644 --- a/pipelines/pipeline_download/pipeline.go +++ b/pipelines/pipeline_download/pipeline.go @@ -14,6 +14,7 @@ import ( "github.com/t2bot/matrix-media-repo/common/rcontext" "github.com/t2bot/matrix-media-repo/database" "github.com/t2bot/matrix-media-repo/limits" + "github.com/t2bot/matrix-media-repo/matrix" "github.com/t2bot/matrix-media-repo/pipelines/_steps/download" "github.com/t2bot/matrix-media-repo/pipelines/_steps/meta" "github.com/t2bot/matrix-media-repo/pipelines/_steps/quarantine" @@ -141,6 +142,14 @@ func Execute(ctx rcontext.RequestContext, origin string, mediaId string, opts Do cancel() return nil, r, err } + var notAllowedErr *matrix.ServerNotAllowedError + if errors.As(err, ¬AllowedErr) { + if notAllowedErr.ServerName != ctx.Request.Host { + ctx.Log.Debug("'Not allowed' error is for another server - retrying") + cancel() + return Execute(ctx, origin, mediaId, opts) + } + } if err != nil { cancel() return nil, nil, err diff --git a/test/matrix_resolve_test.go b/test/matrix_resolve_test.go new file mode 100644 index 00000000..e62da699 --- /dev/null +++ b/test/matrix_resolve_test.go @@ -0,0 +1,26 @@ +package test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/t2bot/matrix-media-repo/matrix" +) + +func doResolve(t *testing.T, origin string, expectedAddress string, expectedHost string) { + url, host, err := matrix.GetServerApiUrl(origin) + assert.NoError(t, err, origin) + assert.Equal(t, expectedAddress, url, origin) + assert.Equal(t, expectedHost, host, origin) +} + +func TestResolveMatrix(t *testing.T) { + doResolve(t, "2.s.resolvematrix.dev:7652", "https://2.s.resolvematrix.dev:7652", "2.s.resolvematrix.dev") + doResolve(t, "3b.s.resolvematrix.dev", "https://wk.3b.s.resolvematrix.dev:7753", "wk.3b.s.resolvematrix.dev:7753") + doResolve(t, "3c.s.resolvematrix.dev", "https://srv.wk.3c.s.resolvematrix.dev:7754", "wk.3c.s.resolvematrix.dev") + doResolve(t, "3d.s.resolvematrix.dev", "https://wk.3d.s.resolvematrix.dev:8448", "wk.3d.s.resolvematrix.dev") + doResolve(t, "4.s.resolvematrix.dev", "https://srv.4.s.resolvematrix.dev:7855", "4.s.resolvematrix.dev") + doResolve(t, "5.s.resolvematrix.dev", "https://5.s.resolvematrix.dev:8448", "5.s.resolvematrix.dev") + doResolve(t, "3c.msc4040.s.resolvematrix.dev", "https://srv.wk.3c.msc4040.s.resolvematrix.dev:7053", "wk.3c.msc4040.s.resolvematrix.dev") + doResolve(t, "4.msc4040.s.resolvematrix.dev", "https://srv.4.msc4040.s.resolvematrix.dev:7054", "4.msc4040.s.resolvematrix.dev") +} diff --git a/test/msc3916_downloads_suite_test.go b/test/msc3916_downloads_suite_test.go new file mode 100644 index 00000000..f2f968f2 --- /dev/null +++ b/test/msc3916_downloads_suite_test.go @@ -0,0 +1,292 @@ +package test + +import ( + "fmt" + "io" + "log" + "net/http" + "net/http/httptest" + "net/url" + "os" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" + "github.com/t2bot/matrix-media-repo/common/config" + "github.com/t2bot/matrix-media-repo/database" + "github.com/t2bot/matrix-media-repo/homeserver_interop" + "github.com/t2bot/matrix-media-repo/matrix" + "github.com/t2bot/matrix-media-repo/test/test_internals" + "github.com/t2bot/matrix-media-repo/util" +) + +type MSC3916DownloadsSuite struct { + suite.Suite + deps *test_internals.ContainerDeps + keyServer *test_internals.HostedFile + keyServerKey *homeserver_interop.SigningKey +} + +func (s *MSC3916DownloadsSuite) SetupSuite() { + err := os.Setenv("MEDIA_REPO_HTTP_ONLY_FEDERATION", "true") + if err != nil { + s.T().Fatal(err) + } + + deps, err := test_internals.MakeTestDeps() + if err != nil { + log.Fatal(err) + } + s.deps = deps + + s.keyServer, s.keyServerKey = test_internals.MakeKeyServer(deps) +} + +func (s *MSC3916DownloadsSuite) TearDownSuite() { + err := os.Unsetenv("MEDIA_REPO_HTTP_ONLY_FEDERATION") + if err != nil { + s.T().Fatal(err) + } + if s.deps != nil { + if s.T().Failed() { + s.deps.Debug() + } + s.deps.Teardown() + } +} + +func (s *MSC3916DownloadsSuite) TestClientDownloads() { + t := s.T() + + client1 := s.deps.Homeservers[0].UnprivilegedUsers[0].WithCsUrl(s.deps.Machines[0].HttpUrl) + client2 := &test_internals.MatrixClient{ + ClientServerUrl: s.deps.Machines[0].HttpUrl, + ServerName: s.deps.Homeservers[0].ServerName, + AccessToken: "", // this client isn't authed + UserId: "", // this client isn't authed + } + + contentType, img, err := test_internals.MakeTestImage(512, 512) + assert.NoError(t, err) + fname := "image" + util.ExtensionForContentType(contentType) + + res, err := client1.Upload(fname, contentType, img) + assert.NoError(t, err) + assert.NotEmpty(t, res.MxcUri) + + origin, mediaId, err := util.SplitMxc(res.MxcUri) + assert.NoError(t, err) + assert.Equal(t, client1.ServerName, origin) + assert.NotEmpty(t, mediaId) + + raw, err := client2.DoRaw("GET", fmt.Sprintf("/_matrix/client/v1/media/download/%s/%s", origin, mediaId), nil, "", nil) + assert.NoError(t, err) + assert.Equal(t, http.StatusUnauthorized, raw.StatusCode) + raw, err = client2.DoRaw("GET", fmt.Sprintf("/_matrix/client/v1/media/download/%s/%s/whatever.png", origin, mediaId), nil, "", nil) + assert.NoError(t, err) + assert.Equal(t, http.StatusUnauthorized, raw.StatusCode) + + raw, err = client1.DoRaw("GET", fmt.Sprintf("/_matrix/client/v1/media/download/%s/%s", origin, mediaId), nil, "", nil) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, raw.StatusCode) + test_internals.AssertIsTestImage(t, raw.Body) + raw, err = client1.DoRaw("GET", fmt.Sprintf("/_matrix/client/v1/media/download/%s/%s/whatever.png", origin, mediaId), nil, "", nil) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, raw.StatusCode) + test_internals.AssertIsTestImage(t, raw.Body) +} + +func (s *MSC3916DownloadsSuite) TestFederationDownloads() { + t := s.T() + + client1 := s.deps.Homeservers[0].UnprivilegedUsers[0].WithCsUrl(s.deps.Machines[0].HttpUrl) + remoteClient := &test_internals.MatrixClient{ + ClientServerUrl: s.deps.Machines[0].HttpUrl, + ServerName: s.deps.Homeservers[0].ServerName, + AccessToken: "", // this client isn't authed over the CS API + UserId: "", // this client isn't authed over the CS API + } + + contentType, img, err := test_internals.MakeTestImage(512, 512) + assert.NoError(t, err) + fname := "image" + util.ExtensionForContentType(contentType) + + res, err := client1.Upload(fname, contentType, img) + assert.NoError(t, err) + assert.NotEmpty(t, res.MxcUri) + + origin, mediaId, err := util.SplitMxc(res.MxcUri) + assert.NoError(t, err) + assert.Equal(t, client1.ServerName, origin) + assert.NotEmpty(t, mediaId) + + // Verify the federation download *fails* when lacking auth + uri := fmt.Sprintf("/_matrix/federation/v1/media/download/%s", mediaId) + raw, err := remoteClient.DoRaw("GET", uri, nil, "", nil) + assert.NoError(t, err) + assert.Equal(t, http.StatusUnauthorized, raw.StatusCode) + + // Now add the X-Matrix auth and try again + header, err := matrix.CreateXMatrixHeader(s.keyServer.PublicHostname, remoteClient.ServerName, "GET", uri, &database.AnonymousJson{}, s.keyServerKey.PrivateKey, s.keyServerKey.KeyVersion) + assert.NoError(t, err) + remoteClient.AuthHeaderOverride = header + raw, err = remoteClient.DoRaw("GET", uri, nil, "", nil) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, raw.StatusCode) +} + +func (s *MSC3916DownloadsSuite) TestFederationMakesAuthedDownloads() { + t := s.T() + + client1 := s.deps.Homeservers[0].UnprivilegedUsers[0].WithCsUrl(s.deps.Machines[0].HttpUrl) + + origin := "" + mediaId := "abc123" + err := matrix.TestsOnlyInjectSigningKey(s.deps.Homeservers[0].ServerName, s.deps.Homeservers[0].ExternalClientServerApiUrl) + assert.NoError(t, err) + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, fmt.Sprintf("/_matrix/federation/v1/media/download/%s", mediaId), r.URL.Path) + origin, err := matrix.ValidateXMatrixAuth(r, true) + assert.NoError(t, err) + assert.Equal(t, client1.ServerName, origin) + w.Header().Set("Content-Type", "multipart/mixed; boundary=gc0p4Jq0M2Yt08jU534c0p") + _, _ = w.Write([]byte("--gc0p4Jq0M2Yt08jU534c0p\nContent-Type: application/json\n\n{}\n\n--gc0p4Jq0M2Yt08jU534c0p\nContent-Type: text/plain\n\nThis media is plain text. Maybe somebody used it as a paste bin.\n\n--gc0p4Jq0M2Yt08jU534c0p")) + })) + defer testServer.Close() + + u, _ := url.Parse(testServer.URL) + origin = fmt.Sprintf("%s:%s", "host.docker.internal", u.Port()) + config.AddDomainForTesting("host.docker.internal", nil) // no port for config lookup + + raw, err := client1.DoRaw("GET", fmt.Sprintf("/_matrix/client/v1/media/download/%s/%s", origin, mediaId), nil, "", nil) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, raw.StatusCode) +} + +func (s *MSC3916DownloadsSuite) TestFederationFollowsRedirects() { + t := s.T() + + client1 := s.deps.Homeservers[0].UnprivilegedUsers[0].WithCsUrl(s.deps.Machines[0].HttpUrl) + + origin := "" + mediaId := "abc123" + fileContents := "hello world! This is a test file" + err := matrix.TestsOnlyInjectSigningKey(s.deps.Homeservers[0].ServerName, s.deps.Homeservers[0].ExternalClientServerApiUrl) + assert.NoError(t, err) + + // Mock CDN (2nd hop) + testServer2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/cdn/file", r.URL.Path) + w.Header().Set("Content-Type", "text/plain") + _, _ = w.Write([]byte(fileContents)) + })) + defer testServer2.Close() + u, _ := url.Parse(testServer2.URL) + //goland:noinspection HttpUrlsUsage + redirectUrl := fmt.Sprintf("http://%s:%s/cdn/file", "host.docker.internal", u.Port()) + + // Mock homeserver (1st hop) + testServer1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, fmt.Sprintf("/_matrix/federation/v1/media/download/%s", mediaId), r.URL.Path) + origin, err := matrix.ValidateXMatrixAuth(r, true) + assert.NoError(t, err) + assert.Equal(t, client1.ServerName, origin) + w.Header().Set("Content-Type", "multipart/mixed; boundary=gc0p4Jq0M2Yt08jU534c0p") + _, _ = w.Write([]byte(fmt.Sprintf("--gc0p4Jq0M2Yt08jU534c0p\nContent-Type: application/json\n\n{}\n\n--gc0p4Jq0M2Yt08jU534c0p\nLocation: %s\n\n-gc0p4Jq0M2Yt08jU534c0p", redirectUrl))) + })) + defer testServer1.Close() + + u, _ = url.Parse(testServer1.URL) + origin = fmt.Sprintf("%s:%s", "host.docker.internal", u.Port()) + config.AddDomainForTesting("host.docker.internal", nil) // no port for config lookup + + raw, err := client1.DoRaw("GET", fmt.Sprintf("/_matrix/client/v1/media/download/%s/%s", origin, mediaId), nil, "", nil) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, raw.StatusCode) + + b, err := io.ReadAll(raw.Body) + assert.NoError(t, err) + assert.Equal(t, fileContents, string(b)) +} + +func (s *MSC3916DownloadsSuite) TestFederationProducesRedirects() { + t := s.T() + + client1 := s.deps.Homeservers[0].UnprivilegedUsers[0].WithCsUrl(s.deps.Machines[0].HttpUrl) + remoteClient := &test_internals.MatrixClient{ + ClientServerUrl: s.deps.Machines[0].HttpUrl, + ServerName: s.deps.Homeservers[0].ServerName, + AccessToken: "", // this client isn't authed over the CS API + UserId: "", // this client isn't authed over the CS API + } + + contentType, img, err := test_internals.MakeTestImage(512, 512) + assert.NoError(t, err) + fname := "image" + util.ExtensionForContentType(contentType) + + res, err := client1.Upload(fname, contentType, img) + assert.NoError(t, err) + assert.NotEmpty(t, res.MxcUri) + + origin, mediaId, err := util.SplitMxc(res.MxcUri) + assert.NoError(t, err) + assert.Equal(t, client1.ServerName, origin) + assert.NotEmpty(t, mediaId) + + // Verify the federation download *fails* when lacking auth + uri := fmt.Sprintf("/_matrix/federation/v1/media/download/%s", mediaId) + header, err := matrix.CreateXMatrixHeader(s.keyServer.PublicHostname, remoteClient.ServerName, "GET", uri, &database.AnonymousJson{}, s.keyServerKey.PrivateKey, s.keyServerKey.KeyVersion) + assert.NoError(t, err) + remoteClient.AuthHeaderOverride = header + raw, err := remoteClient.DoRaw("GET", uri, nil, "", nil) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, raw.StatusCode) + + // TODO: Need to actually test that redirects are properly formed, and set up the test suite to produce them +} + +func (s *MSC3916DownloadsSuite) TestFederationMakesAuthedDownloadsAndFallsBack() { + t := s.T() + + client1 := s.deps.Homeservers[0].UnprivilegedUsers[0].WithCsUrl(s.deps.Machines[0].HttpUrl) + + origin := "" + mediaId := "abc123" + fileContents := "hello world! This is a test file" + err := matrix.TestsOnlyInjectSigningKey(s.deps.Homeservers[0].ServerName, s.deps.Homeservers[0].ExternalClientServerApiUrl) + assert.NoError(t, err) + + reqNum := 0 + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if reqNum == 0 { + origin, err := matrix.ValidateXMatrixAuth(r, true) + assert.NoError(t, err) + assert.Equal(t, client1.ServerName, origin) + assert.Equal(t, fmt.Sprintf("/_matrix/federation/v1/media/download/%s", mediaId), r.URL.Path) + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte("{\"errcode\":\"M_UNRECOGNIZED\"}")) + reqNum++ + } else { + assert.Equal(t, fmt.Sprintf("/_matrix/media/v3/download/%s/%s", origin, mediaId), r.URL.Path) + } + w.Header().Set("Content-Type", "text/plain") + _, _ = w.Write([]byte(fileContents)) + })) + defer testServer.Close() + + u, _ := url.Parse(testServer.URL) + origin = fmt.Sprintf("%s:%s", "host.docker.internal", u.Port()) + config.AddDomainForTesting("host.docker.internal", nil) // no port for config lookup + + raw, err := client1.DoRaw("GET", fmt.Sprintf("/_matrix/client/v1/media/download/%s/%s", origin, mediaId), nil, "", nil) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, raw.StatusCode) + + b, err := io.ReadAll(raw.Body) + assert.NoError(t, err) + assert.Equal(t, fileContents, string(b)) +} + +func TestMSC3916DownloadsSuite(t *testing.T) { + suite.Run(t, new(MSC3916DownloadsSuite)) +} diff --git a/test/msc3916_misc_client_endpoints_suite_test.go b/test/msc3916_misc_client_endpoints_suite_test.go new file mode 100644 index 00000000..1a0801fd --- /dev/null +++ b/test/msc3916_misc_client_endpoints_suite_test.go @@ -0,0 +1,95 @@ +package test + +import ( + "log" + "net/http" + "net/url" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" + "github.com/t2bot/matrix-media-repo/test/test_internals" +) + +type MSC3916MiscClientEndpointsSuite struct { + suite.Suite + deps *test_internals.ContainerDeps + htmlPage *test_internals.HostedFile +} + +func (s *MSC3916MiscClientEndpointsSuite) SetupSuite() { + deps, err := test_internals.MakeTestDeps() + if err != nil { + log.Fatal(err) + } + s.deps = deps + + file, err := test_internals.ServeFile("index.html", deps, "

This is a test file

") + if err != nil { + log.Fatal(err) + } + s.htmlPage = file +} + +func (s *MSC3916MiscClientEndpointsSuite) TearDownSuite() { + if s.htmlPage != nil { + if s.T().Failed() { + staticLogs, err := s.htmlPage.Logs() + s.deps.DumpDebugLogs(staticLogs, err, -1, s.htmlPage.PublicUrl) + } + s.htmlPage.Teardown() + } + if s.deps != nil { + if s.T().Failed() { + s.deps.Debug() + } + s.deps.Teardown() + } +} + +func (s *MSC3916MiscClientEndpointsSuite) TestPreviewUrlRequiresAuth() { + t := s.T() + + client1 := s.deps.Homeservers[0].UnprivilegedUsers[0].WithCsUrl(s.deps.Machines[0].HttpUrl) + client2 := &test_internals.MatrixClient{ + ClientServerUrl: s.deps.Machines[0].HttpUrl, + ServerName: s.deps.Homeservers[0].ServerName, + AccessToken: "", // no auth on this client + UserId: "", // no auth on this client + } + + qs := url.Values{ + "url": []string{s.htmlPage.PublicUrl}, + } + raw, err := client2.DoRaw("GET", "/_matrix/client/v1/media/preview_url", qs, "", nil) + assert.NoError(t, err) + assert.Equal(t, http.StatusUnauthorized, raw.StatusCode) + + raw, err = client1.DoRaw("GET", "/_matrix/client/v1/media/preview_url", qs, "", nil) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, raw.StatusCode) +} + +func (s *MSC3916MiscClientEndpointsSuite) TestConfigRequiresAuth() { + t := s.T() + + client1 := s.deps.Homeservers[0].UnprivilegedUsers[0].WithCsUrl(s.deps.Machines[0].HttpUrl) + client2 := &test_internals.MatrixClient{ + ClientServerUrl: s.deps.Machines[0].HttpUrl, + ServerName: s.deps.Homeservers[0].ServerName, + AccessToken: "", // no auth on this client + UserId: "", // no auth on this client + } + + raw, err := client2.DoRaw("GET", "/_matrix/client/v1/media/config", nil, "", nil) + assert.NoError(t, err) + assert.Equal(t, http.StatusUnauthorized, raw.StatusCode) + + raw, err = client1.DoRaw("GET", "/_matrix/client/v1/media/config", nil, "", nil) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, raw.StatusCode) +} + +func TestMSC3916MiscClientEndpointsSuite(t *testing.T) { + suite.Run(t, new(MSC3916MiscClientEndpointsSuite)) +} diff --git a/test/msc3916_thumbnails_suite_test.go b/test/msc3916_thumbnails_suite_test.go new file mode 100644 index 00000000..4ff43a17 --- /dev/null +++ b/test/msc3916_thumbnails_suite_test.go @@ -0,0 +1,131 @@ +package test + +import ( + "fmt" + "log" + "net/http" + "net/url" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" + "github.com/t2bot/matrix-media-repo/database" + "github.com/t2bot/matrix-media-repo/homeserver_interop" + "github.com/t2bot/matrix-media-repo/matrix" + "github.com/t2bot/matrix-media-repo/test/test_internals" + "github.com/t2bot/matrix-media-repo/util" +) + +type MSC3916ThumbnailsSuite struct { + suite.Suite + deps *test_internals.ContainerDeps + keyServer *test_internals.HostedFile + keyServerKey *homeserver_interop.SigningKey +} + +func (s *MSC3916ThumbnailsSuite) SetupSuite() { + deps, err := test_internals.MakeTestDeps() + if err != nil { + log.Fatal(err) + } + s.deps = deps + + s.keyServer, s.keyServerKey = test_internals.MakeKeyServer(deps) +} + +func (s *MSC3916ThumbnailsSuite) TearDownSuite() { + if s.deps != nil { + if s.T().Failed() { + s.deps.Debug() + } + s.deps.Teardown() + } +} + +func (s *MSC3916ThumbnailsSuite) TestClientThumbnails() { + t := s.T() + + client1 := s.deps.Homeservers[0].UnprivilegedUsers[0].WithCsUrl(s.deps.Machines[0].HttpUrl) + client2 := &test_internals.MatrixClient{ + ClientServerUrl: s.deps.Machines[0].HttpUrl, + ServerName: s.deps.Homeservers[0].ServerName, + AccessToken: "", // this client isn't authed + UserId: "", // this client isn't authed + } + + contentType, img, err := test_internals.MakeTestImage(512, 512) + assert.NoError(t, err) + fname := "image" + util.ExtensionForContentType(contentType) + + res, err := client1.Upload(fname, contentType, img) + assert.NoError(t, err) + assert.NotEmpty(t, res.MxcUri) + + origin, mediaId, err := util.SplitMxc(res.MxcUri) + assert.NoError(t, err) + assert.Equal(t, client1.ServerName, origin) + assert.NotEmpty(t, mediaId) + + qs := url.Values{ + "width": []string{"96"}, + "height": []string{"96"}, + "method": []string{"scale"}, + } + + raw, err := client2.DoRaw("GET", fmt.Sprintf("/_matrix/client/v1/media/thumbnail/%s/%s", origin, mediaId), qs, "", nil) + assert.NoError(t, err) + assert.Equal(t, http.StatusUnauthorized, raw.StatusCode) + + raw, err = client1.DoRaw("GET", fmt.Sprintf("/_matrix/client/v1/media/thumbnail/%s/%s", origin, mediaId), qs, "", nil) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, raw.StatusCode) + //test_internals.AssertIsTestImage(t, raw.Body) // we can't verify that the resulting image is correct +} + +func (s *MSC3916ThumbnailsSuite) TestFederationThumbnails() { + t := s.T() + + client1 := s.deps.Homeservers[0].UnprivilegedUsers[0].WithCsUrl(s.deps.Machines[0].HttpUrl) + remoteClient := &test_internals.MatrixClient{ + ClientServerUrl: s.deps.Machines[0].HttpUrl, + ServerName: s.deps.Homeservers[0].ServerName, + AccessToken: "", // this client isn't authed over the CS API + UserId: "", // this client isn't authed over the CS API + } + + contentType, img, err := test_internals.MakeTestImage(512, 512) + assert.NoError(t, err) + fname := "image" + util.ExtensionForContentType(contentType) + + res, err := client1.Upload(fname, contentType, img) + assert.NoError(t, err) + assert.NotEmpty(t, res.MxcUri) + + origin, mediaId, err := util.SplitMxc(res.MxcUri) + assert.NoError(t, err) + assert.Equal(t, client1.ServerName, origin) + assert.NotEmpty(t, mediaId) + + // Verify the federation download *fails* when lacking auth + uri := fmt.Sprintf("/_matrix/federation/v1/media/thumbnail/%s", mediaId) + qs := url.Values{ + "width": []string{"96"}, + "height": []string{"96"}, + "method": []string{"scale"}, + } + raw, err := remoteClient.DoRaw("GET", uri, qs, "", nil) + assert.NoError(t, err) + assert.Equal(t, http.StatusUnauthorized, raw.StatusCode) + + // Now add the X-Matrix auth and try again + header, err := matrix.CreateXMatrixHeader(s.keyServer.PublicHostname, remoteClient.ServerName, "GET", fmt.Sprintf("%s?%s", uri, qs.Encode()), &database.AnonymousJson{}, s.keyServerKey.PrivateKey, s.keyServerKey.KeyVersion) + assert.NoError(t, err) + remoteClient.AuthHeaderOverride = header + raw, err = remoteClient.DoRaw("GET", uri, qs, "", nil) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, raw.StatusCode) +} + +func TestMSC3916ThumbnailsSuite(t *testing.T) { + suite.Run(t, new(MSC3916ThumbnailsSuite)) +} diff --git a/test/signing_keys_test.go b/test/signing_keys_test.go new file mode 100644 index 00000000..b23ad54b --- /dev/null +++ b/test/signing_keys_test.go @@ -0,0 +1,73 @@ +package test + +import ( + "crypto/ed25519" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/t2bot/matrix-media-repo/database" + "github.com/t2bot/matrix-media-repo/matrix" + "github.com/t2bot/matrix-media-repo/util" +) + +func TestFailInjectedKeys(t *testing.T) { + raw := database.AnonymousJson{ + "old_verify_keys": database.AnonymousJson{}, + "server_name": "x.resolvematrix.dev", + "signatures": database.AnonymousJson{ + "x.resolvematrix.dev": database.AnonymousJson{ + "ed25519:injected": "FB93YAF+fOPyWcsx285Q/xFzRiG5sr7/u1iX9XWaIcOwDyDDwx7daS1eYxuM9PfosWE5vqUyTsCxmB40JTzdCw", + }, + }, + "valid_until_ts": 1701055573679, + "verify_keys": database.AnonymousJson{ + "ed25519:AY4k3ADlto8": database.AnonymousJson{"key": "VF7dl9W/tFWAjZSXm42Ef22k3v4WKBYLXZF9I7ErU00"}, + "ed25519:injected": database.AnonymousJson{"key": "w48CLiV1IkWoEbqJLFmniGUYtxwT+c2zm87X8oEpRO8"}, + }, + } + keyInfo := new(matrix.ServerKeyResult) + err := raw.ApplyTo(keyInfo) + if err != nil { + t.Fatal(err) + } + + _, err = matrix.CheckSigningKeySignatures("x.resolvematrix.dev", keyInfo, raw) + assert.Error(t, err) + assert.Equal(t, "missing signature from key 'ed25519:AY4k3ADlto8'", err.Error()) +} + +func TestRegularKeys(t *testing.T) { + raw := database.AnonymousJson{ + "old_verify_keys": database.AnonymousJson{}, + "server_name": "x.resolvematrix.dev", + "signatures": database.AnonymousJson{ + "x.resolvematrix.dev": database.AnonymousJson{ + "ed25519:AY4k3ADlto8": "3WlsmHFTVjywCoDYyrtx3ies+VufTuBuw1Prlgmoqh+a4XrJT+isEwhTX+I5FBvtJTKTt6vLH3gaP7BA6712CA", + }, + }, + "valid_until_ts": 1701057124839, + "verify_keys": database.AnonymousJson{ + "ed25519:AY4k3ADlto8": database.AnonymousJson{"key": "VF7dl9W/tFWAjZSXm42Ef22k3v4WKBYLXZF9I7ErU00"}, + }, + } + keyInfo := new(matrix.ServerKeyResult) + err := raw.ApplyTo(keyInfo) + if err != nil { + t.Fatal(err) + } + + keys, err := matrix.CheckSigningKeySignatures("x.resolvematrix.dev", keyInfo, raw) + assert.NoError(t, err) + for keyId, keyVal := range keys { + if b64, ok := keyInfo.VerifyKeys[keyId]; !ok { + t.Errorf("got key for '%s' but wasn't expecting it", keyId) + } else { + keySelf, err := util.DecodeUnpaddedBase64String(b64.Key) + if err != nil { + t.Fatal(err) + } + + assert.Equal(t, ed25519.PublicKey(keySelf), keyVal) + } + } +} diff --git a/test/templates/mmr.config.yaml b/test/templates/mmr.config.yaml index 40fbf6d6..965490d5 100644 --- a/test/templates/mmr.config.yaml +++ b/test/templates/mmr.config.yaml @@ -18,6 +18,7 @@ homeservers: csApi: "{{.ClientServerApiUrl}}" backoffAt: 10 adminApiKind: "synapse" + signingKeyPath: "{{.SigningKeyPath}}" {{end}} redis: enabled: true @@ -40,3 +41,20 @@ datastores: ssl: false rateLimit: enabled: false # we've got tests which intentionally spam +urlPreviews: + enabled: true + maxPageSizeBytes: 10485760 + previewUnsafeCertificates: false + numWords: 50 + maxLength: 200 + numTitleWords: 30 + maxTitleLength: 150 + filePreviewTypes: + - "image/*" + numWorkers: 10 + disallowedNetworks: [] + allowedNetworks: ["0.0.0.0/0"] + expireAfterDays: 0 + defaultLanguage: "en-US,en" + userAgent: "matrix-media-repo" + oEmbed: true diff --git a/test/templates/synapse.homeserver.yaml b/test/templates/synapse.homeserver.yaml index 750730c6..556dd611 100644 --- a/test/templates/synapse.homeserver.yaml +++ b/test/templates/synapse.homeserver.yaml @@ -25,7 +25,7 @@ registration_shared_secret: "l,jbms,sR_Z82JNP2,sv-~^5bXqFTV-T=j,,~=OKZ8I_Tardk;" report_stats: false macaroon_secret_key: "KV*8qANyBE28e*pZ-9RP+u86~i8+.j9IZEKU8Vb4+jdIoe~ncw" form_secret: "mQrUxtt6^F3uQ3nVrGdg7yAK64p*#Uf@2n=e9y8ggLbhy3-QIy" -signing_key_path: "/app/signing.key" +signing_key_path: "/data/signing.key" enable_media_repo: false enable_registration: true enable_registration_without_verification: true diff --git a/test/test_internals/deps.go b/test/test_internals/deps.go index 18ba27dd..06999673 100644 --- a/test/test_internals/deps.go +++ b/test/test_internals/deps.go @@ -7,22 +7,27 @@ import ( "log" "os" "path" + "strings" "time" "github.com/t2bot/matrix-media-repo/common/assets" "github.com/t2bot/matrix-media-repo/common/config" + "github.com/t2bot/matrix-media-repo/homeserver_interop" + "github.com/t2bot/matrix-media-repo/homeserver_interop/mmr" + "github.com/t2bot/matrix-media-repo/homeserver_interop/synapse" "github.com/testcontainers/testcontainers-go" "github.com/testcontainers/testcontainers-go/modules/postgres" "github.com/testcontainers/testcontainers-go/wait" ) type ContainerDeps struct { - ctx context.Context - pgContainer *postgres.PostgresContainer - redisContainer testcontainers.Container - minioDep *MinioDep - depNet *NetworkDep - mmrExtConfigPath string + ctx context.Context + pgContainer *postgres.PostgresContainer + redisContainer testcontainers.Container + minioDep *MinioDep + depNet *NetworkDep + mmrExtConfigPath string + mmrSigningKeyPath string Homeservers []*SynapseDep Machines []*mmrContainer @@ -37,12 +42,56 @@ func MakeTestDeps() (*ContainerDeps, error) { return nil, err } + // Create a shared signing key for the MMR instances + signingKeyFile, err := os.CreateTemp(os.TempDir(), "mmr-signing-key") + if err != nil { + return nil, err + } + signingKey, err := homeserver_interop.GenerateSigningKey() + if err != nil { + return nil, err + } + b, err := mmr.EncodeSigningKey(signingKey) + if err != nil { + return nil, err + } + _, err = signingKeyFile.Write(b) + if err != nil { + return nil, err + } + err = signingKeyFile.Close() + if err != nil { + return nil, err + } + + // And use that same signing key for Synapse + synapseSigningKeyFile, err := os.CreateTemp(os.TempDir(), "mmr-synapse-signing-key") + if err != nil { + return nil, err + } + b, err = synapse.EncodeSigningKey(signingKey) + if err != nil { + return nil, err + } + _, err = synapseSigningKeyFile.Write(b) + if err != nil { + return nil, err + } + err = synapseSigningKeyFile.Close() + if err != nil { + return nil, err + } + err = os.Chmod(synapseSigningKeyFile.Name(), 0777) // XXX: Not great, but works. + if err != nil { + return nil, err + } + // Start two synapses for testing - syn1, err := MakeSynapse("first.example.org", depNet) + syn1, err := MakeSynapse("first.example.org", depNet, synapseSigningKeyFile.Name()) if err != nil { return nil, err } - syn2, err := MakeSynapse("second.example.org", depNet) + syn2, err := MakeSynapse("second.example.org", depNet, synapseSigningKeyFile.Name()) if err != nil { return nil, err } @@ -113,14 +162,16 @@ func MakeTestDeps() (*ContainerDeps, error) { // Start two MMRs for testing tmplArgs := mmrTmplArgs{ - Homeservers: []mmrHomeserverTmplArgs{ + Homeservers: []*mmrHomeserverTmplArgs{ { ServerName: syn1.ServerName, ClientServerApiUrl: syn1.InternalClientServerApiUrl, + SigningKeyPath: strings.ReplaceAll(signingKeyFile.Name(), "\\", "\\\\"), }, { ServerName: syn2.ServerName, ClientServerApiUrl: syn2.InternalClientServerApiUrl, + SigningKeyPath: strings.ReplaceAll(signingKeyFile.Name(), "\\", "\\\\"), }, }, RedisAddr: fmt.Sprintf("%s:%d", redisIp, 6379), // we're behind the network for redis @@ -148,20 +199,21 @@ func MakeTestDeps() (*ContainerDeps, error) { assets.SetupAssets(config.DefaultAssetsPath) return &ContainerDeps{ - ctx: ctx, - pgContainer: pgContainer, - redisContainer: redisContainer, - minioDep: minioDep, - mmrExtConfigPath: tmpPath, - Homeservers: []*SynapseDep{syn1, syn2}, - Machines: mmrs, - depNet: depNet, + ctx: ctx, + pgContainer: pgContainer, + redisContainer: redisContainer, + minioDep: minioDep, + mmrExtConfigPath: tmpPath, + mmrSigningKeyPath: signingKeyFile.Name(), + Homeservers: []*SynapseDep{syn1, syn2}, + Machines: mmrs, + depNet: depNet, }, nil } func (c *ContainerDeps) Teardown() { - for _, mmr := range c.Machines { - mmr.Teardown() + for _, machine := range c.Machines { + machine.Teardown() } for _, hs := range c.Homeservers { hs.Teardown() @@ -173,28 +225,37 @@ func (c *ContainerDeps) Teardown() { log.Fatalf("Error shutting down mmr-postgres container: %s", err.Error()) } c.minioDep.Teardown() - c.depNet.Teardown() if err := os.Remove(c.mmrExtConfigPath); err != nil && !os.IsNotExist(err) { log.Fatalf("Error cleaning up MMR-External config file '%s': %s", c.mmrExtConfigPath, err.Error()) } + if err := os.Remove(c.mmrSigningKeyPath); err != nil && !os.IsNotExist(err) { + log.Fatalf("Error cleaning up MMR-Signing Key file '%s': %s", c.mmrSigningKeyPath, err.Error()) + } + + // XXX: We should be shutting this down, but it appears testcontainers leaves something attached :( + //c.depNet.Teardown() } func (c *ContainerDeps) Debug() { for i, m := range c.Machines { logs, err := m.Logs() - if err != nil { - log.Fatal(err) - } - b, err := io.ReadAll(logs) - if err != nil { - log.Fatal(err) - } - fmt.Printf("[MMR Deps] Logs from index %d (%s)", i, m.HttpUrl) - fmt.Println() - fmt.Println(string(b)) - err = logs.Close() - if err != nil { - log.Fatal(err) - } + c.DumpDebugLogs(logs, err, i, m.HttpUrl) + } +} + +func (c *ContainerDeps) DumpDebugLogs(logs io.ReadCloser, err error, i int, url string) { + if err != nil { + log.Fatal(err) + } + b, err := io.ReadAll(logs) + if err != nil { + log.Fatal(err) + } + fmt.Printf("[MMR Deps] Logs from index %d (%s)", i, url) + fmt.Println() + fmt.Println(string(b)) + err = logs.Close() + if err != nil { + log.Fatal(err) } } diff --git a/test/test_internals/deps_mmr.go b/test/test_internals/deps_mmr.go index a3d3b9af..866a8790 100644 --- a/test/test_internals/deps_mmr.go +++ b/test/test_internals/deps_mmr.go @@ -11,6 +11,7 @@ import ( "strings" "text/template" + "github.com/docker/docker/api/types/container" "github.com/docker/go-connections/nat" "github.com/testcontainers/testcontainers-go" "github.com/testcontainers/testcontainers-go/wait" @@ -19,10 +20,12 @@ import ( type mmrHomeserverTmplArgs struct { ServerName string ClientServerApiUrl string + SigningKeyPath string + PublicBaseUrl string } type mmrTmplArgs struct { - Homeservers []mmrHomeserverTmplArgs + Homeservers []*mmrHomeserverTmplArgs RedisAddr string PgConnectionString string S3Endpoint string @@ -71,6 +74,17 @@ func writeMmrConfig(tmplArgs mmrTmplArgs) (string, error) { } func makeMmrInstances(ctx context.Context, count int, depNet *NetworkDep, tmplArgs mmrTmplArgs) ([]*mmrContainer, error) { + // We need to relocate the signing key paths for a Docker mount + additionalMounts := make([]testcontainers.ContainerMount, 0) + for i, hs := range tmplArgs.Homeservers { + if hs.SigningKeyPath != "" { + inContainerName := fmt.Sprintf("/data/hs%d.key", i) + additionalMounts = append(additionalMounts, testcontainers.BindMount(hs.SigningKeyPath, testcontainers.ContainerMountTarget(inContainerName))) + hs.SigningKeyPath = inContainerName + } + } + + // ... then we can write the config and get the temp file path for it intTmpName, err := writeMmrConfig(tmplArgs) if err != nil { return nil, err @@ -95,14 +109,18 @@ func makeMmrInstances(ctx context.Context, count int, depNet *NetworkDep, tmplAr KeepImage: true, }, ExposedPorts: []string{"8000/tcp"}, - Mounts: []testcontainers.ContainerMount{ + Mounts: append([]testcontainers.ContainerMount{ testcontainers.BindMount(intTmpName, "/data/media-repo.yaml"), - }, + }, additionalMounts...), Env: map[string]string{ - "MACHINE_ID": strconv.Itoa(i), + "MACHINE_ID": strconv.Itoa(i), + "MEDIA_REPO_HTTP_ONLY_FEDERATION": "true", }, Networks: []string{depNet.NetId}, WaitingFor: wait.ForHTTP("/healthz").WithPort(p), + HostConfigModifier: func(c *container.HostConfig) { + c.ExtraHosts = append(c.ExtraHosts, "host.docker.internal:host-gateway") + }, }, Started: true, }) diff --git a/test/test_internals/deps_synapse.go b/test/test_internals/deps_synapse.go index f73b1799..d5bc7645 100644 --- a/test/test_internals/deps_synapse.go +++ b/test/test_internals/deps_synapse.go @@ -42,7 +42,7 @@ type SynapseDep struct { UnprivilegedUsers []*MatrixClient // uses ExternalClientServerApiUrl } -func MakeSynapse(domainName string, depNet *NetworkDep) (*SynapseDep, error) { +func MakeSynapse(domainName string, depNet *NetworkDep, signingKeyFilePath string) (*SynapseDep, error) { ctx := context.Background() // Start postgresql database @@ -117,6 +117,7 @@ func MakeSynapse(domainName string, depNet *NetworkDep) (*SynapseDep, error) { ExposedPorts: []string{"8008/tcp"}, Mounts: []testcontainers.ContainerMount{ testcontainers.BindMount(f.Name(), "/data/homeserver.yaml"), + testcontainers.BindMount(signingKeyFilePath, "/data/signing.key"), testcontainers.BindMount(path.Join(cwd, ".", "test", "templates", "synapse.log.config"), "/data/log.config"), testcontainers.BindMount(d, "/app"), }, diff --git a/test/test_internals/inline_dep_host_file.go b/test/test_internals/inline_dep_host_file.go new file mode 100644 index 00000000..953c0971 --- /dev/null +++ b/test/test_internals/inline_dep_host_file.go @@ -0,0 +1,115 @@ +package test_internals + +import ( + "context" + "fmt" + "io" + "log" + "os" + "path" + + "github.com/testcontainers/testcontainers-go" + "github.com/testcontainers/testcontainers-go/wait" +) + +type HostedFile struct { + upstream *ContainerDeps + nginx testcontainers.Container + tempDirectoryPath string + + PublicUrl string + PublicHostname string +} + +func ServeFile(fileName string, deps *ContainerDeps, contents string) (*HostedFile, error) { + container, writeFn, err := LazyServeFile(fileName, deps) + if writeFn != nil { + err2 := writeFn(contents) + if err2 != nil { + return nil, err2 + } + } + return container, err +} + +func LazyServeFile(fileName string, deps *ContainerDeps) (*HostedFile, func(string) error, error) { + tmp, err := os.MkdirTemp(os.TempDir(), "mmr-nginx") + if err != nil { + return nil, nil, err + } + + err = os.Chmod(tmp, 0755) + if err != nil { + return nil, nil, err + } + + err = os.MkdirAll(path.Join(tmp, path.Dir(fileName)), 0755) + if err != nil { + return nil, nil, err + } + + writeFn := func(contents string) error { + f, err := os.Create(path.Join(tmp, fileName)) + if err != nil { + return err + } + defer func(f *os.File) { + _ = f.Close() + }(f) + + _, err = f.Write([]byte(contents)) + if err != nil { + return err + } + + err = f.Close() + if err != nil { + return err + } + + return nil + } + + nginx, err := testcontainers.GenericContainer(deps.ctx, testcontainers.GenericContainerRequest{ + ContainerRequest: testcontainers.ContainerRequest{ + Image: "docker.io/library/nginx:latest", + ExposedPorts: []string{"80/tcp"}, + Mounts: []testcontainers.ContainerMount{ + testcontainers.BindMount(tmp, "/usr/share/nginx/html"), + }, + Networks: []string{deps.depNet.NetId}, + WaitingFor: wait.ForListeningPort("80/tcp"), + }, + Started: true, + }) + if err != nil { + return nil, nil, err + } + + nginxIp, err := nginx.ContainerIP(deps.ctx) + if err != nil { + return nil, nil, err + } + + //goland:noinspection HttpUrlsUsage + return &HostedFile{ + upstream: deps, + nginx: nginx, + tempDirectoryPath: tmp, + PublicUrl: fmt.Sprintf("http://%s:%d/%s", nginxIp, 80, fileName), + PublicHostname: fmt.Sprintf("%s:%d", nginxIp, 80), + }, writeFn, nil +} + +func (f *HostedFile) Teardown() { + if err := f.nginx.Terminate(f.upstream.ctx); err != nil { + log.Fatalf("Error shutting down nginx container: %s", err.Error()) + } + if err := os.RemoveAll(f.tempDirectoryPath); err != nil { + log.Fatalf("Error cleaning up temporarily hosted file: %s", err.Error()) + } +} + +func (f *HostedFile) Logs() (io.ReadCloser, error) { + return f.nginx.Logs(context.Background()) +} diff --git a/test/test_internals/util_client.go b/test/test_internals/util_client.go index 443ca507..e95c799e 100644 --- a/test/test_internals/util_client.go +++ b/test/test_internals/util_client.go @@ -11,10 +11,11 @@ import ( ) type MatrixClient struct { - AccessToken string - ClientServerUrl string - UserId string - ServerName string + AccessToken string + ClientServerUrl string + UserId string + ServerName string + AuthHeaderOverride string } func (c *MatrixClient) WithCsUrl(newUrl string) *MatrixClient { @@ -80,10 +81,14 @@ func (c *MatrixClient) DoRaw(method string, endpoint string, qs url.Values, cont if contentType != "" { req.Header.Set("Content-Type", contentType) } + if c.AccessToken != "" { req.Header.Set("Authorization", "Bearer "+c.AccessToken) } + if c.AuthHeaderOverride != "" { + req.Header.Set("Authorization", c.AuthHeaderOverride) + } - log.Printf("[HTTP] [Auth=%s] [Host=%s] %s %s", c.AccessToken, c.ServerName, req.Method, req.URL.String()) + log.Printf("[HTTP] [Auth=%s] [Host=%s] %s %s", req.Header.Get("Authorization"), c.ServerName, req.Method, req.URL.String()) return http.DefaultClient.Do(req) } diff --git a/test/test_internals/util_keyserver.go b/test/test_internals/util_keyserver.go new file mode 100644 index 00000000..e8aaa914 --- /dev/null +++ b/test/test_internals/util_keyserver.go @@ -0,0 +1,60 @@ +package test_internals + +import ( + "bytes" + "crypto/ed25519" + "encoding/json" + "log" + + "github.com/t2bot/matrix-media-repo/database" + "github.com/t2bot/matrix-media-repo/homeserver_interop" + "github.com/t2bot/matrix-media-repo/homeserver_interop/mmr" + "github.com/t2bot/matrix-media-repo/util" +) + +func MakeKeyServer(deps *ContainerDeps) (*HostedFile, *homeserver_interop.SigningKey) { + // We'll use a pre-computed signing key for simplicity + signingKey, err := mmr.DecodeSigningKey(bytes.NewReader([]byte(`-----BEGIN MMR PRIVATE KEY----- +Key-ID: ed25519:e5d0oC +Version: 1 + +PJt0OaIImDJk8P/PDb4TNQHgI/1AA1C+AaQaABxAcgc= +-----END MMR PRIVATE KEY----- +`))) + if err != nil { + log.Fatal(err) + } + keyServerKey := signingKey + // Create a /_matrix/key/v2/server response file (signed JSON) + keyServer, writeFn, err := LazyServeFile("_matrix/key/v2/server", deps) + if err != nil { + log.Fatal(err) + } + serverKey := database.AnonymousJson{ + "old_verify_keys": database.AnonymousJson{}, + "server_name": keyServer.PublicHostname, + "valid_until_ts": util.NowMillis() + (60 * 60 * 1000), // +1hr + "verify_keys": database.AnonymousJson{ + "ed25519:e5d0oC": database.AnonymousJson{ + "key": "TohekYXzLx7VzV8FtLQlI3XsSdPv1CjhVYY5rZmFCvU", + }, + }, + } + canonical, err := util.EncodeCanonicalJson(serverKey) + signature := util.EncodeUnpaddedBase64ToString(ed25519.Sign(signingKey.PrivateKey, canonical)) + serverKey["signatures"] = database.AnonymousJson{ + keyServer.PublicHostname: database.AnonymousJson{ + "ed25519:e5d0oC": signature, + }, + } + b, err := json.Marshal(serverKey) + if err != nil { + log.Fatal(err) + } + err = writeFn(string(b)) + if err != nil { + log.Fatal(err) + } + + return keyServer, keyServerKey +} diff --git a/test/xmatrix_header_test.go b/test/xmatrix_header_test.go new file mode 100644 index 00000000..5572f0b8 --- /dev/null +++ b/test/xmatrix_header_test.go @@ -0,0 +1,60 @@ +package test + +import ( + "crypto/ed25519" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/t2bot/matrix-media-repo/common/config" + "github.com/t2bot/matrix-media-repo/database" + "github.com/t2bot/matrix-media-repo/matrix" + "github.com/t2bot/matrix-media-repo/util" +) + +func TestXMatrixAuthHeader(t *testing.T) { + config.AddDomainForTesting("localhost", nil) + + pub, priv, err := ed25519.GenerateKey(nil) + if err != nil { + t.Fatal(err) + } + + header, err := matrix.CreateXMatrixHeader("localhost:8008", "localhost", "GET", "/_matrix/media/v3/download/example.org/abc", &database.AnonymousJson{}, priv, "0") + if err != nil { + t.Fatal(err) + } + + auths, err := util.GetXMatrixAuth([]string{header}) + if err != nil { + t.Fatal(err) + } + + keys := make(matrix.ServerSigningKeys) + keys["ed25519:0"] = pub + err = matrix.ValidateXMatrixAuthHeader("GET", "/_matrix/media/v3/download/example.org/abc", &database.AnonymousJson{}, auths, keys, "localhost") + assert.NoError(t, err) +} + +func TestXMatrixAuthDestinationMismatch(t *testing.T) { + config.AddDomainForTesting("localhost", nil) + + pub, priv, err := ed25519.GenerateKey(nil) + if err != nil { + t.Fatal(err) + } + + header, err := matrix.CreateXMatrixHeader("localhost:8008", "localhost:1234", "GET", "/_matrix/media/v3/download/example.org/abc", &database.AnonymousJson{}, priv, "0") + if err != nil { + t.Fatal(err) + } + + auths, err := util.GetXMatrixAuth([]string{header}) + if err != nil { + t.Fatal(err) + } + + keys := make(matrix.ServerSigningKeys) + keys["ed25519:0"] = pub + err = matrix.ValidateXMatrixAuthHeader("GET", "/_matrix/media/v3/download/example.org/abc", &database.AnonymousJson{}, auths, keys, "localhost:1234") + assert.ErrorIs(t, err, matrix.ErrWrongDestination) +} diff --git a/util/canonical_json.go b/util/canonical_json.go index c514f311..939da9f6 100644 --- a/util/canonical_json.go +++ b/util/canonical_json.go @@ -5,7 +5,7 @@ import ( "encoding/json" ) -func EncodeCanonicalJson(obj map[string]interface{}) ([]byte, error) { +func EncodeCanonicalJson(obj any) ([]byte, error) { b, err := json.Marshal(obj) if err != nil { return nil, err diff --git a/util/http.go b/util/http.go index 2188feb3..0892f661 100644 --- a/util/http.go +++ b/util/http.go @@ -1,11 +1,19 @@ package util import ( + "fmt" "net/http" "net/url" "strings" ) +type XMatrixAuth struct { + Origin string + Destination string + KeyId string + Signature []byte +} + func GetAccessTokenFromRequest(request *http.Request) string { token := request.Header.Get("Authorization") @@ -40,3 +48,71 @@ func GetLogSafeUrl(r *http.Request) string { copyUrl.RawQuery = GetLogSafeQueryString(r) return copyUrl.String() } + +func GetXMatrixAuth(headers []string) ([]XMatrixAuth, error) { + auths := make([]XMatrixAuth, 0) + for _, h := range headers { + if !strings.HasPrefix(h, "X-Matrix ") { + continue + } + + paramCsv := h[len("X-Matrix "):] + params := make(map[string]string) + isKey := true + keyName := "" + keyValue := "" + escape := false + for _, c := range paramCsv { + if c == ',' && isKey { + params[strings.TrimSpace(strings.ToLower(keyName))] = keyValue + keyName = "" + keyValue = "" + continue + } + if c == '=' { + isKey = false + continue + } + + if isKey { + keyName = fmt.Sprintf("%s%s", keyName, string(c)) + } else { + if c == '\\' && !escape { + escape = true + continue + } + if c == '"' && !escape { + escape = false + if len(keyValue) > 0 { + isKey = true + } + continue + } + if escape { + escape = false + } + keyValue = fmt.Sprintf("%s%s", keyValue, string(c)) + } + } + if len(keyName) > 0 && isKey { + params[strings.TrimSpace(strings.ToLower(keyName))] = keyValue + } + + sig, err := DecodeUnpaddedBase64String(params["sig"]) + if err != nil { + return nil, err + } + auth := XMatrixAuth{ + Origin: params["origin"], + Destination: params["destination"], + KeyId: params["key"], + Signature: sig, + } + if auth.Origin == "" || auth.KeyId == "" || len(auth.Signature) == 0 { + continue + } + auths = append(auths, auth) + } + + return auths, nil +} diff --git a/util/matrix_media_part.go b/util/matrix_media_part.go new file mode 100644 index 00000000..689113fb --- /dev/null +++ b/util/matrix_media_part.go @@ -0,0 +1,27 @@ +package util + +import ( + "io" + "mime/multipart" + "net/http" + "net/textproto" +) + +type MatrixMediaPart struct { + Header textproto.MIMEHeader + Body io.ReadCloser +} + +func MatrixMediaPartFromResponse(r *http.Response) *MatrixMediaPart { + return &MatrixMediaPart{ + Header: textproto.MIMEHeader(r.Header), + Body: r.Body, + } +} + +func MatrixMediaPartFromMimeMultipart(p *multipart.Part) *MatrixMediaPart { + return &MatrixMediaPart{ + Header: p.Header, + Body: p, + } +} diff --git a/util/readers/multipart_reader.go b/util/readers/multipart_reader.go new file mode 100644 index 00000000..ca8ba1ad --- /dev/null +++ b/util/readers/multipart_reader.go @@ -0,0 +1,63 @@ +package readers + +import ( + "bytes" + "io" + "mime/multipart" + "net/textproto" + "net/url" + + "github.com/alioygur/is" +) + +type MultipartPart struct { + ContentType string + FileName string + Location string + Reader io.ReadCloser +} + +func NewMultipartReader(parts ...*MultipartPart) io.ReadCloser { + r, w := io.Pipe() + go func() { + mpw := multipart.NewWriter(w) + + for _, part := range parts { + headers := textproto.MIMEHeader{} + if part.ContentType != "" { + headers.Set("Content-Type", part.ContentType) + } + if part.FileName != "" { + if is.ASCII(part.FileName) { + headers.Set("Content-Disposition", "attachment; filename="+url.QueryEscape(part.FileName)) + } else { + headers.Set("Content-Disposition", "attachment; filename*=utf-8''"+url.QueryEscape(part.FileName)) + } + } + if part.Location != "" { + headers.Set("Location", part.Location) + part.Reader = io.NopCloser(bytes.NewReader(make([]byte, 0))) + } + + partW, err := mpw.CreatePart(headers) + if err != nil { + _ = w.CloseWithError(err) + return + } + if _, err = io.Copy(partW, part.Reader); err != nil { + _ = w.CloseWithError(err) + return + } + if err = part.Reader.Close(); err != nil { + _ = w.CloseWithError(err) + return + } + } + + if err := mpw.Close(); err != nil { + _ = w.CloseWithError(err) + } + _ = w.Close() + }() + return MakeCloser(r) +}