Skip to content

Commit

Permalink
Add support for creating objects inside the resource manager of an XL…
Browse files Browse the repository at this point in the history
…A compilation device.

Change: 152080944
  • Loading branch information
tensorflower-gardener committed Apr 4, 2017
1 parent d139cf3 commit 438c13e
Show file tree
Hide file tree
Showing 5 changed files with 131 additions and 1 deletion.
11 changes: 10 additions & 1 deletion tensorflow/compiler/tf2xla/xla_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,15 @@ Status CheckSignature(const DataTypeVector& types,

XlaCompiler::XlaCompiler(XlaCompiler::Options options)
: options_(std::move(options)),
initialization_status_(Status::OK()),
next_step_id_(1),
device_(new XlaCompilationDevice(SessionOptions(), options_.device_type)),
device_mgr_({device_}) {}
device_mgr_({device_}) {
if (options_.populate_resource_manager) {
initialization_status_ =
(*options_.populate_resource_manager)(device_->resource_manager());
}
}

XlaCompiler::~XlaCompiler() = default;

Expand Down Expand Up @@ -379,6 +385,9 @@ Status XlaCompiler::CompileGraph(string const& name,
CompilationResult* result) {
VLOG(1) << "Executing graph symbolically to populate ComputationBuilder.";

// Report the error here if initialization failed.
TF_RETURN_IF_ERROR(initialization_status_);

xla::ComputationBuilder builder(client(), name);
XlaContext* context =
new XlaContext(this, &builder, options_.allow_cpu_custom_calls,
Expand Down
10 changes: 10 additions & 0 deletions tensorflow/compiler/tf2xla/xla_compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,12 @@ class XlaCompiler {
// This is useful to prune stateful operators that should not be executed
// from a function body.
bool prune_unreachable_nodes = false;

// If not nullptr, populate_resource_manager is called with the
// compilation device's resource manager when the compilation
// device is created, and can be used to create metadata objects
// that can be accessed by XLA op kernels.
std::function<Status(ResourceMgr*)>* populate_resource_manager = nullptr;
};

explicit XlaCompiler(Options options);
Expand Down Expand Up @@ -247,6 +253,7 @@ class XlaCompiler {
Status BuildExecutable(const CompilationResult& result,
std::unique_ptr<xla::LocalExecutable>* executable);

const Options& options() const { return options_; }
xla::Client* client() const { return options_.client; }
XlaCompilationDevice* device() const { return device_; }
const DeviceMgr* device_mgr() const { return &device_mgr_; }
Expand All @@ -260,6 +267,9 @@ class XlaCompiler {
private:
Options options_;

// Status set to non-OK in the constructor if initialization fails.
Status initialization_status_;

// Returns the next step sequence number.
int64 NextStepId();

Expand Down
101 changes: 101 additions & 0 deletions tensorflow/compiler/tf2xla/xla_compiler_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@ limitations under the License.
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/ops/function_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_constructor.h"
Expand All @@ -33,6 +35,65 @@ limitations under the License.
namespace tensorflow {
namespace {

// Helper class to test the ability to pass resources through to XLA
// compiled kernels.
class DummyResourceForTest : public ResourceBase {
public:
string DebugString() override { return "dummy"; }
void Increment() { ++value_; }
int Get() { return value_; }

private:
int value_ = 0;
};

class DummyReadResourceOp : public XlaOpKernel {
public:
explicit DummyReadResourceOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
ResourceMgr* rm = ctx->op_kernel_context()->resource_manager();
OP_REQUIRES(ctx, rm, errors::Internal("No resource manager."));
DummyResourceForTest* dummy;
OP_REQUIRES_OK(ctx, rm->Lookup<DummyResourceForTest>(
rm->default_container(), "dummy", &dummy));
dummy->Increment();
dummy->Unref();

ctx->SetOutput(0, ctx->Input(0));
}
};

class DummyReadResourceCC {
public:
DummyReadResourceCC(const Scope& scope, const Input& value) {
if (!scope.ok()) return;
auto _value = ops::AsNodeOut(scope, value);
if (!scope.ok()) return;
Node* ret;
const auto unique_name = scope.GetUniqueNameForOp("DummyReadResource");
auto builder = NodeBuilder(unique_name, "DummyReadResource").Input(_value);
scope.UpdateBuilder(&builder);
scope.UpdateStatus(builder.Finalize(scope.graph(), &ret));
if (!scope.ok()) return;
this->output_ = Output(ret, 0);
}
Node* node() const { return output_.node(); }

Output output_;
};

REGISTER_OP("DummyReadResource")
.Input("input: int32")
.Output("output: int32")
.Doc(R"doc(
A dummy Op.
input: dummy input.
output: dummy output.
)doc");

REGISTER_XLA_OP(Name("DummyReadResource"), DummyReadResourceOp);

class XlaCompilerTest : public ::testing::Test {
protected:
void SetUp() override {
Expand Down Expand Up @@ -224,5 +285,45 @@ TEST_F(XlaCompilerTest, ConstantOutputs) {
}
}

// Tests compilation and execution of a graph that adds two tensors.
TEST_F(XlaCompilerTest, ResourceManager) {
// Builds a graph that calls the dummy resource Op.
Scope scope = Scope::NewRootScope().ExitOnError();
auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
auto b = DummyReadResourceCC(scope.WithOpName("B"), a);
auto c = ops::_Retval(scope.WithOpName("C"), b.output_, 0);
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(scope.ToGraph(graph.get()));

// Builds a description of the argument.
std::vector<XlaCompiler::Argument> args(1);
args[0].kind = XlaCompiler::Argument::kParameter;
args[0].type = DT_INT32;
args[0].shape = TensorShape({2});

DummyResourceForTest* resource = new DummyResourceForTest();

// Compiles the graph.
auto options = DefaultOptions();
std::function<Status(ResourceMgr*)> populate_function =
[resource](ResourceMgr* rm) {
resource->Ref();
return rm->Create(rm->default_container(), "dummy", resource);
};
options.populate_resource_manager = &populate_function;
XlaCompiler compiler(options);
auto flr = BuildFunctionLibraryRuntime(compiler);

EXPECT_EQ(0, resource->Get());

XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph("dummy", std::move(graph), flr.get(), args,
&result));

EXPECT_EQ(1, resource->Get());

resource->Unref();
}

} // namespace
} // namespace tensorflow
4 changes: 4 additions & 0 deletions tensorflow/compiler/tf2xla/xla_op_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,10 @@ void XlaOpKernelContext::SetOpHasSideEffects() {
XlaContext::Get(context_).AddSideEffects();
}

const XlaCompiler::Options& XlaOpKernelContext::GetCompilerOptions() const {
return XlaContext::Get(context_).compiler()->options();
}

void XlaOpKernelContext::CtxFailure(Status s) { context_->CtxFailure(s); }
void XlaOpKernelContext::CtxFailureWithWarning(Status s) {
context_->CtxFailureWithWarning(s);
Expand Down
6 changes: 6 additions & 0 deletions tensorflow/compiler/tf2xla/xla_op_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_OP_KERNEL_H_
#define TENSORFLOW_COMPILER_TF2XLA_XLA_OP_KERNEL_H_

#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/platform/macros.h"
Expand Down Expand Up @@ -182,6 +183,11 @@ class XlaOpKernelContext {
// Returns the underlying OpKernelContext. Use rarely.
OpKernelContext* op_kernel_context() const { return context_; }

// Returns the options passed to the XlaCompiler that is being
// run. Used for, e.g., While to inherit options needed for nested
// computation.
const XlaCompiler::Options& GetCompilerOptions() const;

// TODO(phawkins): find a better home for these helpers.

// Get an XLA lambda to compute Max. This is cached in the
Expand Down

0 comments on commit 438c13e

Please sign in to comment.