Skip to content

Commit

Permalink
feat!(pylace): Return PyResults for several methods vs unwrap
Browse files Browse the repository at this point in the history
Removed several unwraps and replaced them with `PyResult` returns.
Some could not be removed due to PyO3's limitation to non-failable type conversions PyO3/pyo3#1813.
  • Loading branch information
schmidmt committed Feb 12, 2024
1 parent 408d4de commit 23f651e
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 74 deletions.
18 changes: 9 additions & 9 deletions pylace/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 4 additions & 1 deletion pylace/src/component.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,10 @@ impl CategoricalParams {
_ => format!(
"[{}, ..., {}]",
self.weights[0],
self.weights.last().unwrap()
self.weights
.last()
.map(|x| x.to_string())
.unwrap_or_else(|| "-".to_string())
),
};

Expand Down
4 changes: 4 additions & 0 deletions pylace/src/df.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@ pub(crate) fn to_py_array(
Ok(array.to_object(py))
}

// TODO: When https://github.com/PyO3/pyo3/issues/1813 is solved, implement a
// failable version.
impl IntoPy<PyObject> for PySeries {
fn into_py(self, py: Python<'_>) -> PyObject {
let s = self.0.rechunk();
Expand All @@ -181,6 +183,8 @@ impl IntoPy<PyObject> for PySeries {
}
}

// TODO: When https://github.com/PyO3/pyo3/issues/1813 is solved, implement a
// failable version.
impl IntoPy<PyObject> for PyDataFrame {
fn into_py(self, py: Python<'_>) -> PyObject {
let pyseries = self
Expand Down
76 changes: 50 additions & 26 deletions pylace/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ use lace::prelude::ColMetadataList;
use lace::{EngineUpdateConfig, FType, HasStates, OracleT};
use polars::prelude::{DataFrame, NamedFrom, Series};
use pyo3::exceptions::{PyIndexError, PyRuntimeError, PyValueError};
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyList, PyType};
use pyo3::{create_exception, prelude::*};
use rand::SeedableRng;
use rand_xoshiro::Xoshiro256Plus;

Expand Down Expand Up @@ -112,18 +112,20 @@ impl CoreEngine {

/// Load a Engine from metadata
#[classmethod]
fn load(_cls: &PyType, path: PathBuf) -> CoreEngine {
fn load(_cls: &PyType, path: PathBuf) -> PyResult<CoreEngine> {
let (engine, rng) = {
let mut engine = lace::Engine::load(path).unwrap();
let rng = Xoshiro256Plus::from_rng(&mut engine.rng).unwrap();
let mut engine = lace::Engine::load(path)
.map_err(|e| EngineLoadError::new_err(e.to_string()))?;
let rng = Xoshiro256Plus::from_rng(&mut engine.rng)
.map_err(|e| EngineLoadError::new_err(e.to_string()))?;
(engine, rng)
};
Self {
Ok(Self {
col_indexer: Indexer::columns(&engine.codebook),
row_indexer: Indexer::rows(&engine.codebook),
rng,
engine,
}
})
}

/// Save the engine to `path`
Expand Down Expand Up @@ -491,13 +493,17 @@ impl CoreEngine {
let mut a = Vec::with_capacity(pairs.len());
let mut b = Vec::with_capacity(pairs.len());

utils::pairs_list_iter(pairs, indexer).for_each(|res| {
let (ix_a, ix_b) = res.unwrap();
let name_a = indexer.to_name[&ix_a].clone();
let name_b = indexer.to_name[&ix_b].clone();
a.push(name_a);
b.push(name_b);
});
utils::pairs_list_iter(pairs, indexer)
.map(|res| {
let (ix_a, ix_b) = res?;
let name_a = indexer.to_name[&ix_a].clone();
let name_b = indexer.to_name[&ix_b].clone();
a.push(name_a);
b.push(name_b);

Ok::<(), PyErr>(())
})
.collect::<PyResult<()>>()?;

let a = Series::new("A", a);
let b = Series::new("B", b);
Expand Down Expand Up @@ -1047,7 +1053,7 @@ impl CoreEngine {
transitions: Option<Vec<transition::StateTransition>>,
save_path: Option<PathBuf>,
update_handler: Option<PyObject>,
) {
) -> PyResult<()> {
use lace::update_handler::Timeout;
use std::time::Duration;

Expand Down Expand Up @@ -1084,11 +1090,13 @@ impl CoreEngine {
config,
(timeout, PyUpdateHandler::new(update_handler)),
)
.unwrap();
.map_err(|e| EngineUpdateError::new_err(e.to_string()))
} else {
self.engine.update(config, timeout).unwrap();
self.engine
.update(config, timeout)
.map_err(|e| EngineUpdateError::new_err(e.to_string()))
}
});
})
}

/// Append new rows to the table.
Expand Down Expand Up @@ -1132,15 +1140,22 @@ impl CoreEngine {
})?;

// must add new row names to indexer
let row_names = df_vals.row_names.unwrap();
(self.engine.n_rows()..).zip(row_names.iter()).for_each(
let row_names = df_vals.row_names.ok_or_else(|| {
PyValueError::new_err("Provided dataframe has no index (row names)")
})?;
(self.engine.n_rows()..).zip(row_names.iter()).map(
|(ix, name)| {
// row names passed to 'append' should not exist
assert!(!self.row_indexer.to_ix.contains_key(name));
self.row_indexer.to_ix.insert(name.to_owned(), ix);
self.row_indexer.to_name.insert(ix, name.to_owned());
if self.row_indexer.to_ix.contains_key(name) {
Err(PyValueError::new_err(
format!("Duplicate ids/indices cannot be inserted. Duplicate `{name}`")
))
} else {
self.row_indexer.to_ix.insert(name.to_owned(), ix);
self.row_indexer.to_name.insert(ix, name.to_owned());
Ok(())
}
},
);
).collect::<PyResult<()>>()?;

let data = parts_to_insert_values(
df_vals.col_names,
Expand Down Expand Up @@ -1217,7 +1232,11 @@ impl CoreEngine {

let data = parts_to_insert_values(
col_names,
df_vals.row_names.unwrap(),
df_vals.row_names.ok_or_else(|| {
PyValueError::new_err(
"Provided dataframe has no index (row names)",
)
})?,
df_vals.values,
);

Expand Down Expand Up @@ -1314,9 +1333,12 @@ pub fn infer_srs_metadata(
.map(metadata::ColumnMetadata)
}

create_exception!(lace, EngineLoadError, pyo3::exceptions::PyException);
create_exception!(lace, EngineUpdateError, pyo3::exceptions::PyException);

/// A Python module implemented in Rust.
#[pymodule]
fn core(_py: Python, m: &PyModule) -> PyResult<()> {
fn core(py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<Codebook>()?;
m.add_class::<CoreEngine>()?;
m.add_class::<CodebookBuilder>()?;
Expand All @@ -1334,5 +1356,7 @@ fn core(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<update_handler::PyEngineUpdateConfig>()?;
m.add_function(wrap_pyfunction!(infer_srs_metadata, m)?)?;
m.add_function(wrap_pyfunction!(metadata::codebook_from_df, m)?)?;
m.add("EngineLoadError", py.get_type::<EngineLoadError>())?;
m.add("EngineUpdateError", py.get_type::<EngineUpdateError>())?;
Ok(())
}
1 change: 0 additions & 1 deletion pylace/src/update_handler.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use std::io::Write;
/// Update Handler and associated tooling for `CoreEngine.update` in Python.
use std::sync::{Arc, Mutex};

Expand Down
Loading

0 comments on commit 23f651e

Please sign in to comment.