Skip to content

Commit

Permalink
Various minor rust improvements (#1390)
Browse files Browse the repository at this point in the history
* Avoid copying names in sorted_tree_items
* Add some more tests for sorted_tree_items
* Make sure we're actually testing the rust implementations and not
accidentally the Python ones
  • Loading branch information
jelmer authored Oct 17, 2024
2 parents cd30df4 + 7b881b3 commit 15d6c81
Show file tree
Hide file tree
Showing 7 changed files with 149 additions and 83 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/pythontest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ jobs:
- name: Build
run: |
python setup.py build_ext -i
env:
RUSTFLAGS: "-D warnings"
- name: codespell
run: |
pip install --upgrade codespell
Expand Down
10 changes: 5 additions & 5 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 6 additions & 1 deletion crates/diff-tree/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,12 @@ fn entry_path_cmp(entry1: &Bound<PyAny>, entry2: &Bound<PyAny>) -> PyResult<Orde
}

#[pyfunction]
fn _merge_entries(py: Python, path: &[u8], tree1: &Bound<PyAny>, tree2: &Bound<PyAny>) -> PyResult<PyObject> {
fn _merge_entries(
py: Python,
path: &[u8],
tree1: &Bound<PyAny>,
tree2: &Bound<PyAny>,
) -> PyResult<PyObject> {
let entries1 = tree_entries(path, tree1, py)?;
let entries2 = tree_entries(path, tree2, py)?;

Expand Down
80 changes: 35 additions & 45 deletions crates/objects/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
*/

use memchr::memchr;
use std::borrow::Cow;

use pyo3::exceptions::PyTypeError;
use pyo3::import_exception;
Expand All @@ -30,6 +29,7 @@ import_exception!(dulwich.errors, ObjectFormatException);

const S_IFDIR: u32 = 0o40000;

#[inline]
fn bytehex(byte: u8) -> u8 {
match byte {
0..=9 => byte + b'0',
Expand Down Expand Up @@ -58,36 +58,19 @@ fn parse_tree(
let mut entries = Vec::new();
let strict = strict.unwrap_or(false);
while !text.is_empty() {
let mode_end = match memchr(b' ', text) {
Some(e) => e,
None => {
return Err(ObjectFormatException::new_err((
"Missing terminator for mode",
)));
}
};
let mode_end = memchr(b' ', text)
.ok_or_else(|| ObjectFormatException::new_err(("Missing terminator for mode",)))?;
let text_str = String::from_utf8_lossy(&text[..mode_end]).to_string();
let mode = match u32::from_str_radix(text_str.as_str(), 8) {
Ok(m) => m,
Err(e) => {
return Err(ObjectFormatException::new_err((format!(
"invalid mode: {}",
e
),)));
}
};
let mode = u32::from_str_radix(text_str.as_str(), 8)
.map_err(|e| ObjectFormatException::new_err((format!("invalid mode: {}", e),)))?;
if strict && text[0] == b'0' {
return Err(ObjectFormatException::new_err((
"Illegal leading zero on mode",
)));
}
text = &text[mode_end + 1..];
let namelen = match memchr(b'\0', text) {
Some(nl) => nl,
None => {
return Err(ObjectFormatException::new_err(("Missing trailing \\0",)));
}
};
let namelen = memchr(b'\0', text)
.ok_or_else(|| ObjectFormatException::new_err(("Missing trailing \\0",)))?;
let name = &text[..namelen];
if namelen + 20 >= text.len() {
return Err(ObjectFormatException::new_err(("SHA truncated",)));
Expand All @@ -104,14 +87,20 @@ fn parse_tree(
Ok(entries)
}

fn name_with_suffix(mode: u32, name: &[u8]) -> Cow<[u8]> {
if mode & S_IFDIR != 0 {
let mut v = name.to_vec();
v.push(b'/');
v.into()
} else {
name.into()
fn cmp_with_suffix(a: (u32, &[u8]), b: (u32, &[u8])) -> std::cmp::Ordering {
let len = std::cmp::min(a.1.len(), b.1.len());
let cmp = a.1[..len].cmp(&b.1[..len]);
if cmp != std::cmp::Ordering::Equal {
return cmp;
}

let c1 =
a.1.get(len)
.map_or_else(|| if a.0 & S_IFDIR != 0 { b'/' } else { 0 }, |&c| c);
let c2 =
b.1.get(len)
.map_or_else(|| if b.0 & S_IFDIR != 0 { b'/' } else { 0 }, |&c| c);
c1.cmp(&c2)
}

/// Iterate over a tree entries dictionary.
Expand All @@ -125,23 +114,24 @@ fn name_with_suffix(mode: u32, name: &[u8]) -> Cow<[u8]> {
///
/// # Returns: Iterator over (name, mode, hexsha)
#[pyfunction]
fn sorted_tree_items(py: Python, entries: &Bound<PyDict>, name_order: bool) -> PyResult<Vec<PyObject>> {
let mut qsort_entries = Vec::new();
for (name, e) in entries.iter() {
let (mode, sha): (u32, Vec<u8>) = match e.extract() {
Ok(o) => o,
Err(e) => {
return Err(PyTypeError::new_err((format!("invalid type: {}", e),)));
}
};
qsort_entries.push((name.extract::<Vec<u8>>().unwrap(), mode, sha));
}
fn sorted_tree_items(
py: Python,
entries: &Bound<PyDict>,
name_order: bool,
) -> PyResult<Vec<PyObject>> {
let mut qsort_entries = entries
.iter()
.map(|(name, value)| -> PyResult<(Vec<u8>, u32, Vec<u8>)> {
let value = value
.extract::<(u32, Vec<u8>)>()
.map_err(|e| PyTypeError::new_err((format!("invalid type: {}", e),)))?;
Ok((name.extract::<Vec<u8>>().unwrap(), value.0, value.1))
})
.collect::<PyResult<Vec<(Vec<u8>, u32, Vec<u8>)>>>()?;
if name_order {
qsort_entries.sort_by(|a, b| a.0.cmp(&b.0));
} else {
qsort_entries.sort_by(|a, b| {
name_with_suffix(a.1, a.0.as_slice()).cmp(&name_with_suffix(b.1, b.0.as_slice()))
});
qsort_entries.sort_by(|a, b| cmp_with_suffix((a.1, a.0.as_slice()), (b.1, b.0.as_slice())));
}
let objectsm = py.import_bound("dulwich.objects")?;
let tree_entry_cls = objectsm.getattr("TreeEntry")?;
Expand Down
34 changes: 23 additions & 11 deletions crates/pack/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
* License, Version 2.0.
*/

use pyo3::exceptions::{PyTypeError, PyValueError};
use pyo3::prelude::*;
use pyo3::types::{PyList,PyBytes};
use pyo3::exceptions::{PyValueError, PyTypeError};
use pyo3::types::{PyBytes, PyList};

pyo3::import_exception!(dulwich.errors, ApplyDeltaError);

Expand All @@ -39,7 +39,13 @@ fn py_is_sha(sha: &PyObject, py: Python) -> PyResult<bool> {
}

#[pyfunction]
fn bisect_find_sha(py: Python, start: i32, end: i32, sha: Py<PyBytes>, unpack_name: PyObject) -> PyResult<Option<i32>> {
fn bisect_find_sha(
py: Python,
start: i32,
end: i32,
sha: Py<PyBytes>,
unpack_name: PyObject,
) -> PyResult<Option<i32>> {
// Convert sha_obj to a byte slice
let sha = sha.as_bytes(py);
let sha_len = sha.len();
Expand Down Expand Up @@ -99,7 +105,10 @@ fn get_delta_header_size(delta: &[u8], index: &mut usize, length: usize) -> usiz
size
}

fn py_chunked_as_string<'a>(py: Python<'a>, py_buf: &'a PyObject) -> PyResult<std::borrow::Cow<'a, [u8]>> {
fn py_chunked_as_string<'a>(
py: Python<'a>,
py_buf: &'a PyObject,
) -> PyResult<std::borrow::Cow<'a, [u8]>> {
if let Ok(py_list) = py_buf.extract::<Bound<PyList>>(py) {
let mut buf = Vec::new();
for chunk in py_list.iter() {
Expand All @@ -108,14 +117,19 @@ fn py_chunked_as_string<'a>(py: Python<'a>, py_buf: &'a PyObject) -> PyResult<st
} else if let Ok(chunk) = chunk.extract::<Vec<u8>>() {
buf.extend(chunk);
} else {
return Err(PyTypeError::new_err(format!("chunk is not a byte string, but a {:?}", chunk.get_type().name())));
return Err(PyTypeError::new_err(format!(
"chunk is not a byte string, but a {:?}",
chunk.get_type().name()
)));
}
}
Ok(buf.into())
} else if py_buf.extract::<Bound<PyBytes>>(py).is_ok() {
Ok(std::borrow::Cow::Borrowed(py_buf.extract::<&[u8]>(py)?))
} else {
Err(PyTypeError::new_err("buf is not a string or a list of chunks"))
Err(PyTypeError::new_err(
"buf is not a string or a list of chunks",
))
}
}

Expand Down Expand Up @@ -168,10 +182,7 @@ fn apply_delta(py: Python, py_src_buf: PyObject, py_delta: PyObject) -> PyResult
cp_size = 0x10000;
}

if cp_off + cp_size < cp_size
|| cp_off + cp_size > src_size
|| cp_size > dest_size
{
if cp_off + cp_size < cp_size || cp_off + cp_size > src_size || cp_size > dest_size {
break;
}

Expand All @@ -187,7 +198,8 @@ fn apply_delta(py: Python, py_src_buf: PyObject, py_delta: PyObject) -> PyResult
return Err(ApplyDeltaError::new_err("Not enough space to copy"));
}

out[outindex..outindex + cmd as usize].copy_from_slice(&delta[index..index + cmd as usize]);
out[outindex..outindex + cmd as usize]
.copy_from_slice(&delta[index..index + cmd as usize]);
outindex += cmd as usize;
index += cmd as usize;
} else {
Expand Down
18 changes: 13 additions & 5 deletions dulwich/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -1027,19 +1027,19 @@ def sorted_tree_items(entries, name_order: bool):
yield TreeEntry(name, mode, hexsha)


def key_entry(entry) -> bytes:
def key_entry(entry: Tuple[bytes, Tuple[int, ObjectID]]) -> bytes:
"""Sort key for tree entry.
Args:
entry: (name, value) tuple
"""
(name, value) = entry
if stat.S_ISDIR(value[0]):
(name, (mode, _sha)) = entry
if stat.S_ISDIR(mode):
name += b"/"
return name


def key_entry_name_order(entry):
def key_entry_name_order(entry: Tuple[bytes, Tuple[int, ObjectID]]) -> bytes:
"""Sort key for tree entry in name order."""
return entry[0]

Expand Down Expand Up @@ -1667,6 +1667,14 @@ def _get_extra(self):
_sorted_tree_items_py = sorted_tree_items
try:
# Try to import Rust versions
from dulwich._objects import parse_tree, sorted_tree_items # type: ignore
from dulwich._objects import (
parse_tree as _parse_tree_rs,
)
from dulwich._objects import (
sorted_tree_items as _sorted_tree_items_rs,
)
except ImportError:
pass
else:
parse_tree = _parse_tree_rs
sorted_tree_items = _sorted_tree_items_rs
Loading

0 comments on commit 15d6c81

Please sign in to comment.