diff --git a/coding.go b/coding.go index bdc1dcd..952c1aa 100644 --- a/coding.go +++ b/coding.go @@ -23,6 +23,14 @@ const _ = pb.DoNotUpgradeFileEverItWillChangeYourHashes // for now, we use a PBNode intermediate thing. // because native go objects are nice. +// pbLinkSlice is a slice of pb.PBLink, similar to LinkSlice but for sorting the +// PB form +type pbLinkSlice []*pb.PBLink + +func (pbls pbLinkSlice) Len() int { return len(pbls) } +func (pbls pbLinkSlice) Swap(a, b int) { pbls[a], pbls[b] = pbls[b], pbls[a] } +func (pbls pbLinkSlice) Less(a, b int) bool { return *pbls[a].Name < *pbls[b].Name } + // unmarshal decodes raw data into a *Node instance. // The conversion uses an intermediate PBNode. func unmarshal(encodedBytes []byte) (*ProtoNode, error) { @@ -41,6 +49,9 @@ func fromImmutableNode(encoded *immutableProtoNode) *ProtoNode { n.data = n.encoded.PBNode.Data.Must().Bytes() } numLinks := n.encoded.PBNode.Links.Length() + // links may not be sorted after deserialization, but we don't change + // them until we mutate this node since we're representing the current, + // as-serialized state n.links = make([]*format.Link, numLinks) linkAllocs := make([]format.Link, numLinks) for i := int64(0); i < numLinks; i++ { @@ -63,9 +74,12 @@ func fromImmutableNode(encoded *immutableProtoNode) *ProtoNode { return n } func (n *ProtoNode) marshalImmutable() (*immutableProtoNode, error) { + // ensure links are sorted, but don't modify existing link slice + links := n.Links() + sort.Stable(LinkSlice(links)) nd, err := qp.BuildMap(dagpb.Type.PBNode, 2, func(ma ipld.MapAssembler) { - qp.MapEntry(ma, "Links", qp.List(int64(len(n.links)), func(la ipld.ListAssembler) { - for _, link := range n.links { + qp.MapEntry(ma, "Links", qp.List(int64(len(links)), func(la ipld.ListAssembler) { + for _, link := range links { qp.ListEntry(la, qp.Map(3, func(ma ipld.MapAssembler) { if link.Cid.Defined() { qp.MapEntry(ma, "Hash", qp.Link(cidlink.Link{Cid: link.Cid})) @@ -113,7 +127,6 @@ func (n *ProtoNode) GetPBNode() *pb.PBNode { pbn.Links = make([]*pb.PBLink, len(n.links)) } - sort.Stable(LinkSlice(n.links)) // keep links sorted for i, l := range n.links { pbn.Links[i] = &pb.PBLink{} pbn.Links[i].Name = &l.Name @@ -123,6 +136,10 @@ func (n *ProtoNode) GetPBNode() *pb.PBNode { } } + // Ensure links are sorted prior to encode. They may not have come sorted if + // we deserialized a badly encoded form that didn't have links already sorted. + sort.Stable(pbLinkSlice(pbn.Links)) + if len(n.data) > 0 { pbn.Data = n.data } @@ -132,7 +149,6 @@ func (n *ProtoNode) GetPBNode() *pb.PBNode { // EncodeProtobuf returns the encoded raw data version of a Node instance. // It may use a cached encoded version, unless the force flag is given. func (n *ProtoNode) EncodeProtobuf(force bool) ([]byte, error) { - sort.Stable(LinkSlice(n.links)) // keep links sorted if n.encoded == nil || force { n.cached = cid.Undef var err error diff --git a/merkledag_test.go b/merkledag_test.go index 17a05c6..449526f 100644 --- a/merkledag_test.go +++ b/merkledag_test.go @@ -16,6 +16,7 @@ import ( . "github.com/ipfs/go-merkledag" mdpb "github.com/ipfs/go-merkledag/pb" + pb "github.com/ipfs/go-merkledag/pb" dstest "github.com/ipfs/go-merkledag/test" blocks "github.com/ipfs/go-block-format" @@ -745,6 +746,164 @@ func TestEnumerateAsyncFailsNotFound(t *testing.T) { } } +func TestLinkSorting(t *testing.T) { + az := "az" + aaaa := "aaaa" + bbbb := "bbbb" + cccc := "cccc" + + azBlk := NewRawNode([]byte(az)) + aaaaBlk := NewRawNode([]byte(aaaa)) + bbbbBlk := NewRawNode([]byte(bbbb)) + ccccBlk := NewRawNode([]byte(cccc)) + pbn := &pb.PBNode{ + Links: []*pb.PBLink{ + {Hash: bbbbBlk.Cid().Bytes(), Name: &bbbb}, + {Hash: azBlk.Cid().Bytes(), Name: &az}, + {Hash: aaaaBlk.Cid().Bytes(), Name: &aaaa}, + {Hash: ccccBlk.Cid().Bytes(), Name: &cccc}, + }, + } + byts, err := pbn.Marshal() + if err != nil { + t.Fatal(err) + } + + verifyUnsortedNode := func(t *testing.T, node *ProtoNode) { + links := node.Links() + if len(links) != 4 { + t.Errorf("wrong number of links, expected 4 but got %d", len(links)) + } + if links[0].Name != bbbb { + t.Errorf("expected link 0 to be 'bbbb', got %s", links[0].Name) + } + if links[1].Name != az { + t.Errorf("expected link 0 to be 'az', got %s", links[1].Name) + } + if links[2].Name != aaaa { + t.Errorf("expected link 0 to be 'aaaa', got %s", links[2].Name) + } + if links[3].Name != cccc { + t.Errorf("expected link 0 to be 'cccc', got %s", links[3].Name) + } + } + verifySortedNode := func(t *testing.T, node *ProtoNode) { + links := node.Links() + if len(links) != 4 { + t.Errorf("wrong number of links, expected 4 but got %d", len(links)) + } + if links[0].Name != aaaa { + t.Errorf("expected link 0 to be 'aaaa', got %s", links[0].Name) + } + if links[1].Name != az { + t.Errorf("expected link 0 to be 'az', got %s", links[1].Name) + } + if links[2].Name != bbbb { + t.Errorf("expected link 0 to be 'bbbb', got %s", links[2].Name) + } + if links[3].Name != cccc { + t.Errorf("expected link 0 to be 'cccc', got %s", links[3].Name) + } + } + + t.Run("decode", func(t *testing.T) { + node, err := DecodeProtobuf(byts) + if err != nil { + t.Fatal(err) + } + verifyUnsortedNode(t, node) + }) + + t.Run("RawData() should not mutate, should return original form", func(t *testing.T) { + node, err := DecodeProtobuf(byts) + if err != nil { + t.Fatal(err) + } + rawData := node.RawData() + verifyUnsortedNode(t, node) + if !bytes.Equal(rawData, byts) { + t.Error("RawData() did not return original bytes") + } + }) + + t.Run("Size() should not mutate", func(t *testing.T) { + node, err := DecodeProtobuf(byts) + if err != nil { + t.Fatal(err) + } + sz, err := node.Size() + if err != nil { + t.Fatal(err) + } + if sz != 182 { + t.Errorf("expected size to be 182, got %d", sz) + } + verifyUnsortedNode(t, node) + }) + + t.Run("GetPBNode() should not mutate, returned PBNode should be sorted", func(t *testing.T) { + node, err := DecodeProtobuf(byts) + if err != nil { + t.Fatal(err) + } + rtPBNode := node.GetPBNode() + rtByts, err := rtPBNode.Marshal() + if err != nil { + t.Fatal(err) + } + verifyUnsortedNode(t, node) + rtNode, err := DecodeProtobuf(rtByts) + if err != nil { + t.Fatal(err) + } + verifySortedNode(t, rtNode) + }) + + t.Run("add and remove link should mutate", func(t *testing.T) { + node, err := DecodeProtobuf(byts) + if err != nil { + t.Fatal(err) + } + someCid, _ := cid.Cast([]byte{1, 85, 0, 5, 0, 1, 2, 3, 4}) + if err = node.AddRawLink("foo", &ipld.Link{ + Size: 10, + Cid: someCid, + }); err != nil { + t.Fatal(err) + } + if err = node.RemoveNodeLink("foo"); err != nil { + t.Fatal(err) + } + verifySortedNode(t, node) + }) + + t.Run("update link should not mutate, returned ProtoNode should be sorted", func(t *testing.T) { + node, err := DecodeProtobuf(byts) + if err != nil { + t.Fatal(err) + } + newNode, err := node.UpdateNodeLink("self", node) + if err != nil { + t.Fatal(err) + } + if err = newNode.RemoveNodeLink("self"); err != nil { + t.Fatal(err) + } + verifySortedNode(t, newNode) + verifyUnsortedNode(t, node) + }) + + t.Run("SetLinks() should mutate", func(t *testing.T) { + node, err := DecodeProtobuf(byts) + if err != nil { + t.Fatal(err) + } + links := node.Links() // clone + node.SetLinks(links) + verifySortedNode(t, node) + }) +} + func TestProgressIndicator(t *testing.T) { testProgressIndicator(t, 5) } diff --git a/node.go b/node.go index cafd9c3..4856565 100644 --- a/node.go +++ b/node.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "sort" blocks "github.com/ipfs/go-block-format" cid "github.com/ipfs/go-cid" @@ -115,7 +116,15 @@ func NodeWithData(d []byte) *ProtoNode { return &ProtoNode{data: d} } -// AddNodeLink adds a link to another node. +// AddNodeLink adds a link to another node. The link will be added in +// sorted order. +// +// If sorting has not already been applied to this node (because +// it was deserialized from a form that did not have sorted links), the links +// list will be sorted. If a ProtoNode was deserialized from a badly encoded +// form that did not already have its links sorted, calling AddNodeLink and then +// RemoveNodeLink for the same link, will not result in an identically encoded +// form as the links will have been sorted. func (n *ProtoNode) AddNodeLink(name string, that format.Node) error { lnk, err := format.MakeLink(that) if err != nil { @@ -129,7 +138,15 @@ func (n *ProtoNode) AddNodeLink(name string, that format.Node) error { return nil } -// AddRawLink adds a copy of a link to this node +// AddRawLink adds a copy of a link to this node. The link will be added in +// sorted order. +// +// If sorting has not already been applied to this node (because +// it was deserialized from a form that did not have sorted links), the links +// list will be sorted. If a ProtoNode was deserialized from a badly encoded +// form that did not already have its links sorted, calling AddRawLink and then +// RemoveNodeLink for the same link, will not result in an identically encoded +// form as the links will have been sorted. func (n *ProtoNode) AddRawLink(name string, l *format.Link) error { n.encoded = nil n.links = append(n.links, &format.Link{ @@ -137,11 +154,13 @@ func (n *ProtoNode) AddRawLink(name string, l *format.Link) error { Size: l.Size, Cid: l.Cid, }) - + sort.Stable(LinkSlice(n.links)) return nil } -// RemoveNodeLink removes a link on this node by the given name. +// RemoveNodeLink removes a link on this node by the given name. If there are +// no links with this name, ErrLinkNotFound will be returned. If there are more +// thank one link with this name, they will all be removed. func (n *ProtoNode) RemoveNodeLink(name string) error { n.encoded = nil @@ -244,7 +263,12 @@ func (n *ProtoNode) SetData(d []byte) { } // UpdateNodeLink return a copy of the node with the link name set to point to -// that. If a link of the same name existed, it is removed. +// that. The link will be added in sorted order. If a link of the same name +// existed, it is removed. +// +// If sorting has not already been applied to this node (because +// it was deserialized from a form that did not have sorted links), the links +// list will be sorted in the returned copy. func (n *ProtoNode) UpdateNodeLink(name string, that *ProtoNode) (*ProtoNode, error) { newnode := n.Copy().(*ProtoNode) _ = newnode.RemoveNodeLink(name) // ignore error @@ -309,6 +333,9 @@ func (n *ProtoNode) UnmarshalJSON(b []byte) error { } n.data = s.Data + // links may not be sorted after deserialization, but we don't change + // them until we mutate this node since we're representing the current, + // as-serialized state n.links = s.Links return nil } @@ -358,14 +385,19 @@ func (n *ProtoNode) Multihash() mh.Multihash { return n.cached.Hash() } -// Links returns the node links. +// Links returns a copy of the node's links. func (n *ProtoNode) Links() []*format.Link { - return n.links + links := make([]*format.Link, len(n.links)) + copy(links, n.links) + return links } -// SetLinks replaces the node links with the given ones. +// SetLinks replaces the node links with a copy of the provided links. Sorting +// will be applied to the list. func (n *ProtoNode) SetLinks(links []*format.Link) { - n.links = links + n.links = make([]*format.Link, len(links)) + copy(n.links, links) + sort.Stable(LinkSlice(n.links)) } // Resolve is an alias for ResolveLink.