Skip to content

Commit

Permalink
Visitor AST into DOT : trial for more clarity (#2426)
Browse files Browse the repository at this point in the history
Co-authored-by: Florian OMNES <florian.omnes@rte-france.com>
  • Loading branch information
guilpier-code and flomnes authored Sep 26, 2024
1 parent d174a1b commit bc8e74c
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,9 @@ class AstDOTStyleVisitor: public NodeVisitor<void, std::ostream&>
void visit(const Nodes::ComponentVariableNode* node, std::ostream& os) override;
void visit(const Nodes::ComponentParameterNode* node, std::ostream& os) override;

void computeNumberNodesPerType();
void makeLegend(std::ostream& os);

/**
* @brief Retrieves a unique ID for a given node.
*
Expand Down Expand Up @@ -181,7 +184,14 @@ class AstDOTStyleVisitor: public NodeVisitor<void, std::ostream&>
*
* This map is used to keep track of assigned IDs for each node in the AST.
*/
std::map<std::string, std::map<const Nodes::Node*, unsigned int>> nodeIds_;
std::map<const Nodes::Node*, unsigned int> nodeIds_;

/**
* @brief A map associating a number of instances to a type name.
*
* This map is used to keep track of assigned IDs for each node in the AST.
*/
std::map<std::string, unsigned int> nbNodesPerType_;

/**
* @brief Counter for generating unique node IDs.
Expand Down
70 changes: 36 additions & 34 deletions src/solver/expressions/visitors/AstDOTStyleVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,11 @@
** You should have received a copy of the Mozilla Public Licence 2.0
** along with Antares_Simulator. If not, see <https://opensource.org/license/mpl-2-0/>.
*/

#include "antares/solver/expressions/visitors/AstDOTStyleVisitor.h"

#include <algorithm>
#include <set>

#include <antares/solver/expressions/nodes/ExpressionsNodes.h>
#include "antares/solver/expressions/nodes/ExpressionsNodes.h"

namespace Antares::Solver::Visitors
{
Expand All @@ -42,6 +40,16 @@ static constexpr BoxStyle ComponentVariableStyle{"goldenrod", "octagon", "filled
static constexpr BoxStyle PortFieldStyle{"olive", "component", "filled, solid"};
} // namespace NodeStyle

void makeLegendTitle(std::ostream& os)
{
os << "subgraph cluster_legend {\n"
<< "label = \"Legend\";\n"
<< "style = dashed;\n"
<< "fontsize = 16;\n"
<< "color = lightgrey;\n"
<< "node [shape=plaintext];\n\n";
}

void ProcessElementLegend(const std::string& element_name, size_t size, std::ostream& os)
{
os << "legend_" << element_name << " [ label =\" " << element_name << ": " << size << "\"]\n";
Expand All @@ -52,34 +60,20 @@ void AddFiliation(std::ostream& os, const std::string& parent_id, const std::str
os << "legend_" << parent_id << " -> " << "legend_" << child_id << " [style=invis];\n";
}

void GetLegend(const std::map<std::string, std::map<const Nodes::Node*, unsigned int>>& nodeIds,
std::ostream& os)
void AstDOTStyleVisitor::makeLegend(std::ostream& os)
{
os << R"raw(subgraph cluster_legend {
label = "Legend";
style = dashed;
fontsize = 16;
color = lightgrey;
node [shape=plaintext];
)raw";

auto order_nb_type = nodeIds.size();
if (order_nb_type > 1)
if (nbNodesPerType_.empty())
{
for (auto it = nodeIds.begin(), next_it = std::next(it); next_it != nodeIds.end();
++it, ++next_it)
{
ProcessElementLegend(it->first, it->second.size(), os);
AddFiliation(os, it->first, next_it->first);
}
ProcessElementLegend(nodeIds.rbegin()->first, nodeIds.rbegin()->second.size(), os);
return;
}
else if (order_nb_type == 1)

ProcessElementLegend(nbNodesPerType_.begin()->first, nbNodesPerType_.begin()->second, os);
for (auto it = std::next(nbNodesPerType_.begin()); it != nbNodesPerType_.end(); ++it)
{
ProcessElementLegend(nodeIds.begin()->first, nodeIds.begin()->second.size(), os);
auto prev_it = std::prev(it);
AddFiliation(os, prev_it->first, it->first);
ProcessElementLegend(it->first, it->second, os);
}

os << "}\n";
}

Expand Down Expand Up @@ -186,17 +180,19 @@ std::string AstDOTStyleVisitor::name() const

unsigned int AstDOTStyleVisitor::getNodeID(const Nodes::Node* node)
{
const auto& node_name = node->name();
if (nodeIds_.find(node_name) == nodeIds_.end())
if (nodeIds_.find(node) == nodeIds_.end())
{
nodeIds_[node->name()][node] = ++nodeCount_;
nodeIds_[node] = ++nodeCount_;
}
else if (auto& id_map = nodeIds_[node_name]; id_map.find(node) == id_map.end())
return nodeIds_[node];
}

void AstDOTStyleVisitor::computeNumberNodesPerType()
{
for (const auto& [node, _]: nodeIds_)
{
id_map[node] = ++nodeCount_;
nbNodesPerType_[node->name()]++;
}

return nodeIds_[node->name()][node];
}

void AstDOTStyleVisitor::emitNode(unsigned int id,
Expand Down Expand Up @@ -238,13 +234,19 @@ void AstDOTStyleVisitor::NewTreeGraph(std::ostream& os, const std::string& tree_

void AstDOTStyleVisitor::EndTreeGraph(std::ostream& os)
{
computeNumberNodesPerType();

// Graph title showing the total number of nodes
os << "label=\"AST Diagram(Total nodes : " << nodeCount_ << ")\"\n";
os << "labelloc = \"t\"\n";
GetLegend(nodeIds_, os);

makeLegendTitle(os);
makeLegend(os);
os << "}\n";

nodeCount_ = 0;
nodeIds_.clear();
nbNodesPerType_.clear();
}

void AstDOTStyleVisitor::operator()(std::ostream& os, Nodes::Node* root)
Expand Down
13 changes: 6 additions & 7 deletions src/tests/src/solver/expressions/test_AstDOTStyleVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@

#define WIN32_LEAN_AND_MEAN

#include <variant>

#include <boost/test/unit_test.hpp>

#include <antares/solver/expressions/Registry.hxx>
Expand Down Expand Up @@ -149,15 +147,16 @@ legend_VariableNode [ label =" VariableNode: 1"]
Registry<Node> registry_;
};

BOOST_FIXTURE_TEST_CASE(tree_with_all_type_node, Fixture)
BOOST_FIXTURE_TEST_CASE(
dot_visitor_is_run_on_complex_expression___resulting_dot_content_as_expected,
Fixture)
{
std::ostringstream os;
std::ostringstream dotContentStream;

AstDOTStyleVisitor astGraphVisitor;
astGraphVisitor(os, makeExpression());
astGraphVisitor(dotContentStream, makeExpression());

// read the content of os
BOOST_CHECK_EQUAL(expectedDotContent(), os.str());
BOOST_CHECK_EQUAL(dotContentStream.str(), expectedDotContent());
}

BOOST_FIXTURE_TEST_CASE(AstDOTStyleVisitor_name, Registry<Node>)
Expand Down

0 comments on commit bc8e74c

Please sign in to comment.