You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Describe the bug
Some files are missing the headers that they rely on, which means they cannot be included by themselves. This is "hidden" in most of the examples because they import many things and one of the other headers will often include what's needed, but this is fragile and has caused build failures as we rearrange includes. I assume everything in cutlass/include and cutlass/tools/util/include/ are supposed to be includable individually? If so, then the following headers are broken (generated with a simple shell script):
Build this with -c -std=c++20 -I cutlass/include/ -I cutlass/tools/util/include.
Expected behavior
This should compile cleanly.
Environment details (please complete the following information):
Environment location: Bare-metal
CUTLASS main branch
$ nvcc --versionnvcc: NVIDIA (R) Cuda compiler driverCopyright (c) 2005-2023 NVIDIA CorporationBuilt on Tue_Aug_15_22:02:13_PDT_2023Cuda compilation tools, release 12.2, V12.2.140Build cuda_12.2.r12.2/compiler.33191640_0
Additional context
Example output
$ nvcc -c -std=c++20 test.cu -I cutlass/include/ -I cutlass/tools/util/include/
cutlass/include/cutlass/epilogue/fusion/operations.hpp(134): error: identifier "sizeof_bits_v" is undefined
int AlignmentBias_ = 128 / sizeof_bits_v<ElementBias_>,
^
cutlass/include/cutlass/epilogue/fusion/operations.hpp(152): error: identifier "sizeof_bits_v" is undefined
int AlignmentBias_ = 128 / sizeof_bits_v<ElementBias_>,
^
cutlass/include/cutlass/epilogue/fusion/operations.hpp(169): error: identifier "sizeof_bits_v" is undefined
int AlignmentBias_ = 128 / sizeof_bits_v<ElementBias_>,
^
cutlass/include/cutlass/epilogue/fusion/operations.hpp(190): error: identifier "sizeof_bits_v" is undefined
int AlignmentAux_ = 128 / sizeof_bits_v<ElementAux_>,
^
cutlass/include/cutlass/epilogue/fusion/operations.hpp(191): error: type name is not allowed
int AlignmentBias_ = 128 / sizeof_bits_v<ElementBias_>,
^
cutlass/include/cutlass/epilogue/fusion/operations.hpp(191): error: expected a "," or ">"
int AlignmentBias_ = 128 / sizeof_bits_v<ElementBias_>,
^
cutlass/include/cutlass/epilogue/fusion/operations.hpp(211): error: identifier "sizeof_bits_v" is undefined
int AlignmentBias_ = 128 / sizeof_bits_v<ElementBias_>,
^
cutlass/include/cutlass/epilogue/fusion/operations.hpp(212): error: type name is not allowed
int AlignmentScalar_ = 128 / sizeof_bits_v<ElementScalar_>,
^
cutlass/include/cutlass/epilogue/fusion/operations.hpp(212): error: expected a "," or ">"
int AlignmentScalar_ = 128 / sizeof_bits_v<ElementScalar_>,
^
cutlass/include/cutlass/epilogue/fusion/operations.hpp(234): error: identifier "sizeof_bits_v" is undefined
int AlignmentBias_ = 128 / sizeof_bits_v<ElementBias_>,
^
cutlass/include/cutlass/epilogue/fusion/operations.hpp(264): error: identifier "sizeof_bits_v" is undefined
int AlignmentAux_ = 128 / sizeof_bits_v<ElementAux_>,
^
cutlass/include/cutlass/epilogue/fusion/operations.hpp(265): error: type name is not allowed
int AlignmentBias_ = 128 / sizeof_bits_v<ElementBias_>,
^
cutlass/include/cutlass/epilogue/fusion/operations.hpp(265): error: expected a "," or ">"
int AlignmentBias_ = 128 / sizeof_bits_v<ElementBias_>,
^
cutlass/include/cutlass/epilogue/fusion/operations.hpp(291): error: identifier "sizeof_bits_v" is undefined
int AlignmentAux_ = 128 / sizeof_bits_v<ElementAux_>,
^
cutlass/include/cutlass/epilogue/fusion/operations.hpp(318): error: identifier "sizeof_bits_v" is undefined
int AlignmentAux_ = 128 / sizeof_bits_v<ElementAux_>,
^
cutlass/include/cutlass/epilogue/fusion/operations.hpp(319): error: type name is not allowed
int AlignmentBias_ = 128 / sizeof_bits_v<ElementBias_>,
^
cutlass/include/cutlass/epilogue/fusion/operations.hpp(319): error: expected a "," or ">"
int AlignmentBias_ = 128 / sizeof_bits_v<ElementBias_>,
^
cutlass/include/cutlass/epilogue/collective/collective_builder.hpp(103): error: identifier "cute" is undefined
cute::enable_if_t<not is_base_of_v<fusion::FusionOperation, FusionCallbacks>>
^
cutlass/include/cutlass/epilogue/collective/collective_builder.hpp(103): error: enable_if_t is not a template
cute::enable_if_t<not is_base_of_v<fusion::FusionOperation, FusionCallbacks>>
^
cutlass/include/cutlass/epilogue/collective/collective_builder.hpp(103): error: identifier "is_base_of_v" is undefined
cute::enable_if_t<not is_base_of_v<fusion::FusionOperation, FusionCallbacks>>
^
cutlass/include/cutlass/epilogue/collective/collective_builder.hpp(103): error: expected a ">"
cute::enable_if_t<not is_base_of_v<fusion::FusionOperation, FusionCallbacks>>
^
cutlass/include/cutlass/epilogue/collective/collective_builder.hpp(103): error: type name is not allowed
cute::enable_if_t<not is_base_of_v<fusion::FusionOperation, FusionCallbacks>>
^
cutlass/include/cutlass/epilogue/collective/collective_builder.hpp(104): error: expected a ";"
> {
^
cutlass/include/cute/arch/mma_sm90.hpp(154): error: complex is not a template
using DRegisters = complex<double>[4];
^
cutlass/include/cute/arch/mma_sm90.hpp(155): error: complex is not a template
using ARegisters = complex<double>[2];
^
cutlass/include/cute/arch/mma_sm90.hpp(156): error: complex is not a template
using BRegisters = complex<double>[1];
^
cutlass/include/cute/arch/mma_sm90.hpp(157): error: complex is not a template
using CRegisters = complex<double>[4];
^
cutlass/include/cute/arch/mma_sm90.hpp(160): error: complex is not a template
fma(complex<double> & d0, complex<double> & d1,
^
cutlass/include/cute/arch/mma_sm90.hpp(160): error: complex is not a template
fma(complex<double> & d0, complex<double> & d1,
^
cutlass/include/cute/arch/mma_sm90.hpp(161): error: complex is not a template
complex<double> & d2, complex<double> & d3,
^
cutlass/include/cute/arch/mma_sm90.hpp(161): error: complex is not a template
complex<double> & d2, complex<double> & d3,
^
cutlass/include/cute/arch/mma_sm90.hpp(162): error: complex is not a template
complex<double> const& a0, complex<double> const& a1,
^
cutlass/include/cute/arch/mma_sm90.hpp(162): error: complex is not a template
complex<double> const& a0, complex<double> const& a1,
^
cutlass/include/cute/arch/mma_sm90.hpp(163): error: complex is not a template
complex<double> const& b0,
^
cutlass/include/cute/arch/mma_sm90.hpp(164): error: complex is not a template
complex<double> const& c0, complex<double> const& c1,
^
cutlass/include/cute/arch/mma_sm90.hpp(164): error: complex is not a template
complex<double> const& c0, complex<double> const& c1,
^
cutlass/include/cute/arch/mma_sm90.hpp(165): error: complex is not a template
complex<double> const& c2, complex<double> const& c3)
^
cutlass/include/cute/arch/mma_sm90.hpp(165): error: complex is not a template
complex<double> const& c2, complex<double> const& c3)
^
cutlass/include/cute/arch/mma_sm90.hpp(212): error: complex is not a template
using DRegisters = complex<double>[4];
^
cutlass/include/cute/arch/mma_sm90.hpp(213): error: complex is not a template
using ARegisters = complex<double>[4];
^
cutlass/include/cute/arch/mma_sm90.hpp(214): error: complex is not a template
using BRegisters = complex<double>[2];
^
cutlass/include/cute/arch/mma_sm90.hpp(215): error: complex is not a template
using CRegisters = complex<double>[4];
^
cutlass/include/cute/arch/mma_sm90.hpp(218): error: complex is not a template
fma(complex<double> & d0, complex<double> & d1,
^
cutlass/include/cute/arch/mma_sm90.hpp(218): error: complex is not a template
fma(complex<double> & d0, complex<double> & d1,
^
cutlass/include/cute/arch/mma_sm90.hpp(219): error: complex is not a template
complex<double> & d2, complex<double> & d3,
^
cutlass/include/cute/arch/mma_sm90.hpp(219): error: complex is not a template
complex<double> & d2, complex<double> & d3,
^
cutlass/include/cute/arch/mma_sm90.hpp(220): error: complex is not a template
complex<double> const& a0, complex<double> const& a1,
^
cutlass/include/cute/arch/mma_sm90.hpp(220): error: complex is not a template
complex<double> const& a0, complex<double> const& a1,
^
cutlass/include/cute/arch/mma_sm90.hpp(221): error: complex is not a template
complex<double> const& a2, complex<double> const& a3,
^
cutlass/include/cute/arch/mma_sm90.hpp(221): error: complex is not a template
complex<double> const& a2, complex<double> const& a3,
^
cutlass/include/cute/arch/mma_sm90.hpp(222): error: complex is not a template
complex<double> const& b0, complex<double> const& b1,
^
cutlass/include/cute/arch/mma_sm90.hpp(222): error: complex is not a template
complex<double> const& b0, complex<double> const& b1,
^
cutlass/include/cute/arch/mma_sm90.hpp(223): error: complex is not a template
complex<double> const& c0, complex<double> const& c1,
^
cutlass/include/cute/arch/mma_sm90.hpp(223): error: complex is not a template
complex<double> const& c0, complex<double> const& c1,
^
cutlass/include/cute/arch/mma_sm90.hpp(224): error: complex is not a template
complex<double> const& c2, complex<double> const& c3)
^
cutlass/include/cute/arch/mma_sm90.hpp(224): error: complex is not a template
complex<double> const& c2, complex<double> const& c3)
^
cutlass/include/cute/arch/mma_sm90.hpp(271): error: complex is not a template
using DRegisters = complex<double>[4];
^
cutlass/include/cute/arch/mma_sm90.hpp(272): error: complex is not a template
using ARegisters = complex<double>[8];
^
cutlass/include/cute/arch/mma_sm90.hpp(273): error: complex is not a template
using BRegisters = complex<double>[4];
^
cutlass/include/cute/arch/mma_sm90.hpp(274): error: complex is not a template
using CRegisters = complex<double>[4];
^
cutlass/include/cute/arch/mma_sm90.hpp(277): error: complex is not a template
fma(complex<double> & d0, complex<double> & d1,
^
cutlass/include/cute/arch/mma_sm90.hpp(277): error: complex is not a template
fma(complex<double> & d0, complex<double> & d1,
^
cutlass/include/cute/arch/mma_sm90.hpp(278): error: complex is not a template
complex<double> & d2, complex<double> & d3,
^
cutlass/include/cute/arch/mma_sm90.hpp(278): error: complex is not a template
complex<double> & d2, complex<double> & d3,
^
cutlass/include/cute/arch/mma_sm90.hpp(279): error: complex is not a template
complex<double> const& a0, complex<double> const& a1,
^
cutlass/include/cute/arch/mma_sm90.hpp(279): error: complex is not a template
complex<double> const& a0, complex<double> const& a1,
^
cutlass/include/cute/arch/mma_sm90.hpp(280): error: complex is not a template
complex<double> const& a2, complex<double> const& a3,
^
cutlass/include/cute/arch/mma_sm90.hpp(280): error: complex is not a template
complex<double> const& a2, complex<double> const& a3,
^
cutlass/include/cute/arch/mma_sm90.hpp(281): error: complex is not a template
complex<double> const& a4, complex<double> const& a5,
^
cutlass/include/cute/arch/mma_sm90.hpp(281): error: complex is not a template
complex<double> const& a4, complex<double> const& a5,
^
cutlass/include/cute/arch/mma_sm90.hpp(282): error: complex is not a template
complex<double> const& a6, complex<double> const& a7,
^
cutlass/include/cute/arch/mma_sm90.hpp(282): error: complex is not a template
complex<double> const& a6, complex<double> const& a7,
^
cutlass/include/cute/arch/mma_sm90.hpp(283): error: complex is not a template
complex<double> const& b0, complex<double> const& b1,
^
cutlass/include/cute/arch/mma_sm90.hpp(283): error: complex is not a template
complex<double> const& b0, complex<double> const& b1,
^
cutlass/include/cute/arch/mma_sm90.hpp(284): error: complex is not a template
complex<double> const& b2, complex<double> const& b3,
^
cutlass/include/cute/arch/mma_sm90.hpp(284): error: complex is not a template
complex<double> const& b2, complex<double> const& b3,
^
cutlass/include/cute/arch/mma_sm90.hpp(285): error: complex is not a template
complex<double> const& c0, complex<double> const& c1,
^
cutlass/include/cute/arch/mma_sm90.hpp(285): error: complex is not a template
complex<double> const& c0, complex<double> const& c1,
^
cutlass/include/cute/arch/mma_sm90.hpp(286): error: complex is not a template
complex<double> const& c2, complex<double> const& c3)
^
cutlass/include/cute/arch/mma_sm90.hpp(286): error: complex is not a template
complex<double> const& c2, complex<double> const& c3)
^
cutlass/include/cutlass/epilogue/fusion/operations.hpp(191): error: expected a ","
int AlignmentBias_ = 128 / sizeof_bits_v<ElementBias_>,
^
detected during processing of template argument list for "cutlass::epilogue::fusion::LinCombPerRowBiasEltActAux" based on template arguments <GmemLayoutTagD, Schedule::ActivationFunctor, ElementD, ElementCompute, Schedule::ElementT, Schedule::ElementBias, ElementC_, ElementCompute, <error-constant>> at line 696 of cutlass/include/cutlass/epilogue/collective/builders/sm90_builder.inl
cutlass/include/cutlass/epilogue/fusion/operations.hpp(152): error: type name is not allowed
int AlignmentBias_ = 128 / sizeof_bits_v<ElementBias_>,
^
detected during processing of template argument list for "cutlass::epilogue::fusion::LinCombPerRowBiasEltAct" based on template arguments <Schedule::ActivationFunctor, ElementD, ElementCompute, Schedule::ElementBias, ElementC_, ElementCompute> at line 704 of cutlass/include/cutlass/epilogue/collective/builders/sm90_builder.inl
cutlass/include/cutlass/epilogue/fusion/operations.hpp(152): error: expected a ","
int AlignmentBias_ = 128 / sizeof_bits_v<ElementBias_>,
^
detected during processing of template argument list for "cutlass::epilogue::fusion::LinCombPerRowBiasEltAct" based on template arguments <Schedule::ActivationFunctor, ElementD, ElementCompute, Schedule::ElementBias, ElementC_, ElementCompute> at line 704 of cutlass/include/cutlass/epilogue/collective/builders/sm90_builder.inl
83 errors detected in the compilation of "test.cu".
The text was updated successfully, but these errors were encountered:
This issue has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed. This issue will be labeled inactive-90d if there is no activity in the next 60 days.
The pre-commit hook (#584) re-ordered the headers. It causes
`group_gemm_sm90.cuh` not able to compile. This PR add a clang-format
section to preserve the include order.
Related: NVIDIA/cutlass#1827
Describe the bug
Some files are missing the headers that they rely on, which means they cannot be included by themselves. This is "hidden" in most of the examples because they import many things and one of the other headers will often include what's needed, but this is fragile and has caused build failures as we rearrange includes. I assume everything in cutlass/include and cutlass/tools/util/include/ are supposed to be includable individually? If so, then the following headers are broken (generated with a simple shell script):
Broken headers
Steps/Code to reproduce bug
Here is an example:
Build this with
-c -std=c++20 -I cutlass/include/ -I cutlass/tools/util/include
.Expected behavior
This should compile cleanly.
Environment details (please complete the following information):
Additional context
Example output
The text was updated successfully, but these errors were encountered: