From 3bbd9f693c8c45e59f1d2049359d42f19cea8a1b Mon Sep 17 00:00:00 2001
From: rakhimov
Date: Tue, 25 Apr 2017 01:13:48 -0700
Subject: [PATCH] Refactor numerical expressions
Use template member specialization instead of free template functions.
Issue #72
---
src/expression.h | 59 ++++++----------
src/expression/numerical.cc | 63 ++++-------------
src/expression/numerical.h | 130 +++++++++++++++++++++---------------
3 files changed, 108 insertions(+), 144 deletions(-)
diff --git a/src/expression.h b/src/expression.h
index fdad037686..78547dd945 100644
--- a/src/expression.h
+++ b/src/expression.h
@@ -138,7 +138,7 @@ class Expression : private boost::noncopyable {
bool sampled_; ///< Indication if the expression is already sampled.
};
-/// CRTP for Expressions with a same formula to evaluate and sample.
+/// CRTP for Expressions with the same formula to evaluate and sample.
///
/// @tparam T The Expression type with Compute function.
template
@@ -167,33 +167,6 @@ class ExpressionFormula : public Expression {
template
class NaryExpression;
-/// Validates expressions of specific type with its given arguments.
-template
-void ValidateExpression(const std::vector& /*args*/) {}
-
-/// Get the validation interval for unary expression T with a given argument.
-template
-Interval GetInterval(Expression* arg) {
- Interval arg_interval = arg->interval();
- double max_value = T()(arg_interval.upper());
- double min_value = T()(arg_interval.lower());
- auto min_max = std::minmax(max_value, min_value);
- return Interval::closed(min_max.first, min_max.second);
-}
-
-/// Get the validation interval for binary expression with given arguments.
-template
-Interval GetInterval(Expression* arg_one, Expression* arg_two) {
- Interval interval_one = arg_one->interval();
- Interval interval_two = arg_two->interval();
- double max_max = T()(interval_one.upper(), interval_two.upper());
- double max_min = T()(interval_one.upper(), interval_two.lower());
- double min_max = T()(interval_one.lower(), interval_two.upper());
- double min_min = T()(interval_one.lower(), interval_two.lower());
- auto interval_pair = std::minmax({max_max, max_min, min_max, min_min});
- return Interval::closed(interval_pair.first, interval_pair.second);
-}
-
/// Unary expression.
template
class NaryExpression : public ExpressionFormula> {
@@ -203,11 +176,15 @@ class NaryExpression : public ExpressionFormula> {
: ExpressionFormula>({expression}),
expression_(*expression) {}
- void Validate() const override {
- return ValidateExpression(Expression::args());
- }
+ void Validate() const override {}
- Interval interval() noexcept override { return GetInterval(&expression_); }
+ Interval interval() noexcept override {
+ Interval arg_interval = expression_.interval();
+ double max_value = T()(arg_interval.upper());
+ double min_value = T()(arg_interval.lower());
+ auto min_max = std::minmax(max_value, min_value);
+ return Interval::closed(min_max.first, min_max.second);
+ }
/// Computes the expression value with a given argument value extractor.
template
@@ -227,13 +204,17 @@ class NaryExpression : public ExpressionFormula> {
explicit NaryExpression(Expression* arg_one, Expression* arg_two)
: ExpressionFormula>({arg_one, arg_two}) {}
- void Validate() const override {
- return ValidateExpression(Expression::args());
- }
+ void Validate() const override {}
Interval interval() noexcept override {
- return GetInterval(Expression::args().front(),
- Expression::args().back());
+ Interval interval_one = Expression::args().front()->interval();
+ Interval interval_two = Expression::args().back()->interval();
+ double max_max = T()(interval_one.upper(), interval_two.upper());
+ double max_min = T()(interval_one.upper(), interval_two.lower());
+ double min_max = T()(interval_one.lower(), interval_two.upper());
+ double min_min = T()(interval_one.lower(), interval_two.lower());
+ auto interval_pair = std::minmax({max_max, max_min, min_max, min_min});
+ return Interval::closed(interval_pair.first, interval_pair.second);
}
/// Computes the expression value with a given argument value extractor.
@@ -259,9 +240,7 @@ class NaryExpression : public ExpressionFormula> {
throw InvalidArgument("Expression requires 2 or more arguments.");
}
- void Validate() const override {
- return ValidateExpression(Expression::args());
- }
+ void Validate() const override {}
Interval interval() noexcept override {
auto it = Expression::args().begin();
diff --git a/src/expression/numerical.cc b/src/expression/numerical.cc
index ebcdedd7a4..8987521a91 100644
--- a/src/expression/numerical.cc
+++ b/src/expression/numerical.cc
@@ -20,15 +20,14 @@
#include "numerical.h"
-#include "src/error.h"
-
namespace scram {
namespace mef {
+/// @cond Doxygen_With_Smart_Using_Declaration
template <>
-void ValidateExpression>(const std::vector& args) {
- auto it = args.begin();
- for (++it; it != args.end(); ++it) {
+void Div::Validate() const {
+ auto it = Expression::args().begin();
+ for (++it; it != Expression::args().end(); ++it) {
const auto& expr = *it;
Interval arg_interval = expr->interval();
if (expr->value() == 0 || Contains(arg_interval, 0))
@@ -37,40 +36,9 @@ void ValidateExpression>(const std::vector& args) {
}
template <>
-void ValidateExpression>(
- const std::vector& args) {
- assert(args.size() == 1);
- EnsureWithin(args.front(), Interval::closed(-1, 1),
- "Arc cos");
-}
-
-template <>
-void ValidateExpression>(
- const std::vector& args) {
- assert(args.size() == 1);
- EnsureWithin(args.front(), Interval::closed(-1, 1),
- "Arc sin");
-}
-
-template <>
-void ValidateExpression>(
- const std::vector& args) {
- assert(args.size() == 1);
- EnsurePositive(args.front(), "Natural Logarithm");
-}
-
-template <>
-void ValidateExpression>(
- const std::vector& args) {
- assert(args.size() == 1);
- EnsurePositive(args.front(), "Decimal Logarithm");
-}
-
-template <>
-void ValidateExpression>(
- const std::vector& args) {
- assert(args.size() == 2);
- auto* arg_two = args.back();
+void Mod::Validate() const {
+ assert(args().size() == 2);
+ auto* arg_two = args().back();
int arg_value = arg_two->value();
if (arg_value == 0)
throw InvalidArgument("Modulo second operand must not be 0.");
@@ -82,24 +50,17 @@ void ValidateExpression>(
}
template <>
-void ValidateExpression>(
- const std::vector& args) {
- assert(args.size() == 2);
- auto* arg_one = args.front();
- auto* arg_two = args.back();
+void Pow::Validate() const {
+ assert(args().size() == 2);
+ auto* arg_one = args().front();
+ auto* arg_two = args().back();
if (arg_one->value() == 0 && arg_two->value() <= 0)
throw InvalidArgument("0 to power 0 or less is undefined.");
if (Contains(arg_one->interval(), 0) && !IsPositive(arg_two->interval()))
throw InvalidArgument("Power expression 'base' sample range contains 0;"
"positive exponent is required.");
}
-
-template <>
-void ValidateExpression>(
- const std::vector& args) {
- assert(args.size() == 1);
- EnsureNonNegative(args.front(), "Square root argument");
-}
+/// @endcond
Mean::Mean(std::vector args) : ExpressionFormula(std::move(args)) {
if (Expression::args().size() < 2)
diff --git a/src/expression/numerical.h b/src/expression/numerical.h
index ac603efbde..adfbeaf34f 100644
--- a/src/expression/numerical.h
+++ b/src/expression/numerical.h
@@ -28,6 +28,7 @@
#include
#include "constant.h"
+#include "src/error.h"
#include "src/expression.h"
namespace scram {
@@ -57,59 +58,6 @@ struct Bifunctor { // Nasty abuse of terminology :(. Haskellers will hate this.
template
using BifunctorExpression = NaryExpression, 2>;
-/// Validation specialization for math functions.
-/// @{
-template <>
-void ValidateExpression>(const std::vector& args);
-template <>
-void ValidateExpression>(
- const std::vector& args);
-template <>
-void ValidateExpression>(
- const std::vector& args);
-template <>
-void ValidateExpression>(
- const std::vector& args);
-template <>
-void ValidateExpression>(
- const std::vector& args);
-template <>
-void ValidateExpression>(
- const std::vector& args);
-template <>
-void ValidateExpression>(
- const std::vector& args);
-template <>
-void ValidateExpression>(
- const std::vector& args);
-/// @}
-
-/// Interval specialization for math functions.
-/// @{
-template <>
-inline Interval GetInterval>(Expression* /*arg*/) {
- return Interval::closed(0, ConstantExpression::kPi.value());
-}
-template <>
-inline Interval GetInterval>(Expression* /*arg*/) {
- double half_pi = ConstantExpression::kPi.value() / 2;
- return Interval::closed(-half_pi, half_pi);
-}
-template <>
-inline Interval GetInterval>(Expression* /*arg*/) {
- double half_pi = ConstantExpression::kPi.value() / 2;
- return Interval::closed(-half_pi, half_pi);
-}
-template <>
-inline Interval GetInterval>(Expression* /*arg*/) {
- return Interval::closed(-1, 1);
-}
-template <>
-inline Interval GetInterval>(Expression* /*arg*/) {
- return Interval::closed(-1, 1);
-}
-/// @}
-
using Neg = NaryExpression, 1>; ///< Negation.
using Add = NaryExpression, -1>; ///< Sum operation.
using Sub = NaryExpression, -1>; ///< Subtraction from the first.
@@ -169,6 +117,82 @@ class Mean : public ExpressionFormula {
}
};
+/// @cond Doxygen_With_Smart_Using_Declaration
+/// Validation specialization for math functions.
+/// @{
+template <>
+void Div::Validate() const;
+
+template <>
+inline void Acos::Validate() const {
+ assert(args().size() == 1);
+ EnsureWithin(args().front(), Interval::closed(-1, 1),
+ "Arc cos");
+}
+
+template <>
+inline void Asin::Validate() const {
+ assert(args().size() == 1);
+ EnsureWithin(args().front(), Interval::closed(-1, 1),
+ "Arc sin");
+}
+
+template <>
+inline void Log::Validate() const {
+ assert(args().size() == 1);
+ EnsurePositive(args().front(), "Natural Logarithm");
+}
+
+template <>
+inline void Log10::Validate() const {
+ assert(args().size() == 1);
+ EnsurePositive(args().front(), "Decimal Logarithm");
+}
+
+template <>
+void Mod::Validate() const;
+
+template <>
+void Pow::Validate() const;
+
+template <>
+inline void Sqrt::Validate() const {
+ assert(args().size() == 1);
+ EnsureNonNegative(args().front(), "Square root argument");
+}
+/// @}
+
+/// Interval specialization for math functions.
+/// @{
+template <>
+inline Interval Acos::interval() noexcept {
+ return Interval::closed(0, ConstantExpression::kPi.value());
+}
+
+template <>
+inline Interval Asin::interval() noexcept {
+ double half_pi = ConstantExpression::kPi.value() / 2;
+ return Interval::closed(-half_pi, half_pi);
+}
+
+template <>
+inline Interval Atan::interval() noexcept {
+ double half_pi = ConstantExpression::kPi.value() / 2;
+ return Interval::closed(-half_pi, half_pi);
+}
+
+template <>
+inline Interval Cos::interval() noexcept {
+ return Interval::closed(-1, 1);
+}
+
+template <>
+inline Interval Sin::interval() noexcept {
+ return Interval::closed(-1, 1);
+}
+/// @}
+/// @endcond
+
} // namespace mef
} // namespace scram