Skip to content

Commit

Permalink
refactor: extend provenance graph
Browse files Browse the repository at this point in the history
  • Loading branch information
doctrino committed Dec 19, 2024
1 parent 8af07f4 commit 2931ac4
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 7 deletions.
29 changes: 28 additions & 1 deletion cognite/neat/_session/_show.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ def __call__(self) -> Any:

def _generate_dm_provenance_di_graph_and_types(self) -> nx.DiGraph:
di_graph = nx.DiGraph()
hex_colored_types = _generate_hex_color_per_type(["Agent", "Entity", "Activity"])
hex_colored_types = _generate_hex_color_per_type(["Agent", "Entity", "Activity", "Export", "Pruned"])

for change in self._state.rule_store.provenance:
source = self._shorten_id(change.source_entity.id_)
Expand Down Expand Up @@ -313,6 +313,33 @@ def _generate_dm_provenance_di_graph_and_types(self) -> nx.DiGraph:
di_graph.add_edge(source, agent, label="used", color="grey")
di_graph.add_edge(agent, target, label="generated", color="grey")

for source_id, exports in self._state.rule_store.exports_by_source_entity_id.items():
source_shorten = self._shorten_id(source_id)
for export in exports:
export_id = self._shorten_id(export.target_entity.id_)
di_graph.add_node(
export_id,
label=export_id,
type="Export",
title="Export",
color=hex_colored_types["Export"],
)
di_graph.add_edge(source_shorten, export_id, label="exported", color="grey")

for pruned_lists in self._state.rule_store.pruned_by_source_entity_id.values():
for prune_path in pruned_lists:
for change in prune_path:
source = self._shorten_id(change.source_entity.id_)
target = self._shorten_id(change.target_entity.id_)
di_graph.add_node(
target,
label=target,
type="Pruned",
title="Pruned",
color=hex_colored_types["Pruned"],
)
di_graph.add_edge(source, target, label="pruned", color="grey")

return di_graph

@staticmethod
Expand Down
12 changes: 6 additions & 6 deletions cognite/neat/_store/_rules_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ class OutcomeEntity(Entity):
class NeatRulesStore:
def __init__(self) -> None:
self.provenance = Provenance()
self.exports_by_target_entity_id: dict[rdflib.URIRef, list[Change]] = defaultdict(list)
self.pruned_by_target_entity_id: dict[rdflib.URIRef, list[Provenance]] = defaultdict(list)
self.exports_by_source_entity_id: dict[rdflib.URIRef, list[Change]] = defaultdict(list)
self.pruned_by_source_entity_id: dict[rdflib.URIRef, list[Provenance]] = defaultdict(list)
self._iteration_by_id: dict[Hashable, int] = {}

def calculate_provenance_hash(self, shorten: bool = True) -> str:
Expand Down Expand Up @@ -124,7 +124,7 @@ def export(self, exporter: BaseExporter[T_VerifiedRules, T_Export]) -> T_Export:
description=exporter.description,
source_entity=source_entity,
)
self.exports_by_target_entity_id[source_entity.id_].append(change)
self.exports_by_source_entity_id[source_entity.id_].append(change)
return result

def export_to_file(self, exporter: BaseExporter, path: Path) -> None:
Expand Down Expand Up @@ -164,7 +164,7 @@ def export_to_file(self, exporter: BaseExporter, path: Path) -> None:
description=exporter.description,
source_entity=source_entity,
)
self.exports_by_target_entity_id[source_entity.id_].append(change)
self.exports_by_source_entity_id[source_entity.id_].append(change)

def export_to_cdf(self, exporter: CDFExporter, client: NeatClient, dry_run: bool) -> UploadResultList:
last_change = self.provenance[-1]
Expand Down Expand Up @@ -205,7 +205,7 @@ def export_to_cdf(self, exporter: CDFExporter, client: NeatClient, dry_run: bool
description=exporter.description,
source_entity=source_entity,
)
self.exports_by_target_entity_id[source_entity.id_].append(change)
self.exports_by_source_entity_id[source_entity.id_].append(change)
return result

def prune_until_compatible(self, transformer: RulesTransformer) -> list[Change]:
Expand All @@ -231,7 +231,7 @@ def prune_until_compatible(self, transformer: RulesTransformer) -> list[Change]:
return []
self.provenance = self.provenance[: -len(pruned_candidates)]
pruned_candidates.reverse()
self.pruned_by_target_entity_id[self.provenance[-1].target_entity.id_].append(Provenance(pruned_candidates))
self.pruned_by_source_entity_id[self.provenance[-1].target_entity.id_].append(Provenance(pruned_candidates))
return pruned_candidates

def _export(self, action: Callable[[Any], Any], agent: Agent, description: str) -> Any:
Expand Down

0 comments on commit 2931ac4

Please sign in to comment.