Skip to content

Commit

Permalink
[mlir][spirv] Change the return type for {Min|Max}VersionBase
Browse files Browse the repository at this point in the history
For synthesizing an op's implementation of the generated interface
from {Min|Max}Version, we need to define an `initializer` and
`mergeAction`. The `initializer` specifies the initial version,
and `mergeAction` specifies how version specifications from
different parts of the op should be merged to generate a final
version requirements.

Previously we use the specified version enum as the type for both
the initializer and thus the final return type. This means we need
to perform `static_cast` over some hopefully not used number (`~0u`)
as the initializer. This is quite opaque and sort of not guaranteed
to work. Also, there are ops that have an enum attribute where some
values declare version requirements (e.g., enumerant `B` requires
v1.1+) but some not (e.g., enumerant `A` requires nothing). Then a
concrete op instance with `A` will still declare it implements the
version interface (because interface implementation is static for
an op) but actually theirs no requirements for version.

So this commit changes to use an more explicit `llvm::Optional`
to wrap around the returned version enum.  This should make it
more clear.

Reviewed By: jpienaar

Differential Revision: https://reviews.llvm.org/D108312
  • Loading branch information
antiagainst committed Nov 24, 2021
1 parent 68e2231 commit cb395f6
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 34 deletions.
22 changes: 14 additions & 8 deletions mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAvailability.td
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,15 @@ class MinVersionBase<string name, I32EnumAttr scheme, I32EnumAttrCase min>
: Availability {
let interfaceName = name;

let queryFnRetType = scheme.returnType;
let queryFnRetType = "llvm::Optional<" # scheme.returnType # ">";
let queryFnName = "getMinVersion";

let mergeAction = "$overall = static_cast<" # scheme.returnType # ">("
"std::max($overall, $instance))";
let initializer = "static_cast<" # scheme.returnType # ">(uint32_t(0))";
let mergeAction = "{ "
"if ($overall.hasValue()) { "
"$overall = static_cast<" # scheme.returnType # ">("
"std::max(*$overall, $instance)); "
"} else { $overall = $instance; }}";
let initializer = "::llvm::None";
let instanceType = scheme.cppNamespace # "::" # scheme.className;

let instance = scheme.cppNamespace # "::" # scheme.className # "::" #
Expand All @@ -76,12 +79,15 @@ class MaxVersionBase<string name, I32EnumAttr scheme, I32EnumAttrCase max>
: Availability {
let interfaceName = name;

let queryFnRetType = scheme.returnType;
let queryFnRetType = "llvm::Optional<" # scheme.returnType # ">";
let queryFnName = "getMaxVersion";

let mergeAction = "$overall = static_cast<" # scheme.returnType # ">("
"std::min($overall, $instance))";
let initializer = "static_cast<" # scheme.returnType # ">(~uint32_t(0))";
let mergeAction = "{ "
"if ($overall.hasValue()) { "
"$overall = static_cast<" # scheme.returnType # ">("
"std::min(*$overall, $instance)); "
"} else { $overall = $instance; }}";
let initializer = "::llvm::None";
let instanceType = scheme.cppNamespace # "::" # scheme.className;

let instance = scheme.cppNamespace # "::" # scheme.className # "::" #
Expand Down
6 changes: 4 additions & 2 deletions mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -239,10 +239,12 @@ class SPIRVOpInterface<string name> : OpInterface<name> {
// TODO: the following interfaces definitions are duplicating with the above.
// Remove them once we are able to support dialect-specific contents in ODS.
def QueryMinVersionInterface : SPIRVOpInterface<"QueryMinVersionInterface"> {
let methods = [InterfaceMethod<"", "::mlir::spirv::Version", "getMinVersion">];
let methods = [InterfaceMethod<
"", "::llvm::Optional<::mlir::spirv::Version>", "getMinVersion">];
}
def QueryMaxVersionInterface : SPIRVOpInterface<"QueryMaxVersionInterface"> {
let methods = [InterfaceMethod<"", "::mlir::spirv::Version", "getMaxVersion">];
let methods = [InterfaceMethod<
"", "::llvm::Optional<::mlir::spirv::Version>", "getMaxVersion">];
}
def QueryExtensionInterface : SPIRVOpInterface<"QueryExtensionInterface"> {
let methods = [InterfaceMethod<
Expand Down
18 changes: 10 additions & 8 deletions mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -843,22 +843,24 @@ bool SPIRVConversionTarget::isLegalOp(Operation *op) {
// Make sure this op is available at the given version. Ops not implementing
// QueryMinVersionInterface/QueryMaxVersionInterface are available to all
// SPIR-V versions.
if (auto minVersion = dyn_cast<spirv::QueryMinVersionInterface>(op))
if (minVersion.getMinVersion() > this->targetEnv.getVersion()) {
if (auto minVersionIfx = dyn_cast<spirv::QueryMinVersionInterface>(op)) {
Optional<spirv::Version> minVersion = minVersionIfx.getMinVersion();
if (minVersion && *minVersion > this->targetEnv.getVersion()) {
LLVM_DEBUG(llvm::dbgs()
<< op->getName() << " illegal: requiring min version "
<< spirv::stringifyVersion(minVersion.getMinVersion())
<< "\n");
<< spirv::stringifyVersion(*minVersion) << "\n");
return false;
}
if (auto maxVersion = dyn_cast<spirv::QueryMaxVersionInterface>(op))
if (maxVersion.getMaxVersion() < this->targetEnv.getVersion()) {
}
if (auto maxVersionIfx = dyn_cast<spirv::QueryMaxVersionInterface>(op)) {
Optional<spirv::Version> maxVersion = maxVersionIfx.getMaxVersion();
if (maxVersion && *maxVersion < this->targetEnv.getVersion()) {
LLVM_DEBUG(llvm::dbgs()
<< op->getName() << " illegal: requiring max version "
<< spirv::stringifyVersion(maxVersion.getMaxVersion())
<< "\n");
<< spirv::stringifyVersion(*maxVersion) << "\n");
return false;
}
}

// Make sure this op's required extensions are allowed to use. Ops not
// implementing QueryExtensionInterface do not require extensions to be
Expand Down
18 changes: 11 additions & 7 deletions mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,13 +109,17 @@ void UpdateVCEPass::runOnOperation() {
// requirements.
WalkResult walkResult = module.walk([&](Operation *op) -> WalkResult {
// Op min version requirements
if (auto minVersion = dyn_cast<spirv::QueryMinVersionInterface>(op)) {
deducedVersion = std::max(deducedVersion, minVersion.getMinVersion());
if (deducedVersion > allowedVersion) {
return op->emitError("'") << op->getName() << "' requires min version "
<< spirv::stringifyVersion(deducedVersion)
<< " but target environment allows up to "
<< spirv::stringifyVersion(allowedVersion);
if (auto minVersionIfx = dyn_cast<spirv::QueryMinVersionInterface>(op)) {
Optional<spirv::Version> minVersion = minVersionIfx.getMinVersion();
if (minVersion) {
deducedVersion = std::max(deducedVersion, *minVersion);
if (deducedVersion > allowedVersion) {
return op->emitError("'")
<< op->getName() << "' requires min version "
<< spirv::stringifyVersion(deducedVersion)
<< " but target environment allows up to "
<< spirv::stringifyVersion(allowedVersion);
}
}
}

Expand Down
24 changes: 17 additions & 7 deletions mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,23 @@ void PrintOpAvailability::runOnFunction() {
auto opName = op->getName();
auto &os = llvm::outs();

if (auto minVersion = dyn_cast<spirv::QueryMinVersionInterface>(op))
os << opName << " min version: "
<< spirv::stringifyVersion(minVersion.getMinVersion()) << "\n";
if (auto minVersionIfx = dyn_cast<spirv::QueryMinVersionInterface>(op)) {
Optional<spirv::Version> minVersion = minVersionIfx.getMinVersion();
os << opName << " min version: ";
if (minVersion)
os << spirv::stringifyVersion(*minVersion) << "\n";
else
os << "None\n";
}

if (auto maxVersion = dyn_cast<spirv::QueryMaxVersionInterface>(op))
os << opName << " max version: "
<< spirv::stringifyVersion(maxVersion.getMaxVersion()) << "\n";
if (auto maxVersionIfx = dyn_cast<spirv::QueryMaxVersionInterface>(op)) {
Optional<spirv::Version> maxVersion = maxVersionIfx.getMaxVersion();
os << opName << " max version: ";
if (maxVersion)
os << spirv::stringifyVersion(*maxVersion) << "\n";
else
os << "None\n";
}

if (auto extension = dyn_cast<spirv::QueryExtensionInterface>(op)) {
os << opName << " extensions: [";
Expand Down Expand Up @@ -81,7 +91,7 @@ void PrintOpAvailability::runOnFunction() {
}

namespace mlir {
void registerPrintOpAvailabilityPass() {
void registerPrintSpirvAvailabilityPass() {
PassRegistration<PrintOpAvailability>();
}
} // namespace mlir
Expand Down
4 changes: 2 additions & 2 deletions mlir/tools/mlir-opt/mlir-opt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ using namespace mlir;
namespace mlir {
void registerConvertToTargetEnvPass();
void registerPassManagerTestPass();
void registerPrintOpAvailabilityPass();
void registerPrintSpirvAvailabilityPass();
void registerShapeFunctionTestPasses();
void registerSideEffectTestPasses();
void registerSliceAnalysisTestPass();
Expand Down Expand Up @@ -119,7 +119,7 @@ void registerTestDialect(DialectRegistry &);
void registerTestPasses() {
registerConvertToTargetEnvPass();
registerPassManagerTestPass();
registerPrintOpAvailabilityPass();
registerPrintSpirvAvailabilityPass();
registerShapeFunctionTestPasses();
registerSideEffectTestPasses();
registerSliceAnalysisTestPass();
Expand Down

0 comments on commit cb395f6

Please sign in to comment.