From 9efcdd7344039f7544cb70bf6997d99d04146f8d Mon Sep 17 00:00:00 2001 From: Vincent Date: Sun, 17 Dec 2023 23:36:02 +0800 Subject: [PATCH 1/2] migrate DeepRec unique ops to Tensorflow2 --- deepray/custom_ops/BUILD | 1 + deepray/custom_ops/unique_ops/BUILD | 77 ++ deepray/custom_ops/unique_ops/__init__.py | 1 + .../unique_ops/cc/kernels/random.cc | 58 ++ .../custom_ops/unique_ops/cc/kernels/random.h | 40 + .../unique_ops/cc/kernels/random_test.cc | 37 + .../unique_ops/cc/kernels/task_runner.h | 116 +++ .../unique_ops/cc/kernels/unique_ali_op.cc | 242 ++++++ .../cc/kernels/unique_ali_op_gpu.cu.cc | 277 +++++++ .../cc/kernels/unique_ali_op_util.h | 768 ++++++++++++++++++ .../unique_ops/cc/ops/unique_ops.cc | 85 ++ .../custom_ops/unique_ops/python/__init__.py | 0 .../unique_ops/python/tests/__init__.py | 0 .../unique_ops/python/tests/run_all_test.py | 9 + .../unique_ops/python/tests/unique_op_test.py | 303 +++++++ .../unique_ops/python/unique_ops.py | 23 + deepray/workspace2.bzl | 9 + third_party/sparsehash_c11.BUILD | 13 + 18 files changed, 2059 insertions(+) create mode 100644 deepray/custom_ops/unique_ops/BUILD create mode 100644 deepray/custom_ops/unique_ops/__init__.py create mode 100644 deepray/custom_ops/unique_ops/cc/kernels/random.cc create mode 100644 deepray/custom_ops/unique_ops/cc/kernels/random.h create mode 100644 deepray/custom_ops/unique_ops/cc/kernels/random_test.cc create mode 100644 deepray/custom_ops/unique_ops/cc/kernels/task_runner.h create mode 100644 deepray/custom_ops/unique_ops/cc/kernels/unique_ali_op.cc create mode 100644 deepray/custom_ops/unique_ops/cc/kernels/unique_ali_op_gpu.cu.cc create mode 100644 deepray/custom_ops/unique_ops/cc/kernels/unique_ali_op_util.h create mode 100644 deepray/custom_ops/unique_ops/cc/ops/unique_ops.cc create mode 100644 deepray/custom_ops/unique_ops/python/__init__.py create mode 100644 deepray/custom_ops/unique_ops/python/tests/__init__.py create mode 100644 deepray/custom_ops/unique_ops/python/tests/run_all_test.py create mode 100644 deepray/custom_ops/unique_ops/python/tests/unique_op_test.py create mode 100644 deepray/custom_ops/unique_ops/python/unique_ops.py create mode 100644 third_party/sparsehash_c11.BUILD diff --git a/deepray/custom_ops/BUILD b/deepray/custom_ops/BUILD index 59d655f..4b0226f 100644 --- a/deepray/custom_ops/BUILD +++ b/deepray/custom_ops/BUILD @@ -14,6 +14,7 @@ py_library( "//deepray/custom_ops/simple_hash_table", "//deepray/custom_ops/sleep:sleep_op", "//deepray/custom_ops/training_ops", + "//deepray/custom_ops/unique_ops", "//deepray/custom_ops/zero_out:zero_out_ops", ], ) diff --git a/deepray/custom_ops/unique_ops/BUILD b/deepray/custom_ops/unique_ops/BUILD new file mode 100644 index 0000000..e240e9a --- /dev/null +++ b/deepray/custom_ops/unique_ops/BUILD @@ -0,0 +1,77 @@ +load("//deepray:deepray.bzl", "custom_op_library") +load("@local_config_tf//:build_defs.bzl", "CPLUSPLUS_VERSION") + +licenses(["notice"]) # Apache 2.0 + +package( + default_visibility = [ + "//deepray:__subpackages__", + ], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "random", + srcs = [ + "cc/kernels/random.cc", + "cc/kernels/random.h", + ], + copts = [CPLUSPLUS_VERSION], + deps = [ + "@local_config_tf//:libtensorflow_framework", + "@local_config_tf//:tf_header_lib", + ], +) + +cc_test( + name = "random_test", + srcs = ["cc/kernels/random_test.cc"], + deps = [ + ":random", + "@com_google_googletest//:gtest_main", + ], +) + +custom_op_library( + name = "_unique_ops.so", + srcs = [ + "cc/kernels/task_runner.h", + "cc/kernels/unique_ali_op.cc", + "cc/kernels/unique_ali_op_util.h", + "cc/ops/unique_ops.cc", + ], + copts = [CPLUSPLUS_VERSION], + cuda_srcs = [ + "cc/kernels/unique_ali_op_gpu.cu.cc", + ], + visibility = ["//visibility:public"], + deps = [ + ":random", + "@com_google_absl//absl/container:flat_hash_map", + "@sparsehash_c11//:dense_hash_map", + ], +) + +py_library( + name = "unique_ops", + srcs = glob( + [ + "python/*.py", + "python/**/*.py", + "*.py", + ], + ), + data = [ + ":_unique_ops.so", + ], +) + +py_test( + name = "unique_ops_test", + size = "small", + srcs = glob(["python/tests/*"]), + main = "python/tests/run_all_test.py", + deps = [ + ":unique_ops", + ], +) diff --git a/deepray/custom_ops/unique_ops/__init__.py b/deepray/custom_ops/unique_ops/__init__.py new file mode 100644 index 0000000..8bd0ed2 --- /dev/null +++ b/deepray/custom_ops/unique_ops/__init__.py @@ -0,0 +1 @@ +from deepray.custom_ops.unique_ops.python.unique_ops import gen_array_ops diff --git a/deepray/custom_ops/unique_ops/cc/kernels/random.cc b/deepray/custom_ops/unique_ops/cc/kernels/random.cc new file mode 100644 index 0000000..1bf8491 --- /dev/null +++ b/deepray/custom_ops/unique_ops/cc/kernels/random.cc @@ -0,0 +1,58 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "random.h" + +#include + +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/env_var.h" + +namespace tensorflow { +namespace random { + +namespace { +std::mt19937_64* InitRngWithRandomSeed() { + std::random_device device("/dev/urandom"); + return new std::mt19937_64(device()); +} +std::mt19937_64 InitRngWithDefaultSeed() { return std::mt19937_64(); } + +} // anonymous namespace + +uint64 New64() { + static std::mt19937_64* rng = InitRngWithRandomSeed(); + static mutex mu(LINKER_INITIALIZED); + mutex_lock l(mu); + return (*rng)(); +} + +uint64 New64DefaultSeed() { + static std::mt19937_64 rng = InitRngWithDefaultSeed(); + static mutex mu(LINKER_INITIALIZED); + mutex_lock l(mu); + return rng(); +} + +uint64 New64Configuable() { + int64 random_64; + CHECK( + ReadInt64FromEnvVar("DEEPREC_CONFIG_RAND_64", New64(), &random_64).ok()); + return static_cast(random_64); +} + +} // namespace random +} // namespace tensorflow diff --git a/deepray/custom_ops/unique_ops/cc/kernels/random.h b/deepray/custom_ops/unique_ops/cc/kernels/random.h new file mode 100644 index 0000000..29aae90 --- /dev/null +++ b/deepray/custom_ops/unique_ops/cc/kernels/random.h @@ -0,0 +1,40 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LIB_RANDOM_RANDOM_H_ +#define TENSORFLOW_LIB_RANDOM_RANDOM_H_ + +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace random { + +// Return a 64-bit random value. Different sequences are generated +// in different processes. +uint64 New64(); + +// Return a 64-bit random value. Uses +// std::mersenne_twister_engine::default_seed as seed value. +uint64 New64DefaultSeed(); + +// Call New64 to generate a 64-bit random value +// if env var DEEPREC_CONFIG_RAND_64 not set. +// Otherwise, return int64 from DEEPREC_CONFIG_RAND_64 +uint64 New64Configuable(); + +} // namespace random +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_RANDOM_RANDOM_H_ diff --git a/deepray/custom_ops/unique_ops/cc/kernels/random_test.cc b/deepray/custom_ops/unique_ops/cc/kernels/random_test.cc new file mode 100644 index 0000000..d37c47e --- /dev/null +++ b/deepray/custom_ops/unique_ops/cc/kernels/random_test.cc @@ -0,0 +1,37 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/lib/random/random.h" + +#include + +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace random { +namespace { + +TEST(New64Test, SanityCheck) { + std::set values; + for (int i = 0; i < 1000000; i++) { + uint64 x = New64(); + EXPECT_TRUE(values.insert(x).second) << "duplicate " << x; + } +} + +} // namespace +} // namespace random +} // namespace tensorflow diff --git a/deepray/custom_ops/unique_ops/cc/kernels/task_runner.h b/deepray/custom_ops/unique_ops/cc/kernels/task_runner.h new file mode 100644 index 0000000..922f059 --- /dev/null +++ b/deepray/custom_ops/unique_ops/cc/kernels/task_runner.h @@ -0,0 +1,116 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_TASK_RUNNER_H_ +#define TENSORFLOW_CORE_KERNELS_TASK_RUNNER_H_ + +#include + +#include "tensorflow/core/lib/core/blocking_counter.h" +#include "tensorflow/core/lib/core/threadpool.h" + +namespace tensorflow { + +// TaskRunner schedules tasks(function f) to ThreadPool +// and wait until all finished +class TaskRunner { + public: + explicit TaskRunner(const std::function& f, + thread::ThreadPool* tp, int32 n) + : func_(f), thread_pool_(tp), num_tasks_(n) {} + + void Run() { + if (num_tasks_ <= 0) return; + BlockingCounter bc(num_tasks_ - 1); + + // Sending (num_tasks - 1) tasks to threadpool for scheduling + for (int32 i = 0; i < num_tasks_ - 1; ++i) { + thread_pool_->Schedule([this, &bc, i]() { + func_(i, num_tasks_); + bc.DecrementCount(); + }); + } + // Run the last task in current thread. + func_(num_tasks_ - 1, num_tasks_); + bc.Wait(); + } + + private: + std::function func_; + thread::ThreadPool* thread_pool_; + const int32 num_tasks_; +}; + +// add more types of SummaryUpdater +// for more types of summary or more ways of summary aggregation +class StatusSummaryUpdater { + public: + static void UpdateSummary(Status* mine, const Status& ret) { + mine->Update(ret); + } +}; + +class Int64SumSummaryUpdater { + public: + static void UpdateSummary(int64_t* mine, const int64_t& ret) { *mine += ret; } +}; + +// SummaryTaskRunner schedules tasks and summary their return values. +// S is the type of return values. +// SUpdater is the class for aggregating the return values. +template +class SummaryTaskRunner { + public: + explicit SummaryTaskRunner(const std::function& f, + const S& init_summary, thread::ThreadPool* tp, + int32 n) + : func_(f), summary_(init_summary), thread_pool_(tp), num_tasks_(n) {} + + void Run() { + if (num_tasks_ <= 0) return; + BlockingCounter bc(num_tasks_ - 1); + + // Sending (num_tasks - 1) tasks to threadpool for scheduling + for (int32 i = 0; i < num_tasks_ - 1; ++i) { + thread_pool_->Schedule([this, &bc, i]() { + const S& ret = func_(i, num_tasks_); + UpdateSummaryUnlocked(ret); + bc.DecrementCount(); + }); + } + // Run the last task in current thread. + const S& ret = func_(num_tasks_ - 1, num_tasks_); + UpdateSummaryUnlocked(ret); + bc.Wait(); + } + + S summary() { return summary_; } + + private: + void UpdateSummaryUnlocked(const S& ret) { + mutex_lock lock(mu_); + SUpdater::UpdateSummary(&summary_, ret); + } + + mutex mu_; + std::function func_; + S summary_; + thread::ThreadPool* thread_pool_; + const int32 num_tasks_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_TASK_RUNNER_H_ diff --git a/deepray/custom_ops/unique_ops/cc/kernels/unique_ali_op.cc b/deepray/custom_ops/unique_ops/cc/kernels/unique_ali_op.cc new file mode 100644 index 0000000..122f294 --- /dev/null +++ b/deepray/custom_ops/unique_ops/cc/kernels/unique_ali_op.cc @@ -0,0 +1,242 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "sparsehash/dense_hash_map" +#include "task_runner.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/util/env_var.h" +#include "unique_ali_op_util.h" + +namespace tensorflow { + +namespace { +const char *kUniqueOpSerialEnv = "DEEPREC_UNIQUE_OP_SERIAL"; +const char *kUniqueOpHashMapEnv = "DEEPREC_UNIQUE_OP_HASH_MAP"; +const char *kUniqueOpUniqRatioHint = "DEEPREC_UNIQUE_OP_UNIQ_RATIO_HINT"; +const char *kUniqueOpPartitionSizeEnv = "DEEPREC_UNIQUE_OP_PARTITION_SIZE"; +const char *kMultiMapString = "MULTIMAP"; +const char *kStlHashMapString = "STL"; +const char *kAbslHashMapString = "ABSL"; +const char *kGoogleHashMapString = "GOOGLE"; +const int64 kDefaultUniqueRatioHint = 4; +} // namespace + +template +class UniqueAliOp : public OpKernel { + public: + explicit UniqueAliOp(OpKernelConstruction *context) : OpKernel(context) { + OP_REQUIRES_OK( + context, ReadInt64FromEnvVar(kUniqueOpPartitionSizeEnv, kPartitionSize, + &partition_size_)); + OP_REQUIRES( + context, partition_size_ > 0, + errors::InvalidArgument("Invaild PARTITION_SIZE=", partition_size_)); + + OP_REQUIRES_OK(context, + ReadBoolFromEnvVar(kUniqueOpSerialEnv, false, &serial_)); + + // NOTE(zycao>: Hash map insertion and lookup performance is dominating in + // Unique Op. Based on benchmark results, 'google::dense_hash_map' will be + // used as default for most key types except string. + // + // By setting "DEEPREC_UNIQUE_OP_HASH_MAP" environment variable, a + // particular hash map could be seleteed to use. Possible choices are listed + // below: + // "MULTIMAP" for multimap parrallel process, + // "STL" for std::unordred_map, + // "ABSL" for absl::flat_hash_map, + // "GOOGLE" for google::dense_hash_map. + std::string hash_map_str; + OP_REQUIRES_OK( + context, ReadStringFromEnvVar(kUniqueOpHashMapEnv, kGoogleHashMapString, + &hash_map_str)); + std::transform(hash_map_str.begin(), hash_map_str.end(), + hash_map_str.begin(), ::toupper); + + OP_REQUIRES_OK(context, ReadInt64FromEnvVar(kUniqueOpUniqRatioHint, + kDefaultUniqueRatioHint, + &unique_ratio_hint_)); + OP_REQUIRES(context, unique_ratio_hint_ > 0, + errors::InvalidArgument("Invaild ", kUniqueOpUniqRatioHint, "=", + unique_ratio_hint_)); + + if (!hash_map_str.compare(kMultiMapString)) { + map_flag_ = MULTIMAP; + static char print_once = [] { + LOG(INFO) << "MultiMapCompute preserved " + "dense hash map key: " + << kPreseverdEmptyKey; + return '\0'; + }(); + } else if (!hash_map_str.compare(kStlHashMapString)) { + map_flag_ = STL; + } else if (!hash_map_str.compare(kAbslHashMapString)) { + map_flag_ = ABSL; + } else if (!hash_map_str.compare(kGoogleHashMapString)) { + map_flag_ = GOOGLE; + } else { + map_flag_ = GOOGLE; + } + } + + void Compute(OpKernelContext *context) override { + VLOG(2) << "Unique V2 executed"; + ComputeInternal(context); + } + + private: + void ComputeInternal(OpKernelContext *context) { + const Tensor &input = context->input(0); + Tensor idx; + Tensor output; + Tensor output_counter; + if (context->num_inputs() == 1) { + UniqueWithoutAxis( + context, input, &idx, &output, &output_counter, num_outputs(), + partition_size_, serial_, unique_ratio_hint_, map_flag_); + } else { + const Tensor &axis_tensor = context->input(1); + UniqueWithAxis(context, input, axis_tensor, &idx, &output, + &output_counter, num_outputs(), partition_size_, + serial_, unique_ratio_hint_, map_flag_); + } + context->set_output(0, output); + context->set_output(1, idx); + if (num_outputs() > 2) { + context->set_output(2, output_counter); + } + } + + bool serial_ = false; + int64 partition_size_ = 0; + int64 unique_ratio_hint_; + UniqueMaps map_flag_ = GOOGLE; // "GOOGLE" dense hash map is default +}; + +#define REGISTER_UNIQUE(type) \ + REGISTER_KERNEL_BUILDER(Name("Deepray>Unique") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .TypeConstraint("out_idx"), \ + UniqueAliOp); \ + REGISTER_KERNEL_BUILDER(Name("Deepray>Unique") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .TypeConstraint("out_idx"), \ + UniqueAliOp); \ + REGISTER_KERNEL_BUILDER(Name("Deepray>UniqueV2") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .TypeConstraint("out_idx"), \ + UniqueAliOp); \ + REGISTER_KERNEL_BUILDER(Name("Deepray>UniqueV2") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .TypeConstraint("out_idx"), \ + UniqueAliOp); \ + REGISTER_KERNEL_BUILDER(Name("Deepray>UniqueWithCounts") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .TypeConstraint("out_idx"), \ + UniqueAliOp) \ + REGISTER_KERNEL_BUILDER(Name("Deepray>UniqueWithCounts") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .TypeConstraint("out_idx"), \ + UniqueAliOp); \ + REGISTER_KERNEL_BUILDER(Name("Deepray>UniqueWithCountsV2") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .TypeConstraint("out_idx"), \ + UniqueAliOp) \ + REGISTER_KERNEL_BUILDER(Name("Deepray>UniqueWithCountsV2") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .TypeConstraint("out_idx"), \ + UniqueAliOp) +TF_CALL_REAL_NUMBER_TYPES(REGISTER_UNIQUE); +REGISTER_UNIQUE(tstring) +#undef REGISTER_UNIQUE + +#if GOOGLE_CUDA +#define REGISTER_UNIQUE(type) \ + REGISTER_KERNEL_BUILDER(Name("Deepray>UniqueWithCounts") \ + .Device(DEVICE_GPU) \ + .HostMemory("x") \ + .HostMemory("y") \ + .HostMemory("idx") \ + .HostMemory("count") \ + .TypeConstraint("T") \ + .TypeConstraint("out_idx"), \ + UniqueAliOp) \ + REGISTER_KERNEL_BUILDER(Name("Deepray>UniqueWithCounts") \ + .Device(DEVICE_GPU) \ + .HostMemory("x") \ + .HostMemory("y") \ + .HostMemory("idx") \ + .HostMemory("count") \ + .TypeConstraint("T") \ + .TypeConstraint("out_idx"), \ + UniqueAliOp); +TF_CALL_REAL_NUMBER_TYPES(REGISTER_UNIQUE); +REGISTER_UNIQUE(tstring) +#undef REGISTER_UNIQUE +#endif // GOOGLE_CUDA + +#ifdef TENSORFLOW_USE_SYCL +REGISTER_KERNEL_BUILDER(Name("Deepray>Unique") + .Device(DEVICE_SYCL) + .TypeConstraint("T") + .TypeConstraint("out_idx") + .HostMemory("x") + .HostMemory("y") + .HostMemory("idx"), + UniqueAliOp); +REGISTER_KERNEL_BUILDER(Name("Deepray>Unique") + .Device(DEVICE_SYCL) + .TypeConstraint("T") + .TypeConstraint("out_idx") + .HostMemory("x") + .HostMemory("y") + .HostMemory("idx"), + UniqueAliOp); +REGISTER_KERNEL_BUILDER(Name("Deepray>Unique") + .Device(DEVICE_SYCL) + .TypeConstraint("T") + .TypeConstraint("out_idx") + .HostMemory("x") + .HostMemory("y") + .HostMemory("idx"), + UniqueAliOp); +REGISTER_KERNEL_BUILDER(Name("Deepray>Unique") + .Device(DEVICE_SYCL) + .TypeConstraint("T") + .TypeConstraint("out_idx") + .HostMemory("x") + .HostMemory("y") + .HostMemory("idx"), + UniqueAliOp); +#endif // TENSORFLOW_USE_SYCL +} // namespace tensorflow diff --git a/deepray/custom_ops/unique_ops/cc/kernels/unique_ali_op_gpu.cu.cc b/deepray/custom_ops/unique_ops/cc/kernels/unique_ali_op_gpu.cu.cc new file mode 100644 index 0000000..9bb0f91 --- /dev/null +++ b/deepray/custom_ops/unique_ops/cc/kernels/unique_ali_op_gpu.cu.cc @@ -0,0 +1,277 @@ +/* Copyright 2022 The DeepRec Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#if GOOGLE_CUDA +#define EIGEN_USE_GPU + +#include + +#include "cub/device/device_radix_sort.cuh" +#include "cub/device/device_scan.cuh" +#include "cub/device/device_select.cuh" +#include "cub/iterator/constant_input_iterator.cuh" +#include "cub/iterator/counting_input_iterator.cuh" +#include "cub/iterator/transform_input_iterator.cuh" +#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" +#include "tensorflow/core/framework/bounds_check.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/platform/cuda.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" +#include "tensorflow/core/util/gpu_solvers.h" // For ScratchSpace +#include "tensorflow/stream_executor/stream_executor.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { +using GPUDevice = Eigen::GpuDevice; + +// Returns true iff index is at the end of a segment (which is equivalent to the +// beginning of the next segment). +template +struct SegmentIndicatorFunctor { + const TKey* __restrict__ sorted_input_ptr_; + SegmentIndicatorFunctor(const TKey* sorted_input_ptr) + : sorted_input_ptr_(sorted_input_ptr) {} + __device__ bool operator()(const TIndex& i) const { + return i > 0 && sorted_input_ptr_[i] != sorted_input_ptr_[i - 1]; + } +}; + +template +__global__ void RangeInitKernel(const TIndex start, const TIndex delta, + const int64 size, TIndex* out) { + GPU_1D_KERNEL_LOOP(i, size) { out[i] = start + i * delta; } +} +template +__global__ void MoveValuesKernel(const TIndex* keys, const TIndex* values, + const int64 size, TIndex* out) { + GPU_1D_KERNEL_LOOP(i, size) { + TIndex key = ldg(keys + i); + out[key] = ldg(values + i); + } +} +template +__global__ void MoveValuesKernel(const TIndex* keys, const TIndex* values, + const int64* size_ptr, TIndex* out) { + int64 size = ldg(size_ptr); + GPU_1D_KERNEL_LOOP(i, size) { + TIndex key = ldg(keys + i); + out[key] = ldg(values + i); + } +} +template +__global__ void MoveSparseValuesKernel(const TIndex* keys, const TIndex* idxs, + const T* values, const int64 size, + T* out) { + GPU_1D_KERNEL_LOOP(i, size) { + TIndex key = ldg(keys + i); + TIndex idx = ldg(idxs + i); + out[key] = ldg(values + idx); + } +} +template +__global__ void CompareAdjacentKernel(const T* in, const int64 size, + TIndex* out) { + GPU_1D_KERNEL_LOOP(i, size) { + out[i] = (i == 0 || ldg(in + (i - 1)) == ldg(in + i)) ? 0 : 1; + } +} +template +void RangeInit(const GPUDevice& d, const TIndex start, const TIndex delta, + const int64 size, TIndex* out) { + GpuLaunchConfig config = GetGpuLaunchConfig(size, d); + RangeInitKernel + <<>>( + start, delta, size, out); +} +template +void MoveValues(const GPUDevice& d, const TIndex* keys, const TIndex* values, + const int64 size, TIndex* out) { + GpuLaunchConfig config = GetGpuLaunchConfig(size, d); + MoveValuesKernel<<>>(keys, values, size, out); +} + +template +void MoveValues(const GPUDevice& d, const TIndex* keys, const TIndex* values, + const int64* size_ptr, const int64 size, TIndex* out) { + GpuLaunchConfig config = GetGpuLaunchConfig(size, d); + MoveValuesKernel<<>>(keys, values, size_ptr, out); +} +template +void MoveSparseValues(const GPUDevice& d, const TIndex* keys, + const TIndex* idxs, const T* values, const int64 size, + T* out) { + GpuLaunchConfig config = GetGpuLaunchConfig(size, d); + MoveSparseValuesKernel<<>>(keys, idxs, values, size, out); +} +template +void CompareAdjacent(const GPUDevice& d, const T* in, const int64 size, + TIndex* out) { + GpuLaunchConfig config = GetGpuLaunchConfig(size, d); + CompareAdjacentKernel + <<>>(in, size, + out); +} +template +class UniqueAliV2GpuOp : public AsyncOpKernel { + public: + explicit UniqueAliV2GpuOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {} + void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { + const Tensor& input_tensor = ctx->input(0); + const T* keys = input_tensor.flat().data(); + + int64 N = input_tensor.NumElements(); + const GPUDevice& device = ctx->eigen_device(); + const cudaStream_t& cu_stream = GetGpuStream(ctx); + + Tensor* output_tensor = nullptr; + Tensor* idx_tensor = nullptr; + Tensor* part_tensor = nullptr; + auto allocate_output = [ctx, &output_tensor, &idx_tensor, &part_tensor, N, + &device, this](int64 N_out) { + TF_RETURN_IF_ERROR(ctx->allocate_output(0, {N_out}, &output_tensor)); + TF_RETURN_IF_ERROR(ctx->allocate_output(1, {N}, &idx_tensor)); + return Status::OK(); + }; + if (N == 0) { + OP_REQUIRES_OK_ASYNC(ctx, allocate_output(0), done); + done(); + return; + } + + Tensor keys_sort_tensor; + Tensor indicies_sort_tensor; + OP_REQUIRES_OK_ASYNC( + ctx, + ctx->allocate_temp(DataTypeToEnum::value, {N}, &keys_sort_tensor), + done); + OP_REQUIRES_OK_ASYNC(ctx, + ctx->allocate_temp(DataTypeToEnum::value, {N}, + &indicies_sort_tensor), + done); + T* keys_sort = keys_sort_tensor.flat().data(); + TIndex* indicies_sort = indicies_sort_tensor.flat().data(); + + Tensor indices_in_tensor; + OP_REQUIRES_OK_ASYNC(ctx, + ctx->allocate_temp(DataTypeToEnum::value, {N}, + &indices_in_tensor), + done); + TIndex* indices_in = indices_in_tensor.flat().data(); + RangeInit(device, (TIndex)0, (TIndex)1, N, indices_in); + + { + const T* keys_in; + Tensor keys_in_tensor; + keys_in = keys; + using U = typename std::make_unsigned::type; + const U* keys_u_in = reinterpret_cast(keys_in); + U* keys_u_sort = reinterpret_cast(keys_sort); + + Tensor cub_temp_storage; + size_t temp_storage_bytes = 0; + cub::DeviceRadixSort::SortPairs(NULL, temp_storage_bytes, keys_u_in, + keys_u_sort, indices_in, indicies_sort, N, + 0, sizeof(T) * 8, cu_stream); + OP_REQUIRES_OK_ASYNC( + ctx, + ctx->allocate_temp( + DT_INT8, TensorShape({static_cast(temp_storage_bytes)}), + &cub_temp_storage), + done); + cub::DeviceRadixSort::SortPairs(cub_temp_storage.flat().data(), + temp_storage_bytes, keys_u_in, + keys_u_sort, indices_in, indicies_sort, N, + 0, sizeof(T) * 8, cu_stream); + } + + Tensor output_indices_tensor; + OP_REQUIRES_OK_ASYNC(ctx, + ctx->allocate_temp(DataTypeToEnum::value, {N}, + &output_indices_tensor), + done); + TIndex* output_indices = output_indices_tensor.flat().data(); + + { + cub::TransformInputIterator, + cub::CountingInputIterator> + segment_indicator_iter(0, {keys_sort}); + Tensor cub_temp_storage; + size_t temp_storage_bytes = 0; + cub::DeviceScan::InclusiveSum(NULL, temp_storage_bytes, + segment_indicator_iter, output_indices, N, + cu_stream); + OP_REQUIRES_OK_ASYNC( + ctx, + ctx->allocate_temp( + DT_INT8, TensorShape({static_cast(temp_storage_bytes)}), + &cub_temp_storage), + done); + cub::DeviceScan::InclusiveSum(cub_temp_storage.flat().data(), + temp_storage_bytes, segment_indicator_iter, + output_indices, N, cu_stream); + } + auto* stream = ctx->op_device_context()->stream(); + OP_REQUIRES_ASYNC(ctx, stream, errors::Internal("No GPU stream available."), + done); + ScratchSpace N_out(ctx, 1, /*on_host=*/true); + se::DeviceMemoryBase wrapped_num_out(output_indices + (N - 1), + sizeof(TIndex)); + TensorReference ref_output_indices(output_indices_tensor); + OP_REQUIRES_ASYNC( + ctx, + stream + ->ThenMemcpy(N_out.mutable_data(), wrapped_num_out, sizeof(TIndex)) + .ok(), + errors::Internal("Failed to launch copy from device to host."), done); + ctx->device()->tensorflow_gpu_device_info()->event_mgr->ThenExecute( + stream, [ref_output_indices]() { ref_output_indices.Unref(); }); + stream->BlockHostUntilDone(); + int64_t uniq_size = (*N_out.data()) + 1; + OP_REQUIRES_OK_ASYNC(ctx, allocate_output(uniq_size), done); + T* output = output_tensor->flat().data(); + TIndex* idx = idx_tensor->flat().data(); + MoveValues(device, indicies_sort, output_indices, N, idx); + MoveSparseValues(device, output_indices, indicies_sort, keys, N, output); + done(); + } +}; + +#define REGISTER_UNIQUE_ALI_V2_GPU_KERNEL(T, TIndex) \ + REGISTER_KERNEL_BUILDER(Name("Deepray>Unique") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T") \ + .TypeConstraint("out_idx"), \ + UniqueAliV2GpuOp) +#define REGISTER_UNIQUE_ALI_V2_GPU(T) \ + REGISTER_UNIQUE_ALI_V2_GPU_KERNEL(T, int32); \ + REGISTER_UNIQUE_ALI_V2_GPU_KERNEL(T, int64) + +TF_CALL_int32(REGISTER_UNIQUE_ALI_V2_GPU); +TF_CALL_int64(REGISTER_UNIQUE_ALI_V2_GPU); +TF_CALL_uint32(REGISTER_UNIQUE_ALI_V2_GPU); +TF_CALL_uint64(REGISTER_UNIQUE_ALI_V2_GPU); + +#undef REGISTER_UNIQUE_ALI_V2_GPU +#undef REGISTER_UNIQUE_ALI_V2_GPU_KERNEL + +} // namespace tensorflow +#endif // GOOGLE_CUDA diff --git a/deepray/custom_ops/unique_ops/cc/kernels/unique_ali_op_util.h b/deepray/custom_ops/unique_ops/cc/kernels/unique_ali_op_util.h new file mode 100644 index 0000000..c27afd2 --- /dev/null +++ b/deepray/custom_ops/unique_ops/cc/kernels/unique_ali_op_util.h @@ -0,0 +1,768 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_UNIQUE_ALI_OP_UTIL_H_ +#define TENSORFLOW_CORE_KERNELS_UNIQUE_ALI_OP_UTIL_H_ + +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "random.h" +#include "sparsehash/dense_hash_map" +#include "task_runner.h" +#include "tensorflow/core/framework/bounds_check.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/lib/core/blocking_counter.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/util/env_var.h" +#include "tensorflow/core/util/work_sharder.h" + +namespace tensorflow { + +#define likely(x) __builtin_expect(!!(x), 1) +#define unlikely(x) __builtin_expect(!!(x), 0) + +namespace { +const int64 kPartitionLimit = 14336; +const int64 kPartitionSize = 8192; +const int64_t kPreseverdEmptyKey = tensorflow::random::New64Configuable(); + +typedef enum { MULTIMAP = 0, STL = 1, ABSL = 2, GOOGLE = 3 } UniqueMaps; + +} // namespace + +template +const T InvalidHashKey() { + return std::numeric_limits::max(); +} + +template +struct HashMapInitializer { + static void InitSize(HashMap* hash_map, int64 capacity) { + hash_map->reserve(2 * capacity); + } +}; + +template +struct HashMapInitializer> { + static void InitSize(google::dense_hash_map* hash_map, int64 capacity) { + hash_map->set_empty_key(InvalidHashKey()); + hash_map->resize(2 * capacity); + } + static void Reserve(google::dense_hash_map* hash_map, int64 capacity) { + hash_map->set_empty_key(InvalidHashKey()); + hash_map->resize(capacity); + } +}; + +template +struct HashMapInitializer> { + static void Reserve(google::dense_hash_map* hash_map, + int64 capacity) { + hash_map->set_empty_key(kPreseverdEmptyKey); + hash_map->resize(capacity); + } +}; + +struct Range { + public: + explicit Range(int64 start, int64 end) : start_(start), end_(end) {} + inline const int64 Start() const { return start_; } + inline const int64 End() const { return end_; } + inline const int64 Size() const { return end_ - start_; } + + private: + const int64 start_, end_; +}; + +struct Partitioner { + public: + explicit Partitioner(int64 work_size, int32 num_parts) { + if (work_size <= 0 || num_parts <= 0) { + return; + } + num_parts_ = num_parts; + parts_.reserve(num_parts); + int64 start = 0; + for (int32 i = 0; i < num_parts; ++i) { + int64 end = start + (work_size + i) / num_parts; + parts_.emplace_back(Range(start, end)); + start = end; + } + } + + const Range* GetRange(const int32 id) const { + if (id < 0 || id >= num_parts_) { + return nullptr; + } + return &parts_[id]; + } + + bool LocatePos(const int64 pos, int32* task_id) const { + for (int32 i = 0; i < num_parts_; ++i) { + if (pos >= parts_[i].Start() && pos < parts_[i].End()) { + *task_id = i; + return true; + } + } + return false; + } + + private: + std::vector parts_; + int32 num_parts_ = 0; +}; + +namespace { +struct IdHash : public std::hash { + inline std::size_t operator()(int64 const& i) const noexcept { + size_t x = (i ^ (i >> 30)) * 0xbf58476d1ce4e5b9ULL; + x = (x ^ (x >> 27)) * 0x94d049bb133111ebULL; + x = x ^ (x >> 31); + return x; + } +}; +} // namespace + +namespace { +void NewSizes(OpKernelContext* context, const Tensor& input, + const Tensor& axis_tensor, std::vector& new_sizes, + int64& axis) { + OP_REQUIRES(context, TensorShapeUtils::IsVector(axis_tensor.shape()), + errors::InvalidArgument("axis expects a 1D vector.")); + OP_REQUIRES( + context, axis_tensor.NumElements() <= 1, + errors::InvalidArgument( + "axis does not support input tensors larger than 1 elements")); + + if (axis_tensor.NumElements() == 0) { + OP_REQUIRES(context, TensorShapeUtils::IsVector(input.shape()), + errors::InvalidArgument("unique expects a 1D vector.")); + } else { + OP_REQUIRES( + context, + (axis_tensor.dtype() == DT_INT32 || axis_tensor.dtype() == DT_INT64), + errors::InvalidArgument( + "axis tensor should be int32 or int64, but got ", + DataTypeString(axis_tensor.dtype()))); + if (axis_tensor.dtype() == DT_INT32) { + axis = internal::SubtleMustCopy(axis_tensor.scalar()()); + } else { + axis = internal::SubtleMustCopy(axis_tensor.scalar()()); + } + axis = axis < 0 ? axis + input.dims() : axis; + OP_REQUIRES(context, 0 <= axis && axis < input.dims(), + errors::InvalidArgument("axis has to be between [0, ", + input.dims(), ")")); + if (axis > 0) { + for (int64 i = 0; i < axis; i++) { + new_sizes[0] *= input.dim_size(i); + } + } + new_sizes[1] = input.dim_size(axis); + if (axis + 1 < input.dims()) { + for (int64 i = axis + 1; i < input.dims(); i++) { + new_sizes[2] *= input.dim_size(i); + } + } + } +} +} // namespace + +template +void SerialComputeV1(OpKernelContext* context, const Tensor& input, Tensor* idx, + int64 axis, int64* uniq_size, Tensor* output) { + auto Tin = input.flat(); + const int64 N = input.NumElements(); + auto idx_vec = idx->template vec(); + + HashMap uniq; + HashMapInitializer::InitSize(&uniq, N); + for (int64 i = 0, j = 0; i < N; ++i) { + auto it = uniq.emplace(Tin(i), j); + idx_vec(i) = it.first->second; + if (it.second) { + ++j; + } + } + + *uniq_size = static_cast(uniq.size()); + TensorShape output_shape(input.shape()); + output_shape.set_dim(axis, *uniq_size); + AllocatorAttributes attr; + attr.set_on_host(true); + OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum::v(), + output_shape, output, attr)); + auto Tout = output->flat(); + + for (auto it : uniq) { + Tout(it.second) = it.first; + } +} + +template +void ParallelComputeV1(OpKernelContext* context, const Tensor& input, + Tensor* idx, int64 axis, int64* uniq_size, + Tensor* output) { + // Struct INode was used to store an inverse mapping for each node in the + // hash map container. + struct INode { + explicit INode(const TIndex index, const T& key) + : owner_ptr_(nullptr), index_(index), key_(key) {} + + const INode* owner_ptr_; + TIndex index_; + const T key_; + }; + + // Struct UniqueSubMap is used to build and operate local hash map and keep + // index mapping information. + struct UniqueSubMap { + public: + inline void Init(int64 size) { + next_index_ = 0; + HashMapInitializer::InitSize(&uniq_, size); + inodes_.reserve(size); + } + + inline void UniqueInsert(const T& key) { + auto it = uniq_.emplace(key, next_index_); + if (it.second) { + inodes_.emplace_back(INode(next_index_, key)); + ++next_index_; + } + } + + inline const INode* GetINodeByPos(const TIndex pos) const { + const INode* inode = &inodes_[pos]; + return (inode->owner_ptr_ == nullptr) ? inode : inode->owner_ptr_; + } + + inline const INode* GetINodeByKey(const T& key) const { + auto item = uniq_.find(key); + if (item != uniq_.end()) { + return GetINodeByPos(item->second); + } + return nullptr; + } + + bool DeDup(const TIndex pos, const UniqueSubMap& prior_map) { + INode* my_inode = &inodes_[pos]; + if (my_inode->owner_ptr_ != nullptr) { + return false; + } + const INode* prior_inode = prior_map.GetINodeByKey(my_inode->key_); + if (prior_inode == nullptr) { + return false; + } + my_inode->owner_ptr_ = prior_inode; + return true; + } + + bool TryIndexAndGetKey(const TIndex pos, const TIndex new_id, T* out) { + INode* inode = &inodes_[pos]; + if (inode->owner_ptr_ != nullptr) { + return false; + } + inode->index_ = new_id; + *out = inode->key_; + return true; + } + + inline int64 Size() const { return static_cast(next_index_); } + + private: + TIndex next_index_; + HashMap uniq_; + std::vector inodes_; + }; + + // NOTE(zycao): A four-step scheme is adopted for parallel unique computing. + // Step 1: Seperate input data into T1 sections. build individual local hash + // maps M(0 .. (T1 - 1)) for each section. + // Step 2: Mark and count duplicated keys accross all T1 hash maps. For each + // key stored in hasp map M(i), it needs to do lookups from hash map + // M(0) to M(i-1) to check possible duplicates. Thus keys stored in + // M(i, i = 1 .. (T1 - 1) would be divided into T2 parts, and then + // processed simultanously in T2 tasks. + // Step 3: Calculate the global unique index for all keys, based on marking + // and counting result of Step 2. Hash maps would be processed by + // T1 tasks in parallel. + // Step 4: Fill the output Tensor with multiple tasks as many as possible. + // + // Since the complexity of Step (1,3) and Step 2 would be affected by the + // number of T1 in opposite direction. A simple deduction was done and it + // indicates that ideal T1 size should be in the order of O(T2 ^ 1/3 * c). + // >> T1_ideal ~= ((beta * max_threads) ^ 1/3) + 1/2 + // Here 'beta' is a factor used to approximately describe hash map lookup + // speed compared to insert operations. This result is adopted in current + // implemetation to decide Step 1 task size T1. + auto Tin = input.flat(); + const int64 N = input.NumElements(); + int32 max_threads = + context->device()->tensorflow_cpu_worker_threads()->num_threads; + auto thread_pool = + context->device()->tensorflow_cpu_worker_threads()->workers; + + // Parallel Step 1: Build hash maps. + const double factor = 10; // Suppose lookup is 10x faster than insert. + int32 max_tasks_t1 = static_cast(std::cbrt(factor * max_threads) + 1); + int32 num_tasks_t1 = std::max(std::min(max_threads, max_tasks_t1), 1); + VLOG(1) << "[UniqueParallel] Step 1 num_tasks: " << num_tasks_t1; + + Partitioner map_parter(N, num_tasks_t1); + std::vector uniq_maps(num_tasks_t1); + + auto MapBuildTask = [&Tin, &uniq_maps, &map_parter](int32 task_id, + int32 num_tasks) { + UniqueSubMap& uniq_map = uniq_maps[task_id]; + const Range* range = map_parter.GetRange(task_id); + uniq_map.Init(range->Size()); + for (int64 i = range->Start(); i < range->End(); ++i) { + uniq_map.UniqueInsert(Tin(i)); + } + }; + TaskRunner t1_runner(MapBuildTask, thread_pool, num_tasks_t1); + t1_runner.Run(); + + int64 est_dup_count_cost = 0; + for (int32 i = 0; i < num_tasks_t1; ++i) { + est_dup_count_cost += uniq_maps[i].Size() * i; + } + + // Parallel Step 2: Check and count duplicated keys. + int32 max_tasks_t2 = + (est_dup_count_cost + kPartitionSize - 1) / kPartitionSize; + int32 num_tasks_t2 = std::max(std::min(max_threads, max_tasks_t2), 1); + VLOG(1) << "[UniqueParallel] Step 2 num_tasks: " << num_tasks_t2; + + // Divide each of T1 hash maps into T2 parts, remember the offsets. + std::vector dups(num_tasks_t1 * num_tasks_t2, 0); + std::vector dup_parters; + dup_parters.reserve(num_tasks_t1); + for (int32 i = 0; i < num_tasks_t1; ++i) { + dup_parters.emplace_back(Partitioner(uniq_maps[i].Size(), num_tasks_t2)); + } + + auto DupCountTask = [&uniq_maps, &dups, &dup_parters, num_tasks_t1]( + int32 task_id, int32 num_tasks) { + // Using 3 layer loop to make all checks. + for (int32 prior_id = 0; prior_id < num_tasks_t1 - 1; ++prior_id) { + const UniqueSubMap& prior_map = uniq_maps[prior_id]; + for (int32 lat_id = prior_id + 1; lat_id < num_tasks_t1; ++lat_id) { + UniqueSubMap& lat_map = uniq_maps[lat_id]; + int64 dup_offsets = lat_id * num_tasks; + const Range* range = dup_parters[lat_id].GetRange(task_id); + for (int64 i = range->Start(); i < range->End(); ++i) { + if (lat_map.DeDup(i, prior_map)) { + ++dups[dup_offsets + task_id]; + } + } + } + } + }; + TaskRunner t2_runner(DupCountTask, thread_pool, num_tasks_t2); + t2_runner.Run(); + + // Calculate the global unique index numbers and global offset for every + // hash map based on duplication checking results. + std::vector global_offsets(num_tasks_t1, 0); + for (int32 i = 0; i < num_tasks_t1 - 1; ++i) { + global_offsets[i + 1] = global_offsets[i] + uniq_maps[i].Size(); + for (int32 j = 0; j < num_tasks_t2; ++j) { + global_offsets[i + 1] -= dups[i * num_tasks_t2 + j]; + } + } + int64 num_tot_indices = + global_offsets[num_tasks_t1 - 1] + uniq_maps[num_tasks_t1 - 1].Size(); + for (int32 j = 0; j < num_tasks_t2; ++j) { + num_tot_indices -= dups[(num_tasks_t1 - 1) * num_tasks_t2 + j]; + } + + // Parallel Step 3: Recalculate global index for all keys in all hash maps. + // Write the output keys Tensor at the same time. + *uniq_size = num_tot_indices; + TensorShape output_shape(input.shape()); + output_shape.set_dim(axis, num_tot_indices); + AllocatorAttributes attr; + attr.set_on_host(true); + OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum::v(), + output_shape, output, attr)); + auto key_output_vec = output->template vec(); + + auto GlobalIndexTask = [&key_output_vec, &uniq_maps, &global_offsets]( + int32 task_id, int32 num_tasks) { + TIndex cur_id = global_offsets[task_id]; + UniqueSubMap& uniq_map = uniq_maps[task_id]; + for (int64 i = 0; i < uniq_map.Size(); ++i) { + if (uniq_map.TryIndexAndGetKey(i, cur_id, &key_output_vec(cur_id))) { + ++cur_id; + } + } + }; + TaskRunner t3_runner(GlobalIndexTask, thread_pool, num_tasks_t1); + t3_runner.Run(); + + // Parallel Step 4: Write output indicies Tensor. + int32 max_tasks_t4 = (N + kPartitionSize - 1) / kPartitionSize; + int32 num_tasks_t4 = std::max(std::min(max_threads, max_tasks_t4), 1); + VLOG(1) << "[UniqueParallel] Step 4 num_tasks: " << num_tasks_t4; + + Partitioner fill_parter(N, num_tasks_t4); + auto idx_vec = idx->template vec(); + + auto OutputTask = [&Tin, &idx_vec, &uniq_maps, &fill_parter, &map_parter]( + int32 task_id, int32 num_tasks) { + const Range* out_range = fill_parter.GetRange(task_id); + int64 out_pos = out_range->Start(); + int32 map_id; + if (!map_parter.LocatePos(out_pos, &map_id)) { + return; + } + int64 map_range_end = map_parter.GetRange(map_id)->End(); + while (out_pos < out_range->End()) { + const INode* inode = uniq_maps[map_id].GetINodeByKey(Tin(out_pos)); + idx_vec(out_pos) = inode->index_; + ++out_pos; + if (out_pos == map_range_end && out_pos < out_range->End()) { + ++map_id; + map_range_end = map_parter.GetRange(map_id)->End(); + } + } + }; + TaskRunner t4_runner(OutputTask, thread_pool, num_tasks_t4); + t4_runner.Run(); +} + +template +void MultiMapCompute(OpKernelContext* context, const Tensor& input, Tensor* idx, + int64 axis, int64* uniq_size_out, int32 num_buckets, + int64 unique_ratio_hint, Tensor* output) { + auto Tin = input.vec(); + const int64 N = input.NumElements(); + + auto idx_vec = idx->template vec(); + + auto thread_pool = + context->device()->tensorflow_cpu_worker_threads()->workers; + + // Parallel Step 0: Partition. + int32 num_partitions = num_buckets; + std::unique_ptr partitions{new int64[num_partitions * num_buckets]}; + + static IdHash hasher; + Partitioner map_parter(N, num_partitions); + auto PartitionTask = [N, num_buckets, &Tin, &partitions, &map_parter, + &idx_vec](int32 task_id, int32 num_tasks) { + auto st = Status::OK(); + int64* partition = partitions.get() + task_id * num_buckets; + for (int64 i = 0; i < num_buckets; ++i) { + partition[i] = -1; + } + const Range* range = map_parter.GetRange(task_id); + + for (int64 i = range->Start(); i < range->End(); ++i) { + auto& id = Tin(i); + if (unlikely(id == kPreseverdEmptyKey)) { + st = errors::InvalidArgument( + "Input id is preserved key of dense_hash_map, " + "not supported: ", + id); + break; + } + int64 bucket = (hasher(id) >> 54) % num_buckets; + idx_vec(i) = partition[bucket]; + partition[bucket] = i; + } + return st; + }; + + SummaryTaskRunner t0_runner( + PartitionTask, Status::OK(), thread_pool, num_partitions); + t0_runner.Run(); + OP_REQUIRES_OK(context, t0_runner.summary()); + + // Parallel Step 1: Build hash maps. + HashMap uniq_maps_ptr[num_buckets]; + HashMap* uniq_maps = uniq_maps_ptr; + auto MapBuildTask = [N, num_partitions, &Tin, &partitions, uniq_maps, + &idx_vec, + unique_ratio_hint](int32 task_id, int32 num_tasks) { + auto& uniq = uniq_maps[task_id]; + HashMapInitializer::Reserve(&uniq, + N / num_tasks / unique_ratio_hint); + + for (int64 k = 0; k < num_partitions; ++k) { + int64* partition = partitions.get() + k * num_tasks + task_id; + int64 next_idx = *partition; + int64 cur_idx; + while (next_idx != -1) { + cur_idx = next_idx; + next_idx = idx_vec(cur_idx); + auto it = uniq.emplace(Tin(cur_idx), cur_idx); + if (!it.second) { + idx_vec(cur_idx) = it.first->second; + it.first->second = cur_idx; + } else { + idx_vec(cur_idx) = -1; + } + } + } + }; + + TaskRunner t1_runner(MapBuildTask, thread_pool, num_buckets); + t1_runner.Run(); + + // Calculate the global unique index numbers and global offset for every + // hash map. + std::vector global_offsets(num_buckets, 0); + for (int32 i = 0; i < num_buckets - 1; ++i) { + global_offsets[i + 1] = global_offsets[i] + uniq_maps[i].size(); + } + int64 uniq_size = + global_offsets[num_buckets - 1] + uniq_maps[num_buckets - 1].size(); + + *uniq_size_out = uniq_size; + AllocatorAttributes attr; + attr.set_on_host(true); + OP_REQUIRES_OK( + context, context->allocate_temp(DataTypeToEnum::v(), + TensorShape({uniq_size}), output, attr)); + auto key_output_vec = output->template vec(); + + auto OutputTask = [&key_output_vec, &uniq_maps, &global_offsets, &Tin, + &idx_vec, &map_parter](int32 task_id, int32 num_tasks) { + TIndex offset = global_offsets[task_id]; + for (auto iter = uniq_maps[task_id].begin(); + iter != uniq_maps[task_id].end(); ++iter) { + // output uniq id + key_output_vec(offset) = iter->first; + // output uniq index + int64 next_idx = iter->second; + int64 cur_idx; + while (next_idx != -1) { + cur_idx = next_idx; + next_idx = idx_vec(cur_idx); + idx_vec(cur_idx) = offset; + } + + ++offset; + } + }; + TaskRunner t2_runner(OutputTask, thread_pool, num_buckets); + t2_runner.Run(); +} + +template +void MultipleElements(OpKernelContext* context, const Tensor& input, + Tensor* idx, Tensor* output, int64* uniq_size, int64 axis, + std::vector& new_sizes) { + // General implementation when unique is run over multiple elements. + auto Tin = input.shaped(new_sizes); + auto idx_vec = idx->template vec(); + + auto hash_fn = [&Tin](const int64& key) { + size_t h = 0; + for (int64 i = 0; i < Tin.dimension(0); i++) { + for (int64 j = 0; j < Tin.dimension(2); j++) { + h = Hash64Combine(h, hash{}(Tin(i, key, j))); + } + } + return h; + }; + + auto equal_to_fn = [&Tin](const int64& lhs, const int64& rhs) { + for (int64 i = 0; i < Tin.dimension(0); i++) { + for (int64 j = 0; j < Tin.dimension(2); j++) { + if (Tin(i, lhs, j) != Tin(i, rhs, j)) { + return false; + } + } + } + return true; + }; + + std::unordered_map + uniq(0, hash_fn, equal_to_fn); + + uniq.reserve(2 * Tin.dimension(1)); + + for (int64 i = 0, j = 0; i < Tin.dimension(1); ++i) { + auto it = uniq.insert(std::make_pair(i, j)); + idx_vec(i) = it.first->second; + if (it.second) { + ++j; + } + } + *uniq_size = static_cast(uniq.size()); + new_sizes[1] = *uniq_size; + TensorShape output_shape(input.shape()); + output_shape.set_dim(axis, *uniq_size); + AllocatorAttributes attr; + attr.set_on_host(true); + OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum::v(), + output_shape, output, attr)); + auto Tout = output->shaped(new_sizes); + + for (auto it : uniq) { + Tout.chip(it.second, 1) = Tin.chip(it.first, 1); + } +} + +template +void CheckCountOutput(OpKernelContext* context, Tensor* output_counter, + Tensor* idx, int num_outputs, int64 uniq_size) { + if (num_outputs > 2) { + auto idx_vec = idx->template vec(); + AllocatorAttributes attr; + attr.set_on_host(true); + OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum::v(), + TensorShape({uniq_size}), + output_counter, attr)); + auto count_output_vec = output_counter->template vec(); + count_output_vec.setZero(); + const int N = idx_vec.size(); + for (int64 i = 0; i < N; ++i) { + count_output_vec(idx_vec(i))++; + } + } +} + +template +void ComputeInternalWithHashMap(OpKernelContext* context, const Tensor& input, + Tensor* idx, int64 axis, int64* uniq_size, + int64 N, bool serial, Tensor* output) { + OP_REQUIRES(context, TensorShapeUtils::IsVector(input.shape()), + errors::InvalidArgument("unique expects a 1D vector.")); + // TODO(dga): Make unique polymorphic for returning int32 and int64 + // vectors to support large tensors. + OP_REQUIRES(context, input.NumElements() <= std::numeric_limits::max(), + errors::InvalidArgument( + "unique does not support input tensors larger than ", + std::numeric_limits::max(), " elements")); + + if (N >= kPartitionLimit && !serial) { + ParallelComputeV1(context, input, idx, axis, uniq_size, + output); + } else { + SerialComputeV1(context, input, idx, axis, uniq_size, + output); + } +} + +template +void UniqueInternal(OpKernelContext* context, const Tensor& input, Tensor* idx, + Tensor* output, Tensor* output_counter, int num_outputs, + int64 partition_size, bool serial, int64 axis, + int64 unique_ratio_hint, std::vector& new_sizes, + UniqueMaps map_flag) { + typedef google::dense_hash_map DefaultHashMap; + + AllocatorAttributes attr; + attr.set_on_host(true); + OP_REQUIRES_OK( + context, context->allocate_temp(DataTypeToEnum::v(), + TensorShape({new_sizes[1]}), idx, attr)); + + int64 uniq_size_out; + + if (new_sizes[0] == 1 && new_sizes[2] == 1) { + // Specialized and faster implementation when unique is run over single + // elements. Here we put T directly into the map rather than ints pointing + // to them as in the general case. + auto Tin = input.vec(); + const int64 N = static_cast(Tin.size()); + int32 max_threads = + context->device()->tensorflow_cpu_worker_threads()->num_threads; + int32 num_buckets = + std::min(N / partition_size, static_cast(max_threads)); + + switch (map_flag) { + case MULTIMAP: + if (num_buckets > 1 && !serial) { + MultiMapCompute>( + context, input, idx, axis, &uniq_size_out, num_buckets, + unique_ratio_hint, output); + } else { + SerialComputeV1(context, input, idx, axis, + &uniq_size_out, output); + } + break; + case STL: + ComputeInternalWithHashMap>( + context, input, idx, axis, &uniq_size_out, N, serial, output); + break; + case ABSL: + ComputeInternalWithHashMap>( + context, input, idx, axis, &uniq_size_out, N, serial, output); + break; + case GOOGLE: + ComputeInternalWithHashMap( + context, input, idx, axis, &uniq_size_out, N, serial, output); + break; + default: + ComputeInternalWithHashMap( + context, input, idx, axis, &uniq_size_out, N, serial, output); + } + } else { + MultipleElements(context, input, idx, output, &uniq_size_out, + axis, new_sizes); + } + + CheckCountOutput(context, output_counter, idx, num_outputs, + uniq_size_out); +} + +template +void UniqueWithoutAxis(OpKernelContext* context, const Tensor& input, + Tensor* idx, Tensor* output, Tensor* output_counter, + int num_outputs, int64 partition_size, bool serial, + int64 unique_ratio_hint, UniqueMaps map_flag) { + int64 axis = 0; + std::vector new_sizes{1, input.NumElements(), 1}; + OP_REQUIRES(context, TensorShapeUtils::IsVector(input.shape()), + errors::InvalidArgument("unique expects a 1D vector.")); + UniqueInternal(context, input, idx, output, output_counter, + num_outputs, partition_size, serial, axis, + unique_ratio_hint, new_sizes, map_flag); +} + +template +void UniqueWithAxis(OpKernelContext* context, const Tensor& input, + const Tensor& axis_tensor, Tensor* idx, Tensor* output, + Tensor* output_counter, int num_outputs, + int64 partition_size, bool serial, int64 unique_ratio_hint, + UniqueMaps map_flag) { + int64 axis = 0; + std::vector new_sizes{1, input.NumElements(), 1}; + NewSizes(context, input, axis_tensor, new_sizes, axis); + UniqueInternal(context, input, idx, output, output_counter, + num_outputs, partition_size, serial, axis, + unique_ratio_hint, new_sizes, map_flag); +} + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_UNIQUE_ALI_OP_UTIL_H_ diff --git a/deepray/custom_ops/unique_ops/cc/ops/unique_ops.cc b/deepray/custom_ops/unique_ops/cc/ops/unique_ops.cc new file mode 100644 index 0000000..f815833 --- /dev/null +++ b/deepray/custom_ops/unique_ops/cc/ops/unique_ops.cc @@ -0,0 +1,85 @@ +/* Copyright 2023 The Deepray Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { + +using ::tensorflow::shape_inference::InferenceContext; +using ::tensorflow::shape_inference::ShapeHandle; + +REGISTER_OP("Deepray>Unique") + .Input("x: T") + .Output("y: T") + .Output("idx: out_idx") + .Attr("T: type") + .Attr("out_idx: {int32, int64} = DT_INT32") + .SetShapeFn([](InferenceContext* c) { + c->set_output(0, c->Vector(InferenceContext::kUnknownDim)); + c->set_output(1, c->input(0)); + // Assert that the input rank is 1. + ShapeHandle dummy; + return c->WithRank(c->input(0), 1, &dummy); + }); + +REGISTER_OP("Deepray>UniqueV2") + .Input("x: T") + .Input("axis: Taxis") + .Output("y: T") + .Output("idx: out_idx") + .Attr("T: type") + .Attr("Taxis: {int32,int64} = DT_INT64") + .Attr("out_idx: {int32, int64} = DT_INT32") + .SetShapeFn([](InferenceContext* c) { + c->set_output(0, c->Vector(InferenceContext::kUnknownDim)); + c->set_output(1, c->input(0)); + return Status::OK(); + }); + +// -------------------------------------------------------------------------- +REGISTER_OP("Deepray>UniqueWithCounts") + .Input("x: T") + .Output("y: T") + .Output("idx: out_idx") + .Output("count: out_idx") + .Attr("T: type") + .Attr("out_idx: {int32, int64} = DT_INT32") + .SetShapeFn([](InferenceContext* c) { + auto uniq = c->Vector(InferenceContext::kUnknownDim); + c->set_output(0, uniq); + c->set_output(1, c->input(0)); + c->set_output(2, uniq); + return Status::OK(); + }); + +REGISTER_OP("Deepray>UniqueWithCountsV2") + .Input("x: T") + .Input("axis: Taxis") + .Output("y: T") + .Output("idx: out_idx") + .Output("count: out_idx") + .Attr("T: type") + .Attr("Taxis: {int32,int64} = DT_INT64") + .Attr("out_idx: {int32, int64} = DT_INT32") + .SetShapeFn([](InferenceContext* c) { + auto uniq = c->Vector(InferenceContext::kUnknownDim); + c->set_output(0, uniq); + c->set_output(1, c->input(0)); + c->set_output(2, uniq); + return Status::OK(); + }); + +} // namespace tensorflow \ No newline at end of file diff --git a/deepray/custom_ops/unique_ops/python/__init__.py b/deepray/custom_ops/unique_ops/python/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/deepray/custom_ops/unique_ops/python/tests/__init__.py b/deepray/custom_ops/unique_ops/python/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/deepray/custom_ops/unique_ops/python/tests/run_all_test.py b/deepray/custom_ops/unique_ops/python/tests/run_all_test.py new file mode 100644 index 0000000..7f259e0 --- /dev/null +++ b/deepray/custom_ops/unique_ops/python/tests/run_all_test.py @@ -0,0 +1,9 @@ +from pathlib import Path +import sys + +import pytest + +if __name__ == "__main__": + dirname = Path(__file__).absolute().parent + sys.exit(pytest.main(['-s', str(dirname)])) + # sys.exit(pytest.main([str(dirname)])) diff --git a/deepray/custom_ops/unique_ops/python/tests/unique_op_test.py b/deepray/custom_ops/unique_ops/python/tests/unique_op_test.py new file mode 100644 index 0000000..a3b8a47 --- /dev/null +++ b/deepray/custom_ops/unique_ops/python/tests/unique_op_test.py @@ -0,0 +1,303 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow.kernels.unique_op.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +import numpy as np +from tensorflow.python.framework import dtypes +from tensorflow.python.platform import test +from tensorflow.python.framework import errors_impl + +from deepray.custom_ops.unique_ops import gen_array_ops + +unique = gen_array_ops.deepray_unique + +# set environ before tf initializing global varialbes +PreservedKey = 1 << 10 +os.environ["DEEPREC_CONFIG_RAND_64"] = str(PreservedKey) + + +class UniqueTest(test.TestCase): + + def testInt32(self): + x = np.random.randint(2, high=10, size=7000) + with self.cached_session() as sess: + y, idx = gen_array_ops.deepray_unique(x) + tf_y, tf_idx = self.evaluate([y, idx]) + + self.assertEqual(len(x), len(tf_idx)) + self.assertEqual(len(tf_y), len(np.unique(x))) + for i in range(len(x)): + self.assertEqual(x[i], tf_y[tf_idx[i]]) + + def testInt32OutIdxInt64(self): + x = np.random.randint(2, high=10, size=7000) + with self.cached_session() as sess: + y, idx = gen_array_ops.deepray_unique(x, out_idx=dtypes.int64) + tf_y, tf_idx = self.evaluate([y, idx]) + + self.assertEqual(len(x), len(tf_idx)) + self.assertEqual(len(tf_y), len(np.unique(x))) + for i in range(len(x)): + self.assertEqual(x[i], tf_y[tf_idx[i]]) + + def testInt64OutIdxInt64(self): + np.random.seed(0) + x = np.random.randint(-1000000000, high=1000000000, size=1000000, dtype=np.int64) + with self.cached_session(use_gpu=True) as sess: + y, idx = unique(x, out_idx=dtypes.int64) + tf_y, tf_idx = sess.run([y, idx]) + + self.assertEqual(len(x), len(tf_idx)) + self.assertEqual(len(tf_y), len(np.unique(x))) + for i in range(len(x)): + self.assertEqual(x[i], tf_y[tf_idx[i]]) + + def testInt64OutIdxInt32(self): + np.random.seed(0) + x = np.random.randint(-1000000000, high=1000000000, size=1000000, dtype=np.int64) + with self.cached_session(use_gpu=True) as sess: + y, idx = unique(x, out_idx=dtypes.int32) + tf_y, tf_idx = sess.run([y, idx]) + + self.assertEqual(len(x), len(tf_idx)) + self.assertEqual(len(tf_y), len(np.unique(x))) + for i in range(len(x)): + self.assertEqual(x[i], tf_y[tf_idx[i]]) + + def testString(self): + indx = np.random.randint(65, high=122, size=7000) + x = [chr(i) for i in indx] + with self.cached_session() as sess: + y, idx = gen_array_ops.deepray_unique(x) + tf_y, tf_idx = self.evaluate([y, idx]) + + self.assertEqual(len(x), len(tf_idx)) + self.assertEqual(len(tf_y), len(np.unique(x))) + for i in range(len(x)): + self.assertEqual(x[i], tf_y[tf_idx[i]].decode('ascii')) + + def testInt32Axis(self): + for dtype in [np.int32, np.int64]: + x = np.array([[1, 0, 0], [1, 0, 0], [2, 0, 0]]) + with self.cached_session() as sess: + y0, idx0 = gen_array_ops.deepray_unique_v2(x, axis=np.array([0], dtype)) + tf_y0, tf_idx0 = self.evaluate([y0, idx0]) + y1, idx1 = gen_array_ops.deepray_unique_v2(x, axis=np.array([1], dtype)) + tf_y1, tf_idx1 = self.evaluate([y1, idx1]) + self.assertAllEqual(tf_y0, np.array([[1, 0, 0], [2, 0, 0]])) + self.assertAllEqual(tf_idx0, np.array([0, 0, 1])) + self.assertAllEqual(tf_y1, np.array([[1, 0], [1, 0], [2, 0]])) + self.assertAllEqual(tf_idx1, np.array([0, 1, 1])) + + def testInt32V2(self): + # This test is only temporary, once V2 is used + # by default, the axis will be wrapped to allow `axis=None`. + x = np.random.randint(2, high=10, size=7000) + with self.cached_session() as sess: + y, idx = gen_array_ops.deepray_unique_v2(x, axis=np.array([], np.int32)) + tf_y, tf_idx = self.evaluate([y, idx]) + + self.assertEqual(len(x), len(tf_idx)) + self.assertEqual(len(tf_y), len(np.unique(x))) + for i in range(len(x)): + self.assertEqual(x[i], tf_y[tf_idx[i]]) + + def IllegalIdForMultMapUnique(self): + recover_env = False + if 'DEEPREC_UNIQUE_OP_PARTITION_SIZE' in os.environ: + recover_env = True + old_env = os.environ['DEEPREC_UNIQUE_OP_PARTITION_SIZE'] + os.environ['DEEPREC_UNIQUE_OP_PARTITION_SIZE'] = '2' + + with self.cached_session() as sess: + x = np.array([-1, 0, 1, PreservedKey], dtype=np.int64) + y, idx = unique(x, out_idx=dtypes.int64) + with self.assertRaisesRegexp( + errors_impl.InvalidArgumentError, "Input id is preserved key of dense_hash_map, " + "not supported: " + str(PreservedKey) + ): + tf_y, tf_idx = sess.run([y, idx]) + + del os.environ['DEEPREC_UNIQUE_OP_PARTITION_SIZE'] + if recover_env: + os.environ['DEEPREC_UNIQUE_OP_PARTITION_SIZE'] = old_env + + def RunUniqueWithDifferentMaps(self, map_type, test_illegal_key=False): + recover_env = False + if 'DEEPREC_UNIQUE_OP_HASH_MAP' in os.environ: + recover_env = True + old_env = os.environ['DEEPREC_UNIQUE_OP_HASH_MAP'] + + os.environ['DEEPREC_UNIQUE_OP_HASH_MAP'] = map_type + self.testInt32() + self.testInt32OutIdxInt64() + self.testInt64OutIdxInt64() + self.testInt64OutIdxInt32() + self.testInt32Axis() + self.testInt32V2() + if test_illegal_key: + self.IllegalIdForMultMapUnique() + + del os.environ['DEEPREC_UNIQUE_OP_HASH_MAP'] + if recover_env: + os.environ['DEEPREC_UNIQUE_OP_HASH_MAP'] = old_env + + def testUniqueMultiMap(self): + self.RunUniqueWithDifferentMaps('MULTIMAP') + + def testUniqueStlMap(self): + self.RunUniqueWithDifferentMaps('STL') + + def testUniqueAbslMap(self): + self.RunUniqueWithDifferentMaps('ABSL') + + def testUniqueDenseHashMap(self): + self.RunUniqueWithDifferentMaps('GOOGLE') + + # def testBool(self): + # x = np.random.choice([True, False], size=7000) + # with self.cached_session() as sess: + # y, idx = gen_array_ops.deepray_unique(x) + # tf_y, tf_idx = self.evaluate([y, idx]) + + # self.assertEqual(len(x), len(tf_idx)) + # self.assertEqual(len(tf_y), len(np.unique(x))) + # for i in range(len(x)): + # self.assertEqual(x[i], tf_y[tf_idx[i]]) + + # def testBoolV2(self): + # x = np.random.choice([True, False], size=7000) + # with self.cached_session() as sess: + # y, idx = gen_array_ops.deepray_unique_v2(x, axis=np.array([], np.int32)) + # tf_y, tf_idx = self.evaluate([y, idx]) + + # self.assertEqual(len(x), len(tf_idx)) + # self.assertEqual(len(tf_y), len(np.unique(x))) + # for i in range(len(x)): + # self.assertEqual(x[i], tf_y[tf_idx[i]]) + + +# class UniqueWithCountsTest(test.TestCase): + +# def testInt32(self): +# x = np.random.randint(2, high=10, size=7000) +# with self.cached_session() as sess: +# y, idx, count = array_ops.unique_with_counts(x) +# tf_y, tf_idx, tf_count = self.evaluate([y, idx, count]) + +# self.assertEqual(len(x), len(tf_idx)) +# self.assertEqual(len(tf_y), len(np.unique(x))) +# for i in range(len(x)): +# self.assertEqual(x[i], tf_y[tf_idx[i]]) +# for value, count in zip(tf_y, tf_count): +# self.assertEqual(count, np.sum(x == value)) + +# def testInt32OutIdxInt64(self): +# x = np.random.randint(2, high=10, size=7000) +# with self.cached_session() as sess: +# y, idx, count = array_ops.unique_with_counts(x, out_idx=dtypes.int64) +# tf_y, tf_idx, tf_count = self.evaluate([y, idx, count]) + +# self.assertEqual(len(x), len(tf_idx)) +# self.assertEqual(len(tf_y), len(np.unique(x))) +# for i in range(len(x)): +# self.assertEqual(x[i], tf_y[tf_idx[i]]) +# for value, count in zip(tf_y, tf_count): +# self.assertEqual(count, np.sum(x == value)) + +# def testString(self): +# indx = np.random.randint(65, high=122, size=7000) +# x = [chr(i) for i in indx] + +# with self.cached_session() as sess: +# y, idx, count = array_ops.unique_with_counts(x) +# tf_y, tf_idx, tf_count = self.evaluate([y, idx, count]) + +# self.assertEqual(len(x), len(tf_idx)) +# self.assertEqual(len(tf_y), len(np.unique(x))) +# for i in range(len(x)): +# self.assertEqual(x[i], tf_y[tf_idx[i]].decode('ascii')) +# for value, count in zip(tf_y, tf_count): +# v = [1 if x[i] == value.decode('ascii') else 0 for i in range(7000)] +# self.assertEqual(count, sum(v)) + +# def testInt32Axis(self): +# for dtype in [np.int32, np.int64]: +# x = np.array([[1, 0, 0], [1, 0, 0], [2, 0, 0]]) +# with self.cached_session() as sess: +# y0, idx0, count0 = gen_array_ops.deepray_unique_with_counts_v2( +# x, axis=np.array([0], dtype)) +# tf_y0, tf_idx0, tf_count0 = self.evaluate([y0, idx0, count0]) +# y1, idx1, count1 = gen_array_ops.deepray_unique_with_counts_v2( +# x, axis=np.array([1], dtype)) +# tf_y1, tf_idx1, tf_count1 = self.evaluate([y1, idx1, count1]) +# self.assertAllEqual(tf_y0, np.array([[1, 0, 0], [2, 0, 0]])) +# self.assertAllEqual(tf_idx0, np.array([0, 0, 1])) +# self.assertAllEqual(tf_count0, np.array([2, 1])) +# self.assertAllEqual(tf_y1, np.array([[1, 0], [1, 0], [2, 0]])) +# self.assertAllEqual(tf_idx1, np.array([0, 1, 1])) +# self.assertAllEqual(tf_count1, np.array([1, 2])) + +# def testInt32V2(self): +# # This test is only temporary, once V2 is used +# # by default, the axis will be wrapped to allow `axis=None`. +# x = np.random.randint(2, high=10, size=7000) +# with self.cached_session() as sess: +# y, idx, count = gen_array_ops.deepray_unique_with_counts_v2( +# x, axis=np.array([], np.int32)) +# tf_y, tf_idx, tf_count = self.evaluate([y, idx, count]) + +# self.assertEqual(len(x), len(tf_idx)) +# self.assertEqual(len(tf_y), len(np.unique(x))) +# for i in range(len(x)): +# self.assertEqual(x[i], tf_y[tf_idx[i]]) +# for value, count in zip(tf_y, tf_count): +# self.assertEqual(count, np.sum(x == value)) + +# def testBool(self): +# x = np.random.choice([True, False], size=7000) +# with self.cached_session() as sess: +# y, idx, count = array_ops.unique_with_counts(x) +# tf_y, tf_idx, tf_count = self.evaluate([y, idx, count]) + +# self.assertEqual(len(x), len(tf_idx)) +# self.assertEqual(len(tf_y), len(np.unique(x))) +# for i in range(len(x)): +# self.assertEqual(x[i], tf_y[tf_idx[i]]) +# for value, count in zip(tf_y, tf_count): +# self.assertEqual(count, np.sum(x == value)) + +# def testBoolV2(self): +# x = np.random.choice([True, False], size=7000) +# with self.cached_session() as sess: +# y, idx, count = gen_array_ops.deepray_unique_with_counts_v2( +# x, axis=np.array([], np.int32)) +# tf_y, tf_idx, tf_count = self.evaluate([y, idx, count]) + +# self.assertEqual(len(x), len(tf_idx)) +# self.assertEqual(len(tf_y), len(np.unique(x))) +# for i in range(len(x)): +# self.assertEqual(x[i], tf_y[tf_idx[i]]) +# for value, count in zip(tf_y, tf_count): +# self.assertEqual(count, np.sum(x == value)) + +if __name__ == '__main__': + test.main() diff --git a/deepray/custom_ops/unique_ops/python/unique_ops.py b/deepray/custom_ops/unique_ops/python/unique_ops.py new file mode 100644 index 0000000..0424e16 --- /dev/null +++ b/deepray/custom_ops/unique_ops/python/unique_ops.py @@ -0,0 +1,23 @@ +# Copyright 2018 The Sonnet Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Use array ops in python.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import tensorflow as tf +from tensorflow.python.platform import resource_loader + +gen_array_ops = tf.load_op_library(resource_loader.get_path_to_datafile("../_unique_ops.so")) diff --git a/deepray/workspace2.bzl b/deepray/workspace2.bzl index a95a323..b26d49c 100644 --- a/deepray/workspace2.bzl +++ b/deepray/workspace2.bzl @@ -266,6 +266,15 @@ def _tf_repositories(): ], ) + http_archive( + name = "sparsehash_c11", + build_file = "//third_party:sparsehash_c11.BUILD", + sha256 = "d4a43cad1e27646ff0ef3a8ce3e18540dbcb1fdec6cc1d1cb9b5095a9ca2a755", + strip_prefix = "sparsehash-c11-2.11.1", + urls = [ + "https://github.com/sparsehash/sparsehash-c11/archive/v2.11.1.tar.gz", + ], + ) def workspace(): # Import all other repositories. This should happen before initializing # any external repositories, because those come with their own diff --git a/third_party/sparsehash_c11.BUILD b/third_party/sparsehash_c11.BUILD new file mode 100644 index 0000000..c1a6ace --- /dev/null +++ b/third_party/sparsehash_c11.BUILD @@ -0,0 +1,13 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # # BSD 3-Clause + +cc_library( + name = "dense_hash_map", + hdrs = glob([ + "sparsehash/**", + ]), + includes = ["."], + visibility = ["//visibility:public"], +) + From dec6ac8eabdba7753c573fbeabfd6ebf9cb6418b Mon Sep 17 00:00:00 2001 From: Vincent Date: Sun, 17 Dec 2023 23:39:30 +0800 Subject: [PATCH 2/2] format with clang --- deepray/custom_ops/ffm_ops/BUILD | 2 +- .../ffm_ops/cc/kernels/ffm_kernels.cc | 1 + .../ffm_ops/cc/kernels/ffm_kernels.cu.cc | 1 + .../unique_ops/cc/kernels/unique_ali_op.cc | 20 +++++++++---------- .../cc/kernels/unique_ali_op_gpu.cu.cc | 2 +- 5 files changed, 14 insertions(+), 12 deletions(-) diff --git a/deepray/custom_ops/ffm_ops/BUILD b/deepray/custom_ops/ffm_ops/BUILD index 5813bc3..3e2bce2 100644 --- a/deepray/custom_ops/ffm_ops/BUILD +++ b/deepray/custom_ops/ffm_ops/BUILD @@ -5,12 +5,12 @@ licenses(["notice"]) # Apache 2.0 custom_op_library( name = "_ffm_ops.so", - copts = [CPLUSPLUS_VERSION], srcs = [ "cc/kernels/ffm_kernels.cc", "cc/kernels/ffm_kernels.h", "cc/ops/ffm_ops.cc", ], + copts = [CPLUSPLUS_VERSION], cuda_srcs = [ "cc/kernels/ffm_kernels.h", "cc/kernels/ffm_kernels.cu.cc", diff --git a/deepray/custom_ops/ffm_ops/cc/kernels/ffm_kernels.cc b/deepray/custom_ops/ffm_ops/cc/kernels/ffm_kernels.cc index 0dbb90c..7a50c35 100644 --- a/deepray/custom_ops/ffm_ops/cc/kernels/ffm_kernels.cc +++ b/deepray/custom_ops/ffm_ops/cc/kernels/ffm_kernels.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "ffm_kernels.h" + #include #include diff --git a/deepray/custom_ops/ffm_ops/cc/kernels/ffm_kernels.cu.cc b/deepray/custom_ops/ffm_ops/cc/kernels/ffm_kernels.cu.cc index 3614a14..6299c82 100644 --- a/deepray/custom_ops/ffm_ops/cc/kernels/ffm_kernels.cu.cc +++ b/deepray/custom_ops/ffm_ops/cc/kernels/ffm_kernels.cu.cc @@ -15,6 +15,7 @@ #if GOOGLE_CUDA #define EIGEN_USE_GPU #include "ffm_kernels.h" + #include #include diff --git a/deepray/custom_ops/unique_ops/cc/kernels/unique_ali_op.cc b/deepray/custom_ops/unique_ops/cc/kernels/unique_ali_op.cc index 122f294..047ff3b 100644 --- a/deepray/custom_ops/unique_ops/cc/kernels/unique_ali_op.cc +++ b/deepray/custom_ops/unique_ops/cc/kernels/unique_ali_op.cc @@ -136,42 +136,42 @@ class UniqueAliOp : public OpKernel { }; #define REGISTER_UNIQUE(type) \ - REGISTER_KERNEL_BUILDER(Name("Deepray>Unique") \ + REGISTER_KERNEL_BUILDER(Name("Deepray>Unique") \ .Device(DEVICE_CPU) \ .TypeConstraint("T") \ .TypeConstraint("out_idx"), \ UniqueAliOp); \ - REGISTER_KERNEL_BUILDER(Name("Deepray>Unique") \ + REGISTER_KERNEL_BUILDER(Name("Deepray>Unique") \ .Device(DEVICE_CPU) \ .TypeConstraint("T") \ .TypeConstraint("out_idx"), \ UniqueAliOp); \ - REGISTER_KERNEL_BUILDER(Name("Deepray>UniqueV2") \ + REGISTER_KERNEL_BUILDER(Name("Deepray>UniqueV2") \ .Device(DEVICE_CPU) \ .TypeConstraint("T") \ .TypeConstraint("out_idx"), \ UniqueAliOp); \ - REGISTER_KERNEL_BUILDER(Name("Deepray>UniqueV2") \ + REGISTER_KERNEL_BUILDER(Name("Deepray>UniqueV2") \ .Device(DEVICE_CPU) \ .TypeConstraint("T") \ .TypeConstraint("out_idx"), \ UniqueAliOp); \ - REGISTER_KERNEL_BUILDER(Name("Deepray>UniqueWithCounts") \ + REGISTER_KERNEL_BUILDER(Name("Deepray>UniqueWithCounts") \ .Device(DEVICE_CPU) \ .TypeConstraint("T") \ .TypeConstraint("out_idx"), \ UniqueAliOp) \ - REGISTER_KERNEL_BUILDER(Name("Deepray>UniqueWithCounts") \ + REGISTER_KERNEL_BUILDER(Name("Deepray>UniqueWithCounts") \ .Device(DEVICE_CPU) \ .TypeConstraint("T") \ .TypeConstraint("out_idx"), \ UniqueAliOp); \ - REGISTER_KERNEL_BUILDER(Name("Deepray>UniqueWithCountsV2") \ + REGISTER_KERNEL_BUILDER(Name("Deepray>UniqueWithCountsV2") \ .Device(DEVICE_CPU) \ .TypeConstraint("T") \ .TypeConstraint("out_idx"), \ UniqueAliOp) \ - REGISTER_KERNEL_BUILDER(Name("Deepray>UniqueWithCountsV2") \ + REGISTER_KERNEL_BUILDER(Name("Deepray>UniqueWithCountsV2") \ .Device(DEVICE_CPU) \ .TypeConstraint("T") \ .TypeConstraint("out_idx"), \ @@ -182,7 +182,7 @@ REGISTER_UNIQUE(tstring) #if GOOGLE_CUDA #define REGISTER_UNIQUE(type) \ - REGISTER_KERNEL_BUILDER(Name("Deepray>UniqueWithCounts") \ + REGISTER_KERNEL_BUILDER(Name("Deepray>UniqueWithCounts") \ .Device(DEVICE_GPU) \ .HostMemory("x") \ .HostMemory("y") \ @@ -191,7 +191,7 @@ REGISTER_UNIQUE(tstring) .TypeConstraint("T") \ .TypeConstraint("out_idx"), \ UniqueAliOp) \ - REGISTER_KERNEL_BUILDER(Name("Deepray>UniqueWithCounts") \ + REGISTER_KERNEL_BUILDER(Name("Deepray>UniqueWithCounts") \ .Device(DEVICE_GPU) \ .HostMemory("x") \ .HostMemory("y") \ diff --git a/deepray/custom_ops/unique_ops/cc/kernels/unique_ali_op_gpu.cu.cc b/deepray/custom_ops/unique_ops/cc/kernels/unique_ali_op_gpu.cu.cc index 9bb0f91..c3677d2 100644 --- a/deepray/custom_ops/unique_ops/cc/kernels/unique_ali_op_gpu.cu.cc +++ b/deepray/custom_ops/unique_ops/cc/kernels/unique_ali_op_gpu.cu.cc @@ -256,7 +256,7 @@ class UniqueAliV2GpuOp : public AsyncOpKernel { }; #define REGISTER_UNIQUE_ALI_V2_GPU_KERNEL(T, TIndex) \ - REGISTER_KERNEL_BUILDER(Name("Deepray>Unique") \ + REGISTER_KERNEL_BUILDER(Name("Deepray>Unique") \ .Device(DEVICE_GPU) \ .TypeConstraint("T") \ .TypeConstraint("out_idx"), \