Skip to content

Commit

Permalink
gateway: ensure llb digests are deterministic when sent by frontends
Browse files Browse the repository at this point in the history
This ensures different valid protobuf serializations that are sent by
frontends will be rewritten into digests that are normalized for the
buildkit solver.

The most recent example of this is that older frontends would generate
protobuf with gogo and the newer buildkit is using the google protobuf
library. These produce different serializations and cause the solver to
think that identical operations are actually different.

This is done by rewriting the incoming definition sent by the llb bridge
forwarder when a gateway calls solve with a protobuf definition.

Signed-off-by: Jonathan A. Sternberg <jonathan.sternberg@docker.com>
  • Loading branch information
jsternberg committed Nov 14, 2024
1 parent c9a17ff commit ea52c76
Show file tree
Hide file tree
Showing 9 changed files with 176 additions and 117 deletions.
2 changes: 1 addition & 1 deletion client/llb/diff.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func (m *DiffOp) Marshal(ctx context.Context, constraints *Constraints) (digest.

proto.Op = &pb.Op_Diff{Diff: op}

dt, err := deterministicMarshal(proto)
dt, err := proto.Marshal()
if err != nil {
return "", nil, nil, nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion client/llb/exec.go
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ func (e *ExecOp) Marshal(ctx context.Context, c *Constraints) (digest.Digest, []
peo.Mounts = append(peo.Mounts, pm)
}

dt, err := deterministicMarshal(pop)
dt, err := pop.Marshal()
if err != nil {
return "", nil, nil, nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion client/llb/fileop.go
Original file line number Diff line number Diff line change
Expand Up @@ -811,7 +811,7 @@ func (f *FileOp) Marshal(ctx context.Context, c *Constraints) (digest.Digest, []
})
}

dt, err := deterministicMarshal(pop)
dt, err := pop.Marshal()
if err != nil {
return "", nil, nil, nil, err
}
Expand Down
5 changes: 0 additions & 5 deletions client/llb/marshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"github.com/containerd/platforms"
"github.com/moby/buildkit/solver/pb"
digest "github.com/opencontainers/go-digest"
"google.golang.org/protobuf/proto"
)

// Definition is the LLB definition structure with per-vertex metadata entries
Expand Down Expand Up @@ -147,7 +146,3 @@ type marshalCacheResult struct {
md *pb.OpMetadata
srcs []*SourceLocation
}

func deterministicMarshal[Message proto.Message](m Message) ([]byte, error) {
return proto.MarshalOptions{Deterministic: true}.Marshal(m)
}
2 changes: 1 addition & 1 deletion client/llb/merge.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func (m *MergeOp) Marshal(ctx context.Context, constraints *Constraints) (digest
}
pop.Op = &pb.Op_Merge{Merge: op}

