Skip to content

Commit

Permalink
Support depth for local clones.
Browse files Browse the repository at this point in the history
  • Loading branch information
jelmer committed Oct 29, 2022
1 parent 7fa53c9 commit 44631bc
Show file tree
Hide file tree
Showing 7 changed files with 93 additions and 32 deletions.
3 changes: 3 additions & 0 deletions NEWS
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
0.20.47 UNRELEASED

* Support ``depth`` for local clones,
(Jelmer Vernooij)

* Fix Repo.reset_index.
Previously, it instead took the union with the given tree.
(Christian Sattler, #1072)
Expand Down
20 changes: 12 additions & 8 deletions dulwich/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,7 @@ def _read_side_band64k_data(pkt_seq, channel_callbacks):


def _handle_upload_pack_head(
proto, capabilities, graph_walker, wants, can_read, depth
proto, capabilities, graph_walker, wants, can_read, depth: Optional[int] = None
):
"""Handle the head of a 'git-upload-pack' request.
Expand Down Expand Up @@ -721,8 +721,10 @@ def send_pack(self, path, update_refs, generate_pack_data, progress=None):
"""
raise NotImplementedError(self.send_pack)

def clone(self, path, target_path, mkdir: bool = True, bare=False, origin="origin",
checkout=None, branch=None, progress=None, depth=None):
def clone(self, path, target_path,
mkdir: bool = True, bare=False, origin="origin",
checkout=None, branch=None, progress=None,
depth: Optional[int] = None):
"""Clone a repository."""
from .refs import _set_origin_head, _set_default_branch, _set_head

Expand Down Expand Up @@ -758,6 +760,7 @@ def clone(self, path, target_path, mkdir: bool = True, bare=False, origin="origi

ref_message = b"clone: from " + encoded_path
result = self.fetch(path, target, progress=progress, depth=depth)

_import_remote_refs(
target.refs, origin, result.refs, message=ref_message)

Expand Down Expand Up @@ -857,7 +860,7 @@ def fetch_pack(
graph_walker,
pack_data,
progress=None,
depth=None,
depth: Optional[int] = None,
):
"""Retrieve a pack from a git smart server.
Expand Down Expand Up @@ -1125,7 +1128,7 @@ def fetch_pack(
graph_walker,
pack_data,
progress=None,
depth=None,
depth: Optional[int] = None,
):
"""Retrieve a pack from a git smart server.
Expand Down Expand Up @@ -1483,7 +1486,8 @@ def progress(x):

return SendPackResult(new_refs, ref_status=ref_status)

def fetch(self, path, target, determine_wants=None, progress=None, depth=None):
def fetch(self, path, target, determine_wants=None, progress=None,
depth: Optional[int] = None):
"""Fetch into a target repository.
Args:
Expand Down Expand Up @@ -1515,7 +1519,7 @@ def fetch_pack(
graph_walker,
pack_data,
progress=None,
depth=None,
depth: Optional[int] = None,
):
"""Retrieve a pack from a git smart server.
Expand Down Expand Up @@ -2056,7 +2060,7 @@ def fetch_pack(
graph_walker,
pack_data,
progress=None,
depth=None,
depth: Optional[int] = None,
):
"""Retrieve a pack from a git smart server.
Expand Down
11 changes: 6 additions & 5 deletions dulwich/object_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -1380,7 +1380,7 @@ def __init__(
def add_todo(self, entries):
self.objects_to_send.update([e for e in entries if not e[0] in self.sha_done])

def next(self):
def __next__(self):
while True:
if not self.objects_to_send:
return None
Expand All @@ -1407,7 +1407,7 @@ def next(self):
self.progress(("counting objects: %d\r" % len(self.sha_done)).encode("ascii"))
return (sha, name)

__next__ = next
next = __next__


class ObjectStoreGraphWalker(object):
Expand All @@ -1418,7 +1418,7 @@ class ObjectStoreGraphWalker(object):
get_parents: Function to retrieve parents in the local repo
"""

def __init__(self, local_heads, get_parents, shallow=None):
def __init__(self, local_heads, get_parents, shallow=None, update_shallow=None):
"""Create a new instance.
Args:
Expand All @@ -1431,6 +1431,7 @@ def __init__(self, local_heads, get_parents, shallow=None):
if shallow is None:
shallow = set()
self.shallow = shallow
self.update_shallow = update_shallow

def ack(self, sha):
"""Ack that a revision and its ancestors are present in the source."""
Expand Down Expand Up @@ -1458,7 +1459,7 @@ def ack(self, sha):

ancestors = new_ancestors

def next(self):
def __next__(self):
"""Iterate over ancestors of heads in the target."""
if self.heads:
ret = self.heads.pop()
Expand All @@ -1471,7 +1472,7 @@ def next(self):
return ret
return None

__next__ = next
next = __next__


def commit_tree_changes(object_store, tree, changes):
Expand Down
45 changes: 29 additions & 16 deletions dulwich/repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
MemoryObjectStore,
BaseObjectStore,
ObjectStoreGraphWalker,
find_shallow,
)
from dulwich.objects import (
check_hexsha,
Expand Down Expand Up @@ -437,7 +438,8 @@ def open_index(self) -> "Index":
"""
raise NotImplementedError(self.open_index)

def fetch(self, target, determine_wants=None, progress=None, depth=None):
def fetch(self, target, determine_wants=None, progress=None,
depth: Optional[int] = None):
"""Fetch objects into another repository.
Args:
Expand All @@ -450,9 +452,10 @@ def fetch(self, target, determine_wants=None, progress=None, depth=None):
"""
if determine_wants is None:
determine_wants = target.object_store.determine_wants_all
graph_walker = target.get_graph_walker()
count, pack_data = self.fetch_pack_data(
determine_wants,
target.get_graph_walker(),
graph_walker,
progress=progress,
depth=depth,
)
Expand All @@ -464,8 +467,9 @@ def fetch_pack_data(
determine_wants,
graph_walker,
progress,
*,
get_tagged=None,
depth=None,
depth: Optional[int] = None,
):
"""Fetch the pack data required for a set of revisions.
Expand All @@ -484,7 +488,7 @@ def fetch_pack_data(
"""
# TODO(jelmer): Fetch pack data directly, don't create objects first.
objects = self.fetch_objects(
determine_wants, graph_walker, progress, get_tagged, depth=depth
determine_wants, graph_walker, progress, get_tagged=get_tagged, depth=depth
)
return pack_objects_to_data(objects)

Expand All @@ -493,8 +497,9 @@ def fetch_objects(
determine_wants,
graph_walker,
progress,
*,
get_tagged=None,
depth=None,
depth: Optional[int] = None,
):
"""Fetch the missing objects required for a set of revisions.
Expand All @@ -511,9 +516,6 @@ def fetch_objects(
depth: Shallow fetch depth
Returns: iterator over objects, with __len__ implemented
"""
if depth not in (None, 0):
raise NotImplementedError("depth not supported yet")

refs = {}
for ref, sha in self.get_refs().items():
try:
Expand All @@ -534,14 +536,23 @@ def fetch_objects(
if not isinstance(wants, list):
raise TypeError("determine_wants() did not return a list")

shallows = getattr(graph_walker, "shallow", frozenset())
unshallows = getattr(graph_walker, "unshallow", frozenset())
current_shallow = set(graph_walker.shallow)

if depth not in (None, 0):
shallow, not_shallow = find_shallow(
self.object_store, wants, depth)
graph_walker.shallow.update(shallow - not_shallow)
new_shallow = graph_walker.shallow - current_shallow
unshallow = graph_walker.unshallow = not_shallow & current_shallow
graph_walker.update_shallow(new_shallow, unshallow)
else:
unshallow = getattr(graph_walker, "unshallow", frozenset())

if wants == []:
# TODO(dborowitz): find a way to short-circuit that doesn't change
# this interface.

if shallows or unshallows:
if graph_walker.shallow or unshallow:
# Do not send a pack in shallow short-circuit path
return None

Expand All @@ -554,12 +565,13 @@ def fetch_objects(

# Deal with shallow requests separately because the haves do
# not reflect what objects are missing
if shallows or unshallows:
if graph_walker.shallow or unshallow:
# TODO: filter the haves commits from iter_shas. the specific
# commits aren't missing.
haves = []

parents_provider = ParentsProvider(self.object_store, shallows=shallows)
parents_provider = ParentsProvider(
self.object_store, shallows=current_shallow)

def get_parents(commit):
return parents_provider.get_parents(commit.id, commit)
Expand All @@ -568,7 +580,7 @@ def get_parents(commit):
self.object_store.find_missing_objects(
haves,
wants,
self.get_shallow(),
graph_walker.shallow,
progress,
get_tagged,
get_parents=get_parents,
Expand Down Expand Up @@ -613,7 +625,8 @@ def get_graph_walker(
]
parents_provider = ParentsProvider(self.object_store)
return ObjectStoreGraphWalker(
heads, parents_provider.get_parents, shallow=self.get_shallow()
heads, parents_provider.get_parents, shallow=self.get_shallow(),
update_shallow=self.update_shallow,
)

def get_refs(self) -> Dict[bytes, bytes]:
Expand Down Expand Up @@ -1472,7 +1485,7 @@ def clone(
checkout=None,
branch=None,
progress=None,
depth=None,
depth: Optional[int] = None,
):
"""Clone this repository.
Expand Down
3 changes: 3 additions & 0 deletions dulwich/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,6 +667,9 @@ def _handle_shallow_request(self, wants):
new_shallow = self.shallow - self.client_shallow
unshallow = self.unshallow = not_shallow & self.client_shallow

self.update_shallow(new_shallow, unshallow)

def update_shallow(self, new_shallow: List[bytes], unshallow: List[bytes]):
for sha in sorted(new_shallow):
self.proto.write_pkt_line(format_shallow_line(sha))
for sha in sorted(unshallow):
Expand Down
6 changes: 3 additions & 3 deletions dulwich/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def check_heads(heads, **kwargs):
self.assertEqual({}, heads)
return []

ret = self.client.fetch_pack(b"bla", check_heads, None, None, None)
ret = self.client.fetch_pack(b"bla", check_heads, None, None)
self.assertEqual({}, ret.refs)
self.assertEqual({}, ret.symrefs)
self.assertEqual(self.rout.getvalue(), b"0000")
Expand All @@ -197,7 +197,7 @@ def test_fetch_pack_none(self):
b"0000"
)
self.rin.seek(0)
ret = self.client.fetch_pack(b"bla", lambda heads, **kwargs: [], None, None, None)
ret = self.client.fetch_pack(b"bla", lambda heads, **kwargs: [], None, None)
self.assertEqual(
{b"HEAD": b"55dcc6bf963f922e1ed5c4bbaaefcfacef57b1d7"}, ret.refs
)
Expand Down Expand Up @@ -891,7 +891,7 @@ def test_fetch_empty(self):
s = open_repo("a.git")
self.addCleanup(tear_down_repo, s)
out = BytesIO()
walker = {}
walker = MemoryRepo().get_graph_walker()
ret = c.fetch_pack(
s.path, lambda heads, **kwargs: [], graph_walker=walker, pack_data=out.write
)
Expand Down
37 changes: 37 additions & 0 deletions dulwich/tests/test_porcelain.py
Original file line number Diff line number Diff line change
Expand Up @@ -872,6 +872,43 @@ def test_fetch_symref(self):
target_repo.refs.get_symrefs(),
)

def test_local_depth(self):
f1_1 = make_object(Blob, data=b"f1")
commit_spec = [[1], [2, 1], [3, 1, 2]]
trees = {
1: [(b"f1", f1_1), (b"f2", f1_1)],
2: [(b"f1", f1_1), (b"f2", f1_1)],
3: [(b"f1", f1_1), (b"f2", f1_1)],
}

c1, c2, c3 = build_commit_graph(self.repo.object_store, commit_spec, trees)
self.repo.refs[b"refs/heads/master"] = c3.id
self.repo.refs[b"refs/tags/foo"] = c3.id
target_path = tempfile.mkdtemp()
errstream = BytesIO()
self.addCleanup(shutil.rmtree, target_path)
r = porcelain.clone(
self.repo.path, target_path, checkout=False, errstream=errstream,
depth=1
)
self.addCleanup(r.close)
self.assertEqual(r.path, target_path)
target_repo = Repo(target_path)
self.assertEqual([c3.id], [w.commit.id for w in target_repo.get_walker()])
self.assertEqual(0, len(target_repo.open_index()))
self.assertEqual(c3.id, target_repo.refs[b"refs/tags/foo"])
self.assertNotIn(b"f1", os.listdir(target_path))
self.assertNotIn(b"f2", os.listdir(target_path))
c = r.get_config()
encoded_path = self.repo.path
if not isinstance(encoded_path, bytes):
encoded_path = encoded_path.encode("utf-8")
self.assertEqual(encoded_path, c.get((b"remote", b"origin"), b"url"))
self.assertEqual(
b"+refs/heads/*:refs/remotes/origin/*",
c.get((b"remote", b"origin"), b"fetch"),
)


class InitTests(TestCase):
def test_non_bare(self):
Expand Down

0 comments on commit 44631bc

Please sign in to comment.