Skip to content

Commit

Permalink
Support to save and load graph view states
Browse files Browse the repository at this point in the history
  • Loading branch information
manuelma committed Sep 19, 2023
1 parent 92fa20f commit bbee0ff
Show file tree
Hide file tree
Showing 8 changed files with 518 additions and 320 deletions.
12 changes: 8 additions & 4 deletions spinetoolbox/spine_db_editor/graphics_items.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
Classes for drawing graphics items on graph view's QGraphicsScene.
"""
from enum import Enum, auto
from PySide6.QtCore import Qt, Signal, Slot, QLineF, QRectF, QPointF, QObject
from PySide6.QtCore import Qt, Signal, Slot, QLineF, QRectF, QPointF, QObject, QByteArray
from PySide6.QtSvgWidgets import QGraphicsSvgItem
from PySide6.QtWidgets import (
QGraphicsItem,
Expand All @@ -26,7 +26,7 @@
QMenu,
)
from PySide6.QtSvg import QSvgRenderer
from PySide6.QtGui import QPen, QBrush, QPainterPath, QPalette, QGuiApplication, QAction, QColor
from PySide6.QtGui import QPen, QBrush, QPainterPath, QPalette, QGuiApplication, QAction

from spinetoolbox.helpers import DB_ITEM_SEPARATOR, color_from_index
from spinetoolbox.widgets.custom_qwidgets import TitleWidgetAction
Expand Down Expand Up @@ -882,11 +882,12 @@ class Anchor(Enum):
Anchor.BR: Qt.SizeFDiagCursor,
}

def __init__(self, file_path, parent=None):
def __init__(self, svg, parent=None):
super().__init__(parent)
self._renderer = QSvgRenderer()
self._svg_item = _ResizableQGraphicsSvgItem(self)
_loading_ok = self._renderer.load(file_path)
self.svg = svg
_loading_ok = self._renderer.load(QByteArray(self.svg))
self._svg_item.setCacheMode(QGraphicsItem.CacheMode.NoCache) # Needed for the exported pdf to be vector
self._svg_item.setSharedRenderer(self._renderer)
self._scaling_factor = 1
Expand Down Expand Up @@ -945,9 +946,12 @@ def _do_resize(self, rect, strong):
self._place_resizers()

def fit_rect(self, rect):
if not isinstance(rect, QRectF):
rect = QRectF(*rect)
self._do_resize(rect, True)

def fit_coordinates(self, p1, p2, scen1, scen2):
# NOTE: not in use at the moment
size = self._renderer.defaultSize()
x1, y1 = p1
x2, y2 = p2
Expand Down
6 changes: 6 additions & 0 deletions spinetoolbox/spine_db_editor/ui/spine_db_editor_window.ui
Original file line number Diff line number Diff line change
Expand Up @@ -1060,6 +1060,12 @@
<header>spinetoolbox/spine_db_editor/widgets/custom_qwidgets.h</header>
<container>1</container>
</customwidget>
<customwidget>
<class>ProgressBarWidget</class>
<extends>QWidget</extends>
<header>spinetoolbox/spine_db_editor/widgets/custom_qwidgets.h</header>
<container>1</container>
</customwidget>
</customwidgets>
<resources>
<include location="../../ui/resources/resources_icons.qrc"/>
Expand Down
329 changes: 199 additions & 130 deletions spinetoolbox/spine_db_editor/widgets/custom_qgraphicsviews.py

Large diffs are not rendered by default.

127 changes: 98 additions & 29 deletions spinetoolbox/spine_db_editor/widgets/graph_view_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import sys
import itertools
import json
from time import monotonic
from PySide6.QtCore import Slot, QTimer, QThreadPool
from spinedb_api import from_database
Expand Down Expand Up @@ -83,7 +84,7 @@ def __init__(self, *args, **kwargs):
self._time_line_index = None
self.entity_items = []
self.arc_items = []
self.selected_tree_inds = {}
self._selected_item_type_db_map_ids = {}
self.db_map_entity_id_sets = []
self.entity_inds = []
self.element_inds = []
Expand Down Expand Up @@ -329,7 +330,7 @@ def _handle_entity_graph_visibility_changed(self, visible):
@Slot(dict)
def _handle_entity_tree_selection_changed_in_graph(self, selected):
"""Stores the given selection of entity tree indexes and builds graph."""
self.selected_tree_inds = selected
self._selected_item_type_db_map_ids = self._get_selected_item_type_db_map_ids(selected)
self.added_db_map_entity_ids.clear()
self._extending_graph = True
self.build_graph()
Expand All @@ -339,8 +340,71 @@ def rebuild_graph(self, _checked=False):
self.db_map_entity_id_sets.clear()
self.build_graph()

def save_graph_data(self, name):
db_map_data = {}
for db_map in self.db_maps:
graph_data = {
"type": "graph_data",
"selected_item_type_ids": {
item_type: db_map_ids.get(db_map, [])
for item_type, db_map_ids in self._selected_item_type_db_map_ids.items()
},
"pruned_entity_ids": self.ui.graphicsView.get_pruned_entity_ids(db_map),
"pos_x_parameter": self.ui.graphicsView.pos_x_parameter,
"pos_y_parameter": self.ui.graphicsView.pos_y_parameter,
"name_parameter": self.ui.graphicsView.name_parameter,
"color_parameter": self.ui.graphicsView.color_parameter,
"arc_width_parameter": self.ui.graphicsView.arc_width_parameter,
"bg_svg": self.ui.graphicsView.get_bg_svg(),
"bg_rect": self.ui.graphicsView.get_bg_rect(),
"properties": self.ui.graphicsView.get_all_properties(),
}
db_map_data[db_map] = [{"name": name, "value": json.dumps(graph_data)}]
self.db_mngr.add_metadata(db_map_data)
# TODO: also add entity_metadata so it sticks

def get_db_map_graph_data_by_name(self):
db_map_graph_data_by_name = {}
for db_map in self.db_maps:
for metadata_item in self.db_mngr.get_items(db_map, "metadata"):
try:
graph_data = json.loads(metadata_item["value"])
except json.decoder.JSONDecodeError:
continue
if isinstance(graph_data, dict) and graph_data.get("type") == "graph_data":
db_map_graph_data_by_name.setdefault(metadata_item["name"], {})[db_map] = graph_data
return db_map_graph_data_by_name

def load_graph_data(self, db_map_graph_data):
if not db_map_graph_data:
self.msg_error.emit("Invalid graph data")
self._selected_item_type_db_map_ids = {}
for db_map, gd in db_map_graph_data.items():
for item_type, ids in gd["selected_item_type_ids"].items():
self._selected_item_type_db_map_ids.setdefault(item_type, {})[db_map] = ids
self.ui.graphicsView.set_pruned_entity_ids(db_map, gd["pruned_entity_ids"])
graph_data = db_map_graph_data[self.first_db_map]
self.ui.graphicsView.pos_x_parameter = graph_data["pos_x_parameter"]
self.ui.graphicsView.pos_y_parameter = graph_data["pos_y_parameter"]
self.ui.graphicsView.name_parameter = graph_data["name_parameter"]
self.ui.graphicsView.color_parameter = graph_data["color_parameter"]
self.ui.graphicsView.arc_width_parameter = graph_data["arc_width_parameter"]
self.ui.graphicsView.set_bg_svg(graph_data["bg_svg"])
self.ui.graphicsView.set_bg_rect(graph_data["bg_rect"])
self.ui.graphicsView.set_many_properties(graph_data["properties"])
self.build_graph()

def remove_graph_data(self, name):
db_map_typed_ids = {}
for db_map in self.db_maps:
metadata_item = next((x for x in self.db_mngr.get_items(db_map, "metadata") if x["name"] == name), None)
if metadata_item is None:
continue
db_map_typed_ids[db_map] = {"metadata": {metadata_item["id"]}}
self.db_mngr.remove_items(db_map_typed_ids)

def build_graph(self, persistent=False):
"""Builds the graph.
"""Builds graph from selection in the entity tree.
Args:
persistent (bool, optional): If True, elements in the current graph (if any) retain their position
Expand Down Expand Up @@ -388,7 +452,23 @@ def _complete_graph(self, layout_gen_id, x, y):
self.ui.graphicsView.reset_zoom()
else:
self.ui.graphicsView.apply_zoom()
self.ui.graphicsView.fit_bg_coordinates()

@staticmethod
def _get_selected_item_type_db_map_ids(selected_tree_inds):
"""Returns a dict mapping item type to db_map to selected ids-
Returns:
dict
"""
if "root" in selected_tree_inds:
return {"root": None}
item_type_db_map_ids = {}
for item_type, indexes in selected_tree_inds.items():
for index in indexes:
item = index.model().item_from_index(index)
for db_map, id_ in item.db_map_ids.items():
item_type_db_map_ids.setdefault(item_type, {}).setdefault(db_map, []).append(id_)
return item_type_db_map_ids

def _get_selected_db_map_entity_ids(self):
"""Returns a set of ids corresponding to selected entities in the trees.
Expand All @@ -397,24 +477,19 @@ def _get_selected_db_map_entity_ids(self):
set: selected object ids
set: selected relationship ids
"""
if "root" in self.selected_tree_inds:
if "root" in self._selected_item_type_db_map_ids:
return set((db_map, x["id"]) for db_map in self.db_maps for x in self.db_mngr.get_items(db_map, "entity"))
db_map_entity_ids = set()
for index in self.selected_tree_inds.get("entity", {}):
item = index.model().item_from_index(index)
db_map_entity_ids |= set(item.db_map_ids.items())
for index in self.selected_tree_inds.get("entity_class", {}):
item = index.model().item_from_index(index)
for db_map, ids in self._selected_item_type_db_map_ids.get("entity", {}).items():
db_map_entity_ids |= {(db_map, id_) for id_ in ids}
for db_map, ids in self._selected_item_type_db_map_ids.get("entity_class", {}).items():
db_map_entity_ids |= set(
(db_map, x["id"])
for db_map, id_ in item.db_map_ids.items()
for x in self.db_mngr.get_items(db_map, "entity")
if x["class_id"] == id_
(db_map, x["id"]) for x in self.db_mngr.get_items(db_map, "entity") if x["class_id"] in ids
)
return db_map_entity_ids

def _get_db_map_entities_for_graph(self, db_map_entity_ids):
cond = any if self.ui.graphicsView.auto_expand_entities else all
cond = any if self.ui.graphicsView.get_property("auto_expand_entities") else all
return [
(db_map, x)
for db_map in self.db_maps
Expand All @@ -427,18 +502,12 @@ def _get_db_map_entities_for_graph(self, db_map_entity_ids):

def _update_graph_data(self):
"""Updates data for graph according to selection in trees."""
pruned_db_map_entity_ids = {
id_ for ids in self.ui.graphicsView.pruned_db_map_entity_ids.values() for id_ in ids
}
pruned_db_map_entity_ids = set(self.ui.graphicsView.get_pruned_db_map_entity_ids())
db_map_entity_ids = self._get_selected_db_map_entity_ids()
db_map_entity_ids |= self.added_db_map_entity_ids
db_map_entity_ids -= pruned_db_map_entity_ids
db_map_entities = self._get_db_map_entities_for_graph(db_map_entity_ids)
max_ent_dim = (
self.ui.graphicsView.max_entity_dimension
if not self.ui.graphicsView.disable_max_relationship_dimension
else sys.maxsize
)
max_ent_dim_count = self.ui.graphicsView.get_property("max_entity_dimension_count")
db_map_element_id_lists = {}
for db_map, entity in db_map_entities:
if (db_map, entity["id"]) in pruned_db_map_entity_ids:
Expand All @@ -447,7 +516,7 @@ def _update_graph_data(self):
(db_map, id_) for id_ in entity["element_id_list"] if (db_map, id_) not in pruned_db_map_entity_ids
]
el_count = len(db_map_element_id_list)
if el_count != 0 and (el_count < 2 or el_count > max_ent_dim):
if el_count != 0 and (el_count < 2 or el_count > max_ent_dim_count):
continue
db_map_entity_ids.add((db_map, entity["id"]))
db_map_entity_ids.update(db_map_element_id_list)
Expand All @@ -468,7 +537,7 @@ def get_entity_key(self, db_map_entity_id):
db_map, entity_id = db_map_entity_id
entity = self.db_mngr.get_item(db_map, "entity", entity_id)
key = (entity["class_name"], entity["dimension_name_list"], entity["byname"])
if not self.ui.graphicsView.merge_dbs:
if not self.ui.graphicsView.get_property("merge_dbs"):
key += (db_map.codename,)
return key

Expand Down Expand Up @@ -588,9 +657,9 @@ def _make_layout_generator(self):
for db_map_entity_id in db_map_entity_ids
if fixed_positions[db_map_entity_id]
}
spread_factor = int(self.qsettings.value("appSettings/layoutAlgoSpreadFactor", defaultValue="100")) / 100
neg_weight_exp = int(self.qsettings.value("appSettings/layoutAlgoNegWeightExp", defaultValue="2"))
max_iters = int(self.qsettings.value("appSettings/layoutAlgoMaxIterations", defaultValue="12"))
spread_factor = self.ui.graphicsView.get_property("spread_factor") / 100
build_iters = self.ui.graphicsView.get_property("build_iters")
neg_weight_exp = self.ui.graphicsView.get_property("neg_weight_exp")
return GraphLayoutGeneratorRunnable(
self._layout_gen_id,
len(self.db_map_entity_id_sets),
Expand All @@ -599,7 +668,7 @@ def _make_layout_generator(self):
spread=spread_factor * self._ARC_LENGTH_HINT,
heavy_positions=heavy_positions,
weight_exp=-neg_weight_exp,
max_iters=max_iters,
max_iters=build_iters,
)

@staticmethod
Expand Down
Loading

0 comments on commit bbee0ff

Please sign in to comment.