diff --git a/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h b/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h index 0551a9f759dc3..83f056404d9a7 100644 --- a/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h +++ b/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h @@ -243,7 +243,6 @@ Status TreeEnsembleCommon::Init( node.falsenode_inc_or_n_weights.weight = 0; // nodes_falsenodeids[i] if not a leaf } - if (i < static_cast(nodes_missing_value_tracks_true.size()) && nodes_missing_value_tracks_true[i] == 1) { node.flags |= static_cast(MissingTrack::kTrue); } @@ -640,22 +639,22 @@ void TreeEnsembleCommon::ComputeAgg(concur } } // namespace detail -#define TREE_FIND_VALUE(CMP) \ - if (has_missing_tracks_) { \ - while (root->is_not_leaf()) { \ - val = x_data[root->feature_id]; \ +#define TREE_FIND_VALUE(CMP) \ + if (has_missing_tracks_) { \ + while (root->is_not_leaf()) { \ + val = x_data[root->feature_id]; \ root = (val CMP root->value_or_unique_weight || \ - (root->is_missing_track_true() && _isnan_(val))) \ - ? root->truenode_inc_or_first_weight.ptr \ - : root->falsenode_inc_or_n_weights.ptr; \ - } \ - } else { \ - while (root->is_not_leaf()) { \ - val = x_data[root->feature_id]; \ + (root->is_missing_track_true() && _isnan_(val))) \ + ? root->truenode_inc_or_first_weight.ptr \ + : root->falsenode_inc_or_n_weights.ptr; \ + } \ + } else { \ + while (root->is_not_leaf()) { \ + val = x_data[root->feature_id]; \ root = val CMP root->value_or_unique_weight \ - ? root->truenode_inc_or_first_weight.ptr \ - : root->falsenode_inc_or_n_weights.ptr; \ - } \ + ? root->truenode_inc_or_first_weight.ptr \ + : root->falsenode_inc_or_n_weights.ptr; \ + } \ } inline bool _isnan_(float x) { return std::isnan(x); } @@ -675,9 +674,9 @@ TreeEnsembleCommon::ProcessTreeNodeLeave( while (root->is_not_leaf()) { val = x_data[root->feature_id]; root = (val <= root->value_or_unique_weight || - (root->is_missing_track_true() && _isnan_(val))) - ? root->truenode_inc_or_first_weight.ptr - : root->falsenode_inc_or_n_weights.ptr; + (root->is_missing_track_true() && _isnan_(val))) + ? root->truenode_inc_or_first_weight.ptr + : root->falsenode_inc_or_n_weights.ptr; } } else { while (root->is_not_leaf()) { @@ -706,48 +705,43 @@ TreeEnsembleCommon::ProcessTreeNodeLeave( } } else { // Different rules to compare to node thresholds. ThresholdType threshold; - auto mode{root->mode()}; while (1) { val = x_data[root->feature_id]; threshold = root->value_or_unique_weight; - switch (mode) { + switch (root->mode()) { case NODE_MODE::BRANCH_LEQ: root = val <= threshold || (root->is_missing_track_true() && _isnan_(val)) - ? root->truenode_inc_or_first_weight.ptr - : root->falsenode_inc_or_n_weights.ptr; + ? root->truenode_inc_or_first_weight.ptr + : root->falsenode_inc_or_n_weights.ptr; break; case NODE_MODE::BRANCH_LT: root = val < threshold || (root->is_missing_track_true() && _isnan_(val)) - ? root->truenode_inc_or_first_weight.ptr - : root->falsenode_inc_or_n_weights.ptr; + ? root->truenode_inc_or_first_weight.ptr + : root->falsenode_inc_or_n_weights.ptr; break; case NODE_MODE::BRANCH_GTE: root = val >= threshold || (root->is_missing_track_true() && _isnan_(val)) - ? root->truenode_inc_or_first_weight.ptr - : root->falsenode_inc_or_n_weights.ptr; + ? root->truenode_inc_or_first_weight.ptr + : root->falsenode_inc_or_n_weights.ptr; break; case NODE_MODE::BRANCH_GT: root = val > threshold || (root->is_missing_track_true() && _isnan_(val)) - ? root->truenode_inc_or_first_weight.ptr - : root->falsenode_inc_or_n_weights.ptr; + ? root->truenode_inc_or_first_weight.ptr + : root->falsenode_inc_or_n_weights.ptr; break; case NODE_MODE::BRANCH_EQ: root = val == threshold || (root->is_missing_track_true() && _isnan_(val)) - ? root->truenode_inc_or_first_weight.ptr - : root->falsenode_inc_or_n_weights.ptr; + ? root->truenode_inc_or_first_weight.ptr + : root->falsenode_inc_or_n_weights.ptr; break; case NODE_MODE::BRANCH_NEQ: root = val != threshold || (root->is_missing_track_true() && _isnan_(val)) - ? root->truenode_inc_or_first_weight.ptr - : root->falsenode_inc_or_n_weights.ptr; + ? root->truenode_inc_or_first_weight.ptr + : root->falsenode_inc_or_n_weights.ptr; break; case NODE_MODE::LEAF: return root; } - mode = root->mode(); - if (mode == NODE_MODE::LEAF) { - return root; - } } } return root;