You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
TensorFlow's tensor_scatter_nd_update expects the values to be of the same type as the tensor. Currently, this leads to TensorFlowTensor.index_update failing since it uses tf.fill, which produces int32/ float32 tensors, independent of the tensor data type itself.
Reproducible example:
import eagerpy as ep
import tensorflow as tf
x_int32 = ep.astensor(tf.range(4, dtype=tf.int32))
x_int64 = ep.astensor(tf.range(4, dtype=tf.int64))
indices = (ep.astensor(tf.constant([0, 2])),)
x_int32.index_update(indices, 0) # this works
x_int32.index_update(indices, 0.0) # this fails
x_int64.index_update(indices, 0) # this fails
x_int64.index_update(indices, 0.0) # this fails
Similar for float32/ float64 tensors.
Solution:
Cast values to the dtype of the raw tensor.
The text was updated successfully, but these errors were encountered:
TensorFlow's
tensor_scatter_nd_update
expects the values to be of the same type as the tensor. Currently, this leads toTensorFlowTensor.index_update
failing since it usestf.fill
, which produces int32/ float32 tensors, independent of the tensor data type itself.Reproducible example:
Similar for float32/ float64 tensors.
Solution:
Cast values to the dtype of the raw tensor.
The text was updated successfully, but these errors were encountered: