diff --git a/build_deps/tf_dependency/BUILD.tpl b/build_deps/tf_dependency/BUILD.tpl index 047baee..3c5526f 100644 --- a/build_deps/tf_dependency/BUILD.tpl +++ b/build_deps/tf_dependency/BUILD.tpl @@ -14,5 +14,13 @@ cc_library( visibility = ["//visibility:public"], ) + +cc_library( + name = "libtensorflow_cc", + srcs = ["%{TF_SHARED_CC_LIBRARY_NAME}"], + visibility = ["//visibility:public"], +) + %{TF_HEADER_GENRULE} -%{TF_SHARED_LIBRARY_GENRULE} \ No newline at end of file +%{TF_SHARED_LIBRARY_GENRULE} +%{TF_SHARED_CC_LIBRARY_GENRULE} \ No newline at end of file diff --git a/build_deps/tf_dependency/tf_configure.bzl b/build_deps/tf_dependency/tf_configure.bzl index 0c0b5e7..be03e21 100644 --- a/build_deps/tf_dependency/tf_configure.bzl +++ b/build_deps/tf_dependency/tf_configure.bzl @@ -6,6 +6,8 @@ _TF_SHARED_LIBRARY_DIR = "TF_SHARED_LIBRARY_DIR" _TF_SHARED_LIBRARY_NAME = "TF_SHARED_LIBRARY_NAME" +_TF_SHARED_CC_LIBRARY_NAME = "TF_SHARED_CC_LIBRARY_NAME" + _TF_CXX11_ABI_FLAG = "TF_CXX11_ABI_FLAG" _TF_CPLUSPLUS_VER = "TF_CPLUSPLUS_VER" @@ -204,7 +206,9 @@ def _tf_pip_impl(repository_ctx): tf_shared_library_dir = repository_ctx.os.environ[_TF_SHARED_LIBRARY_DIR] tf_shared_library_name = repository_ctx.os.environ[_TF_SHARED_LIBRARY_NAME] + tf_shared_cc_library_name = repository_ctx.os.environ[_TF_SHARED_CC_LIBRARY_NAME] tf_shared_library_path = "%s/%s" % (tf_shared_library_dir, tf_shared_library_name) + tf_shared_cc_library_path = "%s/%s" % (tf_shared_library_dir, tf_shared_cc_library_name) tf_cx11_abi = "-D_GLIBCXX_USE_CXX11_ABI=%s" % (repository_ctx.os.environ[_TF_CXX11_ABI_FLAG]) tf_cplusplus_ver = "-std=%s" % repository_ctx.os.environ[_TF_CPLUSPLUS_VER] @@ -217,10 +221,21 @@ def _tf_pip_impl(repository_ctx): [tf_shared_library_name], ) + tf_shared_cc_library_rule = _symlink_genrule_for_dir( + repository_ctx, + None, + "", + tf_shared_cc_library_name, + [tf_shared_cc_library_path], + [tf_shared_cc_library_name], + ) + _tpl(repository_ctx, "BUILD", { "%{TF_HEADER_GENRULE}": tf_header_rule, "%{TF_SHARED_LIBRARY_GENRULE}": tf_shared_library_rule, + "%{TF_SHARED_CC_LIBRARY_GENRULE}": tf_shared_cc_library_rule, "%{TF_SHARED_LIBRARY_NAME}": tf_shared_library_name, + "%{TF_SHARED_CC_LIBRARY_NAME}": tf_shared_cc_library_name, }) _tpl( @@ -237,6 +252,7 @@ tf_configure = repository_rule( _TF_HEADER_DIR, _TF_SHARED_LIBRARY_DIR, _TF_SHARED_LIBRARY_NAME, + _TF_SHARED_CC_LIBRARY_NAME, _TF_CXX11_ABI_FLAG, _TF_CPLUSPLUS_VER, ], diff --git a/configure.py b/configure.py index 4be1837..62ca034 100644 --- a/configure.py +++ b/configure.py @@ -689,10 +689,10 @@ def set_gcc_host_compiler_path(environ_cp): def choose_compiler(environ_cp): - question = 'Do you want to use Clang to build TensorFlow?' - yes_reply = 'Clang will be used to compile TensorFlow.' - no_reply = 'GCC will be used to compile TensorFlow.' - var = int(get_var(environ_cp, 'TF_NEED_CLANG', None, True, question, yes_reply, no_reply)) + question = 'Do you want to use Clang to build Deepray?' + yes_reply = 'Clang will be used to compile Deepray.' + no_reply = 'GCC will be used to compile Deepray.' + var = int(get_var(environ_cp, 'TF_NEED_CLANG', None, False, question, yes_reply, no_reply)) return var @@ -1141,6 +1141,10 @@ def main(): write_action_env_to_bazelrc("TF_HEADER_DIR", get_tf_header_dir()) write_action_env_to_bazelrc("TF_SHARED_LIBRARY_DIR", get_tf_shared_lib_dir()) write_action_env_to_bazelrc("TF_SHARED_LIBRARY_NAME", get_shared_lib_name()) + write_action_env_to_bazelrc( + "TF_SHARED_CC_LIBRARY_NAME", + get_shared_lib_name().replace("libtensorflow_framework", "libtensorflow_cc") + ) write_action_env_to_bazelrc("TF_CXX11_ABI_FLAG", tf.sysconfig.CXX11_ABI_FLAG) # This should be replaced with a call to tf.sysconfig if it's added write_action_env_to_bazelrc("TF_CPLUSPLUS_VER", get_cpp_version()) diff --git a/deepray/custom_ops/distributed_embeddings/cc/kernels/embedding_lookup_kernels.cu.cc b/deepray/custom_ops/distributed_embeddings/cc/kernels/embedding_lookup_kernels.cu.cc index ba70649..eca712d 100644 --- a/deepray/custom_ops/distributed_embeddings/cc/kernels/embedding_lookup_kernels.cu.cc +++ b/deepray/custom_ops/distributed_embeddings/cc/kernels/embedding_lookup_kernels.cu.cc @@ -1,6 +1,6 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 + * SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -31,10 +31,11 @@ namespace cg = cooperative_groups; namespace tensorflow { template -__device__ void EmbeddingReduceByIndices(cg::thread_block_tile g, T* out, const T* params, - int embedding_width, int query_nnz, const TIndex* indices, - TIndex* shmem_indices, Combiner combiner, - const T* weights) { +__device__ void EmbeddingReduceByIndices(cg::thread_block_tile g, T* out, + const T* params, int embedding_width, + int query_nnz, const TIndex* indices, + TIndex* shmem_indices, + Combiner combiner, const T* weights) { T weight = 1; int tid = g.thread_rank(); int row_off = tid / row * row; @@ -43,31 +44,32 @@ __device__ void EmbeddingReduceByIndices(cg::thread_block_tile g, T* out, // Remainder is handled first int remainder = query_nnz % tile; - // First stage, each CTA load one segment of indices in the sample into shared memory + // First stage, each CTA load one segment of indices in the sample into shared + // memory g.sync(); if (tid < remainder) { shmem_indices[tid] = indices[tid]; } g.sync(); // Second stage - // A CTA first reads indices from shared memory and finds the corresponding entry in the - // embedding table. Then the CTA reads the embedding vector and accumulates into register file. - // Each thread in the CTA reads one element of the embedding vector - _Pragma("unroll 4") - for (int i = tid / row; i < remainder; i += (tile + row - 1) / row) { + // A CTA first reads indices from shared memory and finds the corresponding + // entry in the embedding table. Then the CTA reads the embedding vector and + // accumulates into register file. Each thread in the CTA reads one element of + // the embedding vector + _Pragma("unroll 4") for (int i = tid / row; i < remainder; + i += (tile + row - 1) / row) { if (weights != nullptr) weight = weights[shmem_indices[i]]; - _Pragma("unroll") - for (int j = 0; j < (row + tile - 1) / tile; ++j) { + _Pragma("unroll") for (int j = 0; j < (row + tile - 1) / tile; ++j) { if (j * tile + row_tid < embedding_width) { result[j] += weight * - params[shmem_indices[i] * static_cast(embedding_width) + j * tile + row_tid]; + params[shmem_indices[i] * static_cast(embedding_width) + + j * tile + row_tid]; } } } - _Pragma("unroll") - for (int j = 0; j < (row + tile - 1) / tile; ++j) { + _Pragma("unroll") for (int j = 0; j < (row + tile - 1) / tile; ++j) { out[j] += result[j]; result[j] = 0; } @@ -77,20 +79,18 @@ __device__ void EmbeddingReduceByIndices(cg::thread_block_tile g, T* out, for (int processed = remainder; processed < query_nnz; processed += tile) { shmem_indices[tid] = indices[processed + tid]; g.sync(); - _Pragma("unroll 4") - for (int i = 0; i < row && i < tile; ++i) { + _Pragma("unroll 4") for (int i = 0; i < row && i < tile; ++i) { if (weights != nullptr) weight = weights[shmem_indices[i + row_off]]; - _Pragma("unroll") - for (int j = 0; j < (row + tile - 1) / tile; ++j) { + _Pragma("unroll") for (int j = 0; j < (row + tile - 1) / tile; ++j) { if (j * tile + row_tid < embedding_width) { result[j] += - weight * params[shmem_indices[i + row_off] * static_cast(embedding_width) + + weight * params[shmem_indices[i + row_off] * + static_cast(embedding_width) + j * tile + row_tid]; } } } - _Pragma("unroll") - for (int j = 0; j < (row + tile - 1) / tile; ++j) { + _Pragma("unroll") for (int j = 0; j < (row + tile - 1) / tile; ++j) { out[j] += result[j]; result[j] = 0; } @@ -99,49 +99,48 @@ __device__ void EmbeddingReduceByIndices(cg::thread_block_tile g, T* out, // reduce down to row elements, only first row have correct result for (int i = tile / 2; i >= row; i /= 2) { - _Pragma("unroll") - for (int j = 0; j < (row + tile - 1) / tile; ++j) { + _Pragma("unroll") for (int j = 0; j < (row + tile - 1) / tile; ++j) { out[j] += g.shfl_down(out[j], i); } } } template -__device__ void EmbeddingReduceByIndicesWide(cg::thread_block_tile g, T* out, const T* params, - int embedding_width, int query_nnz, - const TIndex* indices, TIndex* shmem_indices, - Combiner combiner, const T* weights, int rem_width) { +__device__ void EmbeddingReduceByIndicesWide( + cg::thread_block_tile g, T* out, const T* params, int embedding_width, + int query_nnz, const TIndex* indices, TIndex* shmem_indices, + Combiner combiner, const T* weights, int rem_width) { T weight = 1; int tid = g.thread_rank(); T result[(row + tile - 1) / tile] = {0}; // Remainder is handled first int remainder = query_nnz % tile; - // First stage, each CTA load one segment of indices in the sample into shared memory + // First stage, each CTA load one segment of indices in the sample into shared + // memory g.sync(); if (tid < remainder) { shmem_indices[tid] = indices[tid]; } g.sync(); // Second stage - // A CTA first reads indices from shared memory and finds the corresponding entry in the - // embedding table. Then the CTA reads the embedding vector and accumulates into register file. - // Each thread in the CTA reads one element of the embedding vector - _Pragma("unroll 4") - for (int i = 0; i < remainder; ++i) { + // A CTA first reads indices from shared memory and finds the corresponding + // entry in the embedding table. Then the CTA reads the embedding vector and + // accumulates into register file. Each thread in the CTA reads one element of + // the embedding vector + _Pragma("unroll 4") for (int i = 0; i < remainder; ++i) { if (weights != nullptr) weight = weights[shmem_indices[i]]; - _Pragma("unroll") - for (int j = 0; j < (row + tile - 1) / tile; ++j) { + _Pragma("unroll") for (int j = 0; j < (row + tile - 1) / tile; ++j) { if (j * tile + tid < rem_width) { result[j] += weight * - params[shmem_indices[i] * static_cast(embedding_width) + j * tile + tid]; + params[shmem_indices[i] * static_cast(embedding_width) + + j * tile + tid]; } } } - _Pragma("unroll") - for (int j = 0; j < (row + tile - 1) / tile; ++j) { + _Pragma("unroll") for (int j = 0; j < (row + tile - 1) / tile; ++j) { out[j] += result[j]; result[j] = 0; } @@ -151,20 +150,18 @@ __device__ void EmbeddingReduceByIndicesWide(cg::thread_block_tile g, T* o for (int processed = remainder; processed < query_nnz; processed += tile) { shmem_indices[tid] = indices[processed + tid]; g.sync(); - _Pragma("unroll 4") - for (int i = 0; i < tile; ++i) { + _Pragma("unroll 4") for (int i = 0; i < tile; ++i) { if (weights != nullptr) weight = weights[shmem_indices[i]]; - _Pragma("unroll") - for (int j = 0; j < (row + tile - 1) / tile; ++j) { + _Pragma("unroll") for (int j = 0; j < (row + tile - 1) / tile; ++j) { if (j * tile + tid < rem_width) { result[j] += weight * - params[shmem_indices[i] * static_cast(embedding_width) + j * tile + tid]; + params[shmem_indices[i] * static_cast(embedding_width) + + j * tile + tid]; } } } - _Pragma("unroll") - for (int j = 0; j < (row + tile - 1) / tile; ++j) { + _Pragma("unroll") for (int j = 0; j < (row + tile - 1) / tile; ++j) { out[j] += result[j]; result[j] = 0; } @@ -174,8 +171,10 @@ __device__ void EmbeddingReduceByIndicesWide(cg::thread_block_tile g, T* o template __global__ void EmbeddingLookUpVariableHot(const T* params, int embedding_width, - const TIndex* indptr, const TIndex* indices, T* out, - Combiner combiner, TIndex num_rows, const T* weights) { + const TIndex* indptr, + const TIndex* indices, T* out, + Combiner combiner, TIndex num_rows, + const T* weights) { auto row_group = cg::tiled_partition(cg::this_thread_block()); // smem same size as block size. @@ -192,25 +191,27 @@ __global__ void EmbeddingLookUpVariableHot(const T* params, int embedding_width, for (int step = 0; step < num_step; step++) { int64_t block_ind_offset = indptr[0]; int query_nnz = indptr[1] - block_ind_offset; - // we only want break down skewed long reductions, i.e, power law input backward. - // These reduction length correlate strongly to batchsize. Let's say we care about perf - // beyond 1k batchsize in general, then we probably need this threshold <512 to be able - // to breakdown long reduction in these cases. - // 128 is chosen so each warp have a full read into indptr when there are 4 of them. - // it seems works fine, but we can make it a function of launch config if needed + // we only want break down skewed long reductions, i.e, power law input + // backward. These reduction length correlate strongly to batchsize. Let's + // say we care about perf beyond 1k batchsize in general, then we probably + // need this threshold <512 to be able to breakdown long reduction in these + // cases. 128 is chosen so each warp have a full read into indptr when there + // are 4 of them. it seems works fine, but we can make it a function of + // launch config if needed if (query_nnz > 128 && blockDim.y > 1) { T result[(row + tile - 1) / tile] = {0}; - int prev_row_extra = - (query_nnz % blockDim.y) > threadIdx.y ? threadIdx.y : query_nnz % blockDim.y; + int prev_row_extra = (query_nnz % blockDim.y) > threadIdx.y + ? threadIdx.y + : query_nnz % blockDim.y; int row_extra = (query_nnz % blockDim.y) > threadIdx.y ? 1 : 0; int row_offset = (query_nnz / blockDim.y) * threadIdx.y + prev_row_extra; int row_nnz = (query_nnz / blockDim.y) + row_extra; EmbeddingReduceByIndices( row_group, result, params, embedding_width, row_nnz, - indices + block_ind_offset + row_offset, shmem_indices, combiner, weights); + indices + block_ind_offset + row_offset, shmem_indices, combiner, + weights); __syncthreads(); - _Pragma("unroll") - for (int j = 0; j < (row + tile - 1) / tile; ++j) { + _Pragma("unroll") for (int j = 0; j < (row + tile - 1) / tile; ++j) { shmem_values[threadIdx.y * blockDim.x + threadIdx.x] = result[j]; __syncthreads(); if (threadIdx.y == 0) { @@ -220,7 +221,8 @@ __global__ void EmbeddingLookUpVariableHot(const T* params, int embedding_width, if (combiner == Combiner::Mean) { result[j] /= query_nnz; } - if (j * tile + threadIdx.x < embedding_width) out[j * tile] = result[j]; + if (j * tile + threadIdx.x < embedding_width) + out[j * tile] = result[j]; } __syncthreads(); } @@ -230,15 +232,15 @@ __global__ void EmbeddingLookUpVariableHot(const T* params, int embedding_width, if (!step_counter) { step_counter = blockDim.y; T result[(row + tile - 1) / tile] = {0}; - EmbeddingReduceByIndices(row_group, result, params, embedding_width, - query_nnz, indices + block_ind_offset, - shmem_indices, combiner, weights); - _Pragma("unroll") - for (int j = 0; j < (row + tile - 1) / tile; ++j) { + EmbeddingReduceByIndices( + row_group, result, params, embedding_width, query_nnz, + indices + block_ind_offset, shmem_indices, combiner, weights); + _Pragma("unroll") for (int j = 0; j < (row + tile - 1) / tile; ++j) { if (combiner == Combiner::Mean) { result[j] /= query_nnz; } - if (j * tile + threadIdx.x < embedding_width) out[j * tile] = result[j]; + if (j * tile + threadIdx.x < embedding_width) + out[j * tile] = result[j]; } } step_counter -= 1; @@ -252,18 +254,21 @@ __global__ void EmbeddingLookUpVariableHot(const T* params, int embedding_width, // each tile not within warp so no reduction with shfldown // have an outer loop to handle arbitrary embedding_width template -__global__ void EmbeddingLookUpVariableHotWide(const T* params, int embedding_width, - const TIndex* indptr, const TIndex* indices, T* out, - Combiner combiner, TIndex num_rows, - const T* weights) { +__global__ void EmbeddingLookUpVariableHotWide( + const T* params, int embedding_width, const TIndex* indptr, + const TIndex* indices, T* out, Combiner combiner, TIndex num_rows, + const T* weights) { #if __CUDACC_VER_MAJOR__ >= 12 - // According to cuda doc, on compute capability 80 or higher, this should consume no memory + // According to cuda doc, on compute capability 80 or higher, this should + // consume no memory __shared__ cg::block_tile_memory shared_for_cg; cg::thread_block thb = cg::this_thread_block(shared_for_cg); auto row_group = cg::tiled_partition(thb); #else - // unchanged legacy code. these are under experimental namespace before cuda 12.0 - __shared__ cg::experimental::block_tile_memory shared_for_cg; + // unchanged legacy code. these are under experimental namespace before + // cuda 12.0 + __shared__ cg::experimental::block_tile_memory + shared_for_cg; cg::thread_block thb = cg::experimental::this_thread_block(shared_for_cg); auto row_group = cg::experimental::tiled_partition(thb); #endif @@ -283,18 +288,20 @@ __global__ void EmbeddingLookUpVariableHotWide(const T* params, int embedding_wi int64_t block_out_offset = cur_id * embedding_width; if (query_nnz > 128 && blockDim.y > 1) { T result[(row + tile - 1) / tile] = {0}; - int prev_row_extra = - (query_nnz % blockDim.y) > threadIdx.y ? threadIdx.y : query_nnz % blockDim.y; + int prev_row_extra = (query_nnz % blockDim.y) > threadIdx.y + ? threadIdx.y + : query_nnz % blockDim.y; int row_extra = (query_nnz % blockDim.y) > threadIdx.y ? 1 : 0; - int row_offset = (query_nnz / blockDim.y) * threadIdx.y + prev_row_extra; + int row_offset = + (query_nnz / blockDim.y) * threadIdx.y + prev_row_extra; int row_nnz = (query_nnz / blockDim.y) + row_extra; EmbeddingReduceByIndicesWide( row_group, result, params, embedding_width, row_nnz, - indices + block_ind_offset + row_offset, shmem_indices, combiner, weights, rem_width); + indices + block_ind_offset + row_offset, shmem_indices, combiner, + weights, rem_width); __syncthreads(); - _Pragma("unroll") - for (int j = 0; j < (row + tile - 1) / tile; ++j) { + _Pragma("unroll") for (int j = 0; j < (row + tile - 1) / tile; ++j) { shmem_values[threadIdx.y * blockDim.x + threadIdx.x] = result[j]; __syncthreads(); if (threadIdx.y == 0) { @@ -315,10 +322,10 @@ __global__ void EmbeddingLookUpVariableHotWide(const T* params, int embedding_wi if ((cur_id / gridDim.x) % blockDim.y == threadIdx.y) { T result[(row + tile - 1) / tile] = {0}; EmbeddingReduceByIndicesWide( - row_group, result, params, embedding_width, query_nnz, indices + block_ind_offset, - shmem_indices, combiner, weights, rem_width); - _Pragma("unroll") - for (int j = 0; j < (row + tile - 1) / tile; ++j) { + row_group, result, params, embedding_width, query_nnz, + indices + block_ind_offset, shmem_indices, combiner, weights, + rem_width); + _Pragma("unroll") for (int j = 0; j < (row + tile - 1) / tile; ++j) { if (combiner == Combiner::Mean) { result[j] /= query_nnz; } @@ -335,8 +342,8 @@ __global__ void EmbeddingLookUpVariableHotWide(const T* params, int embedding_wi } } template -__global__ void RowToSplit(TIndex* split_ptr, const TIndex* row_ptr, TIndex num_ids, - TIndex num_rows) { +__global__ void RowToSplit(TIndex* split_ptr, const TIndex* row_ptr, + TIndex num_ids, TIndex num_rows) { // effectively parallel binary search auto tid = blockDim.x * blockIdx.x + threadIdx.x; if (tid == num_rows) split_ptr[tid] = num_ids; @@ -356,33 +363,41 @@ __global__ void RowToSplit(TIndex* split_ptr, const TIndex* row_ptr, TIndex num_ } template -__global__ void OffsetToWeightsAndRowId(const TIndex* indptr, int32_t* out, T* weights) { +__global__ void OffsetToWeightsAndRowId(const TIndex* indptr, int32_t* out, + T* weights) { TIndex block_start_offset = indptr[blockIdx.x]; TIndex block_end_offset = indptr[blockIdx.x + 1]; - for (TIndex i = block_start_offset + threadIdx.x; i < block_end_offset; i += blockDim.x) { + for (TIndex i = block_start_offset + threadIdx.x; i < block_end_offset; + i += blockDim.x) { out[i] = blockIdx.x; } if (threadIdx.x == 0 && weights) - weights[blockIdx.x] = static_cast(1) / static_cast(block_end_offset - block_start_offset); + weights[blockIdx.x] = static_cast(1) / + static_cast(block_end_offset - block_start_offset); } template struct RowToSplitFunctor { - void operator()(const Eigen::GpuDevice& d, TIndex* split_ptr, const TIndex* row_ptr, - TIndex num_ids, TIndex num_rows) const { - TF_CHECK_OK(GpuLaunchKernel(RowToSplit, num_rows / 512 + 1, 512, 0, d.stream(), - split_ptr, row_ptr, num_ids, num_rows)); + void operator()(const Eigen::GpuDevice& d, TIndex* split_ptr, + const TIndex* row_ptr, TIndex num_ids, + TIndex num_rows) const { + TF_CHECK_OK(GpuLaunchKernel(RowToSplit, num_rows / 512 + 1, 512, 0, + d.stream(), split_ptr, row_ptr, num_ids, + num_rows)); } }; // The kernel does following things: // - generate available indices from count array // - try insert new value from available indices with each key -// - insert either succeed, or get existed value from pervious batch/other parallel threads -// - now we have needed output, update count array for future available index generation +// - insert either succeed, or get existed value from pervious batch/other +// parallel threads +// - now we have needed output, update count array for future available index +// generation template -__global__ void SearchAndUpdate(ViewT view, const T* keys, T* values, T* avails, CountT* counts, - T num_elem, int* g_counter, T capacity) { +__global__ void SearchAndUpdate(ViewT view, const T* keys, T* values, T* avails, + CountT* counts, T num_elem, int* g_counter, + T capacity) { cg::grid_group grid = cg::this_grid(); int tid = blockDim.x * blockIdx.x + threadIdx.x; // set global atomic counters to save a memset outside @@ -407,7 +422,8 @@ __global__ void SearchAndUpdate(ViewT view, const T* keys, T* values, T* avails, // now we have available indices, try insert them with keys int num_avail = g_counter[0]; T key, value; - // First deal with case where we still have empty slot but not enough to do in one go + // First deal with case where we still have empty slot but not enough to do in + // one go if (num_avail > 0 && num_avail < num_elem) { if (tid < num_avail) { int cur_offset = atomicAdd(g_counter + 1, 1); @@ -417,31 +433,37 @@ __global__ void SearchAndUpdate(ViewT view, const T* keys, T* values, T* avails, #if __CUDA_ARCH__ < 700 if constexpr (cuco::detail::is_packable()) { #endif - auto [iter, inserted] = view.insert_and_find(cuco::make_pair(key, value)); + auto [iter, inserted] = + view.insert_and_find(cuco::make_pair(key, value)); counts[iter->second] += 1; values[cur_offset] = iter->second; if (inserted) break; #if __CUDA_ARCH__ < 700 - // TODO(deyuf): add fallback logic determinism and pre-volta gpu. might need multi-kernel + // TODO(deyuf): add fallback logic determinism and pre-volta gpu. + // might need multi-kernel } #endif cur_offset = atomicAdd(g_counter + 1, 1); } } - // above run could stop before checking all keys, when all avaiable indices are inserted + // above run could stop before checking all keys, when all avaiable indices + // are inserted grid.sync(); - // threads with tid>g_counter will continue and insert_and_find remaining keys with default - // g_counter >= num_elem means all thread should returns since all keys are looked up already + // threads with tid>g_counter will continue and insert_and_find remaining + // keys with default g_counter >= num_elem means all thread should returns + // since all keys are looked up already if (tid < g_counter[1]) return; } // drop rest of no longer needed threads after possible grid sync if (tid >= num_elem) return; // Three cases we end up here: - // - we have enough new indices for use in one go, all num_elem threads got here + // - we have enough new indices for use in one go, all num_elem threads got + // here // - there is no new indices to use at all, all num_elem threads got here - // - we run out of new indices during above if, only tid matching never looked up key got here + // - we run out of new indices during above if, only tid matching never looked + // up key got here key = keys[tid]; // Don't insert OOV keys so table remain not full @@ -470,23 +492,27 @@ __global__ void SearchAndUpdate(ViewT view, const T* keys, T* values, T* avails, template struct IntegerLookupFunctor { - void operator()(OpKernelContext* context, T* table_ptr, CountT* count_ptr, const T* keys_ptr, - T* value_ptr, T num_elem, bool init, int64_t capacity) const { + void operator()(OpKernelContext* context, T* table_ptr, CountT* count_ptr, + const T* keys_ptr, T* value_ptr, T num_elem, bool init, + int64_t capacity) const { const auto& cu_stream = GetGpuStream(context); // get a mutable view from TF managed memory, initialize if needed auto table_capacity = capacity * 3 / 2; T constexpr empty_key_sentinel = -1; T constexpr empty_value_sentinel = -1; - auto slot = reinterpret_cast::pair_atomic_type*>(table_ptr); + auto slot = + reinterpret_cast::pair_atomic_type*>( + table_ptr); if (init) { using atomic_key_type = typename cuco::static_map::atomic_key_type; - using atomic_mapped_type = typename cuco::static_map::atomic_mapped_type; + using atomic_mapped_type = + typename cuco::static_map::atomic_mapped_type; auto grid_size = (table_capacity + 1023) / 1024; cuco::detail::initialize<256, atomic_key_type, atomic_mapped_type> - <<>>(slot, cuco::empty_key{empty_key_sentinel}, - cuco::empty_value{empty_value_sentinel}, - table_capacity); + <<>>( + slot, cuco::empty_key{empty_key_sentinel}, + cuco::empty_value{empty_value_sentinel}, table_capacity); } auto view = typename cuco::static_map::device_mutable_view( slot, table_capacity, cuco::empty_key{empty_key_sentinel}, @@ -494,23 +520,27 @@ struct IntegerLookupFunctor { // counters to figure out offsets between threads Tensor atomic_counter; - context->allocate_temp(DT_INT32, TensorShape({static_cast(2)}), &atomic_counter); + context->allocate_temp(DT_INT32, TensorShape({static_cast(2)}), + &atomic_counter); auto atomic_counter_ptr = atomic_counter.flat().data(); // DRAM workspace buffer to store new indices available for use Tensor temp_avail; - context->allocate_temp(DataTypeToEnum::value, TensorShape({static_cast(num_elem)}), + context->allocate_temp(DataTypeToEnum::value, + TensorShape({static_cast(num_elem)}), &temp_avail); auto temp_avail_ptr = temp_avail.flat().data(); int num_threads = 512; - // TODO: add loop for batch dim and get device prop from TF to figure safe/largest num_blocks - // For now, use max(enough_for_batch, 64) since most cards we care have more than 64 sm + // TODO: add loop for batch dim and get device prop from TF to figure + // safe/largest num_blocks For now, use max(enough_for_batch, 64) since most + // cards we care have more than 64 sm auto num_blocks = (num_elem + num_threads - 1) / num_threads; num_blocks = num_blocks < 64 ? 64 : num_blocks; void* args[] = {&view, &keys_ptr, &value_ptr, &temp_avail_ptr, &count_ptr, &num_elem, &atomic_counter_ptr, &capacity}; cudaLaunchCooperativeKernel( - (void*)SearchAndUpdate::device_mutable_view, T, CountT>, + (void*)SearchAndUpdate< + typename cuco::static_map::device_mutable_view, T, CountT>, num_blocks, num_threads, args, 0, cu_stream); } }; @@ -518,8 +548,9 @@ struct IntegerLookupFunctor { template struct EmbeddingLookupVariableHotnessFunctor { void operator()(const Eigen::GpuDevice& d, T* output_ptr, const T* param_ptr, - const TIndex* ids_ptr, const TIndex* offsets_ptr, TIndex num_rows, - TIndex embedding_width, Combiner combiner, TIndex ave_red_len) const { + const TIndex* ids_ptr, const TIndex* offsets_ptr, + TIndex num_rows, TIndex embedding_width, Combiner combiner, + TIndex ave_red_len) const { int next_power_of_two = 1 << Log2Ceiling64(embedding_width); // decide number of parallel tile base on reduction length @@ -527,7 +558,8 @@ struct EmbeddingLookupVariableHotnessFunctor { if (ave_red_len >= 256) parallel_tile = 2; if (ave_red_len >= 1024) parallel_tile = 4; - // decide number of threads per tile and adjust number of tile with CUDA limits + // decide number of threads per tile and adjust number of tile with CUDA + // limits int blockX = next_power_of_two / ILP; if (blockX < 32) blockX = 32; if (blockX > 256) blockX = 256; @@ -542,59 +574,70 @@ struct EmbeddingLookupVariableHotnessFunctor { switch (next_power_of_two) { case 1: - TF_CHECK_OK(GpuLaunchKernel(EmbeddingLookUpVariableHot, gridDim, blockDim, - smem_size, d.stream(), param_ptr, embedding_width, offsets_ptr, - ids_ptr, output_ptr, combiner, num_rows, nullptr)); + TF_CHECK_OK(GpuLaunchKernel( + EmbeddingLookUpVariableHot, gridDim, blockDim, + smem_size, d.stream(), param_ptr, embedding_width, offsets_ptr, + ids_ptr, output_ptr, combiner, num_rows, nullptr)); break; case 2: - TF_CHECK_OK(GpuLaunchKernel(EmbeddingLookUpVariableHot, gridDim, blockDim, - smem_size, d.stream(), param_ptr, embedding_width, offsets_ptr, - ids_ptr, output_ptr, combiner, num_rows, nullptr)); + TF_CHECK_OK(GpuLaunchKernel( + EmbeddingLookUpVariableHot, gridDim, blockDim, + smem_size, d.stream(), param_ptr, embedding_width, offsets_ptr, + ids_ptr, output_ptr, combiner, num_rows, nullptr)); break; case 4: - TF_CHECK_OK(GpuLaunchKernel(EmbeddingLookUpVariableHot, gridDim, blockDim, - smem_size, d.stream(), param_ptr, embedding_width, offsets_ptr, - ids_ptr, output_ptr, combiner, num_rows, nullptr)); + TF_CHECK_OK(GpuLaunchKernel( + EmbeddingLookUpVariableHot, gridDim, blockDim, + smem_size, d.stream(), param_ptr, embedding_width, offsets_ptr, + ids_ptr, output_ptr, combiner, num_rows, nullptr)); break; case 8: - TF_CHECK_OK(GpuLaunchKernel(EmbeddingLookUpVariableHot, gridDim, blockDim, - smem_size, d.stream(), param_ptr, embedding_width, offsets_ptr, - ids_ptr, output_ptr, combiner, num_rows, nullptr)); + TF_CHECK_OK(GpuLaunchKernel( + EmbeddingLookUpVariableHot, gridDim, blockDim, + smem_size, d.stream(), param_ptr, embedding_width, offsets_ptr, + ids_ptr, output_ptr, combiner, num_rows, nullptr)); break; case 16: - TF_CHECK_OK(GpuLaunchKernel(EmbeddingLookUpVariableHot, gridDim, - blockDim, smem_size, d.stream(), param_ptr, embedding_width, - offsets_ptr, ids_ptr, output_ptr, combiner, num_rows, nullptr)); + TF_CHECK_OK(GpuLaunchKernel( + EmbeddingLookUpVariableHot, gridDim, blockDim, + smem_size, d.stream(), param_ptr, embedding_width, offsets_ptr, + ids_ptr, output_ptr, combiner, num_rows, nullptr)); break; case 32: - TF_CHECK_OK(GpuLaunchKernel(EmbeddingLookUpVariableHot, gridDim, - blockDim, smem_size, d.stream(), param_ptr, embedding_width, - offsets_ptr, ids_ptr, output_ptr, combiner, num_rows, nullptr)); + TF_CHECK_OK(GpuLaunchKernel( + EmbeddingLookUpVariableHot, gridDim, blockDim, + smem_size, d.stream(), param_ptr, embedding_width, offsets_ptr, + ids_ptr, output_ptr, combiner, num_rows, nullptr)); break; case 64: - TF_CHECK_OK(GpuLaunchKernel(EmbeddingLookUpVariableHot, gridDim, - blockDim, smem_size, d.stream(), param_ptr, embedding_width, - offsets_ptr, ids_ptr, output_ptr, combiner, num_rows, nullptr)); + TF_CHECK_OK(GpuLaunchKernel( + EmbeddingLookUpVariableHot, gridDim, blockDim, + smem_size, d.stream(), param_ptr, embedding_width, offsets_ptr, + ids_ptr, output_ptr, combiner, num_rows, nullptr)); break; case 128: - TF_CHECK_OK(GpuLaunchKernel(EmbeddingLookUpVariableHot, gridDim, - blockDim, smem_size, d.stream(), param_ptr, embedding_width, - offsets_ptr, ids_ptr, output_ptr, combiner, num_rows, nullptr)); + TF_CHECK_OK(GpuLaunchKernel( + EmbeddingLookUpVariableHot, gridDim, blockDim, + smem_size, d.stream(), param_ptr, embedding_width, offsets_ptr, + ids_ptr, output_ptr, combiner, num_rows, nullptr)); break; case 256: - TF_CHECK_OK(GpuLaunchKernel(EmbeddingLookUpVariableHotWide, gridDim, - blockDim, smem_size, d.stream(), param_ptr, embedding_width, - offsets_ptr, ids_ptr, output_ptr, combiner, num_rows, nullptr)); + TF_CHECK_OK(GpuLaunchKernel( + EmbeddingLookUpVariableHotWide, gridDim, + blockDim, smem_size, d.stream(), param_ptr, embedding_width, + offsets_ptr, ids_ptr, output_ptr, combiner, num_rows, nullptr)); break; case 512: - TF_CHECK_OK(GpuLaunchKernel(EmbeddingLookUpVariableHotWide, gridDim, - blockDim, smem_size, d.stream(), param_ptr, embedding_width, - offsets_ptr, ids_ptr, output_ptr, combiner, num_rows, nullptr)); + TF_CHECK_OK(GpuLaunchKernel( + EmbeddingLookUpVariableHotWide, gridDim, + blockDim, smem_size, d.stream(), param_ptr, embedding_width, + offsets_ptr, ids_ptr, output_ptr, combiner, num_rows, nullptr)); break; default: - TF_CHECK_OK(GpuLaunchKernel(EmbeddingLookUpVariableHotWide, gridDim, - blockDim, smem_size, d.stream(), param_ptr, embedding_width, - offsets_ptr, ids_ptr, output_ptr, combiner, num_rows, nullptr)); + TF_CHECK_OK(GpuLaunchKernel( + EmbeddingLookUpVariableHotWide, gridDim, + blockDim, smem_size, d.stream(), param_ptr, embedding_width, + offsets_ptr, ids_ptr, output_ptr, combiner, num_rows, nullptr)); break; } } @@ -602,9 +645,11 @@ struct EmbeddingLookupVariableHotnessFunctor { template struct EmbeddingLookupVariableHotnessGradFunctor { - void operator()(OpKernelContext* context, const TIndex* ids_ptr, const TIndex* offset_in_ptr, - const T* grad_ptr, int64_t num_ids, TIndex embedding_width, TIndex num_rows, - int64_t dense_shape_dim0, int64_t max_red_len, Combiner combiner) const { + void operator()(OpKernelContext* context, const TIndex* ids_ptr, + const TIndex* offset_in_ptr, const T* grad_ptr, + int64_t num_ids, TIndex embedding_width, TIndex num_rows, + int64_t dense_shape_dim0, int64_t max_red_len, + Combiner combiner) const { const auto& cu_stream = GetGpuStream(context); cub::CountingInputIterator itr(0); @@ -613,10 +658,14 @@ struct EmbeddingLookupVariableHotnessGradFunctor { Tensor offsets; Tensor num_unique_ids; Tensor sorted_ids; - context->allocate_temp(DataTypeToEnum::value, TensorShape({num_ids}), &tmp_unique_ids); - context->allocate_temp(DataTypeToEnum::value, TensorShape({num_ids}), &offsets); - context->allocate_temp(DataTypeToEnum::value, TensorShape({1}), &num_unique_ids); - context->allocate_temp(DataTypeToEnum::value, TensorShape({num_ids}), &sorted_ids); + context->allocate_temp(DataTypeToEnum::value, + TensorShape({num_ids}), &tmp_unique_ids); + context->allocate_temp(DataTypeToEnum::value, + TensorShape({num_ids}), &offsets); + context->allocate_temp(DataTypeToEnum::value, TensorShape({1}), + &num_unique_ids); + context->allocate_temp(DataTypeToEnum::value, + TensorShape({num_ids}), &sorted_ids); auto tmp_unique_ids_ptr = tmp_unique_ids.flat().data(); auto offsets_ptr = offsets.flat().data(); auto num_unique_ids_ptr = num_unique_ids.flat().data(); @@ -624,62 +673,72 @@ struct EmbeddingLookupVariableHotnessGradFunctor { Tensor row; Tensor sorted_row; - context->allocate_temp(DataTypeToEnum::value, TensorShape({num_ids}), &row); - context->allocate_temp(DataTypeToEnum::value, TensorShape({num_ids}), &sorted_row); + context->allocate_temp(DataTypeToEnum::value, + TensorShape({num_ids}), &row); + context->allocate_temp(DataTypeToEnum::value, + TensorShape({num_ids}), &sorted_row); auto row_ptr = row.flat().data(); auto sorted_row_ptr = sorted_row.flat().data(); T* weights_ptr = nullptr; Tensor weights; if (combiner == Combiner::Mean) { - context->allocate_temp(DataTypeToEnum::value, TensorShape({num_rows}), &weights); + context->allocate_temp(DataTypeToEnum::value, TensorShape({num_rows}), + &weights); weights_ptr = weights.flat().data(); } - TF_CHECK_OK(GpuLaunchKernel(OffsetToWeightsAndRowId, num_rows, 32, 0, cu_stream, - offset_in_ptr, row_ptr, weights_ptr)); + TF_CHECK_OK(GpuLaunchKernel(OffsetToWeightsAndRowId, num_rows, + 32, 0, cu_stream, offset_in_ptr, row_ptr, + weights_ptr)); // Determine temporary device storage requirements size_t temp_sort = 0; size_t temp_unique = 0; - cub::DeviceRadixSort::SortPairs(nullptr, temp_sort, ids_ptr, sorted_ids_ptr, row_ptr, - sorted_row_ptr, num_ids, 0, Log2Ceiling64(dense_shape_dim0), - cu_stream); - cub::DeviceSelect::UniqueByKey(nullptr, temp_unique, sorted_ids_ptr, itr, tmp_unique_ids_ptr, - offsets_ptr, num_unique_ids_ptr, num_ids, cu_stream); + cub::DeviceRadixSort::SortPairs(nullptr, temp_sort, ids_ptr, sorted_ids_ptr, + row_ptr, sorted_row_ptr, num_ids, 0, + Log2Ceiling64(dense_shape_dim0), cu_stream); + cub::DeviceSelect::UniqueByKey(nullptr, temp_unique, sorted_ids_ptr, itr, + tmp_unique_ids_ptr, offsets_ptr, + num_unique_ids_ptr, num_ids, cu_stream); Tensor temp_storage; - size_t temp_storage_bytes = temp_sort > temp_unique ? temp_sort : temp_unique; - context->allocate_temp(DT_INT8, TensorShape({static_cast(temp_storage_bytes)}), - &temp_storage); + size_t temp_storage_bytes = + temp_sort > temp_unique ? temp_sort : temp_unique; + context->allocate_temp( + DT_INT8, TensorShape({static_cast(temp_storage_bytes)}), + &temp_storage); auto temp_storage_ptr = temp_storage.flat().data(); - cub::DeviceRadixSort::SortPairs(temp_storage_ptr, temp_sort, ids_ptr, sorted_ids_ptr, row_ptr, - sorted_row_ptr, num_ids, 0, Log2Ceiling64(dense_shape_dim0), - cu_stream); - cub::DeviceSelect::UniqueByKey(temp_storage_ptr, temp_unique, sorted_ids_ptr, itr, - tmp_unique_ids_ptr, offsets_ptr, num_unique_ids_ptr, num_ids, - cu_stream); - - // copy this back to host. should be ok to sync since there is not much to do in between - // TF way of doing it seems to be event query base + cub::DeviceRadixSort::SortPairs( + temp_storage_ptr, temp_sort, ids_ptr, sorted_ids_ptr, row_ptr, + sorted_row_ptr, num_ids, 0, Log2Ceiling64(dense_shape_dim0), cu_stream); + cub::DeviceSelect::UniqueByKey( + temp_storage_ptr, temp_unique, sorted_ids_ptr, itr, tmp_unique_ids_ptr, + offsets_ptr, num_unique_ids_ptr, num_ids, cu_stream); + + // copy this back to host. should be ok to sync since there is not much to + // do in between TF way of doing it seems to be event query base TIndex num_unique_ids_host = 0; cudaMemcpyAsync(&num_unique_ids_host, num_unique_ids_ptr, sizeof(TIndex), cudaMemcpyDeviceToHost, cu_stream); - cudaMemcpyAsync(offsets_ptr + num_unique_ids_host, &num_ids, sizeof(int32_t), - cudaMemcpyHostToDevice, cu_stream); + cudaMemcpyAsync(offsets_ptr + num_unique_ids_host, &num_ids, + sizeof(int32_t), cudaMemcpyHostToDevice, cu_stream); // allocate output Tensor* unique_ids = nullptr; - OP_REQUIRES_OK(context, - context->allocate_output(0, TensorShape({num_unique_ids_host}), &unique_ids)); + OP_REQUIRES_OK( + context, context->allocate_output(0, TensorShape({num_unique_ids_host}), + &unique_ids)); auto unique_ids_ptr = unique_ids->flat().data(); - cudaMemcpyAsync(unique_ids_ptr, tmp_unique_ids_ptr, num_unique_ids_host * sizeof(TIndex), + cudaMemcpyAsync(unique_ids_ptr, tmp_unique_ids_ptr, + num_unique_ids_host * sizeof(TIndex), cudaMemcpyDeviceToDevice, cu_stream); Tensor* unique_grad = nullptr; OP_REQUIRES_OK(context, - context->allocate_output(1, TensorShape({num_unique_ids_host, embedding_width}), - &unique_grad)); + context->allocate_output( + 1, TensorShape({num_unique_ids_host, embedding_width}), + &unique_grad)); auto unique_grad_ptr = unique_grad->flat().data(); int next_power_of_two = 1 << Log2Ceiling64(embedding_width); @@ -690,7 +749,8 @@ struct EmbeddingLookupVariableHotnessGradFunctor { if (max_red_len > 4096) parallel_tile = 4; if (max_red_len > 65536) parallel_tile = 6; - // decide number of threads per tile and adjust number of tile with CUDA limits + // decide number of threads per tile and adjust number of tile with CUDA + // limits int blockX = next_power_of_two / ILP; if (blockX < 32) blockX = 32; if (blockX > 256) blockX = 256; @@ -705,70 +765,81 @@ struct EmbeddingLookupVariableHotnessGradFunctor { switch (next_power_of_two) { case 1: - TF_CHECK_OK(GpuLaunchKernel(EmbeddingLookUpVariableHot, gridDim, - blockDim, smem_size, cu_stream, grad_ptr, embedding_width, - offsets_ptr, sorted_row_ptr, unique_grad_ptr, Combiner::Sum, - num_unique_ids_host, weights_ptr)); + TF_CHECK_OK(GpuLaunchKernel( + EmbeddingLookUpVariableHot, gridDim, blockDim, + smem_size, cu_stream, grad_ptr, embedding_width, offsets_ptr, + sorted_row_ptr, unique_grad_ptr, Combiner::Sum, num_unique_ids_host, + weights_ptr)); break; case 2: - TF_CHECK_OK(GpuLaunchKernel(EmbeddingLookUpVariableHot, gridDim, - blockDim, smem_size, cu_stream, grad_ptr, embedding_width, - offsets_ptr, sorted_row_ptr, unique_grad_ptr, Combiner::Sum, - num_unique_ids_host, weights_ptr)); + TF_CHECK_OK(GpuLaunchKernel( + EmbeddingLookUpVariableHot, gridDim, blockDim, + smem_size, cu_stream, grad_ptr, embedding_width, offsets_ptr, + sorted_row_ptr, unique_grad_ptr, Combiner::Sum, num_unique_ids_host, + weights_ptr)); break; case 4: - TF_CHECK_OK(GpuLaunchKernel(EmbeddingLookUpVariableHot, gridDim, - blockDim, smem_size, cu_stream, grad_ptr, embedding_width, - offsets_ptr, sorted_row_ptr, unique_grad_ptr, Combiner::Sum, - num_unique_ids_host, weights_ptr)); + TF_CHECK_OK(GpuLaunchKernel( + EmbeddingLookUpVariableHot, gridDim, blockDim, + smem_size, cu_stream, grad_ptr, embedding_width, offsets_ptr, + sorted_row_ptr, unique_grad_ptr, Combiner::Sum, num_unique_ids_host, + weights_ptr)); break; case 8: - TF_CHECK_OK(GpuLaunchKernel(EmbeddingLookUpVariableHot, gridDim, - blockDim, smem_size, cu_stream, grad_ptr, embedding_width, - offsets_ptr, sorted_row_ptr, unique_grad_ptr, Combiner::Sum, - num_unique_ids_host, weights_ptr)); + TF_CHECK_OK(GpuLaunchKernel( + EmbeddingLookUpVariableHot, gridDim, blockDim, + smem_size, cu_stream, grad_ptr, embedding_width, offsets_ptr, + sorted_row_ptr, unique_grad_ptr, Combiner::Sum, num_unique_ids_host, + weights_ptr)); break; case 16: - TF_CHECK_OK(GpuLaunchKernel(EmbeddingLookUpVariableHot, gridDim, - blockDim, smem_size, cu_stream, grad_ptr, embedding_width, - offsets_ptr, sorted_row_ptr, unique_grad_ptr, Combiner::Sum, - num_unique_ids_host, weights_ptr)); + TF_CHECK_OK(GpuLaunchKernel( + EmbeddingLookUpVariableHot, gridDim, blockDim, + smem_size, cu_stream, grad_ptr, embedding_width, offsets_ptr, + sorted_row_ptr, unique_grad_ptr, Combiner::Sum, num_unique_ids_host, + weights_ptr)); break; case 32: - TF_CHECK_OK(GpuLaunchKernel(EmbeddingLookUpVariableHot, gridDim, - blockDim, smem_size, cu_stream, grad_ptr, embedding_width, - offsets_ptr, sorted_row_ptr, unique_grad_ptr, Combiner::Sum, - num_unique_ids_host, weights_ptr)); + TF_CHECK_OK(GpuLaunchKernel( + EmbeddingLookUpVariableHot, gridDim, blockDim, + smem_size, cu_stream, grad_ptr, embedding_width, offsets_ptr, + sorted_row_ptr, unique_grad_ptr, Combiner::Sum, num_unique_ids_host, + weights_ptr)); break; case 64: - TF_CHECK_OK(GpuLaunchKernel(EmbeddingLookUpVariableHot, gridDim, - blockDim, smem_size, cu_stream, grad_ptr, embedding_width, - offsets_ptr, sorted_row_ptr, unique_grad_ptr, Combiner::Sum, - num_unique_ids_host, weights_ptr)); + TF_CHECK_OK(GpuLaunchKernel( + EmbeddingLookUpVariableHot, gridDim, blockDim, + smem_size, cu_stream, grad_ptr, embedding_width, offsets_ptr, + sorted_row_ptr, unique_grad_ptr, Combiner::Sum, num_unique_ids_host, + weights_ptr)); break; case 128: - TF_CHECK_OK(GpuLaunchKernel(EmbeddingLookUpVariableHot, gridDim, - blockDim, smem_size, cu_stream, grad_ptr, embedding_width, - offsets_ptr, sorted_row_ptr, unique_grad_ptr, Combiner::Sum, - num_unique_ids_host, weights_ptr)); + TF_CHECK_OK(GpuLaunchKernel( + EmbeddingLookUpVariableHot, gridDim, blockDim, + smem_size, cu_stream, grad_ptr, embedding_width, offsets_ptr, + sorted_row_ptr, unique_grad_ptr, Combiner::Sum, num_unique_ids_host, + weights_ptr)); break; case 256: - TF_CHECK_OK(GpuLaunchKernel(EmbeddingLookUpVariableHotWide, gridDim, - blockDim, smem_size, cu_stream, grad_ptr, embedding_width, - offsets_ptr, sorted_row_ptr, unique_grad_ptr, Combiner::Sum, - num_unique_ids_host, weights_ptr)); + TF_CHECK_OK(GpuLaunchKernel( + EmbeddingLookUpVariableHotWide, gridDim, + blockDim, smem_size, cu_stream, grad_ptr, embedding_width, + offsets_ptr, sorted_row_ptr, unique_grad_ptr, Combiner::Sum, + num_unique_ids_host, weights_ptr)); break; case 512: - TF_CHECK_OK(GpuLaunchKernel(EmbeddingLookUpVariableHotWide, gridDim, - blockDim, smem_size, cu_stream, grad_ptr, embedding_width, - offsets_ptr, sorted_row_ptr, unique_grad_ptr, Combiner::Sum, - num_unique_ids_host, weights_ptr)); + TF_CHECK_OK(GpuLaunchKernel( + EmbeddingLookUpVariableHotWide, gridDim, + blockDim, smem_size, cu_stream, grad_ptr, embedding_width, + offsets_ptr, sorted_row_ptr, unique_grad_ptr, Combiner::Sum, + num_unique_ids_host, weights_ptr)); break; default: - TF_CHECK_OK(GpuLaunchKernel(EmbeddingLookUpVariableHotWide, gridDim, - blockDim, smem_size, cu_stream, grad_ptr, embedding_width, - offsets_ptr, sorted_row_ptr, unique_grad_ptr, Combiner::Sum, - num_unique_ids_host, weights_ptr)); + TF_CHECK_OK(GpuLaunchKernel( + EmbeddingLookUpVariableHotWide, gridDim, + blockDim, smem_size, cu_stream, grad_ptr, embedding_width, + offsets_ptr, sorted_row_ptr, unique_grad_ptr, Combiner::Sum, + num_unique_ids_host, weights_ptr)); break; } } @@ -776,10 +847,14 @@ struct EmbeddingLookupVariableHotnessGradFunctor { template struct RowToSplitFunctor; template struct RowToSplitFunctor; -template struct EmbeddingLookupVariableHotnessFunctor; -template struct EmbeddingLookupVariableHotnessFunctor; -template struct EmbeddingLookupVariableHotnessGradFunctor; -template struct EmbeddingLookupVariableHotnessGradFunctor; +template struct EmbeddingLookupVariableHotnessFunctor; +template struct EmbeddingLookupVariableHotnessFunctor; +template struct EmbeddingLookupVariableHotnessGradFunctor; +template struct EmbeddingLookupVariableHotnessGradFunctor; template struct IntegerLookupFunctor; } // namespace tensorflow diff --git a/deepray/custom_ops/training_ops/BUILD b/deepray/custom_ops/training_ops/BUILD index c631c21..8ff1851 100644 --- a/deepray/custom_ops/training_ops/BUILD +++ b/deepray/custom_ops/training_ops/BUILD @@ -5,7 +5,6 @@ licenses(["notice"]) # Apache 2.0 custom_op_library( name = "_training_ops.so", srcs = [ - # "cc/kernels/training_op_helpers.h", "cc/kernels/training_ops.cc", "cc/kernels/training_ops.h", "cc/ops/training_ops.cc", diff --git a/deepray/custom_ops/training_ops/cc/kernels/training_op_helpers.h b/deepray/custom_ops/training_ops/cc/kernels/training_op_helpers.h deleted file mode 100644 index 186fc8a..0000000 --- a/deepray/custom_ops/training_ops/cc/kernels/training_op_helpers.h +++ /dev/null @@ -1,285 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef DEEPRAY_CUSTOM_OPS_TRAINING_OP_HELPERS_H_ -#define DEEPRAY_CUSTOM_OPS_TRAINING_OP_HELPERS_H_ - -#include - -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/variant_op_registry.h" -#include "tensorflow/core/kernels/dense_update_functor.h" -#include "tensorflow/core/kernels/variable_ops.h" -#include "tensorflow/core/lib/core/refcount.h" - -namespace tensorflow { - -// Must be called before performing a sparse operation on a variable. Ensures -// that no concurrent dense operations can happen while holding the variable's -// lock. -// @param ctx OpKernelContext for variable tensor cloning -// @param var Variable to be shared -// @param lock_held Whether the variable mutex was already held or not -// NOTE: This function uses variable's `copy_on_read_mode` flag to decide if -// it should immediately return or continue to lock the variable mutex for more -// processing, and always sets the `copy_on_read_mode` flag to true when this -// function returns. However, there is no guarantee that another op won't set -// the `copy_on_read_mode` flag back to false after this function. -// Therefore, for the operation that requires `copy_on_read` to stay true during -// its execution, the caller needs to lock the variable mutex outside and call -// this function with `lock_held = true` to avoid double locking. -template -Status EnsureSparseVariableAccess(OpKernelContext* ctx, Var* var, - bool lock_held = false) { - if (var->copy_on_read_mode.load()) { - return Status::OK(); - } - - std::optional ml; - if (!lock_held) { - ml.emplace(*var->mu()); - } - - // Once copy-on-read mode is True the refcount is guaranteed to be 1. This can - // also happen if there are no concurrent reads of the variable and - // copy-on-read mode is false. - if (var->tensor()->RefCountIsOne()) { - var->copy_on_read_mode.store(true); - return Status::OK(); - } - Tensor tmp; - if (std::is_same::value) { - AllocatorAttributes attr; - attr.set_on_host(true); - TF_RETURN_IF_ERROR(ctx->allocate_temp(var->tensor()->dtype(), - var->tensor()->shape(), &tmp, attr)); - - const auto elements_in = var->tensor()->flat(); - auto elements_out = tmp.flat(); - for (int64_t i = 0; i < elements_in.size(); ++i) { - elements_out(i) = elements_in(i); - } - } else { - AllocatorAttributes attr; - attr.set_gpu_compatible(true); - attr.set_nic_compatible(true); - TF_RETURN_IF_ERROR(ctx->allocate_temp(var->tensor()->dtype(), - var->tensor()->shape(), &tmp, attr)); - functor::DenseUpdate copy_functor; - copy_functor(ctx->eigen_device(), tmp.flat(), - const_cast(var->tensor())->flat()); - } - *var->tensor() = tmp; - var->copy_on_read_mode.store(true); - return Status::OK(); -} - -// Utility structure that releases a sequence of borrowed mutexes when it is -// deleted. -struct VariableInputLockHolder { - public: - VariableInputLockHolder( - std::vector vars, std::unique_ptr> locks, - std::unique_ptr> shared_locks) - : vars_(std::move(vars)), - locks_(std::move(locks)), - shared_locks_(std::move(shared_locks)) {} - - VariableInputLockHolder(VariableInputLockHolder&& other) - : vars_(std::move(other.vars_)), - locks_(std::move(other.locks_)), - shared_locks_(std::move(other.shared_locks_)) {} - - ~VariableInputLockHolder() { - // Release the locks before unreffing the Vars, because each lock - // is potentially borrowed from a Var in vars_. - locks_.reset(); - for (Var* var : vars_) { - var->Unref(); - } - } - - private: - std::vector vars_; - // NOTE: Use a `std::unique_ptr` instead of moving in a vector directly, - // because a `std::vector` is not movable on all platforms. - std::unique_ptr> locks_; - std::unique_ptr> shared_locks_; -}; - -// Returns a borrowed pointer to the mutex for the variable `input` in `ctx`. -// -// If `input` corresponds to a `DT_RESOURCE`-type variable input, -// `*maybe_resource` will be updated to contain the underlying resource, and the -// caller will be responsible for calling `Unref()` on that resource. -template -mutex* GetTrainingVariableMutex(OpKernelContext* ctx, int input, - Var** maybe_resource) { - *maybe_resource = nullptr; - if (ctx->input_dtype(input) == DT_RESOURCE) { - if (LookupResource(ctx, HandleFromInput(ctx, input), maybe_resource).ok()) { - return (*maybe_resource)->mu(); - } else { - ctx->CtxFailureWithWarning( - errors::Internal("Invalid variable reference.")); - return nullptr; - } - } - return ctx->input_ref_mutex(input); -} - -// MaybeLockVariableInputMutexesInOrder is a helper function to acquire mutexes -// in address order to mitigate deadlock. Returns a structure that, when -// deleted, will release the acquired mutexes. Safe to pass duplicates - will -// only lock each distinct mutex once. If sparse is true will ensure the -// variable gets switched to copy-on-read mode before trying to acquire the -// locks. If do_lock is false, returns immediately for reference variables. For -// resource variables in copy-on-read-mode it will grab a shared lock if do_lock -// is false, exclusive lock otherwise. Note that this silently doesn't lock -// mutexes for invalid variable references; in all usages this is followed by -// GetInputTensor which will signal a failure. -template -VariableInputLockHolder MaybeLockVariableInputMutexesInOrder( - OpKernelContext* ctx, bool do_lock, bool sparse, - const std::vector& input_ids) { - bool any_resource = false; - for (auto i : input_ids) { - if (ctx->input_dtype(i) == DT_RESOURCE) { - any_resource = true; - break; - } - } - if (!do_lock && !any_resource) { - return VariableInputLockHolder({}, {}, {}); - } - std::vector vars; - std::vector mutexes; - std::vector acquire_order; - for (auto input : input_ids) { - Var* var; - mutex* mutex = GetTrainingVariableMutex(ctx, input, &var); - if (var) vars.push_back(var); - // Only lock each mutex once if duplicates exist (n^2 but n is 2 or 3). - if (std::find(mutexes.begin(), mutexes.end(), mutex) == mutexes.end()) { - acquire_order.push_back(mutexes.size()); - mutexes.push_back(mutex); - } - } - std::sort(acquire_order.begin(), acquire_order.end(), - [&mutexes](int a, int b) { return mutexes[a] < mutexes[b]; }); - - auto locks = std::make_unique>(); - auto shared_locks = std::make_unique>(); - locks->reserve(acquire_order.size()); - - for (auto acquire : acquire_order) { - mutex* mu = mutexes[acquire]; - if (mu != nullptr) { - if (!sparse || do_lock) { - locks->emplace_back(*mu); - } else { - shared_locks->emplace_back(*mu); - } - } - } - auto variableInputLock = - VariableInputLockHolder(vars, std::move(locks), std::move(shared_locks)); - if (sparse) { - // Enable sparse variables' access. - // NOTE: This can not be done before the variable input locks are held, - // because a race condition can happen between this and another thread that - // turns off some variable's `copy_on_read_mode` after this thread enables - // sparse access; when a later function sees `copy_on_read_mode` is off, it - // will try to lock the variable again for updating `copy_on_read_mode` and - // cause the deadlock, since the variable mutex is non-re-entrant. - for (auto* var : vars) { - EnsureSparseVariableAccess(ctx, var, /*lock_held=*/true) - .IgnoreError(); - } - } - return variableInputLock; -} - -void MaybeForwardRefInputToRefOutput(OpKernelContext* ctx, int input, - int output); - -// This is for use with ResourceVariables to ensure *tensor has a -// reference count of 1 before you update it. -// REQUIRES: If you pass in variable->tensor(), *variable->mu() must be held. -template -Status PrepareToUpdateVariable(OpKernelContext* ctx, Tensor* tensor, - bool copy_on_read_mode) { - if (copy_on_read_mode || !tensor->RefCountIsOne()) { - // Tensor's buffer is in use by some read, so we need to copy before - // updating. - Tensor tmp; - if (std::is_same::value) { - AllocatorAttributes attr; - attr.set_on_host(true); - TF_RETURN_IF_ERROR( - ctx->allocate_temp(tensor->dtype(), tensor->shape(), &tmp, attr)); - - const auto elements_in = tensor->flat(); - auto elements_out = tmp.flat(); - for (int64_t i = 0; i < elements_in.size(); ++i) { - elements_out(i) = elements_in(i); - } - } else { - AllocatorAttributes attr; - attr.set_gpu_compatible(true); - attr.set_nic_compatible(true); - TF_RETURN_IF_ERROR( - ctx->allocate_temp(tensor->dtype(), tensor->shape(), &tmp, attr)); - functor::DenseUpdate copy_functor; - copy_functor(ctx->eigen_device(), tmp.flat(), - const_cast(tensor)->flat()); - } - *tensor = tmp; - } - return Status::OK(); -} - -// This gives you `*out`, a tensor you can update, corresponding to a variable -// passed as input index `input`. This handles the differences between -// reference and resource variables. For reference variables we can just grab -// the tensor, grabbing the lock if lock_held is False. -// -// For resource variables we, if sparse is true, ensure it's in copy-on-read -// mode, and then, regardless of the value of sparse, ensure its refcount is 1 -// (by potentially copying its contents). In this case lock_held is ignored. -template -Status GetInputTensorFromVariable(OpKernelContext* ctx, int input, - bool lock_held, bool sparse, Tensor* out) { - if (ctx->input_dtype(input) == DT_RESOURCE) { - core::RefCountPtr var; - TF_RETURN_IF_ERROR(LookupResource(ctx, HandleFromInput(ctx, input), &var)); - if (sparse) { - TF_RETURN_IF_ERROR(EnsureSparseVariableAccess(ctx, var.get())); - *out = *var->tensor(); - return Status::OK(); - } - TF_RETURN_IF_ERROR(PrepareToUpdateVariable( - ctx, var->tensor(), var->copy_on_read_mode.load())); - *out = *var->tensor(); - return Status::OK(); - } - *out = ctx->mutable_input(input, lock_held); - return Status::OK(); -} - -} // end namespace tensorflow - -#endif // DEEPRAY_CUSTOM_OPS_TRAINING_OP_HELPERS_H_ diff --git a/deepray/custom_ops/training_ops/cc/kernels/training_ops.cc b/deepray/custom_ops/training_ops/cc/kernels/training_ops.cc index 0f2cbe9..0cf0988 100644 --- a/deepray/custom_ops/training_ops/cc/kernels/training_ops.cc +++ b/deepray/custom_ops/training_ops/cc/kernels/training_ops.cc @@ -23,7 +23,6 @@ limitations under the License. #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/kernels/training_op_helpers.h" #include "tensorflow/core/kernels/variable_ops.h" -// #include "training_op_helpers.h" namespace tensorflow { diff --git a/deepray/deepray.bzl b/deepray/deepray.bzl index 8aa16e1..74a6403 100644 --- a/deepray/deepray.bzl +++ b/deepray/deepray.bzl @@ -11,6 +11,7 @@ def custom_op_library( **kwargs): deps = deps + [ "@local_config_tf//:libtensorflow_framework", + "@local_config_tf//:libtensorflow_cc", "@local_config_tf//:tf_header_lib", ]