diff --git a/internal/storage/mysqlstore/mysql.go b/internal/storage/mysqlstore/mysql.go index 1077920b..6aa6120b 100644 --- a/internal/storage/mysqlstore/mysql.go +++ b/internal/storage/mysqlstore/mysql.go @@ -115,6 +115,27 @@ func (s *store) GetBySubjectDigest(ctx context.Context, request *archivist.GetBy return &archivist.GetBySubjectDigestResponse{Object: results}, err } +func (s *store) withTx(ctx context.Context, fn func(tx *ent.Tx) error) error { + tx, err := s.client.Tx(ctx) + if err != nil { + return err + } + + if err := fn(tx); err != nil { + if err := tx.Rollback(); err != nil { + return fmt.Errorf("unable to rollback transaction: %w", err) + } + + return err + } + + if err := tx.Commit(); err != nil { + return fmt.Errorf("unable to commit transaction: %w", err) + } + + return nil +} + // attestation.Collection from go-witness will try to parse each of the attestations by calling their factory functions, // which require the attestations to be registered in the go-witness library. We don't really care about the actual attestation // data for the purposes here, so just leave it as a raw message. @@ -157,89 +178,90 @@ func (s *store) Store(ctx context.Context, request *archivist.StoreRequest) (*em return nil, err } - tx, err := s.client.Tx(ctx) - dsse, err := tx.Dsse.Create(). - SetPayloadType(envelope.PayloadType). - SetGitbomSha256(gb.Identity()). - Save(ctx) - if err != nil { - return nil, err - } - - for _, sig := range envelope.Signatures { - _, err = tx.Signature.Create(). - SetKeyID(sig.KeyID). - SetSignature(base64.StdEncoding.EncodeToString(sig.Signature)). - SetDsse(dsse). + err = s.withTx(ctx, func(tx *ent.Tx) error { + dsse, err := tx.Dsse.Create(). + SetPayloadType(envelope.PayloadType). + SetGitbomSha256(gb.Identity()). Save(ctx) if err != nil { - return nil, err + return err } - } - for hashFn, digest := range payloadDigestSet { - hashName, err := cryptoutil.HashToString(hashFn) - if err != nil { - return nil, err + for _, sig := range envelope.Signatures { + _, err = tx.Signature.Create(). + SetKeyID(sig.KeyID). + SetSignature(base64.StdEncoding.EncodeToString(sig.Signature)). + SetDsse(dsse). + Save(ctx) + if err != nil { + return err + } } - if _, err := tx.PayloadDigest.Create(). - SetDsse(dsse). - SetAlgorithm(hashName). - SetValue(digest). - Save(ctx); err != nil { - return nil, err - } - } + for hashFn, digest := range payloadDigestSet { + hashName, err := cryptoutil.HashToString(hashFn) + if err != nil { + return err + } - stmt, err := tx.Statement.Create(). - SetPredicate(payload.PredicateType). - AddDsse(dsse). - Save(ctx) - if err != nil { - return nil, err - } + if _, err := tx.PayloadDigest.Create(). + SetDsse(dsse). + SetAlgorithm(hashName). + SetValue(digest). + Save(ctx); err != nil { + return err + } + } - for _, subject := range payload.Subject { - storedSubject, err := tx.Subject.Create(). - SetName(subject.Name). - SetStatement(stmt). + stmt, err := tx.Statement.Create(). + SetPredicate(payload.PredicateType). + AddDsse(dsse). Save(ctx) if err != nil { - return nil, err + return err } - for algorithm, value := range subject.Digest { - if err := tx.SubjectDigest.Create(). - SetAlgorithm(algorithm). - SetValue(value).SetSubject(storedSubject). - Exec(ctx); err != nil { - return nil, err + for _, subject := range payload.Subject { + storedSubject, err := tx.Subject.Create(). + SetName(subject.Name). + SetStatement(stmt). + Save(ctx) + if err != nil { + return err + } + + for algorithm, value := range subject.Digest { + if err := tx.SubjectDigest.Create(). + SetAlgorithm(algorithm). + SetValue(value).SetSubject(storedSubject). + Exec(ctx); err != nil { + return err + } } } - } - collection, err := tx.AttestationCollection.Create(). - SetStatementID(stmt.ID). - SetName(parsedCollection.Name). - Save(ctx) - if err != nil { - return nil, err - } + collection, err := tx.AttestationCollection.Create(). + SetStatementID(stmt.ID). + SetName(parsedCollection.Name). + Save(ctx) + if err != nil { + return err + } - for _, a := range parsedCollection.Attestations { - if err := tx.Attestation.Create(). - SetAttestationCollectionID(collection.ID). - SetType(a.Type). - Exec(ctx); err != nil { - return nil, err + for _, a := range parsedCollection.Attestations { + if err := tx.Attestation.Create(). + SetAttestationCollectionID(collection.ID). + SetType(a.Type). + Exec(ctx); err != nil { + return err + } } - } - err = tx.Commit() + return nil + }) if err != nil { - logrus.Errorf("unable to commit transaction: %+v", err) + logrus.Errorf("unable to store metadata: %+v", err) return nil, err } @@ -253,6 +275,6 @@ func (s *store) Store(ctx context.Context, request *archivist.StoreRequest) (*em fmt.Println("object stored") } - // ******************************************************************************** + return &emptypb.Empty{}, nil }