From c505160b34f060e860a7547ea20038092a502707 Mon Sep 17 00:00:00 2001 From: Vincent Date: Wed, 1 Nov 2023 11:53:33 +0800 Subject: [PATCH] custom ops example: simple_hash_table --- deepray/custom_ops/BUILD | 1 + deepray/custom_ops/__init__.py | 1 - deepray/custom_ops/simple_hash_table/BUILD | 61 ++ .../custom_ops/simple_hash_table/README.md | 899 ++++++++++++++++++ .../custom_ops/simple_hash_table/__init__.py | 0 .../simple_hash_table/simple_hash_table.py | 244 +++++ .../simple_hash_table_kernel.cc | 336 +++++++ .../simple_hash_table/simple_hash_table_op.cc | 173 ++++ .../simple_hash_table/simple_hash_table_op.py | 21 + .../simple_hash_table_test.py | 125 +++ 10 files changed, 1860 insertions(+), 1 deletion(-) create mode 100644 deepray/custom_ops/simple_hash_table/BUILD create mode 100644 deepray/custom_ops/simple_hash_table/README.md create mode 100644 deepray/custom_ops/simple_hash_table/__init__.py create mode 100644 deepray/custom_ops/simple_hash_table/simple_hash_table.py create mode 100644 deepray/custom_ops/simple_hash_table/simple_hash_table_kernel.cc create mode 100644 deepray/custom_ops/simple_hash_table/simple_hash_table_op.cc create mode 100644 deepray/custom_ops/simple_hash_table/simple_hash_table_op.py create mode 100644 deepray/custom_ops/simple_hash_table/simple_hash_table_test.py diff --git a/deepray/custom_ops/BUILD b/deepray/custom_ops/BUILD index c3d9840..71b7d8e 100644 --- a/deepray/custom_ops/BUILD +++ b/deepray/custom_ops/BUILD @@ -6,6 +6,7 @@ py_library( deps = [ "//deepray/custom_ops/correlation_cost", "//deepray/custom_ops/parquet_dataset", + "//deepray/custom_ops/simple_hash_table", "//deepray/custom_ops/zero_out:zero_out_ops", ], ) diff --git a/deepray/custom_ops/__init__.py b/deepray/custom_ops/__init__.py index 03a6fce..e69de29 100644 --- a/deepray/custom_ops/__init__.py +++ b/deepray/custom_ops/__init__.py @@ -1 +0,0 @@ -from .zero_out import zero_out diff --git a/deepray/custom_ops/simple_hash_table/BUILD b/deepray/custom_ops/simple_hash_table/BUILD new file mode 100644 index 0000000..586005e --- /dev/null +++ b/deepray/custom_ops/simple_hash_table/BUILD @@ -0,0 +1,61 @@ +# Build simple_hash_table custom ops example, which is similar to, +# but simpler than, tf.lookup.experimental.MutableHashTable + +load("//deepray:deepray.bzl", "custom_op_library") + +licenses(["notice"]) + +package(default_visibility = ["//visibility:public"]) + +custom_op_library( + name = "simple_hash_table_kernel.so", + srcs = [ + "simple_hash_table_kernel.cc", + "simple_hash_table_op.cc", + ], + deps = [ + "@com_google_absl//absl/container:flat_hash_map", + ], +) + +py_library( + name = "simple_hash_table_op", + srcs = ["simple_hash_table_op.py"], + data = ["simple_hash_table_kernel.so"], + srcs_version = "PY3", + deps = [ + ], +) + +py_library( + name = "simple_hash_table", + # srcs = [ + # "__init__.py", + # "simple_hash_table.py", + # "simple_hash_table_op.py", + # ], + srcs = glob( + [ + "*.py", + ], + ), + srcs_version = "PY3", + deps = [ + ":simple_hash_table_op", + ], +) + +py_test( + name = "simple_hash_table_test", + size = "medium", # This test blocks because it writes and reads a file, + timeout = "short", # but it still runs quickly. + srcs = ["simple_hash_table_test.py"], + python_version = "PY3", + srcs_version = "PY3", + tags = [ + "no_mac", # TODO(b/216321151): Re-enable this test. + ], + deps = [ + ":simple_hash_table", + ], +) diff --git a/deepray/custom_ops/simple_hash_table/README.md b/deepray/custom_ops/simple_hash_table/README.md new file mode 100644 index 0000000..5c79c64 --- /dev/null +++ b/deepray/custom_ops/simple_hash_table/README.md @@ -0,0 +1,899 @@ + +# Create a stateful custom op + +This example shows you how to create TensorFlow custom ops with internal state. +These custom ops use resources to hold their state. You do this by implementing +a stateful C++ data structure instead of using tensors to hold the state. This +example also covers: + +* Using `tf.Variable`s for state as an alternative to using a resource (which + is recommended for all cases where storing the state in fixed size tensors + is reasonable) +* Saving and restoring the state of the custom op by using `SavedModel`s +* The `IsStateful` flag which is true for ops with state and other unrelated + cases +* A set of related custom ops and ops with various signatures (number of + inputs and outputs) + +For additional context, +read the +[OSS guide on creating custom ops](https://www.tensorflow.org/guide/create_op). + +## Background + +### Overview of resources and ref-counting + +TensorFlow C++ resource classes are derived from the +[`ResourceBase` class](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/resource_base.h). +A resource is referenced in the Tensorflow Graph as a +[`ResourceHandle`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/resource_handle.h) +Tensor, of the type +[`DT_RESOURCE`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/types.proto). +A resource can be owned by a ref-counting `ResourceHandle`, or by a +[`ResourceMgr`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/resource_mgr.h) +(deprecated for new ops). + +A handle-owned resource using ref-counting is automatically destroyed when all +resource handles pointing to it go out of scope. If the resource needs to be +looked up from a name, for example, if the resource handle is serialized and +deserialized, the resource must be published as an unowned entry with +[`ResourceMgr::CreateUnowned`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/resource_mgr.h#L641) +upon creation. The entry is a weak reference to the resource, and is +automatically removed when the resource goes out of scope. + +In contrast, with the deprecated `ResourceMgr` owned resources, a resource +handle behaves like a weak ref - it does not destroy the underlying resource +when its lifetime ends. +[`DestroyResourceOp`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/resource_variable_ops.cc#L339) +must be explicitly called to destroy the resource. An example of why requiring +calling `DestroyResourceOp` is problematic is that it is easy to introduce a bug +when adding a new code path to an op that returns without calling +`DestroyResourceOp`. + +While it would be possible to have a data structure that has a single operation +implemented by a single custom op (for example, something that just has a 'next +state' operation that requires no initialization), typically a related set of +custom ops are used with a resource. **Separating creation and use into +different custom ops is recommended.** Typically, one custom op creates the +resource and one or more additional custom ops implement functionality to access +or modify it. + +### Using `tf.Variable`s for state + +An alternative to custom ops with internal state is to store the state +externally in one or more +[`tf.Variable`](https://www.tensorflow.org/api_docs/python/tf/Variable)s as +tensors (as detailed [here](https://www.tensorflow.org/guide/variable)) and have +one or more (normal, stateless) ops that use tensors stored in these +`tf.Variable`s. One example is `tf.random.Generator`. + +For cases where using variables is possible and efficient, using them is +preferred since the implementation does not require adding a new C++ resource +type. An example of a case where using variables is not possible and custom ops +using a resource must be used is where the amount of space for data grows +dynamically. + +Here is a toy example of using the +`multiplex_2` +custom op with a `tf.Variable` in the same manner as an in-built TensorFlow op +is used with variables. The variable is initialized to 1 for every value. +Indices from a +[`tf.data.Dataset`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset) +cause the corresponding element of the variable to be doubled. + + +```python +import tensorflow as tf +from tensorflow.examples.custom_ops_doc.multiplex_2 import multiplex_2_op + +def variable_and_stateless_op(): + n = 10 + v = tf.Variable(tf.ones(n, dtype=tf.int64), trainable=False) + dataset = tf.data.Dataset.from_tensor_slices([5, 1, 7, 5]) + for position in dataset: + print(v.numpy()) + cond = tf.one_hot( + position, depth=n, on_value=True, off_value=False, dtype=bool) + v.assign(multiplex_2_op.multiplex(cond, v*2, v)) + print(v.numpy()) +``` + +This outputs: + + +``` +[1 1 1 1 1 1 1 1 1 1] +[1 1 1 1 1 2 1 1 1 1] +[1 2 1 1 1 2 1 1 1 1] +[1 2 1 1 1 2 1 2 1 1] +[1 2 1 1 1 4 1 2 1 1] +``` + +It is also possible to pass a Python `tf.Variable` handle as a +[`DT_RESOURCE`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/types.proto) +input to a custom op where its C++ kernel uses the handle to access the variable +using the variable's internal C++ API. This is not a common case because the +variable's C++ API provides little extra functionality. It can be appropriate +for cases that only sparsely update the variable. + +### The `IsStateful` flag + +All TensorFlow ops have an +[`IsStateful`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/resource_mgr.h#L505) +flag. It is set to `True` for ops that have internal state, both for ops that +follow the recommended pattern of using a resource and those that have internal +state in some other way. In addition, it is also set to `True` for ops with side +effects (for example, I/O) and for disabling certain optimizations for an op. + +### Save and restore state with `SavedModel` + +A [`SavedModel`](https://www.tensorflow.org/guide/saved_model) contains a +complete TensorFlow program, including trained parameters (like `tf.Variable`s) +and computation. In some cases, you can save and restore the state in a custom +op by using `SavedModel`. This is optional, but is important for some cases, for +example if a custom op is used during training and the state needs to persist +when training is restarted from a checkpoint. The hash table in this example +supports `SavedModel` with a custom Python wrapper that implements the +`_serialize_to_tensors` and `_restore_from_tensors` methods of +[`tf.saved_model.experimental.TrackableResource`](https://www.tensorflow.org/api_docs/python/tf/saved_model/experimental/TrackableResource) +and ops that are used by these methods. + +## Creating a hash table with a stateful custom op + +This example demonstrates how to implement custom ops that use ref-counting +resources to track state. This example creates a `simple_hash_table` which is +similar to +[`tf.lookup.experimental.MutableHashTable`](https://www.tensorflow.org/api_docs/python/tf/lookup/experimental/MutableHashTable). + +The hash table implements a set of 4 CRUD (Create/Read/Update/Delete) style +custom ops. Each of these ops has a different signature and they illustrate +cases such as custom ops with more than one output. + +The hash table supports `SavedModel` through the `export` and `import` ops to +save/restore state. For an actual hash table use case, it is preferable to use +(or extend) the existing +[`tf.lookup`](https://www.tensorflow.org/api_docs/python/tf/lookup) ops. In this +simple example, `insert`, `find`, and `remove` use only a single key-value pair +per call. In contrast, existing `tf.lookup` ops can use multiple key-value +pairs in a single call. + +The table below summarizes the six custom ops implemented in this example. + +Operation | Purpose | Resource class method | Kernel class | Custom op | Python class member +--------- | --------------------------- | --------------------- | ------------------------------- | ------------------------------ | ------------------- +create | CRUD and SavedModel: create | Default constructor | `SimpleHashTableCreateOpKernel` | Examples>SimpleHashTableCreate | `__init__` +find | CRUD: read | Find | `SimpleHashTableFindOpKernel` | Examples>SimpleHashTableFind | `find` +insert | CRUD: update | Insert | `SimpleHashTableInsertOpKernel` | Examples>SimpleHashTableInsert | `insert` +remove | CRUD: delete | Remove | `SimpleHashTableRemoveOpKernel` | Examples>SimpleHashTableRemove | `remove` +import | SavedModel: restore | Import | `SimpleHashTableImportOpKernel` | Examples>SimpleHashTableImport | `do_import` +export | SavedModel: save | Export | `SimpleHashTableExportOpKernel` | Examples>SimpleHashTableExport | `export` + +You can use this hash table as: + + +```python +hash_table = simple_hash_table_op.SimpleHashTable(tf.int32, float, + default_value=-999.0) +result1 = hash_table.find(key=1, dynamic_default_value=-999.0) +# -999.0 +hash_table.insert(key=1, value=100.0) +result2 = hash_table.find(key=1, dynamic_default_value=-999.0) +# 100.0 +``` + +The example below contains C++ and Python code snippets to illustrate the code +flow. These snippets are not all complete; some are missing namespace +declarations, imports, and test cases. + +This example deviates slightly from the general recipe for creating TensorFlow +custom ops. The most significant differences are noted in each step below. + +### Step 0 - Implement the resource class + +Implement a +`SimpleHashTableResource` +resource class derived from +[`ResourceBase`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/resource_base.h). + +```live-snippet +template +class SimpleHashTableResource : public ::tensorflow::ResourceBase { + public: + Status Insert(const Tensor& key, const Tensor& value) { + const K key_val = key.flat()(0); + const V value_val = value.flat()(0); + + mutex_lock l(mu_); + table_[key_val] = value_val; + return OkStatus(); + } + + Status Find(const Tensor& key, Tensor* value, const Tensor& default_value) { + // Note that tf_shared_lock could be used instead of mutex_lock + // in ops that do not not modify data protected by a mutex, but + // go/totw/197 recommends using exclusive lock instead of a shared + // lock when the lock is not going to be held for a significant amount + // of time. + mutex_lock l(mu_); + + const V default_val = default_value.flat()(0); + const K key_val = key.flat()(0); + auto value_val = value->flat(); + value_val(0) = gtl::FindWithDefault(table_, key_val, default_val); + return OkStatus(); + } + + Status Remove(const Tensor& key) { + mutex_lock l(mu_); + + const K key_val = key.flat()(0); + if (table_.erase(key_val) != 1) { + return errors::NotFound("Key for remove not found: ", key_val); + } + return OkStatus(); + } + + // Save all key, value pairs to tensor outputs to support SavedModel + Status Export(OpKernelContext* ctx) { + mutex_lock l(mu_); + int64_t size = table_.size(); + Tensor* keys; + Tensor* values; + TF_RETURN_IF_ERROR( + ctx->allocate_output("keys", TensorShape({size}), &keys)); + TF_RETURN_IF_ERROR( + ctx->allocate_output("values", TensorShape({size}), &values)); + auto keys_data = keys->flat(); + auto values_data = values->flat(); + int64_t i = 0; + for (auto it = table_.begin(); it != table_.end(); ++it, ++i) { + keys_data(i) = it->first; + values_data(i) = it->second; + } + return OkStatus(); + } + + // Load all key, value pairs from tensor inputs to support SavedModel + Status Import(const Tensor& keys, const Tensor& values) { + const auto key_values = keys.flat(); + const auto value_values = values.flat(); + + mutex_lock l(mu_); + table_.clear(); + for (int64_t i = 0; i < key_values.size(); ++i) { + gtl::InsertOrUpdate(&table_, key_values(i), value_values(i)); + } + return OkStatus(); + } + + // Create a debug string with the content of the map if this is small, + // or some example data if this is large, handling both the cases where the + // hash table has many entries and where the entries are long strings. + std::string DebugString() const override { return DebugString(3); } + std::string DebugString(int num_pairs) const { + std::string rval = "SimpleHashTable {"; + size_t count = 0; + const size_t max_kv_str_len = 100; + mutex_lock l(mu_); + for (const auto& pair : table_) { + if (count >= num_pairs) { + strings::StrAppend(&rval, "..."); + break; + } + std::string kv_str = strings::StrCat(pair.first, ": ", pair.second); + strings::StrAppend(&rval, kv_str.substr(0, max_kv_str_len)); + if (kv_str.length() > max_kv_str_len) strings::StrAppend(&rval, " ..."); + strings::StrAppend(&rval, ", "); + count += 1; + } + strings::StrAppend(&rval, "}"); + return rval; + } + + private: + mutable mutex mu_; + absl::flat_hash_map table_ TF_GUARDED_BY(mu_); +}; +``` + +Note that this class provides: + +* Helper methods to access the hash table. These methods correspond to the + `find`, `insert`, and `remove` ops. +* Helper methods to `import`/`export` the complete internal state of the hash + table. These methods help support `SavedModel`. +* A [mutex](https://en.wikipedia.org/wiki/Lock_\(computer_science\)) for the + helper methods to use for exclusive access to the + [`absl::flat_hash_map`](https://abseil.io/docs/cpp/guides/container#abslflat_hash_map-and-abslflat_hash_set). + This ensures thread safety by ensuring that only one thread at a time can + access the data in the hash table. + +### Step 1 - Define the op interface + +Define op interfaces and register all the custom ops you create for the hash +table. You typically define one custom op to create the resource. You also +define one or more custom ops that correspond to operations on the data +structure. You also define custom ops to perform import and export operations +that input/output the whole internal state. These ops are optional; you define +them in this example to support `SavedModel`. As the resource is automatically +deleted based on ref-counting, there is no custom op required to delete the +resource. + +The `simple_hash_table` has kernels for the `create`, `insert`, `find`, +`remove`, `import`, and `export` ops which use the resource object that actually +stores and manipulates data. The resource is passed between ops using a resource +handle. The `create` op has an `Output` of type `resource`. The other ops have +an `Input` of type `resource`. The interface definitions for all the ops along +with their shape functions are in +`simple_hash_table_op.cc`. + +Note that these definitions do not explicitly use `SetIsStateful`. The +`IsStateful` flag is set automatically for any op with an input or output of +type `DT_RESOURCE`. + +The definitions below and their corresponding kernels and generated Python +wrappers illustrate the following cases: + +* 0, 1, 2, and 3 inputs +* 0, 1, and 2 outputs +* `Attr` for types where none are used by `Input` or `Output` and all have to + be explicitly passed into the generated wrapper (for example, + `SimpleHashTableCreate`) +* `Attr` for types where all are used by `Input` and/or `Output` (so they are + set implicitly inside the generated wrapper) and none are passed into the + generated wrapper (for example, `SimpleHashTableFind`) +* `Attr` for types where some but not all are used by `Input` or `Output` and + only those that are not used are explicitly passed into the generated + wrapper (for example, `SimpleHashTableRemove` where there is an `Input` that + uses `key_dtype` but the `value_type` `Attr` is a parameter to the generated + wrapper). + +``` +REGISTER_OP("Examples>SimpleHashTableCreate") + .Output("output: resource") + .Attr("key_dtype: type") + .Attr("value_dtype: type") + .SetShapeFn(ScalarOutput); +``` +``` +REGISTER_OP("Examples>SimpleHashTableFind") + .Input("resource_handle: resource") + .Input("key: key_dtype") + .Input("default_value: value_dtype") + .Output("value: value_dtype") + .Attr("key_dtype: type") + .Attr("value_dtype: type") + .SetShapeFn(ThreeScalarInputsScalarOutput); +``` +``` +REGISTER_OP("Examples>SimpleHashTableInsert") + .Input("resource_handle: resource") + .Input("key: key_dtype") + .Input("value: value_dtype") + .Attr("key_dtype: type") + .Attr("value_dtype: type") + .SetShapeFn(ThreeScalarInputs); +``` +``` +REGISTER_OP("Examples>SimpleHashTableRemove") + .Input("resource_handle: resource") + .Input("key: key_dtype") + .Attr("key_dtype: type") + .Attr("value_dtype: type") + .SetShapeFn(TwoScalarInputs); +``` +``` +REGISTER_OP("Examples>SimpleHashTableExport") + .Input("table_handle: resource") + .Output("keys: key_dtype") + .Output("values: value_dtype") + .Attr("key_dtype: type") + .Attr("value_dtype: type") + .SetShapeFn(ExportShapeFunction); +``` +``` +REGISTER_OP("Examples>SimpleHashTableImport") + .Input("table_handle: resource") + .Input("keys: key_dtype") + .Input("values: value_dtype") + .Attr("key_dtype: type") + .Attr("value_dtype: type") + .SetShapeFn(ImportShapeFunction); +``` + +### Step 2 - Register the op implementation (kernel) + +Declare kernels for specific types of key-value pairs. Register the kernel by +calling the `REGISTER_KERNEL_BUILDER` macro. + +``` +#define REGISTER_KERNEL(key_dtype, value_dtype) \ + REGISTER_KERNEL_BUILDER( \ + Name("Examples>SimpleHashTableCreate") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("key_dtype") \ + .TypeConstraint("value_dtype"), \ + SimpleHashTableCreateOpKernel); \ + REGISTER_KERNEL_BUILDER( \ + Name("Examples>SimpleHashTableFind") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("key_dtype") \ + .TypeConstraint("value_dtype"), \ + SimpleHashTableFindOpKernel); \ + REGISTER_KERNEL_BUILDER( \ + Name("Examples>SimpleHashTableInsert") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("key_dtype") \ + .TypeConstraint("value_dtype"), \ + SimpleHashTableInsertOpKernel) \ + REGISTER_KERNEL_BUILDER( \ + Name("Examples>SimpleHashTableRemove") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("key_dtype") \ + .TypeConstraint("value_dtype"), \ + SimpleHashTableRemoveOpKernel) \ + REGISTER_KERNEL_BUILDER( \ + Name("Examples>SimpleHashTableExport") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("key_dtype") \ + .TypeConstraint("value_dtype"), \ + SimpleHashTableExportOpKernel) \ + REGISTER_KERNEL_BUILDER( \ + Name("Examples>SimpleHashTableImport") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("key_dtype") \ + .TypeConstraint("value_dtype"), \ + SimpleHashTableImportOpKernel); +``` +``` +REGISTER_KERNEL(int32, double); +REGISTER_KERNEL(int32, float); +REGISTER_KERNEL(int32, int32); +REGISTER_KERNEL(int32, tstring); +REGISTER_KERNEL(int64_t, double); +REGISTER_KERNEL(int64_t, float); +REGISTER_KERNEL(int64_t, int32); +REGISTER_KERNEL(int64_t, int64_t); +REGISTER_KERNEL(int64_t, tstring); +REGISTER_KERNEL(tstring, bool); +REGISTER_KERNEL(tstring, double); +REGISTER_KERNEL(tstring, float); +REGISTER_KERNEL(tstring, int32); +REGISTER_KERNEL(tstring, int64_t); +REGISTER_KERNEL(tstring, tstring); +``` + +### Step 3 - Implement the op kernel(s) + +The implementation of kernels for the ops in the hash table all use helper +functions from the `SimpleHashTableResource` resource class. + +You implement the op kernels in two phases: + +1. Implement the kernel for the `create` op that has a `Compute` method that + creates a resource object and a ref-counted handle for it. + + + ```c++ + handle_tensor.scalar()() = + ResourceHandle::MakeRefCountingHandle( + new SimpleHashTableResource(), /* … */); + ``` + +1. Implement the kernels(s) for each of the other operations on the resource + that have Compute methods that get a resource object and use one or more of + its helper methods. + + + ```c++ + MyResource* resource; + OP_REQUIRES_OK(ctx, GetResource(ctx, &resource)); + // The GetResource local function uses handle.GetResource() + OP_REQUIRES_OK(ctx, resource->Find(key, out, default_value)); + ``` + +#### Creating the resource + +The `create` op creates a resource object of type `SimpleHashTableResource` and +then uses `MakeRefCountingHandle` to pass the ownership to a resource handle. +This op outputs a `resource` handle. + +``` +template +class SimpleHashTableCreateOpKernel : public OpKernel { + public: + explicit SimpleHashTableCreateOpKernel(OpKernelConstruction* ctx) + : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + Tensor handle_tensor; + AllocatorAttributes attr; + OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_RESOURCE, TensorShape({}), + &handle_tensor, attr)); + handle_tensor.scalar()() = + ResourceHandle::MakeRefCountingHandle( + new SimpleHashTableResource(), ctx->device()->name(), + /*dtypes_and_shapes=*/{}, ctx->stack_trace()); + ctx->set_output(0, handle_tensor); + } + + private: + // Just to be safe, avoid accidentally copying the kernel. + SimpleHashTableCreateOpKernel(const SimpleHashTableCreateOpKernel&) = delete; + void operator=(const SimpleHashTableCreateOpKernel&) = delete; +}; +``` + +#### Getting the resource + +In +`simple_hash_table_kernel.cc`, +the `GetResource` helper function uses an input `resource` handle to retrieve +the corresponding `SimpleHashTableResource` object. It is used by all the custom +ops that use the resource (that is, all the custom ops in the set other than +`create`). + +``` +template +Status GetResource(OpKernelContext* ctx, + SimpleHashTableResource** resource) { + const Tensor& handle_tensor = ctx->input(0); + const ResourceHandle& handle = handle_tensor.scalar()(); + typedef SimpleHashTableResource resource_type; + TF_ASSIGN_OR_RETURN(*resource, handle.GetResource()); + return OkStatus(); +} +``` + +#### Using the resource + +The ops that use the resource use `GetResource` to get a pointer to the resource +object and call the corresponding helper function for that object. Below is the +source code for the `find` op which uses `resource->Find`. The other ops that +use the resource similarly use the corresponding helper method in the resource +class. + +``` +template +class SimpleHashTableFindOpKernel : public OpKernel { + public: + explicit SimpleHashTableFindOpKernel(OpKernelConstruction* ctx) + : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + DataTypeVector expected_inputs = {DT_RESOURCE, DataTypeToEnum::v(), + DataTypeToEnum::v()}; + DataTypeVector expected_outputs = {DataTypeToEnum::v()}; + OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, expected_outputs)); + SimpleHashTableResource* resource; + OP_REQUIRES_OK(ctx, GetResource(ctx, &resource)); + // Note that ctx->input(0) is the Resource handle + const Tensor& key = ctx->input(1); + const Tensor& default_value = ctx->input(2); + TensorShape output_shape = default_value.shape(); + Tensor* out; + OP_REQUIRES_OK(ctx, ctx->allocate_output("value", output_shape, &out)); + OP_REQUIRES_OK(ctx, resource->Find(key, out, default_value)); + } +}; +``` + +#### Compile the op + +Compile the C++ op to create a kernel library and Python wrapper that enables +you to use the op with TensorFlow. + +Create a `BUILD` file for the op which declares the dependencies and the output +build targets. Refer to +[building for OSS](https://www.tensorflow.org/guide/create_op#build_the_op_library). + +Note +that you will be reusing this `BUILD` file later on in this example. + +``` +tf_custom_op_library( + name = "simple_hash_table_kernel.so", + srcs = [ + "simple_hash_table_kernel.cc", + "simple_hash_table_op.cc", + ], + deps = [ + "//third_party/absl/container:flat_hash_map", + "//third_party/tensorflow/core/lib/gtl:map_util", + "//third_party/tensorflow/core/platform:strcat", + ], +) + +py_strict_library( + name = "simple_hash_table_op", + srcs = ["simple_hash_table_op.py"], + data = ["simple_hash_table_kernel.so"], + srcs_version = "PY3", + deps = [ + "//third_party/py/tensorflow", + ], +) + +py_strict_library( + name = "simple_hash_table", + srcs = ["simple_hash_table.py"], + srcs_version = "PY3", + deps = [ + ":simple_hash_table_op", + "//third_party/py/tensorflow", + ], +) +``` + +### Step 4 - Create the Python wrapper + +To create the Python wrapper, import and implement a function that serves as the +op's public API and provides a docstring. + +The Python wrapper for this example, `simple_hash_table.py` uses the +`SimpleHashTable` class to provide methods that allow access to the data +structure. The custom ops are used to implement these methods. This class also +supports `SavedModel` which allows the state of the resource to be saved and +restored for checkpointing. The class is derived from the +`tf.saved_model.experimental.TrackableResource` base class and implements the +`_serialize_to_tensors` (which uses the `export` op) and `_restore_from_tensors` +(which uses the `import` op) methods defined by this base class. The `__init__` +method of this class calls the `_create_resource` method (which is an override +of a base class method required for `SavedModel`) which in turn calls the C++ +`examples_simple_hash_table_create` kernel. The resource handle returned by this +kernel is stored in the private `self._resource_handle` object member. Methods +that use this handle access it using the public `self.resource_handle` property +provided by the `tf.saved_model.experimental.TrackableResource` base class. + +``` + def _create_resource(self): + """Create the resource tensor handle. + + `_create_resource` is an override of a method in base class + `TrackableResource` that is required for SavedModel support. It can be + called by the `resource_handle` property defined by `TrackableResource`. + + Returns: + A tensor handle to the lookup table. + """ + assert self._default_value.get_shape().ndims == 0 + table_ref = gen_simple_hash_table_op.examples_simple_hash_table_create( + key_dtype=self._key_dtype, + value_dtype=self._value_dtype, + name=self._name) + return table_ref + + def _serialize_to_tensors(self): + """Implements checkpointing protocols for `Trackable`.""" + tensors = self.export() + return {"table-keys": tensors[0], "table-values": tensors[1]} + + def _restore_from_tensors(self, restored_tensors): + """Implements checkpointing protocols for `Trackable`.""" + return gen_simple_hash_table_op.examples_simple_hash_table_import( + self.resource_handle, restored_tensors["table-keys"], + restored_tensors["table-values"]) +``` + +The `find` method in this class calls the `examples_simple_hash_table_find` +custom op using the reference handle (from the public `self.resource_handle` +property) and a key and default value. The other methods are similar. It is +recommended that methods for using a resource avoid logic other than a call to a +generated Python wrapper to avoid eager/`tf.function` inconsistencies (avoid +Python logic that is lost when a `tf.function` is created). In the case of +`find`, using `tf.convert_to_tensor` cannot be avoided and is not lost during +`tf.function` creation. + +``` + def find(self, key, dynamic_default_value=None, name=None): + """Looks up `key` in a table, outputs the corresponding value. + + The `default_value` is used if key not present in the table. + + Args: + key: Key to look up. Must match the table's key_dtype. + dynamic_default_value: The value to use if the key is missing in the + table. If None (by default), the `table.default_value` will be used. + name: A name for the operation (optional). + + Returns: + A tensor containing the value in the same shape as `key` using the + table's value type. + + Raises: + TypeError: when `key` do not match the table data types. + """ + with tf.name_scope(name or "%s_lookup_table_find" % self._name): + key = tf.convert_to_tensor(key, dtype=self._key_dtype, name="key") + if dynamic_default_value is not None: + dynamic_default_value = tf.convert_to_tensor( + dynamic_default_value, + dtype=self._value_dtype, + name="default_value") + value = gen_simple_hash_table_op.examples_simple_hash_table_find( + self.resource_handle, key, dynamic_default_value + if dynamic_default_value is not None else self._default_value) + return value +``` + +The Python wrapper specifies that gradients are not implemented in this example. +For an example of a differentiable map. For general information about gradients, +read +[Implement the gradient in Python](https://www.tensorflow.org/guide/create_op#implement_the_gradient_in_python) +in the OSS guide. + +``` +tf.no_gradient("Examples>SimpleHashTableCreate") +tf.no_gradient("Examples>SimpleHashTableFind") +tf.no_gradient("Examples>SimpleHashTableInsert") +tf.no_gradient("Examples>SimpleHashTableRemove") +``` + +The full source code for the Python wrapper is in +`simple_hash_table_op.py`] +and +`simple_hash_table.py`. + +### Step 5 - Test the op + +Create op tests using classes derived from +[`tf.test.TestCase`](https://www.tensorflow.org/api_docs/python/tf/test/TestCase) +and the +[parameterized tests provided by Abseil](https://github.com/abseil/abseil-py/blob/main/absl/testing/parameterized.py). + +Create tests using three or more custom ops to create `SimpleHashTable` objects +and: + +1. Update the state using the `insert`, `remove`, and/or `import` methods +1. Observe the state using the `find` and/or `export` methods + +Here is an example test: + +``` + def test_find_insert_find_strings_eager(self): + default = 'Default' + foo = 'Foo' + bar = 'Bar' + hash_table = simple_hash_table.SimpleHashTable(tf.string, tf.string, + default) + result1 = hash_table.find(foo, default) + self.assertEqual(result1, default) + hash_table.insert(foo, bar) + result2 = hash_table.find(foo, default) + self.assertEqual(result2, bar) +``` + +Create a helper function to work with the hash table: + +``` + def _use_table(self, key_dtype, value_dtype): + hash_table = simple_hash_table.SimpleHashTable(key_dtype, value_dtype, 111) + result1 = hash_table.find(1, -999) + hash_table.insert(1, 100) + result2 = hash_table.find(1, -999) + hash_table.remove(1) + result3 = hash_table.find(1, -999) + results = tf.stack((result1, result2, result3)) + return results # expect [-999, 100, -999] +``` + +The following test explicitly creates a `tf.function` with `_use_table`. +Ref-counting causes the C++ resource object created in `_use_table` to be +destroyed when this `tf.function` returns inside `self.evaluate(results)`. By +explicitly creating the `tf.function` instead of relying on decorators like +`@test_util.with_eager_op_as_function` and +`@test_util.run_in_graph_and_eager_modes` (such as in +`multiplex_1`), +you have an explicit place in the test corresponding to where the C++ resource's +destructor is called. + +This test also shows a test that is parameterized for different data types; it +is actually two tests, one with `tf.int32` input / `float` output and the other +with `tf.int64` input / `tf.int32` output. + +``` + def test_find_insert_find_tf_function(self, key_dtype, value_dtype): + results = def_function.function( + lambda: self._use_table(key_dtype, value_dtype)) + self.assertAllClose(self.evaluate(results), [-999.0, 100.0, -999.0]) +``` + +Reuse the `BUILD` file to add build +rules for the Python API wrapper and the op test. + +``` +tf_custom_op_library( + name = "simple_hash_table_kernel.so", + srcs = [ + "simple_hash_table_kernel.cc", + "simple_hash_table_op.cc", + ], + deps = [ + "//third_party/absl/container:flat_hash_map", + "//third_party/tensorflow/core/lib/gtl:map_util", + "//third_party/tensorflow/core/platform:strcat", + ], +) + +py_strict_library( + name = "simple_hash_table_op", + srcs = ["simple_hash_table_op.py"], + data = ["simple_hash_table_kernel.so"], + srcs_version = "PY3", + deps = [ + "//third_party/py/tensorflow", + ], +) + +py_strict_library( + name = "simple_hash_table", + srcs = ["simple_hash_table.py"], + srcs_version = "PY3", + deps = [ + ":simple_hash_table_op", + "//third_party/py/tensorflow", + ], +) + +tf_py_test( + name = "simple_hash_table_test", + size = "medium", # This test blocks because it writes and reads a file, + timeout = "short", # but it still runs quickly. + srcs = ["simple_hash_table_test.py"], + python_version = "PY3", + srcs_version = "PY3", + tags = [ + "no_mac", # TODO(b/216321151): Re-enable this test. + ], + deps = [ + ":simple_hash_table", + "//third_party/py/numpy", + "//third_party/py/tensorflow", + "//third_party/tensorflow/python/framework:errors", + "//third_party/tensorflow/python/framework:test_lib", + ], +) +``` + +Test the op by running: + + +```shell +$ bazel test //third_party/tensorflow/google/g3doc/example/simple_hash_table:simple_hash_table_test +``` + +### Use the op + +Use the op by importing and calling it as follows: + + +```python +import tensorflow as tf + +from tensorflow.examples.custom_ops_doc.simple_hash_table import simple_hash_table + +hash_table = simple_hash_table.SimpleHashTable(tf.int32, float, -999.0) +result1 = hash_table.find(1, -999.0) # -999.0 +hash_table.insert(1, 100.0) +result2 = hash_table.find(1, -999.0) # 100.0 +``` + +Here, `simple_hash_table` is the name of the Python wrapper that was created +in this example. + +### Summary + +In this example, you learned how to implement a simple hash table data structure +using stateful custom ops. + +The table below summarizes the build rules and targets for building and testing +the `simple_hash_table` op. + +Op components | Build rule | Build target | Source +--------------------------------------- | ---------------------- | -------------------------- | ------ +Kernels (C++) | `tf_custom_op_library` | `simple_hash_table_kernel` | `simple_hash_table_kernel.cc`, `simple_hash_table_op.cc` +Wrapper (automatically generated) | N/A. | `gen_simple_hash_table_op` | N/A +Wrapper (with public API and docstring) | `py_strict_library` | `simple_hash_table_op`, `simple_hash_table` | `simple_hash_table_op.py`, `simple_hash_table.py` +Tests | `tf_py_test` | `simple_hash_table_test` | `simple_hash_table_test.py` + + diff --git a/deepray/custom_ops/simple_hash_table/__init__.py b/deepray/custom_ops/simple_hash_table/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/deepray/custom_ops/simple_hash_table/simple_hash_table.py b/deepray/custom_ops/simple_hash_table/simple_hash_table.py new file mode 100644 index 0000000..a11e9d9 --- /dev/null +++ b/deepray/custom_ops/simple_hash_table/simple_hash_table.py @@ -0,0 +1,244 @@ +# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A wrapper for gen_simple_hash_table_op.py. + +This defines a public API and provides a docstring for the C++ Op defined by +simple_hash_table_kernel.cc +""" + +import tensorflow as tf +from deepray.custom_ops.simple_hash_table.simple_hash_table_op import gen_simple_hash_table_op + + + +class SimpleHashTable(tf.saved_model.experimental.TrackableResource): + """A simple mutable hash table implementation. + + Implement a simple hash table as a Resource using ref-counting. + This demonstrates a Stateful Op for a general Create/Read/Update/Delete + (CRUD) style use case. To instead make an op for a specific lookup table + case, it is preferable to follow the implementation style of + TensorFlow's internal ops, e.g. use LookupInterface. + + Data can be inserted by calling the `insert` method and removed by calling + the `remove` method. It does not support initialization via the init method. + + The `import` and `export` methods allow loading and restoring all of + the key, value pairs. These methods (or their corresponding kernels) + are intended to be used for supporting SavedModel. + + Example usage: + hash_table = simple_hash_table_op.SimpleHashTable(key_dtype, value_dtype, + 111) + result1 = hash_table.find(1, -999) # -999 + hash_table.insert(1, 100) + result2 = hash_table.find(1, -999) # 100 + hash_table.remove(1) + result3 = hash_table.find(1, -999) # -999 + """ + + def __init__(self, + key_dtype, + value_dtype, + default_value, + name="SimpleHashTable"): + """Creates an empty `SimpleHashTable` object. + + Creates a table, the type of its keys and values are specified by key_dtype + and value_dtype, respectively. + + Args: + key_dtype: the type of the key tensors. + value_dtype: the type of the value tensors. + default_value: The value to use if a key is missing in the table. + name: A name for the operation (optional). + + Returns: + A `SimpleHashTable` object. + """ + super(SimpleHashTable, self).__init__() + self._default_value = tf.convert_to_tensor(default_value, dtype=value_dtype) + self._value_shape = self._default_value.get_shape() + self._key_dtype = key_dtype + self._value_dtype = value_dtype + self._name = name + self._resource_handle = self._create_resource() + # Methods that use the Resource get its handle using the + # public self.resource_handle property (defined by TrackableResource). + # This property calls self._create_resource() the first time + # if private self._resource_handle is not preemptively initialized. + + def _create_resource(self): + """Create the resource tensor handle. + + `_create_resource` is an override of a method in base class + `TrackableResource` that is required for SavedModel support. It can be + called by the `resource_handle` property defined by `TrackableResource`. + + Returns: + A tensor handle to the lookup table. + """ + assert self._default_value.get_shape().ndims == 0 + table_ref = gen_simple_hash_table_op.examples_simple_hash_table_create( + key_dtype=self._key_dtype, + value_dtype=self._value_dtype, + name=self._name) + return table_ref + + def _serialize_to_tensors(self): + """Implements checkpointing protocols for `Trackable`.""" + tensors = self.export() + return {"table-keys": tensors[0], "table-values": tensors[1]} + + def _restore_from_tensors(self, restored_tensors): + """Implements checkpointing protocols for `Trackable`.""" + return gen_simple_hash_table_op.examples_simple_hash_table_import( + self.resource_handle, restored_tensors["table-keys"], + restored_tensors["table-values"]) + + @property + def key_dtype(self): + """The table key dtype.""" + return self._key_dtype + + @property + def value_dtype(self): + """The table value dtype.""" + return self._value_dtype + + def find(self, key, dynamic_default_value=None, name=None): + """Looks up `key` in a table, outputs the corresponding value. + + The `default_value` is used if key not present in the table. + + Args: + key: Key to look up. Must match the table's key_dtype. + dynamic_default_value: The value to use if the key is missing in the + table. If None (by default), the `table.default_value` will be used. + name: A name for the operation (optional). + + Returns: + A tensor containing the value in the same shape as `key` using the + table's value type. + + Raises: + TypeError: when `key` do not match the table data types. + """ + with tf.name_scope(name or "%s_lookup_table_find" % self._name): + key = tf.convert_to_tensor(key, dtype=self._key_dtype, name="key") + if dynamic_default_value is not None: + dynamic_default_value = tf.convert_to_tensor( + dynamic_default_value, + dtype=self._value_dtype, + name="default_value") + value = gen_simple_hash_table_op.examples_simple_hash_table_find( + self.resource_handle, key, dynamic_default_value + if dynamic_default_value is not None else self._default_value) + return value + + def insert(self, key, value, name=None): + """Associates `key` with `value`. + + Args: + key: Scalar key to insert. + value: Scalar value to be associated with key. + name: A name for the operation (optional). + + Returns: + The created Operation. + + Raises: + TypeError: when `key` or `value` doesn't match the table data + types. + """ + with tf.name_scope(name or "%s_lookup_table_insert" % self._name): + key = tf.convert_to_tensor(key, self._key_dtype, name="key") + value = tf.convert_to_tensor(value, self._value_dtype, name="value") + # pylint: disable=protected-access + op = gen_simple_hash_table_op.examples_simple_hash_table_insert( + self.resource_handle, key, value) + return op + + def remove(self, key, name=None): + """Remove `key`. + + Args: + key: Scalar key to remove. + name: A name for the operation (optional). + + Returns: + The created Operation. + + Raises: + TypeError: when `key` doesn't match the table data type. + """ + with tf.name_scope(name or "%s_lookup_table_remove" % self._name): + key = tf.convert_to_tensor(key, self._key_dtype, name="key") + + # For remove, just the key is used by the kernel; no value is used. + # But the kernel is specifc to key_dtype and value_dtype + # (i.e. it uses a template). + # So value_dtype is passed in explicitly. (While + # key_dtype is specificed implicitly by the dtype of key.) + + # pylint: disable=protected-access + op = gen_simple_hash_table_op.examples_simple_hash_table_remove( + self.resource_handle, key, value_dtype=self._value_dtype) + return op + + def export(self, name=None): + """Export all `key` and `value` pairs. + + Args: + name: A name for the operation (optional). + + Returns: + A tuple of two tensors, the first with the `keys` and the second with + the `values`. + """ + with tf.name_scope(name or "%s_lookup_table_export" % self._name): + # pylint: disable=protected-access + keys, values = gen_simple_hash_table_op.examples_simple_hash_table_export( + self.resource_handle, + key_dtype=self._key_dtype, + value_dtype=self._value_dtype) + return keys, values + + def do_import(self, keys, values, name=None): + """Import all `key` and `value` pairs. + + (Note that "import" is a python reserved word, so it cannot be the name of + a method.) + + Args: + keys: Tensor of all keys. + values: Tensor of all values. + name: A name for the operation (optional). + + Returns: + A tuple of two tensors, the first with the `keys` and the second with + the `values`. + """ + with tf.name_scope(name or "%s_lookup_table_import" % self._name): + # pylint: disable=protected-access + op = gen_simple_hash_table_op.examples_simple_hash_table_import( + self.resource_handle, keys, values) + return op + + +tf.no_gradient("Examples>SimpleHashTableCreate") +tf.no_gradient("Examples>SimpleHashTableFind") +tf.no_gradient("Examples>SimpleHashTableInsert") +tf.no_gradient("Examples>SimpleHashTableRemove") diff --git a/deepray/custom_ops/simple_hash_table/simple_hash_table_kernel.cc b/deepray/custom_ops/simple_hash_table/simple_hash_table_kernel.cc new file mode 100644 index 0000000..034c10a --- /dev/null +++ b/deepray/custom_ops/simple_hash_table/simple_hash_table_kernel.cc @@ -0,0 +1,336 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/strings/str_cat.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/platform/strcat.h" + +// Please use the appropriate namespace for your project +namespace tensorflow { +namespace custom_op_examples { + +// Implement a simple hash table as a Resource using ref-counting. +// This demonstrates a Stateful Op for a general Create/Read/Update/Delete +// (CRUD) style use case. To instead make an op for a specific lookup table +// case, it is preferable to follow the implementation style of +// the kernels for TensorFlow's tf.lookup internal ops which use +// LookupInterface. +template +class SimpleHashTableResource : public ::tensorflow::ResourceBase { + public: + Status Insert(const Tensor& key, const Tensor& value) { + const K key_val = key.flat()(0); + const V value_val = value.flat()(0); + + mutex_lock l(mu_); + table_[key_val] = value_val; + return Status::OK(); + } + + Status Find(const Tensor& key, Tensor* value, const Tensor& default_value) { + // Note that tf_shared_lock could be used instead of mutex_lock + // in ops that do not not modify data protected by a mutex, but + // go/totw/197 recommends using exclusive lock instead of a shared + // lock when the lock is not going to be held for a significant amount + // of time. + mutex_lock l(mu_); + + const V default_val = default_value.flat()(0); + const K key_val = key.flat()(0); + auto value_val = value->flat(); + value_val(0) = gtl::FindWithDefault(table_, key_val, default_val); + return Status::OK(); + } + + Status Remove(const Tensor& key) { + mutex_lock l(mu_); + + const K key_val = key.flat()(0); + if (table_.erase(key_val) != 1) { + return errors::NotFound("Key for remove not found: ", key_val); + } + return Status::OK(); + } + + // Save all key, value pairs to tensor outputs to support SavedModel + Status Export(OpKernelContext* ctx) { + mutex_lock l(mu_); + int64_t size = table_.size(); + Tensor* keys; + Tensor* values; + TF_RETURN_IF_ERROR( + ctx->allocate_output("keys", TensorShape({size}), &keys)); + TF_RETURN_IF_ERROR( + ctx->allocate_output("values", TensorShape({size}), &values)); + auto keys_data = keys->flat(); + auto values_data = values->flat(); + int64_t i = 0; + for (auto it = table_.begin(); it != table_.end(); ++it, ++i) { + keys_data(i) = it->first; + values_data(i) = it->second; + } + return Status::OK(); + } + + // Load all key, value pairs from tensor inputs to support SavedModel + Status Import(const Tensor& keys, const Tensor& values) { + const auto key_values = keys.flat(); + const auto value_values = values.flat(); + + mutex_lock l(mu_); + table_.clear(); + for (int64_t i = 0; i < key_values.size(); ++i) { + gtl::InsertOrUpdate(&table_, key_values(i), value_values(i)); + } + return Status::OK(); + } + + // Create a debug string with the content of the map if this is small, + // or some example data if this is large, handling both the cases where the + // hash table has many entries and where the entries are long strings. + std::string DebugString() const override { return DebugString(3); } + std::string DebugString(int num_pairs) const { + std::string rval = "SimpleHashTable {"; + size_t count = 0; + const size_t max_kv_str_len = 100; + mutex_lock l(mu_); + for (const auto& pair : table_) { + if (count >= num_pairs) { + strings::StrAppend(&rval, "..."); + break; + } + std::string kv_str = strings::StrCat(pair.first, ": ", pair.second); + strings::StrAppend(&rval, kv_str.substr(0, max_kv_str_len)); + if (kv_str.length() > max_kv_str_len) strings::StrAppend(&rval, " ..."); + strings::StrAppend(&rval, ", "); + count += 1; + } + strings::StrAppend(&rval, "}"); + return rval; + } + + private: + mutable mutex mu_; + absl::flat_hash_map table_ TF_GUARDED_BY(mu_); +}; + +template +class SimpleHashTableCreateOpKernel : public OpKernel { + public: + explicit SimpleHashTableCreateOpKernel(OpKernelConstruction* ctx) + : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + Tensor handle_tensor; + AllocatorAttributes attr; + OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_RESOURCE, TensorShape({}), + &handle_tensor, attr)); + handle_tensor.scalar()() = + ResourceHandle::MakeRefCountingHandle( + new SimpleHashTableResource(), ctx->device()->name(), + /*dtypes_and_shapes=*/{}, ctx->stack_trace()); + ctx->set_output(0, handle_tensor); + } + + private: + // Just to be safe, avoid accidentally copying the kernel. + SimpleHashTableCreateOpKernel(const SimpleHashTableCreateOpKernel&) = delete; + void operator=(const SimpleHashTableCreateOpKernel&) = delete; +}; + +// GetResource retrieves a Resource using a handle from the first +// input in "ctx" and saves it in "resource" without increasing +// the reference count for that resource. +template +Status GetResource(OpKernelContext* ctx, + SimpleHashTableResource** resource) { + const Tensor& handle_tensor = ctx->input(0); + const ResourceHandle& handle = handle_tensor.scalar()(); + typedef SimpleHashTableResource resource_type; + TF_ASSIGN_OR_RETURN(*resource, handle.GetResource()); + return Status::OK(); +} + +template +class SimpleHashTableFindOpKernel : public OpKernel { + public: + explicit SimpleHashTableFindOpKernel(OpKernelConstruction* ctx) + : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + DataTypeVector expected_inputs = {DT_RESOURCE, DataTypeToEnum::v(), + DataTypeToEnum::v()}; + DataTypeVector expected_outputs = {DataTypeToEnum::v()}; + OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, expected_outputs)); + SimpleHashTableResource* resource; + OP_REQUIRES_OK(ctx, GetResource(ctx, &resource)); + // Note that ctx->input(0) is the Resource handle + const Tensor& key = ctx->input(1); + const Tensor& default_value = ctx->input(2); + TensorShape output_shape = default_value.shape(); + Tensor* out; + OP_REQUIRES_OK(ctx, ctx->allocate_output("value", output_shape, &out)); + OP_REQUIRES_OK(ctx, resource->Find(key, out, default_value)); + } +}; + +template +class SimpleHashTableInsertOpKernel : public OpKernel { + public: + explicit SimpleHashTableInsertOpKernel(OpKernelConstruction* ctx) + : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + DataTypeVector expected_inputs = {DT_RESOURCE, DataTypeToEnum::v(), + DataTypeToEnum::v()}; + OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, {})); + SimpleHashTableResource* resource; + OP_REQUIRES_OK(ctx, GetResource(ctx, &resource)); + // Note that ctx->input(0) is the Resource handle + const Tensor& key = ctx->input(1); + const Tensor& value = ctx->input(2); + OP_REQUIRES_OK(ctx, resource->Insert(key, value)); + } +}; + +template +class SimpleHashTableRemoveOpKernel : public OpKernel { + public: + explicit SimpleHashTableRemoveOpKernel(OpKernelConstruction* ctx) + : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + DataTypeVector expected_inputs = {DT_RESOURCE, DataTypeToEnum::v()}; + OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, {})); + SimpleHashTableResource* resource; + OP_REQUIRES_OK(ctx, GetResource(ctx, &resource)); + // Note that ctx->input(0) is the Resource handle + const Tensor& key = ctx->input(1); + OP_REQUIRES_OK(ctx, resource->Remove(key)); + } +}; + +template +class SimpleHashTableExportOpKernel : public OpKernel { + public: + explicit SimpleHashTableExportOpKernel(OpKernelConstruction* ctx) + : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + DataTypeVector expected_inputs = {DT_RESOURCE}; + DataTypeVector expected_outputs = {DataTypeToEnum::v(), + DataTypeToEnum::v()}; + OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, expected_outputs)); + SimpleHashTableResource* resource; + OP_REQUIRES_OK(ctx, GetResource(ctx, &resource)); + OP_REQUIRES_OK(ctx, resource->Export(ctx)); + } +}; + +template +class SimpleHashTableImportOpKernel : public OpKernel { + public: + explicit SimpleHashTableImportOpKernel(OpKernelConstruction* ctx) + : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + SimpleHashTableResource* resource; + OP_REQUIRES_OK(ctx, GetResource(ctx, &resource)); + DataTypeVector expected_inputs = {DT_RESOURCE, DataTypeToEnum::v(), + DataTypeToEnum::v()}; + OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, {})); + const Tensor& keys = ctx->input(1); + const Tensor& values = ctx->input(2); + OP_REQUIRES( + ctx, keys.shape() == values.shape(), + errors::InvalidArgument("Shapes of keys and values are not the same: ", + keys.shape().DebugString(), " for keys, ", + values.shape().DebugString(), "for values.")); + + int memory_used_before = 0; + if (ctx->track_allocations()) { + memory_used_before = resource->MemoryUsed(); + } + OP_REQUIRES_OK(ctx, resource->Import(keys, values)); + if (ctx->track_allocations()) { + ctx->record_persistent_memory_allocation(resource->MemoryUsed() - + memory_used_before); + } + } +}; + +// The "Name" used by REGISTER_KERNEL_BUILDER is defined by REGISTER_OP, +// see simple_hash_table_op.cc. +#define REGISTER_KERNEL(key_dtype, value_dtype) \ + REGISTER_KERNEL_BUILDER( \ + Name("Examples>SimpleHashTableCreate") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("key_dtype") \ + .TypeConstraint("value_dtype"), \ + SimpleHashTableCreateOpKernel); \ + REGISTER_KERNEL_BUILDER( \ + Name("Examples>SimpleHashTableFind") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("key_dtype") \ + .TypeConstraint("value_dtype"), \ + SimpleHashTableFindOpKernel); \ + REGISTER_KERNEL_BUILDER( \ + Name("Examples>SimpleHashTableInsert") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("key_dtype") \ + .TypeConstraint("value_dtype"), \ + SimpleHashTableInsertOpKernel) \ + REGISTER_KERNEL_BUILDER( \ + Name("Examples>SimpleHashTableRemove") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("key_dtype") \ + .TypeConstraint("value_dtype"), \ + SimpleHashTableRemoveOpKernel) \ + REGISTER_KERNEL_BUILDER( \ + Name("Examples>SimpleHashTableExport") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("key_dtype") \ + .TypeConstraint("value_dtype"), \ + SimpleHashTableExportOpKernel) \ + REGISTER_KERNEL_BUILDER( \ + Name("Examples>SimpleHashTableImport") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("key_dtype") \ + .TypeConstraint("value_dtype"), \ + SimpleHashTableImportOpKernel); + +REGISTER_KERNEL(int32, double); +REGISTER_KERNEL(int32, float); +REGISTER_KERNEL(int32, int32); +REGISTER_KERNEL(int32, tstring); +REGISTER_KERNEL(int64_t, double); +REGISTER_KERNEL(int64_t, float); +REGISTER_KERNEL(int64_t, int32); +REGISTER_KERNEL(int64_t, int64_t); +REGISTER_KERNEL(int64_t, tstring); +REGISTER_KERNEL(tstring, bool); +REGISTER_KERNEL(tstring, double); +REGISTER_KERNEL(tstring, float); +REGISTER_KERNEL(tstring, int32); +REGISTER_KERNEL(tstring, int64_t); +REGISTER_KERNEL(tstring, tstring); +#undef REGISTER_KERNEL + +} // namespace custom_op_examples +} // namespace tensorflow diff --git a/deepray/custom_ops/simple_hash_table/simple_hash_table_op.cc b/deepray/custom_ops/simple_hash_table/simple_hash_table_op.cc new file mode 100644 index 0000000..cb70bfa --- /dev/null +++ b/deepray/custom_ops/simple_hash_table/simple_hash_table_op.cc @@ -0,0 +1,173 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +// Please use the appropriate namespace for your project +namespace tensorflow { +namespace custom_op_examples { + +using ::tensorflow::shape_inference::DimensionHandle; +using ::tensorflow::shape_inference::InferenceContext; +using ::tensorflow::shape_inference::ShapeAndType; +using ::tensorflow::shape_inference::ShapeHandle; + +Status ScalarOutput(InferenceContext* c) { + c->set_output(0, c->Scalar()); + return Status::OK(); +} + +Status TwoScalarInputs(InferenceContext* c) { + ShapeHandle handle; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &handle)); + return Status::OK(); +} + +Status TwoScalarInputsScalarOutput(InferenceContext* c) { + ShapeHandle handle; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &handle)); + return ScalarOutput(c); +} + +Status ThreeScalarInputs(InferenceContext* c) { + ShapeHandle handle; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &handle)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &handle)); + return Status::OK(); +} + +Status ThreeScalarInputsScalarOutput(InferenceContext* c) { + ShapeHandle handle; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &handle)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &handle)); + return ScalarOutput(c); +} + +Status ValidateTableType(InferenceContext* c, + const ShapeAndType& key_shape_and_type, + const string& key_dtype_attr, + const ShapeAndType& value_shape_and_type, + const string& value_dtype_attr) { + DataType key_dtype; + TF_RETURN_IF_ERROR(c->GetAttr(key_dtype_attr, &key_dtype)); + if (key_shape_and_type.dtype != key_dtype) { + return errors::InvalidArgument( + "Trying to read value with wrong dtype. " + "Expected ", + DataTypeString(key_shape_and_type.dtype), " got ", + DataTypeString(key_dtype)); + } + DataType value_dtype; + TF_RETURN_IF_ERROR(c->GetAttr(value_dtype_attr, &value_dtype)); + if (value_shape_and_type.dtype != value_dtype) { + return errors::InvalidArgument( + "Trying to read value with wrong dtype. " + "Expected ", + DataTypeString(value_shape_and_type.dtype), " got ", + DataTypeString(value_dtype)); + } + return Status::OK(); +} + +Status ExportShapeFunction(InferenceContext* c) { + ShapeHandle handle; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); + auto* handle_data = c->input_handle_shapes_and_types(0); + if (handle_data != nullptr && handle_data->size() == 2) { + const ShapeAndType& key_shape_and_type = (*handle_data)[0]; + const ShapeAndType& value_shape_and_type = (*handle_data)[1]; + TF_RETURN_IF_ERROR(ValidateTableType(c, key_shape_and_type, + /*key_dtype_attr*/ "key_dtype", + value_shape_and_type, + /*value_dtype_attr*/ "value_dtype")); + } + // Different lookup tables have different output shapes. + c->set_output(0, c->UnknownShape()); + c->set_output(1, c->UnknownShape()); + return Status::OK(); +} + +Status ImportShapeFunction(InferenceContext* c) { + ShapeHandle handle; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); + + ShapeHandle keys; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &keys)); + DimensionHandle unused; + TF_RETURN_IF_ERROR( + c->Merge(c->Dim(keys, 0), c->Dim(c->input(2), 0), &unused)); + return Status::OK(); +} + +// Note that if an op has any Input or Output of type "resource", it +// is automatically marked as stateful so there is no need to explicitly +// use "SetIsStateful()". +// (See FinalizeInputOrOutput in core/framework/op_def_builder.cc.) + +REGISTER_OP("Examples>SimpleHashTableCreate") + .Output("output: resource") + .Attr("key_dtype: type") + .Attr("value_dtype: type") + .SetShapeFn(ScalarOutput); + +REGISTER_OP("Examples>SimpleHashTableFind") + .Input("resource_handle: resource") + .Input("key: key_dtype") + .Input("default_value: value_dtype") + .Output("value: value_dtype") + .Attr("key_dtype: type") + .Attr("value_dtype: type") + .SetShapeFn(ThreeScalarInputsScalarOutput); + +REGISTER_OP("Examples>SimpleHashTableInsert") + .Input("resource_handle: resource") + .Input("key: key_dtype") + .Input("value: value_dtype") + .Attr("key_dtype: type") + .Attr("value_dtype: type") + .SetShapeFn(ThreeScalarInputs); + +REGISTER_OP("Examples>SimpleHashTableRemove") + .Input("resource_handle: resource") + .Input("key: key_dtype") + .Attr("key_dtype: type") + .Attr("value_dtype: type") + .SetShapeFn(TwoScalarInputs); + +REGISTER_OP("Examples>SimpleHashTableExport") + .Input("table_handle: resource") + .Output("keys: key_dtype") + .Output("values: value_dtype") + .Attr("key_dtype: type") + .Attr("value_dtype: type") + .SetShapeFn(ExportShapeFunction); + +REGISTER_OP("Examples>SimpleHashTableImport") + .Input("table_handle: resource") + .Input("keys: key_dtype") + .Input("values: value_dtype") + .Attr("key_dtype: type") + .Attr("value_dtype: type") + .SetShapeFn(ImportShapeFunction); + +} // namespace custom_op_examples +} // namespace tensorflow diff --git a/deepray/custom_ops/simple_hash_table/simple_hash_table_op.py b/deepray/custom_ops/simple_hash_table/simple_hash_table_op.py new file mode 100644 index 0000000..54c1232 --- /dev/null +++ b/deepray/custom_ops/simple_hash_table/simple_hash_table_op.py @@ -0,0 +1,21 @@ +# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Load the simple_hash_table_op kernel.""" + +import tensorflow as tf +from tensorflow.python.platform import resource_loader + +gen_simple_hash_table_op = tf.load_op_library( + resource_loader.get_path_to_datafile("simple_hash_table_kernel.so")) diff --git a/deepray/custom_ops/simple_hash_table/simple_hash_table_test.py b/deepray/custom_ops/simple_hash_table/simple_hash_table_test.py new file mode 100644 index 0000000..627e4fa --- /dev/null +++ b/deepray/custom_ops/simple_hash_table/simple_hash_table_test.py @@ -0,0 +1,125 @@ +# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for simple_hash_table.""" + +import os.path +import tempfile + +from absl.testing import parameterized +import tensorflow as tf + +from deepray.custom_ops.simple_hash_table import simple_hash_table +from tensorflow.python.eager import def_function +# This pylint disable is only needed for internal google users +from tensorflow.python.framework import test_util # pylint: disable=g-direct-tensorflow-import + + +class SimpleHashTableTest(tf.test.TestCase, parameterized.TestCase): + + # Helper function using "create, find, insert, find, remove, find + def _use_table(self, key_dtype, value_dtype): + hash_table = simple_hash_table.SimpleHashTable(key_dtype, value_dtype, 111) + result1 = hash_table.find(1, -999) + hash_table.insert(1, 100) + result2 = hash_table.find(1, -999) + hash_table.remove(1) + result3 = hash_table.find(1, -999) + results = tf.stack((result1, result2, result3)) + return results # expect [-999, 100, -999] + + # Test of "create, find, insert, find" in eager mode. + @parameterized.named_parameters(('int32_float', tf.int32, float), + ('int64_int32', tf.int64, tf.int32)) + def test_find_insert_find_eager(self, key_dtype, value_dtype): + results = self._use_table(key_dtype, value_dtype) + self.assertAllClose(results, [-999, 100, -999]) + + # Test of "create, find, insert, find" in a tf.function. Note that the + # creation and use of the ref-counted resource occurs inside a single + # self.evaluate. + @parameterized.named_parameters(('int32_float', tf.int32, float), + ('int64_int32', tf.int64, tf.int32)) + def test_find_insert_find_tf_function(self, key_dtype, value_dtype): + results = def_function.function( + lambda: self._use_table(key_dtype, value_dtype)) + self.assertAllClose(self.evaluate(results), [-999.0, 100.0, -999.0]) + + # strings for key and value + def test_find_insert_find_strings_eager(self): + default = 'Default' + foo = 'Foo' + bar = 'Bar' + hash_table = simple_hash_table.SimpleHashTable(tf.string, tf.string, + default) + result1 = hash_table.find(foo, default) + self.assertEqual(result1, default) + hash_table.insert(foo, bar) + result2 = hash_table.find(foo, default) + self.assertEqual(result2, bar) + + def test_export(self): + table = simple_hash_table.SimpleHashTable( + tf.int64, tf.int64, default_value=-1) + table.insert(1, 100) + table.insert(2, 200) + table.insert(3, 300) + keys, values = self.evaluate(table.export()) + self.assertAllEqual(sorted(keys), [1, 2, 3]) + self.assertAllEqual(sorted(values), [100, 200, 300]) + + def test_import(self): + table = simple_hash_table.SimpleHashTable( + tf.int64, tf.int64, default_value=-1) + keys = tf.constant([1, 2, 3], dtype=tf.int64) + values = tf.constant([100, 200, 300], dtype=tf.int64) + table.do_import(keys, values) + self.assertEqual(table.find(1), 100) + self.assertEqual(table.find(2), 200) + self.assertEqual(table.find(3), 300) + self.assertEqual(table.find(9), -1) + + @test_util.run_v2_only + def testSavedModelSaveRestore(self): + save_dir = os.path.join(self.get_temp_dir(), 'save_restore') + save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), 'hash') + + # TODO(b/203097231) is there an alternative that is not __internal__? + root = tf.__internal__.tracking.AutoTrackable() + + default_value = -1 + root.table = simple_hash_table.SimpleHashTable( + tf.int64, tf.int64, default_value=default_value) + + @def_function.function(input_signature=[tf.TensorSpec((), tf.int64)]) + def lookup(key): + return root.table.find(key) + + root.lookup = lookup + + root.table.insert(1, 100) + root.table.insert(2, 200) + root.table.insert(3, 300) + self.assertEqual(root.lookup(2), 200) + self.assertAllEqual(3, len(self.evaluate(root.table.export()[0]))) + tf.saved_model.save(root, save_path) + + del root + loaded = tf.saved_model.load(save_path) + self.assertEqual(loaded.lookup(2), 200) + self.assertEqual(loaded.lookup(10), -1) + + +if __name__ == '__main__': + tf.test.main()