Skip to content

Commit

Permalink
fixing merge conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
rohan100jain committed Apr 4, 2017
2 parents 8908272 + 0873aa5 commit 5ee21f2
Show file tree
Hide file tree
Showing 211 changed files with 5,481 additions and 4,644 deletions.
2 changes: 2 additions & 0 deletions tensorflow/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ filegroup(
"//tensorflow/contrib/boosted_trees:all_files",
"//tensorflow/contrib/boosted_trees/lib:all_files",
"//tensorflow/contrib/boosted_trees/proto:all_files",
"//tensorflow/contrib/boosted_trees/resources:all_files",
"//tensorflow/contrib/cloud:all_files",
"//tensorflow/contrib/cloud/kernels:all_files",
"//tensorflow/contrib/compiler:all_files",
Expand Down Expand Up @@ -256,6 +257,7 @@ filegroup(
"//tensorflow/contrib/tfprof/python/tools/tfprof:all_files",
"//tensorflow/contrib/training:all_files",
"//tensorflow/contrib/util:all_files",
"//tensorflow/contrib/xla_tf_graph:all_files",
"//tensorflow/core:all_files",
"//tensorflow/core/debug:all_files",
"//tensorflow/core/distributed_runtime:all_files",
Expand Down
11 changes: 11 additions & 0 deletions tensorflow/compiler/aot/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ genrule(
"test_graph_tfgather.pb",
"test_graph_tfmatmul.pb",
"test_graph_tfmatmulandadd.pb",
"test_graph_tffunction.pb",
],
cmd = "$(location :make_test_graphs) --out_dir $(@D)",
tags = ["manual"],
Expand Down Expand Up @@ -114,6 +115,15 @@ tf_library(
tags = ["manual"],
)

tf_library(
name = "test_graph_tffunction",
testonly = 1,
config = "test_graph_tffunction.config.pbtxt",
cpp_class = "FunctionComp",
graph = "test_graph_tffunction.pb",
tags = ["manual"],
)

cc_test(
name = "tfcompile_test",
srcs = ["tfcompile_test.cc"],
Expand All @@ -122,6 +132,7 @@ cc_test(
":test_graph_tfadd",
":test_graph_tfadd_with_ckpt",
":test_graph_tfadd_with_ckpt_saver",
":test_graph_tffunction",
":test_graph_tfgather",
":test_graph_tfmatmul",
":test_graph_tfmatmulandadd",
Expand Down
16 changes: 14 additions & 2 deletions tensorflow/compiler/aot/tests/make_test_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
Expand Down Expand Up @@ -95,6 +96,17 @@ def tfmatmulandadd(_):
math_ops.add(x, y, name='x_y_sum')


def tffunction(_):

@function.Defun(dtypes.int32, dtypes.int32)
def test_func(a, b):
return a + b

x = constant_op.constant([1], name='x_const')
y = constant_op.constant([2], name='y_const')
test_func(x, y, name='func_call') # pylint: disable=unexpected-keyword-arg


def write_graph(build_graph, out_dir):
"""Build a graph using build_graph and write it out."""
g = ops.Graph()
Expand All @@ -112,6 +124,7 @@ def main(_):
write_graph(tfgather, FLAGS.out_dir)
write_graph(tfmatmul, FLAGS.out_dir)
write_graph(tfmatmulandadd, FLAGS.out_dir)
write_graph(tffunction, FLAGS.out_dir)


if __name__ == '__main__':
Expand All @@ -121,7 +134,6 @@ def main(_):
'--out_dir',
type=str,
default='',
help='Output directory for graphs, checkpoints and savers.'
)
help='Output directory for graphs, checkpoints and savers.')
FLAGS, unparsed = parser.parse_known_args()
app.run(main=main, argv=[sys.argv[0]] + unparsed)
16 changes: 16 additions & 0 deletions tensorflow/compiler/aot/tests/test_graph_tffunction.config.pbtxt
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Text form of tensorflow.tfcompile.Config proto.
feed {
id { node_name: "x_const" }
shape {
dim { size: 1 }
}
}
feed {
id { node_name: "y_const" }
shape {
dim { size: 1 }
}
}
fetch {
id { node_name: "func_call" }
}
16 changes: 16 additions & 0 deletions tensorflow/compiler/aot/tests/tfcompile_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/aot/tests/test_graph_tfadd.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt_saver.h"
#include "tensorflow/compiler/aot/tests/test_graph_tffunction.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfgather.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmul.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd.h"
Expand Down Expand Up @@ -376,6 +377,21 @@ TEST(TFCompileTest, MatMulAndAdd1) {
}
}

TEST(TFCompileTest, Function) {
// The function is equivalent to an addition
FunctionComp add_fn;
EXPECT_EQ(add_fn.arg0_data(), add_fn.args()[0]);
EXPECT_EQ(add_fn.arg1_data(), add_fn.args()[1]);

add_fn.arg0() = 1;
add_fn.arg1() = 2;
EXPECT_TRUE(add_fn.Run());
EXPECT_EQ(add_fn.error_msg(), "");
EXPECT_EQ(add_fn.result0(), 3);
EXPECT_EQ(add_fn.result0_data()[0], 3);
EXPECT_EQ(add_fn.result0_data(), add_fn.results()[0]);
}

} // namespace
} // namespace tfcompile
} // namespace tensorflow
2 changes: 1 addition & 1 deletion tensorflow/compiler/jit/mark_for_compilation_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ bool HasXLAKernel(const Node& node, const DeviceType& jit_device_type) {
}

