Skip to content

Commit

Permalink
Lint
Browse files Browse the repository at this point in the history
  • Loading branch information
adityagoel4512 committed Sep 11, 2023
1 parent 82dcb34 commit 4bf61f7
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 61 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ struct TreeNodeElement {
// PtrOrWeight acts as a tagged union, with the "tag" being whether the node is a leaf or not (see `is_not_leaf`).

// If it is not a leaf, it is a pointer to the true child node when traversing the decision tree. The false branch is
// always 1 position away from the TreeNodeElement in practice in `TreeEnsembleCommon::nodes_` and so it is not stored.
// always 1 position away from the TreeNodeElement in practice in `TreeEnsembleCommon::nodes_` so it is not stored.

// If it is a leaf, it contains `weight` and `n_weights` attributes which are used to indicate the position of the
// weight in array `TreeEnsembleCommon::weights_`. If the number of targets or classes is one, the weight is also
Expand Down
119 changes: 59 additions & 60 deletions onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,11 @@ class TreeEnsembleCommon : public TreeEnsembleCommonAttributes {
void ComputeAgg(concurrency::ThreadPool* ttp, const Tensor* X, Tensor* Y, Tensor* label, const AGG& agg) const;

private:
size_t AddNodes(const size_t i, const InlinedVector<NODE_MODE>& cmodes, const InlinedVector<size_t>& truenode_ids, const InlinedVector<size_t>& falsenode_ids, const std::vector<int64_t>& nodes_featureids, const std::vector<ThresholdType>& nodes_values_as_tensor, const std::vector<float>& node_values, const std::vector<int64_t>& nodes_missing_value_tracks_true, std::vector<size_t>& updated_mapping, int64_t tree_id, const InlinedVector<TreeNodeElementId>& node_tree_ids);
size_t AddNodes(const size_t i, const InlinedVector<NODE_MODE>& cmodes, const InlinedVector<size_t>& truenode_ids,
const InlinedVector<size_t>& falsenode_ids, const std::vector<int64_t>& nodes_featureids,
const std::vector<ThresholdType>& nodes_values_as_tensor, const std::vector<float>& node_values,
const std::vector<int64_t>& nodes_missing_value_tracks_true, std::vector<size_t>& updated_mapping,
int64_t tree_id, const InlinedVector<TreeNodeElementId>& node_tree_ids);
};

template <typename InputType, typename ThresholdType, typename OutputType>
Expand Down Expand Up @@ -198,14 +202,12 @@ Status TreeEnsembleCommon<InputType, ThresholdType, OutputType>::Init(
int fpos = -1;
for (i = 0, limit = nodes_modes.size(); i < limit; ++i) {
cmodes.push_back(MakeTreeNodeMode(nodes_modes[i]));
if (cmodes[i] == NODE_MODE::LEAF)
continue;
if (cmodes[i] == NODE_MODE::LEAF) continue;
if (fpos == -1) {
fpos = static_cast<int>(i);
continue;
}
if (cmodes[i] != cmodes[fpos])
same_mode_ = false;
if (cmodes[i] != cmodes[fpos]) same_mode_ = false;
}

n_nodes_ = nodes_treeids.size();
Expand All @@ -225,8 +227,7 @@ Status TreeEnsembleCommon<InputType, ThresholdType, OutputType>::Init(

// Build node_tree_ids and node_tree_ids_map and truenode_ids and falsenode_ids
for (i = 0; i < limit; ++i) {
TreeNodeElementId node_tree_id{static_cast<int>(nodes_treeids[i]),
static_cast<int>(nodes_nodeids[i])};
TreeNodeElementId node_tree_id{static_cast<int>(nodes_treeids[i]), static_cast<int>(nodes_nodeids[i])};
auto p = node_tree_ids_map.insert(std::pair<TreeNodeElementId, size_t>(node_tree_id, i));
if (!p.second) {
ORT_THROW("Node ", node_tree_id.node_id, " in tree ", node_tree_id.tree_id, " is already there.");
Expand Down Expand Up @@ -277,7 +278,9 @@ Status TreeEnsembleCommon<InputType, ThresholdType, OutputType>::Init(
if (previous_tree_id == -1 || (previous_tree_id != node_tree_ids[i].tree_id)) {
// New tree.
int64_t tree_id = node_tree_ids[i].tree_id;
size_t root_position = AddNodes(i, cmodes, truenode_ids, falsenode_ids, nodes_featureids, nodes_values_as_tensor, nodes_values, nodes_missing_value_tracks_true, updated_mapping, tree_id, node_tree_ids);
size_t root_position =
AddNodes(i, cmodes, truenode_ids, falsenode_ids, nodes_featureids, nodes_values_as_tensor, nodes_values,
nodes_missing_value_tracks_true, updated_mapping, tree_id, node_tree_ids);
roots_.push_back(&nodes_[root_position]);
previous_tree_id = tree_id;
}
Expand All @@ -292,9 +295,8 @@ Status TreeEnsembleCommon<InputType, ThresholdType, OutputType>::Init(
InlinedVector<std::pair<TreeNodeElementId, uint32_t>> indices;
indices.reserve(target_class_nodeids.size());
for (i = 0, limit = target_class_nodeids.size(); i < limit; i++) {
indices.emplace_back(std::pair<TreeNodeElementId, uint32_t>(
TreeNodeElementId{target_class_treeids[i], target_class_nodeids[i]},
i));
indices.emplace_back(
std::pair<TreeNodeElementId, uint32_t>(TreeNodeElementId{target_class_treeids[i], target_class_nodeids[i]}, i));
}

std::sort(indices.begin(), indices.end());
Expand All @@ -318,9 +320,8 @@ Status TreeEnsembleCommon<InputType, ThresholdType, OutputType>::Init(
continue;
}
w.i = target_class_ids[i];
w.value = target_class_weights_as_tensor.empty()
? static_cast<ThresholdType>(target_class_weights[i])
: target_class_weights_as_tensor[i];
w.value = target_class_weights_as_tensor.empty() ? static_cast<ThresholdType>(target_class_weights[i])
: target_class_weights_as_tensor[i];
if (leaf.truenode_or_weight.weight_data.n_weights == 0) {
leaf.truenode_or_weight.weight_data.weight = static_cast<int32_t>(weights_.size());
leaf.value_or_unique_weight = w.value;
Expand All @@ -330,8 +331,7 @@ Status TreeEnsembleCommon<InputType, ThresholdType, OutputType>::Init(
}

has_missing_tracks_ = false;
for (auto itm = nodes_missing_value_tracks_true.begin();
itm != nodes_missing_value_tracks_true.end(); ++itm) {
for (auto itm = nodes_missing_value_tracks_true.begin(); itm != nodes_missing_value_tracks_true.end(); ++itm) {
if (*itm) {
has_missing_tracks_ = true;
break;
Expand All @@ -342,7 +342,12 @@ Status TreeEnsembleCommon<InputType, ThresholdType, OutputType>::Init(
}

template <typename InputType, typename ThresholdType, typename OutputType>
size_t TreeEnsembleCommon<InputType, ThresholdType, OutputType>::AddNodes(const size_t i, const InlinedVector<NODE_MODE>& cmodes, const InlinedVector<size_t>& truenode_ids, const InlinedVector<size_t>& falsenode_ids, const std::vector<int64_t>& nodes_featureids, const std::vector<ThresholdType>& nodes_values_as_tensor, const std::vector<float>& node_values, const std::vector<int64_t>& nodes_missing_value_tracks_true, std::vector<size_t>& updated_mapping, int64_t tree_id, const InlinedVector<TreeNodeElementId>& node_tree_ids) {
size_t TreeEnsembleCommon<InputType, ThresholdType, OutputType>::AddNodes(
const size_t i, const InlinedVector<NODE_MODE>& cmodes, const InlinedVector<size_t>& truenode_ids,
const InlinedVector<size_t>& falsenode_ids, const std::vector<int64_t>& nodes_featureids,
const std::vector<ThresholdType>& nodes_values_as_tensor, const std::vector<float>& node_values,
const std::vector<int64_t>& nodes_missing_value_tracks_true, std::vector<size_t>& updated_mapping, int64_t tree_id,
const InlinedVector<TreeNodeElementId>& node_tree_ids) {
// Validate this index maps to the same tree_id as the one we should be building.
if (node_tree_ids[i].tree_id != tree_id) {
ORT_THROW("Tree id mismatch. Expected ", tree_id, " but got ", node_tree_ids[i].tree_id, " at position ", i);
Expand All @@ -351,7 +356,7 @@ size_t TreeEnsembleCommon<InputType, ThresholdType, OutputType>::AddNodes(const
if (updated_mapping[i] != 0) {
// In theory we should not accept any cycles, however in practice LGBM conversion implements set membership via a
// series of "Equals" nodes, with the true branches directed at the same child node (a cycle).
// We may instead seek to formalise set membership in the future.
// We may instead seek to formalize set membership in the future.
return updated_mapping[i];
}

Expand All @@ -364,19 +369,23 @@ size_t TreeEnsembleCommon<InputType, ThresholdType, OutputType>::AddNodes(const
if (node.feature_id > max_feature_id_) {
max_feature_id_ = node.feature_id;
}
node.value_or_unique_weight = nodes_values_as_tensor.empty()
? static_cast<ThresholdType>(node_values[i])
: nodes_values_as_tensor[i];
node.value_or_unique_weight =
nodes_values_as_tensor.empty() ? static_cast<ThresholdType>(node_values[i]) : nodes_values_as_tensor[i];
if (i < static_cast<size_t>(nodes_missing_value_tracks_true.size()) && nodes_missing_value_tracks_true[i] == 1) {
node.flags |= static_cast<uint8_t>(MissingTrack::kTrue);
}
nodes_.push_back(std::move(node));
if (nodes_[node_pos].is_not_leaf()) {
size_t false_branch = AddNodes(falsenode_ids[i], cmodes, truenode_ids, falsenode_ids, nodes_featureids, nodes_values_as_tensor, node_values, nodes_missing_value_tracks_true, updated_mapping, tree_id, node_tree_ids);
size_t false_branch =
AddNodes(falsenode_ids[i], cmodes, truenode_ids, falsenode_ids, nodes_featureids, nodes_values_as_tensor,
node_values, nodes_missing_value_tracks_true, updated_mapping, tree_id, node_tree_ids);
if (false_branch != node_pos + 1) {
ORT_THROW("False node must always be the next node, but it isn't at index ", node_pos, " with flags ", static_cast<int>(nodes_[node_pos].flags));
ORT_THROW("False node must always be the next node, but it isn't at index ", node_pos, " with flags ",
static_cast<int>(nodes_[node_pos].flags));
}
size_t true_branch = AddNodes(truenode_ids[i], cmodes, truenode_ids, falsenode_ids, nodes_featureids, nodes_values_as_tensor, node_values, nodes_missing_value_tracks_true, updated_mapping, tree_id, node_tree_ids);
size_t true_branch =
AddNodes(truenode_ids[i], cmodes, truenode_ids, falsenode_ids, nodes_featureids, nodes_values_as_tensor,
node_values, nodes_missing_value_tracks_true, updated_mapping, tree_id, node_tree_ids);
// We don't need to store the false branch pointer since we know it is always in the immediate next entry in nodes_.
// nodes_[node_pos].falsenode_inc_or_n_weights.ptr = &nodes_[false_branch];
nodes_[node_pos].truenode_or_weight.ptr = &nodes_[true_branch];
Expand Down Expand Up @@ -660,22 +669,19 @@ void TreeEnsembleCommon<InputType, ThresholdType, OutputType>::ComputeAgg(concur
}
} // namespace detail

#define TREE_FIND_VALUE(CMP) \
if (has_missing_tracks_) { \
while (root->is_not_leaf()) { \
val = x_data[root->feature_id]; \
root = (val CMP root->value_or_unique_weight || \
(root->is_missing_track_true() && _isnan_(val))) \
? root->truenode_or_weight.ptr \
: root + 1; \
} \
} else { \
while (root->is_not_leaf()) { \
val = x_data[root->feature_id]; \
root = val CMP root->value_or_unique_weight \
? root->truenode_or_weight.ptr \
: root + 1; \
} \
#define TREE_FIND_VALUE(CMP) \
if (has_missing_tracks_) { \
while (root->is_not_leaf()) { \
val = x_data[root->feature_id]; \
root = (val CMP root->value_or_unique_weight || (root->is_missing_track_true() && _isnan_(val))) \
? root->truenode_or_weight.ptr \
: root + 1; \
} \
} else { \
while (root->is_not_leaf()) { \
val = x_data[root->feature_id]; \
root = val CMP root->value_or_unique_weight ? root->truenode_or_weight.ptr : root + 1; \
} \
}

inline bool _isnan_(float x) { return std::isnan(x); }
Expand All @@ -694,8 +700,7 @@ TreeEnsembleCommon<InputType, ThresholdType, OutputType>::ProcessTreeNodeLeave(
if (has_missing_tracks_) {
while (root->is_not_leaf()) {
val = x_data[root->feature_id];
root = (val <= root->value_or_unique_weight ||
(root->is_missing_track_true() && _isnan_(val)))
root = (val <= root->value_or_unique_weight || (root->is_missing_track_true() && _isnan_(val)))
? root->truenode_or_weight.ptr
: root + 1;
}
Expand Down Expand Up @@ -731,34 +736,28 @@ TreeEnsembleCommon<InputType, ThresholdType, OutputType>::ProcessTreeNodeLeave(
threshold = root->value_or_unique_weight;
switch (root->mode()) {
case NODE_MODE::BRANCH_LEQ:
root = val <= threshold || (root->is_missing_track_true() && _isnan_(val))
? root->truenode_or_weight.ptr
: root + 1;
root = val <= threshold || (root->is_missing_track_true() && _isnan_(val)) ? root->truenode_or_weight.ptr
: root + 1;
break;
case NODE_MODE::BRANCH_LT:
root = val < threshold || (root->is_missing_track_true() && _isnan_(val))
? root->truenode_or_weight.ptr
: root + 1;
root = val < threshold || (root->is_missing_track_true() && _isnan_(val)) ? root->truenode_or_weight.ptr
: root + 1;
break;
case NODE_MODE::BRANCH_GTE:
root = val >= threshold || (root->is_missing_track_true() && _isnan_(val))
? root->truenode_or_weight.ptr
: root + 1;
root = val >= threshold || (root->is_missing_track_true() && _isnan_(val)) ? root->truenode_or_weight.ptr
: root + 1;
break;
case NODE_MODE::BRANCH_GT:
root = val > threshold || (root->is_missing_track_true() && _isnan_(val))
? root->truenode_or_weight.ptr
: root + 1;
root = val > threshold || (root->is_missing_track_true() && _isnan_(val)) ? root->truenode_or_weight.ptr
: root + 1;
break;
case NODE_MODE::BRANCH_EQ:
root = val == threshold || (root->is_missing_track_true() && _isnan_(val))
? root->truenode_or_weight.ptr
: root + 1;
root = val == threshold || (root->is_missing_track_true() && _isnan_(val)) ? root->truenode_or_weight.ptr
: root + 1;
break;
case NODE_MODE::BRANCH_NEQ:
root = val != threshold || (root->is_missing_track_true() && _isnan_(val))
? root->truenode_or_weight.ptr
: root + 1;
root = val != threshold || (root->is_missing_track_true() && _isnan_(val)) ? root->truenode_or_weight.ptr
: root + 1;
break;
case NODE_MODE::LEAF:
return root;
Expand Down

0 comments on commit 4bf61f7

Please sign in to comment.