Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Computing result error in cutlass gemm with specfied shape #1829

Open
chenhongyu2048 opened this issue Sep 19, 2024 · 3 comments
Open

Comments

@chenhongyu2048
Copy link

when I use cutlass template to write my own gemm kernel, I meet a Internal error, even I follow the settings provided by cutlass profiler.

The full code is as below:

#include <iostream>

#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm.h"

#include "cutlass/util/command_line.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/tensor_view_io.h"

#include "helper.h"

/////////////////////////////////////////////////////////////////////////////////////////////////

/// Result structure
struct Result
{

    double runtime_ms;
    double gflops;
    cutlass::Status status;
    cudaError_t error;
    bool passed;

    //
    // Methods
    //

    Result(
        double runtime_ms = 0,
        double gflops = 0,
        cutlass::Status status = cutlass::Status::kSuccess,
        cudaError_t error = cudaSuccess) : runtime_ms(runtime_ms), gflops(gflops), status(status), error(error), passed(true) {}
};

///////////////////////////////////////////////////////////////////////////////////////////////////

// Command line options parsing
struct Options {
    
    bool help;

    cutlass::gemm::GemmCoord problem_size;
    int batch_count;
    float alpha;
    float beta;

    bool reference_check;
    int iterations;

    Options() : help(false),
                problem_size({32, 24576, 6144}),
                batch_count(1),
                reference_check(true),
                iterations(20),
                alpha(1),
                beta() {}

    bool valid() {
        return true;
    }

    // Parses the command line
    void parse(int argc, char const **args) {

        cutlass::CommandLine cmd(argc, args);
        if (cmd.check_cmd_line_flag("help")) {
            help = true;
        }

        cmd.get_cmd_line_argument("m", problem_size.m());
        cmd.get_cmd_line_argument("n", problem_size.n());
        cmd.get_cmd_line_argument("k", problem_size.k());

        cmd.get_cmd_line_argument("alpha", alpha);
        cmd.get_cmd_line_argument("beta", beta);

        cmd.get_cmd_line_argument("iterations", iterations);
    }

    /// Prints the usage statement.
    std::ostream &print_usage(std::ostream &out) const {

        out << "14_ampere_tf32_tensorop_gemm example\n\n"
            << "This example uses the CUTLASS Library to execute TF32 tensorop GEMM computations.\n\n"
            << "Options:\n\n"
            << "  --help                      If specified, displays this usage statement.\n\n"
            << "  --m=<int>                   GEMM M dimension\n"
            << "  --n=<int>                   GEMM N dimension\n"
            << "  --k=<int>                   GEMM K dimension\n"
            << "  --alpha=<f32>               Epilogue scalar alpha\n"
            << "  --beta=<f32>                Epilogue scalar beta\n"
            << "  --iterations=<int>          Number of profiling iterations to perform.\n\n";

        out << "\nExamples:\n\n"
            << "$ ./examples/14_ampere_tf32_tensorop_gemm/14_ampere_tf32_tensorop_gemm --m=1024 --n=512 --k=1024 \\\n"
            << "     --alpha=2 --beta=0.707 \n";

        return out;
    }

    /// Compute performance in GFLOP/s
    double gflops(double runtime_s) const {
        // Number of real-valued multiply-adds
        int64_t fmas = problem_size.product() * batch_count;

        // Two flops per multiply-add
        return 2.0 * double(fmas) / double(1.0e9) / runtime_s;
    }
};

///////////////////////////////////////////////////////////////////////////////////////////////////

// The code section below describes datatype for input, output matrices and computation between
// elements in input matrices.
using ElementAccumulator = cutlass::half_t;                  // <- data type of accumulator
using ElementComputeEpilogue = ElementAccumulator; // <- data type of epilogue operations
using ElementInputA = cutlass::half_t;                       // <- data type of elements in input matrix A
using ElementInputB = cutlass::half_t;                       // <- data type of elements in input matrix B
using ElementOutput = cutlass::half_t;                       // <- data type of elements in output matrix D

// The code section below describes matrix layout of input and output matrices. Row Colume for A matrix, while Column Major for B and C matrix.
using LayoutInputA = cutlass::layout::RowMajor;
using LayoutInputB = cutlass::layout::RowMajor;
using LayoutOutput = cutlass::layout::ColumnMajor;

// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM
using MMAOp = cutlass::arch::OpClassTensorOp;

