From 3bbd9f693c8c45e59f1d2049359d42f19cea8a1b Mon Sep 17 00:00:00 2001 From: rakhimov Date: Tue, 25 Apr 2017 01:13:48 -0700 Subject: [PATCH] Refactor numerical expressions Use template member specialization instead of free template functions. Issue #72 --- src/expression.h | 59 ++++++---------- src/expression/numerical.cc | 63 ++++------------- src/expression/numerical.h | 130 +++++++++++++++++++++--------------- 3 files changed, 108 insertions(+), 144 deletions(-) diff --git a/src/expression.h b/src/expression.h index fdad037686..78547dd945 100644 --- a/src/expression.h +++ b/src/expression.h @@ -138,7 +138,7 @@ class Expression : private boost::noncopyable { bool sampled_; ///< Indication if the expression is already sampled. }; -/// CRTP for Expressions with a same formula to evaluate and sample. +/// CRTP for Expressions with the same formula to evaluate and sample. /// /// @tparam T The Expression type with Compute function. template @@ -167,33 +167,6 @@ class ExpressionFormula : public Expression { template class NaryExpression; -/// Validates expressions of specific type with its given arguments. -template -void ValidateExpression(const std::vector& /*args*/) {} - -/// Get the validation interval for unary expression T with a given argument. -template -Interval GetInterval(Expression* arg) { - Interval arg_interval = arg->interval(); - double max_value = T()(arg_interval.upper()); - double min_value = T()(arg_interval.lower()); - auto min_max = std::minmax(max_value, min_value); - return Interval::closed(min_max.first, min_max.second); -} - -/// Get the validation interval for binary expression with given arguments. -template -Interval GetInterval(Expression* arg_one, Expression* arg_two) { - Interval interval_one = arg_one->interval(); - Interval interval_two = arg_two->interval(); - double max_max = T()(interval_one.upper(), interval_two.upper()); - double max_min = T()(interval_one.upper(), interval_two.lower()); - double min_max = T()(interval_one.lower(), interval_two.upper()); - double min_min = T()(interval_one.lower(), interval_two.lower()); - auto interval_pair = std::minmax({max_max, max_min, min_max, min_min}); - return Interval::closed(interval_pair.first, interval_pair.second); -} - /// Unary expression. template class NaryExpression : public ExpressionFormula> { @@ -203,11 +176,15 @@ class NaryExpression : public ExpressionFormula> { : ExpressionFormula>({expression}), expression_(*expression) {} - void Validate() const override { - return ValidateExpression(Expression::args()); - } + void Validate() const override {} - Interval interval() noexcept override { return GetInterval(&expression_); } + Interval interval() noexcept override { + Interval arg_interval = expression_.interval(); + double max_value = T()(arg_interval.upper()); + double min_value = T()(arg_interval.lower()); + auto min_max = std::minmax(max_value, min_value); + return Interval::closed(min_max.first, min_max.second); + } /// Computes the expression value with a given argument value extractor. template @@ -227,13 +204,17 @@ class NaryExpression : public ExpressionFormula> { explicit NaryExpression(Expression* arg_one, Expression* arg_two) : ExpressionFormula>({arg_one, arg_two}) {} - void Validate() const override { - return ValidateExpression(Expression::args()); - } + void Validate() const override {} Interval interval() noexcept override { - return GetInterval(Expression::args().front(), - Expression::args().back()); + Interval interval_one = Expression::args().front()->interval(); + Interval interval_two = Expression::args().back()->interval(); + double max_max = T()(interval_one.upper(), interval_two.upper()); + double max_min = T()(interval_one.upper(), interval_two.lower()); + double min_max = T()(interval_one.lower(), interval_two.upper()); + double min_min = T()(interval_one.lower(), interval_two.lower()); + auto interval_pair = std::minmax({max_max, max_min, min_max, min_min}); + return Interval::closed(interval_pair.first, interval_pair.second); } /// Computes the expression value with a given argument value extractor. @@ -259,9 +240,7 @@ class NaryExpression : public ExpressionFormula> { throw InvalidArgument("Expression requires 2 or more arguments."); } - void Validate() const override { - return ValidateExpression(Expression::args()); - } + void Validate() const override {} Interval interval() noexcept override { auto it = Expression::args().begin(); diff --git a/src/expression/numerical.cc b/src/expression/numerical.cc index ebcdedd7a4..8987521a91 100644 --- a/src/expression/numerical.cc +++ b/src/expression/numerical.cc @@ -20,15 +20,14 @@ #include "numerical.h" -#include "src/error.h" - namespace scram { namespace mef { +/// @cond Doxygen_With_Smart_Using_Declaration template <> -void ValidateExpression>(const std::vector& args) { - auto it = args.begin(); - for (++it; it != args.end(); ++it) { +void Div::Validate() const { + auto it = Expression::args().begin(); + for (++it; it != Expression::args().end(); ++it) { const auto& expr = *it; Interval arg_interval = expr->interval(); if (expr->value() == 0 || Contains(arg_interval, 0)) @@ -37,40 +36,9 @@ void ValidateExpression>(const std::vector& args) { } template <> -void ValidateExpression>( - const std::vector& args) { - assert(args.size() == 1); - EnsureWithin(args.front(), Interval::closed(-1, 1), - "Arc cos"); -} - -template <> -void ValidateExpression>( - const std::vector& args) { - assert(args.size() == 1); - EnsureWithin(args.front(), Interval::closed(-1, 1), - "Arc sin"); -} - -template <> -void ValidateExpression>( - const std::vector& args) { - assert(args.size() == 1); - EnsurePositive(args.front(), "Natural Logarithm"); -} - -template <> -void ValidateExpression>( - const std::vector& args) { - assert(args.size() == 1); - EnsurePositive(args.front(), "Decimal Logarithm"); -} - -template <> -void ValidateExpression>( - const std::vector& args) { - assert(args.size() == 2); - auto* arg_two = args.back(); +void Mod::Validate() const { + assert(args().size() == 2); + auto* arg_two = args().back(); int arg_value = arg_two->value(); if (arg_value == 0) throw InvalidArgument("Modulo second operand must not be 0."); @@ -82,24 +50,17 @@ void ValidateExpression>( } template <> -void ValidateExpression>( - const std::vector& args) { - assert(args.size() == 2); - auto* arg_one = args.front(); - auto* arg_two = args.back(); +void Pow::Validate() const { + assert(args().size() == 2); + auto* arg_one = args().front(); + auto* arg_two = args().back(); if (arg_one->value() == 0 && arg_two->value() <= 0) throw InvalidArgument("0 to power 0 or less is undefined."); if (Contains(arg_one->interval(), 0) && !IsPositive(arg_two->interval())) throw InvalidArgument("Power expression 'base' sample range contains 0;" "positive exponent is required."); } - -template <> -void ValidateExpression>( - const std::vector& args) { - assert(args.size() == 1); - EnsureNonNegative(args.front(), "Square root argument"); -} +/// @endcond Mean::Mean(std::vector args) : ExpressionFormula(std::move(args)) { if (Expression::args().size() < 2) diff --git a/src/expression/numerical.h b/src/expression/numerical.h index ac603efbde..adfbeaf34f 100644 --- a/src/expression/numerical.h +++ b/src/expression/numerical.h @@ -28,6 +28,7 @@ #include #include "constant.h" +#include "src/error.h" #include "src/expression.h" namespace scram { @@ -57,59 +58,6 @@ struct Bifunctor { // Nasty abuse of terminology :(. Haskellers will hate this. template using BifunctorExpression = NaryExpression, 2>; -/// Validation specialization for math functions. -/// @{ -template <> -void ValidateExpression>(const std::vector& args); -template <> -void ValidateExpression>( - const std::vector& args); -template <> -void ValidateExpression>( - const std::vector& args); -template <> -void ValidateExpression>( - const std::vector& args); -template <> -void ValidateExpression>( - const std::vector& args); -template <> -void ValidateExpression>( - const std::vector& args); -template <> -void ValidateExpression>( - const std::vector& args); -template <> -void ValidateExpression>( - const std::vector& args); -/// @} - -/// Interval specialization for math functions. -/// @{ -template <> -inline Interval GetInterval>(Expression* /*arg*/) { - return Interval::closed(0, ConstantExpression::kPi.value()); -} -template <> -inline Interval GetInterval>(Expression* /*arg*/) { - double half_pi = ConstantExpression::kPi.value() / 2; - return Interval::closed(-half_pi, half_pi); -} -template <> -inline Interval GetInterval>(Expression* /*arg*/) { - double half_pi = ConstantExpression::kPi.value() / 2; - return Interval::closed(-half_pi, half_pi); -} -template <> -inline Interval GetInterval>(Expression* /*arg*/) { - return Interval::closed(-1, 1); -} -template <> -inline Interval GetInterval>(Expression* /*arg*/) { - return Interval::closed(-1, 1); -} -/// @} - using Neg = NaryExpression, 1>; ///< Negation. using Add = NaryExpression, -1>; ///< Sum operation. using Sub = NaryExpression, -1>; ///< Subtraction from the first. @@ -169,6 +117,82 @@ class Mean : public ExpressionFormula { } }; +/// @cond Doxygen_With_Smart_Using_Declaration +/// Validation specialization for math functions. +/// @{ +template <> +void Div::Validate() const; + +template <> +inline void Acos::Validate() const { + assert(args().size() == 1); + EnsureWithin(args().front(), Interval::closed(-1, 1), + "Arc cos"); +} + +template <> +inline void Asin::Validate() const { + assert(args().size() == 1); + EnsureWithin(args().front(), Interval::closed(-1, 1), + "Arc sin"); +} + +template <> +inline void Log::Validate() const { + assert(args().size() == 1); + EnsurePositive(args().front(), "Natural Logarithm"); +} + +template <> +inline void Log10::Validate() const { + assert(args().size() == 1); + EnsurePositive(args().front(), "Decimal Logarithm"); +} + +template <> +void Mod::Validate() const; + +template <> +void Pow::Validate() const; + +template <> +inline void Sqrt::Validate() const { + assert(args().size() == 1); + EnsureNonNegative(args().front(), "Square root argument"); +} +/// @} + +/// Interval specialization for math functions. +/// @{ +template <> +inline Interval Acos::interval() noexcept { + return Interval::closed(0, ConstantExpression::kPi.value()); +} + +template <> +inline Interval Asin::interval() noexcept { + double half_pi = ConstantExpression::kPi.value() / 2; + return Interval::closed(-half_pi, half_pi); +} + +template <> +inline Interval Atan::interval() noexcept { + double half_pi = ConstantExpression::kPi.value() / 2; + return Interval::closed(-half_pi, half_pi); +} + +template <> +inline Interval Cos::interval() noexcept { + return Interval::closed(-1, 1); +} + +template <> +inline Interval Sin::interval() noexcept { + return Interval::closed(-1, 1); +} +/// @} +/// @endcond + } // namespace mef } // namespace scram