diff --git a/include/solver.h b/include/solver.h index 15dcd1b..d59c25c 100644 --- a/include/solver.h +++ b/include/solver.h @@ -17,6 +17,17 @@ * @{ */ +/** + * @brief Map of equivalences to be applied to the expressions + */ +extern std::unordered_map equivalences; + +/** + * @brief Fill the map of equivalences to be applied to the expressions. + * Must be called before any other functions in this file. + */ +void preprocess(std::shared_ptr lhs, std::shared_ptr rhs); + /** * @brief Prove the expressions are equivalence using equivalence laws by solving the left hand side to the right hand side. * Required that the expressions are actually equivalent for this to work. diff --git a/src/solver.cpp b/src/solver.cpp index 5cfd21e..98fd828 100644 --- a/src/solver.cpp +++ b/src/solver.cpp @@ -1,4 +1,104 @@ /** * @file solver.cpp * @brief Implementation file for solver functions - */ \ No newline at end of file + */ + +#include "../include/solver.h" + +std::unordered_map equivalences = {}; + +void preprocess(std::shared_ptr lhs, std::shared_ptr rhs) { + // include laws + equivalences.clear(); + equivalences.insert(EquivLaws::laws.begin(), EquivLaws::laws.end()); + + // convert lhs and rhs to strings + std::string lhsString = lhs->toStringTree(); + std::string rhsString = rhs->toStringTree(); + + if (lhsString.find("->") != std::string::npos || rhsString.find("->") != std::string::npos) { + equivalences.insert(EquivLaws::implications.begin(), EquivLaws::implications.end()); + } + + if (lhsString.find("<=>") != std::string::npos || rhsString.find("<=>") != std::string::npos) { + equivalences.insert(EquivLaws::bidirectionalImplications.begin(), EquivLaws::bidirectionalImplications.end()); + } +} + +std::vector> proveEquivalence(std::shared_ptr lhs, std::shared_ptr rhs) +{ + if (lhs->compare(rhs)) + return {{"", "Given"}}; + + std::vector> steps; + + std::queue> queue; + std::unordered_map> visited; + + queue.push(lhs); + visited[lhs->toStringTree()] = {"", "Given"}; + + bool found = false; + + while (!queue.empty()) + { + std::shared_ptr expr = queue.front(); + queue.pop(); + + if (expr->compare(rhs)) + { + // found the rhs, now backtrack the visited map to get the steps + std::string currentExprString = expr->toStringTree(); + while (currentExprString != "") + { + steps.push_back({currentExprString, visited[currentExprString].second}); + currentExprString = visited[currentExprString].first; + } + std::reverse(steps.begin(), steps.end()); + break; + } + + generateNextSteps(expr, rhs, found, queue, visited); + } + return steps; +} + +void generateNextSteps(std::shared_ptr expr, std::shared_ptr end, bool &found, std::queue> &queue, std::unordered_map> &visited) +{ + for (auto equiv : equivalences) + { + if (found) + return; + + auto funct = equiv.first; + auto lawName = equiv.second; + + std::shared_ptr newExpr = expr->cloneTree(); + + if (funct(newExpr)) + { + std::string newExprString = newExpr->toStringTree(); + if (newExprString.length() > 100) + continue; + // ignore extremely long expressions + + if (visited.find(newExprString) == visited.end()) + { + visited[newExprString] = {expr->toStringTree(), lawName}; + if (newExpr->compareTree(end)) + found = true; + while (newExpr->getParent()) + newExpr = newExpr->getParent(); + queue.push(newExpr); + } + } + } + + if (!found) + { + if (expr->getLeft() && !expr->getLeft()->isVar()) + generateNextSteps(expr->getLeft(), end, found, queue, visited); + if (expr->getRight() && !expr->getRight()->isVar()) + generateNextSteps(expr->getRight(), end, found, queue, visited); + } +}