-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #54 from deepray-AI/hotfix
[ops] migrate DeepRec unique ops to Tensorflow2
- Loading branch information
Showing
21 changed files
with
2,062 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,6 +13,7 @@ | |
// limitations under the License. | ||
|
||
#include "ffm_kernels.h" | ||
|
||
#include <string> | ||
#include <vector> | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,6 +15,7 @@ | |
#if GOOGLE_CUDA | ||
#define EIGEN_USE_GPU | ||
#include "ffm_kernels.h" | ||
|
||
#include <string> | ||
#include <vector> | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
load("//deepray:deepray.bzl", "custom_op_library") | ||
load("@local_config_tf//:build_defs.bzl", "CPLUSPLUS_VERSION") | ||
|
||
licenses(["notice"]) # Apache 2.0 | ||
|
||
package( | ||
default_visibility = [ | ||
"//deepray:__subpackages__", | ||
], | ||
licenses = ["notice"], # Apache 2.0 | ||
) | ||
|
||
cc_library( | ||
name = "random", | ||
srcs = [ | ||
"cc/kernels/random.cc", | ||
"cc/kernels/random.h", | ||
], | ||
copts = [CPLUSPLUS_VERSION], | ||
deps = [ | ||
"@local_config_tf//:libtensorflow_framework", | ||
"@local_config_tf//:tf_header_lib", | ||
], | ||
) | ||
|
||
cc_test( | ||
name = "random_test", | ||
srcs = ["cc/kernels/random_test.cc"], | ||
deps = [ | ||
":random", | ||
"@com_google_googletest//:gtest_main", | ||
], | ||
) | ||
|
||
custom_op_library( | ||
name = "_unique_ops.so", | ||
srcs = [ | ||
"cc/kernels/task_runner.h", | ||
"cc/kernels/unique_ali_op.cc", | ||
"cc/kernels/unique_ali_op_util.h", | ||
"cc/ops/unique_ops.cc", | ||
], | ||
copts = [CPLUSPLUS_VERSION], | ||
cuda_srcs = [ | ||
"cc/kernels/unique_ali_op_gpu.cu.cc", | ||
], | ||
visibility = ["//visibility:public"], | ||
deps = [ | ||
":random", | ||
"@com_google_absl//absl/container:flat_hash_map", | ||
"@sparsehash_c11//:dense_hash_map", | ||
], | ||
) | ||
|
||
py_library( | ||
name = "unique_ops", | ||
srcs = glob( | ||
[ | ||
"python/*.py", | ||
"python/**/*.py", | ||
"*.py", | ||
], | ||
), | ||
data = [ | ||
":_unique_ops.so", | ||
], | ||
) | ||
|
||
py_test( | ||
name = "unique_ops_test", | ||
size = "small", | ||
srcs = glob(["python/tests/*"]), | ||
main = "python/tests/run_all_test.py", | ||
deps = [ | ||
":unique_ops", | ||
], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from deepray.custom_ops.unique_ops.python.unique_ops import gen_array_ops |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
==============================================================================*/ | ||
|
||
#include "random.h" | ||
|
||
#include <random> | ||
|
||
#include "tensorflow/core/platform/mutex.h" | ||
#include "tensorflow/core/platform/types.h" | ||
#include "tensorflow/core/util/env_var.h" | ||
|
||
namespace tensorflow { | ||
namespace random { | ||
|
||
namespace { | ||
std::mt19937_64* InitRngWithRandomSeed() { | ||
std::random_device device("/dev/urandom"); | ||
return new std::mt19937_64(device()); | ||
} | ||
std::mt19937_64 InitRngWithDefaultSeed() { return std::mt19937_64(); } | ||
|
||
} // anonymous namespace | ||
|
||
uint64 New64() { | ||
static std::mt19937_64* rng = InitRngWithRandomSeed(); | ||
static mutex mu(LINKER_INITIALIZED); | ||
mutex_lock l(mu); | ||
return (*rng)(); | ||
} | ||
|
||
uint64 New64DefaultSeed() { | ||
static std::mt19937_64 rng = InitRngWithDefaultSeed(); | ||
static mutex mu(LINKER_INITIALIZED); | ||
mutex_lock l(mu); | ||
return rng(); | ||
} | ||
|
||
uint64 New64Configuable() { | ||
int64 random_64; | ||
CHECK( | ||
ReadInt64FromEnvVar("DEEPREC_CONFIG_RAND_64", New64(), &random_64).ok()); | ||
return static_cast<uint64>(random_64); | ||
} | ||
|
||
} // namespace random | ||
} // namespace tensorflow |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
==============================================================================*/ | ||
|
||
#ifndef TENSORFLOW_LIB_RANDOM_RANDOM_H_ | ||
#define TENSORFLOW_LIB_RANDOM_RANDOM_H_ | ||
|
||
#include "tensorflow/core/platform/types.h" | ||
|
||
namespace tensorflow { | ||
namespace random { | ||
|
||
// Return a 64-bit random value. Different sequences are generated | ||
// in different processes. | ||
uint64 New64(); | ||
|
||
// Return a 64-bit random value. Uses | ||
// std::mersenne_twister_engine::default_seed as seed value. | ||
uint64 New64DefaultSeed(); | ||
|
||
// Call New64 to generate a 64-bit random value | ||
// if env var DEEPREC_CONFIG_RAND_64 not set. | ||
// Otherwise, return int64 from DEEPREC_CONFIG_RAND_64 | ||
uint64 New64Configuable(); | ||
|
||
} // namespace random | ||
} // namespace tensorflow | ||
|
||
#endif // TENSORFLOW_LIB_RANDOM_RANDOM_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
==============================================================================*/ | ||
|
||
#include "tensorflow/core/lib/random/random.h" | ||
|
||
#include <set> | ||
|
||
#include "tensorflow/core/platform/test.h" | ||
#include "tensorflow/core/platform/types.h" | ||
|
||
namespace tensorflow { | ||
namespace random { | ||
namespace { | ||
|
||
TEST(New64Test, SanityCheck) { | ||
std::set<uint64> values; | ||
for (int i = 0; i < 1000000; i++) { | ||
uint64 x = New64(); | ||
EXPECT_TRUE(values.insert(x).second) << "duplicate " << x; | ||
} | ||
} | ||
|
||
} // namespace | ||
} // namespace random | ||
} // namespace tensorflow |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
==============================================================================*/ | ||
|
||
#ifndef TENSORFLOW_CORE_KERNELS_TASK_RUNNER_H_ | ||
#define TENSORFLOW_CORE_KERNELS_TASK_RUNNER_H_ | ||
|
||
#include <functional> | ||
|
||
#include "tensorflow/core/lib/core/blocking_counter.h" | ||
#include "tensorflow/core/lib/core/threadpool.h" | ||
|
||
namespace tensorflow { | ||
|
||
// TaskRunner schedules tasks(function f) to ThreadPool | ||
// and wait until all finished | ||
class TaskRunner { | ||
public: | ||
explicit TaskRunner(const std::function<void(int32, int32)>& f, | ||
thread::ThreadPool* tp, int32 n) | ||
: func_(f), thread_pool_(tp), num_tasks_(n) {} | ||
|
||
void Run() { | ||
if (num_tasks_ <= 0) return; | ||
BlockingCounter bc(num_tasks_ - 1); | ||
|
||
// Sending (num_tasks - 1) tasks to threadpool for scheduling | ||
for (int32 i = 0; i < num_tasks_ - 1; ++i) { | ||
thread_pool_->Schedule([this, &bc, i]() { | ||
func_(i, num_tasks_); | ||
bc.DecrementCount(); | ||
}); | ||
} | ||
// Run the last task in current thread. | ||
func_(num_tasks_ - 1, num_tasks_); | ||
bc.Wait(); | ||
} | ||
|
||
private: | ||
std::function<void(int32 task_id, int32 num_tasks)> func_; | ||
thread::ThreadPool* thread_pool_; | ||
const int32 num_tasks_; | ||
}; | ||
|
||
// add more types of SummaryUpdater | ||
// for more types of summary or more ways of summary aggregation | ||
class StatusSummaryUpdater { | ||
public: | ||
static void UpdateSummary(Status* mine, const Status& ret) { | ||
mine->Update(ret); | ||
} | ||
}; | ||
|
||
class Int64SumSummaryUpdater { | ||
public: | ||
static void UpdateSummary(int64_t* mine, const int64_t& ret) { *mine += ret; } | ||
}; | ||
|
||
// SummaryTaskRunner schedules tasks and summary their return values. | ||
// S is the type of return values. | ||
// SUpdater is the class for aggregating the return values. | ||
template <typename S, typename SUpdater> | ||
class SummaryTaskRunner { | ||
public: | ||
explicit SummaryTaskRunner(const std::function<S(int32, int32)>& f, | ||
const S& init_summary, thread::ThreadPool* tp, | ||
int32 n) | ||
: func_(f), summary_(init_summary), thread_pool_(tp), num_tasks_(n) {} | ||
|
||
void Run() { | ||
if (num_tasks_ <= 0) return; | ||
BlockingCounter bc(num_tasks_ - 1); | ||
|
||
// Sending (num_tasks - 1) tasks to threadpool for scheduling | ||
for (int32 i = 0; i < num_tasks_ - 1; ++i) { | ||
thread_pool_->Schedule([this, &bc, i]() { | ||
const S& ret = func_(i, num_tasks_); | ||
UpdateSummaryUnlocked(ret); | ||
bc.DecrementCount(); | ||
}); | ||
} | ||
// Run the last task in current thread. | ||
const S& ret = func_(num_tasks_ - 1, num_tasks_); | ||
UpdateSummaryUnlocked(ret); | ||
bc.Wait(); | ||
} | ||
|
||
S summary() { return summary_; } | ||
|
||
private: | ||
void UpdateSummaryUnlocked(const S& ret) { | ||
mutex_lock lock(mu_); | ||
SUpdater::UpdateSummary(&summary_, ret); | ||
} | ||
|
||
mutex mu_; | ||
std::function<S(int32 task_id, int32 num_tasks)> func_; | ||
S summary_; | ||
thread::ThreadPool* thread_pool_; | ||
const int32 num_tasks_; | ||
}; | ||
|
||
} // namespace tensorflow | ||
|
||
#endif // TENSORFLOW_CORE_KERNELS_TASK_RUNNER_H_ |
Oops, something went wrong.