Skip to content

Commit

Permalink
nest floating point addition and multiply operations (#794)
Browse files Browse the repository at this point in the history
* nest floating point addition and multiply operations

* fix mood
  • Loading branch information
water111 authored Aug 31, 2021
1 parent c108340 commit 41507f1
Show file tree
Hide file tree
Showing 31 changed files with 576 additions and 865 deletions.
6 changes: 6 additions & 0 deletions decompiler/IR2/Form.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,12 @@ class SimpleExpressionElement : public FormElement {
FormStack& stack,
std::vector<FormElement*>* result,
bool allow_side_effects);
void update_from_stack_float_2_nestable(const Env& env,
FixedOperatorKind kind,
FormPool& pool,
FormStack& stack,
std::vector<FormElement*>* result,
bool allow_side_effects);
void update_from_stack_float_1(const Env& env,
FixedOperatorKind kind,
FormPool& pool,
Expand Down
152 changes: 96 additions & 56 deletions decompiler/IR2/FormExpressionAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -670,6 +670,9 @@ void SimpleExpressionElement::update_from_stack_div_s(const Env& env,
}
}

/*!
* Update a two-argument form that uses two floats.
*/
void SimpleExpressionElement::update_from_stack_float_2(const Env& env,
FixedOperatorKind kind,
FormPool& pool,
Expand All @@ -678,8 +681,6 @@ void SimpleExpressionElement::update_from_stack_float_2(const Env& env,
bool allow_side_effects) {
if (is_float_type(env, m_my_idx, m_expr.get_arg(0).var()) &&
is_float_type(env, m_my_idx, m_expr.get_arg(1).var())) {
// todo - check the order here

auto args = pop_to_forms({m_expr.get_arg(0).var(), m_expr.get_arg(1).var()}, env, pool, stack,
allow_side_effects);
auto new_form = pool.alloc_element<GenericElement>(GenericOperator::make_fixed(kind),
Expand All @@ -694,69 +695,55 @@ void SimpleExpressionElement::update_from_stack_float_2(const Env& env,
}
}

void SimpleExpressionElement::update_from_stack_float_1(const Env& env,
FixedOperatorKind kind,
FormPool& pool,
FormStack& stack,
std::vector<FormElement*>* result,
bool allow_side_effects) {
if (is_float_type(env, m_my_idx, m_expr.get_arg(0).var())) {
auto args = pop_to_forms({m_expr.get_arg(0).var()}, env, pool, stack, allow_side_effects);
auto new_form =
pool.alloc_element<GenericElement>(GenericOperator::make_fixed(kind), args.at(0));
result->push_back(new_form);
} else {
throw std::runtime_error(fmt::format("Floating point division attempted on invalid types."));
}
}

void SimpleExpressionElement::update_from_stack_si_1(const Env& env,
FixedOperatorKind kind,
FormPool& pool,
FormStack& stack,
std::vector<FormElement*>* result,
bool allow_side_effects) {
auto in_type = env.get_types_before_op(m_my_idx).get(m_expr.get_arg(0).var().reg()).typespec();
auto arg = pop_to_forms({m_expr.get_arg(0).var()}, env, pool, stack, allow_side_effects).at(0);
result->push_back(pool.alloc_element<GenericElement>(
GenericOperator::make_fixed(kind),
make_cast_if_needed(arg, in_type, TypeSpec("int"), pool, env)));
}

namespace {
std::vector<Form*> get_addition_elements(Form* in) {
std::vector<Form*> get_math_op_elements(Form* in, FixedOperatorKind kind) {
auto gen_elt = in->try_as_element<GenericElement>();
if (gen_elt && gen_elt->op().is_fixed(FixedOperatorKind::ADDITION)) {
if (gen_elt && gen_elt->op().is_fixed(kind)) {
return gen_elt->elts();
} else {
return {in};
}
}

FormElement* make_and_compact_addition(Form* arg0,
Form* arg1,
const std::optional<TypeSpec>& arg0_cast,
const std::optional<TypeSpec>& arg1_cast,
FormPool& pool,
const Env& env) {
FormElement* make_and_compact_math_op(Form* arg0,
Form* arg1,
const std::optional<TypeSpec>& arg0_cast,
const std::optional<TypeSpec>& arg1_cast,
FormPool& pool,
const Env& env,
FixedOperatorKind operator_kind,
bool inline_first,
bool inline_second) {
if (!arg1_cast) {
auto arg0_elts = get_addition_elements(arg0);
std::vector<Form*> arg0_elts;
if (inline_first) {
arg0_elts = get_math_op_elements(arg0, operator_kind);
} else {
arg0_elts = {arg0};
}

assert(!arg0_elts.empty());
if (arg0_cast) {
arg0_elts.front() = cast_form(arg0_elts.front(), *arg0_cast, pool, env);
}

// it's fine to only cast the first thing here - the rest are already cast properly.
auto arg1_elts = get_addition_elements(arg1);
std::vector<Form*> arg1_elts;
if (inline_second) {
arg1_elts = get_math_op_elements(arg1, operator_kind);
} else {
arg1_elts = {arg1};
}

assert(!arg1_elts.empty());
if (arg1_cast) {
arg1_elts.front() = cast_form(arg1_elts.front(), *arg1_cast, pool, env);
}

// add all together
arg0_elts.insert(arg0_elts.end(), arg1_elts.begin(), arg1_elts.end());
return pool.alloc_element<GenericElement>(
GenericOperator::make_fixed(FixedOperatorKind::ADDITION), arg0_elts);
return pool.alloc_element<GenericElement>(GenericOperator::make_fixed(operator_kind),
arg0_elts);
} else {
if (arg0_cast) {
arg0 = cast_form(arg0, *arg0_cast, pool, env);
Expand All @@ -765,12 +752,69 @@ FormElement* make_and_compact_addition(Form* arg0,
if (arg1_cast) {
arg1 = cast_form(arg1, *arg1_cast, pool, env);
}
return pool.alloc_element<GenericElement>(
GenericOperator::make_fixed(FixedOperatorKind::ADDITION), arg0, arg1);
return pool.alloc_element<GenericElement>(GenericOperator::make_fixed(operator_kind), arg0,
arg1);
}
}
} // namespace

/*!
* Update a two-argument form that uses two floats.
* This is for operations like * and + that can be nested
* (* (* a b)) -> (* a b c)
* Note that we only apply this to the _first_ argument to keep the order of operations the same.
*/
void SimpleExpressionElement::update_from_stack_float_2_nestable(const Env& env,
FixedOperatorKind kind,
FormPool& pool,
FormStack& stack,
std::vector<FormElement*>* result,
bool allow_side_effects) {
if (is_float_type(env, m_my_idx, m_expr.get_arg(0).var()) &&
is_float_type(env, m_my_idx, m_expr.get_arg(1).var())) {
auto args = pop_to_forms({m_expr.get_arg(0).var(), m_expr.get_arg(1).var()}, env, pool, stack,
allow_side_effects);
auto new_form =
make_and_compact_math_op(args.at(0), args.at(1), {}, {}, pool, env, kind, true, false);
result->push_back(new_form);
} else {
auto type0 = env.get_types_before_op(m_my_idx).get(m_expr.get_arg(0).var().reg());
auto type1 = env.get_types_before_op(m_my_idx).get(m_expr.get_arg(1).var().reg());
throw std::runtime_error(fmt::format(
"[OP: {}] - Floating point math attempted on invalid types: {} and {} in op {}.", m_my_idx,
type0.print(), type1.print(), to_string(env)));
}
}

void SimpleExpressionElement::update_from_stack_float_1(const Env& env,
FixedOperatorKind kind,
FormPool& pool,
FormStack& stack,
std::vector<FormElement*>* result,
bool allow_side_effects) {
if (is_float_type(env, m_my_idx, m_expr.get_arg(0).var())) {
auto args = pop_to_forms({m_expr.get_arg(0).var()}, env, pool, stack, allow_side_effects);
auto new_form =
pool.alloc_element<GenericElement>(GenericOperator::make_fixed(kind), args.at(0));
result->push_back(new_form);
} else {
throw std::runtime_error(fmt::format("Floating point division attempted on invalid types."));
}
}

void SimpleExpressionElement::update_from_stack_si_1(const Env& env,
FixedOperatorKind kind,
FormPool& pool,
FormStack& stack,
std::vector<FormElement*>* result,
bool allow_side_effects) {
auto in_type = env.get_types_before_op(m_my_idx).get(m_expr.get_arg(0).var().reg()).typespec();
auto arg = pop_to_forms({m_expr.get_arg(0).var()}, env, pool, stack, allow_side_effects).at(0);
result->push_back(pool.alloc_element<GenericElement>(
GenericOperator::make_fixed(kind),
make_cast_if_needed(arg, in_type, TypeSpec("int"), pool, env)));
}

void SimpleExpressionElement::update_from_stack_add_i(const Env& env,
FormPool& pool,
FormStack& stack,
Expand Down Expand Up @@ -1050,11 +1094,7 @@ void SimpleExpressionElement::update_from_stack_add_i(const Env& env,
}
}

if (false && ((arg0_i && arg1_i) || (arg0_u && arg1_u))) {
auto new_form = pool.alloc_element<GenericElement>(
GenericOperator::make_fixed(FixedOperatorKind::ADDITION), args.at(0), args.at(1));
result->push_back(new_form);
} else if (arg0_ptr) {
if (arg0_ptr) {
auto new_form = pool.alloc_element<GenericElement>(
GenericOperator::make_fixed(FixedOperatorKind::ADDITION_PTR), args.at(0), args.at(1));
result->push_back(new_form);
Expand All @@ -1076,8 +1116,8 @@ void SimpleExpressionElement::update_from_stack_add_i(const Env& env,
arg1_cast = TypeSpec(arg0_i ? "int" : "uint");
}

result->push_back(
make_and_compact_addition(args.at(0), args.at(1), arg0_cast, arg1_cast, pool, env));
result->push_back(make_and_compact_math_op(args.at(0), args.at(1), arg0_cast, arg1_cast, pool,
env, FixedOperatorKind::ADDITION, true, true));
}
}

Expand Down Expand Up @@ -1960,12 +2000,12 @@ void SimpleExpressionElement::update_from_stack(const Env& env,
allow_side_effects);
break;
case SimpleExpression::Kind::MUL_S:
update_from_stack_float_2(env, FixedOperatorKind::MULTIPLICATION, pool, stack, result,
allow_side_effects);
update_from_stack_float_2_nestable(env, FixedOperatorKind::MULTIPLICATION, pool, stack,
result, allow_side_effects);
break;
case SimpleExpression::Kind::ADD_S:
update_from_stack_float_2(env, FixedOperatorKind::ADDITION, pool, stack, result,
allow_side_effects);
update_from_stack_float_2_nestable(env, FixedOperatorKind::ADDITION, pool, stack, result,
allow_side_effects);
break;
case SimpleExpression::Kind::MAX_S:
update_from_stack_float_2(env, FixedOperatorKind::FMAX, pool, stack, result,
Expand Down
Loading

0 comments on commit 41507f1

Please sign in to comment.