Skip to content

Commit

Permalink
Raise if kernel not supported on device
Browse files Browse the repository at this point in the history
ghstack-source-id: e6bcca5ca4b995751e3f60c40bc59791c23b44f1
Pull Request resolved: #509
  • Loading branch information
danthe3rd committed Nov 10, 2022
1 parent 3a16b20 commit 034464a
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> dual_gemm_silu_identity_mul_(
// Optionally, we might not need intermediate GEMM outputs
constexpr bool kStoreD0 = true;
constexpr bool kStoreD1 = true;
using ArchTag = cutlass::arch::Sm80;

using DualGemm = cutlass::gemm::device::DualGemm<
scalar_t,
Expand All @@ -77,7 +78,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> dual_gemm_silu_identity_mul_(
cutlass::layout::RowMajor,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
Expand All @@ -90,6 +91,10 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> dual_gemm_silu_identity_mul_(
kStoreD1,
kSplitKSerial
>;
{
cudaDeviceProp* p = at::cuda::getDeviceProperties(x.device().index());
TORCH_CHECK(p->major * 10 + p->minor >= ArchTag::kMinComputeCapability, "GPU not supported");
}

int split_k_slices = DualGemm::kSplitKSerial ? 2 : 1;
using RefA = typename cutlass::TensorRef<typename DualGemm::ElementA, typename DualGemm::LayoutA>;
Expand Down
4 changes: 4 additions & 0 deletions xformers/components/swiglu/cuda/gemm_fused_operand_sum.cu
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,10 @@ void gemm_fused_operand_sum_(
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone
>;
{
cudaDeviceProp* p = at::cuda::getDeviceProperties(a.device().index());
TORCH_CHECK(p->major * 10 + p->minor >= SmArch::kMinComputeCapability, "GPU not supported");
}

// Below is the reduction kernel used in the case of parallel split-k
using ReduceGemmSplitKShape = cutlass::MatrixShape<4, 64>;
Expand Down

0 comments on commit 034464a

Please sign in to comment.