From fc8bd6683c0ac70340a63e7793dc92a2414586c4 Mon Sep 17 00:00:00 2001 From: pommedeterresautee Date: Mon, 1 Aug 2022 17:16:00 +0200 Subject: [PATCH] make universal kernel work with torchlib --- .gitignore | 6 + CMakeLists.txt | 2 +- examples/00_basic_gemm/CMakeLists.txt | 7 +- examples/00_basic_gemm/basic_gemm.cu | 751 ++++++++++---------------- 4 files changed, 287 insertions(+), 479 deletions(-) diff --git a/.gitignore b/.gitignore index 1328f6b7d6..f1e606ec02 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,8 @@ # PyCache files __pycache__/ +build/ +.history/ +libtorch/ +cmake-build-debug/ +.idea/ +.vscode/ \ No newline at end of file diff --git a/CMakeLists.txt b/CMakeLists.txt index cfed600b72..00534ea111 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -52,7 +52,7 @@ find_package(Doxygen QUIET) # # CUTLASS 2.x requires C++11 # -set(CMAKE_CXX_STANDARD 11) +set(CMAKE_CXX_STANDARD 14) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_EXTENSIONS OFF) diff --git a/examples/00_basic_gemm/CMakeLists.txt b/examples/00_basic_gemm/CMakeLists.txt index 5af8fcf363..031897cc2a 100644 --- a/examples/00_basic_gemm/CMakeLists.txt +++ b/examples/00_basic_gemm/CMakeLists.txt @@ -27,9 +27,14 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +set(CMAKE_PREFIX_PATH "/home/geantvert/workspace/cutlass/libtorch") +set(Torch_DIR "/home/geantvert/workspace/cutlass/libtorch") +find_package(Torch REQUIRED) cutlass_example_add_executable( 00_basic_gemm basic_gemm.cu ) - + +target_link_libraries(00_basic_gemm PRIVATE "${TORCH_LIBRARIES}") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}") diff --git a/examples/00_basic_gemm/basic_gemm.cu b/examples/00_basic_gemm/basic_gemm.cu index 7c633b30a5..a5b57111ec 100644 --- a/examples/00_basic_gemm/basic_gemm.cu +++ b/examples/00_basic_gemm/basic_gemm.cu @@ -1,497 +1,294 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/* - This example demonstrates how to call a CUTLASS GEMM kernel and provides a naive reference - matrix multiply kernel to verify its correctness. - - The CUTLASS Gemm template is instantiated in the function CutlassSgemmNN. This is kernel computes - the general matrix product (GEMM) using single-precision floating-point arithmetic and assumes - all matrices have column-major layout. - - The threadblock tile size is chosen as 128x128x8 which offers good performance for large matrices. - See the CUTLASS Parallel for All blog post for more exposition on the tunable parameters available - in CUTLASS. - - https://devblogs.nvidia.com/cutlass-linear-algebra-cuda/ - - Aside from defining and launching the SGEMM kernel, this example does not use any other components - or utilities within CUTLASS. Such utilities are demonstrated elsewhere in other examples and are - prevalent in the CUTLASS unit tests. - - This example has delibrately been kept similar to the basic_gemm example from cutass-1.3 to - highlight the minimum amount of differences needed to transition to cutlass-2.0. - - Cutlass-1.3 sgemm: https://github.com/NVIDIA/cutlass/blob/master/examples/00_basic_gemm/basic_gemm.cu -*/ - -// Standard Library includes #include -#include -#include - -// Helper methods to check for errors -#include "helper.h" - -// -// CUTLASS includes needed for single-precision GEMM kernel -// - -// Defines cutlass::gemm::device::Gemm, the generic Gemm computation template class. +#include "cutlass/cutlass.h" #include "cutlass/gemm/device/gemm.h" +#include "cutlass/util/device_memory.h" +#include +#include -/////////////////////////////////////////////////////////////////////////////////////////////////// -// -// This function defines a CUTLASS GEMM kernel instantiation, constructs its parameters object, -// and launches it on the CUDA device. -// -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Define a CUTLASS GEMM template and launch a GEMM kernel. -cudaError_t CutlassSgemmNN( - int M, - int N, - int K, - float alpha, - float const *A, - int lda, - float const *B, - int ldb, - float beta, - float *C, - int ldc) { - - // Define type definition for single-precision CUTLASS GEMM with column-major - // input matrices and 128x128x8 threadblock tile size (chosen by default). - // - // To keep the interface manageable, several helpers are defined for plausible compositions - // including the following example for single-precision GEMM. Typical values are used as - // default template arguments. See `cutlass/gemm/device/default_gemm_configuration.h` for more details. - // - // To view the full gemm device API interface, see `cutlass/gemm/device/gemm.h` - - using ColumnMajor = cutlass::layout::ColumnMajor; - - using CutlassGemm = cutlass::gemm::device::Gemm; // Layout of C matrix - - // Define a CUTLASS GEMM type - CutlassGemm gemm_operator; - - // Construct the CUTLASS GEMM arguments object. - // - // One of CUTLASS's design patterns is to define gemm argument objects that are constructible - // in host code and passed to kernels by value. These may include pointers, strides, scalars, - // and other arguments needed by Gemm and its components. - // - // The benefits of this pattern are (1.) a structured, composable strategy for passing host-constructible - // arguments to kernels and (2.) minimized initialization overhead on kernel entry. - // - CutlassGemm::Arguments args({M , N, K}, // Gemm Problem dimensions - {A, lda}, // Tensor-ref for source matrix A - {B, ldb}, // Tensor-ref for source matrix B - {C, ldc}, // Tensor-ref for source matrix C - {C, ldc}, // Tensor-ref for destination matrix D (may be different memory than source C matrix) - {alpha, beta}); // Scalars used in the Epilogue - - // - // Launch the CUTLASS GEMM kernel. - // - - cutlass::Status status = gemm_operator(args); - - // - // Return a cudaError_t if the CUTLASS GEMM operator returned an error code. - // - - if (status != cutlass::Status::kSuccess) { - return cudaErrorUnknown; - } - - // Return success, if no errors were encountered. - return cudaSuccess; +#include "cutlass/tensor_ref.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "helper.h" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/device/gemm_universal.h" + + +using precision_a = cutlass::half_t; +using precision_b = cutlass::half_t; +using precision_output = cutlass::half_t; +using precision_accumulator = float; +using precision_epilogue = precision_accumulator; + + +using LayoutInputA = cutlass::layout::RowMajor; +using LayoutInputB = cutlass::layout::RowMajor; +using LayoutOutput = cutlass::layout::RowMajor; + +// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM +using MMAOp = cutlass::arch::OpClassTensorOp; // cutlass::arch::OpClassSimt + +// This code section describes CUDA SM architecture number +using SmArch = cutlass::arch::Sm80; // cutlass::arch::Sm86 + +// This code section describes the tile size a thread block will compute +// thread block tile M = 128, N = 128, K = 32 +using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; + +// This code section describes tile size a warp will compute +// warp tile M = 64, N = 64, K = 32 +using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; + +// MMA Op tile M = 8, N = 8, K = 4 +using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + +using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>; +static int const kStages = 4; + +// This is the number of elements per vectorized memory access. +// For half precision, it's 8 elements. +// This becomes the vector width of math instructions in epilogue too +static int const kEpilogueElementsPerAccess = 128 / cutlass::sizeof_bits::value; +using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + precision_output, + kEpilogueElementsPerAccess, + precision_accumulator, + precision_epilogue +>; + +//using Gemm = cutlass::gemm::device::Gemm< +// precision_a, +// LayoutInputA, +// precision_b, +// LayoutInputB, +// precision_output, +// LayoutOutput, +// precision_accumulator, +// MMAOp, +// SmArch, +// ThreadblockShape, +// WarpShape, +// InstructionShape, +// EpilogueOutputOp +//>; + +using GemmU = cutlass::gemm::device::GemmUniversal< + precision_a, LayoutInputA, + precision_b, LayoutInputB, + precision_output, LayoutOutput, + float, + MMAOp, + SmArch, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp +>; + +inline char const *to_string(cutlass::Status status) { + + switch (status) { + case cutlass::Status::kSuccess: + return "kSuccess"; + case cutlass::Status::kErrorMisalignedOperand: + return "kErrorMisalignedOperand"; + case cutlass::Status::kErrorInvalidLayout: + return "kErrorInvalidLayout"; + case cutlass::Status::kErrorInvalidProblem: + return "kErrorInvalidProblem"; + case cutlass::Status::kErrorNotSupported: + return "kErrorNotSupported"; + case cutlass::Status::kErrorWorkspaceNull: + return "kErrorWorkspaceNull"; + case cutlass::Status::kErrorInternal: + return "kErrorInternal"; + case cutlass::Status::kInvalid: + return "kInvalid"; + default: + break; + } + return "invalid"; } -/////////////////////////////////////////////////////////////////////////////////////////////////// -// -// The source code after this point in the file is generic CUDA using the CUDA Runtime API -// and simple CUDA kernels to initialize matrices and compute the general matrix product. -// -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Kernel to initialize a matrix with small integers. -__global__ void InitializeMatrix_kernel( - float *matrix, - int rows, - int columns, - int seed = 0) { - - int i = threadIdx.x + blockIdx.x * blockDim.x; - int j = threadIdx.y + blockIdx.y * blockDim.y; - - if (i < rows && j < columns) { - int offset = i + j * rows; - - // Generate arbitrary elements. - int const k = 16807; - int const m = 16; - float value = float(((offset + seed) * k % m) - m / 2); - - matrix[offset] = value; - } +template +cutlass::TensorRef toRef(torch::Tensor tensor) { + auto leadingAxis = tensor.size(1); + // stride is used to extract the leading dimension of each matrix, in our case it's always the nb of cols + // as we are row oriented. + cutlass::TensorRef ref((precision *) tensor.data_ptr(), layout(leadingAxis)); + return ref; } -/// Simple function to initialize a matrix to arbitrary small integers. -cudaError_t InitializeMatrix(float *matrix, int rows, int columns, int seed = 0) { - - dim3 block(16, 16); - dim3 grid( - (rows + block.x - 1) / block.x, - (columns + block.y - 1) / block.y - ); - - InitializeMatrix_kernel<<< grid, block >>>(matrix, rows, columns, seed); - - return cudaGetLastError(); -} -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Allocates device memory for a matrix then fills with arbitrary small integers. -cudaError_t AllocateMatrix(float **matrix, int rows, int columns, int seed = 0) { - cudaError_t result; - - size_t sizeof_matrix = sizeof(float) * rows * columns; - - // Allocate device memory. - result = cudaMalloc(reinterpret_cast(matrix), sizeof_matrix); - - if (result != cudaSuccess) { - std::cerr << "Failed to allocate matrix: " - << cudaGetErrorString(result) << std::endl; - return result; - } - - // Clear the allocation. - result = cudaMemset(*matrix, 0, sizeof_matrix); - - if (result != cudaSuccess) { - std::cerr << "Failed to clear matrix device memory: " - << cudaGetErrorString(result) << std::endl; - return result; - } - - // Initialize matrix elements to arbitrary small integers. - result = InitializeMatrix(*matrix, rows, columns, seed); - - if (result != cudaSuccess) { - std::cerr << "Failed to initialize matrix: " - << cudaGetErrorString(result) << std::endl; - return result; - } - - return result; +//https://github.com/NVIDIA/cutlass/discussions/396 -> tuning +/// Define a CUTLASS GEMM template and launch a GEMM kernel. +void CutlassSgemmNN( + torch::Tensor &A, + torch::Tensor &B, + torch::Tensor &C, + float alpha, + float beta) { + + int M = (int) A.size(0); + int N = (int) B.size(1); + int K = (int) A.size(1); + + auto ref_a = toRef(A); + auto ref_b = toRef(B); + auto ref_c = toRef(C); + auto ref_d = toRef(C); +// cutlass::gemm::GemmUniversalMode mode = cutlass::gemm::GemmUniversalMode::kGemm; + +// typename Gemm::Arguments arguments{ +// {M, N, K}, // Gemm Problem dimensions +// ref_a, // Tensor-ref for source matrix A +// ref_b, // ... B +// ref_c, // ... C +// ref_d, // ... output +// {alpha, beta}, // Scalars used in the Epilogue +// 1}; // split_k + + + typename GemmU::Arguments argumentsUniversal{ + cutlass::gemm::GemmUniversalMode::kGemm, + {M, N, K}, + 1, + {alpha, beta}, + ref_a.data(), + ref_b.data(), + ref_c.data(), + ref_d.data(), + int64_t(), + int64_t(), + int64_t(), + int64_t(), + int64_t(K), + int64_t(N), + int64_t(N), + int64_t(N) + }; + + cutlass::Status status = GemmU::can_implement(argumentsUniversal); + CHECK(status == cutlass::Status::kSuccess) + << "arg can't be implemented by this kernel: " + << to_string(status) + << std::endl; + + + GemmU gemm_operator; + auto workspace_size = GemmU::get_workspace_size(argumentsUniversal); + cutlass::device_memory::allocation workspace(workspace_size); + status = gemm_operator.initialize(argumentsUniversal, workspace.get()); + CHECK(status == cutlass::Status::kSuccess) + << "GEMM initialization failed: " + << to_string(status) + << std::endl; + + status = gemm_operator(); + CHECK(status == cutlass::Status::kSuccess) + << "GEMM execution failed: " + << to_string(status) + << std::endl; } -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Naive reference GEMM computation. -__global__ void ReferenceGemm_kernel( - int M, - int N, - int K, - float alpha, - float const *A, - int lda, - float const *B, - int ldb, - float beta, - float *C, - int ldc) { - - int i = threadIdx.x + blockIdx.x * blockDim.x; - int j = threadIdx.y + blockIdx.y * blockDim.y; - - if (i < M && j < N) { - float accumulator = 0; - - for (int k = 0; k < K; ++k) { - accumulator += A[i + k * lda] * B[k + j * ldb]; +void TestCutlassGemm(int M, int N, int K, float alpha, float beta) { + auto options = torch::TensorOptions() + .dtype(torch::kFloat32) + .layout(torch::kStrided) + .device(torch::kCUDA, 0) + .requires_grad(false); + + auto tensorA = torch::rand({M, K}, options); + auto tensorB = torch::rand({K, N}, options); + auto tensorCCutlass = torch::empty({M, N}, options).toType(torch::kFloat16); + auto tensorCTorchFp32 = torch::empty({M, N}, options); + auto tensorAFp16 = tensorA.toType(torch::kFloat16); + auto tensorBFp16 = tensorB.toType(torch::kFloat16); + auto tensorCTorchFp16 = torch::mm(tensorAFp16, tensorBFp16); + + clock_t start, end; + int nb_repeat = 10; + for (int i = 0; i < nb_repeat; i++) { + CutlassSgemmNN(tensorAFp16, tensorBFp16, tensorCCutlass, alpha, beta); } + cudaDeviceSynchronize(); + start = clock(); + for (int i = 0; i < nb_repeat; i++) { + CutlassSgemmNN(tensorAFp16, tensorBFp16, tensorCCutlass, alpha, beta); + } + cudaDeviceSynchronize(); + end = clock(); + auto cutlass_time = (double) (end - start) / CLOCKS_PER_SEC; + std::cout << "cutlass time: " << cutlass_time << std::endl; - C[i + j * ldc] = alpha * accumulator + beta * C[i + j * ldc]; - } -} - -/// Reference GEMM computation. -cudaError_t ReferenceGemm( - int M, - int N, - int K, - float alpha, - float const *A, - int lda, - float const *B, - int ldb, - float beta, - float *C, - int ldc) { - - dim3 block(16, 16); - dim3 grid( - (M + block.x - 1) / block.x, - (N + block.y - 1) / block.y - ); - - ReferenceGemm_kernel<<< grid, block >>>(M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); - - return cudaGetLastError(); -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Allocate several matrices in GPU device memory and call a single-precision -/// CUTLASS GEMM kernel. -cudaError_t TestCutlassGemm(int M, int N, int K, float alpha, float beta) { - cudaError_t result; - - // - // Define several matrices to be used as operands to GEMM kernels. - // - - // Compute leading dimensions for each matrix. - int lda = M; - int ldb = K; - int ldc = M; - - // Compute size in bytes of the C matrix. - size_t sizeof_C = sizeof(float) * ldc * N; - - // Define pointers to matrices in GPU device memory. - float *A; - float *B; - float *C_cutlass; - float *C_reference; - - // - // Allocate matrices in GPU device memory with arbitrary seeds. - // - - result = AllocateMatrix(&A, M, K, 0); - - if (result != cudaSuccess) { - return result; - } - - result = AllocateMatrix(&B, K, N, 17); - - if (result != cudaSuccess) { - cudaFree(A); - return result; - } - - result = AllocateMatrix(&C_cutlass, M, N, 101); - - if (result != cudaSuccess) { - cudaFree(A); - cudaFree(B); - return result; - } - - result = AllocateMatrix(&C_reference, M, N, 101); - - if (result != cudaSuccess) { - cudaFree(A); - cudaFree(B); - cudaFree(C_cutlass); - return result; - } - - result = cudaMemcpy(C_reference, C_cutlass, sizeof_C, cudaMemcpyDeviceToDevice); - - if (result != cudaSuccess) { - std::cerr << "Failed to copy C_cutlass matrix to C_reference: " - << cudaGetErrorString(result) << std::endl; - - cudaFree(C_reference); - cudaFree(C_cutlass); - cudaFree(B); - cudaFree(A); - - return result; - } - - // - // Launch CUTLASS GEMM. - // - - result = CutlassSgemmNN(M, N, K, alpha, A, lda, B, ldb, beta, C_cutlass, ldc); - - if (result != cudaSuccess) { - std::cerr << "CUTLASS GEMM kernel failed: " - << cudaGetErrorString(result) << std::endl; - - cudaFree(C_reference); - cudaFree(C_cutlass); - cudaFree(B); - cudaFree(A); - - return result; - } - - // - // Verify. - // - - // Launch reference GEMM - result = ReferenceGemm(M, N, K, alpha, A, lda, B, ldb, beta, C_reference, ldc); - - if (result != cudaSuccess) { - std::cerr << "Reference GEMM kernel failed: " - << cudaGetErrorString(result) << std::endl; - - cudaFree(C_reference); - cudaFree(C_cutlass); - cudaFree(B); - cudaFree(A); - - return result; - } - - // Copy to host and verify equivalence. - std::vector host_cutlass(ldc * N, 0); - std::vector host_reference(ldc * N, 0); - - result = cudaMemcpy(host_cutlass.data(), C_cutlass, sizeof_C, cudaMemcpyDeviceToHost); - - if (result != cudaSuccess) { - std::cerr << "Failed to copy CUTLASS GEMM results: " - << cudaGetErrorString(result) << std::endl; - - cudaFree(C_reference); - cudaFree(C_cutlass); - cudaFree(B); - cudaFree(A); - - return result; - } - - result = cudaMemcpy(host_reference.data(), C_reference, sizeof_C, cudaMemcpyDeviceToHost); - - if (result != cudaSuccess) { - std::cerr << "Failed to copy Reference GEMM results: " - << cudaGetErrorString(result) << std::endl; - - cudaFree(C_reference); - cudaFree(C_cutlass); - cudaFree(B); - cudaFree(A); - - return result; - } - - // - // Free device memory allocations. - // - - cudaFree(C_reference); - cudaFree(C_cutlass); - cudaFree(B); - cudaFree(A); - - // - // Test for bit equivalence of results. - // - - if (host_cutlass != host_reference) { - std::cerr << "CUTLASS results incorrect." << std::endl; - - return cudaErrorUnknown; - } - - return cudaSuccess; + for (int i = 0; i < nb_repeat; i++) { + torch::mm_out(tensorCTorchFp32, tensorA, tensorB); + } + torch::cuda::synchronize(); + start = clock(); + for (int i = 0; i < nb_repeat; i++) { + torch::mm_out(tensorCTorchFp32, tensorA, tensorB); + } + torch::cuda::synchronize(); + end = clock(); + auto torch_time = (double) (end - start) / CLOCKS_PER_SEC; + std::cout << "torch time: " << torch_time << std::endl; + std::cout << "speedup: " << torch_time / cutlass_time << std::endl; + + auto diffPytorchOnly = tensorCTorchFp32.sub(tensorCTorchFp16).abs().sum().item(); + auto diffPytorchCutlass = tensorCTorchFp32.sub(tensorCCutlass).abs().sum().item(); + auto diffDiff = abs(diffPytorchOnly - diffPytorchCutlass); + + std::cout << std::boolalpha; + std::cout << "distance FP32 (PyTorch) - FP16 (cutlass): " + << diffPytorchCutlass + << std::endl; + std::cout << "distance FP32 (PyTorch) - FP16 (PyTorch): " + << diffPytorchOnly + << std::endl; + std::cout << "distance between distances: " + << diffDiff + << std::endl; + + CHECK(!torch::any(torch::isinf(tensorCCutlass)).item()); + CHECK(!torch::any(torch::isnan(tensorCCutlass)).item()); + CHECK(!torch::any(torch::isinf(tensorCTorchFp32)).item()); + CHECK(!torch::any(torch::isnan(tensorCTorchFp32)).item()); + CHECK(!torch::any(torch::isinf(tensorCTorchFp16)).item()); + CHECK(!torch::any(torch::isnan(tensorCTorchFp16)).item()); + CHECK(std::addressof(diffPytorchOnly) != std::addressof(diffPytorchCutlass)); + CHECK(diffPytorchOnly >= diffPytorchCutlass); + + cudaFree(tensorA.data_ptr()); + cudaFree(tensorB.data_ptr()); + cudaFree(tensorCCutlass.data_ptr()); + cudaFree(tensorCTorchFp32.data_ptr()); } /////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Entry point to basic_gemm example. -// // usage: // -// 00_basic_gemm +// 00_basic_gemm // int main(int argc, const char *arg[]) { - - // - // Parse the command line to obtain GEMM dimensions and scalar values. - // - - // GEMM problem dimensions. - int problem[3] = { 128, 128, 128 }; - - for (int i = 1; i < argc && i < 4; ++i) { - std::stringstream ss(arg[i]); - ss >> problem[i - 1]; - } - - // Scalars used for linear scaling the result of the matrix product. - float scalars[2] = { 1, 0 }; - - for (int i = 4; i < argc && i < 6; ++i) { - std::stringstream ss(arg[i]); - ss >> scalars[i - 4]; - } - - // - // Run the CUTLASS GEMM test. - // - - cudaError_t result = TestCutlassGemm( - problem[0], // GEMM M dimension - problem[1], // GEMM N dimension - problem[2], // GEMM K dimension - scalars[0], // alpha - scalars[1] // beta - ); - - if (result == cudaSuccess) { - std::cout << "Passed." << std::endl; - } - - // Exit. - return result == cudaSuccess ? 0 : -1; + torch::manual_seed(123); + // default problem dimensions + int problem[3] = {768, 256, 768}; + for (int i = 1; i < argc && i < 4; ++i) { + std::stringstream ss(arg[i]); + ss >> problem[i - 1]; + } + std::cout << "problem size: "; + for (const auto i: problem) { + std::cout << i << ' '; + } + std::cout << std::endl; + TestCutlassGemm( + problem[0], + problem[1], + problem[2], + 1, // alpha + 0 // beta + ); + + return 0; } - -///////////////////////////////////////////////////////////////////////////////////////////////////