Skip to content

Commit

Permalink
Merge pull request #4 from rapidsai/branch-0.12
Browse files Browse the repository at this point in the history
merge
  • Loading branch information
vishalmehta1991 authored Dec 14, 2019
2 parents a23ee8f + b0511d3 commit b0ad0a8
Show file tree
Hide file tree
Showing 4 changed files with 253 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

## Improvements
- PR #1473: C++: lazy initialization of "costly" resources inside cumlHandle
- PR #1443: Added a new overloaded GEMM primitive

- PR #1463: Update FAISS submodule to 1.6.1

Expand Down
96 changes: 95 additions & 1 deletion cpp/src_prims/linalg/gemm.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2018, NVIDIA CORPORATION.
* Copyright (c) 2018-2019, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -180,5 +180,99 @@ void gemm(const math_t *a, int n_rows_a, int n_cols_a, const math_t *b,
beta, cublas_h, stream);
}

/**
* @brief A wrapper for CUBLS GEMM function designed for handling all possible
* combinations of operand layouts.
* It computes the following equation: Z = alpha . X * Y + beta . Z
* @tparam T Data type of input/output matrices (float/double)
* @param handle cublas handle
* @param z output matrix of size M rows x N columns
* @param x input matrix of size M rows x K columns
* @param y input matrix of size K rows x N columns
* @param _M number of rows of X and Z
* @param _N number of rows of Y and columns of Z
* @param _K number of columns of X and rows of Y
* @param isZColMajor Storage layout of Z. true = col major, false = row major
* @param isXColMajor Storage layout of X. true = col major, false = row major
* @param isYColMajor Storage layout of Y. true = col major, false = row major
* @param alpha scalar
* @param beta scalar
* @{
*/
template <typename T>
void gemm(cublasHandle_t handle, T *z, T *x, T *y, int _M, int _N, int _K,
bool isZColMajor, bool isXColMajor, bool isYColMajor,
cudaStream_t stream, T alpha = T(1.0), T beta = T(0.0)) {
cublasOperation_t trans_a, trans_b;
T *a, *b, *c;
int lda, ldb, ldc;
int M, N, K;
// This function performs c = a * b. Based on the required output layout,
// either a = x, b = y or a = y, b = x. In either case c = z.
if (isZColMajor == true) {
// Result c is required in column major layout. Thus we perform,
// z = x * y
// Using BLAS call c = a * b. Therefore a = x, b = y and c = z

a = x;
// If x is in row major layout, cublas needs to transpose x first,
// therefore trans_x needs to be CUBLAS_OP_T. If x is in column major
// layout, trans_b needs to be CUBLAS_OP_N.
trans_a = isXColMajor == true ? CUBLAS_OP_N : CUBLAS_OP_T;
// Set leading dimension appropriately
lda = isXColMajor == true ? _M : _K;

b = y;
// If y is in row major layout, cublas needs to transpose y first,
// therefore trans_x needs to be CUBLAS_OP_T. If x is in column major
// layout, trans_b needs to be CUBLAS_OP_N.
trans_b = isYColMajor == true ? CUBLAS_OP_N : CUBLAS_OP_T;
ldb = isYColMajor == true ? _K : _N;

c = z;
ldc = _M;
M = _M;
N = _N;
K = _K;
} else {
// Result c is required in row major layout Thus we pick
// a = y, b = x and c = a * b = y * x
// cublas produces output matrix only in column major layout. To get output
// matrix on row major layout, we need to produce transpose of output
// in column major layout. Therefore we perform,
// tr(z) = tr(y) * tr(x)
// we model this using cublas call for c = a * b
// therefore a = tr(y), b = tr(x) and c = tr(z)

a = y;
// If y is in row major layout, it can be/ interpreted as tr(y) on column
// major layout. Therefore we can pass trans_a as CUBLAS_OP_N. If y is in
// column major layout, cublas needs to transpose y first, therefore
// trans_a needs to be CUBLAS_OP_T
trans_a = isYColMajor == true ? CUBLAS_OP_T : CUBLAS_OP_N;
// Set leading dimension appropriately
lda = isYColMajor == true ? _K : _N;

b = x;
// If x is in row major layout, it can be interpreted as tr(x) on column
// major layout. Therefore we can pass trans_b as CUBLAS_OP_N. If x is in
// column major layout, cublas needs to trasponse x first, therefore
// trans_b needs to be CUBLAS_OP_T
trans_b = isXColMajor == true ? CUBLAS_OP_T : CUBLAS_OP_N;
// Set leading dimension appropriately
ldb = isXColMajor == true ? _M : _K;

c = z;
ldc = _N;

M = _N;
N = _M;
K = _K;
}
// Actual cuBLAS call
CUBLAS_CHECK(LinAlg::cublasgemm(handle, trans_a, trans_b, M, N, K, &alpha, a,
lda, b, ldb, &beta, c, ldc, stream));
}

} // end namespace LinAlg
} // end namespace MLCommon
1 change: 1 addition & 0 deletions cpp/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ if(BUILD_PRIMS_TESTS)
prims/entropy.cu
prims/gather.cu
prims/gemm.cu
prims/gemm_layout.cu
prims/gram.cu
prims/grid_sync.cu
prims/hinge.cu
Expand Down
156 changes: 156 additions & 0 deletions cpp/test/prims/gemm_layout.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
/*
* Copyright (c) 2019, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <gtest/gtest.h>
#include "cuda_utils.h"
#include "linalg/gemm.h"
#include "random/rng.h"
#include "test_utils.h"

namespace MLCommon {
namespace LinAlg {

template <typename T>
struct GemmLayoutInputs {
int M;
int N;
int K;
bool zLayout;
bool xLayout;
bool yLayout;
unsigned long long int seed;
};

// Reference GEMM implementation.
template <typename T>
__global__ void naiveGemm(T *Z, T *X, T *Y, int M, int N, int K,
bool isZColMajor, bool isXColMajor,
bool isYColMajor) {
int tidx = blockIdx.x * blockDim.x + threadIdx.x;
int tidy = blockIdx.y * blockDim.y + threadIdx.y;

for (int m = tidy; m < M; m += (blockDim.y * gridDim.y)) {
for (int n = tidx; n < N; n += (blockDim.x * gridDim.x)) {
T temp = T(0.0);
for (int k = 0; k < K; k++) {
int xIndex = isXColMajor ? m + k * M : m * K + k;
int yIndex = isYColMajor ? k + n * K : k * N + n;
temp += X[xIndex] * Y[yIndex];
}
int zIndex = isZColMajor ? m + n * M : m * N + n;
Z[zIndex] = temp;
}
}
}

template <typename T>
class GemmLayoutTest : public ::testing::TestWithParam<GemmLayoutInputs<T>> {
protected:
void SetUp() override {
params = ::testing::TestWithParam<GemmLayoutInputs<T>>::GetParam();
cudaStream_t stream;
cublasHandle_t handle;
CUBLAS_CHECK(cublasCreate(&handle));
CUDA_CHECK(cudaStreamCreate(&stream));
Random::Rng r(params.seed);

// We compute Z = X * Y and compare against reference result
// Dimensions of X : M x K
// Dimensions of Y : K x N
// Dimensions of Z : M x N

T *X = NULL; // Argument X
T *Y = NULL; // Argument Y

size_t xElems = params.M * params.K;
size_t yElems = params.K * params.N;
size_t zElems = params.M * params.N;

CUDA_CHECK(cudaMalloc(&X, xElems * sizeof(T)));
CUDA_CHECK(cudaMalloc(&Y, yElems * sizeof(T)));
CUDA_CHECK(cudaMalloc(&refZ, zElems * sizeof(T)));
CUDA_CHECK(cudaMalloc(&Z, zElems * sizeof(T)));

r.uniform(X, xElems, T(-10.0), T(10.0), stream);
r.uniform(Y, yElems, T(-10.0), T(10.0), stream);

dim3 blocks(ceildiv<int>(params.M, 128), ceildiv<int>(params.N, 4), 1);
dim3 threads(128, 4, 1);

naiveGemm<<<blocks, threads>>>(refZ, X, Y, params.M, params.N, params.K,
params.zLayout, params.xLayout,
params.yLayout);

gemm(handle, Z, X, Y, params.M, params.N, params.K, params.zLayout,
params.xLayout, params.yLayout, stream);

CUDA_CHECK(cudaFree(X));
CUDA_CHECK(cudaFree(Y));
CUDA_CHECK(cudaStreamDestroy(stream));
CUBLAS_CHECK(cublasDestroy(handle));
}

void TearDown() override {
CUDA_CHECK(cudaFree(refZ));
CUDA_CHECK(cudaFree(Z));
}

protected:
GemmLayoutInputs<T> params;
T *refZ = NULL; // Reference result for comparison
T *Z = NULL; // Computed result
};

const std::vector<GemmLayoutInputs<float>> inputsf = {
{80, 70, 80, true, true, true, 76433ULL},
{80, 100, 40, true, true, false, 426646ULL},
{20, 100, 20, true, false, true, 237703ULL},
{100, 60, 30, true, false, false, 538004ULL},
{50, 10, 60, false, true, true, 73012ULL},
{90, 90, 30, false, true, false, 538147ULL},
{30, 100, 10, false, false, true, 412352ULL},
{40, 80, 100, false, false, false, 297941ULL}};

const std::vector<GemmLayoutInputs<double>> inputsd = {
{10, 70, 40, true, true, true, 535648ULL},
{30, 30, 30, true, true, false, 956681ULL},
{70, 80, 50, true, false, true, 875083ULL},
{80, 90, 70, true, false, false, 50744ULL},
{90, 90, 30, false, true, true, 506321ULL},
{40, 100, 70, false, true, false, 638418ULL},
{80, 50, 30, false, false, true, 701529ULL},
{50, 80, 60, false, false, false, 893038ULL}};

typedef GemmLayoutTest<float> GemmLayoutTestF;
TEST_P(GemmLayoutTestF, Result) {
ASSERT_TRUE(
devArrMatch(refZ, Z, params.M * params.N, CompareApprox<float>(1e-4)));
}

typedef GemmLayoutTest<double> GemmLayoutTestD;
TEST_P(GemmLayoutTestD, Result) {
ASSERT_TRUE(
devArrMatch(refZ, Z, params.M * params.N, CompareApprox<float>(1e-6)));
}

INSTANTIATE_TEST_CASE_P(GemmLayoutTests, GemmLayoutTestF,
::testing::ValuesIn(inputsf));

INSTANTIATE_TEST_CASE_P(GemmLayoutTests, GemmLayoutTestD,
::testing::ValuesIn(inputsd));

} // end namespace LinAlg
} // end namespace MLCommon

0 comments on commit b0ad0a8

Please sign in to comment.