Skip to content

Commit

Permalink
Merge pull request #31 from deepray-AI/hotfix
Browse files Browse the repository at this point in the history
simple_hash_table ops
  • Loading branch information
fuhailin authored Nov 1, 2023
2 parents ba64ea1 + c505160 commit c42a480
Show file tree
Hide file tree
Showing 15 changed files with 1,877 additions and 9 deletions.
1 change: 1 addition & 0 deletions deepray/custom_ops/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ py_library(
deps = [
"//deepray/custom_ops/correlation_cost",
"//deepray/custom_ops/parquet_dataset",
"//deepray/custom_ops/simple_hash_table",
"//deepray/custom_ops/zero_out:zero_out_ops",
],
)
1 change: 0 additions & 1 deletion deepray/custom_ops/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +0,0 @@
from .zero_out import zero_out
61 changes: 61 additions & 0 deletions deepray/custom_ops/simple_hash_table/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Build simple_hash_table custom ops example, which is similar to,
# but simpler than, tf.lookup.experimental.MutableHashTable

load("//deepray:deepray.bzl", "custom_op_library")

licenses(["notice"])

package(default_visibility = ["//visibility:public"])

custom_op_library(
name = "simple_hash_table_kernel.so",
srcs = [
"simple_hash_table_kernel.cc",
"simple_hash_table_op.cc",
],
deps = [
"@com_google_absl//absl/container:flat_hash_map",
],
)

py_library(
name = "simple_hash_table_op",
srcs = ["simple_hash_table_op.py"],
data = ["simple_hash_table_kernel.so"],
srcs_version = "PY3",
deps = [
],
)

py_library(
name = "simple_hash_table",
# srcs = [
# "__init__.py",
# "simple_hash_table.py",
# "simple_hash_table_op.py",
# ],
srcs = glob(
[
"*.py",
],
),
srcs_version = "PY3",
deps = [
":simple_hash_table_op",
],
)

py_test(
name = "simple_hash_table_test",
size = "medium", # This test blocks because it writes and reads a file,
timeout = "short", # but it still runs quickly.
srcs = ["simple_hash_table_test.py"],
python_version = "PY3",
srcs_version = "PY3",
tags = [
"no_mac", # TODO(b/216321151): Re-enable this test.
],
deps = [
":simple_hash_table",
],
)
Loading

0 comments on commit c42a480

Please sign in to comment.