Skip to content

Commit

Permalink
Add switch expression
Browse files Browse the repository at this point in the history
Issue #67
  • Loading branch information
rakhimov committed Apr 25, 2017
1 parent 5a384ae commit 3c835f3
Show file tree
Hide file tree
Showing 6 changed files with 133 additions and 1 deletion.
17 changes: 17 additions & 0 deletions share/input.rng
Original file line number Diff line number Diff line change
Expand Up @@ -841,6 +841,7 @@
<define name="conditional-operation">
<choice>
<ref name="if-then-else-operation"/>
<ref name="switch-operation"/>
</choice>
</define>

Expand All @@ -852,6 +853,22 @@
</element>
</define>

<define name="switch-operation">
<element name="switch">
<zeroOrMore>
<ref name="case-operation"/>
</zeroOrMore>
<ref name="expression"/>
</element>
</define>

<define name="case-operation">
<element name="case">
<ref name="expression"/>
<ref name="expression"/>
</element>
</define>

<!-- *********************************************************** -->
<!-- V.2.6. Built-ins -->
<!-- *********************************************************** -->
Expand Down
22 changes: 22 additions & 0 deletions src/expression/conditional.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,27 @@ Interval Ite::interval() noexcept {
std::max(then_interval.upper(), else_interval.upper()));
}

Switch::Switch(std::vector<Case> cases, Expression* default_value)
: ExpressionFormula({default_value}),
cases_(std::move(cases)),
default_value_(*default_value) {
for (auto& case_arm : cases_) {
Expression::AddArg(&case_arm.condition);
Expression::AddArg(&case_arm.value);
}
}

Interval Switch::interval() noexcept {
Interval default_interval = default_value_.interval();
double min_value = default_interval.lower();
double max_value = default_interval.upper();
for (auto& case_arm : cases_) {
Interval case_interval = case_arm.value.interval();
min_value = std::min(min_value, case_interval.lower());
max_value = std::max(max_value, case_interval.upper());
}
return Interval::closed(min_value, max_value);
}

} // namespace mef
} // namespace scram
32 changes: 32 additions & 0 deletions src/expression/conditional.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
#ifndef SCRAM_SRC_EXPRESSION_CONDITIONAL_H_
#define SCRAM_SRC_EXPRESSION_CONDITIONAL_H_

#include <vector>

#include "src/expression.h"

namespace scram {
Expand All @@ -45,6 +47,36 @@ class Ite : public ExpressionFormula<Ite> {
}
};

/// Switch-Case conditional operations.
class Switch : public ExpressionFormula<Switch> {
public:
/// Individual cases in the switch-case operation.
struct Case {
Expression& condition; ///< The case condition.
Expression& value; ///< The value to evaluated if the condition is true.
};

/// @param[in] cases The collection of cases to evaluate.
/// @param[in] default_value The default value if all cases are false.
Switch(std::vector<Case> cases, Expression* default_value);

Interval interval() noexcept override;

/// Computes the switch-case expression with the given evaluator.
template <typename F>
double Compute(F&& eval) noexcept {
for (Case& case_arm : cases_) {
if (eval(&case_arm.condition))
return eval(&case_arm.value);
}
return eval(&default_value_);
}

private:
std::vector<Case> cases_; ///< Ordered collection of cases.
Expression& default_value_; ///< The default case value.
};

} // namespace mef
} // namespace scram

Expand Down
24 changes: 23 additions & 1 deletion src/initializer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -790,6 +790,27 @@ std::unique_ptr<Expression> Initializer::Extract<PeriodicTest>(
}
}

/// Specialization for Switch-Case operation extraction.
template <>
std::unique_ptr<Expression> Initializer::Extract<Switch>(
const xmlpp::NodeSet& args,
const std::string& base_path,
Initializer* init) {
assert(!args.empty());
Expression* default_value =
init->GetExpression(XmlElement(args.back()), base_path);
std::vector<Switch::Case> cases;
auto it_end = std::prev(args.end());
for (auto it = args.begin(); it != it_end; ++it) {
xmlpp::NodeSet nodes = (*it)->find("./*");
assert(nodes.size() == 2);
cases.push_back(
{*init->GetExpression(XmlElement(nodes.front()), base_path),
*init->GetExpression(XmlElement(nodes.back()), base_path)});
}
return std::make_unique<Switch>(std::move(cases), default_value);
}

const Initializer::ExtractorMap Initializer::kExpressionExtractors_ = {
{"exponential", &Extract<Exponential>},
{"GLM", &Extract<Glm>},
Expand Down Expand Up @@ -836,7 +857,8 @@ const Initializer::ExtractorMap Initializer::kExpressionExtractors_ = {
{"gt", &Extract<Gt>},
{"leq", &Extract<Leq>},
{"geq", &Extract<Geq>},
{"ite", &Extract<Ite>}};
{"ite", &Extract<Ite>},
{"switch", &Extract<Switch>}};

Expression* Initializer::GetExpression(const xmlpp::Element* expr_element,
const std::string& base_path) {
Expand Down
20 changes: 20 additions & 0 deletions tests/expression_tests.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1167,6 +1167,26 @@ TEST(ExpressionTest, Ite) {
<< dev->interval();
}

TEST(ExpressionTest, Switch) {
OpenExpression arg_one(1);
OpenExpression arg_two(42, 42, 32, 52);
OpenExpression arg_three(10, 10, 5, 15);
std::unique_ptr<Expression> dev;
ASSERT_NO_THROW(
dev = std::make_unique<Switch>(
std::vector<Switch::Case>{{arg_one, arg_two}}, &arg_three));
EXPECT_DOUBLE_EQ(42, dev->value());
arg_one.mean = 0;
EXPECT_DOUBLE_EQ(10, dev->value());
arg_one.mean = 0.5;
EXPECT_DOUBLE_EQ(42, dev->value());

EXPECT_TRUE(Interval::closed(5, 52) == dev->interval())
<< dev->interval();

EXPECT_DOUBLE_EQ(10, Switch({}, &arg_three).value());
}

} // namespace test
} // namespace mef
} // namespace scram
19 changes: 19 additions & 0 deletions tests/input/fta/correct_expressions.xml
Original file line number Diff line number Diff line change
Expand Up @@ -360,5 +360,24 @@ The input tries to utilize all the functionality including optional cases.
<float value="-42"/>
</ite>
</define-parameter>
<define-parameter name="switch">
<switch>
<case>
<eq>
<parameter name="if-then-else"/>
<int value="1"/>
</eq>
<float value="1.0e-4"/>
</case>
<case>
<eq>
<parameter name="Modulo"/>
<int value="2"/>
</eq>
<float value="2.5e-4"/>
</case>
<float value="1.0e-3"/>
</switch>
</define-parameter>
</model-data>
</opsa-mef>

0 comments on commit 3c835f3

Please sign in to comment.