Skip to content

Commit

Permalink
Merge pull request mlcommons#57 from astra-sim/schema_update_0.004
Browse files Browse the repository at this point in the history
Update et_feeder for compatibility with Chakra schema v0.0.4
  • Loading branch information
srinivas212 authored Nov 16, 2023
2 parents 6c29de9 + 224643f commit b892cbe
Show file tree
Hide file tree
Showing 5 changed files with 292 additions and 20 deletions.
30 changes: 17 additions & 13 deletions et_feeder/et_feeder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ using namespace std;
using namespace Chakra;

ETFeeder::ETFeeder(string filename)
: trace_(filename), window_size_(4096), et_complete_(false) {
: trace_(filename), window_size_(4096 * 256), et_complete_(false) {
readGlobalMetadata();
readNextWindow();
}

Expand All @@ -30,7 +31,7 @@ bool ETFeeder::hasNodesToIssue() {

shared_ptr<ETFeederNode> ETFeeder::getNextIssuableNode() {
if (dep_free_node_queue_.size() != 0) {
shared_ptr<ETFeederNode> node = dep_free_node_queue_.front();
shared_ptr<ETFeederNode> node = dep_free_node_queue_.top();
dep_free_node_id_set_.erase(node->getChakraNode()->id());
dep_free_node_queue_.pop();
return node;
Expand All @@ -53,38 +54,41 @@ void ETFeeder::freeChildrenNodes(uint64_t node_id) {
shared_ptr<ETFeederNode> node = dep_graph_[node_id];
for (auto child: node->getChildren()) {
auto child_chakra = child->getChakraNode();
for (auto it = child_chakra->mutable_parent()->begin();
it != child_chakra->mutable_parent()->end();
for (auto it = child_chakra->mutable_data_deps()->begin();
it != child_chakra->mutable_data_deps()->end();
++it) {
if (*it == node_id) {
child_chakra->mutable_parent()->erase(it);
child_chakra->mutable_data_deps()->erase(it);
break;
}
}
if (child_chakra->parent().size() == 0) {
if (child_chakra->data_deps().size() == 0) {
dep_free_node_id_set_.emplace(child_chakra->id());
dep_free_node_queue_.emplace(child);
}
}
}

void ETFeeder::readGlobalMetadata() {
shared_ptr<ChakraProtoMsg::GlobalMetadata> pkt_msg = make_shared<ChakraProtoMsg::GlobalMetadata>();
trace_.read(*pkt_msg);
}

shared_ptr<ETFeederNode> ETFeeder::readNode() {
shared_ptr<ETFeederNode> node = make_shared<ETFeederNode>();
shared_ptr<ChakraProtoMsg::Node> pkt_msg = make_shared<ChakraProtoMsg::Node>();

if (!trace_.read(*pkt_msg)) {
return nullptr;
}
node->setChakraNode(pkt_msg);
shared_ptr<ETFeederNode> node = make_shared<ETFeederNode>(pkt_msg);

bool dep_unresolved = false;
for (int i = 0; i < pkt_msg->parent_size(); ++i) {
auto parent_node = dep_graph_.find(pkt_msg->parent(i));
for (int i = 0; i < pkt_msg->data_deps_size(); ++i) {
auto parent_node = dep_graph_.find(pkt_msg->data_deps(i));
if (parent_node != dep_graph_.end()) {
parent_node->second->addChild(node);
} else {
dep_unresolved = true;
node->addDepUnresolvedParentID(pkt_msg->parent(i));
node->addDepUnresolvedParentID(pkt_msg->data_deps(i));
}
}

Expand Down Expand Up @@ -139,7 +143,7 @@ void ETFeeder::readNextWindow() {
uint64_t node_id = node_id_node.first;
shared_ptr<ETFeederNode> node = node_id_node.second;
if ((dep_free_node_id_set_.count(node_id) == 0)
&& (node->getChakraNode()->parent().size() == 0)) {
&& (node->getChakraNode()->data_deps().size() == 0)) {
dep_free_node_id_set_.emplace(node_id);
dep_free_node_queue_.emplace(node);
}
Expand Down
11 changes: 10 additions & 1 deletion et_feeder/et_feeder.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,19 @@
#include <queue>
#include <unordered_map>
#include <unordered_set>
#include <vector>

#include "third_party/utils/protoio.hh"
#include "et_feeder/et_feeder_node.h"

namespace Chakra {
struct CompareNodes: public std::binary_function<std::shared_ptr<ETFeederNode>, std::shared_ptr<ETFeederNode>, bool>
{
bool operator()(const std::shared_ptr<ETFeederNode> lhs, const std::shared_ptr<ETFeederNode> rhs) const
{
return lhs->getChakraNode()->id() > rhs->getChakraNode()->id();
}
};

class ETFeeder {
public:
Expand All @@ -24,6 +32,7 @@ class ETFeeder {
void freeChildrenNodes(uint64_t node_id);

private:
void readGlobalMetadata();
std::shared_ptr<ETFeederNode> readNode();
void readNextWindow();
void resolveDep();
Expand All @@ -34,7 +43,7 @@ class ETFeeder {

std::unordered_map<uint64_t, std::shared_ptr<ETFeederNode>> dep_graph_{};
std::unordered_set<uint64_t> dep_free_node_id_set_{};
std::queue<std::shared_ptr<ETFeederNode>> dep_free_node_queue_{};
std::priority_queue<std::shared_ptr<ETFeederNode>, std::vector<std::shared_ptr<ETFeederNode>>, CompareNodes> dep_free_node_queue_{};
std::unordered_set<std::shared_ptr<ETFeederNode>> dep_unresolved_node_set_{};
};

Expand Down
232 changes: 228 additions & 4 deletions et_feeder/et_feeder_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,41 @@
using namespace std;
using namespace Chakra;

shared_ptr<ChakraProtoMsg::Node> ETFeederNode::getChakraNode() {
return node_;
ETFeederNode::ETFeederNode(std::shared_ptr<ChakraProtoMsg::Node> node) {
this->node_= node;
this->id_ = node->id();
this->name_ = node->name();
this->runtime_ = node->duration_micros();
this->is_cpu_op_ = true;
for (int i = 0; i < node->attr_size(); i++) {
string attr_name = node->attr(i).name();
if (attr_name == "is_cpu_op") {
assign_attr_val(node, i, (void *)(&is_cpu_op_));
} else if (attr_name == "num_ops") {
assign_attr_val(node, i, (void *)(&num_ops_));
} else if (attr_name == "tensor_size") {
assign_attr_val(node, i, (void *)(&tensor_size_));
} else if (attr_name == "comm_type") {
assign_attr_val(node, i, (void *)(&comm_type_));
} else if (attr_name == "involved_dim") {
assign_attr_val(node, i, (void *)(&involved_dim_));
involved_dim_size_ = node->attr(i).bool_list().values_size();
} else if (attr_name == "comm_priority") {
assign_attr_val(node, i, (void *)(&comm_priority_));
} else if (attr_name == "comm_size") {
assign_attr_val(node, i, (void *)(&comm_size_));
} else if (attr_name == "comm_src") {
assign_attr_val(node, i, (void *)(&comm_src_));
} else if (attr_name == "comm_dst") {
assign_attr_val(node, i, (void *)(&comm_dst_));
} else if (attr_name == "comm_tag") {
assign_attr_val(node, i, (void *)(&comm_tag_));
}
}
}

void ETFeederNode::setChakraNode(shared_ptr<ChakraProtoMsg::Node> node) {
node_ = node;
shared_ptr<ChakraProtoMsg::Node> ETFeederNode::getChakraNode() {
return node_;
}

void ETFeederNode::addChild(shared_ptr<ETFeederNode> node) {
Expand Down Expand Up @@ -37,3 +66,198 @@ void ETFeederNode::setDepUnresolvedParentIDs(
vector<uint64_t> const& dep_unresolved_parent_ids) {
dep_unresolved_parent_ids_ = dep_unresolved_parent_ids;
}

void ETFeederNode::assign_attr_val(shared_ptr<ChakraProtoMsg::Node> node, int i, void *member) {
auto attr = node->attr(i);
switch(attr.value_case()) {
case ChakraProtoMsg::AttributeProto::kDoubleVal:
*((double *)member) = attr.double_val();
break;
case ChakraProtoMsg::AttributeProto::kDoubleList:
for (const auto& val : attr.double_list().values()) {
(*((std::vector<double> *)member)).push_back(val);
}
break;
case ChakraProtoMsg::AttributeProto::kFloatVal:
*((float *)member) = attr.float_val();
break;
case ChakraProtoMsg::AttributeProto::kFloatList:
for (const auto& val : attr.float_list().values()) {
(*((std::vector<float> *)member)).push_back(val);
}
break;
case ChakraProtoMsg::AttributeProto::kInt32Val:
*((int32_t *)member) = attr.int32_val();
break;
case ChakraProtoMsg::AttributeProto::kInt32List:
for (const auto& val : attr.int32_list().values()) {
(*((std::vector<int32_t> *)member)).push_back(val);
}
break;
case ChakraProtoMsg::AttributeProto::kInt64Val:
*((int64_t *)member) = attr.int64_val();
break;
case ChakraProtoMsg::AttributeProto::kInt64List:
for (const auto& val : attr.int64_list().values()) {
(*((std::vector<int64_t> *)member)).push_back(val);
}
break;
case ChakraProtoMsg::AttributeProto::kUint32Val:
*((uint32_t *)member) = attr.uint32_val();
break;
case ChakraProtoMsg::AttributeProto::kUint32List:
for (const auto& val : attr.uint32_list().values()) {
(*((std::vector<uint32_t> *)member)).push_back(val);
}
break;
case ChakraProtoMsg::AttributeProto::kUint64Val:
*((uint64_t *)member) = attr.uint64_val();
break;
case ChakraProtoMsg::AttributeProto::kUint64List:
for (const auto& val : attr.uint64_list().values()) {
(*((std::vector<uint64_t> *)member)).push_back(val);
}
break;
case ChakraProtoMsg::AttributeProto::kSint32Val:
*((int32_t *)member) = attr.sint32_val();
break;
case ChakraProtoMsg::AttributeProto::kSint32List:
for (const auto& val : attr.sint32_list().values()) {
(*((std::vector<int32_t> *)member)).push_back(val);
}
break;
case ChakraProtoMsg::AttributeProto::kSint64Val:
*((int64_t *)member) = attr.sint64_val();
break;
case ChakraProtoMsg::AttributeProto::kSint64List:
for (const auto& val : attr.sint64_list().values()) {
(*((std::vector<int64_t> *)member)).push_back(val);
}
break;
case ChakraProtoMsg::AttributeProto::kFixed32Val:
*((uint32_t *)member) = attr.fixed32_val();
break;
case ChakraProtoMsg::AttributeProto::kFixed32List:
for (const auto& val : attr.fixed32_list().values()) {
(*((std::vector<uint32_t> *)member)).push_back(val);
}
break;
case ChakraProtoMsg::AttributeProto::kFixed64Val:
*((uint64_t *)member) = attr.fixed64_val();
break;
case ChakraProtoMsg::AttributeProto::kFixed64List:
for (const auto& val : attr.fixed64_list().values()) {
(*((std::vector<uint64_t> *)member)).push_back(val);
}
break;
case ChakraProtoMsg::AttributeProto::kSfixed32Val:
*((int32_t *)member) = attr.sfixed32_val();
break;
case ChakraProtoMsg::AttributeProto::kSfixed32List:
for (const auto& val : attr.sfixed32_list().values()) {
(*((std::vector<int32_t> *)member)).push_back(val);
}
break;
case ChakraProtoMsg::AttributeProto::kSfixed64Val:
*((int64_t *)member) = attr.sfixed64_val();
break;
case ChakraProtoMsg::AttributeProto::kSfixed64List:
for (const auto& val : attr.sfixed64_list().values()) {
(*((std::vector<int64_t> *)member)).push_back(val);
}
break;
case ChakraProtoMsg::AttributeProto::kBoolVal:
*((bool *)member) = attr.bool_val();
break;
case ChakraProtoMsg::AttributeProto::kBoolList:
for (const auto& val : attr.bool_list().values()) {
(*((std::vector<bool> *)member)).push_back(val);
}
break;
case ChakraProtoMsg::AttributeProto::kStringVal:
*((std::string *)member) = attr.string_val();
break;
case ChakraProtoMsg::AttributeProto::kStringList:
for (const auto& val : attr.string_list().values()) {
(*((std::vector<std::string> *)member)).push_back(val);
}
break;
case ChakraProtoMsg::AttributeProto::kBytesVal:
*((std::string *)member) = attr.bytes_val();
break;
case ChakraProtoMsg::AttributeProto::kBytesList:
for (const auto& val : attr.bytes_list().values()) {
(*((std::vector<std::string> *)member)).push_back(val);
}
break;
case ChakraProtoMsg::AttributeProto::VALUE_NOT_SET:
default:
std::cerr << "undefined attr type in chakra node" << std::endl;
exit(EXIT_FAILURE);
break;
}
}

uint64_t ETFeederNode::id() {
return id_;
}

string ETFeederNode::name() {
return name_;
}

bool ETFeederNode::is_cpu_op() {
return is_cpu_op_;
}

ChakraProtoMsg::NodeType ETFeederNode::type() {
return node_->type();
}

uint64_t ETFeederNode::runtime() {
return runtime_;
}

uint64_t ETFeederNode::num_ops() {
return num_ops_;
}

uint32_t ETFeederNode::tensor_loc() {
return tensor_loc_;
}

uint64_t ETFeederNode::tensor_size() {
return tensor_size_;
}

ChakraProtoMsg::CollectiveCommType ETFeederNode::comm_type() {
return comm_type_;
}

uint32_t ETFeederNode::involved_dim_size() {
return involved_dim_size_;
}

bool ETFeederNode::involved_dim(int i) {
return involved_dim_[i];
}

uint32_t ETFeederNode::comm_priority() {
return comm_priority_;
}

uint64_t ETFeederNode::comm_size() {
return comm_size_;
}

uint32_t ETFeederNode::comm_src() {
return comm_src_;
}

uint32_t ETFeederNode::comm_dst() {
return comm_dst_;
}

uint32_t ETFeederNode::comm_tag() {
return comm_tag_;
}
Loading

0 comments on commit b892cbe

Please sign in to comment.