dt, err := deterministicMarshal(pop)
dt, err := pop.Marshal()
if err != nil {
return "", nil, nil, nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion client/llb/source.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ func (s *SourceOp) Marshal(ctx context.Context, constraints *Constraints) (diges
proto.Platform = nil
}

dt, err := deterministicMarshal(proto)
dt, err := proto.Marshal()
if err != nil {
return "", nil, nil, nil, err
}
Expand Down
220 changes: 128 additions & 92 deletions solver/llbsolver/vertex.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import (
digest "github.com/opencontainers/go-digest"
ocispecs "github.com/opencontainers/image-spec/specs-go/v1"
"github.com/pkg/errors"
"google.golang.org/protobuf/proto"
)

type vertex struct {
Expand Down Expand Up @@ -199,108 +198,33 @@ func newVertex(dgst digest.Digest, op *pb.Op, opMeta *pb.OpMetadata, load func(d
return vtx, nil
}

func recomputeDigests(ctx context.Context, all map[digest.Digest]*pb.Op, visited map[digest.Digest]digest.Digest, dgst digest.Digest) (digest.Digest, error) {
if dgst, ok := visited[dgst]; ok {
return dgst, nil
}
op, ok := all[dgst]
if !ok {
return "", errors.Errorf("invalid missing input digest %s", dgst)
}

var mutated bool
for _, input := range op.Inputs {
select {
case <-ctx.Done():
return "", context.Cause(ctx)
default:
}

iDgst, err := recomputeDigests(ctx, all, visited, digest.Digest(input.Digest))
if err != nil {
return "", err
}
if digest.Digest(input.Digest) != iDgst {
mutated = true
input.Digest = string(iDgst)
}
}

if !mutated {
visited[dgst] = dgst
return dgst, nil
}

dt, err := deterministicMarshal(op)
if err != nil {
return "", err
}
newDgst := digest.FromBytes(dt)
visited[dgst] = newDgst
all[newDgst] = op
delete(all, dgst)
return newDgst, nil
}

// loadLLB loads LLB.
// fn is executed sequentially.
func loadLLB(ctx context.Context, def *pb.Definition, polEngine SourcePolicyEvaluator, fn func(digest.Digest, *pb.Op, func(digest.Digest) (solver.Vertex, error)) (solver.Vertex, error)) (solver.Edge, error) {
if len(def.Def) == 0 {
return solver.Edge{}, errors.New("invalid empty definition")
}

allOps := make(map[digest.Digest]*pb.Op)
mutatedDigests := make(map[digest.Digest]digest.Digest) // key: old, val: new

var lastDgst digest.Digest

for _, dt := range def.Def {
var op pb.Op
if err := op.UnmarshalVT(dt); err != nil {
return solver.Edge{}, errors.Wrap(err, "failed to parse llb proto op")
}
dgst := digest.FromBytes(dt)
mutate := func(ctx context.Context, op *pb.Op) (err error) {
if polEngine != nil {
mutated, err := polEngine.Evaluate(ctx, op.GetSource())
if err != nil {
return solver.Edge{}, errors.Wrap(err, "error evaluating the source policy")
}
if mutated {
dtMutated, err := deterministicMarshal(&op)
if err != nil {
return solver.Edge{}, err
}
dgstMutated := digest.FromBytes(dtMutated)
mutatedDigests[dgst] = dgstMutated
dgst = dgstMutated
if _, err = polEngine.Evaluate(ctx, op.GetSource()); err != nil {
err = errors.Wrap(err, "error evaluating the source policy")
}
}
allOps[dgst] = &op
lastDgst = dgst
return err
}

for dgst := range allOps {
_, err := recomputeDigests(ctx, allOps, mutatedDigests, dgst)
if err != nil {
return solver.Edge{}, err
}
}

if len(allOps) < 2 {
return solver.Edge{}, errors.Errorf("invalid LLB with %d vertexes", len(allOps))
def, dgstMap, err := recomputeDigests(ctx, def, mutate)
if err != nil {
return solver.Edge{}, err
}

for {
newDgst, ok := mutatedDigests[lastDgst]
if !ok || newDgst == lastDgst {
break
}
lastDgst = newDgst
if len(def.Def) < 2 {
return solver.Edge{}, errors.Errorf("invalid LLB with %d vertexes", len(def.Def))
}
lastDgst := digest.FromBytes(def.Def[len(def.Def)-1])

lastOp := allOps[lastDgst]
delete(allOps, lastDgst)
if len(lastOp.Inputs) == 0 {
lastOp := dgstMap.OpFor(lastDgst)
if lastOp == nil || len(lastOp.Inputs) == 0 {
return solver.Edge{}, errors.Errorf("invalid LLB with no inputs on last vertex")
}
dgst := lastOp.Inputs[0].Digest
Expand All @@ -312,8 +236,9 @@ func loadLLB(ctx context.Context, def *pb.Definition, polEngine SourcePolicyEval
if v, ok := cache[dgst]; ok {
return v, nil
}
op, ok := allOps[dgst]
if !ok {

op := dgstMap.OpFor(dgst)
if op == nil {
return nil, errors.Errorf("invalid missing input digest %s", dgst)
}

Expand Down Expand Up @@ -401,6 +326,117 @@ func fileOpName(actions []*pb.FileAction) string {
return strings.Join(names, ", ")
}

func deterministicMarshal[Message proto.Message](m Message) ([]byte, error) {
return proto.MarshalOptions{Deterministic: true}.Marshal(m)
type opMutator func(context.Context, *pb.Op) error

func recomputeDigests(ctx context.Context, def *pb.Definition, mutators ...opMutator) (*pb.Definition, *digestMapping, error) {
dm, err := newDigestMapping(def)
if err != nil {
return nil, nil, err
}
if err := dm.Rewrite(ctx, mutators); err != nil {
return nil, nil, err
}
return dm.out, dm, nil
}

type digestMapping struct {
in, out *pb.Definition
mapping map[digest.Digest]digest.Digest
indexByDigest map[digest.Digest]int
opByDigest map[digest.Digest]*pb.Op
}

func newDigestMapping(def *pb.Definition) (*digestMapping, error) {
dm := &digestMapping{
in: def,
out: def.CloneVT(),
mapping: map[digest.Digest]digest.Digest{},
indexByDigest: map[digest.Digest]int{},
opByDigest: map[digest.Digest]*pb.Op{},
}
for i, in := range def.Def {
dgst := digest.FromBytes(in)

op := new(pb.Op)
if err := op.UnmarshalVT(in); err != nil {
return nil, errors.Wrap(err, "failed to parse llb proto op")
}
dm.opByDigest[dgst] = op
dm.indexByDigest[dgst] = i
}
return dm, nil
}

func (dm *digestMapping) Rewrite(ctx context.Context, mutators []opMutator) error {
for dgst := range dm.indexByDigest {
if _, err := dm.rewrite(ctx, dgst, mutators); err != nil {
return err
}
}
return nil
}

func (dm *digestMapping) OpFor(dgst digest.Digest) *pb.Op {
return dm.opByDigest[dgst]
}

func (dm *digestMapping) IndexFor(dgst digest.Digest) int {
index, ok := dm.indexByDigest[dgst]
if !ok {
return -1
}
return index
}

func (dm *digestMapping) rewrite(ctx context.Context, dgst digest.Digest, mutators []opMutator) (digest.Digest, error) {
if dgst, ok := dm.mapping[dgst]; ok {
return dgst, nil
}

op, ok := dm.opByDigest[dgst]
if !ok {
return "", errors.Errorf("invalid missing input digest %s", dgst)
}

for _, mutator := range mutators {
if err := mutator(ctx, op); err != nil {
return "", err
}
}

// Recompute input digests.
for _, input := range op.Inputs {
select {
case <-ctx.Done():
return "", context.Cause(ctx)
default:
}

iDgst, err := dm.rewrite(ctx, digest.Digest(input.Digest), mutators)
if err != nil {
return "", err
}
if digest.Digest(input.Digest) != iDgst {
input.Digest = string(iDgst)
}
}

// Must use deterministic marshal here so the digest is consistent.
data, err := op.Marshal()
if err != nil {
return "", err
}
newDgst := digest.FromBytes(data)

dm.mapping[dgst] = newDgst
if dgst != newDgst {
// Ensure the indices also map to the new digest.
dm.indexByDigest[newDgst] = dm.indexByDigest[dgst]
dm.opByDigest[newDgst] = dm.opByDigest[dgst]
dm.mapping[newDgst] = newDgst
}

index := dm.indexByDigest[dgst]
dm.out.Def[index] = data
return newDgst, nil
}
42 changes: 28 additions & 14 deletions solver/llbsolver/vertex_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,26 @@ import (
digest "github.com/opencontainers/go-digest"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/proto"
)

func TestRecomputeDigests(t *testing.T) {
const (
busyboxLatest = "docker-image://docker.io/library/busybox:latest"
busyboxPinned = "docker-image://docker.io/library/busybox:1.31.1"
)
op1 := &pb.Op{
Op: &pb.Op_Source{
Source: &pb.SourceOp{
Identifier: "docker-image://docker.io/library/busybox:latest",
Identifier: busyboxLatest,
},
},
}
oldData, err := op1.Marshal()
require.NoError(t, err)
oldDigest := digest.FromBytes(oldData)

op1.GetOp().(*pb.Op_Source).Source.Identifier = "docker-image://docker.io/library/busybox:1.31.1"
op1.GetOp().(*pb.Op_Source).Source.Identifier = busyboxPinned
newData, err := op1.Marshal()
require.NoError(t, err)
newDigest := digest.FromBytes(newData)
Expand All @@ -36,20 +41,29 @@ func TestRecomputeDigests(t *testing.T) {
require.NoError(t, err)
op2Digest := digest.FromBytes(op2Data)

all := map[digest.Digest]*pb.Op{
newDigest: op1,
op2Digest: op2,
// Construct a definition with the marshaled data.
def := &pb.Definition{
Def: [][]byte{oldData, op2Data},
}
visited := map[digest.Digest]digest.Digest{oldDigest: newDigest}
def, dgstMap, err := recomputeDigests(context.Background(), def, func(ctx context.Context, op *pb.Op) error {
if source := op.GetSource(); source != nil {
source.Identifier = busyboxPinned
}
return nil
})
require.NoError(t, err)
require.Len(t, def.Def, 2)

assert.Equal(t, op1, dgstMap.OpFor(newDigest))
assert.True(t, proto.Equal(op1, dgstMap.OpFor(newDigest)))
require.Equal(t, newDigest, dgstMap.mapping[oldDigest])
require.Equal(t, op1, dgstMap.OpFor(newDigest))

updated, err := recomputeDigests(context.Background(), all, visited, op2Digest)
lastDgst := digest.FromBytes(def.Def[len(def.Def)-1])
op2.Inputs[0].Digest = string(newDigest)
assert.True(t, proto.Equal(op2, dgstMap.OpFor(lastDgst)))
err = op2.UnmarshalVT(def.Def[len(def.Def)-1])
require.NoError(t, err)
require.Len(t, visited, 2)
require.Len(t, all, 2)
assert.Equal(t, op1, all[newDigest])
require.Equal(t, newDigest, visited[oldDigest])
require.Equal(t, op1, all[newDigest])
assert.Equal(t, op2, all[updated])
require.Equal(t, newDigest, digest.Digest(op2.Inputs[0].Digest))
assert.NotEqual(t, op2Digest, updated)
assert.NotEqual(t, op2Digest, lastDgst)
}
Loading

0 comments on commit ea52c76

Please sign in to comment.