diff --git a/lldb/include/lldb/API/SBDebugger.h b/lldb/include/lldb/API/SBDebugger.h index af19b1faf3bf51..84ea9c0f772e16 100644 --- a/lldb/include/lldb/API/SBDebugger.h +++ b/lldb/include/lldb/API/SBDebugger.h @@ -57,6 +57,8 @@ class LLDB_API SBDebugger { static const char *GetBroadcasterClass(); + static bool SupportsLanguage(lldb::LanguageType language); + lldb::SBBroadcaster GetBroadcaster(); /// Get progress data from a SBEvent whose type is eBroadcastBitProgress. diff --git a/lldb/include/lldb/Symbol/TypeSystem.h b/lldb/include/lldb/Symbol/TypeSystem.h index b4025c173a1861..7d48f9b316138c 100644 --- a/lldb/include/lldb/Symbol/TypeSystem.h +++ b/lldb/include/lldb/Symbol/TypeSystem.h @@ -209,6 +209,7 @@ class TypeSystem : public PluginInterface, // TypeSystems can support more than one language virtual bool SupportsLanguage(lldb::LanguageType language) = 0; + static bool SupportsLanguageStatic(lldb::LanguageType language); // Type Completion virtual bool GetCompleteType(lldb::opaque_compiler_type_t type) = 0; diff --git a/lldb/source/API/SBDebugger.cpp b/lldb/source/API/SBDebugger.cpp index 7ef0d6efd4aaa5..29da7d33dd80b8 100644 --- a/lldb/source/API/SBDebugger.cpp +++ b/lldb/source/API/SBDebugger.cpp @@ -1742,3 +1742,7 @@ bool SBDebugger::InterruptRequested() { return m_opaque_sp->InterruptRequested(); return false; } + +bool SBDebugger::SupportsLanguage(lldb::LanguageType language) { + return TypeSystem::SupportsLanguageStatic(language); +} diff --git a/lldb/source/Symbol/TypeSystem.cpp b/lldb/source/Symbol/TypeSystem.cpp index 4956f10a0b0a73..5d56d9b1829dac 100644 --- a/lldb/source/Symbol/TypeSystem.cpp +++ b/lldb/source/Symbol/TypeSystem.cpp @@ -335,3 +335,14 @@ TypeSystemMap::GetTypeSystemForLanguage(lldb::LanguageType language, } return GetTypeSystemForLanguage(language); } + +bool TypeSystem::SupportsLanguageStatic(lldb::LanguageType language) { + if (language == eLanguageTypeUnknown) + return false; + + LanguageSet languages = + PluginManager::GetAllTypeSystemSupportedLanguagesForTypes(); + if (languages.Empty()) + return false; + return languages[language]; +} diff --git a/lldb/tools/lldb-dap/DAP.cpp b/lldb/tools/lldb-dap/DAP.cpp index d419f821999e6c..8bcad72a74a80c 100644 --- a/lldb/tools/lldb-dap/DAP.cpp +++ b/lldb/tools/lldb-dap/DAP.cpp @@ -65,8 +65,51 @@ DAP::DAP() DAP::~DAP() = default; +void DAP::PopulateExceptionBreakpoints() { + llvm::call_once(initExceptionBreakpoints, [this]() { + exception_breakpoints = {}; + if (lldb::SBDebugger::SupportsLanguage(lldb::eLanguageTypeC_plus_plus)) { + exception_breakpoints->emplace_back("cpp_catch", "C++ Catch", + lldb::eLanguageTypeC_plus_plus); + exception_breakpoints->emplace_back("cpp_throw", "C++ Throw", + lldb::eLanguageTypeC_plus_plus); + } + if (lldb::SBDebugger::SupportsLanguage(lldb::eLanguageTypeObjC)) { + exception_breakpoints->emplace_back("objc_catch", "Objective-C Catch", + lldb::eLanguageTypeObjC); + exception_breakpoints->emplace_back("objc_throw", "Objective-C Throw", + lldb::eLanguageTypeObjC); + } + if (lldb::SBDebugger::SupportsLanguage(lldb::eLanguageTypeSwift)) { + exception_breakpoints->emplace_back("swift_catch", "Swift Catch", + lldb::eLanguageTypeSwift); + exception_breakpoints->emplace_back("swift_throw", "Swift Throw", + lldb::eLanguageTypeSwift); + } + }); +} + ExceptionBreakpoint *DAP::GetExceptionBreakpoint(const std::string &filter) { - for (auto &bp : exception_breakpoints) { + // PopulateExceptionBreakpoints() is called after g_dap.debugger is created + // in a request-initialize. + // + // But this GetExceptionBreakpoint() method may be called before attaching, in + // which case, we may not have populated the filter yet. + // + // We also cannot call PopulateExceptionBreakpoints() in DAP::DAP() because + // we need SBDebugger::Initialize() to have been called before this. + // + // So just calling PopulateExceptionBreakoints(),which does lazy-populating + // seems easiest. Two other options include: + // + call g_dap.PopulateExceptionBreakpoints() in lldb-dap.cpp::main() + // right after the call to SBDebugger::Initialize() + // + Just call PopulateExceptionBreakpoints() to get a fresh list everytime + // we query (a bit overkill since it's not likely to change?) + PopulateExceptionBreakpoints(); + assert(exception_breakpoints.has_value() && + "exception_breakpoints must have been populated"); + + for (auto &bp : *exception_breakpoints) { if (bp.filter == filter) return &bp; } @@ -74,7 +117,12 @@ ExceptionBreakpoint *DAP::GetExceptionBreakpoint(const std::string &filter) { } ExceptionBreakpoint *DAP::GetExceptionBreakpoint(const lldb::break_id_t bp_id) { - for (auto &bp : exception_breakpoints) { + // See comment in the other GetExceptionBreakpoint(). + PopulateExceptionBreakpoints(); + assert(exception_breakpoints.has_value() && + "exception_breakpoints must have been populated"); + + for (auto &bp : *exception_breakpoints) { if (bp.bp.GetID() == bp_id) return &bp; } diff --git a/lldb/tools/lldb-dap/DAP.h b/lldb/tools/lldb-dap/DAP.h index a88ee3e1dec6bc..daa0d9f1aa7f04 100644 --- a/lldb/tools/lldb-dap/DAP.h +++ b/lldb/tools/lldb-dap/DAP.h @@ -156,7 +156,8 @@ struct DAP { std::unique_ptr log; llvm::StringMap source_breakpoints; FunctionBreakpointMap function_breakpoints; - std::vector exception_breakpoints; + std::optional> exception_breakpoints; + llvm::once_flag initExceptionBreakpoints; std::vector init_commands; std::vector pre_run_commands; std::vector post_run_commands; @@ -228,6 +229,8 @@ struct DAP { llvm::json::Value CreateTopLevelScopes(); + void PopulateExceptionBreakpoints(); + /// \return /// Attempt to determine if an expression is a variable expression or /// lldb command using a hueristic based on the first term of the diff --git a/lldb/tools/lldb-dap/lldb-dap.cpp b/lldb/tools/lldb-dap/lldb-dap.cpp index 7746afb6cbbf38..470c9f84c6a203 100644 --- a/lldb/tools/lldb-dap/lldb-dap.cpp +++ b/lldb/tools/lldb-dap/lldb-dap.cpp @@ -16,6 +16,7 @@ #include #include #include +#include #include #include #if defined(_WIN32) @@ -1586,6 +1587,7 @@ void request_initialize(const llvm::json::Object &request) { bool source_init_file = GetBoolean(arguments, "sourceInitFile", true); g_dap.debugger = lldb::SBDebugger::Create(source_init_file, log_cb, nullptr); + g_dap.PopulateExceptionBreakpoints(); auto cmd = g_dap.debugger.GetCommandInterpreter().AddMultiwordCommand( "lldb-dap", "Commands for managing lldb-dap."); if (GetBoolean(arguments, "supportsStartDebuggingRequest", false)) { @@ -1621,7 +1623,7 @@ void request_initialize(const llvm::json::Object &request) { body.try_emplace("supportsEvaluateForHovers", true); // Available filters or options for the setExceptionBreakpoints request. llvm::json::Array filters; - for (const auto &exc_bp : g_dap.exception_breakpoints) { + for (const auto &exc_bp : *g_dap.exception_breakpoints) { filters.emplace_back(CreateExceptionBreakpointFilter(exc_bp)); } body.try_emplace("exceptionBreakpointFilters", std::move(filters)); @@ -2476,7 +2478,7 @@ void request_setExceptionBreakpoints(const llvm::json::Object &request) { // Keep a list of any exception breakpoint filter names that weren't set // so we can clear any exception breakpoints if needed. std::set unset_filters; - for (const auto &bp : g_dap.exception_breakpoints) + for (const auto &bp : *g_dap.exception_breakpoints) unset_filters.insert(bp.filter); for (const auto &value : *filters) { diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp index 7b8e3230bf5534..10a1e9f129f429 100644 --- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp @@ -1396,6 +1396,7 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) { continue; insertAssignPtrTypeIntrs(I, B); + deduceOperandElementType(I); insertAssignTypeIntrs(I, B); insertPtrCastOrAssignTypeInstr(I, B); insertSpirvDecorations(I, B); diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h index 08afdf373f014a..ed5b4ff2de4dce 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h @@ -28,6 +28,8 @@ namespace mlir { namespace linalg { class IteratorTypeAttr; class LinalgOp; +class ConvolutionOpInterface; +class GroupedConvolutionOpInterface; class GenericOp; namespace detail { @@ -133,6 +135,38 @@ std::optional isaFillOpInterface(GenericOp genericOp); namespace detail { +// Common implementations for ConvolutionOpInterface +namespace convolution_impl { +// Returns strides as a vector. +SmallVector getStrides(ConvolutionOpInterface op); +// Returns dilations as a vector. +SmallVector getDilations(ConvolutionOpInterface op); +// Region builder for basic convolution +void regionBuilder(ImplicitLocOpBuilder &b, Block &block, + ArrayRef attrs); +// Region builder for basic quantized convolution +void quantizedRegionBuilder(ImplicitLocOpBuilder &b, Block &block, + ArrayRef attrs); +void getEffects( + Operation *op, + SmallVectorImpl> + &effects); +ParseResult parse(OpAsmParser &parser, OperationState &result, + bool isQuantized = false); +void print(LinalgOp op, OpAsmPrinter &p); +} // namespace convolution_impl + +// Common implementations for GroupedConvolutionOpInterface +namespace grouped_convolution_impl { +int64_t getSpatialRank(GroupedConvolutionOpInterface op); +ArrayAttr createCommonIndexingMaps( + MLIRContext *ctx, int64_t numSpatial, + const SmallVector> &layouts, + const SmallVectorImpl &strides, + const SmallVectorImpl &dilations); +ArrayAttr getIteratorTypes(GroupedConvolutionOpInterface op); +} // namespace grouped_convolution_impl + /// Returns true if the block contains a contraction of the following form: /// /// %0 = (permutation-of(cu(block-argument-0), @@ -189,6 +223,9 @@ LogicalResult verifyContractionInterface(Operation *op); /// Verify that `op` conforms to the ConvolutionOpInterface. LogicalResult verifyConvolutionInterface(Operation *op); +/// Verify that `op` conforms to the GroupedConvolutionOpInterface. +LogicalResult verifyGroupedConvolutionInterface(Operation *op); + /// Verify that `op` conforms to the FillOpInterface. LogicalResult verifyFillInterface(Operation *op); diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td index fbf3f19cde0e9b..5ae481a222e3c8 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -175,6 +175,101 @@ def LinalgConvolutionOpInterface : OpInterface<"ConvolutionOpInterface"> { return $_op.getOperation()->getOperand(1); }] >, + InterfaceMethod< + /*desc=*/"Return the spatial rank.", + /*retTy=*/"int64_t", + /*methodName=*/"getSpatialRank", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + // Most convolution's inputs have batch, channel and spatial dims + return cast(image().getType()).getRank() - 2; + }] + > + ]; +} + +def LinalgGroupedConvolutionOpInterface : OpInterface<"GroupedConvolutionOpInterface", [ + LinalgConvolutionOpInterface]> { + let description = [{ + A grouped convolution is defined in general terms: + 1. It is a convolution as defined by `ConvolutionOpInterface`. + 2. Operands have a the following distinct dimensions (excluding batch in input/output): group, channel, spatial + 3. `input_rank == kernel_rank == output_rank` (including batch in input/output) + 4. Reductions are along the input channel and spatial dimensions while group, output channel + and output spatial dimensions are parallel. + }]; + let cppNamespace = "::mlir::linalg"; + let verify = [{ return detail::verifyGroupedConvolutionInterface($_op); }]; + let methods = [ + InterfaceMethod<[{ + Returns the groups position for the input. + }], + "SmallVector>", "getLayoutsEnums", (ins) + >, + InterfaceMethod<[{ + Returns the groups position for the input. + }], + "int64_t", "getInputGroupsPosition", (ins) + >, + InterfaceMethod<[{ + Returns the channel position for the input. + }], + "int64_t", "getInputChannelPosition", (ins) + >, + InterfaceMethod<[{ + Returns the channel position for the output. + }], + "int64_t", "getOutputChannelPosition", (ins) + >, + InterfaceMethod<[{ + Get number of groups. + }], + "int64_t", "getNumGroups", (ins), + /*methodBody=*/[{}], + /*defaultImplementation=*/[{ + return cast($_op.image().getType()).getShape()[$_op.getInputGroupsPosition() - 1]; + }]>, + InterfaceMethod<[{ + Get number of input channels. + }], + "int64_t", "getNumInputChannels", (ins), + /*methodBody=*/[{}], + /*defaultImplementation=*/[{ + return cast($_op.image().getType()).getShape()[$_op.getInputChannelPosition()]; + }]>, + InterfaceMethod<[{ + Get number of output channels. + }], + "int64_t", "getNumOutputChannels", (ins), + /*methodBody=*/[{}], + /*defaultImplementation=*/[{ + return cast($_op.getDpsInits()[0].getType()).getShape()[$_op.getOutputChannelPosition()]; + }]>, + InterfaceMethod<[{ + Returns indexing maps for any spatial dimension. + }], + "::mlir::ArrayAttr", "getIteratorTypes", (ins), + /*methodBody=*/[{}], + /*defaultImplementation=*/[{ + return detail::grouped_convolution_impl::getIteratorTypes($_op); + }]>, + InterfaceMethod<[{ + Returns strides. + }], + "::llvm::SmallVector", "getStridesVector", (ins), + /*methodBody=*/[{}], + /*defaultImplementation=*/[{ + return detail::convolution_impl::getStrides($_op); + }]>, + InterfaceMethod<[{ + Returns dilations. + }], + "::llvm::SmallVector", "getDilationsVector", (ins), + /*methodBody=*/[{}], + /*defaultImplementation=*/[{ + return detail::convolution_impl::getDilations($_op); + }]> ]; } diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td index ac61117c3d6e36..7db7c54a4ea098 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -384,6 +384,147 @@ def ReduceOp : LinalgStructuredBase_Op<"reduce", [ let hasVerifier = 1; } +//===----------------------------------------------------------------------===// +// GroupedConvNDOp ops. +//===----------------------------------------------------------------------===// + +def GroupedConvNDOp : LinalgStructuredBase_Op<"grouped_conv_nd", + [AttrSizedOperandSegments, LinalgGroupedConvolutionOpInterface]> { + + let summary = [{ + Performs N-D grouped convolution with switchable channel position; either first or last. + }]; + let description = [{ + Allows any number of spatial dimensions but treats all of them as contiguous. Throughout, `S`, + will represent all spatial dimensions. Operand layouts are determined by the `layouts` + `StrArrayAttr` attritbute. Each element of the array is a string representing the layout of the + corresponding operand and should be be mappable to a `GroupedConvDim` enum, i.e. one of + n: (batch dim) + g: (group dim) + f: (feature or output channel dim) + s: (all spatial dims) + c: (input channel dim). + + The domain will always be in the order `(N, G, F, S, C, KS)`. + + }]; + + let arguments = (ins + Variadic:$inputs, + Variadic:$inits, + DefaultValuedAttr:$layouts, + OptionalAttr:$strides, + OptionalAttr:$dilations + ); + let results = (outs Variadic:$result_tensors); + let regions = (region AnyRegion:$region); + + let skipDefaultBuilders = 1; + let builders = [ + OpBuilder< + (ins "Value":$input, "Value":$filter, "Value":$init, + CArg<"ArrayRef", "{}">:$strides, CArg<"ArrayRef", "{}">:$dilations, + CArg<"ArrayRef", "{}">:$attributes), + [{ + int64_t numSpatialDims = cast(input.getType()).getRank() - 3; + if (strides.empty()) + strides = ::llvm::SmallVector(numSpatialDims, 1); + if (dilations.empty()) + dilations = ::llvm::SmallVector(numSpatialDims, 1); + $_state.addAttribute(getStridesAttrName($_state.name), + ::mlir::DenseElementsAttr::get( + ::mlir::RankedTensorType::get(numSpatialDims, $_builder.getI64Type()), strides)); + $_state.addAttribute(getDilationsAttrName($_state.name), + ::mlir::DenseElementsAttr::get( + ::mlir::RankedTensorType::get(numSpatialDims, $_builder.getI64Type()), dilations)); + buildStructuredOp($_builder, $_state, std::nullopt, {input, filter}, init, + attributes, GroupedConvNDOp::getRegionBuilder()); + }]>, + OpBuilder< + (ins "TypeRange":$resultTensorTypes, "Value":$input, "Value":$filter, + "Value":$init, + CArg<"ArrayRef", "{}">:$strides, CArg<"ArrayRef", "{}">:$dilations, + CArg<"ArrayRef", "{}">:$attributes), + [{ + int64_t numSpatialDims = cast(input.getType()).getRank() - 3; + if (strides.empty()) + strides = ::llvm::SmallVector(numSpatialDims, 1); + if (dilations.empty()) + dilations = ::llvm::SmallVector(numSpatialDims, 1); + $_state.addAttribute(getStridesAttrName($_state.name), + ::mlir::DenseElementsAttr::get( + ::mlir::RankedTensorType::get(numSpatialDims, $_builder.getI64Type()), strides)); + $_state.addAttribute(getDilationsAttrName($_state.name), + ::mlir::DenseElementsAttr::get( + ::mlir::RankedTensorType::get(numSpatialDims, $_builder.getI64Type()), dilations)); + buildStructuredOp($_builder, $_state, resultTensorTypes, + {input, filter}, init, attributes, GroupedConvNDOp::getRegionBuilder()); + }]>, + OpBuilder< + (ins "TypeRange":$resultTensorTypes, "Value":$input, "Value":$filter, + "Value":$init, "Attribute":$strides, "Attribute":$dilations, + CArg<"ArrayRef", "{}">:$attributes), + [{ + $_state.addAttribute(getStridesAttrName($_state.name), strides); + $_state.addAttribute(getDilationsAttrName($_state.name), dilations); + buildStructuredOp($_builder, $_state, resultTensorTypes, {input, filter}, init, + attributes, GroupedConvNDOp::getRegionBuilder()); + }]> + ]; + + // TODO: Figure out how to move this to the interface + let extraClassDeclaration = structuredOpsBaseDecls # [{ + void print(::mlir::OpAsmPrinter &printer) { + return detail::convolution_impl::print(*this, printer); + } + static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, + ::mlir::OperationState &result) { + return detail::convolution_impl::parse(parser, result); + } + static std::function)> + getRegionBuilder() { + return detail::convolution_impl::regionBuilder; + } + // Implement functions necessary for DestinationStyleOpInterface. + MutableOperandRange getDpsInitsMutable() { return getInitsMutable(); } + + // Implement functions necessary for LinalgOp. + ArrayAttr getIndexingMaps(); + + // Implement functions necessary for GroupedConvolutionOpInterface + int64_t getSpatialRank() { + return detail::grouped_convolution_impl::getSpatialRank(*this); + } + + SmallVector> getLayoutsEnums() { + SmallVector> layouts; + for (auto attr : (*this).getLayoutsAttr().getValue()) { + std::string layoutStr = cast(attr).getValue().str(); + SmallVector<::mlir::utils::GroupedConvDim> layout(layoutStr.size()); + for (size_t i = 0; i < layoutStr.size(); i++) { + auto maybeDimEnum = ::mlir::utils::symbolizeGroupedConvDim(layoutStr.substr(i, 1).c_str()); + assert(maybeDimEnum); + layout[i] = maybeDimEnum.value(); + } + layouts.push_back(layout); + } + return layouts; + } + + int64_t getOutputChannelPosition() { + return 2; + } + + int64_t getInputChannelPosition() { + return 2; + } + + int64_t getInputGroupsPosition() { + return 1; + } + }]; +} //===----------------------------------------------------------------------===// // Transpose op. diff --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.td b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.td index 4200343ce3e132..c7c5d617f6492c 100644 --- a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.td +++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.td @@ -20,4 +20,16 @@ def IteratorType : I32EnumAttr<"IteratorType", "Iterator type", [ let cppNamespace = "::mlir::utils"; } +def GroupedConvDim : I32EnumAttr<"GroupedConvDim", "Convolution dim", + [ + I32EnumAttrCase<"n", 0>, // batch + I32EnumAttrCase<"g", 1>, // group + I32EnumAttrCase<"f", 2>, // feature (output channel) + I32EnumAttrCase<"s", 3>, // spatial + I32EnumAttrCase<"c", 4> // channel (input channel) + ]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::utils"; +} + #endif // STRUCTURED_OPS_UTILS diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp index f35ab3b856b4ea..c2db6670e4167a 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -766,6 +766,135 @@ enum class MatchConvolutionResult { }; } // namespace mlir::linalg::detail +SmallVector +mlir::linalg::detail::convolution_impl::getStrides(ConvolutionOpInterface op) { + auto maybeStridesAttr = op->getAttrOfType("strides"); + if (!maybeStridesAttr) { + OpBuilder builder(op.getContext()); + return SmallVector(op.getSpatialRank(), 1); + } + return llvm::to_vector(maybeStridesAttr.getValues()); +} + +SmallVector mlir::linalg::detail::convolution_impl::getDilations( + ConvolutionOpInterface op) { + auto maybeDilationsAttr = + op->getAttrOfType("dilations"); + if (!maybeDilationsAttr) { + OpBuilder builder(op.getContext()); + return SmallVector(op.getSpatialRank(), 1); + } + return llvm::to_vector(maybeDilationsAttr.getValues()); +} + +int64_t mlir::linalg::detail::grouped_convolution_impl::getSpatialRank( + GroupedConvolutionOpInterface op) { + return cast(op.image().getType()).getRank() - 3; +} + +ArrayAttr mlir::linalg::detail::grouped_convolution_impl::getIteratorTypes( + GroupedConvolutionOpInterface op) { + int64_t numSpatialDims = op.getSpatialRank(); + SmallVector iteratorTypes( + 3 + numSpatialDims, IteratorTypeAttr::get(op.getContext(), par)); + SmallVector reductions( + numSpatialDims + 1, IteratorTypeAttr::get(op.getContext(), red)); + iteratorTypes.insert(iteratorTypes.end(), reductions.begin(), + reductions.end()); + + return Builder(op.getContext()).getArrayAttr(iteratorTypes); +} + +ArrayAttr +mlir::linalg::detail::grouped_convolution_impl::createCommonIndexingMaps( + MLIRContext *ctx, int64_t numSpatial, + const SmallVector> &layouts, + const SmallVectorImpl &strides, + const SmallVectorImpl &dilations) { + assert(layouts.size() == 3 && "expected 3 layouts: image, filter, init"); + + // Domain: (n, g, f, os, c, ks) + AffineExpr n = getAffineDimExpr(0, ctx); + AffineExpr g = getAffineDimExpr(1, ctx); + AffineExpr f = getAffineDimExpr(2, ctx); + SmallVector s( + llvm::map_range(llvm::seq(3, numSpatial + 3), + [&](int64_t d) { return getAffineDimExpr(d, ctx); })); + AffineExpr c = getAffineDimExpr(numSpatial + 3, ctx); + SmallVector ks(llvm::map_range( + llvm::seq(numSpatial + 4, 2 * (numSpatial + 1) + 2), + [&](int64_t d) { return getAffineDimExpr(d, ctx); })); + + SmallVector inSpatials; + inSpatials.reserve(numSpatial); + for (const auto &[sp, ksp, st, di] : llvm::zip(s, ks, strides, dilations)) { + inSpatials.push_back(sp * st + ksp * di); + } + + auto getExprs = [&](const SmallVector &layout, + const SmallVector &spatials) { + SmallVector exprs(layout.size()); + int64_t spatialDim; + for (const auto &[i, dim] : llvm::enumerate(layout)) { + switch (dim) { + case utils::GroupedConvDim::n: + exprs[i] = n; + break; + case utils::GroupedConvDim::g: + exprs[i] = g; + break; + case utils::GroupedConvDim::f: + exprs[i] = f; + break; + case utils::GroupedConvDim::s: + exprs[i] = spatials[0]; + spatialDim = i; + break; + case utils::GroupedConvDim::c: + exprs[i] = c; + break; + default: + assert(false); + } + } + if (spatials.size() > 1) + exprs.insert(exprs.begin() + spatialDim + 1, spatials.begin() + 1, + spatials.end()); + return exprs; + }; + SmallVector inExprs = getExprs(layouts[0], inSpatials); + SmallVector kExprs = getExprs(layouts[1], ks); + SmallVector outExprs = getExprs(layouts[2], s); + SmallVector maps( + {AffineMap::get(4 + 2 * numSpatial, 0, getExprs(layouts[0], inSpatials), + ctx), + AffineMap::get(4 + 2 * numSpatial, 0, getExprs(layouts[1], ks), ctx), + AffineMap::get(4 + 2 * numSpatial, 0, getExprs(layouts[2], s), ctx)}); + + return Builder(ctx).getAffineMapArrayAttr(maps); +} + +LogicalResult +mlir::linalg::detail::verifyGroupedConvolutionInterface(Operation *op) { + if (failed(verifyConvolutionInterface(op))) + return failure(); + if (GroupedConvolutionOpInterface conv = + dyn_cast(op)) { + const auto imageType = conv.image().getType().dyn_cast(); + const auto imageRank = imageType.getRank(); + const auto kernelRank = + conv.filter().getType().cast().getRank(); + const auto initType = + cast(op).getDpsInits()[0].getType().dyn_cast(); + const auto initRank = initType.getRank(); + if (imageRank != kernelRank || imageRank != initRank) + return op->emitError( + "Rank relationship must be `in_rank == out_rank == kernel_rank`"); + return success(); + } + return failure(); +} + mlir::linalg::detail::MatchConvolutionResult mlir::linalg::detail::isConvolutionInterfaceImpl( Operation *op, ConvolutionDimensions *dimensions) { diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index b79afebfa81587..21cc22f034aa66 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1735,6 +1735,110 @@ LogicalResult ReduceOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// ConvolutionOpInterface +//===----------------------------------------------------------------------===// + +// There must be a way to avoid defining the following 3 functions +ParseResult mlir::linalg::detail::convolution_impl::parse( + OpAsmParser &parser, OperationState &result, bool isQuantized) { + if (isQuantized) + return parseNamedStructuredOp( + parser, result, 5, + mlir::linalg::detail::convolution_impl::quantizedRegionBuilder); + return parseNamedStructuredOp( + parser, result, 3, mlir::linalg::detail::convolution_impl::regionBuilder); +} + +void mlir::linalg::detail::convolution_impl::print(LinalgOp op, + OpAsmPrinter &p) { + printNamedStructuredOp(p, op.getOperation(), op.getDpsInputs(), + op.getDpsInits()); +} + +// Build {mul, add} region for convolution +void mlir::linalg::detail::convolution_impl::regionBuilder( + ImplicitLocOpBuilder &b, Block &block, ArrayRef attrs) { + assert(block.getNumArguments() == 3 && + "ConvolutionInterface regionBuilder expects 3 (>=0) args"); + RegionBuilderHelper helper(b, block); + SmallVector yields; + + Value value1 = + helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(2).getType(), + block.getArgument(0)); + Value value2 = + helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(2).getType(), + block.getArgument(1)); + Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2); + Value value4 = + helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), value3); + yields.push_back(value4); + helper.yieldOutputs(yields); +} + +void mlir::linalg::detail::convolution_impl::quantizedRegionBuilder( + ImplicitLocOpBuilder &b, Block &block, ArrayRef attrs) { + assert(block.getNumArguments() == 5 && + "ConvolutionInterface regionBuilder expects 5 args"); + RegionBuilderHelper helper(b, block); + Value value1 = + helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(4).getType(), + block.getArgument(0)); + Value value2 = + helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(4).getType(), + block.getArgument(2)); + Value value3 = helper.buildBinaryFn(BinaryFn::sub, value1, value2); + Value value4 = + helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(4).getType(), + block.getArgument(1)); + Value value5 = + helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(4).getType(), + block.getArgument(3)); + Value value6 = helper.buildBinaryFn(BinaryFn::sub, value4, value5); + Value value7 = helper.buildBinaryFn(BinaryFn::mul, value3, value6); + Value value8 = + helper.buildBinaryFn(BinaryFn::add, block.getArgument(4), value7); + helper.yieldOutputs({value8}); +} + +void mlir::linalg::detail::convolution_impl::getEffects( + Operation *op, + SmallVectorImpl> + &effects) { + if (!isa(op)) + return; + if (LinalgOp linalgOp = dyn_cast(op)) { + if (linalgOp.hasPureTensorSemantics()) + return; + getGenericEffectsImpl(effects, linalgOp); + } +} + +//===----------------------------------------------------------------------===// +// GroupedConvNDOp +//===----------------------------------------------------------------------===// + +void GroupedConvNDOp::getEffects( + SmallVectorImpl> + &effects) { + return detail::convolution_impl::getEffects(*this, effects); +} + +ArrayAttr GroupedConvNDOp::getIndexingMaps() { + ArrayAttr cached = (*this)->getAttrOfType( + LinalgDialect::kMemoizedIndexingMapsAttrName); + if (cached) + return cached; + + cached = detail::grouped_convolution_impl::createCommonIndexingMaps( + getContext(), getSpatialRank(), getLayoutsEnums(), getStridesVector(), + getDilationsVector()); + + (*this)->setAttr(LinalgDialect::kMemoizedIndexingMapsAttrName, cached); + return cached; +} + //===----------------------------------------------------------------------===// // TransposeOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/bufferize.mlir b/mlir/test/Dialect/Linalg/bufferize.mlir index e8ab1184b1fd26..56fc39f5fc073f 100644 --- a/mlir/test/Dialect/Linalg/bufferize.mlir +++ b/mlir/test/Dialect/Linalg/bufferize.mlir @@ -189,3 +189,24 @@ func.func @bufferize_dot(%in: tensor<4xf32>, %out: tensor) -> tensor { // CHECK: %[[OUT_TENSOR:.*]] = bufferization.to_tensor %[[ALLOC]] : memref // CHECK: return %[[OUT_TENSOR]] } + +// ----- + +// CHECK-LABEL: func @gen_grouped_3D_channel_first_tensor( +// CHECK-SAME: %[[ARG0_TENSOR:.*]]: tensor<64x2x16x26x26x26xf32>, +// CHECK-SAME: %[[ARG1_TENSOR:.*]]: tensor<2x20x16x3x3x3xf32>, +// CHECK-SAME: %[[ARG2_TENSOR:.*]]: tensor<64x2x20x8x8x8xf32>) -> tensor<64x2x20x8x8x8xf32> { +// CHECK-DAG: %[[ARG0_MEMREF:.*]] = bufferization.to_memref %[[ARG0_TENSOR]] : memref<64x2x16x26x26x26xf32> +// CHECK-DAG: %[[ARG1_MEMREF:.*]] = bufferization.to_memref %[[ARG1_TENSOR]] : memref<2x20x16x3x3x3xf32> +// CHECK-DAG: %[[ARG2_MEMREF:.*]] = bufferization.to_memref %[[ARG2_TENSOR]] : memref<64x2x20x8x8x8xf32> +// CHECK-DAG: %[[INIT_BUFFER:.*]] = memref.alloc() {{.*}} : memref<64x2x20x8x8x8xf32> +// CHECK: memref.copy %[[ARG2_MEMREF]], %[[INIT_BUFFER]] : memref<64x2x20x8x8x8xf32> to memref<64x2x20x8x8x8xf32> +// CHECK: linalg.grouped_conv_nd +// CHECK-SAME: dilations = dense<2> : tensor<3xi64> +// CHECK-SAME: strides = dense<3> : tensor<3xi64>} +// CHECK-SAME: ins(%[[ARG0_MEMREF]], %[[ARG1_MEMREF]] : memref<64x2x16x26x26x26xf32>, memref<2x20x16x3x3x3xf32>) +// CHECK-SAME: outs(%[[INIT_BUFFER]] : memref<64x2x20x8x8x8xf32>) +func.func @gen_grouped_3D_channel_first_tensor(%arg0: tensor<64x2x16x26x26x26xf32>, %arg1: tensor<2x20x16x3x3x3xf32>, %arg2: tensor<64x2x20x8x8x8xf32>) -> tensor<64x2x20x8x8x8xf32> { + %0 = linalg.grouped_conv_nd {strides = dense<3> : tensor<3xi64>, dilations = dense<2> : tensor<3xi64>} ins(%arg0, %arg1: tensor<64x2x16x26x26x26xf32>, tensor<2x20x16x3x3x3xf32>) outs(%arg2: tensor<64x2x20x8x8x8xf32>) -> tensor<64x2x20x8x8x8xf32> + return %0 : tensor<64x2x20x8x8x8xf32> +} diff --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir index b818170a8e7974..df89029e27d868 100644 --- a/mlir/test/Dialect/Linalg/loops.mlir +++ b/mlir/test/Dialect/Linalg/loops.mlir @@ -1,12 +1,10 @@ -// RUN: mlir-opt %s -convert-linalg-to-loops | FileCheck %s -// RUN: mlir-opt %s -convert-linalg-to-parallel-loops | FileCheck --check-prefix=CHECKPARALLEL %s +// RUN: mlir-opt %s -convert-linalg-to-loops | FileCheck %s --check-prefixes=COMMON,CHECK +// RUN: mlir-opt %s -convert-linalg-to-parallel-loops | FileCheck --check-prefixes=COMMON,CHECKPARALLEL %s // Test that we can lower all the way to LLVM without crashing, don't check results here. // RUN: mlir-opt %s -convert-linalg-to-loops -test-lower-to-llvm -o=/dev/null 2>&1 -// CHECK: #[[$stride1Dilation1:.*]] = affine_map<(d0, d1) -> (d0 + d1)> - -// CHECKPARALLEL: #[[$stride1Dilation1:.*]] = affine_map<(d0, d1) -> (d0 + d1)> +// COMMON: #[[$stride1Dilation1:.*]] = affine_map<(d0, d1) -> (d0 + d1)> func.func @matmul(%arg0: memref, %M: index, %N: index, %K: index) { %c0 = arith.constant 0 : index diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir index 02ecbed232c8b5..a231569672209b 100644 --- a/mlir/test/Dialect/Linalg/named-ops.mlir +++ b/mlir/test/Dialect/Linalg/named-ops.mlir @@ -1,5 +1,16 @@ // RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s +// ----- + +// CHECK-LABEL: func @gen_grouped_1D_channel_first_memref +func.func @gen_grouped_1D_channel_first_memref(%arg0: memref<64x8x16x10xf32>, %arg1: memref<8x32x16x3xf32>, %arg2: memref<64x8x32x8xf32>) { + // CHECK: grouped_conv_nd + linalg.grouped_conv_nd ins(%arg0, %arg1: memref<64x8x16x10xf32>, memref<8x32x16x3xf32>) outs(%arg2: memref<64x8x32x8xf32>) + return +} + +// ----- + // CHECK-LABEL: func @depthwise_conv_1d_nwc_wcm func.func @depthwise_conv_1d_nwc_wcm(%input: tensor<1x12x8xf32>, %filter: tensor<3x8x8xf32>) -> tensor<1x10x8x8xf32> { %zero = arith.constant 0.000000e+00 : f32 diff --git a/mlir/test/Dialect/Linalg/tile-conv.mlir b/mlir/test/Dialect/Linalg/tile-conv.mlir index f674996e42f333..6d19665931662c 100644 --- a/mlir/test/Dialect/Linalg/tile-conv.mlir +++ b/mlir/test/Dialect/Linalg/tile-conv.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -transform-interpreter -canonicalize | FileCheck %s +// RUN: mlir-opt %s -transform-interpreter -canonicalize -split-input-file | FileCheck %s // CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 2)> // CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 3)> @@ -41,3 +41,63 @@ module attributes {transform.with_named_sequence} { // CHECK: linalg.conv_2d // CHECK-SAME: ins(%[[SVIN]], %[[SVKER]] // CHECK-SAME: outs(%[[SVOUT]] + +// ----- + +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 2)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 3)> +// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 4)> +// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 5)> +// CHECK-DAG: #[[MAP4:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 6)> +// CHECK-DAG: #[[MAP5:.*]] = affine_map<(d0)[s0] -> (d0 + s0 - 1)> + +func.func @grouped_conv_2D(%arg0 : memref, %arg1 : memref, %arg2 : memref) { + linalg.grouped_conv_nd {layouts = ["ngcs", "gfcs", "ngfs"]} ins(%arg0, %arg1 : memref, memref) outs(%arg2 : memref) + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.grouped_conv_nd"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1, %loop:5 = transform.structured.tile_using_for %0 tile_sizes [2, 3, 4, 5, 6] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) + transform.yield + } +} + +// CHECK: func @grouped_conv_2D +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index +// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index +// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index +// CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index +// CHECK-DAG: %[[BATCH:.*]] = memref.dim %[[ARG0]], %[[C0]] +// CHECK-DAG: %[[GROUPS:.*]] = memref.dim %[[ARG0]], %[[C1]] +// CHECK-DAG: %[[IN_CHANNELS:.*]] = memref.dim %[[ARG0]], %[[C2]] +// CHECK-DAG: %[[OUT_CHANNELS:.*]] = memref.dim %[[ARG1]], %[[C1]] +// CHECK-DAG: %[[KW:.*]] = memref.dim %[[ARG1]], %[[C3]] +// CHECK-DAG: %[[KH:.*]] = memref.dim %[[ARG1]], %[[C4]] +// CHECK-DAG: %[[W:.*]] = memref.dim %[[ARG2]], %[[C3]] +// CHECK-DAG: %[[H:.*]] = memref.dim %[[ARG2]], %[[C4]] +// CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[BATCH]] step %[[C2]] +// CHECK: scf.for %[[J:.*]] = %[[C0]] to %[[GROUPS]] step %[[C3]] +// CHECK: scf.for %[[K:.*]] = %[[C0]] to %[[OUT_CHANNELS]] step %[[C4]] +// CHECK: scf.for %[[L:.*]] = %[[C0]] to %[[W]] step %[[C5]] +// CHECK: scf.for %[[M:.*]] = %[[C0]] to %[[H]] step %[[C6]] +// CHECK: %[[T4:.*]] = affine.min #[[MAP0]](%[[I]])[%[[BATCH]]] +// CHECK: %[[T5:.*]] = affine.min #[[MAP1]](%[[J]])[%[[GROUPS]]] +// CHECK-DAG: %[[T6:.*]] = affine.min #[[MAP2]](%[[K]])[%[[OUT_CHANNELS]]] +// CHECK-DAG: %[[T7:.*]] = affine.min #[[MAP3]](%[[L]])[%[[W]]] +// CHECK-DAG: %[[T8:.*]] = affine.min #[[MAP4]](%[[M]])[%[[H]]] +// CHECK-DAG: %[[T9:.*]] = affine.apply #[[MAP5]](%[[T7]])[%[[KW]]] +// CHECK-DAG: %[[T10:.*]] = affine.apply #[[MAP5]](%[[T8]])[%[[KH]]] +// CHECK-DAG: %[[SVIN:.*]] = memref.subview %[[ARG0]][%[[I]], %[[J]], 0, %[[L]], %[[M]]] [%[[T4]], %[[T5]], %[[IN_CHANNELS]], %[[T9]], %[[T10]]] +// CHECK-DAG: %[[SVKER:.*]] = memref.subview %[[ARG1]][%[[J]], %[[K]], 0, 0, 0] [%[[T5]], %[[T6]], %[[IN_CHANNELS]], %[[KW]], %[[KH]]] +// CHECK-DAG: %[[SVOUT:.*]] = memref.subview %[[ARG2]][%[[I]], %[[J]], %[[K]], %[[L]], %[[M]]] [%[[T4]], %[[T5]], %[[T6]], %[[T7]], %[[T8]]] +// CHECK: linalg.grouped_conv_nd {layouts = ["ngcs", "gfcs", "ngfs"]} +// CHECK-SAME: ins(%[[SVIN]], %[[SVKER]] +// CHECK-SAME: outs(%[[SVOUT]]