// This code section describes CUDA SM architecture number
using SmArch = cutlass::arch::Sm80;

// This code section describes the tile size a thread block will compute
using ShapeMMAThreadBlock = cutlass::gemm::GemmShape<256, 128, 32>; // <- threadblock tile M = 128, N = 128, K = 32
// This code section describes tile size a warp will compute
using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 32>; // <- warp tile M = 64, N = 64, K = 32
// This code section describes the size of MMA op
using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 8>; // <- MMA Op tile M = 16, N = 8, K = 8

// This code section describes how threadblocks are scheduled on GPU
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>;

// This code section describes the epilogue part of the kernel
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
    ElementOutput,                                    // <- data type of output matrix
    128 / cutlass::sizeof_bits<ElementOutput>::value, // <- the number of elements per vectorized
                                                      // memory access. For a byte, it's 16
                                                      // elements. This becomes the vector width of
                                                      // math instructions in the epilogue too
    ElementAccumulator,                               // <- data type of accumulator
    ElementComputeEpilogue>;                          // <- data type for alpha/beta in linear combination function

// Number of pipelines you want to use
constexpr int NumStages = 2;

using Gemm = cutlass::gemm::device::Gemm<ElementInputA,
                                         LayoutInputA,
                                         ElementInputB,
                                         LayoutInputB,
                                         ElementOutput,
                                         LayoutOutput,
                                         ElementAccumulator,
                                         MMAOp,
                                         SmArch,
                                         ShapeMMAThreadBlock,
                                         ShapeMMAWarp,
                                         ShapeMMAOp,
                                         EpilogueOp,
                                         SwizzleThreadBlock,
                                         NumStages,
                                         8,       // AlignmentA
                                         8>;      //AlignmentB

int run(Options &options) {

    // Create a tuple of problem size for matrix multiplication
    cutlass::gemm::GemmCoord problem_size = options.problem_size;

    // Initialize tensors using CUTLASS helper functions
    cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(
        problem_size.mk()); // <- Create matrix A with dimensions M x K
    cutlass::HostTensor<ElementInputB, LayoutInputB> tensor_b(
        problem_size.kn()); // <- Create matrix B with dimensions K x N
    cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_c(
        problem_size.mn()); // <- Create matrix C with dimensions M x N
    cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_d(
        problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from
                            // CUTLASS kernel
    cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_ref_d(
        problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from
                            // reference kernel

    // Fill input and output matrices on host using CUTLASS helper functions
    cutlass::reference::host::TensorFillRandomUniform(
        tensor_a.host_view(),
        1,
        ElementInputA(4),
        ElementInputA(-4),
        0); // <- Fill matrix A on host with uniform-distribution random data
    cutlass::reference::host::TensorFillRandomUniform(
        tensor_b.host_view(),
        1,
        ElementInputB(4),
        ElementInputB(-4),
        0); // <- Fill matrix B on host with uniform-distribution random data
    cutlass::reference::host::TensorFillRandomUniform(
        tensor_c.host_view(),
        1,
        ElementOutput(4),
        ElementOutput(-4),
        0); // <- Fill matrix C on host with uniform-distribution random data
    cutlass::reference::host::TensorFill(
        tensor_d.host_view()); // <- fill matrix D on host with zeros
    cutlass::reference::host::TensorFill(
        tensor_ref_d.host_view()); // <- fill matrix D for reference on host with zeros

    // Copy data from host to GPU
    tensor_a.sync_device();
    tensor_b.sync_device();
    tensor_c.sync_device();
    tensor_d.sync_device();
    tensor_ref_d.sync_device();

    // Initialize alpha and beta for dot product computation
    ElementComputeEpilogue alpha = ElementComputeEpilogue(options.alpha);
    ElementComputeEpilogue beta = ElementComputeEpilogue(options.beta);

    // Split K dimension into 1 partitions
    int split_k_slices = 1;

    // Create a tuple of gemm kernel arguments. This is later passed as arguments to launch
    // instantiated CUTLASS kernel
    typename Gemm::Arguments arguments{problem_size,          // <- problem size of matrix multiplication
                                       tensor_a.device_ref(), // <- reference to matrix A on device
                                       tensor_b.device_ref(), // <- reference to matrix B on device
                                       tensor_c.device_ref(), // <- reference to matrix C on device
                                       tensor_d.device_ref(), // <- reference to matrix D on device
                                       {alpha, beta},         // <- tuple of alpha and beta
                                       split_k_slices};       // <- k-dimension split factor

    // Using the arguments, query for extra workspace required for matrix multiplication computation
    size_t workspace_size = Gemm::get_workspace_size(arguments);

    // Allocate workspace memory
    cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);

    // Instantiate CUTLASS kernel depending on templates
    Gemm gemm_op;

    // Check the problem size is supported or not
    cutlass::Status status = gemm_op.can_implement(arguments);
    CUTLASS_CHECK(status);

    // Initialize CUTLASS kernel with arguments and workspace pointer
    status = gemm_op.initialize(arguments, workspace.get());
    CUTLASS_CHECK(status);

    // warmup loop
    for (int iter = 0; iter < 5; ++iter) {
        // Launch initialized CUTLASS kernel
        status = gemm_op();
        CUTLASS_CHECK(status);
    }

    // Result structure
    Result result;

    // Construct events
    cudaEvent_t events[2];
    for (auto &event : events) {
        result.error = cudaEventCreate(&event);
        if (result.error != cudaSuccess) {
            std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl;
            return -1;
        }
    }

    // Record an event at the start of a series of GEMMs
    result.error = cudaEventRecord(events[0]);
    if (result.error != cudaSuccess) {
        std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl;
        return -1;
    }

    // Run profiling loop
    for (int iter = 0; iter < options.iterations; ++iter) {
        // Launch initialized CUTLASS kernel
        status = gemm_op();
        // CUDA_CHECK(cudaDeviceSynchronize());
        CUTLASS_CHECK(status);
    }

    // Stop profiling loop
    // Record an event when the GEMMs are complete
    result.error = cudaEventRecord(events[1]);
    if (result.error != cudaSuccess) {
        std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl;
        return -1;
    }

    // Wait for work on the device to complete.
    result.error = cudaEventSynchronize(events[1]);
    if (result.error != cudaSuccess) {
        std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl;
        return -1;
    }

    // Measure elapsed runtime
    float runtime_ms = 0;
    result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]);
    if (result.error != cudaSuccess) {
        std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl;
        return -1;
    }

    // Compute average runtime and GFLOPs.
    result.runtime_ms = double(runtime_ms) / double(options.iterations);
    result.gflops = options.gflops(result.runtime_ms / 1000.0);

    // Cleanup
    for (auto event : events) {
        (void)cudaEventDestroy(event);
    }

    // Create instantiation for device reference gemm kernel
    cutlass::reference::device::Gemm<ElementInputA,
                                     LayoutInputA,
                                     ElementInputB,
                                     LayoutInputB,
                                     ElementOutput,
                                     LayoutOutput,
                                     ElementComputeEpilogue,
                                     ElementComputeEpilogue> gemm_device;

    // Launch device reference gemm kernel
    gemm_device(problem_size,
                alpha,
                tensor_a.device_ref(),
                tensor_b.device_ref(),
                beta,
                tensor_c.device_ref(),
                tensor_ref_d.device_ref());

    // Wait for kernels to finish
    CUDA_CHECK(cudaDeviceSynchronize());

    // Copy output data from CUTLASS and reference kernel to host for comparison
    tensor_d.sync_host();
    tensor_ref_d.sync_host();

    // Check if output from CUTLASS kernel and reference kernel are equal or not
    bool passed = cutlass::reference::host::TensorEquals(
        tensor_d.host_view(),
        tensor_ref_d.host_view());

    std::cout << "Runtime: " << result.runtime_ms * 1000 << " us" << std::endl;
    std::cout << " GFLOPs: " << result.gflops << std::endl;

    // if (passed) {
    //     std::cout << "Runtime: " << result.runtime_ms * 1000 << " us" << std::endl;
    //     std::cout << " GFLOPs: " << result.gflops << std::endl;
    // }

    std::cout << (passed ? "Passed" : "Failed") << std::endl;

    return (passed ? 0 : -1);
}

int main(int argc, const char **argv) {
    bool notSupported = false;

    // Ampere Tensor Core operations exposed with mma.sync and ldmatrix are first available
    // in CUDA 11.0.
    // CUTLASS must be compiled with CUDA 11.0 Toolkit to run these examples.
    if (!(__CUDACC_VER_MAJOR__ >= 11)) {
        std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl;
        notSupported = true;
    }

    cudaDeviceProp props;
    cudaError_t error = cudaGetDeviceProperties(&props, 0);
    if (error != cudaSuccess) {
        std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
        return -1;
    }

    if (!((props.major * 10 + props.minor) >= 80)) {
        std::cerr << "Ampere Tensor Core operations must be run on a machine with compute capability at least 80."
                  << std::endl;
        notSupported = true;
    }

    if (notSupported) {
        // Returning zero so this test passes on older Toolkits. Its actions are no-op.
        return 0;
    }

    std::cout << "Device Name: " << props.name << std::endl;
    std::cout << "Number of Streaming Multiprocessors: " << props.multiProcessorCount << std::endl;

    Options options;
    options.parse(argc, argv);

    if (options.help) {
        options.print_usage(std::cout) << std::endl;
        return 0;
    }

    printf("%d x %d x %d Half tensor op Matrix Multiply\n", options.problem_size.m(), options.problem_size.n(), options.problem_size.k());

    if (!options.valid()) {
        std::cerr << "Invalid problem." << std::endl;
        return -1;
    }

    return run(options);
}

The above setting is provided by cutlass profiler:

Problem ID: 1

        Provider: CUTLASS
   OperationKind: gemm
       Operation: cutlass_tensorop_h1688gemm_256x128_32x2_tt_align8

          Status: Success
    Verification: ON
     Disposition: Passed

reference_device: Passed
          cuBLAS: Not run
           cuDNN: Not run

       Arguments: --gemm_kind=universal --m=32 --n=24576 --k=6144 --A=f16:row --B=f16:row --C=f16:column --D=f16:column  \
                  --alpha=1 --beta=0 --split_k_mode=serial --split_k_slices=1 --batch_count=1 --raster_order=heuristic  \
                  --swizzle_size=1 --op_class=tensorop --accum=f16 --cta_m=256 --cta_n=128 --cta_k=32 --cluster_m=1 --cluster_n=1  \
                  --cluster_k=1 --stages=2 --warps_m=4 --warps_n=2 --warps_k=1 --inst_m=16 --inst_n=8 --inst_k=8 --min_cc=75  \
                  --max_cc=1024

           Bytes: 303955968  bytes
           FLOPs: 9665249280  flops
           FLOPs/Byte: 31

         Runtime: 0.328755  ms
          Memory: 861.069 GiB/s

            Math: 29399.5 GFLOP/s

I compiled it with nvcc -std=c++17 -arch=sm_80 -I/xxx/third_party/cutlass/include -I/xxx/third_party/cutlass/tools/util/include -I/xxx/third_party/cutlass/tools/library/include -I/xxx/third_party/cutlass/examples/common -lcublas ./cutlass_gemm.cu --expt-relaxed-constexpr -o cutlass_gemm_example. I use cuda12.6 and RTX 6000 ada GPU.
I'd like to know if this is an issue with the way I'm using it?

@chenhongyu2048
Copy link
Author

UPDATE:
the running result: will not report error (computation is finished), but cutlass::reference::host::TensorEquals failed

@chenhongyu2048
Copy link
Author

UPDATE: the running result: will not report error (computation is finished), but cutlass::reference::host::TensorEquals failed

Maybe such an error is an accuracy issue?

@chenhongyu2048
Copy link
Author

After further debugging, we found that this error was caused by the difference between the calculation result and the tensor_ref_d in some of the values. We've added the following code:

ElementOutput sum = (ElementOutput)0;
ElementOutput *d_ptr = tensor_d.host_data();
ElementOutput *ref_d_ptr = tensor_ref_d.host_data();
for (int i = 0; i < 32 * 12288; ++i) {
    sum += *(d_ptr+i) - *(ref_d_ptr+i);
    if (*(d_ptr+i) - *(ref_d_ptr+i) != 0) {
        std::cout<<i<<" "<<*(d_ptr+i) - *(ref_d_ptr+i)<<std::endl;
    }
}
std::cout << sum << std::endl;

and got the following result:
when i=60390, print -4, which is the differnece between *(d_ptr+i) and *(ref_d_ptr+i).

@chenhongyu2048 chenhongyu2048 changed the title [QST] Internal error in cutlass gemm [BUG] Computing result error in cutlass gemm with specfied shape Sep 22, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

1 participant