Skip to content

Commit

Permalink
Simplify running KerasCV with Keras 3 (#2179)
Browse files Browse the repository at this point in the history
* remove keras_core dependency

* update init

* update readme

* fix model None error (#2176) (#2177)

* Update pycoco_callback.py

* Update waymo_evaluation_callback.py

* fix model None error (#2176) (#2178)

* Update pycoco_callback.py

* Update waymo_evaluation_callback.py

* update readme and conftest

* update readme

* update citation list

* fix mix transformer tests

* fix lint error

* fix all failing tests
  • Loading branch information
divyashreepathihalli authored Dec 1, 2023
1 parent 2604105 commit 431e97c
Show file tree
Hide file tree
Showing 14 changed files with 88 additions and 121 deletions.
61 changes: 34 additions & 27 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
[![Contributions Welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg?style=flat)](https://github.com/keras-team/keras-cv/issues)

KerasCV is a library of modular computer vision components that work natively
with TensorFlow, JAX, or PyTorch. Built on [Keras Core](https://keras.io/keras_core/announcement/),
these models, layers, metrics, callbacks, etc., can be trained and serialized
in any framework and re-used in another without costly migrations. See
"Configuring your backend" below for more details on multi-framework KerasCV.
with TensorFlow, JAX, or PyTorch. Built on Keras 3, these models, layers,
metrics, callbacks, etc., can be trained and serialized in any framework and
re-used in another without costly migrations. See "Configuring your backend"
below for more details on multi-framework KerasCV.

<img style="width: 440px; max-width: 90%;" src="https://storage.googleapis.com/keras-cv/guides/keras-cv-augmentations.gif">

Expand All @@ -34,29 +34,44 @@ these common tasks.
- [API Design Guidelines](.github/API_DESIGN.md)

## Installation
KerasCV supports both Keras 2 and Keras 3. We recommend Keras 3 for all new
users, as it enables using KerasCV models and layers with JAX, TensorFlow and
PyTorch.

To install the latest official release:
### Keras 2 Installation

To install the latest KerasCV release with Keras 2, simply run:

```
pip install keras-cv tensorflow --upgrade
pip install --upgrade keras-cv tensorflow
```

To install the latest unreleased changes to the library, we recommend using
pip to install directly from the master branch on github:
### Keras 3 Installation

There are currently two ways to install Keras 3 with KerasCV. To install the
latest changes for KerasCV and Keras, you can use our nightly package.


```
pip install git+https://github.com/keras-team/keras-cv.git tensorflow --upgrade
pip install --upgrade keras-cv-nightly tf-nightly
```

## Configuring your backend
To install the stable versions of KerasCV and Keras 3, you should install Keras
3 **after** installing KerasCV. This is a temporary step while TensorFlow is
pinned to Keras 2, and will no longer be necessary after TensorFlow 2.16.

**Keras 3** is an upcoming release of the Keras library which supports
TensorFlow, Jax or Torch as backends. This is supported today in KerasNLP,
but will not be enabled by default until the official release of Keras 3. If you
`pip install keras-cv` and run a script or notebook without changes, you will
be using TensorFlow and **Keras 2**.
```
pip install --upgrade keras-cv tensorflow
pip install keras>=3
```
> [!IMPORTANT]
> Keras 3 will not function with TensorFlow 2.14 or earlier.
## Configuring your backend

If you would like to enable a preview of the Keras 3 behavior, you can do
If you have Keras 3 installed in your environment (see installation above),
you can use KerasCV with any of JAX, TensorFlow and PyTorch. To do so, set the
`KERAS_BACKEND` environment variable. For example:
so by setting the `KERAS_BACKEND` environment variable. For example:

```shell
Expand All @@ -75,21 +90,13 @@ import keras_cv
> [!IMPORTANT]
> Make sure to set the `KERAS_BACKEND` before import any Keras libraries, it
> will be used to set up Keras when it is first imported.
Until the Keras 3 release, KerasCV will use a preview of Keras 3 on PyPI named
[keras-core](https://pypi.org/project/keras-core/).

> [!IMPORTANT]
> If you set `KERAS_BACKEND` variable, you should `import keras_core as keras`
> instead of `import keras`. This is a temporary step until Keras 3 is out!
To restore the default **Keras 2** behavior, `unset KERAS_BACKEND` before
importing Keras and KerasCV.
Once that configuration step is done, you can just import KerasCV and start
using it on top of your backend of choice:

```python
import keras_cv
from keras_cv.backend import keras
import keras

filepath = keras.utils.get_file(origin="https://i.imgur.com/gCNcJJI.jpg")
image = np.array(keras.utils.load_img(filepath))
Expand All @@ -108,7 +115,7 @@ predictions = model.predict(image_resized)
import tensorflow as tf
import keras_cv
import tensorflow_datasets as tfds
from keras_cv.backend import keras
import keras

# Create a preprocessing pipeline with augmentations
BATCH_SIZE = 16
Expand Down Expand Up @@ -260,7 +267,7 @@ Here is the BibTeX entry:
```bibtex
@misc{wood2022kerascv,
title={KerasCV},
author={Wood, Luke and Tan, Zhenyu and Stenbit, Ian and Bischof, Jonathan and Zhu, Scott and Chollet, Fran\c{c}ois and others},
author={Wood, Luke and Tan, Zhenyu and Stenbit, Ian and Bischof, Jonathan and Zhu, Scott and Chollet, Fran\c{c}ois and Sreepathihalli, Divyashree and Sampath, Ramesh and others},
year={2022},
howpublished={\url{https://github.com/keras-team/keras-cv}},
}
Expand Down
20 changes: 17 additions & 3 deletions keras_cv/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Keras backend module.
This module adds a temporary Keras API surface that is fully under KerasCV
control. The goal is to allow us to write Keras 3-like code everywhere, while
still supporting Keras 2. We do this by using the `keras_core` package with
Keras 2 to backport Keras 3 numerics APIs (`keras.ops` and `keras.random`) into
Keras 2. The sub-modules exposed are as follows:
- `config`: check which version of Keras is being run.
- `keras`: The full `keras` API with compat shims for older Keras versions.
- `ops`: `keras.ops` for Keras 3 or `keras_core.ops` for Keras 2.
- `random`: `keras.random` for Keras 3 or `keras_core.ops` for Keras 2.
"""
from keras_cv.backend import config # noqa: E402
from keras_cv.backend import keras # noqa: E402
from keras_cv.backend import ops # noqa: E402
Expand All @@ -19,12 +33,12 @@


def assert_tf_keras(src):
if config.multi_backend():
if config.keras_3():
raise NotImplementedError(
f"KerasCV component {src} does not yet support Keras Core, and can "
f"KerasCV component {src} does not yet support Keras 3, and can "
"only be used in `tf.keras`."
)


def supports_ragged():
return not config.multi_backend()
return not config.keras_3()
66 changes: 9 additions & 57 deletions keras_cv/backend/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os

_MULTI_BACKEND = False

# Set Keras base dir path given KERAS_HOME env variable, if applicable.
# Otherwise either ~/.keras or /tmp.
if "KERAS_HOME" in os.environ:
_keras_dir = os.environ.get("KERAS_HOME")
else:
_keras_base_dir = os.path.expanduser("~")
if not os.access(_keras_base_dir, os.W_OK):
_keras_base_dir = "/tmp"
_keras_dir = os.path.join(_keras_base_dir, ".keras")
from tensorflow import keras

# We follow the version of keras that tensorflow is configured to use.
_USE_KERAS_3 = False

# Note that only recent versions of keras have a `version()` function.
if hasattr(keras, "version") and keras.version().startswith("3."):
_USE_KERAS_3 = True


def detect_if_tensorflow_uses_keras_3():
Expand Down Expand Up @@ -57,53 +53,9 @@ def keras_3():
return _USE_KERAS_3


# Attempt to read KerasCV config file.
_config_path = os.path.expanduser(os.path.join(_keras_dir, "keras_cv.json"))
if os.path.exists(_config_path):
try:
with open(_config_path) as f:
_config = json.load(f)
except ValueError:
_config = {}
_MULTI_BACKEND = _config.get("multi_backend", _MULTI_BACKEND)

# Save config file, if possible.
if not os.path.exists(_keras_dir):
try:
os.makedirs(_keras_dir)
except OSError:
# Except permission denied and potential race conditions
# in multi-threaded environments.
pass

if not os.path.exists(_config_path):
_config = {
"multi_backend": _MULTI_BACKEND,
}
try:
with open(_config_path, "w") as f:
f.write(json.dumps(_config, indent=4))
except IOError:
# Except permission denied.
pass

if "KERAS_BACKEND" in os.environ and os.environ["KERAS_BACKEND"]:
_MULTI_BACKEND = True


def multi_backend():
return _MULTI_BACKEND


def backend():
"""Check the backend framework."""
if not multi_backend():
return "tensorflow"
if not keras_3():
import keras_core

return keras_core.config.backend()

from tensorflow import keras
return "tensorflow"

return keras.config.backend()
5 changes: 0 additions & 5 deletions keras_cv/backend/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,6 @@
import keras # noqa: F403, F401
from keras import * # noqa: F403, F401

keras.backend.name_scope = keras.name_scope
elif config.multi_backend():
import keras_core as keras # noqa: F403, F401
from keras_core import * # noqa: F403, F401

keras.backend.name_scope = keras.name_scope
else:
from tensorflow import keras # noqa: F403, F401
Expand Down
9 changes: 4 additions & 5 deletions keras_cv/backend/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from keras_cv.backend.config import keras_3
from keras_cv.backend.config import multi_backend
from keras_cv.backend import config

if keras_3():
if config.keras_3():
from keras.ops import * # noqa: F403, F401
from keras.preprocessing.image import smart_resize # noqa: F403, F401

Expand All @@ -32,5 +31,5 @@
from keras_core.src.utils.image_utils import ( # noqa: F403, F401
smart_resize,
)
if not multi_backend():
from keras_cv.backend.tf_ops import * # noqa: F403, F401
if config.backend() == "tensorflow":
from keras_cv.backend.tf_ops import * # noqa: F403, F401
4 changes: 2 additions & 2 deletions keras_cv/backend/scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from keras_cv.backend import keras
from keras_cv.backend import ops
from keras_cv.backend import tf_ops
from keras_cv.backend.config import multi_backend
from keras_cv.backend.config import keras_3

_ORIGINAL_OPS = copy.copy(backend.ops.__dict__)
_ORIGINAL_SUPPORTS_RAGGED = backend.supports_ragged
Expand All @@ -30,7 +30,7 @@
def tf_data(function):
@functools.wraps(function)
def wrapper(*args, **kwargs):
if multi_backend() and keras.src.utils.backend_utils.in_tf_graph():
if keras_3() and keras.src.utils.backend_utils.in_tf_graph():
with TFDataScope():
return function(*args, **kwargs)
else:
Expand Down
14 changes: 7 additions & 7 deletions keras_cv/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import tensorflow as tf
from packaging import version

from keras_cv.backend.config import multi_backend
from keras_cv.backend.config import keras_3


def pytest_addoption(parser):
Expand Down Expand Up @@ -45,7 +45,7 @@ def pytest_configure(config):
)
config.addinivalue_line(
"markers",
"tf_keras_only: mark test as a tf.keras-only test",
"tf_keras_only: mark test as a Keras 2-only test",
)
config.addinivalue_line(
"markers",
Expand All @@ -69,12 +69,12 @@ def pytest_collection_modifyitems(config, items):
skip_extra_large = pytest.mark.skipif(
not run_extra_large_tests, reason="need --run_extra_large option to run"
)
skip_tf_keras_only = pytest.mark.skipif(
multi_backend(),
reason="This test is only supported on tf.keras",
skip_keras_2_only = pytest.mark.skipif(
keras_3(),
reason="This test is only supported on Keras 2",
)
skip_tf_only = pytest.mark.skipif(
multi_backend() and keras_core.backend.backend() != "tensorflow",
keras_3() and keras_core.backend.backend() != "tensorflow",
reason="This test is only supported on TensorFlow",
)
for item in items:
Expand All @@ -87,6 +87,6 @@ def pytest_collection_modifyitems(config, items):
if "extra_large" in item.keywords:
item.add_marker(skip_extra_large)
if "tf_keras_only" in item.keywords:
item.add_marker(skip_tf_keras_only)
item.add_marker(skip_keras_2_only)
if "tf_only" in item.keywords:
item.add_marker(skip_tf_only)
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from keras_cv.api_export import keras_cv_export
from keras_cv.backend import keras
from keras_cv.backend import ops
from keras_cv.backend.config import multi_backend
from keras_cv.backend.config import keras_3


@keras_cv_export("keras_cv.layers.MultiClassNonMaxSuppression")
Expand Down Expand Up @@ -73,7 +73,7 @@ def call(
`bounding_box_format` specified in the constructor.
class_prediction: Dense Tensor of shape [batch, boxes, num_classes].
"""
if multi_backend() and keras.backend.backend() != "tensorflow":
if keras_3() and keras.backend.backend() != "tensorflow":
raise NotImplementedError(
"MultiClassNonMaxSuppression does not support non-TensorFlow "
"backends. Consider using NonMaxSuppression instead."
Expand Down
6 changes: 3 additions & 3 deletions keras_cv/layers/object_detection/non_max_suppression.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from keras_cv.api_export import keras_cv_export
from keras_cv.backend import keras
from keras_cv.backend import ops
from keras_cv.backend.config import multi_backend
from keras_cv.backend.config import keras_3

EPSILON = 1e-8

Expand Down Expand Up @@ -89,7 +89,7 @@ def call(

confidence_prediction = ops.max(class_prediction, axis=-1)

if not multi_backend() or keras.backend.backend() == "tensorflow":
if not keras_3() or keras.backend.backend() == "tensorflow":
idx, valid_det = tf.image.non_max_suppression_padded(
box_prediction,
confidence_prediction,
Expand Down Expand Up @@ -318,7 +318,7 @@ def suppression_loop_body(boxes, iou_threshold, output_size, idx):

# TODO(ianstenbit): Fix bug in tfnp.take_along_axis that causes this hack.
# (This will be removed anyway when we use built-in NMS for TF.)
if multi_backend() and keras.backend.backend() != "tensorflow":
if keras_3() and keras.backend.backend() != "tensorflow":
idx = ops.take_along_axis(
ops.reshape(sorted_indices, [-1]), take_along_axis_idx
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from keras_cv.api_export import keras_cv_export
from keras_cv.backend import keras
from keras_cv.backend import scope
from keras_cv.backend.config import multi_backend
from keras_cv.backend.config import keras_3
from keras_cv.utils import preprocessing

H_AXIS = -3
Expand All @@ -44,7 +44,7 @@

base_class = (
keras.src.layers.preprocessing.tf_data_layer.TFDataLayer
if multi_backend()
if keras_3()
else keras.layers.Layer
)

Expand Down
Loading

0 comments on commit 431e97c

Please sign in to comment.