Skip to content

Commit

Permalink
First round of requiring a content token everywhere
Browse files Browse the repository at this point in the history
This currently doesn't build due to some functions not getting the content token treatment. Upon reflection, it may be best to just work on this after rewriting the middle layer.

Part of #103
  • Loading branch information
turt2live committed Jun 15, 2018
1 parent dbc6d2b commit ad0f783
Show file tree
Hide file tree
Showing 9 changed files with 152 additions and 45 deletions.
2 changes: 2 additions & 0 deletions migrations/5_add_visibility_to_media_down.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
ALTER TABLE media DROP COLUMN visibility;
ALTER TABLE media DROP COLUMN content_token_hash;
2 changes: 2 additions & 0 deletions migrations/5_add_visibility_to_media_up.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
ALTER TABLE media ADD COLUMN visibility TEXT NOT NULL DEFAULT 'public';
ALTER TABLE media ADD COLUMN content_token_hash TEXT NULL DEFAULT NULL;
26 changes: 22 additions & 4 deletions src/github.com/turt2live/matrix-media-repo/api/r0/upload.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"io"
"io/ioutil"
"net/http"
"strconv"

"github.com/sirupsen/logrus"
"github.com/turt2live/matrix-media-repo/api"
Expand All @@ -12,7 +13,8 @@ import (
)

type MediaUploadedResponse struct {
ContentUri string `json:"content_uri"`
ContentUri string `json:"content_uri"`
ContentToken string `json:"content_token,omitempty"`
}

func UploadMedia(r *http.Request, log *logrus.Entry, user api.UserInfo) interface{} {
Expand All @@ -21,8 +23,24 @@ func UploadMedia(r *http.Request, log *logrus.Entry, user api.UserInfo) interfac
filename = "upload.bin"
}

isPublic := r.URL.Query().Get("public")
visibility := "public"
if isPublic != "" {
parsedFlag, err := strconv.ParseBool(isPublic)
if err != nil {
return api.InternalServerError("public flag does not appear to be a boolean")
}

if parsedFlag {
visibility = "public"
} else {
visibility = "private"
}
}

log = log.WithFields(logrus.Fields{
"filename": filename,
"filename": filename,
"visibility": visibility,
})

contentType := r.Header.Get("Content-Type")
Expand All @@ -38,7 +56,7 @@ func UploadMedia(r *http.Request, log *logrus.Entry, user api.UserInfo) interfac
return api.RequestTooLarge()
}

media, err := svc.UploadMedia(r.Body, contentType, filename, user.UserId, r.Host)
media, unhashedContentToken, err := svc.UploadMedia(r.Body, contentType, filename, visibility, user.UserId, r.Host)
if err != nil {
io.Copy(ioutil.Discard, r.Body) // Ditch the entire request
defer r.Body.Close()
Expand All @@ -51,5 +69,5 @@ func UploadMedia(r *http.Request, log *logrus.Entry, user api.UserInfo) interfac
return api.InternalServerError("Unexpected Error")
}

return &MediaUploadedResponse{media.MxcUri()}
return &MediaUploadedResponse{media.MxcUri(), unhashedContentToken}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"os"
"strconv"

"github.com/pkg/errors"
"github.com/ryanuber/go-glob"
"github.com/sirupsen/logrus"
"github.com/turt2live/matrix-media-repo/common"
Expand All @@ -32,14 +33,14 @@ func (s *mediaService) GetMediaDirect(server string, mediaId string) (*types.Med
return s.store.Get(server, mediaId)
}

