Skip to content

Commit

Permalink
Fix nits
Browse files Browse the repository at this point in the history
  • Loading branch information
saudet authored and karllessard committed Jan 29, 2020
1 parent 577994d commit b966d37
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ private static void requireOp(TFE_Op handle) {

private static void requireTensorHandle(TFE_TensorHandle handle) {
if (handle == null || handle.isNull()) {
throw new IllegalStateException("EagerSession has been closed");
throw new IllegalStateException("Eager session has been closed");
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,7 @@ static void resolveOutputs(String type, TF_Operation[] srcOps,
}
for (int i = 0; i < n; ++i) {
if (srcOps[i] == null || srcOps[i].isNull()) {
throw new NullPointerException("invalid " + type + " (#" + i + " of " + n + ")");
throw new IllegalStateException("invalid " + type + " (#" + i + " of " + n + ")");
}
dst.position(i).oper(srcOps[i]).index(srcIndices[i]);
}
Expand Down Expand Up @@ -633,7 +633,7 @@ private static Object[] whileLoop(
condOutputHandles[0] = condOutputOutput.oper();
condOutputIndices[0] = condOutputOutput.index();

Object[] cond_output_handles_and_indices =
Object[] condOutputHandlesAndIndices =
buildSubgraph(condGraphBuilder, params.cond_graph(),
condInputHandles, condInputIndices,
condOutputHandles, condOutputIndices);
Expand All @@ -652,23 +652,23 @@ private static Object[] whileLoop(
bodyOutputIndices[i] = bodyOutputsOutput.position(i).index();
}

Object[] body_output_handles_and_indices =
Object[] bodyOutputHandlesAndIndices =
buildSubgraph(bodyGraphBuilder, params.body_graph(),
bodyInputHandles, bodyInputIndices,
bodyOutputHandles, bodyOutputIndices);

if (cond_output_handles_and_indices == null ||
body_output_handles_and_indices == null)
if (condOutputHandlesAndIndices == null ||
bodyOutputHandlesAndIndices == null)
return null;

// set cond_output param to output of the conditional subgraph
condOutputOutput.oper((TF_Operation)cond_output_handles_and_indices[0])
.index((Integer)cond_output_handles_and_indices[1]);
condOutputOutput.oper((TF_Operation)condOutputHandlesAndIndices[0])
.index((Integer)condOutputHandlesAndIndices[1]);

// set body_outputs param to outputs of the body subgraph
for (int i = 0, j = ninputs; i < ninputs; ++i, ++j) {
bodyOutputsOutput.position(i).oper((TF_Operation)body_output_handles_and_indices[i])
.index((Integer)body_output_handles_and_indices[j]);
bodyOutputsOutput.position(i).oper((TF_Operation)bodyOutputHandlesAndIndices[i])
.index((Integer)bodyOutputHandlesAndIndices[j]);
}

// set loop name param
Expand All @@ -681,14 +681,14 @@ private static Object[] whileLoop(
status.throwExceptionIfNotOK();

// returned array contains both op handles and output indices, in pair
Object[] output_handles_and_indices = new Object[ninputs * 2];
Object[] outputHandlesAndIndices = new Object[ninputs * 2];
for (int i = 0, j = ninputs; i < ninputs; ++i, ++j) {
TF_Output output = outputs.position(i);
output_handles_and_indices[i] = output.oper();
output_handles_and_indices[j] = output.index();
outputHandlesAndIndices[i] = output.oper();
outputHandlesAndIndices[j] = output.index();
}

return output_handles_and_indices;
return outputHandlesAndIndices;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ private static void resolveHandles(String type, Pointer[] src, PointerPointer ds
}
for (int i = 0; i < n; ++i) {
if (src[i] == null || src[i].isNull()) {
throw new NullPointerException("invalid " + type + " (#" + i + " of " + n + ")");
throw new IllegalStateException("invalid " + type + " (#" + i + " of " + n + ")");
}
dst.put(i, src[i]);
}
Expand All @@ -487,7 +487,7 @@ private static TF_Session allocate(TF_Graph graphHandle) {

private static TF_Session allocate2(TF_Graph graphHandle, String target, byte[] config) {
if (graphHandle == null || graphHandle.isNull()) {
throw new NullPointerException("Graph has been close()d");
throw new IllegalStateException("Graph has been close()d");
}

try (PointerScope scope = new PointerScope()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -850,7 +850,7 @@ private void throwExceptionIfTypeIsIncompatible(Object o) {

private static void requireHandle(TF_Tensor handle) {
if (handle == null || handle.isNull()) {
throw new NullPointerException("close() was called on the Tensor");
throw new IllegalStateException("close() was called on the Tensor");
}
}

Expand Down Expand Up @@ -1000,32 +1000,30 @@ private static long writeNDArray(Object src, int dtype, int dimsLeft,
BytePointer dst, long dstSize) {
if (dimsLeft == 1) {
return write1DArray(src, dtype, dst, dstSize);
} else {
Object[] ndarray = (Object[])src;
long sz = 0;
for (int i = 0; i < ndarray.length; ++i) {
Object row = ndarray[i];
sz += writeNDArray(row, dtype, dimsLeft - 1,
new BytePointer(dst).position(dst.position() + sz), dstSize - sz);
}
return sz;
}
Object[] ndarray = (Object[])src;
long sz = 0;
for (int i = 0; i < ndarray.length; ++i) {
Object row = ndarray[i];
sz += writeNDArray(row, dtype, dimsLeft - 1,
new BytePointer(dst).position(dst.position() + sz), dstSize - sz);
}
return sz;
}

private static long readNDArray(int dtype, BytePointer src, long srcSize,
int dimsLeft, Object dst) {
if (dimsLeft == 1) {
return read1DArray(dtype, src, srcSize, dst);
} else {
Object[] ndarray = (Object[])dst;
long sz = 0;
for (int i = 0; i < ndarray.length; ++i) {
Object row = ndarray[i];
sz += readNDArray(dtype, new BytePointer(src).position(src.position() + sz),
srcSize - sz, dimsLeft - 1, row);
}
return sz;
}
Object[] ndarray = (Object[])dst;
long sz = 0;
for (int i = 0; i < ndarray.length; ++i) {
Object row = ndarray[i];
sz += readNDArray(dtype, new BytePointer(src).position(src.position() + sz),
srcSize - sz, dimsLeft - 1, row);
}
return sz;
}

private static byte[] TF_StringDecodeToArray(BytePointer src, long srcLen, TF_Status status) {
Expand Down Expand Up @@ -1126,7 +1124,7 @@ private static void readNDStringArray(StringTensorReader reader, int dimsLeft,
private static TF_Tensor allocate(int dtype, long[] shape, long byteSize) {
TF_Tensor t = TF_AllocateTensor(dtype, shape, shape.length, byteSize);
if (t == null || t.isNull()) {
throw new NullPointerException("unable to allocate memory for the Tensor");
throw new IllegalStateException("unable to allocate memory for the Tensor");
}
return t;
}
Expand Down Expand Up @@ -1157,7 +1155,7 @@ private static long nonScalarStringTensorSize(Object value, int numDims) {
for (int i = 0; i < array.length; ++i) {
Object elem = array[i];
if (elem == null) {
throw new NullPointerException("null entries in provided array");
throw new IllegalStateException("null entries in provided array");
}
ret += nonScalarStringTensorSize(elem, numDims - 1);
}
Expand All @@ -1175,7 +1173,7 @@ private static void fillNonScalarStringTensorData(Object value, int numDims,
for (int i = 0; i < array.length; ++i) {
Object elem = array[i];
if (elem == null) {
throw new NullPointerException("null entries in provided array");
throw new IllegalStateException("null entries in provided array");
}
fillNonScalarStringTensorData(elem, numDims - 1, writer, status);
if (TF_GetCode(status) != TF_OK) return;
Expand All @@ -1194,7 +1192,7 @@ private static TF_Tensor allocateNonScalarBytes(long[] shape, Object[] value) {
TF_Tensor t = TF_AllocateTensor(TF_STRING, shape, numDims,
8 * numElements + encodedSize);
if (t == null || t.isNull()) {
throw new NullPointerException("unable to allocate memory for the Tensor");
throw new IllegalStateException("unable to allocate memory for the Tensor");
}
TF_Status status = TF_Status.newStatus();
try (PointerScope scope = new PointerScope()) {
Expand Down Expand Up @@ -1247,55 +1245,55 @@ private static float scalarFloat(TF_Tensor handle) {
requireHandle(handle);
if (TF_NumDims(handle) != 0) {
throw new IllegalStateException("Tensor is not a scalar");
} else if (TF_TensorType(handle) != TF_FLOAT) {
}
if (TF_TensorType(handle) != TF_FLOAT) {
throw new IllegalStateException("Tensor is not a float scalar");
} else {
return new FloatPointer(TF_TensorData(handle)).get();
}
return new FloatPointer(TF_TensorData(handle)).get();
}

private static double scalarDouble(TF_Tensor handle) {
requireHandle(handle);
if (TF_NumDims(handle) != 0) {
throw new IllegalStateException("Tensor is not a scalar");
} else if (TF_TensorType(handle) != TF_DOUBLE) {
}
if (TF_TensorType(handle) != TF_DOUBLE) {
throw new IllegalStateException("Tensor is not a double scalar");
} else {
return new DoublePointer(TF_TensorData(handle)).get();
}
return new DoublePointer(TF_TensorData(handle)).get();
}

private static int scalarInt(TF_Tensor handle) {
requireHandle(handle);
if (TF_NumDims(handle) != 0) {
throw new IllegalStateException("Tensor is not a scalar");
} else if (TF_TensorType(handle) != TF_INT32) {
}
if (TF_TensorType(handle) != TF_INT32) {
throw new IllegalStateException("Tensor is not a int scalar");
} else {
return new IntPointer(TF_TensorData(handle)).get();
}
return new IntPointer(TF_TensorData(handle)).get();
}

private static long scalarLong(TF_Tensor handle) {
requireHandle(handle);
if (TF_NumDims(handle) != 0) {
throw new IllegalStateException("Tensor is not a scalar");
} else if (TF_TensorType(handle) != TF_INT64) {
}
if (TF_TensorType(handle) != TF_INT64) {
throw new IllegalStateException("Tensor is not a long scalar");
} else {
return new LongPointer(TF_TensorData(handle)).get();
}
return new LongPointer(TF_TensorData(handle)).get();
}

private static boolean scalarBoolean(TF_Tensor handle) {
requireHandle(handle);
if (TF_NumDims(handle) != 0) {
throw new IllegalStateException("Tensor is not a scalar");
} else if (TF_TensorType(handle) != TF_BOOL) {
}
if (TF_TensorType(handle) != TF_BOOL) {
throw new IllegalStateException("Tensor is not a boolean scalar");
} else {
return new BooleanPointer(TF_TensorData(handle)).get();
}
return new BooleanPointer(TF_TensorData(handle)).get();
}

private static byte[] scalarBytes(TF_Tensor handle) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
import static java.nio.charset.StandardCharsets.UTF_8;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;

Expand Down Expand Up @@ -523,7 +524,7 @@ public void useAfterClose() {
t.close();
try {
t.intValue();
} catch (NullPointerException e) {
} catch (IllegalStateException e) {
// The expected exception.
}
}
Expand All @@ -535,14 +536,14 @@ public void eagerTensorIsReleasedAfterSessionIsClosed() {
Output<?> x = TestUtil.constant(session, "Const1", 10);
Output<?> y = TestUtil.constant(session, "Const2", 20);
sum = TestUtil.<TInt32>addN(session, x, y).tensor();
assertNotEquals(null, sum.getNativeHandle());
assertNotNull(sum.getNativeHandle());
assertEquals(30, sum.intValue());
}
assertEquals(null, sum.getNativeHandle());
assertNull(sum.getNativeHandle());
try {
sum.intValue();
fail();
} catch (NullPointerException e) {
} catch (IllegalStateException e) {
// expected.
}
}
Expand Down Expand Up @@ -571,7 +572,7 @@ public void gracefullyFailCreationFromNullArrayForStringTensor() {
byte[][] array = new byte[1][];
try {
Tensors.create(array);
} catch (NullPointerException e) {
} catch (IllegalStateException e) {
// expected.
}
}
Expand Down

0 comments on commit b966d37

Please sign in to comment.