diff --git a/include/tvm/target/compilation_config.h b/include/tvm/target/compilation_config.h index 8946a104dac4..cfb782e5ea70 100644 --- a/include/tvm/target/compilation_config.h +++ b/include/tvm/target/compilation_config.h @@ -133,6 +133,12 @@ class CompilationConfigNode : public Object { */ Optional FindPrimitiveTargetForKind(const std::string& kind_name) const; + /*! + * \brief Returns a \p Target structurally equal to \p target, however prefer a structually equal + * known host or primitive target if the configuration has one. + */ + Target CanonicalTarget(const Target& target) const; + /*! * \brief Returns a \p VirtualDevice agreeing with \p virtual_device on all its constrained * fields, however: diff --git a/src/relay/transforms/device_domains.cc b/src/relay/transforms/device_domains.cc index 95249f902b48..e7d3a65dfe68 100644 --- a/src/relay/transforms/device_domains.cc +++ b/src/relay/transforms/device_domains.cc @@ -399,10 +399,13 @@ void DeviceDomains::SetDefault(DeviceDomainPtr domain, ICHECK(!default_virtual_device->IsFullyUnconstrained()); domain = Lookup(domain); if (domain->args_and_result_.empty()) { - DeviceDomainPtr defaulted_domain_ptr = UnifyOrNull( - domain, MakeFirstOrderDomain(config_->CanonicalVirtualDevice( - VirtualDevice::Default(domain->virtual_device_, default_virtual_device)))); - ICHECK_NOTNULL(defaulted_domain_ptr); + DeviceDomainPtr default_domain = MakeFirstOrderDomain(config_->CanonicalVirtualDevice( + VirtualDevice::Default(domain->virtual_device_, default_virtual_device))); + DeviceDomainPtr defaulted_domain_ptr = UnifyOrNull(domain, default_domain); + ICHECK(defaulted_domain_ptr != nullptr) << "domain:" << std::endl + << ToString(domain) << std::endl + << "default domain:" << std::endl + << ToString(default_domain); } else { for (const auto& sub_domain : domain->args_and_result_) { SetDefault(sub_domain, default_virtual_device); diff --git a/src/relay/transforms/device_planner.cc b/src/relay/transforms/device_planner.cc index 3562da3b0d6f..6ccbe38dbebf 100644 --- a/src/relay/transforms/device_planner.cc +++ b/src/relay/transforms/device_planner.cc @@ -553,18 +553,9 @@ class DeviceAnalyzer : public MixedModeVisitor { } void VisitExpr_(const FunctionNode* function_node) final { - // No need to step into fused primitive functions as they are lowered individually according - // to the devices of all their call sites. - if (function_node->HasNonzeroAttr(attr::kPrimitive)) { - return; - } - auto function = GetRef(function_node); auto func_domain = domains_->DomainFor(function); // higher-order - - // The function body domain must match the function result domain. - domains_->UnifyExprExact(function_node->body, - func_domain->function_result()); // may be higher-order + ICHECK_EQ(func_domain->function_arity(), function_node->params.size()); VLOG(2) << "initial function domain:" << std::endl << domains_->ToString(func_domain) << std::endl @@ -573,39 +564,33 @@ class DeviceAnalyzer : public MixedModeVisitor { << "for function:" << std::endl << PrettyPrint(function); - ICHECK_EQ(func_domain->function_arity(), function_node->params.size()); - for (size_t i = 0; i < function_node->params.size(); ++i) { - // The parameter domains must match the function argument domains. - domains_->UnifyExprExact(function_node->params[i], - func_domain->function_param(i)); // may be higher-order - VisitExpr(function_node->params[i]); + // The function body domain must match the function result domain. + domains_->UnifyExprExact(function_node->body, + func_domain->function_result()); // may be higher-order + if (!function_node->virtual_device()->IsFullyUnconstrained()) { + // The function body domain must match any existing virtual device annotation. + domains_->UnifyExprExact(function_node->body, + domains_->ForVirtualDevice(function_node->body->checked_type(), + function_node->virtual_device())); } - // If the function already has VirtualDevice attributes then we can further constrain the - // function's domain to match them. - if (!function_node->virtual_device()->IsFullyUnconstrained()) { - std::vector args_and_result; - for (auto param : function_node->params) { - args_and_result.emplace_back( - domains_->ForVirtualDevice(param->checked_type(), param->virtual_device())); - } - args_and_result.emplace_back(domains_->ForVirtualDevice(function_node->body->checked_type(), - function_node->virtual_device())); - auto annotation_domain = domains_->MakeHigherOrderDomain(std::move(args_and_result)); - if (domains_->UnifyOrNull(func_domain, annotation_domain) == nullptr) { // higher-order - // TODO(mbs): Proper diagnostics. - LOG(FATAL) << "Function VirtualDevices are incompatible with its \"on_device\" annotation. " - "Function:" - << std::endl - << PrettyPrint(function) << std::endl - << "with function virtual devices:" << std::endl - << domains_->ToString(func_domain) << std::endl - << "and annotation virtual devices:" << std::endl - << domains_->ToString(annotation_domain); + for (size_t i = 0; i < function_node->params.size(); ++i) { + const auto& param = function_node->params[i]; + // The parameter domain must match the function argument domain. + domains_->UnifyExprExact(param, + func_domain->function_param(i)); // may be higher-order + if (!param->virtual_device()->IsFullyUnconstrained()) { + // The parameter domain must match any existing virtual device annotation. + domains_->UnifyExprExact( + param, domains_->ForVirtualDevice(param->checked_type(), param->virtual_device())); } + VisitExpr(param); } - VisitExpr(function_node->body); + // No need to step into the body of Primitive functions. + if (!function_node->HasNonzeroAttr(attr::kPrimitive)) { + VisitExpr(function_node->body); + } VLOG(2) << "final function domain:" << std::endl << domains_->ToString(func_domain) << std::endl @@ -839,10 +824,16 @@ class DeviceDefaulter : public ExprVisitor { // For calls to Relay functions this step is identical to that for VisitExpr_(FunctionNode*) // above. But for calls to primitives we may still need to force free domains to be // defaulted. - VLOG(2) << "before defaulting callee:" << std::endl << domains_->ToString(func_domain); + VLOG(2) << "before defaulting callee:" << std::endl + << PrettyPrint(call_node->op) << std::endl + << "of domain:" << std::endl + << domains_->ToString(func_domain); domains_->SetResultDefaultThenParams(func_domain, domains_->config()->default_primitive_virtual_device); - VLOG(2) << "after defaulting callee:" << std::endl << domains_->ToString(func_domain); + VLOG(2) << "after defaulting callee:" << std::endl + << PrettyPrint(call_node->op) << std::endl + << "of domain:" << std::endl + << domains_->ToString(func_domain); } return ExprVisitor::VisitExpr_(call_node); } diff --git a/src/target/compilation_config.cc b/src/target/compilation_config.cc index cb50615ce6a5..ef54896ef187 100644 --- a/src/target/compilation_config.cc +++ b/src/target/compilation_config.cc @@ -72,16 +72,43 @@ Optional CompilationConfigNode::FindPrimitiveTargetForKind( return *itr; } +Target CompilationConfigNode::CanonicalTarget(const Target& target) const { + // Fast path -- object identity. + if (target == host_target) { + return target; + } + for (const auto& primitive_target : primitive_targets) { + if (target == primitive_target) { + return target; + } + } + // Slow path -- structural equality. We have so few targets it does not seem worth building an + // index. + if (StructuralEqual()(target, host_target)) { + return host_target; + } + for (const auto& primitive_target : primitive_targets) { + if (StructuralEqual()(target, primitive_target)) { + return primitive_target; + } + } + // No match. + return target; +} + VirtualDevice CompilationConfigNode::CanonicalVirtualDevice( const VirtualDevice& virtual_device) const { - if (virtual_device->target.defined()) { - return virtual_device_cache_.Unique(virtual_device); - } DLDeviceType device_type = virtual_device->device_type(); - // TODO(mbs): Proper diagnostics. - CHECK(device_type != kInvalidDeviceType) - << "VirtualDevice annotations must include at least a device_type"; - Target target = FindPrimitiveTargetForDeviceOrFail(virtual_device->device_type()); + Target target = virtual_device->target; + if (target.defined()) { + target = CanonicalTarget(target); + } else { + // Find the (unique) target matching the device's device type. + // TODO(mbs): Proper diagnostics. + CHECK(device_type != kInvalidDeviceType) + << "VirtualDevice annotations must include at least a device_type"; + target = FindPrimitiveTargetForDeviceOrFail(device_type); + } return virtual_device_cache_.Unique(VirtualDevice(device_type, virtual_device->virtual_device_id, target, virtual_device->memory_scope)); } @@ -222,9 +249,8 @@ void CompilationConfigNode::Init(const transform::PassContext& pass_ctx, // Establish the default primitive VirtualDevice, choosing a known Target to match the device // type. We do not create a default target, it must already exist as a primitive target. // - default_primitive_virtual_device = virtual_device_cache_.Unique(VirtualDevice( - default_primitive_device_type, - /*virtual_device_id=*/0, FindPrimitiveTargetForDeviceOrFail(default_primitive_device_type))); + default_primitive_virtual_device = CanonicalVirtualDevice( + VirtualDevice::ForDeviceType(default_primitive_device_type, /*virtual_device_id=*/0)); ICHECK(default_primitive_virtual_device.defined()); ICHECK(default_primitive_virtual_device->target.defined()); diff --git a/src/target/target.cc b/src/target/target.cc index afdfad9b76b9..07b347f09817 100644 --- a/src/target/target.cc +++ b/src/target/target.cc @@ -550,7 +550,8 @@ Optional TargetNode::GetHost() const { String TargetNode::ToDebugString() const { std::ostringstream os; os << "Target("; - os << "kind='" << kind->name << "'"; + os << "id=" << std::hex << reinterpret_cast(this); + os << ", kind='" << kind->name << "'"; if (!tag.empty()) { os << ", tag='" << tag << "'"; } diff --git a/tests/cpp/target/compilation_config_test.cc b/tests/cpp/target/compilation_config_test.cc index 825cb5baeb8c..e3e85110d87e 100644 --- a/tests/cpp/target/compilation_config_test.cc +++ b/tests/cpp/target/compilation_config_test.cc @@ -286,6 +286,29 @@ TEST(CompilationConfig, FindPrimitiveTargetForKind_NotFound) { ASSERT_FALSE(config->FindPrimitiveTargetForKind("cutlass").defined()); } +TEST(CompilationConfig, CanonicalTarget) { + Target host_target = TestDefaultCpuTarget(); + Target cuda_target = TestCudaTarget(); + Target cpu_target = TestCpuTarget(); + CompilationConfig config = TestCompilationConfig(); + + { + Target other_cuda_target = Target::WithHost(TestCudaTarget(), TestDefaultCpuTarget()); + ASSERT_NE(other_cuda_target, cuda_target); + ASSERT_EQ(config->CanonicalTarget(other_cuda_target), + config->FindPrimitiveTargetForKind("cuda")); + } + { + Target other_host_target = TestDefaultCpuTarget(); + ASSERT_NE(other_host_target, cuda_target); + ASSERT_EQ(config->CanonicalTarget(other_host_target), config->host_target); + } + { + Target other_target("cuda -max_num_threads=7"); + ASSERT_EQ(config->CanonicalTarget(other_target), other_target); + } +} + TEST(CompilationConfig, CanonicalVirtualDevice) { Target host_target = TestDefaultCpuTarget(); Target cuda_target = TestCudaTarget(); @@ -306,6 +329,12 @@ TEST(CompilationConfig, CanonicalVirtualDevice) { EXPECT_TRUE(StructuralEqual()(actual->target, Target::WithHost(cuda_target, host_target))); EXPECT_EQ(config->CanonicalVirtualDevice(in), actual); } + { + Target other_cuda_target = Target::WithHost(TestCudaTarget(), TestDefaultCpuTarget()); + VirtualDevice in = VirtualDevice(kDLCUDA, -1, other_cuda_target); + VirtualDevice actual = config->CanonicalVirtualDevice(in); + ASSERT_EQ(actual->target, config->FindPrimitiveTargetForKind("cuda")); + } } TEST(CompilationConfig, CanonicalVirtualDevice_NoDevice) { diff --git a/tests/cpp/target/virtual_device_test.cc b/tests/cpp/target/virtual_device_test.cc index 35e078713d1b..d982a8ae2153 100644 --- a/tests/cpp/target/virtual_device_test.cc +++ b/tests/cpp/target/virtual_device_test.cc @@ -107,6 +107,7 @@ TEST(VirtualDeviceCache, Memoized) { VirtualDeviceCache cache; Target target_a = Target("cuda"); Target target_b = Target("llvm"); + Target target_c = Target("cuda"); VirtualDevice virtual_device_a = cache.Make(kDLCUDA, 3, target_a, "local"); VirtualDevice virtual_device_b = cache.Make(kDLCPU, 1, target_b, "global"); @@ -115,6 +116,9 @@ TEST(VirtualDeviceCache, Memoized) { EXPECT_NE(cache.Make(kDLCUDA, 2, target_a, "local"), virtual_device_a); EXPECT_NE(cache.Make(kDLCPU, 3, target_b, "local"), virtual_device_a); EXPECT_NE(cache.Make(kDLCUDA, 3, target_a, "global"), virtual_device_a); + EXPECT_EQ(cache.Make(kDLCUDA, 3, Target("cuda"), "local"), virtual_device_a); + EXPECT_NE(cache.Make(kDLCUDA, 3, Target("cuda -max_threads_per_block=4096"), "local"), + virtual_device_a); } } // namespace diff --git a/tests/python/relay/test_pass_plan_devices.py b/tests/python/relay/test_pass_plan_devices.py index 35f072d19d92..2749339afdce 100644 --- a/tests/python/relay/test_pass_plan_devices.py +++ b/tests/python/relay/test_pass_plan_devices.py @@ -47,6 +47,9 @@ CPU_SCOPE_A = tvm.target.VirtualDevice(CPU_DEVICE, CPU_TARGET, memory_scope="scopeA") CPU_SCOPE_B = tvm.target.VirtualDevice(CPU_DEVICE, CPU_TARGET, memory_scope="scopeB") +GPU_SCOPE_GLOBAL = tvm.target.VirtualDevice(GPU_DEVICE, GPU_TARGET, memory_scope="global") +GPU_SCOPE_TEXTURE = tvm.target.VirtualDevice(GPU_DEVICE, GPU_TARGET, memory_scope="global.texture") + CTXT = tvm.transform.PassContext(config={"relay.fallback_device_type": DEFAULT.device_type_int}) core = tvm.IRModule() @@ -57,7 +60,7 @@ def rewrite_and_assert(in_mod, expected_mod): """Manually run the pass and assert it's structurally equals to the expected.""" - config = tvm.target.make_compilation_config(CTXT, TARGETS, HOST_TARGET) + config = tvm.target.make_compilation_config(CTXT, TARGETS) actual_mod = relay.transform.InferType()(in_mod) actual_mod = relay.transform.PlanDevices(config)(actual_mod) actual_mod = relay.transform.InferType()(actual_mod) @@ -1774,11 +1777,59 @@ def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], metatable, ) - config = tvm.target.make_compilation_config(CTXT, TARGETS, HOST_TARGET) + config = tvm.target.make_compilation_config(CTXT, TARGETS) actual_mod = relay.transform.InferType()(input()) actual_mod = relay.transform.PlanDevices(config)(actual_mod) relay.transform.InferType()(actual_mod) +def test_primitive(): + """Annotations on Primitive functions should be accepted, even though the body + of the Primitive function is not considered during PlanDevices.""" + metatable = { + "VirtualDevice": [ + GPU_SCOPE_GLOBAL, + GPU_SCOPE_TEXTURE, + ] + } + + mod = tvm.parser.parse( + """ + #[version = "0.0.5"] + def @main(%data1: Tensor[(1, 32, 40, 40), float32], + %data2: Tensor[(1, 32, 40, 40), float32]) { + %0 = fn (%a, Primitive=1) { + layout_transform(%a, src_layout="NCHW", dst_layout="NCHW4c") + }; + %1 = %0(%data1); + %3 = %0(%data2); + %5 = fn (%a {virtual_device=meta[VirtualDevice][0]}, + %b {virtual_device=meta[VirtualDevice][0]}, + virtual_device=meta[VirtualDevice][1], + Primitive=1) { + add(%a, %b) + }; + %6 = %5(%1, %3); + %10 = fn (%a, + virtual_device=meta[VirtualDevice][0], + Primitive=1) { + layout_transform(%a, src_layout="NCHW4c", dst_layout="NCHW") + }; + %10(%6) + } + """, + "from_string", + None, + metatable, + ) + print(mod) + + config = tvm.target.make_compilation_config(CTXT, GPU_TARGET) + mod = relay.transform.InferType()(mod) + # PlanDevices should succeed. + mod = relay.transform.PlanDevices(config)(mod) + print(mod) + + if __name__ == "__main__": tvm.testing.main()