diff --git a/src/solver/expressions/include/antares/solver/expressions/visitors/AstDOTStyleVisitor.h b/src/solver/expressions/include/antares/solver/expressions/visitors/AstDOTStyleVisitor.h index 3b5ce8410e..8ca2e509c5 100644 --- a/src/solver/expressions/include/antares/solver/expressions/visitors/AstDOTStyleVisitor.h +++ b/src/solver/expressions/include/antares/solver/expressions/visitors/AstDOTStyleVisitor.h @@ -134,6 +134,9 @@ class AstDOTStyleVisitor: public NodeVisitor void visit(const Nodes::ComponentVariableNode* node, std::ostream& os) override; void visit(const Nodes::ComponentParameterNode* node, std::ostream& os) override; + void computeNumberNodesPerType(); + void makeLegend(std::ostream& os); + /** * @brief Retrieves a unique ID for a given node. * @@ -181,7 +184,14 @@ class AstDOTStyleVisitor: public NodeVisitor * * This map is used to keep track of assigned IDs for each node in the AST. */ - std::map> nodeIds_; + std::map nodeIds_; + + /** + * @brief A map associating a number of instances to a type name. + * + * This map is used to keep track of assigned IDs for each node in the AST. + */ + std::map nbNodesPerType_; /** * @brief Counter for generating unique node IDs. diff --git a/src/solver/expressions/visitors/AstDOTStyleVisitor.cpp b/src/solver/expressions/visitors/AstDOTStyleVisitor.cpp index 78a8e0be57..a1e3f6fbde 100644 --- a/src/solver/expressions/visitors/AstDOTStyleVisitor.cpp +++ b/src/solver/expressions/visitors/AstDOTStyleVisitor.cpp @@ -18,13 +18,11 @@ ** You should have received a copy of the Mozilla Public Licence 2.0 ** along with Antares_Simulator. If not, see . */ - #include "antares/solver/expressions/visitors/AstDOTStyleVisitor.h" #include -#include -#include +#include "antares/solver/expressions/nodes/ExpressionsNodes.h" namespace Antares::Solver::Visitors { @@ -42,6 +40,16 @@ static constexpr BoxStyle ComponentVariableStyle{"goldenrod", "octagon", "filled static constexpr BoxStyle PortFieldStyle{"olive", "component", "filled, solid"}; } // namespace NodeStyle +void makeLegendTitle(std::ostream& os) +{ + os << "subgraph cluster_legend {\n" + << "label = \"Legend\";\n" + << "style = dashed;\n" + << "fontsize = 16;\n" + << "color = lightgrey;\n" + << "node [shape=plaintext];\n\n"; +} + void ProcessElementLegend(const std::string& element_name, size_t size, std::ostream& os) { os << "legend_" << element_name << " [ label =\" " << element_name << ": " << size << "\"]\n"; @@ -52,34 +60,20 @@ void AddFiliation(std::ostream& os, const std::string& parent_id, const std::str os << "legend_" << parent_id << " -> " << "legend_" << child_id << " [style=invis];\n"; } -void GetLegend(const std::map>& nodeIds, - std::ostream& os) +void AstDOTStyleVisitor::makeLegend(std::ostream& os) { - os << R"raw(subgraph cluster_legend { -label = "Legend"; -style = dashed; -fontsize = 16; -color = lightgrey; -node [shape=plaintext]; - -)raw"; - - auto order_nb_type = nodeIds.size(); - if (order_nb_type > 1) + if (nbNodesPerType_.empty()) { - for (auto it = nodeIds.begin(), next_it = std::next(it); next_it != nodeIds.end(); - ++it, ++next_it) - { - ProcessElementLegend(it->first, it->second.size(), os); - AddFiliation(os, it->first, next_it->first); - } - ProcessElementLegend(nodeIds.rbegin()->first, nodeIds.rbegin()->second.size(), os); + return; } - else if (order_nb_type == 1) + + ProcessElementLegend(nbNodesPerType_.begin()->first, nbNodesPerType_.begin()->second, os); + for (auto it = std::next(nbNodesPerType_.begin()); it != nbNodesPerType_.end(); ++it) { - ProcessElementLegend(nodeIds.begin()->first, nodeIds.begin()->second.size(), os); + auto prev_it = std::prev(it); + AddFiliation(os, prev_it->first, it->first); + ProcessElementLegend(it->first, it->second, os); } - os << "}\n"; } @@ -186,17 +180,19 @@ std::string AstDOTStyleVisitor::name() const unsigned int AstDOTStyleVisitor::getNodeID(const Nodes::Node* node) { - const auto& node_name = node->name(); - if (nodeIds_.find(node_name) == nodeIds_.end()) + if (nodeIds_.find(node) == nodeIds_.end()) { - nodeIds_[node->name()][node] = ++nodeCount_; + nodeIds_[node] = ++nodeCount_; } - else if (auto& id_map = nodeIds_[node_name]; id_map.find(node) == id_map.end()) + return nodeIds_[node]; +} + +void AstDOTStyleVisitor::computeNumberNodesPerType() +{ + for (const auto& [node, _]: nodeIds_) { - id_map[node] = ++nodeCount_; + nbNodesPerType_[node->name()]++; } - - return nodeIds_[node->name()][node]; } void AstDOTStyleVisitor::emitNode(unsigned int id, @@ -238,13 +234,19 @@ void AstDOTStyleVisitor::NewTreeGraph(std::ostream& os, const std::string& tree_ void AstDOTStyleVisitor::EndTreeGraph(std::ostream& os) { + computeNumberNodesPerType(); + // Graph title showing the total number of nodes os << "label=\"AST Diagram(Total nodes : " << nodeCount_ << ")\"\n"; os << "labelloc = \"t\"\n"; - GetLegend(nodeIds_, os); + + makeLegendTitle(os); + makeLegend(os); os << "}\n"; + nodeCount_ = 0; nodeIds_.clear(); + nbNodesPerType_.clear(); } void AstDOTStyleVisitor::operator()(std::ostream& os, Nodes::Node* root) diff --git a/src/tests/src/solver/expressions/test_AstDOTStyleVisitor.cpp b/src/tests/src/solver/expressions/test_AstDOTStyleVisitor.cpp index 1cf239cef8..09f80008f3 100644 --- a/src/tests/src/solver/expressions/test_AstDOTStyleVisitor.cpp +++ b/src/tests/src/solver/expressions/test_AstDOTStyleVisitor.cpp @@ -21,8 +21,6 @@ #define WIN32_LEAN_AND_MEAN -#include - #include #include @@ -149,15 +147,16 @@ legend_VariableNode [ label =" VariableNode: 1"] Registry registry_; }; -BOOST_FIXTURE_TEST_CASE(tree_with_all_type_node, Fixture) +BOOST_FIXTURE_TEST_CASE( + dot_visitor_is_run_on_complex_expression___resulting_dot_content_as_expected, + Fixture) { - std::ostringstream os; + std::ostringstream dotContentStream; AstDOTStyleVisitor astGraphVisitor; - astGraphVisitor(os, makeExpression()); + astGraphVisitor(dotContentStream, makeExpression()); - // read the content of os - BOOST_CHECK_EQUAL(expectedDotContent(), os.str()); + BOOST_CHECK_EQUAL(dotContentStream.str(), expectedDotContent()); } BOOST_FIXTURE_TEST_CASE(AstDOTStyleVisitor_name, Registry)