Skip to content

Commit

Permalink
Refactor numerical expressions
Browse files Browse the repository at this point in the history
Use template member specialization instead of free template functions.

Issue #72
  • Loading branch information
rakhimov committed Apr 25, 2017
1 parent 379b32e commit 3bbd9f6
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 144 deletions.
59 changes: 19 additions & 40 deletions src/expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ class Expression : private boost::noncopyable {
bool sampled_; ///< Indication if the expression is already sampled.
};

/// CRTP for Expressions with a same formula to evaluate and sample.
/// CRTP for Expressions with the same formula to evaluate and sample.
///
/// @tparam T The Expression type with Compute function.
template <class T>
Expand Down Expand Up @@ -167,33 +167,6 @@ class ExpressionFormula : public Expression {
template <typename T, int N>
class NaryExpression;

/// Validates expressions of specific type with its given arguments.
template <typename T>
void ValidateExpression(const std::vector<Expression*>& /*args*/) {}

/// Get the validation interval for unary expression T with a given argument.
template <typename T>
Interval GetInterval(Expression* arg) {
Interval arg_interval = arg->interval();
double max_value = T()(arg_interval.upper());
double min_value = T()(arg_interval.lower());
auto min_max = std::minmax(max_value, min_value);
return Interval::closed(min_max.first, min_max.second);
}

/// Get the validation interval for binary expression with given arguments.
template <typename T>
Interval GetInterval(Expression* arg_one, Expression* arg_two) {
Interval interval_one = arg_one->interval();
Interval interval_two = arg_two->interval();
double max_max = T()(interval_one.upper(), interval_two.upper());
double max_min = T()(interval_one.upper(), interval_two.lower());
double min_max = T()(interval_one.lower(), interval_two.upper());
double min_min = T()(interval_one.lower(), interval_two.lower());
auto interval_pair = std::minmax({max_max, max_min, min_max, min_min});
return Interval::closed(interval_pair.first, interval_pair.second);
}

/// Unary expression.
template <typename T>
class NaryExpression<T, 1> : public ExpressionFormula<NaryExpression<T, 1>> {
Expand All @@ -203,11 +176,15 @@ class NaryExpression<T, 1> : public ExpressionFormula<NaryExpression<T, 1>> {
: ExpressionFormula<NaryExpression<T, 1>>({expression}),
expression_(*expression) {}

void Validate() const override {
return ValidateExpression<T>(Expression::args());
}
void Validate() const override {}

Interval interval() noexcept override { return GetInterval<T>(&expression_); }
Interval interval() noexcept override {
Interval arg_interval = expression_.interval();
double max_value = T()(arg_interval.upper());
double min_value = T()(arg_interval.lower());
auto min_max = std::minmax(max_value, min_value);
return Interval::closed(min_max.first, min_max.second);
}

/// Computes the expression value with a given argument value extractor.
template <typename F>
Expand All @@ -227,13 +204,17 @@ class NaryExpression<T, 2> : public ExpressionFormula<NaryExpression<T, 2>> {
explicit NaryExpression(Expression* arg_one, Expression* arg_two)
: ExpressionFormula<NaryExpression<T, 2>>({arg_one, arg_two}) {}

void Validate() const override {
return ValidateExpression<T>(Expression::args());
}
void Validate() const override {}

Interval interval() noexcept override {
return GetInterval<T>(Expression::args().front(),
Expression::args().back());
Interval interval_one = Expression::args().front()->interval();
Interval interval_two = Expression::args().back()->interval();
double max_max = T()(interval_one.upper(), interval_two.upper());
double max_min = T()(interval_one.upper(), interval_two.lower());
double min_max = T()(interval_one.lower(), interval_two.upper());
double min_min = T()(interval_one.lower(), interval_two.lower());
auto interval_pair = std::minmax({max_max, max_min, min_max, min_min});
return Interval::closed(interval_pair.first, interval_pair.second);
}

/// Computes the expression value with a given argument value extractor.
Expand All @@ -259,9 +240,7 @@ class NaryExpression<T, -1> : public ExpressionFormula<NaryExpression<T, -1>> {
throw InvalidArgument("Expression requires 2 or more arguments.");
}

void Validate() const override {
return ValidateExpression<T>(Expression::args());
}
void Validate() const override {}

Interval interval() noexcept override {
auto it = Expression::args().begin();
Expand Down
63 changes: 12 additions & 51 deletions src/expression/numerical.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,14 @@

#include "numerical.h"

#include "src/error.h"

namespace scram {
namespace mef {

/// @cond Doxygen_With_Smart_Using_Declaration
template <>
void ValidateExpression<std::divides<>>(const std::vector<Expression*>& args) {
auto it = args.begin();
for (++it; it != args.end(); ++it) {
void Div::Validate() const {
auto it = Expression::args().begin();
for (++it; it != Expression::args().end(); ++it) {
const auto& expr = *it;
Interval arg_interval = expr->interval();
if (expr->value() == 0 || Contains(arg_interval, 0))
Expand All @@ -37,40 +36,9 @@ void ValidateExpression<std::divides<>>(const std::vector<Expression*>& args) {
}

template <>
void ValidateExpression<Functor<&std::acos>>(
const std::vector<Expression*>& args) {
assert(args.size() == 1);
EnsureWithin<InvalidArgument>(args.front(), Interval::closed(-1, 1),
"Arc cos");
}

template <>
void ValidateExpression<Functor<&std::asin>>(
const std::vector<Expression*>& args) {
assert(args.size() == 1);
EnsureWithin<InvalidArgument>(args.front(), Interval::closed(-1, 1),
"Arc sin");
}

template <>
void ValidateExpression<Functor<&std::log>>(
const std::vector<Expression*>& args) {
assert(args.size() == 1);
EnsurePositive<InvalidArgument>(args.front(), "Natural Logarithm");
}

template <>
void ValidateExpression<Functor<&std::log10>>(
const std::vector<Expression*>& args) {
assert(args.size() == 1);
EnsurePositive<InvalidArgument>(args.front(), "Decimal Logarithm");
}

template <>
void ValidateExpression<std::modulus<int>>(
const std::vector<Expression*>& args) {
assert(args.size() == 2);
auto* arg_two = args.back();
void Mod::Validate() const {
assert(args().size() == 2);
auto* arg_two = args().back();
int arg_value = arg_two->value();
if (arg_value == 0)
throw InvalidArgument("Modulo second operand must not be 0.");
Expand All @@ -82,24 +50,17 @@ void ValidateExpression<std::modulus<int>>(
}

template <>
void ValidateExpression<Bifunctor<&std::pow>>(
const std::vector<Expression*>& args) {
assert(args.size() == 2);
auto* arg_one = args.front();
auto* arg_two = args.back();
void Pow::Validate() const {
assert(args().size() == 2);
auto* arg_one = args().front();
auto* arg_two = args().back();
if (arg_one->value() == 0 && arg_two->value() <= 0)
throw InvalidArgument("0 to power 0 or less is undefined.");
if (Contains(arg_one->interval(), 0) && !IsPositive(arg_two->interval()))
throw InvalidArgument("Power expression 'base' sample range contains 0;"
"positive exponent is required.");
}

template <>
void ValidateExpression<Functor<&std::sqrt>>(
const std::vector<Expression*>& args) {
assert(args.size() == 1);
EnsureNonNegative<InvalidArgument>(args.front(), "Square root argument");
}
/// @endcond

Mean::Mean(std::vector<Expression*> args) : ExpressionFormula(std::move(args)) {
if (Expression::args().size() < 2)
Expand Down
130 changes: 77 additions & 53 deletions src/expression/numerical.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <vector>

#include "constant.h"
#include "src/error.h"
#include "src/expression.h"

namespace scram {
Expand Down Expand Up @@ -57,59 +58,6 @@ struct Bifunctor { // Nasty abuse of terminology :(. Haskellers will hate this.
template <double (*F)(double, double)>
using BifunctorExpression = NaryExpression<Bifunctor<F>, 2>;

/// Validation specialization for math functions.
/// @{
template <>
void ValidateExpression<std::divides<>>(const std::vector<Expression*>& args);
template <>
void ValidateExpression<Functor<&std::acos>>(
const std::vector<Expression*>& args);
template <>
void ValidateExpression<Functor<&std::asin>>(
const std::vector<Expression*>& args);
template <>
void ValidateExpression<Functor<&std::log>>(
const std::vector<Expression*>& args);
template <>
void ValidateExpression<Functor<&std::log10>>(
const std::vector<Expression*>& args);
template <>
void ValidateExpression<std::modulus<int>>(
const std::vector<Expression*>& args);
template <>
void ValidateExpression<Bifunctor<&std::pow>>(
const std::vector<Expression*>& args);
template <>
void ValidateExpression<Functor<&std::sqrt>>(
const std::vector<Expression*>& args);
/// @}

/// Interval specialization for math functions.
/// @{
template <>
inline Interval GetInterval<Functor<&std::acos>>(Expression* /*arg*/) {
return Interval::closed(0, ConstantExpression::kPi.value());
}
template <>
inline Interval GetInterval<Functor<&std::asin>>(Expression* /*arg*/) {
double half_pi = ConstantExpression::kPi.value() / 2;
return Interval::closed(-half_pi, half_pi);
}
template <>
inline Interval GetInterval<Functor<&std::atan>>(Expression* /*arg*/) {
double half_pi = ConstantExpression::kPi.value() / 2;
return Interval::closed(-half_pi, half_pi);
}
template <>
inline Interval GetInterval<Functor<&std::cos>>(Expression* /*arg*/) {
return Interval::closed(-1, 1);
}
template <>
inline Interval GetInterval<Functor<&std::sin>>(Expression* /*arg*/) {
return Interval::closed(-1, 1);
}
/// @}

using Neg = NaryExpression<std::negate<>, 1>; ///< Negation.
using Add = NaryExpression<std::plus<>, -1>; ///< Sum operation.
using Sub = NaryExpression<std::minus<>, -1>; ///< Subtraction from the first.
Expand Down Expand Up @@ -169,6 +117,82 @@ class Mean : public ExpressionFormula<Mean> {
}
};

/// @cond Doxygen_With_Smart_Using_Declaration
/// Validation specialization for math functions.
/// @{
template <>
void Div::Validate() const;

template <>
inline void Acos::Validate() const {
assert(args().size() == 1);
EnsureWithin<InvalidArgument>(args().front(), Interval::closed(-1, 1),
"Arc cos");
}

template <>
inline void Asin::Validate() const {
assert(args().size() == 1);
EnsureWithin<InvalidArgument>(args().front(), Interval::closed(-1, 1),
"Arc sin");
}

template <>
inline void Log::Validate() const {
assert(args().size() == 1);
EnsurePositive<InvalidArgument>(args().front(), "Natural Logarithm");
}

template <>
inline void Log10::Validate() const {
assert(args().size() == 1);
EnsurePositive<InvalidArgument>(args().front(), "Decimal Logarithm");
}

template <>
void Mod::Validate() const;

template <>
void Pow::Validate() const;

template <>
inline void Sqrt::Validate() const {
assert(args().size() == 1);
EnsureNonNegative<InvalidArgument>(args().front(), "Square root argument");
}
/// @}

/// Interval specialization for math functions.
/// @{
template <>
inline Interval Acos::interval() noexcept {
return Interval::closed(0, ConstantExpression::kPi.value());
}

template <>
inline Interval Asin::interval() noexcept {
double half_pi = ConstantExpression::kPi.value() / 2;
return Interval::closed(-half_pi, half_pi);
}

template <>
inline Interval Atan::interval() noexcept {
double half_pi = ConstantExpression::kPi.value() / 2;
return Interval::closed(-half_pi, half_pi);
}

template <>
inline Interval Cos::interval() noexcept {
return Interval::closed(-1, 1);
}

template <>
inline Interval Sin::interval() noexcept {
return Interval::closed(-1, 1);
}
/// @}
/// @endcond

} // namespace mef
} // namespace scram

Expand Down

0 comments on commit 3bbd9f6

Please sign in to comment.