Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay] Allow Primitive functions to carry virtual device annotations in PlanDevices #12095

Merged
merged 5 commits into from
Jul 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions include/tvm/target/compilation_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,12 @@ class CompilationConfigNode : public Object {
*/
Optional<Target> FindPrimitiveTargetForKind(const std::string& kind_name) const;

/*!
* \brief Returns a \p Target structurally equal to \p target, however prefer a structually equal
* known host or primitive target if the configuration has one.
*/
Target CanonicalTarget(const Target& target) const;

/*!
* \brief Returns a \p VirtualDevice agreeing with \p virtual_device on all its constrained
* fields, however:
Expand Down
11 changes: 7 additions & 4 deletions src/relay/transforms/device_domains.cc
Original file line number Diff line number Diff line change
Expand Up @@ -399,10 +399,13 @@ void DeviceDomains::SetDefault(DeviceDomainPtr domain,
ICHECK(!default_virtual_device->IsFullyUnconstrained());
domain = Lookup(domain);
if (domain->args_and_result_.empty()) {
DeviceDomainPtr defaulted_domain_ptr = UnifyOrNull(
domain, MakeFirstOrderDomain(config_->CanonicalVirtualDevice(
VirtualDevice::Default(domain->virtual_device_, default_virtual_device))));
ICHECK_NOTNULL(defaulted_domain_ptr);
DeviceDomainPtr default_domain = MakeFirstOrderDomain(config_->CanonicalVirtualDevice(
VirtualDevice::Default(domain->virtual_device_, default_virtual_device)));
DeviceDomainPtr defaulted_domain_ptr = UnifyOrNull(domain, default_domain);
ICHECK(defaulted_domain_ptr != nullptr) << "domain:" << std::endl
<< ToString(domain) << std::endl
<< "default domain:" << std::endl
<< ToString(default_domain);
} else {
for (const auto& sub_domain : domain->args_and_result_) {
SetDefault(sub_domain, default_virtual_device);
Expand Down
71 changes: 31 additions & 40 deletions src/relay/transforms/device_planner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -553,18 +553,9 @@ class DeviceAnalyzer : public MixedModeVisitor {
}

void VisitExpr_(const FunctionNode* function_node) final {
// No need to step into fused primitive functions as they are lowered individually according
// to the devices of all their call sites.
if (function_node->HasNonzeroAttr(attr::kPrimitive)) {
return;
}

auto function = GetRef<Function>(function_node);
auto func_domain = domains_->DomainFor(function); // higher-order

// The function body domain must match the function result domain.
domains_->UnifyExprExact(function_node->body,
func_domain->function_result()); // may be higher-order
ICHECK_EQ(func_domain->function_arity(), function_node->params.size());

VLOG(2) << "initial function domain:" << std::endl
<< domains_->ToString(func_domain) << std::endl
Expand All @@ -573,39 +564,33 @@ class DeviceAnalyzer : public MixedModeVisitor {
<< "for function:" << std::endl
<< PrettyPrint(function);

ICHECK_EQ(func_domain->function_arity(), function_node->params.size());
for (size_t i = 0; i < function_node->params.size(); ++i) {
// The parameter domains must match the function argument domains.
domains_->UnifyExprExact(function_node->params[i],
func_domain->function_param(i)); // may be higher-order
VisitExpr(function_node->params[i]);
// The function body domain must match the function result domain.
domains_->UnifyExprExact(function_node->body,
func_domain->function_result()); // may be higher-order
if (!function_node->virtual_device()->IsFullyUnconstrained()) {
// The function body domain must match any existing virtual device annotation.
domains_->UnifyExprExact(function_node->body,
domains_->ForVirtualDevice(function_node->body->checked_type(),
function_node->virtual_device()));
}

// If the function already has VirtualDevice attributes then we can further constrain the
// function's domain to match them.
if (!function_node->virtual_device()->IsFullyUnconstrained()) {
std::vector<DeviceDomainPtr> args_and_result;
for (auto param : function_node->params) {
args_and_result.emplace_back(
domains_->ForVirtualDevice(param->checked_type(), param->virtual_device()));
}
args_and_result.emplace_back(domains_->ForVirtualDevice(function_node->body->checked_type(),
function_node->virtual_device()));
auto annotation_domain = domains_->MakeHigherOrderDomain(std::move(args_and_result));
if (domains_->UnifyOrNull(func_domain, annotation_domain) == nullptr) { // higher-order
// TODO(mbs): Proper diagnostics.
LOG(FATAL) << "Function VirtualDevices are incompatible with its \"on_device\" annotation. "
"Function:"
<< std::endl
<< PrettyPrint(function) << std::endl
<< "with function virtual devices:" << std::endl
<< domains_->ToString(func_domain) << std::endl
<< "and annotation virtual devices:" << std::endl
<< domains_->ToString(annotation_domain);
for (size_t i = 0; i < function_node->params.size(); ++i) {
const auto& param = function_node->params[i];
// The parameter domain must match the function argument domain.
domains_->UnifyExprExact(param,
func_domain->function_param(i)); // may be higher-order
if (!param->virtual_device()->IsFullyUnconstrained()) {
// The parameter domain must match any existing virtual device annotation.
domains_->UnifyExprExact(
param, domains_->ForVirtualDevice(param->checked_type(), param->virtual_device()));
}
VisitExpr(param);
}

VisitExpr(function_node->body);
// No need to step into the body of Primitive functions.
if (!function_node->HasNonzeroAttr(attr::kPrimitive)) {
VisitExpr(function_node->body);
}

VLOG(2) << "final function domain:" << std::endl
<< domains_->ToString(func_domain) << std::endl
Expand Down Expand Up @@ -839,10 +824,16 @@ class DeviceDefaulter : public ExprVisitor {
// For calls to Relay functions this step is identical to that for VisitExpr_(FunctionNode*)
// above. But for calls to primitives we may still need to force free domains to be
// defaulted.
VLOG(2) << "before defaulting callee:" << std::endl << domains_->ToString(func_domain);
VLOG(2) << "before defaulting callee:" << std::endl
<< PrettyPrint(call_node->op) << std::endl
<< "of domain:" << std::endl
<< domains_->ToString(func_domain);
domains_->SetResultDefaultThenParams(func_domain,
domains_->config()->default_primitive_virtual_device);
VLOG(2) << "after defaulting callee:" << std::endl << domains_->ToString(func_domain);
VLOG(2) << "after defaulting callee:" << std::endl
<< PrettyPrint(call_node->op) << std::endl
<< "of domain:" << std::endl
<< domains_->ToString(func_domain);
}
return ExprVisitor::VisitExpr_(call_node);
}
Expand Down
46 changes: 36 additions & 10 deletions src/target/compilation_config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,16 +72,43 @@ Optional<Target> CompilationConfigNode::FindPrimitiveTargetForKind(
return *itr;
}

Target CompilationConfigNode::CanonicalTarget(const Target& target) const {
// Fast path -- object identity.
if (target == host_target) {
return target;
}
for (const auto& primitive_target : primitive_targets) {
if (target == primitive_target) {
return target;
}
}
// Slow path -- structural equality. We have so few targets it does not seem worth building an
// index.
if (StructuralEqual()(target, host_target)) {
return host_target;
}
for (const auto& primitive_target : primitive_targets) {
if (StructuralEqual()(target, primitive_target)) {
return primitive_target;
}
}
// No match.
return target;
}

VirtualDevice CompilationConfigNode::CanonicalVirtualDevice(
const VirtualDevice& virtual_device) const {
if (virtual_device->target.defined()) {
return virtual_device_cache_.Unique(virtual_device);
}
DLDeviceType device_type = virtual_device->device_type();
// TODO(mbs): Proper diagnostics.
CHECK(device_type != kInvalidDeviceType)
<< "VirtualDevice annotations must include at least a device_type";
Target target = FindPrimitiveTargetForDeviceOrFail(virtual_device->device_type());
Target target = virtual_device->target;
if (target.defined()) {
target = CanonicalTarget(target);
} else {
// Find the (unique) target matching the device's device type.
// TODO(mbs): Proper diagnostics.
CHECK(device_type != kInvalidDeviceType)
<< "VirtualDevice annotations must include at least a device_type";
target = FindPrimitiveTargetForDeviceOrFail(device_type);
}
return virtual_device_cache_.Unique(VirtualDevice(device_type, virtual_device->virtual_device_id,
target, virtual_device->memory_scope));
}
Expand Down Expand Up @@ -222,9 +249,8 @@ void CompilationConfigNode::Init(const transform::PassContext& pass_ctx,
// Establish the default primitive VirtualDevice, choosing a known Target to match the device
// type. We do not create a default target, it must already exist as a primitive target.
//
default_primitive_virtual_device = virtual_device_cache_.Unique(VirtualDevice(
default_primitive_device_type,
/*virtual_device_id=*/0, FindPrimitiveTargetForDeviceOrFail(default_primitive_device_type)));
default_primitive_virtual_device = CanonicalVirtualDevice(
VirtualDevice::ForDeviceType(default_primitive_device_type, /*virtual_device_id=*/0));