// Make sure we don't recurse infinitely on recursive functions.
const int kMaxRecursionDepth = 5;
const int kMaxRecursionDepth = 10;

bool IsCompilableCall(const NodeDef& call_def, DeviceType jit_device_type,
int depth, FunctionLibraryRuntime* lib_runtime);
Expand Down
8 changes: 8 additions & 0 deletions tensorflow/compiler/tests/randomized_tests.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2339,6 +2339,14 @@ TEST_F(OpTest, ZerosLike) {
});
}

TEST_F(OpTest, OnesLike) {
Repeatedly([this]() {
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
ExpectTfAndXlaOutputsAreClose(
OpTestBuilder("OnesLike").Input(RandomTensor(type)).Attr("T", type));
});
}

} // anonymous namespace
} // namespace tensorflow

Expand Down
5 changes: 5 additions & 0 deletions tensorflow/compiler/tests/unary_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,11 @@ def testNumericOps(self):
np.array([[4, 3], [2, 1]], dtype=dtype),
expected=np.array([[0, 0], [0, 0]], dtype=dtype))

self._assertOpOutputMatchesExpected(
array_ops.ones_like,
np.array([[4, 3], [2, 1]], dtype=dtype),
expected=np.array([[1, 1], [1, 1]], dtype=dtype))