func (s *mediaService) GetRemoteMediaDirect(server string, mediaId string) (*types.Media, error) {
return s.downloadRemoteMedia(server, mediaId)
func (s *mediaService) GetRemoteMediaDirect(server string, mediaId string, rawContentToken string) (*types.Media, error) {
return s.downloadRemoteMedia(server, mediaId, rawContentToken)
}

func (s *mediaService) downloadRemoteMedia(server string, mediaId string) (*types.Media, error) {
func (s *mediaService) downloadRemoteMedia(server string, mediaId string, rawContentToken string) (*types.Media, error) {
s.log.Info("Attempting to download remote media")

result := <-getResourceHandler().DownloadRemoteMedia(server, mediaId)
result := <-getResourceHandler().DownloadRemoteMedia(server, mediaId, rawContentToken)
return result.media, result.err
}

Expand Down Expand Up @@ -135,7 +136,7 @@ func (s *mediaService) PurgeRemoteMediaBefore(beforeTs int64) (int, error) {
return removed, nil
}

func (s *mediaService) UploadMedia(contents io.ReadCloser, contentType string, filename string, userId string, host string) (*types.Media, error) {
func (s *mediaService) UploadMedia(contents io.ReadCloser, contentType string, filename string, visibility string, userId string, host string) (*types.Media, string, error) {
defer contents.Close()
var data io.Reader
if config.Get().Uploads.MaxSizeBytes > 0 {
Expand All @@ -144,68 +145,93 @@ func (s *mediaService) UploadMedia(contents io.ReadCloser, contentType string, f
data = contents
}

return s.StoreMedia(data, contentType, filename, userId, host, "")
return s.StoreMedia(data, contentType, filename, visibility, userId, host, "", "")
}

func (s *mediaService) StoreMedia(contents io.Reader, contentType string, filename string, userId string, host string, mediaId string) (*types.Media, error) {
func (s *mediaService) StoreMedia(contents io.Reader, contentType string, filename string, visibility string, userId string, host string, mediaId string, unhashedContentToken string) (*types.Media, string, error) {
isGeneratedId := false
if mediaId == "" {
mediaId = generateMediaId()
isGeneratedId = true
}
isGeneratedToken := false
if unhashedContentToken == "" && visibility != "public" {
unhashedContentToken = generateContentToken()
isGeneratedToken = true
}
log := s.log.WithFields(logrus.Fields{
"mediaService_mediaId": mediaId,
"mediaService_host": host,
"mediaService_mediaIdIsGenerated": isGeneratedId,
"mediaService_mediaId": mediaId,
"mediaService_host": host,
"mediaService_mediaIdIsGenerated": isGeneratedId,
"mediaService_contentTokenIsGenerated": isGeneratedToken,
})

var hashedContentToken *string
if visibility != "public" {
hashed, err := util.HashString(unhashedContentToken)
if err != nil {
return nil, unhashedContentToken, err
}

hashedContentToken = &hashed
}

// Store the file in a temporary location
fileLocation, err := storage.PersistFile(contents, s.ctx, s.log)
if err != nil {
return nil, err
return nil, unhashedContentToken, err
}

// Check to make sure the file is allowed
fileMime, err := util.GetMimeType(fileLocation)
if err != nil {
s.log.Error("Error while checking content type of file: " + err.Error())
os.Remove(fileLocation) // attempt cleanup
return nil, err
return nil, unhashedContentToken, err
}

for _, allowedType := range config.Get().Uploads.AllowedTypes {
if !glob.Glob(allowedType, fileMime) {
s.log.Warn("Content type " + fileMime + " (reported as " + contentType + ") is not allowed to be uploaded")

os.Remove(fileLocation) // attempt cleanup
return nil, common.ErrMediaNotAllowed
return nil, unhashedContentToken, common.ErrMediaNotAllowed
}
}

hash, err := storage.GetFileHash(fileLocation)
if err != nil {
os.Remove(fileLocation) // attempt cleanup
return nil, err
return nil, unhashedContentToken, err
}

records, err := s.store.GetByHash(hash)
if err != nil {
os.Remove(fileLocation) // attempt cleanup
return nil, err
return nil, unhashedContentToken, err
}

// If there's at least one record, then we have a duplicate hash - try and process it
if len(records) > 0 {
// See if we one of the duplicate records is a match for the host and media ID. We'll otherwise use
// the last duplicate (should only be 1 anyways) as our starting point for a new record.

var media *types.Media
for i := 0; i < len(records); i++ {
media = records[i]

if media.Origin == host && (media.MediaId == mediaId || isGeneratedId) && media.ContentType == contentType && media.UserId == userId {
if media.Origin == host && media.MediaId == mediaId && media.ContentType == contentType && media.UserId == userId {
log.Info("User has uploaded this media before - returning unaltered media record")

if media.Visibility != visibility || (visibility != "public" && (media.ContentTokenHash == nil || hashedContentToken == nil || *media.ContentTokenHash != *hashedContentToken)) {
log.Warn("Media visibility or content token does not match what is already stored. Refusing to serve media.")

os.Remove(fileLocation) // attempt cleanup
return nil, unhashedContentToken, errors.New("media visibility or content token does not match what is already stored")
}

overwriteExistingOrDeleteTempFile(fileLocation, media)
return media, nil
return media, unhashedContentToken, nil
}

// The last media object will be used to create a new pointer
Expand All @@ -219,22 +245,26 @@ func (s *mediaService) StoreMedia(contents io.Reader, contentType string, filena
media.UploadName = filename
media.ContentType = contentType
media.CreationTs = util.NowMillis()
media.Visibility = visibility
if visibility != "public" {
media.ContentTokenHash = hashedContentToken
}

err = s.store.Insert(media)
if err != nil {
return nil, err
return nil, unhashedContentToken, err
}

overwriteExistingOrDeleteTempFile(fileLocation, media)
return media, nil
return media, unhashedContentToken, nil
}

// No duplicate hash, so we have to create a new record

fileSize, err := util.FileSize(fileLocation)
if err != nil {
os.Remove(fileLocation) // attempt cleanup
return nil, err
return nil, unhashedContentToken, err
}

log.Info("Persisting unique media record")
Expand All @@ -249,15 +279,19 @@ func (s *mediaService) StoreMedia(contents io.Reader, contentType string, filena
SizeBytes: fileSize,
Location: fileLocation,
CreationTs: util.NowMillis(),
Visibility: visibility,
}
if visibility != "public" {
media.ContentTokenHash = hashedContentToken
}

err = s.store.Insert(media)
if err != nil {
os.Remove(fileLocation) // attempt cleanup
return nil, err
return nil, unhashedContentToken, err
}

return media, nil
return media, unhashedContentToken, nil
}

func generateMediaId() string {
Expand All @@ -269,6 +303,15 @@ func generateMediaId() string {
return str
}

func generateContentToken() string {
str, err := util.GenerateRandomString(128)
if err != nil {
panic(err)
}

return str
}

func overwriteExistingOrDeleteTempFile(tempFileLocation string, media *types.Media) {
// If the media's file exists, we'll delete the temp file
// If the media's file doesn't exist, we'll move the temp file to where the media expects it to be
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@ type mediaResourceHandler struct {
}

type downloadRequest struct {
origin string
mediaId string
origin string
mediaId string
rawContentToken string
}

type downloadResponse struct {
Expand Down Expand Up @@ -62,19 +63,24 @@ func downloadResourceWorkFn(request *resource_handler.WorkRequest) interface{} {
svc := New(ctx, log) // media_service (us)
defer downloaded.Contents.Close()

media, err := svc.StoreMedia(downloaded.Contents, downloaded.ContentType, downloaded.DesiredFilename, "", info.origin, info.mediaId)
visibility := "public"
if info.rawContentToken != "" {
visibility = "private"
}

media, _, err := svc.StoreMedia(downloaded.Contents, downloaded.ContentType, downloaded.DesiredFilename, visibility, "", info.origin, info.mediaId, info.rawContentToken)
if err != nil {
return &downloadResponse{err: err}
}

return &downloadResponse{media, err}
}

func (h *mediaResourceHandler) DownloadRemoteMedia(origin string, mediaId string) chan *downloadResponse {
func (h *mediaResourceHandler) DownloadRemoteMedia(origin string, mediaId string, rawContentToken string) chan *downloadResponse {
resultChan := make(chan *downloadResponse)
go func() {
reqId := "remote_download:" + origin + "_" + mediaId
result := <-h.resourceHandler.GetResource(reqId, &downloadRequest{origin, mediaId})
result := <-h.resourceHandler.GetResource(reqId, &downloadRequest{origin, mediaId, rawContentToken})
resultChan <- result.(*downloadResponse)
}()
return resultChan
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ import (
"github.com/turt2live/matrix-media-repo/types"
)

const selectMedia = "SELECT origin, media_id, upload_name, content_type, user_id, sha256_hash, size_bytes, location, creation_ts, quarantined FROM media WHERE origin = $1 and media_id = $2;"
const selectMediaByHash = "SELECT origin, media_id, upload_name, content_type, user_id, sha256_hash, size_bytes, location, creation_ts, quarantined FROM media WHERE sha256_hash = $1;"
const insertMedia = "INSERT INTO media (origin, media_id, upload_name, content_type, user_id, sha256_hash, size_bytes, location, creation_ts, quarantined) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10);"
const selectOldMedia = "SELECT m.origin, m.media_id, m.upload_name, m.content_type, m.user_id, m.sha256_hash, m.size_bytes, m.location, m.creation_ts, quarantined FROM media AS m WHERE NOT(m.origin = ANY($1)) AND m.creation_ts < $2 AND (SELECT COUNT(*) FROM media AS d WHERE d.sha256_hash = m.sha256_hash AND d.creation_ts >= $2) = 0 AND (SELECT COUNT(*) FROM media AS d WHERE d.sha256_hash = m.sha256_hash AND d.origin = ANY($1)) = 0;"
const selectMedia = "SELECT origin, media_id, upload_name, content_type, user_id, sha256_hash, size_bytes, location, creation_ts, quarantined, visibility, content_token_hash FROM media WHERE origin = $1 and media_id = $2;"
const selectMediaByHash = "SELECT origin, media_id, upload_name, content_type, user_id, sha256_hash, size_bytes, location, creation_ts, quarantined, visibility, content_token_hash FROM media WHERE sha256_hash = $1;"
const insertMedia = "INSERT INTO media (origin, media_id, upload_name, content_type, user_id, sha256_hash, size_bytes, location, creation_ts, quarantined, visibility, content_token_hash) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12);"
const selectOldMedia = "SELECT m.origin, m.media_id, m.upload_name, m.content_type, m.user_id, m.sha256_hash, m.size_bytes, m.location, m.creation_ts, quarantined, visibility, content_token_hash FROM media AS m WHERE NOT(m.origin = ANY($1)) AND m.creation_ts < $2 AND (SELECT COUNT(*) FROM media AS d WHERE d.sha256_hash = m.sha256_hash AND d.creation_ts >= $2) = 0 AND (SELECT COUNT(*) FROM media AS d WHERE d.sha256_hash = m.sha256_hash AND d.origin = ANY($1)) = 0;"
const selectOrigins = "SELECT DISTINCT origin FROM media;"
const deleteMedia = "DELETE FROM media WHERE origin = $1 AND media_id = $2;"
const updateQuarantined = "UPDATE media SET quarantined = $3 WHERE origin = $1 AND media_id = $2;"
Expand Down Expand Up @@ -92,6 +92,8 @@ func (s *MediaStore) Insert(media *types.Media) (error) {
media.Location,
media.CreationTs,
media.Quarantined,
media.Visibility,
media.ContentTokenHash,
)
return err
}
Expand All @@ -116,6 +118,8 @@ func (s *MediaStore) GetByHash(hash string) ([]*types.Media, error) {
&obj.Location,
&obj.CreationTs,
&obj.Quarantined,
&obj.Visibility,
&obj.ContentTokenHash,
)
if err != nil {
return nil, err
Expand All @@ -139,6 +143,8 @@ func (s *MediaStore) Get(origin string, mediaId string) (*types.Media, error) {
&m.Location,
&m.CreationTs,
&m.Quarantined,
&m.Visibility,
&m.ContentTokenHash,
)
return m, err
}
Expand All @@ -163,6 +169,8 @@ func (s *MediaStore) GetOldMedia(exceptOrigins []string, beforeTs int64) ([]*typ
&obj.Location,
&obj.CreationTs,
&obj.Quarantined,
&obj.Visibility,
&obj.ContentTokenHash,
)
if err != nil {
return nil, err
Expand Down
22 changes: 12 additions & 10 deletions src/github.com/turt2live/matrix-media-repo/types/media.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,18 @@ package types
import "io"

type Media struct {
Origin string
MediaId string
UploadName string
ContentType string
UserId string
Sha256Hash string
SizeBytes int64
Location string
CreationTs int64
Quarantined bool
Origin string
MediaId string
UploadName string
ContentType string
UserId string
Sha256Hash string
SizeBytes int64
Location string
CreationTs int64
Quarantined bool
Visibility string // "public" or "private" at the moment
ContentTokenHash *string
}

type StreamedMedia struct {
Expand Down
Loading

0 comments on commit ad0f783

Please sign in to comment.