Skip to content

Commit

Permalink
chore: add error handling to properly rollback transactions
Browse files Browse the repository at this point in the history
  • Loading branch information
mikhailswift committed Jul 5, 2022
1 parent aa95769 commit e81bb1c
Showing 1 changed file with 86 additions and 64 deletions.
150 changes: 86 additions & 64 deletions internal/storage/mysqlstore/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,27 @@ func (s *store) GetBySubjectDigest(ctx context.Context, request *archivist.GetBy
return &archivist.GetBySubjectDigestResponse{Object: results}, err
}

func (s *store) withTx(ctx context.Context, fn func(tx *ent.Tx) error) error {
tx, err := s.client.Tx(ctx)
if err != nil {
return err
}

if err := fn(tx); err != nil {
if err := tx.Rollback(); err != nil {
return fmt.Errorf("unable to rollback transaction: %w", err)
}

return err
}

if err := tx.Commit(); err != nil {
return fmt.Errorf("unable to commit transaction: %w", err)
}

return nil
}

// attestation.Collection from go-witness will try to parse each of the attestations by calling their factory functions,
// which require the attestations to be registered in the go-witness library. We don't really care about the actual attestation
// data for the purposes here, so just leave it as a raw message.
Expand Down Expand Up @@ -157,89 +178,90 @@ func (s *store) Store(ctx context.Context, request *archivist.StoreRequest) (*em
return nil, err
}

tx, err := s.client.Tx(ctx)
dsse, err := tx.Dsse.Create().
SetPayloadType(envelope.PayloadType).
SetGitbomSha256(gb.Identity()).
Save(ctx)
if err != nil {
return nil, err
}

for _, sig := range envelope.Signatures {
_, err = tx.Signature.Create().
SetKeyID(sig.KeyID).
SetSignature(base64.StdEncoding.EncodeToString(sig.Signature)).
SetDsse(dsse).
err = s.withTx(ctx, func(tx *ent.Tx) error {
dsse, err := tx.Dsse.Create().
SetPayloadType(envelope.PayloadType).
SetGitbomSha256(gb.Identity()).
Save(ctx)
if err != nil {
return nil, err
return err
}
}

for hashFn, digest := range payloadDigestSet {
hashName, err := cryptoutil.HashToString(hashFn)
if err != nil {
return nil, err
for _, sig := range envelope.Signatures {
_, err = tx.Signature.Create().
SetKeyID(sig.KeyID).
SetSignature(base64.StdEncoding.EncodeToString(sig.Signature)).
SetDsse(dsse).
Save(ctx)
if err != nil {
return err
}
}

if _, err := tx.PayloadDigest.Create().
SetDsse(dsse).
SetAlgorithm(hashName).
SetValue(digest).
Save(ctx); err != nil {
return nil, err
}
}
for hashFn, digest := range payloadDigestSet {
hashName, err := cryptoutil.HashToString(hashFn)
if err != nil {
return err
}

stmt, err := tx.Statement.Create().
SetPredicate(payload.PredicateType).
AddDsse(dsse).
Save(ctx)
if err != nil {
return nil, err
}
if _, err := tx.PayloadDigest.Create().
SetDsse(dsse).
SetAlgorithm(hashName).
SetValue(digest).
Save(ctx); err != nil {
return err
}
}

for _, subject := range payload.Subject {
storedSubject, err := tx.Subject.Create().
SetName(subject.Name).
SetStatement(stmt).
stmt, err := tx.Statement.Create().
SetPredicate(payload.PredicateType).
AddDsse(dsse).
Save(ctx)
if err != nil {
return nil, err
return err
}

for algorithm, value := range subject.Digest {
if err := tx.SubjectDigest.Create().
SetAlgorithm(algorithm).
SetValue(value).SetSubject(storedSubject).
Exec(ctx); err != nil {
return nil, err
for _, subject := range payload.Subject {
storedSubject, err := tx.Subject.Create().
SetName(subject.Name).
SetStatement(stmt).
Save(ctx)
if err != nil {
return err
}

for algorithm, value := range subject.Digest {
if err := tx.SubjectDigest.Create().
SetAlgorithm(algorithm).
SetValue(value).SetSubject(storedSubject).
Exec(ctx); err != nil {
return err
}
}
}
}

collection, err := tx.AttestationCollection.Create().
SetStatementID(stmt.ID).
SetName(parsedCollection.Name).
Save(ctx)
if err != nil {
return nil, err
}
collection, err := tx.AttestationCollection.Create().
SetStatementID(stmt.ID).
SetName(parsedCollection.Name).
Save(ctx)
if err != nil {
return err
}

for _, a := range parsedCollection.Attestations {
if err := tx.Attestation.Create().
SetAttestationCollectionID(collection.ID).
SetType(a.Type).
Exec(ctx); err != nil {
return nil, err
for _, a := range parsedCollection.Attestations {
if err := tx.Attestation.Create().
SetAttestationCollectionID(collection.ID).
SetType(a.Type).
Exec(ctx); err != nil {
return err
}
}
}

err = tx.Commit()
return nil
})

if err != nil {
logrus.Errorf("unable to commit transaction: %+v", err)
logrus.Errorf("unable to store metadata: %+v", err)
return nil, err
}

Expand All @@ -253,6 +275,6 @@ func (s *store) Store(ctx context.Context, request *archivist.StoreRequest) (*em

fmt.Println("object stored")
}
// ********************************************************************************

return &emptypb.Empty{}, nil
}

0 comments on commit e81bb1c

Please sign in to comment.