Skip to content
This repository has been archived by the owner on May 28, 2024. It is now read-only.

Custom Op TPU #53

Open
bhack opened this issue Apr 4, 2020 · 13 comments
Open

Custom Op TPU #53

bhack opened this issue Apr 4, 2020 · 13 comments

Comments

@bhack
Copy link
Contributor

bhack commented Apr 4, 2020

Can we add in the example something related to TPU.
There was a FAQ about creating custom ops for TPU https://cloud.google.com/tpu/docs/faq

@yifeif
Copy link
Contributor

yifeif commented Apr 21, 2020

Had a brief chat with @frankchn and sounds like custom ops are not supported on TPU yet.
See https://cloud.google.com/tpu/docs/tpus#when_to_use_tpus

'''
Cloud TPUs are not suited to the following workloads:
...
Neural network workloads that contain custom TensorFlow operations written in C++. Specifically, custom operations in the body of the main training loop are not suitable for TPUs.
'''

cc: @frankchn @jhseu

@bhack
Copy link
Contributor Author

bhack commented Apr 21, 2020

In the TPU FAQ I see

How can I write a custom op for Compute Engine?
TensorFlow ops that run on Compute Engine are implemented in XLA HLO, a language for defining high-level tensor ops using a small set of low-level functions. XLA is included in TensorFlow's open source release, so it is technically possible to write your op in HLO. The majority of existing implementations can be found in the tf2xla directory.
XLA only allows for execution of a limited set of tensor ops on the TPU, not arbitrary C++ or Python code. Most common tensor ops that can be implemented in HLO have already been written.

@frankchn
Copy link

While it is technically possible to write an XLA HLO op and get it to run on TPUs, we currently don't expose any way to load arbitrary user written HLO ops onto the TPU system itself. This may change in future releases, but we don't have anything to announce today.

@bhack
Copy link
Contributor Author

bhack commented Apr 21, 2020

@frankchn Ok. can you reach anyone internally to fix the FAQ? Cause with that text seems that currently there is "an undocumented" path to build that custom ops.

@frankchn
Copy link

Yup, working on it. Thanks for bringing that to our attention!

@bhack
Copy link
Contributor Author

bhack commented Apr 21, 2020

Thanks

@orsharir
Copy link

Is there a way to specify a fall-back option for TPUs? I have an optimized custom op for CPUs and GPUs, but I want my custom op to be able to run, even in-efficiently, on TPUs. The op can be specified with standard TF operations, so I'm just looking for a way to register a python function as the TPU implementation of the op. Is this possible?

@frankchn
Copy link

@orsharir Would it be possible to just encapsulate the op you want in a Python function, and then just switch between your custom op implementation and the TF default op implementation using flags?

def my_custom_op(input1, input2):
  if os.environ['use_tpu']:
    return tf.add(input1, input2)
  else:
    return custom_op(input1, input2)

@XMaster96
Copy link

XMaster96 commented Apr 7, 2021

@frankchn

we currently don't expose any way to load arbitrary user written HLO ops onto the TPU system itself.

This sounds to me like you could load a custom XLA HLO op on TPU with out modifying the tpulib on the TPU system.
Let's assume that I have SSH on to the TPU system it self do to the JAX TPU beta, can I than write a custom op?

@frankchn
Copy link

frankchn commented Apr 7, 2021

@XMaster96 Unfortunately not because the underlying TPU ISA and associated tools for you to be able to write a XLA op isn't exposed even with the TPU VM preview.

@XMaster96
Copy link

XMaster96 commented Apr 17, 2021

@frankchn ok, thanks
So I was thinking a bit, I can't write a custom op myself, but can I at least load a custom op, some one else has written? And do I even need the VM preview to do so? I am really not sure but I could be able to request a custom op.

@frankchn
Copy link

You can load custom ops that someone else (or you) has written so long as they are CPU custom ops (or build them into a custom TF build). I don't think anyone outside of Google can write XLA custom ops that run on the TPU.

@edwardyehuang
Copy link

@bhack Since JAX already has Pallas to write the TPU kernel, is there any plan for TensorFlow for a similar feature?

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants