-
-
Notifications
You must be signed in to change notification settings - Fork 40
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Does a universal function can be compiled in tensorflow? #36
Comments
Well, I would say, it's mostly a bug in TensorFlow because it doesn't support Calling this fails: @tf.function
def f(a):
return a.ndim Calling this works: @tf.function
def f(a):
return len(a.shape) We run into this problem, because our |
@eserie I filed a bug in the TensorFlow repository. Let's see what they think. If you need a temporary workaround, you can comment out the shape checks that use |
Thank you very much to have posted the issue in TensorFlow repository! Another remark, if compilation makes sens in eagerpy, we could made it available in a universal way through an argument ‘compile=True’ in ‘eager_function’ proposed in #34. What do you think about that ? |
I have to say, I haven't really thought enough about compilation and I am not sure it can be abstracted away enough to unify it between TensorFlow, PyTorch, and JAX. I think it could be interesting, but it requires careful testing of all the special cases and limitations. |
…tensorflow (see issue tensorflow/tensorflow#48612 and jonasrauber#36)
* correct implementation of ndim in order to work in compile mode with tensorflow (see issue tensorflow/tensorflow#48612 and #36) * correct flake8 * Implement __len__ with shape for tf.function works * import cast
Thanks to #40 this is resolved, but I'll leave this issue open for now, while the TensorFlow project discusses what to do about it. |
Let's consider a simple compiled function in tensorflow.
This bunch of code works.
However, its "universal" version :
does not work and raises the error:
(but it works if we comment the
@tf.function
)Let's notice that the equivalent thing with jax seems to work:
Is it a problem with the integration of eagerpy with tensorflow ?
The text was updated successfully, but these errors were encountered: