Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

👷 Create pre-merge workflow #140

Merged
merged 3 commits into from
Jan 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions .github/workflows/pre-merge.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
name: Pre-merge

on:
pull_request:
branches: [ "main", "ci-test" ]

env:
CARGO_TERM_COLOR: always
RUSTFLAGS: "-Dwarnings"

jobs:
rust:

runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v4
- uses: dtolnay/rust-toolchain@stable
- uses: Swatinem/rust-cache@v2
with:
prefix-key: "rust-dependencies"
- name: Build
run: cargo build --verbose
- name: Clippy
run: cargo clippy --no-deps
- name: Format check
run: cargo fmt --check
- name: Run tests
run: cargo test --verbose
10 changes: 5 additions & 5 deletions python/deepdecipher/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from .deepdecipher import (
log_init,
start_server,
Database,
DataType,
DataTypeHandle,
Index,
ModelHandle,
ModelMetadata,
DataTypeHandle,
DataType,
ServiceHandle,
ServiceProvider,
Index,
log_init,
start_server,
)

deepdecipher.setup_keyboard_interrupt()
4 changes: 4 additions & 0 deletions rustfmt.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
unstable_features = true
imports_granularity = "Crate"
group_imports = "StdExternalCrate"
format_strings = true
3 changes: 1 addition & 2 deletions src/data/data_objects/metadata_object.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
use anyhow::Result;
use serde::{Deserialize, Serialize};

use crate::data::{Metadata, ModelHandle};

use super::{data_object, DataObject};
use crate::data::{Metadata, ModelHandle};

#[derive(Clone, Serialize, Deserialize)]
pub struct MetadataObject {
Expand Down
79 changes: 58 additions & 21 deletions src/data/data_objects/neuron2graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@ use graphviz_rust::{
use itertools::Itertools;
use serde::{Deserialize, Serialize};

use crate::data::SimilarNeurons;

use super::{data_object, DataObject};
use crate::data::SimilarNeurons;

fn id_to_str(id: &Id) -> &str {
match id {
Expand All @@ -24,9 +23,13 @@ fn id_to_str(id: &Id) -> &str {

fn id_to_usize(id: &Id) -> Result<usize> {
let id_string = id_to_str(id);
id_string.parse::<usize>().with_context(|| format!(
"Could not parse node id {} as usize. It is assumed that all N2G graphs only use positive integer node ids.", id_string
))
id_string.parse::<usize>().with_context(|| {
format!(
"Could not parse node id {} as usize. It is assumed that all N2G graphs only use \
positive integer node ids.",
id_string
)
})
}

fn dot_node_to_id_label_importance(node: &DotNode) -> Result<(usize, String, f32)> {
Expand All @@ -40,15 +43,26 @@ fn dot_node_to_id_label_importance(node: &DotNode) -> Result<(usize, String, f32
.find(|Attribute(key, _)| id_to_str(key) == "label")
.with_context(|| format!("Node with id {id} has no attribute 'label'."))?;
// Assume that the `fillcolor` attribute is a 9 character string with '"' enclosing a hexadecimal color code.
let color_str = get_attribute(attributes.as_slice(), "fillcolor").with_context(|| format!(
"Node {id} has no attribute 'fillcolor'. It is assumed that all N2G nodes have a 'fillcolor' attribute that signifies their importance."
))?;
let importance_hex = color_str.get(4..6).with_context(|| format!(
"The 'fillcolor' attribute of node {id} is insufficiently long. It is expected to be 9 characters long."
))?;
let importance = 1.-u8::from_str_radix(importance_hex, 16).with_context(|| format!(
"The green part of the 'fillcolor' attribute of node {id} is not a valid hexadecimal number."
))? as f32 / 255.0;
let color_str = get_attribute(attributes.as_slice(), "fillcolor").with_context(|| {
format!(
"Node {id} has no attribute 'fillcolor'. It is assumed that all N2G nodes have a \
'fillcolor' attribute that signifies their importance."
)
})?;
let importance_hex = color_str.get(4..6).with_context(|| {
format!(
"The 'fillcolor' attribute of node {id} is insufficiently long. It is expected to be \
9 characters long."
)
})?;
let importance = 1.
- u8::from_str_radix(importance_hex, 16).with_context(|| {
format!(
"The green part of the 'fillcolor' attribute of node {id} is not a valid \
hexadecimal number."
)
})? as f32
/ 255.0;

let label = id_to_str(label_id).to_string();
Ok((id, label, importance))
Expand All @@ -70,7 +84,12 @@ fn subgraph_to_nodes(subgraph: &Subgraph) -> Result<Vec<(usize, String, f32)>> {
let id_str = id_to_str(id);
let id: usize = id_str
.strip_prefix("cluster_")
.with_context(|| format!("It is assumed that all N2G subgraphs have ids starting with 'cluster_'. Subgraph id: {id_str}"))?
.with_context(|| {
format!(
"It is assumed that all N2G subgraphs have ids starting with 'cluster_'. Subgraph \
id: {id_str}"
)
})?
.parse::<usize>()
.with_context(|| format!("Failed to parse subgraph id '{id_str}' as usize."))?;
let nodes = statements
Expand All @@ -92,16 +111,34 @@ fn dot_edge_to_ids(
) -> Result<(usize, usize)> {
match edge_ty {
EdgeTy::Pair(Vertex::N(NodeId(node_id1, _)), Vertex::N(NodeId(node_id2, _))) => {
let id1 = id_to_usize(node_id1).with_context(|| format!("Failed to parse first id for edge {edge_ty:?}."))?;
let id2 = id_to_usize(node_id2).with_context(|| format!("Failed to parse second id for edge {edge_ty:?}."))?;
let id1 = id_to_usize(node_id1)
.with_context(|| format!("Failed to parse first id for edge {edge_ty:?}."))?;
let id2 = id_to_usize(node_id2)
.with_context(|| format!("Failed to parse second id for edge {edge_ty:?}."))?;
match get_attribute(attributes, "dir") {
Some("back") => Ok((id2, id1)),
None => bail!("No direction attribute found for edge {id1}->{id2}. It is assumed that all N2G graphs only use edges with direction 'back'."),
_ => bail!("Only edges with direction 'back' or 'forward' are supported. It is assumed that all N2G graphs only use edges with direction 'back' or 'forward'. Edge: {:?}", edge_ty)
None => bail!(
"No direction attribute found for edge {id1}->{id2}. It is assumed that all \
N2G graphs only use edges with direction 'back'."
),
_ => bail!(
"Only edges with direction 'back' or 'forward' are supported. It is assumed \
that all N2G graphs only use edges with direction 'back' or 'forward'. Edge: \
{:?}",
edge_ty
),
}
}
EdgeTy::Pair(_, _) => bail!("Only edges between individual nodes are supported. It is assumed that N2G does not use edges between subgraphs. Edge: {:?}", edge_ty),
EdgeTy::Chain(_) => bail!("Only pair edges are supported. It is assumed that all N2G graphs only use pair edges. Edge: {:?}", edge_ty)
EdgeTy::Pair(_, _) => bail!(
"Only edges between individual nodes are supported. It is assumed that N2G does not \
use edges between subgraphs. Edge: {:?}",
edge_ty
),
EdgeTy::Chain(_) => bail!(
"Only pair edges are supported. It is assumed that all N2G graphs only use pair \
edges. Edge: {:?}",
edge_ty
),
}
}

Expand Down
1 change: 0 additions & 1 deletion src/data/data_objects/neuroscope/neuroscope_page.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ use anyhow::{bail, Context, Result};
use itertools::Itertools;
use regex::Regex;
use serde::{Deserialize, Serialize};

use utoipa::ToSchema;

use crate::data::{
Expand Down
42 changes: 28 additions & 14 deletions src/data/database/data_types/json.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use anyhow::{bail, Context, Result};
use async_trait::async_trait;

use super::{data_type::DataValidationError, DataTypeDiscriminants, ModelDataType};
use crate::{
data::{
data_objects::{DataObject, JsonData},
Expand All @@ -8,10 +10,6 @@ use crate::{
Index,
};

use super::{data_type::DataValidationError, DataTypeDiscriminants, ModelDataType};

use anyhow::{bail, Context, Result};

pub struct Json {
model: ModelHandle,
data_type: DataTypeHandle,
Expand Down Expand Up @@ -81,26 +79,42 @@ impl Json {
pub async fn layer_page(&self, layer_index: u32) -> Result<JsonData> {
let model_name = self.model.name();
let data_type_name = self.data_type.name();
let raw_data = self.model
.layer_data( &self.data_type, layer_index)
.await.with_context(|| {
format!("Failed to get '{data_type_name}' layer data for layer {layer_index} in model '{model_name}'.")
let raw_data = self
.model
.layer_data(&self.data_type, layer_index)
.await
.with_context(|| {
format!(
"Failed to get '{data_type_name}' layer data for layer {layer_index} in model \
'{model_name}'."
)
})?
.with_context(|| {
format!("Database has no '{data_type_name}' layer data for layer {layer_index} in model '{model_name}'.")
format!(
"Database has no '{data_type_name}' layer data for layer {layer_index} in \
model '{model_name}'."
)
})?;
JsonData::from_binary(raw_data.as_slice())
}
pub async fn neuron_page(&self, layer_index: u32, neuron_index: u32) -> Result<JsonData> {
let model_name = self.model.name();
let data_type_name = self.data_type.name();
let raw_data = self.model
.neuron_data( &self.data_type, layer_index, neuron_index)
.await.with_context(|| {
format!("Failed to get '{data_type_name}' neuron data for neuron l{layer_index}n{neuron_index} in model '{model_name}'.")
let raw_data = self
.model
.neuron_data(&self.data_type, layer_index, neuron_index)
.await
.with_context(|| {
format!(
"Failed to get '{data_type_name}' neuron data for neuron \
l{layer_index}n{neuron_index} in model '{model_name}'."
)
})?
.with_context(|| {
format!("Database has no '{data_type_name}' neuron data for neuron l{layer_index}n{neuron_index} in model '{model_name}'.")
format!(
"Database has no '{data_type_name}' neuron data for neuron \
l{layer_index}n{neuron_index} in model '{model_name}'."
)
})?;
JsonData::from_binary(raw_data.as_slice())
}
Expand Down
18 changes: 13 additions & 5 deletions src/data/database/data_types/neuron2graph.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
use anyhow::{bail, Context, Result};
use async_trait::async_trait;

use super::{data_type::DataValidationError, DataTypeDiscriminants, ModelDataType};
use crate::data::{
data_objects::{DataObject, Graph},
DataTypeHandle, ModelHandle,
};

use super::{data_type::DataValidationError, DataTypeDiscriminants, ModelDataType};

pub struct Neuron2Graph {
model: ModelHandle,
data_type: DataTypeHandle,
Expand Down Expand Up @@ -47,12 +46,21 @@ impl ModelDataType for Neuron2Graph {
impl Neuron2Graph {
pub async fn neuron_graph(&self, layer_index: u32, neuron_index: u32) -> Result<Graph> {
let model_name = self.model.name();
let raw_data = self.model
let raw_data = self
.model
.neuron_data(&self.data_type, layer_index, neuron_index)
.await?
.with_context(|| {
format!("Database has no neuron2graph data for neuron l{layer_index}n{neuron_index} in model '{model_name}'")
format!(
"Database has no neuron2graph data for neuron l{layer_index}n{neuron_index} \
in model '{model_name}'"
)
})?;
Graph::from_binary(raw_data).with_context(|| format!("Failed to unpack neuron2graph graph for neuron l{layer_index}n{neuron_index} in model '{model_name}'."))
Graph::from_binary(raw_data).with_context(|| {
format!(
"Failed to unpack neuron2graph graph for neuron l{layer_index}n{neuron_index} in \
model '{model_name}'."
)
})
}
}
36 changes: 23 additions & 13 deletions src/data/database/data_types/neuron_explainer.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
use anyhow::{bail, Context, Result};
use async_trait::async_trait;

use super::{
data_type::{DataValidationError, ModelDataType},
DataTypeDiscriminants,
};
use crate::data::{
data_objects::{DataObject, NeuronExplainerPage},
database::ModelHandle,
DataTypeHandle,
};

use super::{
data_type::{DataValidationError, ModelDataType},
DataTypeDiscriminants,
};

pub struct NeuronExplainer {
model: ModelHandle,
data_type: DataTypeHandle,
Expand Down Expand Up @@ -55,14 +54,25 @@ impl NeuronExplainer {
neuron_index: u32,
) -> Result<Option<NeuronExplainerPage>> {
let model_name = self.model.name();
let raw_data = self.model
.neuron_data( &self.data_type, layer_index, neuron_index)
.await.with_context(|| {
format!("Failed to get neuron explainer neuron data for neuron l{layer_index}n{neuron_index} in model '{model_name}'.")
})?;
raw_data.map(|raw_data| NeuronExplainerPage::from_binary(raw_data.as_slice())
let raw_data = self
.model
.neuron_data(&self.data_type, layer_index, neuron_index)
.await
.with_context(|| {
format!("Failed to deserialize neuron explainer neuron data for neuron l{layer_index}n{neuron_index} in model '{model_name}'.")
})).transpose()
format!(
"Failed to get neuron explainer neuron data for neuron \
l{layer_index}n{neuron_index} in model '{model_name}'."
)
})?;
raw_data
.map(|raw_data| {
NeuronExplainerPage::from_binary(raw_data.as_slice()).with_context(|| {
format!(
"Failed to deserialize neuron explainer neuron data for neuron \
l{layer_index}n{neuron_index} in model '{model_name}'."
)
})
})
.transpose()
}
}
Loading