Skip to content

Commit

Permalink
Register OnesLike kernel in XLA.
Browse files Browse the repository at this point in the history
Change: 152079500
  • Loading branch information
Suharsh Sivakumar authored and tensorflower-gardener committed Apr 4, 2017
1 parent 8e833fb commit d139cf3
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 0 deletions.
8 changes: 8 additions & 0 deletions tensorflow/compiler/tests/randomized_tests.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2339,6 +2339,14 @@ TEST_F(OpTest, ZerosLike) {
});
}

TEST_F(OpTest, OnesLike) {
Repeatedly([this]() {
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
ExpectTfAndXlaOutputsAreClose(
OpTestBuilder("OnesLike").Input(RandomTensor(type)).Attr("T", type));
});
}

} // anonymous namespace
} // namespace tensorflow

Expand Down
5 changes: 5 additions & 0 deletions tensorflow/compiler/tests/unary_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,11 @@ def testNumericOps(self):
np.array([[4, 3], [2, 1]], dtype=dtype),
expected=np.array([[0, 0], [0, 0]], dtype=dtype))

self._assertOpOutputMatchesExpected(
array_ops.ones_like,
np.array([[4, 3], [2, 1]], dtype=dtype),
expected=np.array([[1, 1], [1, 1]], dtype=dtype))

def testLogicalOps(self):
self._assertOpOutputMatchesExpected(
math_ops.logical_not,
Expand Down
14 changes: 14 additions & 0 deletions tensorflow/compiler/tf2xla/kernels/shape_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -241,5 +241,19 @@ class ZerosLikeOp : public XlaOpKernel {

REGISTER_XLA_OP(Name("ZerosLike"), ZerosLikeOp);

class OnesLikeOp : public XlaOpKernel {
public:
explicit OnesLikeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}

void Compile(XlaOpKernelContext* ctx) override {
const TensorShape input_shape = ctx->InputShape(0);

auto one = XlaHelpers::One(ctx->builder(), input_type(0));
ctx->SetOutput(0, ctx->builder()->Broadcast(one, input_shape.dim_sizes()));
}
};

REGISTER_XLA_OP(Name("OnesLike"), OnesLikeOp);

} // namespace
} // namespace tensorflow

0 comments on commit d139cf3

Please sign in to comment.