def testLogicalOps(self):
self._assertOpOutputMatchesExpected(
math_ops.logical_not,
Expand Down
14 changes: 14 additions & 0 deletions tensorflow/compiler/tf2xla/kernels/shape_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -241,5 +241,19 @@ class ZerosLikeOp : public XlaOpKernel {

REGISTER_XLA_OP(Name("ZerosLike"), ZerosLikeOp);

class OnesLikeOp : public XlaOpKernel {
public:
explicit OnesLikeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}

void Compile(XlaOpKernelContext* ctx) override {
const TensorShape input_shape = ctx->InputShape(0);

auto one = XlaHelpers::One(ctx->builder(), input_type(0));
ctx->SetOutput(0, ctx->builder()->Broadcast(one, input_shape.dim_sizes()));
}
};

REGISTER_XLA_OP(Name("OnesLike"), OnesLikeOp);

} // namespace
} // namespace tensorflow
11 changes: 10 additions & 1 deletion tensorflow/compiler/tf2xla/xla_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,15 @@ Status CheckSignature(const DataTypeVector& types,

XlaCompiler::XlaCompiler(XlaCompiler::Options options)
: options_(std::move(options)),
initialization_status_(Status::OK()),
next_step_id_(1),
device_(new XlaCompilationDevice(SessionOptions(), options_.device_type)),
device_mgr_({device_}) {}
device_mgr_({device_}) {
if (options_.populate_resource_manager) {
initialization_status_ =
(*options_.populate_resource_manager)(device_->resource_manager());
}
}

XlaCompiler::~XlaCompiler() = default;

Expand Down Expand Up @@ -379,6 +385,9 @@ Status XlaCompiler::CompileGraph(string const& name,
CompilationResult* result) {
VLOG(1) << "Executing graph symbolically to populate ComputationBuilder.";

// Report the error here if initialization failed.
TF_RETURN_IF_ERROR(initialization_status_);

xla::ComputationBuilder builder(client(), name);
XlaContext* context =
new XlaContext(this, &builder, options_.allow_cpu_custom_calls,
Expand Down
10 changes: 10 additions & 0 deletions tensorflow/compiler/tf2xla/xla_compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,12 @@ class XlaCompiler {
// This is useful to prune stateful operators that should not be executed
// from a function body.
bool prune_unreachable_nodes = false;

// If not nullptr, populate_resource_manager is called with the
// compilation device's resource manager when the compilation
// device is created, and can be used to create metadata objects
// that can be accessed by XLA op kernels.
std::function<Status(ResourceMgr*)>* populate_resource_manager = nullptr;
};

explicit XlaCompiler(Options options);
Expand Down Expand Up @@ -247,6 +253,7 @@ class XlaCompiler {
Status BuildExecutable(const CompilationResult& result,
std::unique_ptr<xla::LocalExecutable>* executable);

const Options& options() const { return options_; }
xla::Client* client() const { return options_.client; }
XlaCompilationDevice* device() const { return device_; }
const DeviceMgr* device_mgr() const { return &device_mgr_; }
Expand All @@ -260,6 +267,9 @@ class XlaCompiler {
private:
Options options_;

// Status set to non-OK in the constructor if initialization fails.
Status initialization_status_;

// Returns the next step sequence number.
int64 NextStepId();

Expand Down
101 changes: 101 additions & 0 deletions tensorflow/compiler/tf2xla/xla_compiler_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@ limitations under the License.
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/ops/function_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_constructor.h"
Expand All @@ -33,6 +35,65 @@ limitations under the License.
namespace tensorflow {
namespace {

// Helper class to test the ability to pass resources through to XLA
// compiled kernels.
class DummyResourceForTest : public ResourceBase {
public:
string DebugString() override { return "dummy"; }
void Increment() { ++value_; }
int Get() { return value_; }

private:
int value_ = 0;
};

class DummyReadResourceOp : public XlaOpKernel {
public:
explicit DummyReadResourceOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
ResourceMgr* rm = ctx->op_kernel_context()->resource_manager();
OP_REQUIRES(ctx, rm, errors::Internal("No resource manager."));
DummyResourceForTest* dummy;
OP_REQUIRES_OK(ctx, rm->Lookup<DummyResourceForTest>(
rm->default_container(), "dummy", &dummy));
dummy->Increment();
dummy->Unref();

ctx->SetOutput(0, ctx->Input(0));
}
};

class DummyReadResourceCC {
public:
DummyReadResourceCC(const Scope& scope, const Input& value) {
if (!scope.ok()) return;
auto _value = ops::AsNodeOut(scope, value);
if (!scope.ok()) return;
Node* ret;
const auto unique_name = scope.GetUniqueNameForOp("DummyReadResource");
auto builder = NodeBuilder(unique_name, "DummyReadResource").Input(_value);
scope.UpdateBuilder(&builder);
scope.UpdateStatus(builder.Finalize(scope.graph(), &ret));
if (!scope.ok()) return;
this->output_ = Output(ret, 0);
}
Node* node() const { return output_.node(); }

Output output_;
};

REGISTER_OP("DummyReadResource")
.Input("input: int32")
.Output("output: int32")
.Doc(R"doc(
A dummy Op.
input: dummy input.
output: dummy output.
)doc");

REGISTER_XLA_OP(Name("DummyReadResource"), DummyReadResourceOp);

class XlaCompilerTest : public ::testing::Test {
protected:
void SetUp() override {
Expand Down Expand Up @@ -224,5 +285,45 @@ TEST_F(XlaCompilerTest, ConstantOutputs) {
}
}

// Tests compilation and execution of a graph that adds two tensors.
TEST_F(XlaCompilerTest, ResourceManager) {
// Builds a graph that calls the dummy resource Op.
Scope scope = Scope::NewRootScope().ExitOnError();
auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
auto b = DummyReadResourceCC(scope.WithOpName("B"), a);
auto c = ops::_Retval(scope.WithOpName("C"), b.output_, 0);
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(scope.ToGraph(graph.get()));

// Builds a description of the argument.
std::vector<XlaCompiler::Argument> args(1);
args[0].kind = XlaCompiler::Argument::kParameter;
args[0].type = DT_INT32;
args[0].shape = TensorShape({2});

DummyResourceForTest* resource = new DummyResourceForTest();

// Compiles the graph.
auto options = DefaultOptions();
std::function<Status(ResourceMgr*)> populate_function =
[resource](ResourceMgr* rm) {
resource->Ref();
return rm->Create(rm->default_container(), "dummy", resource);
};
options.populate_resource_manager = &populate_function;
XlaCompiler compiler(options);
auto flr = BuildFunctionLibraryRuntime(compiler);

EXPECT_EQ(0, resource->Get());

XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph("dummy", std::move(graph), flr.get(), args,
&result));

EXPECT_EQ(1, resource->Get());

resource->Unref();
}

} // namespace
} // namespace tensorflow
4 changes: 4 additions & 0 deletions tensorflow/compiler/tf2xla/xla_op_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,10 @@ void XlaOpKernelContext::SetOpHasSideEffects() {
XlaContext::Get(context_).AddSideEffects();
}

const XlaCompiler::Options& XlaOpKernelContext::GetCompilerOptions() const {
return XlaContext::Get(context_).compiler()->options();
}

void XlaOpKernelContext::CtxFailure(Status s) { context_->CtxFailure(s); }
void XlaOpKernelContext::CtxFailureWithWarning(Status s) {
context_->CtxFailureWithWarning(s);
Expand Down
Loading

0 comments on commit 5ee21f2

Please sign in to comment.