Skip to content

Commit

Permalink
Fix formatting of switch statements and update URL in error message
Browse files Browse the repository at this point in the history
  • Loading branch information
saudet committed Jan 28, 2020
1 parent b8ff850 commit 55b7b5b
Showing 1 changed file with 74 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -893,74 +893,105 @@ private static void writeScalar(Object src, int dtype, BytePointer dst, long dst
+ " bytes) not compatible with allocated tensor (" + dstSize + " bytes)");
}
switch (dtype) {
case TF_FLOAT: dst.putFloat((Float)src); break;
case TF_DOUBLE: dst.putDouble((Double)src); break;
case TF_INT32: dst.putInt((Integer)src); break;
case TF_INT64: dst.putLong((Long)src); break;
case TF_UINT8: dst.put((Byte)src); break;
case TF_BOOL: dst.putBool((Boolean)src); break;
default: throw new IllegalStateException("invalid DataType(" + dtype + ")");
case TF_FLOAT:
dst.putFloat((Float)src);
break;
case TF_DOUBLE:
dst.putDouble((Double)src);
break;
case TF_INT32:
dst.putInt((Integer)src);
break;
case TF_INT64:
dst.putLong((Long)src);
break;
case TF_UINT8:
dst.put((Byte)src);
break;
case TF_BOOL:
dst.putBool((Boolean)src);
break;
default:
throw new IllegalStateException("invalid DataType(" + dtype + ")");
}
}

/** Copy a 1-D array of Java primitive types to the tensor buffer dst.
* Returns the number of bytes written to dst. */
private static long write1DArray(Object array, int dtype, BytePointer dst, long dstSize) {
int nelems;
private static int getArrayLength(Object array, int dtype) {
switch (dtype) {
case TF_FLOAT: nelems = ((float[])array).length; break;
case TF_DOUBLE: nelems = ((double[])array).length; break;
case TF_INT32: nelems = ((int[])array).length; break;
case TF_INT64: nelems = ((long[])array).length; break;
case TF_UINT8: nelems = ((byte[])array).length; break;
case TF_BOOL: nelems = ((boolean[])array).length; break;
case TF_FLOAT: return ((float[])array).length;
case TF_DOUBLE: return ((double[])array).length;
case TF_INT32: return ((int[])array).length;
case TF_INT64: return ((long[])array).length;
case TF_UINT8: return ((byte[])array).length;
case TF_BOOL: return ((boolean[])array).length;
default: throw new IllegalStateException("invalid DataType(" + dtype + ")");
}
}

/** Copy a 1-D array of Java primitive types to the tensor buffer dst.
* Returns the number of bytes written to dst. */
private static long write1DArray(Object array, int dtype, BytePointer dst, long dstSize) {
int nelems = getArrayLength(array, dtype);
long toCopy = nelems * elemByteSize(dtype);
if (toCopy > dstSize) {
throw new IllegalStateException(
"cannot write Java array of " + toCopy + " bytes to Tensor of " + dstSize + " bytes");
}
switch (dtype) {
case TF_FLOAT: dst.put(new FloatPointer((float[])array).capacity(nelems)); break;
case TF_DOUBLE: dst.put(new DoublePointer((double[])array).capacity(nelems)); break;
case TF_INT32: dst.put(new IntPointer((int[])array).capacity(nelems)); break;
case TF_INT64: dst.put(new LongPointer((long[])array).capacity(nelems)); break;
case TF_UINT8: dst.put(new BytePointer((byte[])array).capacity(nelems)); break;
case TF_BOOL: dst.put(new BooleanPointer((boolean[])array).capacity(nelems)); break;
default: throw new IllegalStateException("invalid DataType(" + dtype + ")");
case TF_FLOAT:
dst.put(new FloatPointer((float[])array).capacity(nelems));
break;
case TF_DOUBLE:
dst.put(new DoublePointer((double[])array).capacity(nelems));
break;
case TF_INT32:
dst.put(new IntPointer((int[])array).capacity(nelems));
break;
case TF_INT64:
dst.put(new LongPointer((long[])array).capacity(nelems));
break;
case TF_UINT8:
dst.put(new BytePointer((byte[])array).capacity(nelems));
break;
case TF_BOOL:
dst.put(new BooleanPointer((boolean[])array).capacity(nelems));
break;
default:
throw new IllegalStateException("invalid DataType(" + dtype + ")");
}
return toCopy;
}

/** Copy the elements of a 1-D array from the tensor buffer src to a 1-D array of
* Java primitive types. Returns the number of bytes read from src. */
private static long read1DArray(int dtype, BytePointer src, long srcSize, Object dst) {
int len;
switch (dtype) {
case TF_FLOAT: len = ((float[])dst).length; break;
case TF_DOUBLE: len = ((double[])dst).length; break;
case TF_INT32: len = ((int[])dst).length; break;
case TF_INT64: len = ((long[])dst).length; break;
case TF_UINT8: len = ((byte[])dst).length; break;
case TF_BOOL: len = ((boolean[])dst).length; break;
default: throw new IllegalStateException("invalid DataType(" + dtype + ")");
}

int len = getArrayLength(dst, dtype);
long sz = len * elemByteSize(dtype);
if (sz > srcSize) {
throw new IllegalStateException(
"cannot fill a Java array of " + sz + "bytes with a Tensor of " + srcSize + " bytes");
}
switch (dtype) {
case TF_FLOAT: new FloatPointer(src).position(src.position() / 4).get((float[])dst); break;
case TF_DOUBLE: new DoublePointer(src).position(src.position() / 8).get((double[])dst); break;
case TF_INT32: new IntPointer(src).position(src.position() / 4).get((int[])dst); break;
case TF_INT64: new LongPointer(src).position(src.position() / 8).get((long[])dst); break;
case TF_UINT8: src.get((byte[])dst); break;
case TF_BOOL: new BooleanPointer(src).position(src.position()).get((boolean[])dst); break;
default: throw new IllegalStateException("invalid DataType(" + dtype + ")");
case TF_FLOAT:
new FloatPointer(src).position(src.position() / 4).get((float[])dst);
break;
case TF_DOUBLE:
new DoublePointer(src).position(src.position() / 8).get((double[])dst);
break;
case TF_INT32:
new IntPointer(src).position(src.position() / 4).get((int[])dst);
break;
case TF_INT64:
new LongPointer(src).position(src.position() / 8).get((long[])dst);
break;
case TF_UINT8:
src.get((byte[])dst);
break;
case TF_BOOL:
new BooleanPointer(src).position(src.position()).get((boolean[])dst);
break;
default:
throw new IllegalStateException("invalid DataType(" + dtype + ")");
}
return sz;
}
Expand Down Expand Up @@ -1025,7 +1056,7 @@ void Add(BytePointer src, long len, TF_Status status) {
TF_SetStatus(status, TF_OUT_OF_RANGE,
"TF_STRING tensor encoding ran out of space for offsets, "
+ "this is likely a bug, please file an issue at "
+ "https://github.com/tensorflow/tensorflow/issues/new");
+ "https://github.com/tensorflow/java/issues/new");
return;
}
poffsets.putLong(offset);
Expand Down

0 comments on commit 55b7b5b

Please sign in to comment.