ICHECK(default_primitive_virtual_device.defined());
ICHECK(default_primitive_virtual_device->target.defined());
Expand Down
3 changes: 2 additions & 1 deletion src/target/target.cc
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,8 @@ Optional<Target> TargetNode::GetHost() const {
String TargetNode::ToDebugString() const {
std::ostringstream os;
os << "Target(";
os << "kind='" << kind->name << "'";
os << "id=" << std::hex << reinterpret_cast<size_t>(this);
os << ", kind='" << kind->name << "'";
if (!tag.empty()) {
os << ", tag='" << tag << "'";
}
Expand Down
29 changes: 29 additions & 0 deletions tests/cpp/target/compilation_config_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,29 @@ TEST(CompilationConfig, FindPrimitiveTargetForKind_NotFound) {
ASSERT_FALSE(config->FindPrimitiveTargetForKind("cutlass").defined());
}

TEST(CompilationConfig, CanonicalTarget) {
Target host_target = TestDefaultCpuTarget();
Target cuda_target = TestCudaTarget();
Target cpu_target = TestCpuTarget();
CompilationConfig config = TestCompilationConfig();

{
Target other_cuda_target = Target::WithHost(TestCudaTarget(), TestDefaultCpuTarget());
ASSERT_NE(other_cuda_target, cuda_target);
ASSERT_EQ(config->CanonicalTarget(other_cuda_target),
config->FindPrimitiveTargetForKind("cuda"));
}
{
Target other_host_target = TestDefaultCpuTarget();
ASSERT_NE(other_host_target, cuda_target);
ASSERT_EQ(config->CanonicalTarget(other_host_target), config->host_target);
}
{
Target other_target("cuda -max_num_threads=7");
ASSERT_EQ(config->CanonicalTarget(other_target), other_target);
}
}

TEST(CompilationConfig, CanonicalVirtualDevice) {
Target host_target = TestDefaultCpuTarget();
Target cuda_target = TestCudaTarget();
Expand All @@ -306,6 +329,12 @@ TEST(CompilationConfig, CanonicalVirtualDevice) {
EXPECT_TRUE(StructuralEqual()(actual->target, Target::WithHost(cuda_target, host_target)));
EXPECT_EQ(config->CanonicalVirtualDevice(in), actual);
}
{
Target other_cuda_target = Target::WithHost(TestCudaTarget(), TestDefaultCpuTarget());
VirtualDevice in = VirtualDevice(kDLCUDA, -1, other_cuda_target);
VirtualDevice actual = config->CanonicalVirtualDevice(in);
ASSERT_EQ(actual->target, config->FindPrimitiveTargetForKind("cuda"));
}
}

TEST(CompilationConfig, CanonicalVirtualDevice_NoDevice) {
Expand Down
4 changes: 4 additions & 0 deletions tests/cpp/target/virtual_device_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ TEST(VirtualDeviceCache, Memoized) {
VirtualDeviceCache cache;
Target target_a = Target("cuda");
Target target_b = Target("llvm");
Target target_c = Target("cuda");
VirtualDevice virtual_device_a = cache.Make(kDLCUDA, 3, target_a, "local");
VirtualDevice virtual_device_b = cache.Make(kDLCPU, 1, target_b, "global");

Expand All @@ -115,6 +116,9 @@ TEST(VirtualDeviceCache, Memoized) {
EXPECT_NE(cache.Make(kDLCUDA, 2, target_a, "local"), virtual_device_a);
EXPECT_NE(cache.Make(kDLCPU, 3, target_b, "local"), virtual_device_a);
EXPECT_NE(cache.Make(kDLCUDA, 3, target_a, "global"), virtual_device_a);
EXPECT_EQ(cache.Make(kDLCUDA, 3, Target("cuda"), "local"), virtual_device_a);
EXPECT_NE(cache.Make(kDLCUDA, 3, Target("cuda -max_threads_per_block=4096"), "local"),
virtual_device_a);
}

} // namespace
Expand Down
55 changes: 53 additions & 2 deletions tests/python/relay/test_pass_plan_devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@
CPU_SCOPE_A = tvm.target.VirtualDevice(CPU_DEVICE, CPU_TARGET, memory_scope="scopeA")
CPU_SCOPE_B = tvm.target.VirtualDevice(CPU_DEVICE, CPU_TARGET, memory_scope="scopeB")

GPU_SCOPE_GLOBAL = tvm.target.VirtualDevice(GPU_DEVICE, GPU_TARGET, memory_scope="global")
GPU_SCOPE_TEXTURE = tvm.target.VirtualDevice(GPU_DEVICE, GPU_TARGET, memory_scope="global.texture")

CTXT = tvm.transform.PassContext(config={"relay.fallback_device_type": DEFAULT.device_type_int})

core = tvm.IRModule()
Expand All @@ -57,7 +60,7 @@

def rewrite_and_assert(in_mod, expected_mod):
"""Manually run the pass and assert it's structurally equals to the expected."""
config = tvm.target.make_compilation_config(CTXT, TARGETS, HOST_TARGET)
config = tvm.target.make_compilation_config(CTXT, TARGETS)
actual_mod = relay.transform.InferType()(in_mod)
actual_mod = relay.transform.PlanDevices(config)(actual_mod)
actual_mod = relay.transform.InferType()(actual_mod)
Expand Down Expand Up @@ -1774,11 +1777,59 @@ def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32],
metatable,
)

config = tvm.target.make_compilation_config(CTXT, TARGETS, HOST_TARGET)
config = tvm.target.make_compilation_config(CTXT, TARGETS)
actual_mod = relay.transform.InferType()(input())
actual_mod = relay.transform.PlanDevices(config)(actual_mod)
relay.transform.InferType()(actual_mod)


def test_primitive():
"""Annotations on Primitive functions should be accepted, even though the body
of the Primitive function is not considered during PlanDevices."""
metatable = {
"VirtualDevice": [
GPU_SCOPE_GLOBAL,
GPU_SCOPE_TEXTURE,
]
}

mod = tvm.parser.parse(
"""
#[version = "0.0.5"]
def @main(%data1: Tensor[(1, 32, 40, 40), float32],
%data2: Tensor[(1, 32, 40, 40), float32]) {
%0 = fn (%a, Primitive=1) {
layout_transform(%a, src_layout="NCHW", dst_layout="NCHW4c")
};
%1 = %0(%data1);
%3 = %0(%data2);
%5 = fn (%a {virtual_device=meta[VirtualDevice][0]},
%b {virtual_device=meta[VirtualDevice][0]},
virtual_device=meta[VirtualDevice][1],
Primitive=1) {
add(%a, %b)
};
%6 = %5(%1, %3);
%10 = fn (%a,
virtual_device=meta[VirtualDevice][0],
Primitive=1) {
layout_transform(%a, src_layout="NCHW4c", dst_layout="NCHW")
};
%10(%6)
}
""",
"from_string",
None,
metatable,
)
print(mod)

config = tvm.target.make_compilation_config(CTXT, GPU_TARGET)
mod = relay.transform.InferType()(mod)
# PlanDevices should succeed.
mod = relay.transform.PlanDevices(config)(mod)
print(mod)


if __name__ == "__main__":
tvm.testing.main()