diff --git a/api/_apimeta/auth.go b/api/_apimeta/auth.go index d16df56f..613d5f94 100644 --- a/api/_apimeta/auth.go +++ b/api/_apimeta/auth.go @@ -20,6 +20,15 @@ type ServerInfo struct { ServerName string } +type AuthContext struct { + User UserInfo + Server ServerInfo +} + +func (a AuthContext) IsAuthenticated() bool { + return a.User.UserId != "" || a.Server.ServerName != "" +} + func GetRequestUserAdminStatus(r *http.Request, rctx rcontext.RequestContext, user UserInfo) (bool, bool) { isGlobalAdmin := util.IsGlobalAdmin(user.UserId) || user.IsShared isLocalAdmin, err := matrix.IsUserAdmin(rctx, r.Host, user.AccessToken, r.RemoteAddr) diff --git a/api/r0/download.go b/api/r0/download.go index 2ebeaa1d..5360dcc1 100644 --- a/api/r0/download.go +++ b/api/r0/download.go @@ -18,7 +18,11 @@ import ( "github.com/t2bot/matrix-media-repo/common/rcontext" ) -func DownloadMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta.UserInfo) interface{} { +func DownloadMediaUser(r *http.Request, rctx rcontext.RequestContext, user _apimeta.UserInfo) interface{} { + return DownloadMedia(r, rctx, _apimeta.AuthContext{User: user}) +} + +func DownloadMedia(r *http.Request, rctx rcontext.RequestContext, auth _apimeta.AuthContext) interface{} { server := _routers.GetParam("server", r) mediaId := _routers.GetParam("mediaId", r) filename := _routers.GetParam("filename", r) @@ -61,16 +65,20 @@ func DownloadMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta. } rctx = rctx.LogWithFields(logrus.Fields{ - "mediaId": mediaId, - "server": server, - "filename": filename, - "allowRemote": downloadRemote, - "allowRedirect": canRedirect, + "mediaId": mediaId, + "server": server, + "filename": filename, + "allowRemote": downloadRemote, + "allowRedirect": canRedirect, + "authUserId": auth.User.UserId, + "authServerName": auth.Server.ServerName, }) - if !util.IsGlobalAdmin(user.UserId) && util.IsHostIgnored(server) { - rctx.Log.Warn("Request blocked due to domain being ignored.") - return _responses.MediaBlocked() + if auth.User.UserId != "" { + if !util.IsGlobalAdmin(auth.User.UserId) && util.IsHostIgnored(server) { + rctx.Log.Warn("Request blocked due to domain being ignored.") + return _responses.MediaBlocked() + } } media, stream, err := pipeline_download.Execute(rctx, server, mediaId, pipeline_download.DownloadOpts{ @@ -78,11 +86,18 @@ func DownloadMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta. BlockForReadUntil: blockFor, CanRedirect: canRedirect, RecordOnly: recordOnly, + AuthProvided: auth.IsAuthenticated(), }) if err != nil { var redirect datastores.RedirectError if errors.Is(err, common.ErrMediaNotFound) { return _responses.NotFoundError() + } else if errors.Is(err, common.ErrRestrictedAuth) { + return _responses.ErrorResponse{ + Code: common.ErrCodeNotFound, + Message: "authentication is required to download this media", + InternalCode: common.ErrCodeUnauthorized, + } } else if errors.Is(err, common.ErrMediaTooLarge) { return _responses.RequestTooLarge() } else if errors.Is(err, common.ErrRateLimitExceeded) { diff --git a/api/r0/thumbnail.go b/api/r0/thumbnail.go index 9552bd8a..48ec314e 100644 --- a/api/r0/thumbnail.go +++ b/api/r0/thumbnail.go @@ -20,7 +20,11 @@ import ( "github.com/t2bot/matrix-media-repo/common/rcontext" ) -func ThumbnailMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta.UserInfo) interface{} { +func ThumbnailMediaUser(r *http.Request, rctx rcontext.RequestContext, user _apimeta.UserInfo) interface{} { + return ThumbnailMedia(r, rctx, _apimeta.AuthContext{User: user}) +} + +func ThumbnailMedia(r *http.Request, rctx rcontext.RequestContext, auth _apimeta.AuthContext) interface{} { server := _routers.GetParam("server", r) mediaId := _routers.GetParam("mediaId", r) allowRemote := r.URL.Query().Get("allow_remote") @@ -55,15 +59,19 @@ func ThumbnailMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta } rctx = rctx.LogWithFields(logrus.Fields{ - "mediaId": mediaId, - "server": server, - "allowRemote": downloadRemote, - "allowRedirect": canRedirect, + "mediaId": mediaId, + "server": server, + "allowRemote": downloadRemote, + "allowRedirect": canRedirect, + "authUserId": auth.User.UserId, + "authServerName": auth.Server.ServerName, }) - if !util.IsGlobalAdmin(user.UserId) && util.IsHostIgnored(server) { - rctx.Log.Warn("Request blocked due to domain being ignored.") - return _responses.MediaBlocked() + if auth.User.UserId != "" { + if !util.IsGlobalAdmin(auth.User.UserId) && util.IsHostIgnored(server) { + rctx.Log.Warn("Request blocked due to domain being ignored.") + return _responses.MediaBlocked() + } } widthStr := r.URL.Query().Get("width") @@ -124,6 +132,7 @@ func ThumbnailMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta BlockForReadUntil: blockFor, RecordOnly: false, // overridden CanRedirect: canRedirect, + AuthProvided: auth.IsAuthenticated(), }, Width: width, Height: height, @@ -134,6 +143,12 @@ func ThumbnailMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta var redirect datastores.RedirectError if errors.Is(err, common.ErrMediaNotFound) { return _responses.NotFoundError() + } else if errors.Is(err, common.ErrRestrictedAuth) { + return _responses.ErrorResponse{ + Code: common.ErrCodeNotFound, + Message: "authentication is required to download this media", + InternalCode: common.ErrCodeUnauthorized, + } } else if errors.Is(err, common.ErrMediaTooLarge) { return _responses.RequestTooLarge() } else if errors.Is(err, common.ErrRateLimitExceeded) { diff --git a/api/routes.go b/api/routes.go index 5bbdbe63..ee2d87a7 100644 --- a/api/routes.go +++ b/api/routes.go @@ -33,10 +33,10 @@ func buildRoutes() http.Handler { // Standard (spec) features register([]string{"PUT"}, PrefixMedia, "upload/:server/:mediaId", mxV3, router, makeRoute(_routers.RequireAccessToken(r0.UploadMediaAsync), "upload_async", counter)) register([]string{"POST"}, PrefixMedia, "upload", mxSpecV3Transition, router, makeRoute(_routers.RequireAccessToken(r0.UploadMediaSync), "upload", counter)) - downloadRoute := makeRoute(_routers.OptionalAccessToken(r0.DownloadMedia), "download", counter) + downloadRoute := makeRoute(_routers.OptionalAccessToken(r0.DownloadMediaUser), "download", counter) register([]string{"GET", "HEAD"}, PrefixMedia, "download/:server/:mediaId/:filename", mxSpecV3Transition, router, downloadRoute) register([]string{"GET", "HEAD"}, PrefixMedia, "download/:server/:mediaId", mxSpecV3Transition, router, downloadRoute) - register([]string{"GET"}, PrefixMedia, "thumbnail/:server/:mediaId", mxSpecV3Transition, router, makeRoute(_routers.OptionalAccessToken(r0.ThumbnailMedia), "thumbnail", counter)) + register([]string{"GET"}, PrefixMedia, "thumbnail/:server/:mediaId", mxSpecV3Transition, router, makeRoute(_routers.OptionalAccessToken(r0.ThumbnailMediaUser), "thumbnail", counter)) previewUrlRoute := makeRoute(_routers.RequireAccessToken(r0.PreviewUrl), "url_preview", counter) register([]string{"GET"}, PrefixMedia, "preview_url", mxSpecV3TransitionCS, router, previewUrlRoute) register([]string{"GET"}, PrefixMedia, "identicon/*seed", mxR0, router, makeRoute(_routers.OptionalAccessToken(r0.Identicon), "identicon", counter)) diff --git a/api/v1/download.go b/api/v1/download.go index b833059e..0e2e8e22 100644 --- a/api/v1/download.go +++ b/api/v1/download.go @@ -2,9 +2,10 @@ package v1 import ( "bytes" - "github.com/t2bot/matrix-media-repo/util/ids" "net/http" + "github.com/t2bot/matrix-media-repo/util/ids" + "github.com/t2bot/matrix-media-repo/api/_apimeta" "github.com/t2bot/matrix-media-repo/api/_responses" "github.com/t2bot/matrix-media-repo/api/_routers" @@ -16,7 +17,7 @@ import ( func ClientDownloadMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta.UserInfo) interface{} { r.URL.Query().Set("allow_remote", "true") r.URL.Query().Set("allow_redirect", "true") - return r0.DownloadMedia(r, rctx, user) + return r0.DownloadMedia(r, rctx, _apimeta.AuthContext{User: user}) } func FederationDownloadMedia(r *http.Request, rctx rcontext.RequestContext, server _apimeta.ServerInfo) interface{} { @@ -26,7 +27,7 @@ func FederationDownloadMedia(r *http.Request, rctx rcontext.RequestContext, serv r.URL.RawQuery = query.Encode() r = _routers.ForceSetParam("server", r.Host, r) - res := r0.DownloadMedia(r, rctx, _apimeta.UserInfo{}) + res := r0.DownloadMedia(r, rctx, _apimeta.AuthContext{Server: server}) boundary, err := ids.NewUniqueId() if err != nil { rctx.Log.Error("Error generating boundary on response: ", err) diff --git a/api/v1/thumbnail.go b/api/v1/thumbnail.go index e0cb7442..6d00d1a9 100644 --- a/api/v1/thumbnail.go +++ b/api/v1/thumbnail.go @@ -2,9 +2,10 @@ package v1 import ( "bytes" - "github.com/t2bot/matrix-media-repo/util/ids" "net/http" + "github.com/t2bot/matrix-media-repo/util/ids" + "github.com/t2bot/matrix-media-repo/api/_apimeta" "github.com/t2bot/matrix-media-repo/api/_responses" "github.com/t2bot/matrix-media-repo/api/_routers" @@ -16,7 +17,7 @@ import ( func ClientThumbnailMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta.UserInfo) interface{} { r.URL.Query().Set("allow_remote", "true") r.URL.Query().Set("allow_redirect", "true") - return r0.ThumbnailMedia(r, rctx, user) + return r0.ThumbnailMedia(r, rctx, _apimeta.AuthContext{User: user}) } func FederationThumbnailMedia(r *http.Request, rctx rcontext.RequestContext, server _apimeta.ServerInfo) interface{} { @@ -26,7 +27,7 @@ func FederationThumbnailMedia(r *http.Request, rctx rcontext.RequestContext, ser r.URL.RawQuery = query.Encode() r = _routers.ForceSetParam("server", r.Host, r) - res := r0.ThumbnailMedia(r, rctx, _apimeta.UserInfo{}) + res := r0.ThumbnailMedia(r, rctx, _apimeta.AuthContext{Server: server}) boundary, err := ids.NewUniqueId() if err != nil { rctx.Log.Error("Error generating boundary on response: ", err) diff --git a/archival/entity_export.go b/archival/entity_export.go index ac5f870c..f7181ac2 100644 --- a/archival/entity_export.go +++ b/archival/entity_export.go @@ -44,6 +44,7 @@ func ExportEntityData(ctx rcontext.RequestContext, exportId string, entityId str FetchRemoteIfNeeded: false, BlockForReadUntil: 10 * time.Minute, RecordOnly: false, + AuthProvided: true, // it's for an export, so assume authentication }) if errors.Is(err, common.ErrMediaQuarantined) { ctx.Log.Warnf("%s is quarantined and will not be included in the export", mxc) diff --git a/common/config/conf_main.go b/common/config/conf_main.go index ab5e2473..8b7e3ac6 100644 --- a/common/config/conf_main.go +++ b/common/config/conf_main.go @@ -24,14 +24,15 @@ func NewDefaultMainConfig() MainRepoConfig { return MainRepoConfig{ MinimumRepoConfig: NewDefaultMinimumRepoConfig(), General: GeneralConfig{ - BindAddress: "127.0.0.1", - Port: 8000, - LogDirectory: "logs", - LogColors: false, - JsonLogs: false, - LogLevel: "info", - TrustAnyForward: false, - UseForwardedHost: true, + BindAddress: "127.0.0.1", + Port: 8000, + LogDirectory: "logs", + LogColors: false, + JsonLogs: false, + LogLevel: "info", + TrustAnyForward: false, + UseForwardedHost: true, + FreezeUnauthenticatedMedia: false, }, Database: DatabaseConfig{ Postgres: "postgres://your_username:your_password@localhost/database_name?sslmode=disable", diff --git a/common/config/models_main.go b/common/config/models_main.go index 6b80d272..fd047591 100644 --- a/common/config/models_main.go +++ b/common/config/models_main.go @@ -1,14 +1,15 @@ package config type GeneralConfig struct { - BindAddress string `yaml:"bindAddress"` - Port int `yaml:"port"` - LogDirectory string `yaml:"logDirectory"` - LogColors bool `yaml:"logColors"` - JsonLogs bool `yaml:"jsonLogs"` - LogLevel string `yaml:"logLevel"` - TrustAnyForward bool `yaml:"trustAnyForwardedAddress"` - UseForwardedHost bool `yaml:"useForwardedHost"` + BindAddress string `yaml:"bindAddress"` + Port int `yaml:"port"` + LogDirectory string `yaml:"logDirectory"` + LogColors bool `yaml:"logColors"` + JsonLogs bool `yaml:"jsonLogs"` + LogLevel string `yaml:"logLevel"` + TrustAnyForward bool `yaml:"trustAnyForwardedAddress"` + UseForwardedHost bool `yaml:"useForwardedHost"` + FreezeUnauthenticatedMedia bool `yaml:"freezeUnauthenticatedMedia"` } type HomeserverConfig struct { diff --git a/common/errors.go b/common/errors.go index 7a1999c4..e1207cc8 100644 --- a/common/errors.go +++ b/common/errors.go @@ -17,3 +17,4 @@ var ErrAlreadyUploaded = errors.New("already uploaded") var ErrMediaNotYetUploaded = errors.New("media not yet uploaded") var ErrMediaDimensionsTooSmall = errors.New("media is too small dimensionally") var ErrRateLimitExceeded = errors.New("rate limit exceeded") +var ErrRestrictedAuth = errors.New("authentication is required to download this media") diff --git a/config.sample.yaml b/config.sample.yaml index a657d0a2..5e037828 100644 --- a/config.sample.yaml +++ b/config.sample.yaml @@ -34,6 +34,17 @@ repo: # See https://github.com/t2bot/matrix-media-repo/issues/202 for more information. useForwardedHost: true + # If true, media uploaded or cached from that point forwards will require authentication in order to + # be accessed. Media uploaded or cached prior will remain accessible on the unauthenticated endpoints. + # If set to false after being set to true, media uploaded or cached while the flag was true will still + # only be accessible over authenticated endpoints, though future media will be accessible on both + # authenticated and unauthenticated media. + # + # This flag currently defaults to false. A future release, likely in August 2024, will remove this flag + # and have the same effect as it being true (always on). This flag is primarily intended for servers to + # opt-in to the behaviour early. + freezeUnauthenticatedMedia: false + # Options for dealing with federation federation: # On a per-host basis, the number of consecutive failures in calling the host before the diff --git a/database/db.go b/database/db.go index ffae2bc4..2a3da1a3 100644 --- a/database/db.go +++ b/database/db.go @@ -28,6 +28,7 @@ type Database struct { Tasks *tasksTableStatements Exports *exportsTableStatements ExportParts *exportPartsTableStatements + RestrictedMedia *restrictedMediaTableStatements } var instance *Database @@ -126,6 +127,9 @@ func openDatabase(connectionString string, maxConns int, maxIdleConns int) error if d.ExportParts, err = prepareExportPartsTables(d.conn); err != nil { return errors.New("failed to create export parts table accessor: " + err.Error()) } + if d.RestrictedMedia, err = prepareRestrictedMediaTables(d.conn); err != nil { + return errors.New("failed to create restricted media table accessor: " + err.Error()) + } instance = d return nil diff --git a/database/table_restricted_media.go b/database/table_restricted_media.go new file mode 100644 index 00000000..74b5d600 --- /dev/null +++ b/database/table_restricted_media.go @@ -0,0 +1,87 @@ +package database + +import ( + "database/sql" + "errors" + + "github.com/t2bot/matrix-media-repo/common/rcontext" +) + +type RestrictedCondition string + +const RestrictedRequiresAuth RestrictedCondition = "io.t2bot.requires_authentication" // Internal extension + +type DbRestrictedMedia struct { + Origin string + MediaId string + Condition RestrictedCondition + ConditionValue string +} + +const insertRestrictedMedia = "INSERT INTO restricted_media (origin, media_id, condition_type, condition_value) VALUES ($1, $2, $3, $4);" +const updateRestrictedMedia = "UPDATE restricted_media SET condition_type = $3, condition_value = $4 WHERE origin = $1 AND media_id = $2;" +const selectRestrictedMedia = "SELECT origin, media_id, condition_type, condition_value FROM restricted_media WHERE origin = $1 AND media_id = $2;" + +type restrictedMediaTableStatements struct { + insertRestrictedMedia *sql.Stmt + updateRestrictedMedia *sql.Stmt + selectRestrictedMedia *sql.Stmt +} + +type restrictedMediaTableWithContext struct { + statements *restrictedMediaTableStatements + ctx rcontext.RequestContext +} + +func prepareRestrictedMediaTables(db *sql.DB) (*restrictedMediaTableStatements, error) { + var err error + var stmts = &restrictedMediaTableStatements{} + + if stmts.insertRestrictedMedia, err = db.Prepare(insertRestrictedMedia); err != nil { + return nil, errors.New("error preparing insertRestrictedMedia: " + err.Error()) + } + if stmts.updateRestrictedMedia, err = db.Prepare(updateRestrictedMedia); err != nil { + return nil, errors.New("error preparing updateRestrictedMedia: " + err.Error()) + } + if stmts.selectRestrictedMedia, err = db.Prepare(selectRestrictedMedia); err != nil { + return nil, errors.New("error preparing selectRestrictedMedia: " + err.Error()) + } + + return stmts, nil +} + +func (s *restrictedMediaTableStatements) Prepare(ctx rcontext.RequestContext) *restrictedMediaTableWithContext { + return &restrictedMediaTableWithContext{ + statements: s, + ctx: ctx, + } +} + +func (s *restrictedMediaTableWithContext) Insert(origin string, mediaId string, condition RestrictedCondition, conditionValue string) error { + _, err := s.statements.insertRestrictedMedia.ExecContext(s.ctx, origin, mediaId, condition, conditionValue) + return err +} + +func (s *restrictedMediaTableWithContext) Update(origin string, mediaId string, condition RestrictedCondition, conditionValue string) error { + _, err := s.statements.updateRestrictedMedia.ExecContext(s.ctx, origin, mediaId, condition, conditionValue) + return err +} + +func (s *restrictedMediaTableWithContext) GetAllForId(origin string, mediaId string) ([]*DbRestrictedMedia, error) { + results := make([]*DbRestrictedMedia, 0) + rows, err := s.statements.selectRestrictedMedia.QueryContext(s.ctx, origin, mediaId) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return results, nil + } + return nil, err + } + for rows.Next() { + val := &DbRestrictedMedia{} + if err = rows.Scan(&val.Origin, &val.MediaId, &val.Condition, &val.ConditionValue); err != nil { + return nil, err + } + results = append(results, val) + } + return results, nil +} diff --git a/migrations/29_create_media_restrictions_down.sql b/migrations/29_create_media_restrictions_down.sql new file mode 100644 index 00000000..7553179a --- /dev/null +++ b/migrations/29_create_media_restrictions_down.sql @@ -0,0 +1,2 @@ +DROP INDEX IF EXISTS idx_restricted_media; +DROP TABLE IF EXISTS restricted_media; \ No newline at end of file diff --git a/migrations/29_create_media_restrictions_up.sql b/migrations/29_create_media_restrictions_up.sql new file mode 100644 index 00000000..ae5e5121 --- /dev/null +++ b/migrations/29_create_media_restrictions_up.sql @@ -0,0 +1,2 @@ +CREATE TABLE IF NOT EXISTS restricted_media (origin TEXT NOT NULL, media_id TEXT NOT NULL, condition_type TEXT NOT NULL, condition_value TEXT NOT NULL); +CREATE UNIQUE INDEX IF NOT EXISTS idx_restricted_media ON restricted_media (origin, media_id); \ No newline at end of file diff --git a/pipelines/pipeline_download/pipeline.go b/pipelines/pipeline_download/pipeline.go index fa45c514..e56715b8 100644 --- a/pipelines/pipeline_download/pipeline.go +++ b/pipelines/pipeline_download/pipeline.go @@ -18,6 +18,7 @@ import ( "github.com/t2bot/matrix-media-repo/pipelines/_steps/download" "github.com/t2bot/matrix-media-repo/pipelines/_steps/meta" "github.com/t2bot/matrix-media-repo/pipelines/_steps/quarantine" + "github.com/t2bot/matrix-media-repo/restrictions" "github.com/t2bot/matrix-media-repo/util/readers" "github.com/t2bot/matrix-media-repo/util/sfcache" ) @@ -34,6 +35,7 @@ type DownloadOpts struct { BlockForReadUntil time.Duration RecordOnly bool CanRedirect bool + AuthProvided bool } func (o DownloadOpts) String() string { @@ -41,6 +43,13 @@ func (o DownloadOpts) String() string { } func Execute(ctx rcontext.RequestContext, origin string, mediaId string, opts DownloadOpts) (*database.DbMedia, io.ReadCloser, error) { + // Step 0: Check restrictions + if requiresAuth, err := restrictions.DoesMediaRequireAuth(ctx, origin, mediaId); err != nil { + return nil, nil, err + } else if requiresAuth && !opts.AuthProvided { + return nil, nil, common.ErrRestrictedAuth + } + // Step 1: Make our context a timeout context var cancel context.CancelFunc //goland:noinspection GoVetLostCancel - we handle the function in our custom cancelCloser struct diff --git a/pipelines/pipeline_thumbnail/pipeline.go b/pipelines/pipeline_thumbnail/pipeline.go index c8d9d5ae..fb88816b 100644 --- a/pipelines/pipeline_thumbnail/pipeline.go +++ b/pipelines/pipeline_thumbnail/pipeline.go @@ -17,6 +17,7 @@ import ( "github.com/t2bot/matrix-media-repo/pipelines/_steps/quarantine" "github.com/t2bot/matrix-media-repo/pipelines/_steps/thumbnails" "github.com/t2bot/matrix-media-repo/pipelines/pipeline_download" + "github.com/t2bot/matrix-media-repo/restrictions" "github.com/t2bot/matrix-media-repo/util/readers" "github.com/t2bot/matrix-media-repo/util/sfcache" ) @@ -46,10 +47,18 @@ func (o ThumbnailOpts) ImpliedDownloadOpts() pipeline_download.DownloadOpts { FetchRemoteIfNeeded: o.FetchRemoteIfNeeded, BlockForReadUntil: o.BlockForReadUntil, RecordOnly: true, + AuthProvided: o.AuthProvided, } } func Execute(ctx rcontext.RequestContext, origin string, mediaId string, opts ThumbnailOpts) (*database.DbThumbnail, io.ReadCloser, error) { + // Step 0: Check restrictions + if requiresAuth, err := restrictions.DoesMediaRequireAuth(ctx, origin, mediaId); err != nil { + return nil, nil, err + } else if requiresAuth && !opts.AuthProvided { + return nil, nil, common.ErrRestrictedAuth + } + // Step 1: Fix the request parameters w, h, method, err1 := thumbnails.PickNewDimensions(ctx, opts.Width, opts.Height, opts.Method) if err1 != nil { diff --git a/pipelines/pipeline_upload/pipeline.go b/pipelines/pipeline_upload/pipeline.go index 369c899e..3ef21af2 100644 --- a/pipelines/pipeline_upload/pipeline.go +++ b/pipelines/pipeline_upload/pipeline.go @@ -14,6 +14,7 @@ import ( "github.com/t2bot/matrix-media-repo/pipelines/_steps/meta" "github.com/t2bot/matrix-media-repo/pipelines/_steps/quota" "github.com/t2bot/matrix-media-repo/pipelines/_steps/upload" + "github.com/t2bot/matrix-media-repo/restrictions" "github.com/t2bot/matrix-media-repo/util" "github.com/t2bot/matrix-media-repo/util/readers" ) @@ -142,6 +143,11 @@ func Execute(ctx rcontext.RequestContext, origin string, mediaId string, r io.Re if err = database.GetInstance().Media.Prepare(ctx).Insert(newRecord); err != nil { return nil, err } + if config.Get().General.FreezeUnauthenticatedMedia { + if err = restrictions.SetMediaRequiresAuth(ctx, newRecord.Origin, newRecord.MediaId); err != nil { + return nil, err + } + } uploadDone(newRecord) return newRecord, nil } @@ -173,6 +179,11 @@ func Execute(ctx rcontext.RequestContext, origin string, mediaId string, r io.Re } return nil, err } + if config.Get().General.FreezeUnauthenticatedMedia { + if err = restrictions.SetMediaRequiresAuth(ctx, newRecord.Origin, newRecord.MediaId); err != nil { + return nil, err + } + } uploadDone(newRecord) return newRecord, nil } diff --git a/restrictions/auth.go b/restrictions/auth.go new file mode 100644 index 00000000..ba6c9594 --- /dev/null +++ b/restrictions/auth.go @@ -0,0 +1,23 @@ +package restrictions + +import ( + "github.com/t2bot/matrix-media-repo/common/rcontext" + "github.com/t2bot/matrix-media-repo/database" +) + +func DoesMediaRequireAuth(ctx rcontext.RequestContext, origin string, mediaId string) (bool, error) { + restrictions, err := database.GetInstance().RestrictedMedia.Prepare(ctx).GetAllForId(origin, mediaId) + if err != nil { + return false, err + } + for _, restriction := range restrictions { + if restriction.Condition == database.RestrictedRequiresAuth { + return restriction.ConditionValue == "true", nil + } + } + return false, nil +} + +func SetMediaRequiresAuth(ctx rcontext.RequestContext, origin string, mediaId string) error { + return database.GetInstance().RestrictedMedia.Prepare(ctx).Insert(origin, mediaId, database.RestrictedRequiresAuth, "true") +}