diff --git a/onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h b/onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h index aba9a798cf786..b9f3050e59c5b 100644 --- a/onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h +++ b/onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h @@ -64,6 +64,18 @@ enum MissingTrack : uint8_t { kFalse = 0 }; +template +struct TreeNodeElement; + +template +union PtrOrWeight { + TreeNodeElement* ptr; + struct WeightData { + int32_t weight; + int32_t n_weights; + } weight_data; +}; + template struct TreeNodeElement { int feature_id; @@ -71,24 +83,19 @@ struct TreeNodeElement { // Stores the node threshold or the weights if the tree has one target. T value_or_unique_weight; - // onnx specification says hitrates is used to store information about the node, + // The onnx specification says hitrates is used to store information about the node, // but this information is not used for inference. // T hitrates; - // True node, false node are obtained by computing `this + truenode_inc_or_first_weight`, - // `this + falsenode_inc_or_n_weights` if the node is not a leaf. - // In case of a leaf, these attributes are used to indicate the position of the weight - // in array `TreeEnsembleCommon::weights_`. If the number of targets or classes is one, - // the weight is also stored in `value_or_unique_weight`. - // This implementation assumes a tree has less than 2^31 nodes, - // and the total number of leave in the set of trees is below 2^31. - // A node cannot point to itself. - int32_t truenode_inc_or_first_weight; - // In case of a leaf, the following attribute indicates the number of weights - // in array `TreeEnsembleCommon::weights_`. If not a leaf, it indicates - // `this + falsenode_inc_or_n_weights` is the false node. - // A node cannot point to itself. - int32_t falsenode_inc_or_n_weights; + // PtrOrWeight acts as a tagged union, with the "tag" being whether the node is a leaf or not (see `is_not_leaf`). + + // If it is not a leaf, it is a pointer to the true child node when traversing the decision tree. The false branch is + // always 1 position away from the TreeNodeElement in practice in `TreeEnsembleCommon::nodes_` so it is not stored. + + // If it is a leaf, it contains `weight` and `n_weights` attributes which are used to indicate the position of the + // weight in array `TreeEnsembleCommon::weights_`. If the number of targets or classes is one, the weight is also + // stored in `value_or_unique_weight`. + PtrOrWeight truenode_or_weight; uint8_t flags; inline NODE_MODE mode() const { return NODE_MODE(flags & 0xF); } @@ -189,8 +196,8 @@ class TreeAggregatorSum : public TreeAggregator>& predictions, const TreeNodeElement& root, gsl::span> weights) const { - auto it = weights.begin() + root.truenode_inc_or_first_weight; - for (int32_t i = 0; i < root.falsenode_inc_or_n_weights; ++i, ++it) { + auto it = weights.begin() + root.truenode_or_weight.weight_data.weight; + for (int32_t i = 0; i < root.truenode_or_weight.weight_data.n_weights; ++i, ++it) { ORT_ENFORCE(it->i < (int64_t)predictions.size()); predictions[onnxruntime::narrow(it->i)].score += it->value; predictions[onnxruntime::narrow(it->i)].has_score = 1; @@ -292,8 +299,8 @@ class TreeAggregatorMin : public TreeAggregator>& predictions, const TreeNodeElement& root, gsl::span> weights) const { - auto it = weights.begin() + root.truenode_inc_or_first_weight; - for (int32_t i = 0; i < root.falsenode_inc_or_n_weights; ++i, ++it) { + auto it = weights.begin() + root.truenode_or_weight.weight_data.weight; + for (int32_t i = 0; i < root.truenode_or_weight.weight_data.n_weights; ++i, ++it) { predictions[onnxruntime::narrow(it->i)].score = (!predictions[onnxruntime::narrow(it->i)].has_score || it->value < predictions[onnxruntime::narrow(it->i)].score) ? it->value @@ -349,8 +356,8 @@ class TreeAggregatorMax : public TreeAggregator>& predictions, const TreeNodeElement& root, gsl::span> weights) const { - auto it = weights.begin() + root.truenode_inc_or_first_weight; - for (int32_t i = 0; i < root.falsenode_inc_or_n_weights; ++i, ++it) { + auto it = weights.begin() + root.truenode_or_weight.weight_data.weight; + for (int32_t i = 0; i < root.truenode_or_weight.weight_data.n_weights; ++i, ++it) { predictions[onnxruntime::narrow(it->i)].score = (!predictions[onnxruntime::narrow(it->i)].has_score || it->value > predictions[onnxruntime::narrow(it->i)].score) ? it->value diff --git a/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h b/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h index 161bb2b0820eb..8f847fe66aa73 100644 --- a/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h +++ b/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h @@ -85,6 +85,13 @@ class TreeEnsembleCommon : public TreeEnsembleCommonAttributes { template void ComputeAgg(concurrency::ThreadPool* ttp, const Tensor* X, Tensor* Y, Tensor* label, const AGG& agg) const; + + private: + size_t AddNodes(const size_t i, const InlinedVector& cmodes, const InlinedVector& truenode_ids, + const InlinedVector& falsenode_ids, const std::vector& nodes_featureids, + const std::vector& nodes_values_as_tensor, const std::vector& node_values, + const std::vector& nodes_missing_value_tracks_true, std::vector& updated_mapping, + int64_t tree_id, const InlinedVector& node_tree_ids); }; template @@ -186,7 +193,7 @@ Status TreeEnsembleCommon::Init( max_tree_depth_ = 1000; ORT_ENFORCE(nodes_modes.size() < std::numeric_limits::max()); - // additional members + // Additional members size_t limit; uint32_t i; InlinedVector cmodes; @@ -195,18 +202,14 @@ Status TreeEnsembleCommon::Init( int fpos = -1; for (i = 0, limit = nodes_modes.size(); i < limit; ++i) { cmodes.push_back(MakeTreeNodeMode(nodes_modes[i])); - if (cmodes[i] == NODE_MODE::LEAF) - continue; + if (cmodes[i] == NODE_MODE::LEAF) continue; if (fpos == -1) { fpos = static_cast(i); continue; } - if (cmodes[i] != cmodes[fpos]) - same_mode_ = false; + if (cmodes[i] != cmodes[fpos]) same_mode_ = false; } - // filling nodes - n_nodes_ = nodes_treeids.size(); limit = static_cast(n_nodes_); InlinedVector node_tree_ids; @@ -214,156 +217,185 @@ Status TreeEnsembleCommon::Init( nodes_.clear(); nodes_.reserve(limit); roots_.clear(); - std::unordered_map idi; - idi.reserve(limit); + std::unordered_map node_tree_ids_map; + node_tree_ids_map.reserve(limit); + + InlinedVector truenode_ids, falsenode_ids; + truenode_ids.reserve(limit); + falsenode_ids.reserve(limit); max_feature_id_ = 0; + // Build node_tree_ids and node_tree_ids_map and truenode_ids and falsenode_ids for (i = 0; i < limit; ++i) { - TreeNodeElementId node_tree_id{static_cast(nodes_treeids[i]), - static_cast(nodes_nodeids[i])}; - TreeNodeElement node; - node.feature_id = static_cast(nodes_featureids[i]); - if (node.feature_id > max_feature_id_) { - max_feature_id_ = node.feature_id; - } - node.value_or_unique_weight = nodes_values_as_tensor.empty() - ? static_cast(nodes_values[i]) - : nodes_values_as_tensor[i]; - - /* hitrates is not used for inference, they are ignored. - if (nodes_hitrates_as_tensor.empty()) { - node.hitrates = static_cast(i < nodes_hitrates.size() ? nodes_hitrates[i] : -1); - } else { - node.hitrates = i < nodes_hitrates_as_tensor.size() ? nodes_hitrates_as_tensor[i] : -1; - } */ - - node.flags = static_cast(cmodes[i]); - node.truenode_inc_or_first_weight = 0; // nodes_truenodeids[i] if not a leaf - node.falsenode_inc_or_n_weights = 0; // nodes_falsenodeids[i] if not a leaf - - if (i < static_cast(nodes_missing_value_tracks_true.size()) && nodes_missing_value_tracks_true[i] == 1) { - node.flags |= static_cast(MissingTrack::kTrue); - } - auto p = idi.insert(std::pair(node_tree_id, i)); + TreeNodeElementId node_tree_id{static_cast(nodes_treeids[i]), static_cast(nodes_nodeids[i])}; + auto p = node_tree_ids_map.insert(std::pair(node_tree_id, i)); if (!p.second) { ORT_THROW("Node ", node_tree_id.node_id, " in tree ", node_tree_id.tree_id, " is already there."); } - nodes_.emplace_back(node); node_tree_ids.emplace_back(node_tree_id); } - InlinedVector truenode_ids, falsenode_ids; - truenode_ids.reserve(limit); - falsenode_ids.reserve(limit); TreeNodeElementId coor; - i = 0; - for (auto it = nodes_.begin(); it != nodes_.end(); ++it, ++i) { - if (!it->is_not_leaf()) { + for (i = 0; i < limit; ++i) { + if (cmodes[i] == NODE_MODE::LEAF) { truenode_ids.push_back(0); falsenode_ids.push_back(0); - continue; - } - - TreeNodeElementId& node_tree_id = node_tree_ids[i]; - coor.tree_id = node_tree_id.tree_id; - coor.node_id = static_cast(nodes_truenodeids[i]); - ORT_ENFORCE((coor.node_id >= 0 && coor.node_id < n_nodes_)); + } else { + TreeNodeElementId& node_tree_id = node_tree_ids[i]; + coor.tree_id = node_tree_id.tree_id; + coor.node_id = static_cast(nodes_truenodeids[i]); + ORT_ENFORCE((coor.node_id >= 0 && coor.node_id < n_nodes_)); + + auto found = node_tree_ids_map.find(coor); + if (found == node_tree_ids_map.end()) { + ORT_THROW("Unable to find node ", coor.tree_id, "-", coor.node_id, " (truenode)."); + } + if (found->second == truenode_ids.size()) { + ORT_THROW("A node cannot point to itself: ", coor.tree_id, "-", node_tree_id.node_id, " (truenode)."); + } + truenode_ids.emplace_back(found->second); - auto found = idi.find(coor); - if (found == idi.end()) { - ORT_THROW("Unable to find node ", coor.tree_id, "-", coor.node_id, " (truenode)."); - } - if (found->second == truenode_ids.size()) { - ORT_THROW("A node cannot point to itself: ", coor.tree_id, "-", node_tree_id.node_id, " (truenode)."); + coor.node_id = static_cast(nodes_falsenodeids[i]); + ORT_ENFORCE((coor.node_id >= 0 && coor.node_id < n_nodes_)); + found = node_tree_ids_map.find(coor); + if (found == node_tree_ids_map.end()) { + ORT_THROW("Unable to find node ", coor.tree_id, "-", coor.node_id, " (falsenode)."); + } + if (found->second == falsenode_ids.size()) { + ORT_THROW("A node cannot point to itself: ", coor.tree_id, "-", node_tree_id.node_id, " (falsenode)."); + } + falsenode_ids.emplace_back(found->second); + // We could also check that truenode_ids[truenode_ids.size() - 1] != falsenode_ids[falsenode_ids.size() - 1]). + // It is valid but no training algorithm would produce a tree where left and right nodes are the same. } - truenode_ids.emplace_back(found->second); + } - coor.node_id = static_cast(nodes_falsenodeids[i]); - ORT_ENFORCE((coor.node_id >= 0 && coor.node_id < n_nodes_)); - found = idi.find(coor); - if (found == idi.end()) { - ORT_THROW("Unable to find node ", coor.tree_id, "-", coor.node_id, " (falsenode)."); - } - if (found->second == falsenode_ids.size()) { - ORT_THROW("A node cannot point to itself: ", coor.tree_id, "-", node_tree_id.node_id, " (falsenode)."); + // Let's construct nodes_ such that the false branch is always the next element in nodes_. + // updated_mapping will translates the old position of each node to the new node position in nodes_. + std::vector updated_mapping(nodes_treeids.size(), 0); + int64_t previous_tree_id = -1; + for (i = 0; i < n_nodes_; ++i) { + if (previous_tree_id == -1 || (previous_tree_id != node_tree_ids[i].tree_id)) { + // New tree. + int64_t tree_id = node_tree_ids[i].tree_id; + size_t root_position = + AddNodes(i, cmodes, truenode_ids, falsenode_ids, nodes_featureids, nodes_values_as_tensor, nodes_values, + nodes_missing_value_tracks_true, updated_mapping, tree_id, node_tree_ids); + roots_.push_back(&nodes_[root_position]); + previous_tree_id = tree_id; } - falsenode_ids.emplace_back(found->second); - // We could also check that truenode_ids[truenode_ids.size() - 1] != falsenode_ids[falsenode_ids.size() - 1]). - // It is valid but no training algorithm would produce a tree where left and right nodes are the same. } - // sort targets + n_trees_ = roots_.size(); + if (((int64_t)nodes_.size()) != n_nodes_) { + ORT_THROW("Number of nodes in nodes_ (", nodes_.size(), ") is different from n_nodes (", n_nodes_, ")."); + } + + // Sort targets InlinedVector> indices; indices.reserve(target_class_nodeids.size()); for (i = 0, limit = target_class_nodeids.size(); i < limit; i++) { - indices.emplace_back(std::pair( - TreeNodeElementId{target_class_treeids[i], target_class_nodeids[i]}, - i)); + indices.emplace_back( + std::pair(TreeNodeElementId{target_class_treeids[i], target_class_nodeids[i]}, i)); } + std::sort(indices.begin(), indices.end()); - // Initialize the leaves. TreeNodeElementId ind; SparseValue w; size_t indi; for (indi = 0, limit = target_class_nodeids.size(); indi < limit; ++indi) { ind = indices[indi].first; i = indices[indi].second; - auto found = idi.find(ind); - if (found == idi.end()) { + auto found = node_tree_ids_map.find(ind); + if (found == node_tree_ids_map.end()) { ORT_THROW("Unable to find node ", ind.tree_id, "-", ind.node_id, " (weights)."); } - TreeNodeElement& leaf = nodes_[found->second]; + TreeNodeElement& leaf = nodes_[updated_mapping[found->second]]; if (leaf.is_not_leaf()) { // An exception should be raised in that case. But this case may happen in // models converted with an old version of onnxmltools. These weights are ignored. // ORT_THROW("Node ", ind.tree_id, "-", ind.node_id, " is not a leaf."); continue; } - w.i = target_class_ids[i]; - w.value = target_class_weights_as_tensor.empty() - ? static_cast(target_class_weights[i]) - : target_class_weights_as_tensor[i]; - if (leaf.falsenode_inc_or_n_weights == 0) { - leaf.truenode_inc_or_first_weight = static_cast(weights_.size()); + w.value = target_class_weights_as_tensor.empty() ? static_cast(target_class_weights[i]) + : target_class_weights_as_tensor[i]; + if (leaf.truenode_or_weight.weight_data.n_weights == 0) { + leaf.truenode_or_weight.weight_data.weight = static_cast(weights_.size()); leaf.value_or_unique_weight = w.value; } - ++leaf.falsenode_inc_or_n_weights; + ++leaf.truenode_or_weight.weight_data.n_weights; weights_.push_back(w); } - // Initialize all the nodes but the leaves. - int64_t previous = -1; - for (i = 0, limit = static_cast(n_nodes_); i < limit; ++i) { - if ((previous == -1) || (previous != node_tree_ids[i].tree_id)) - roots_.push_back(&(nodes_[idi[node_tree_ids[i]]])); - previous = node_tree_ids[i].tree_id; - if (!nodes_[i].is_not_leaf()) { - if (nodes_[i].falsenode_inc_or_n_weights == 0) { - ORT_THROW("Target is missing for leaf ", ind.tree_id, "-", ind.node_id, "."); - } - continue; - } - ORT_ENFORCE(truenode_ids[i] != i); // That would mean the left node is itself, leading to an infinite loop. - nodes_[i].truenode_inc_or_first_weight = static_cast(truenode_ids[i] - i); - ORT_ENFORCE(falsenode_ids[i] != i); // That would mean the right node is itself, leading to an infinite loop. - nodes_[i].falsenode_inc_or_n_weights = static_cast(falsenode_ids[i] - i); - } - - n_trees_ = roots_.size(); has_missing_tracks_ = false; - for (auto itm = nodes_missing_value_tracks_true.begin(); - itm != nodes_missing_value_tracks_true.end(); ++itm) { + for (auto itm = nodes_missing_value_tracks_true.begin(); itm != nodes_missing_value_tracks_true.end(); ++itm) { if (*itm) { has_missing_tracks_ = true; break; } } + return Status::OK(); } +template +size_t TreeEnsembleCommon::AddNodes( + const size_t i, const InlinedVector& cmodes, const InlinedVector& truenode_ids, + const InlinedVector& falsenode_ids, const std::vector& nodes_featureids, + const std::vector& nodes_values_as_tensor, const std::vector& node_values, + const std::vector& nodes_missing_value_tracks_true, std::vector& updated_mapping, int64_t tree_id, + const InlinedVector& node_tree_ids) { + // Validate this index maps to the same tree_id as the one we should be building. + if (node_tree_ids[i].tree_id != tree_id) { + ORT_THROW("Tree id mismatch. Expected ", tree_id, " but got ", node_tree_ids[i].tree_id, " at position ", i); + } + + if (updated_mapping[i] != 0) { + // In theory we should not accept any cycles, however in practice LGBM conversion implements set membership via a + // series of "Equals" nodes, with the true branches directed at the same child node (a cycle). + // We may instead seek to formalize set membership in the future. + return updated_mapping[i]; + } + + size_t node_pos = nodes_.size(); + updated_mapping[i] = node_pos; + + TreeNodeElement node; + node.flags = static_cast(cmodes[i]); + node.feature_id = static_cast(nodes_featureids[i]); + if (node.feature_id > max_feature_id_) { + max_feature_id_ = node.feature_id; + } + node.value_or_unique_weight = + nodes_values_as_tensor.empty() ? static_cast(node_values[i]) : nodes_values_as_tensor[i]; + if (i < static_cast(nodes_missing_value_tracks_true.size()) && nodes_missing_value_tracks_true[i] == 1) { + node.flags |= static_cast(MissingTrack::kTrue); + } + nodes_.push_back(std::move(node)); + if (nodes_[node_pos].is_not_leaf()) { + size_t false_branch = + AddNodes(falsenode_ids[i], cmodes, truenode_ids, falsenode_ids, nodes_featureids, nodes_values_as_tensor, + node_values, nodes_missing_value_tracks_true, updated_mapping, tree_id, node_tree_ids); + if (false_branch != node_pos + 1) { + ORT_THROW("False node must always be the next node, but it isn't at index ", node_pos, " with flags ", + static_cast(nodes_[node_pos].flags)); + } + size_t true_branch = + AddNodes(truenode_ids[i], cmodes, truenode_ids, falsenode_ids, nodes_featureids, nodes_values_as_tensor, + node_values, nodes_missing_value_tracks_true, updated_mapping, tree_id, node_tree_ids); + // We don't need to store the false branch pointer since we know it is always in the immediate next entry in nodes_. + // nodes_[node_pos].falsenode_inc_or_n_weights.ptr = &nodes_[false_branch]; + nodes_[node_pos].truenode_or_weight.ptr = &nodes_[true_branch]; + } else { + nodes_[node_pos].truenode_or_weight.weight_data.weight = 0; + nodes_[node_pos].truenode_or_weight.weight_data.n_weights = 0; + } + return node_pos; +} + template Status TreeEnsembleCommon::compute(OpKernelContext* ctx, const Tensor* X, @@ -637,22 +669,19 @@ void TreeEnsembleCommon::ComputeAgg(concur } } // namespace detail -#define TREE_FIND_VALUE(CMP) \ - if (has_missing_tracks_) { \ - while (root->is_not_leaf()) { \ - val = x_data[root->feature_id]; \ - root += (val CMP root->value_or_unique_weight || \ - (root->is_missing_track_true() && _isnan_(val))) \ - ? root->truenode_inc_or_first_weight \ - : root->falsenode_inc_or_n_weights; \ - } \ - } else { \ - while (root->is_not_leaf()) { \ - val = x_data[root->feature_id]; \ - root += val CMP root->value_or_unique_weight \ - ? root->truenode_inc_or_first_weight \ - : root->falsenode_inc_or_n_weights; \ - } \ +#define TREE_FIND_VALUE(CMP) \ + if (has_missing_tracks_) { \ + while (root->is_not_leaf()) { \ + val = x_data[root->feature_id]; \ + root = (val CMP root->value_or_unique_weight || (root->is_missing_track_true() && _isnan_(val))) \ + ? root->truenode_or_weight.ptr \ + : root + 1; \ + } \ + } else { \ + while (root->is_not_leaf()) { \ + val = x_data[root->feature_id]; \ + root = val CMP root->value_or_unique_weight ? root->truenode_or_weight.ptr : root + 1; \ + } \ } inline bool _isnan_(float x) { return std::isnan(x); } @@ -671,15 +700,14 @@ TreeEnsembleCommon::ProcessTreeNodeLeave( if (has_missing_tracks_) { while (root->is_not_leaf()) { val = x_data[root->feature_id]; - root += (val <= root->value_or_unique_weight || - (root->is_missing_track_true() && _isnan_(val))) - ? root->truenode_inc_or_first_weight - : root->falsenode_inc_or_n_weights; + root = (val <= root->value_or_unique_weight || (root->is_missing_track_true() && _isnan_(val))) + ? root->truenode_or_weight.ptr + : root + 1; } } else { while (root->is_not_leaf()) { val = x_data[root->feature_id]; - root += val <= root->value_or_unique_weight ? root->truenode_inc_or_first_weight : root->falsenode_inc_or_n_weights; + root = val <= root->value_or_unique_weight ? root->truenode_or_weight.ptr : root + 1; } } break; @@ -703,42 +731,36 @@ TreeEnsembleCommon::ProcessTreeNodeLeave( } } else { // Different rules to compare to node thresholds. ThresholdType threshold; - while (root->is_not_leaf()) { + while (1) { val = x_data[root->feature_id]; threshold = root->value_or_unique_weight; switch (root->mode()) { case NODE_MODE::BRANCH_LEQ: - root += val <= threshold || (root->is_missing_track_true() && _isnan_(val)) - ? root->truenode_inc_or_first_weight - : root->falsenode_inc_or_n_weights; + root = val <= threshold || (root->is_missing_track_true() && _isnan_(val)) ? root->truenode_or_weight.ptr + : root + 1; break; case NODE_MODE::BRANCH_LT: - root += val < threshold || (root->is_missing_track_true() && _isnan_(val)) - ? root->truenode_inc_or_first_weight - : root->falsenode_inc_or_n_weights; + root = val < threshold || (root->is_missing_track_true() && _isnan_(val)) ? root->truenode_or_weight.ptr + : root + 1; break; case NODE_MODE::BRANCH_GTE: - root += val >= threshold || (root->is_missing_track_true() && _isnan_(val)) - ? root->truenode_inc_or_first_weight - : root->falsenode_inc_or_n_weights; + root = val >= threshold || (root->is_missing_track_true() && _isnan_(val)) ? root->truenode_or_weight.ptr + : root + 1; break; case NODE_MODE::BRANCH_GT: - root += val > threshold || (root->is_missing_track_true() && _isnan_(val)) - ? root->truenode_inc_or_first_weight - : root->falsenode_inc_or_n_weights; + root = val > threshold || (root->is_missing_track_true() && _isnan_(val)) ? root->truenode_or_weight.ptr + : root + 1; break; case NODE_MODE::BRANCH_EQ: - root += val == threshold || (root->is_missing_track_true() && _isnan_(val)) - ? root->truenode_inc_or_first_weight - : root->falsenode_inc_or_n_weights; + root = val == threshold || (root->is_missing_track_true() && _isnan_(val)) ? root->truenode_or_weight.ptr + : root + 1; break; case NODE_MODE::BRANCH_NEQ: - root += val != threshold || (root->is_missing_track_true() && _isnan_(val)) - ? root->truenode_inc_or_first_weight - : root->falsenode_inc_or_n_weights; + root = val != threshold || (root->is_missing_track_true() && _isnan_(val)) ? root->truenode_or_weight.ptr + : root + 1; break; case NODE_MODE::LEAF: - break; + return root; } } }