Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR] Make conversion from Integer to int64_t explicit #12010

Merged
merged 4 commits into from
Jul 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/tvm/ir/attrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ class DictAttrs : public Attrs {
* \endcode
*/
bool HasNonzeroAttr(const std::string& attr_key) const {
return GetAttr<Integer>(attr_key, 0) != 0;
return GetAttr<Integer>(attr_key, 0).value_or(0).IntValue() != 0;
}

TVM_DEFINE_OBJECT_REF_METHODS(DictAttrs, Attrs, DictAttrsNode);
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,7 @@ class Integer : public IntImm {
/*!
* \brief convert to int64_t
*/
operator int64_t() const {
int64_t IntValue() const {
ICHECK(data_ != nullptr) << " Trying to reference a null Integer";
return (*this)->value;
}
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/relay/feature.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class FeatureSet {
explicit FeatureSet(Feature ft) { bs_.set(static_cast<size_t>(ft)); }
explicit FeatureSet(const tvm::Array<tvm::Integer>& ft) {
for (Integer i : ft) {
(*this) += Feature(static_cast<int>(i));
*this += Feature(i.IntValue());
}
}
explicit operator Array<Integer>() const {
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/topi/cuda/injective.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ namespace cuda {
inline Schedule schedule_injective_from_existing(Schedule sch, const Tensor& out) {
auto fused = detail::Fuse(sch[out], sch[out]->op.as<ComputeOpNode>()->axis);
auto target = Target::Current(false);
int num_thread = target->GetAttr<Integer>("max_num_threads").value();
int num_thread = target->GetAttr<Integer>("max_num_threads").value().IntValue();
IterVar bx, tx;
sch[out].split(fused, num_thread, &bx, &tx);
sch[out].bind(bx, thread_axis(Range(), "blockIdx.x"));
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/topi/cuda/pooling.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ inline Schedule schedule_pool(const Target& target, const Array<Tensor>& outs) {
if (padded_input->op->IsInstance<ComputeOpNode>()) {
s[padded_input].compute_inline();
}
int num_thread = target->GetAttr<Integer>("max_num_threads").value();
int num_thread = target->GetAttr<Integer>("max_num_threads").value().IntValue();
Tensor out;
Tensor OL;
if (detail::contains(s->outputs, pool->op)) {
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/topi/cuda/reduction.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ Schedule ScheduleReduce(const Target& target, Operation op, Schedule sch,
thread_y = tvm::te::thread_axis(Range(0, num_thread), "threadIdx.y");
} else {
all_reduce = true;
num_thread = target->GetAttr<Integer>("max_num_threads").value();
num_thread = target->GetAttr<Integer>("max_num_threads").value().IntValue();
thread_x = tvm::te::thread_axis(Range(0, num_thread), "threadIdx.x");
}

Expand Down
16 changes: 8 additions & 8 deletions include/tvm/topi/detail/strided_slice.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,12 @@ inline Array<PrimExpr> StridedSliceCanonicalizeBegin(const Array<PrimExpr>& isha
std::string slice_mode = "end") {
Array<PrimExpr> begin_expr;
for (size_t i = 0; i < axes.size(); ++i) {
if (ishape[axes[i]]->IsInstance<tvm::IntImmNode>()) {
int64_t dim_i = GetConstInt(ishape[axes[i]]);
if (ishape[axes[i].IntValue()]->IsInstance<tvm::IntImmNode>()) {
int64_t dim_i = GetConstInt(ishape[axes[i].IntValue()]);
int64_t begin_i = CanonicalizeIndex(begin[i], dim_i, strides[i]);
begin_expr.push_back(make_const(dtype, begin_i));
} else {
auto idim = ishape[axes[i]];
auto idim = ishape[axes[i].IntValue()];
auto b_expr = make_const(dtype, begin[i]);
PrimExpr b = begin[i] < 0 ? b_expr + idim : b_expr;
auto s = strides[i];
Expand Down Expand Up @@ -129,8 +129,8 @@ inline Array<PrimExpr> StridedSliceOutputShape(const Array<PrimExpr>& ishape,
}

for (size_t i = 0; i < axes.size(); ++i) {
if (ishape[axes[i]]->IsInstance<tvm::IntImmNode>()) {
const int64_t dim_i = GetConstInt(ishape[axes[i]]);
if (ishape[axes[i].IntValue()]->IsInstance<tvm::IntImmNode>()) {
const int64_t dim_i = GetConstInt(ishape[axes[i].IntValue()]);
ICHECK(begin_canonicalized[i]->IsInstance<tvm::IntImmNode>());
int64_t begin_i = GetConstInt(begin_canonicalized[i]);
int64_t end_i = CanonicalizeIndex(end[i], dim_i, strides[i]);
Expand All @@ -139,11 +139,11 @@ inline Array<PrimExpr> StridedSliceOutputShape(const Array<PrimExpr>& ishape,
static_cast<int>((interval + std::abs(strides[i]) - 1) / std::abs(strides[i]));
ICHECK(strides[i] < 0 ? (end_i <= begin_i) : (begin_i <= end_i))
<< ": Input [Begin=" << begin[i] << ", End=" << end[i] << "] is invalid for axis=" << i;
out_shape.Set(axes[i], cast(out_shape[i].dtype(), PrimExpr(slice_size)));
out_shape.Set(axes[i].IntValue(), cast(out_shape[i].dtype(), PrimExpr(slice_size)));
} else if (use_any) {
out_shape.Set(axes[i], tvm::tir::Any());
out_shape.Set(axes[i].IntValue(), tvm::tir::Any());
} else {
out_shape.Set(axes[i], tvm::tir::Var("dim", out_shape[i]->dtype));
out_shape.Set(axes[i].IntValue(), tvm::tir::Var("dim", out_shape[i]->dtype));
}
}

Expand Down
4 changes: 2 additions & 2 deletions include/tvm/topi/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -790,8 +790,8 @@ inline Tensor strided_slice_with_axes(const Tensor& x, const Array<Integer>& beg
for (size_t i = 0; i < out_shape.size(); ++i) real_indices.push_back(indices[i]);
for (size_t i = 0; i < axes.size(); ++i) {
auto stride = make_const(strides[i].dtype(), strides_vec[i]);
PrimExpr ind = indices[axes[i]] * stride + begin_expr[i];
real_indices.Set(axes[i], ind);
PrimExpr ind = indices[axes[i].IntValue()] * stride + begin_expr[i];
real_indices.Set(axes[i].IntValue(), ind);
}
return x(real_indices);
},
Expand Down
36 changes: 18 additions & 18 deletions src/auto_scheduler/transform_step.cc
Original file line number Diff line number Diff line change
Expand Up @@ -501,18 +501,17 @@ Iterator FuseStepNode::ApplyToState(State* state) const {
if (i > 0) {
ICHECK_EQ(fused_ids[i]->value, fused_ids[i - 1]->value + 1);
}

if (i != fused_ids.size() - 1) {
const auto& iter_to_attached_stage = (*state)->attach_map->iter_to_attached_stages;
if (iter_to_attached_stage.find(std::make_pair(stage_id, fused_ids[i])) !=
if (iter_to_attached_stage.find(std::make_pair(stage_id, fused_ids[i].IntValue())) !=
iter_to_attached_stage.end()) {
LOG(FATAL) << "Invalid Fuse. Trying to fuse iterators that have been attached by some "
<< "stages. State before fusion:\n"
<< (*state);
}
}

const Iterator& it = stage->iters[fused_ids[i]];
const Iterator& it = stage->iters[fused_ids[i].IntValue()];
orig_iters.push_back(it);
new_name = new_name + it->name + "@";

Expand Down Expand Up @@ -543,9 +542,9 @@ Iterator FuseStepNode::ApplyToState(State* state) const {
new_iters.push_back(new_it);
} else {
new_iters.insert(new_iters.end(), stage->iters.begin(),
stage->iters.begin() + fused_ids.front());
stage->iters.begin() + fused_ids.front().IntValue());
new_iters.push_back(new_it);
new_iters.insert(new_iters.end(), stage->iters.begin() + fused_ids.back() + 1,
new_iters.insert(new_iters.end(), stage->iters.begin() + fused_ids.back().IntValue() + 1,
stage->iters.end());
}

Expand All @@ -561,7 +560,7 @@ Iterator FuseStepNode::ApplyToState(State* state) const {
// The original iterators in AttachMap will be updated with the new iterators
std::vector<IterKey> from_iters;
std::vector<IterKey> to_iters;
const size_t begin_id = fused_ids.front(), end_id = fused_ids.back();
const size_t begin_id = fused_ids.front().IntValue(), end_id = fused_ids.back().IntValue();
for (size_t i = 0; i < old_iter_size; ++i) {
if (i <= begin_id) {
continue;
Expand All @@ -587,7 +586,7 @@ IterVar FuseStepNode::ApplyToSchedule(Array<te::Stage>* stages,

Array<IterVar> to_fuse;
for (const auto& i : fused_ids) {
to_fuse.push_back(axes[i]);
to_fuse.push_back(axes[i.IntValue()]);
}
IterVar fused_axis;
stage.fuse(to_fuse, &fused_axis);
Expand All @@ -596,9 +595,9 @@ IterVar FuseStepNode::ApplyToSchedule(Array<te::Stage>* stages,
if (fused_ids.empty()) {
new_axes.push_back(fused_axis);
} else {
new_axes.insert(new_axes.end(), axes.begin(), axes.begin() + fused_ids.front());
new_axes.insert(new_axes.end(), axes.begin(), axes.begin() + fused_ids.front().IntValue());
new_axes.push_back(fused_axis);
new_axes.insert(new_axes.end(), axes.begin() + fused_ids.back() + 1, axes.end());
new_axes.insert(new_axes.end(), axes.begin() + fused_ids.back().IntValue() + 1, axes.end());
}

stage_to_axes->Set(stage, std::move(new_axes));
Expand All @@ -613,7 +612,8 @@ String FuseStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
std::stringstream to_fuse;

for (size_t i = 0; i < fused_ids.size(); ++i) {
to_fuse << CleanName(stage_to_axes->at(stage)[fused_ids[i]]->var->name_hint, op_name);
to_fuse << CleanName(stage_to_axes->at(stage)[fused_ids[i].IntValue()]->var->name_hint,
op_name);
if (i != fused_ids.size() - 1) {
to_fuse << ", ";
}
Expand Down Expand Up @@ -773,7 +773,7 @@ void ReorderStepNode::ApplyToState(State* state) const {
const Stage& stage = (*state)->stages[stage_id];
Array<Iterator> iters;
for (auto x : after_ids) {
iters.push_back(stage->iters[x]);
iters.push_back(stage->iters[x.IntValue()]);
}
state->CopyOnWrite()->stages.Set(
stage_id, Stage(stage->op, stage->op_type, iters, stage->compute_at, stage->attrs));
Expand All @@ -788,7 +788,7 @@ void ReorderStepNode::ApplyToSchedule(Array<te::Stage>* stages,
Array<IterVar> new_axes;
new_axes.reserve(axes.size());
for (auto i : after_ids) {
new_axes.push_back(axes[i]);
new_axes.push_back(axes[i.IntValue()]);
}
stage.reorder(new_axes);

Expand All @@ -804,7 +804,7 @@ String ReorderStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,

ss << "s[" << op_name << "].reorder(";
for (size_t i = 0; i < after_ids.size(); ++i) {
ss << CleanName((*stage_to_axes)[stage][after_ids[i]]->var->name_hint, op_name);
ss << CleanName((*stage_to_axes)[stage][after_ids[i].IntValue()]->var->name_hint, op_name);
if (i != after_ids.size() - 1) {
ss << ", ";
}
Expand Down Expand Up @@ -1180,10 +1180,10 @@ Optional<Integer> FollowFusedSplitStepNode::ExtractSplitLength(
const Array<Step>& transform_steps) const {
PrimExpr ret(1);

for (int src_step_id : src_step_ids) {
for (auto src_step_id : src_step_ids) {
// Make sure the src_step_id is within the range of transform_steps.
ICHECK_LT(src_step_id, transform_steps.size());
auto ps = transform_steps[src_step_id].as<SplitStepNode>();
ICHECK_LT(src_step_id.IntValue(), transform_steps.size());
auto ps = transform_steps[src_step_id.IntValue()].as<SplitStepNode>();
ICHECK(ps != nullptr);
// Multiple the splitting factor on corresponding splitting level of src_steps.
if (ps->lengths[level] && ret.defined()) {
Expand Down Expand Up @@ -1572,7 +1572,7 @@ te::Tensor CacheReadStepNode::ApplyToSchedule(Array<te::Stage>* stages,
const te::Stage& stage = (*stages)[stage_id];
Array<te::Operation> readers;
for (const auto& i : reader_stage_ids) {
readers.push_back((*stages)[i]->origin_op);
readers.push_back((*stages)[i.IntValue()]->origin_op);
}
auto out = schedule->cache_read(stage->origin_op.output(0), scope_name, readers);

Expand All @@ -1591,7 +1591,7 @@ String CacheReadStepNode::PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxes
auto stage = (*stages)[stage_id];
Array<te::Stage> reader_stages;
for (size_t i = 0; i < reader_stage_ids.size(); ++i) {
reader_stages.push_back((*stages)[reader_stage_ids[i]]);
reader_stages.push_back((*stages)[reader_stage_ids[i].IntValue()]);
}
auto out = ApplyToSchedule(stages, stage_to_axes, schedule);

Expand Down
5 changes: 4 additions & 1 deletion src/contrib/ethosu/cascader/parts/ethosu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,10 @@ TVM_REGISTER_GLOBAL("contrib.ethosu.cascader.EthosuPart")
Array<BlockConfig> valid_block_configs, int weight_tensor_idx) {
std::vector<te::Tensor> vsubgraph_inputs(subgraph_inputs.begin(), subgraph_inputs.end());
std::vector<Propagator> vpropagators(propagators.begin(), propagators.end());
std::vector<int> voutput_quantum(output_quantum.begin(), output_quantum.end());
std::vector<int> voutput_quantum;
std::transform(output_quantum.begin(), output_quantum.end(),
std::back_inserter(voutput_quantum),
[](auto&& val) { return val.IntValue(); });
TESubgraph subgraph;
subgraph.input_tensors = vsubgraph_inputs;
subgraph.output_tensor = subgraph_output;
Expand Down
5 changes: 4 additions & 1 deletion src/meta_schedule/arg_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,10 @@ TensorInfo TensorInfo::FromJSON(const ObjectRef& json_obj) {
LOG(FATAL) << "ValueError: Unable to parse the JSON object: " << json_obj
<< "\nThe error is: " << e.what();
}
return TensorInfo(DataType(dtype), ShapeTuple(shape.begin(), shape.end()));
std::vector<int64_t> s;
std::transform(shape.begin(), shape.end(), std::back_inserter(s),
[](Integer i) { return i.IntValue(); });
return TensorInfo(DataType(dtype), ShapeTuple(s.begin(), s.end()));
}

/******** Repr ********/
Expand Down
2 changes: 1 addition & 1 deletion src/meta_schedule/database/json_database.cc
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ Database Database::JSONDatabase(String path_workload, String path_tuning_record,
try {
const ArrayNode* arr = json_obj.as<ArrayNode>();
ICHECK_EQ(arr->size(), 2);
workload = workloads[Downcast<Integer>(arr->at(0))];
workload = workloads[Downcast<Integer>(arr->at(0)).IntValue()];
records[task_id] = TuningRecord::FromJSON(arr->at(1), workload);
} catch (std::runtime_error& e) {
LOG(FATAL) << "ValueError: Unable to parse TuningRecord, on line " << (task_id + 1)
Expand Down
2 changes: 1 addition & 1 deletion src/meta_schedule/postproc/rewrite_unbound_block.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class RewriteUnboundBlockNode : public PostprocNode {
context->target.value()->GetAttr<Integer>("max_threads_per_block");
CHECK(max_threads_per_block.defined())
<< "ValueError: missing attribute `max_threads_per_block` in the target";
this->max_threads_per_block_ = max_threads_per_block.value();
this->max_threads_per_block_ = max_threads_per_block.value().IntValue();
}

// Inherited from PostprocNode
Expand Down
2 changes: 1 addition & 1 deletion src/meta_schedule/postproc/verify_gpu_code.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ class VerifyGPUCodeNode : public PostprocNode {
{"max_vthread", Integer(8)},
{"max_vector_bytes", Integer(16)},
};
thread_warp_size_ = Extract(target, "thread_warp_size");
thread_warp_size_ = Extract(target, "thread_warp_size").IntValue();
}

bool Verify(const IRModule& mod) const {
Expand Down
2 changes: 1 addition & 1 deletion src/meta_schedule/schedule_rule/auto_bind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ class AutoBindNode : public ScheduleRuleNode {
context->target.value()->GetAttr<Integer>("max_threads_per_block");
CHECK(max_threads_per_block.defined())
<< "ValueError: missing attribute `max_threads_per_block` in the target";
this->max_threads_per_block_ = max_threads_per_block.value();
this->max_threads_per_block_ = max_threads_per_block.value().IntValue();
}

// Inherited from ScheduleRuleNode
Expand Down
4 changes: 3 additions & 1 deletion src/meta_schedule/schedule_rule/multi_level_tiling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,9 @@ std::vector<State> MultiLevelTilingNode::AddWriteReuse(State state) const {
if (Optional<Array<Integer>> ann = tir::GetAnn<Array<Integer>>(
state->sch->GetSRef(state->block_rv), "meta_schedule.write_cache_level")) {
req = ReuseType::kMustReuse;
levels = std::vector<int>(ann.value().begin(), ann.value().end());
levels.clear();
std::transform(ann.value().begin(), ann.value().end(), std::back_inserter(levels),
[](auto&& v) { return v.IntValue(); });
}
std::vector<State> results;
if (req == ReuseType::kMayReuse) {
Expand Down
2 changes: 1 addition & 1 deletion src/meta_schedule/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ struct ThreadedTraceApply {
* \return The number of cores.
*/
inline int GetTargetNumCores(const Target& target) {
int num_cores = target->GetAttr<Integer>("num-cores").value_or(-1);
int num_cores = target->GetAttr<Integer>("num-cores").value_or(-1).IntValue();
if (num_cores == -1) {
static const auto* f_cpu_count = runtime::Registry::Get("meta_schedule.cpu_count");
ICHECK(f_cpu_count)
Expand Down
2 changes: 1 addition & 1 deletion src/parser/parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1540,7 +1540,7 @@ class Parser {
}
case TokenType::kBoolean: {
Consume(TokenType::kBoolean);
int64_t value = Downcast<tvm::Integer>(next->data);
int64_t value = Downcast<tvm::Integer>(next->data).IntValue();
Expr e = Constant(support::BoolToNDArray(value), next->span);
ICHECK(e->span.defined()) << "constant spans must be defined";
return e;
Expand Down
4 changes: 3 additions & 1 deletion src/parser/token.h
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,9 @@ Token::Token(Span span, TokenType token_type, ObjectRef data) {

Token Token::Null() { return Token(Span(SourceName(), 0, 0, 0, 0), TokenType::kNull); }

int64_t Token::ToNumber() const { return Downcast<tvm::Integer>(this->operator->()->data); }
int64_t Token::ToNumber() const {
return Downcast<tvm::Integer>(this->operator->()->data).IntValue();
}

std::string Token::ToString() const { return Downcast<tvm::String>(this->operator->()->data); }

Expand Down
2 changes: 1 addition & 1 deletion src/relay/analysis/extract_fake_quantized_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class ExtractFakeQuantizedOpsWrapper : private MixedModeVisitor {
if (op != dequantize_op_) {
if (fake_quantized_op_freqs_.find(op->name) != fake_quantized_op_freqs_.end()) {
fake_quantized_op_freqs_.Set(op->name,
int64_t(fake_quantized_op_freqs_.at(op->name)) + 1);
fake_quantized_op_freqs_.at(op->name).IntValue() + 1);
} else {
fake_quantized_op_freqs_.Set(op->name, 1);
}
Expand Down
2 changes: 1 addition & 1 deletion src/relay/analysis/extract_operators.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class OperatorExtractorWrapper : private MixedModeVisitor {
auto it = operator_freqs_.find(op->name);
ICHECK(it != operator_freqs_.end())
<< "Call's OpNode must be visited and registered before access";
operator_freqs_.Set(op->name, 1 + operator_freqs_.at(op->name));
operator_freqs_.Set(op->name, 1 + operator_freqs_.at(op->name).IntValue());
}

MixedModeVisitor::VisitExpr_(n);
Expand Down
4 changes: 3 additions & 1 deletion src/relay/backend/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,9 @@ class RelayBuildModule : public runtime::ModuleNode {
if (config_->optional_homogeneous_target.defined()) {
// This pass currently only supports the homogeneous case.
pass_seqs.push_back(transform::SplitArgs(
config_->optional_homogeneous_target->GetAttr<Integer>("max_function_args", -1).value()));
config_->optional_homogeneous_target->GetAttr<Integer>("max_function_args", -1)
.value()
.IntValue()));
}

// Always plan devices so the remaining passes don't need to distinguish homogeneous vs
Expand Down
2 changes: 1 addition & 1 deletion src/relay/backend/contrib/ethosu/source_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ class EthosUModuleNode : public ModuleNode {
std::unordered_map<int, relay::contrib::ethosu::BaseAddress> param_idx_to_base_address;
for (const relay::contrib::ethosu::BaseAddress& base_address : artifact->base_addresses) {
if (base_address->primfunc_param_idx.defined()) {
param_idx_to_base_address[base_address->primfunc_param_idx] = base_address;
param_idx_to_base_address[base_address->primfunc_param_idx.IntValue()] = base_address;
}
}
for (unsigned int i = 0; i < param_idx_to_base_address.size(); i++) {
Expand Down
4 changes: 2 additions & 2 deletions src/relay/backend/contrib/tensorrt/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -291,8 +291,8 @@ class TensorRTJSONSerializer : public JSONSerializer {
}
ICHECK_EQ(target_attr.size(), 3);
SetAttr(node, "tensorrt_version",
{std::to_string(target_attr[0]), std::to_string(target_attr[1]),
std::to_string(target_attr[2])});
{std::to_string(target_attr[0]->value), std::to_string(target_attr[1]->value),
std::to_string(target_attr[2]->value)});
}

{
Expand Down
Loading