From 098289426b3b032d511237cecdf525c7b1786722 Mon Sep 17 00:00:00 2001 From: mbs-octoml Date: Thu, 14 Jul 2022 11:58:58 -0700 Subject: [PATCH 1/5] [Relay] Allow Primitive function to carry virtual device annotations in PlanDevices Previously Primitive=1 functions not analyzed and calls to such were completely unconstrained. With this change at least any virtual device annotation on the function are respected and accounted for in calls, even though the body is not analyzed. This may help with piggy-backing on PlanDevices for doing memory scope analysis, since it is now possible to express cross-scope functions on Primitive functions. However I believe there are other issues to deal with in addition to this one. --- src/relay/transforms/device_domains.cc | 11 +-- src/relay/transforms/device_planner.cc | 71 +++++++++----------- tests/python/relay/test_pass_plan_devices.py | 54 ++++++++++++++- 3 files changed, 90 insertions(+), 46 deletions(-) diff --git a/src/relay/transforms/device_domains.cc b/src/relay/transforms/device_domains.cc index 95249f902b48..e7d3a65dfe68 100644 --- a/src/relay/transforms/device_domains.cc +++ b/src/relay/transforms/device_domains.cc @@ -399,10 +399,13 @@ void DeviceDomains::SetDefault(DeviceDomainPtr domain, ICHECK(!default_virtual_device->IsFullyUnconstrained()); domain = Lookup(domain); if (domain->args_and_result_.empty()) { - DeviceDomainPtr defaulted_domain_ptr = UnifyOrNull( - domain, MakeFirstOrderDomain(config_->CanonicalVirtualDevice( - VirtualDevice::Default(domain->virtual_device_, default_virtual_device)))); - ICHECK_NOTNULL(defaulted_domain_ptr); + DeviceDomainPtr default_domain = MakeFirstOrderDomain(config_->CanonicalVirtualDevice( + VirtualDevice::Default(domain->virtual_device_, default_virtual_device))); + DeviceDomainPtr defaulted_domain_ptr = UnifyOrNull(domain, default_domain); + ICHECK(defaulted_domain_ptr != nullptr) << "domain:" << std::endl + << ToString(domain) << std::endl + << "default domain:" << std::endl + << ToString(default_domain); } else { for (const auto& sub_domain : domain->args_and_result_) { SetDefault(sub_domain, default_virtual_device); diff --git a/src/relay/transforms/device_planner.cc b/src/relay/transforms/device_planner.cc index 3562da3b0d6f..6ccbe38dbebf 100644 --- a/src/relay/transforms/device_planner.cc +++ b/src/relay/transforms/device_planner.cc @@ -553,18 +553,9 @@ class DeviceAnalyzer : public MixedModeVisitor { } void VisitExpr_(const FunctionNode* function_node) final { - // No need to step into fused primitive functions as they are lowered individually according - // to the devices of all their call sites. - if (function_node->HasNonzeroAttr(attr::kPrimitive)) { - return; - } - auto function = GetRef(function_node); auto func_domain = domains_->DomainFor(function); // higher-order - - // The function body domain must match the function result domain. - domains_->UnifyExprExact(function_node->body, - func_domain->function_result()); // may be higher-order + ICHECK_EQ(func_domain->function_arity(), function_node->params.size()); VLOG(2) << "initial function domain:" << std::endl << domains_->ToString(func_domain) << std::endl @@ -573,39 +564,33 @@ class DeviceAnalyzer : public MixedModeVisitor { << "for function:" << std::endl << PrettyPrint(function); - ICHECK_EQ(func_domain->function_arity(), function_node->params.size()); - for (size_t i = 0; i < function_node->params.size(); ++i) { - // The parameter domains must match the function argument domains. - domains_->UnifyExprExact(function_node->params[i], - func_domain->function_param(i)); // may be higher-order - VisitExpr(function_node->params[i]); + // The function body domain must match the function result domain. + domains_->UnifyExprExact(function_node->body, + func_domain->function_result()); // may be higher-order + if (!function_node->virtual_device()->IsFullyUnconstrained()) { + // The function body domain must match any existing virtual device annotation. + domains_->UnifyExprExact(function_node->body, + domains_->ForVirtualDevice(function_node->body->checked_type(), + function_node->virtual_device())); } - // If the function already has VirtualDevice attributes then we can further constrain the - // function's domain to match them. - if (!function_node->virtual_device()->IsFullyUnconstrained()) { - std::vector args_and_result; - for (auto param : function_node->params) { - args_and_result.emplace_back( - domains_->ForVirtualDevice(param->checked_type(), param->virtual_device())); - } - args_and_result.emplace_back(domains_->ForVirtualDevice(function_node->body->checked_type(), - function_node->virtual_device())); - auto annotation_domain = domains_->MakeHigherOrderDomain(std::move(args_and_result)); - if (domains_->UnifyOrNull(func_domain, annotation_domain) == nullptr) { // higher-order - // TODO(mbs): Proper diagnostics. - LOG(FATAL) << "Function VirtualDevices are incompatible with its \"on_device\" annotation. " - "Function:" - << std::endl - << PrettyPrint(function) << std::endl - << "with function virtual devices:" << std::endl - << domains_->ToString(func_domain) << std::endl - << "and annotation virtual devices:" << std::endl - << domains_->ToString(annotation_domain); + for (size_t i = 0; i < function_node->params.size(); ++i) { + const auto& param = function_node->params[i]; + // The parameter domain must match the function argument domain. + domains_->UnifyExprExact(param, + func_domain->function_param(i)); // may be higher-order + if (!param->virtual_device()->IsFullyUnconstrained()) { + // The parameter domain must match any existing virtual device annotation. + domains_->UnifyExprExact( + param, domains_->ForVirtualDevice(param->checked_type(), param->virtual_device())); } + VisitExpr(param); } - VisitExpr(function_node->body); + // No need to step into the body of Primitive functions. + if (!function_node->HasNonzeroAttr(attr::kPrimitive)) { + VisitExpr(function_node->body); + } VLOG(2) << "final function domain:" << std::endl << domains_->ToString(func_domain) << std::endl @@ -839,10 +824,16 @@ class DeviceDefaulter : public ExprVisitor { // For calls to Relay functions this step is identical to that for VisitExpr_(FunctionNode*) // above. But for calls to primitives we may still need to force free domains to be // defaulted. - VLOG(2) << "before defaulting callee:" << std::endl << domains_->ToString(func_domain); + VLOG(2) << "before defaulting callee:" << std::endl + << PrettyPrint(call_node->op) << std::endl + << "of domain:" << std::endl + << domains_->ToString(func_domain); domains_->SetResultDefaultThenParams(func_domain, domains_->config()->default_primitive_virtual_device); - VLOG(2) << "after defaulting callee:" << std::endl << domains_->ToString(func_domain); + VLOG(2) << "after defaulting callee:" << std::endl + << PrettyPrint(call_node->op) << std::endl + << "of domain:" << std::endl + << domains_->ToString(func_domain); } return ExprVisitor::VisitExpr_(call_node); } diff --git a/tests/python/relay/test_pass_plan_devices.py b/tests/python/relay/test_pass_plan_devices.py index 35f072d19d92..a13f77ad5996 100644 --- a/tests/python/relay/test_pass_plan_devices.py +++ b/tests/python/relay/test_pass_plan_devices.py @@ -47,6 +47,9 @@ CPU_SCOPE_A = tvm.target.VirtualDevice(CPU_DEVICE, CPU_TARGET, memory_scope="scopeA") CPU_SCOPE_B = tvm.target.VirtualDevice(CPU_DEVICE, CPU_TARGET, memory_scope="scopeB") +GPU_SCOPE_GLOBAL = tvm.target.VirtualDevice(GPU_DEVICE, GPU_TARGET, memory_scope="global") +GPU_SCOPE_TEXTURE = tvm.target.VirtualDevice(GPU_DEVICE, GPU_TARGET, memory_scope="global.texture") + CTXT = tvm.transform.PassContext(config={"relay.fallback_device_type": DEFAULT.device_type_int}) core = tvm.IRModule() @@ -57,7 +60,7 @@ def rewrite_and_assert(in_mod, expected_mod): """Manually run the pass and assert it's structurally equals to the expected.""" - config = tvm.target.make_compilation_config(CTXT, TARGETS, HOST_TARGET) + config = tvm.target.make_compilation_config(CTXT, TARGETS) actual_mod = relay.transform.InferType()(in_mod) actual_mod = relay.transform.PlanDevices(config)(actual_mod) actual_mod = relay.transform.InferType()(actual_mod) @@ -1774,11 +1777,58 @@ def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], metatable, ) - config = tvm.target.make_compilation_config(CTXT, TARGETS, HOST_TARGET) + config = tvm.target.make_compilation_config(CTXT, TARGETS) actual_mod = relay.transform.InferType()(input()) actual_mod = relay.transform.PlanDevices(config)(actual_mod) relay.transform.InferType()(actual_mod) +def test_primitive(): + metatable = { + "VirtualDevice": [ + GPU_SCOPE_GLOBAL, + GPU_SCOPE_TEXTURE, + ] + } + + # This module should pass PlanDevices without failure thanks to the annotations + # the Primitive functions. + mod = tvm.parser.parse( + """ + #[version = "0.0.5"] + def @main(%data1: Tensor[(1, 32, 40, 40), float32], + %data2: Tensor[(1, 32, 40, 40), float32], + virtual_device=meta[VirtualDevice][1]) { + %0 = fn (%a, Primitive=1) { + layout_transform(%a, src_layout="NCHW", dst_layout="NCHW4c") + }; + %1 = %0(%data1); + %3 = %0(%data2); + %5 = fn (%a {virtual_device=meta[VirtualDevice][0]}, + %b {virtual_device=meta[VirtualDevice][0]}, + virtual_device=meta[VirtualDevice][1], + Primitive=1) { + add(%a, %b) + }; + %6 = %5(%1, %3); + %10 = fn (%a, Primitive=1) { + layout_transform(%a, src_layout="NCHW4c", dst_layout="NCHW") + }; + %10(%6) + } + """, + "from_string", + None, + metatable, + ) + print(mod) + + config = tvm.target.make_compilation_config(CTXT, GPU_TARGET) + mod = relay.transform.InferType()(mod) + mod = relay.transform.PlanDevices(config)(mod) + mod = relay.transform.InferType()(mod) + print(mod) + + if __name__ == "__main__": tvm.testing.main() From ca8b0ba2677238b1b8267d0097b95d07e488c7a0 Mon Sep 17 00:00:00 2001 From: mbs-octoml Date: Thu, 14 Jul 2022 12:51:14 -0700 Subject: [PATCH 2/5] - comments --- tests/python/relay/test_pass_plan_devices.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/python/relay/test_pass_plan_devices.py b/tests/python/relay/test_pass_plan_devices.py index a13f77ad5996..07a977644510 100644 --- a/tests/python/relay/test_pass_plan_devices.py +++ b/tests/python/relay/test_pass_plan_devices.py @@ -1784,6 +1784,8 @@ def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], def test_primitive(): + """Annotations on Primitive functions should be accepted, even though the body + of the Primitive function is not considered during PlanDevices.""" metatable = { "VirtualDevice": [ GPU_SCOPE_GLOBAL, @@ -1791,8 +1793,6 @@ def test_primitive(): ] } - # This module should pass PlanDevices without failure thanks to the annotations - # the Primitive functions. mod = tvm.parser.parse( """ #[version = "0.0.5"] @@ -1825,8 +1825,8 @@ def @main(%data1: Tensor[(1, 32, 40, 40), float32], config = tvm.target.make_compilation_config(CTXT, GPU_TARGET) mod = relay.transform.InferType()(mod) + # PlanDevices should succeed. mod = relay.transform.PlanDevices(config)(mod) - mod = relay.transform.InferType()(mod) print(mod) From c70dafdd708b58c9224a92d72786293798a15976 Mon Sep 17 00:00:00 2001 From: mbs-octoml Date: Thu, 14 Jul 2022 13:48:21 -0700 Subject: [PATCH 3/5] - also canonicalize targets When including virtual device annotations in test relay programs the annotation will typically use a target which was used as an input to the make_compilation_config helper, but due to various canonicalization make not be pointer equal to the final structurally equal target which ends up inside the constructed CompilationConfig. However VirtualDevices use pointer equality when comparing their target field. So make sure the notion of CanonicalVirtualDevice also accounts for canonical targets. --- include/tvm/target/compilation_config.h | 6 +++ src/target/compilation_config.cc | 46 +++++++++++++++----- src/target/target.cc | 3 +- tests/cpp/target/compilation_config_test.cc | 29 ++++++++++++ tests/cpp/target/virtual_device_test.cc | 4 ++ tests/python/relay/test_pass_plan_devices.py | 3 +- 6 files changed, 78 insertions(+), 13 deletions(-) diff --git a/include/tvm/target/compilation_config.h b/include/tvm/target/compilation_config.h index 8946a104dac4..cfb782e5ea70 100644 --- a/include/tvm/target/compilation_config.h +++ b/include/tvm/target/compilation_config.h @@ -133,6 +133,12 @@ class CompilationConfigNode : public Object { */ Optional FindPrimitiveTargetForKind(const std::string& kind_name) const; + /*! + * \brief Returns a \p Target structurally equal to \p target, however prefer a structually equal + * known host or primitive target if the configuration has one. + */ + Target CanonicalTarget(const Target& target) const; + /*! * \brief Returns a \p VirtualDevice agreeing with \p virtual_device on all its constrained * fields, however: diff --git a/src/target/compilation_config.cc b/src/target/compilation_config.cc index cb50615ce6a5..32ba5b6c5164 100644 --- a/src/target/compilation_config.cc +++ b/src/target/compilation_config.cc @@ -72,16 +72,43 @@ Optional CompilationConfigNode::FindPrimitiveTargetForKind( return *itr; } +Target CompilationConfigNode::CanonicalTarget(const Target& target) const { + // Fast path -- object identity. + if (target == host_target) { + return target; + } + for (const auto& primitive_target : primitive_targets) { + if (target == primitive_target) { + return target; + } + } + // Slow path -- structural equality. We have so few targets it does not seem worth building an + // index. + if (StructuralEqual()(target, host_target)) { + return host_target; + } + for (const auto& primitive_target : primitive_targets) { + if (StructuralEqual()(target, primitive_target)) { + return primitive_target; + } + } + // No match. + return target; +} + VirtualDevice CompilationConfigNode::CanonicalVirtualDevice( const VirtualDevice& virtual_device) const { - if (virtual_device->target.defined()) { - return virtual_device_cache_.Unique(virtual_device); - } DLDeviceType device_type = virtual_device->device_type(); - // TODO(mbs): Proper diagnostics. - CHECK(device_type != kInvalidDeviceType) - << "VirtualDevice annotations must include at least a device_type"; - Target target = FindPrimitiveTargetForDeviceOrFail(virtual_device->device_type()); + Target target = virtual_device->target; + if (target.defined()) { + target = CanonicalTarget(target); + } else { + // Find the (unique) target matching the device's device type. + // TODO(mbs): Proper diagnostics. + CHECK(device_type != kInvalidDeviceType) + << "VirtualDevice annotations must include at least a device_type"; + target = FindPrimitiveTargetForDeviceOrFail(virtual_device->device_type()); + } return virtual_device_cache_.Unique(VirtualDevice(device_type, virtual_device->virtual_device_id, target, virtual_device->memory_scope)); } @@ -222,9 +249,8 @@ void CompilationConfigNode::Init(const transform::PassContext& pass_ctx, // Establish the default primitive VirtualDevice, choosing a known Target to match the device // type. We do not create a default target, it must already exist as a primitive target. // - default_primitive_virtual_device = virtual_device_cache_.Unique(VirtualDevice( - default_primitive_device_type, - /*virtual_device_id=*/0, FindPrimitiveTargetForDeviceOrFail(default_primitive_device_type))); + default_primitive_virtual_device = CanonicalVirtualDevice( + VirtualDevice::ForDeviceType(default_primitive_device_type, /*virtual_device_id=*/0)); ICHECK(default_primitive_virtual_device.defined()); ICHECK(default_primitive_virtual_device->target.defined()); diff --git a/src/target/target.cc b/src/target/target.cc index afdfad9b76b9..07b347f09817 100644 --- a/src/target/target.cc +++ b/src/target/target.cc @@ -550,7 +550,8 @@ Optional TargetNode::GetHost() const { String TargetNode::ToDebugString() const { std::ostringstream os; os << "Target("; - os << "kind='" << kind->name << "'"; + os << "id=" << std::hex << reinterpret_cast(this); + os << ", kind='" << kind->name << "'"; if (!tag.empty()) { os << ", tag='" << tag << "'"; } diff --git a/tests/cpp/target/compilation_config_test.cc b/tests/cpp/target/compilation_config_test.cc index 825cb5baeb8c..e3e85110d87e 100644 --- a/tests/cpp/target/compilation_config_test.cc +++ b/tests/cpp/target/compilation_config_test.cc @@ -286,6 +286,29 @@ TEST(CompilationConfig, FindPrimitiveTargetForKind_NotFound) { ASSERT_FALSE(config->FindPrimitiveTargetForKind("cutlass").defined()); } +TEST(CompilationConfig, CanonicalTarget) { + Target host_target = TestDefaultCpuTarget(); + Target cuda_target = TestCudaTarget(); + Target cpu_target = TestCpuTarget(); + CompilationConfig config = TestCompilationConfig(); + + { + Target other_cuda_target = Target::WithHost(TestCudaTarget(), TestDefaultCpuTarget()); + ASSERT_NE(other_cuda_target, cuda_target); + ASSERT_EQ(config->CanonicalTarget(other_cuda_target), + config->FindPrimitiveTargetForKind("cuda")); + } + { + Target other_host_target = TestDefaultCpuTarget(); + ASSERT_NE(other_host_target, cuda_target); + ASSERT_EQ(config->CanonicalTarget(other_host_target), config->host_target); + } + { + Target other_target("cuda -max_num_threads=7"); + ASSERT_EQ(config->CanonicalTarget(other_target), other_target); + } +} + TEST(CompilationConfig, CanonicalVirtualDevice) { Target host_target = TestDefaultCpuTarget(); Target cuda_target = TestCudaTarget(); @@ -306,6 +329,12 @@ TEST(CompilationConfig, CanonicalVirtualDevice) { EXPECT_TRUE(StructuralEqual()(actual->target, Target::WithHost(cuda_target, host_target))); EXPECT_EQ(config->CanonicalVirtualDevice(in), actual); } + { + Target other_cuda_target = Target::WithHost(TestCudaTarget(), TestDefaultCpuTarget()); + VirtualDevice in = VirtualDevice(kDLCUDA, -1, other_cuda_target); + VirtualDevice actual = config->CanonicalVirtualDevice(in); + ASSERT_EQ(actual->target, config->FindPrimitiveTargetForKind("cuda")); + } } TEST(CompilationConfig, CanonicalVirtualDevice_NoDevice) { diff --git a/tests/cpp/target/virtual_device_test.cc b/tests/cpp/target/virtual_device_test.cc index 35e078713d1b..d982a8ae2153 100644 --- a/tests/cpp/target/virtual_device_test.cc +++ b/tests/cpp/target/virtual_device_test.cc @@ -107,6 +107,7 @@ TEST(VirtualDeviceCache, Memoized) { VirtualDeviceCache cache; Target target_a = Target("cuda"); Target target_b = Target("llvm"); + Target target_c = Target("cuda"); VirtualDevice virtual_device_a = cache.Make(kDLCUDA, 3, target_a, "local"); VirtualDevice virtual_device_b = cache.Make(kDLCPU, 1, target_b, "global"); @@ -115,6 +116,9 @@ TEST(VirtualDeviceCache, Memoized) { EXPECT_NE(cache.Make(kDLCUDA, 2, target_a, "local"), virtual_device_a); EXPECT_NE(cache.Make(kDLCPU, 3, target_b, "local"), virtual_device_a); EXPECT_NE(cache.Make(kDLCUDA, 3, target_a, "global"), virtual_device_a); + EXPECT_EQ(cache.Make(kDLCUDA, 3, Target("cuda"), "local"), virtual_device_a); + EXPECT_NE(cache.Make(kDLCUDA, 3, Target("cuda -max_threads_per_block=4096"), "local"), + virtual_device_a); } } // namespace diff --git a/tests/python/relay/test_pass_plan_devices.py b/tests/python/relay/test_pass_plan_devices.py index 07a977644510..112f6966bd7d 100644 --- a/tests/python/relay/test_pass_plan_devices.py +++ b/tests/python/relay/test_pass_plan_devices.py @@ -1797,8 +1797,7 @@ def test_primitive(): """ #[version = "0.0.5"] def @main(%data1: Tensor[(1, 32, 40, 40), float32], - %data2: Tensor[(1, 32, 40, 40), float32], - virtual_device=meta[VirtualDevice][1]) { + %data2: Tensor[(1, 32, 40, 40), float32]) { %0 = fn (%a, Primitive=1) { layout_transform(%a, src_layout="NCHW", dst_layout="NCHW4c") }; From ba6667c8a26642da552b9af814c06f044fe37d28 Mon Sep 17 00:00:00 2001 From: mbs-octoml Date: Thu, 14 Jul 2022 13:59:36 -0700 Subject: [PATCH 4/5] - update unit test to reflect the Ardreno example --- tests/python/relay/test_pass_plan_devices.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/python/relay/test_pass_plan_devices.py b/tests/python/relay/test_pass_plan_devices.py index 112f6966bd7d..2749339afdce 100644 --- a/tests/python/relay/test_pass_plan_devices.py +++ b/tests/python/relay/test_pass_plan_devices.py @@ -1810,7 +1810,9 @@ def @main(%data1: Tensor[(1, 32, 40, 40), float32], add(%a, %b) }; %6 = %5(%1, %3); - %10 = fn (%a, Primitive=1) { + %10 = fn (%a, + virtual_device=meta[VirtualDevice][0], + Primitive=1) { layout_transform(%a, src_layout="NCHW4c", dst_layout="NCHW") }; %10(%6) From efea08279d3a23fd75976c4780ce7d0d3ea80033 Mon Sep 17 00:00:00 2001 From: mbs-octoml Date: Thu, 14 Jul 2022 15:07:16 -0700 Subject: [PATCH 5/5] - trivial cleanup --- src/target/compilation_config.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/target/compilation_config.cc b/src/target/compilation_config.cc index 32ba5b6c5164..ef54896ef187 100644 --- a/src/target/compilation_config.cc +++ b/src/target/compilation_config.cc @@ -107,7 +107,7 @@ VirtualDevice CompilationConfigNode::CanonicalVirtualDevice( // TODO(mbs): Proper diagnostics. CHECK(device_type != kInvalidDeviceType) << "VirtualDevice annotations must include at least a device_type"; - target = FindPrimitiveTargetForDeviceOrFail(virtual_device->device_type()); + target = FindPrimitiveTargetForDeviceOrFail(device_type); } return virtual_device_cache_.Unique(VirtualDevice(device_type, virtual_device->virtual_device_id, target, virtual_device->memory_scope));