Skip to content

Commit

Permalink
Fix handling of symrefs with protocol v2. Fixes #1389 (#1392)
Browse files Browse the repository at this point in the history
  • Loading branch information
jelmer authored Oct 20, 2024
2 parents 15d6c81 + f1075d2 commit b1287d3
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 48 deletions.
5 changes: 5 additions & 0 deletions NEWS
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
0.22.4 UNRELEASED

* Fix handling of symrefs with protocol v2.
(Jelmer Vernooij, #1389)

0.22.3 2024-10-15

* Improve wheel building in CI, so we can upload wheels for the next release.
Expand Down
128 changes: 80 additions & 48 deletions dulwich/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,33 @@ def read_server_capabilities(pkt_seq):
return set(server_capabilities)


def read_pkt_refs(pkt_seq, server_capabilities=None):
def read_pkt_refs_v2(
pkt_seq,
) -> Tuple[Dict[bytes, bytes], Dict[bytes, bytes], Dict[bytes, bytes]]:
refs = {}
symrefs = {}
peeled = {}
# Receive refs from server
for pkt in pkt_seq:
parts = pkt.rstrip(b"\n").split(b" ")
sha = parts[0]
if sha == b"unborn":
sha = None
ref = parts[1]
for part in parts[2:]:
if part.startswith(b"peeled:"):
peeled[ref] = part[7:]
elif part.startswith(b"symref-target:"):
symrefs[ref] = part[14:]
else:
logging.warning("unknown part in pkt-ref: %s", part)
refs[ref] = sha

return refs, symrefs, peeled


def read_pkt_refs_v1(pkt_seq) -> Tuple[Dict[bytes, bytes], Set[bytes]]:
server_capabilities = None
refs = {}
# Receive refs from server
for pkt in pkt_seq:
Expand All @@ -267,24 +293,13 @@ def read_pkt_refs(pkt_seq, server_capabilities=None):
raise GitProtocolError(ref.decode("utf-8", "replace"))
if server_capabilities is None:
(ref, server_capabilities) = extract_capabilities(ref)
else: # Git protocol-v2:
try:
symref, target = ref.split(b" ", 1)
except ValueError:
pass
else:
if symref and target and target[:14] == b"symref-target:":
server_capabilities.add(
b"%s=%s:%s"
% (CAPABILITY_SYMREF, symref, target.split(b":", 1)[1])
)
ref = symref
refs[ref] = sha

if len(refs) == 0:
return {}, set()
if refs == {CAPABILITIES_REF: ZERO_SHA}:
refs = {}
assert server_capabilities is not None
return refs, set(server_capabilities)


Expand Down Expand Up @@ -684,6 +699,26 @@ def progress(x):
pack_data(data)


def _extract_symrefs_and_agent(capabilities):
"""Extract symrefs and agent from capabilities.
Args:
capabilities: List of capabilities
Returns:
(symrefs, agent) tuple
"""
symrefs = {}
agent = None
for capability in capabilities:
k, v = parse_capability(capability)
if k == CAPABILITY_SYMREF:
(src, dst) = v.split(b":", 1)
symrefs[src] = dst
if k == CAPABILITY_AGENT:
agent = v
return (symrefs, agent)


# TODO(durin42): this doesn't correctly degrade if the server doesn't
# support some capabilities. This should work properly with servers
# that don't support multi_ack.
Expand Down Expand Up @@ -1012,11 +1047,7 @@ def _should_send_pack(new_refs):

def _negotiate_receive_pack_capabilities(self, server_capabilities):
negotiated_capabilities = self._send_capabilities & server_capabilities
agent = None
for capability in server_capabilities:
k, v = parse_capability(capability)
if k == CAPABILITY_AGENT:
agent = v
(agent, _symrefs) = _extract_symrefs_and_agent(server_capabilities)
(extract_capability_names(server_capabilities) - KNOWN_RECEIVE_CAPABILITIES)
# TODO(jelmer): warn about unknown capabilities
return negotiated_capabilities, agent
Expand Down Expand Up @@ -1069,23 +1100,16 @@ def progress(x):
def _negotiate_upload_pack_capabilities(self, server_capabilities):
(extract_capability_names(server_capabilities) - KNOWN_UPLOAD_CAPABILITIES)
# TODO(jelmer): warn about unknown capabilities
symrefs = {}
agent = None
fetch_capa = None
for capability in server_capabilities:
k, v = parse_capability(capability)
if k == CAPABILITY_SYMREF:
(src, dst) = v.split(b":", 1)
symrefs[src] = dst
if k == CAPABILITY_AGENT:
agent = v
if self.protocol_version == 2 and k == CAPABILITY_FETCH:
fetch_capa = CAPABILITY_FETCH
fetch_features = []
v = v.strip()
if b"shallow" in v.split(b" "):
v = v.strip().split(b" ")
if b"shallow" in v:
fetch_features.append(CAPABILITY_SHALLOW)
if b"filter" in v.split(b" "):
if b"filter" in v:
fetch_features.append(CAPABILITY_FILTER)
for i in range(len(fetch_features)):
if i == 0:
Expand All @@ -1094,6 +1118,8 @@ def _negotiate_upload_pack_capabilities(self, server_capabilities):
fetch_capa += b" "
fetch_capa += fetch_features[i]

(symrefs, agent) = _extract_symrefs_and_agent(server_capabilities)

negotiated_capabilities = self._fetch_capabilities & server_capabilities
if fetch_capa:
negotiated_capabilities.add(fetch_capa)
Expand Down Expand Up @@ -1196,7 +1222,7 @@ def send_pack(self, path, update_refs, generate_pack_data, progress=None):
proto, unused_can_read, stderr = self._connect(b"receive-pack", path)
with proto:
try:
old_refs, server_capabilities = read_pkt_refs(proto.read_pkt_seq())
old_refs, server_capabilities = read_pkt_refs_v1(proto.read_pkt_seq())
except HangupException as exc:
raise _remote_error_from_stderr(stderr) from exc
(
Expand Down Expand Up @@ -1329,7 +1355,7 @@ def fetch_pack(
server_capabilities = read_server_capabilities(proto.read_pkt_seq())
refs = None
else:
refs, server_capabilities = read_pkt_refs(proto.read_pkt_seq())
refs, server_capabilities = read_pkt_refs_v1(proto.read_pkt_seq())
except HangupException as exc:
raise _remote_error_from_stderr(stderr) from exc
(
Expand All @@ -1345,9 +1371,7 @@ def fetch_pack(
for prefix in ref_prefix:
proto.write_pkt_line(b"ref-prefix " + prefix)
proto.write_pkt_line(None)
refs, server_capabilities = read_pkt_refs(
proto.read_pkt_seq(), server_capabilities
)
refs, symrefs, _peeled = read_pkt_refs_v2(proto.read_pkt_seq())

if refs is None:
proto.write_pkt_line(None)
Expand Down Expand Up @@ -1425,17 +1449,22 @@ def get_refs(self, path, protocol_version=None):
proto.write(b"0001") # delim-pkt
proto.write_pkt_line(b"symrefs")
proto.write_pkt_line(None)
with proto:
try:
refs, _symrefs, _peeled = read_pkt_refs_v2(proto.read_pkt_seq())
except HangupException as exc:
raise _remote_error_from_stderr(stderr) from exc
proto.write_pkt_line(None)
return refs
else:
server_capabilities = None # read_pkt_refs will find them
with proto:
try:
refs, server_capabilities = read_pkt_refs(
proto.read_pkt_seq(), server_capabilities
)
except HangupException as exc:
raise _remote_error_from_stderr(stderr) from exc
proto.write_pkt_line(None)
return refs
with proto:
try:
refs, server_capabilities = read_pkt_refs_v1(proto.read_pkt_seq())
except HangupException as exc:
raise _remote_error_from_stderr(stderr) from exc
proto.write_pkt_line(None)
(_symrefs, _agent) = _extract_symrefs_and_agent(server_capabilities)
return refs

def archive(
self,
Expand Down Expand Up @@ -2384,6 +2413,9 @@ def begin_protocol_v2(proto):
self.protocol_version = server_protocol_version
if self.protocol_version == 2:
server_capabilities, resp, read, proto = begin_protocol_v2(proto)
(refs, _symrefs, _peeled) = read_pkt_refs_v2(proto.read_pkt_seq())
return refs, server_capabilities, base_url

else:
server_capabilities = None # read_pkt_refs will find them
try:
Expand Down Expand Up @@ -2414,11 +2446,11 @@ def begin_protocol_v2(proto):
server_capabilities, resp, read, proto = begin_protocol_v2(
proto
)
(
refs,
server_capabilities,
) = read_pkt_refs(proto.read_pkt_seq(), server_capabilities)
return refs, server_capabilities, base_url
(
refs,
server_capabilities,
) = read_pkt_refs_v1(proto.read_pkt_seq())
return refs, server_capabilities, base_url
else:
self.protocol_version = 0 # dumb servers only support protocol v0
return read_info_refs(resp), set(), base_url
Expand Down
10 changes: 10 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
SubprocessSSHVendor,
TCPGitClient,
TraditionalGitClient,
_extract_symrefs_and_agent,
_remote_error_from_stderr,
check_wants,
default_urllib3_manager,
Expand Down Expand Up @@ -1867,3 +1868,12 @@ def test_no_error_line(self):
]
),
)


class TestExtractAgentAndSymrefs(TestCase):
def test_extract_agent_and_symrefs(self):
(symrefs, agent) = _extract_symrefs_and_agent(
[b"agent=git/2.31.1", b"symref=HEAD:refs/heads/master"]
)
self.assertEqual(agent, b"git/2.31.1")
self.assertEqual(symrefs, {b"HEAD": b"refs/heads/master"})

0 comments on commit b1287d3

Please sign in to comment.