diff --git a/.github/workflows/auto-assignment.yml b/.github/workflows/auto-assignment.yml new file mode 100644 index 0000000000..de72da8ba2 --- /dev/null +++ b/.github/workflows/auto-assignment.yml @@ -0,0 +1,21 @@ +name: auto-assignment +on: + issues: + types: + - opened + +permissions: + contents: read + issues: write + pull-requests: write + +jobs: + welcome: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/github-script@v7 + with: + script: | + const script = require('./\.github/workflows/scripts/auto-assignment.js') + script({github, context}) diff --git a/.github/workflows/scripts/auto-assignment.js b/.github/workflows/scripts/auto-assignment.js new file mode 100644 index 0000000000..d08b06a8b7 --- /dev/null +++ b/.github/workflows/scripts/auto-assignment.js @@ -0,0 +1,43 @@ +/** Automatically assign issues and PRs to users in the `assigneesList` + * on a rotating basis. + + @param {!object} + GitHub objects can call GitHub APIs using their built-in library functions. + The context object contains issue and PR details. +*/ + +module.exports = async ({ github, context }) => { + let issueNumber; + let assigneesList; + // Is this an issue? If so, assign the issue number. Otherwise, assign the PR number. + if (context.payload.issue) { + //assignee List for issues. + assigneesList = ["SuryanarayanaY", "sachinprasadhs"]; + issueNumber = context.payload.issue.number; + } else { + //assignee List for PRs. + assigneesList = []; + issueNumber = context.payload.number; + } + console.log("assignee list", assigneesList); + console.log("entered auto assignment for this issue: ", issueNumber); + if (!assigneesList.length) { + console.log("No assignees found for this repo."); + return; + } + let noOfAssignees = assigneesList.length; + let selection = issueNumber % noOfAssignees; + let assigneeForIssue = assigneesList[selection]; + + console.log( + "issue Number = ", + issueNumber + " , assigning to: ", + assigneeForIssue + ); + return github.rest.issues.addAssignees({ + issue_number: context.issue.number, + owner: context.repo.owner, + repo: context.repo.repo, + assignees: [assigneeForIssue], + }); +}; diff --git a/.github/workflows/stale-issue-pr.yml b/.github/workflows/stale-issue-pr.yml new file mode 100644 index 0000000000..034fb4c266 --- /dev/null +++ b/.github/workflows/stale-issue-pr.yml @@ -0,0 +1,50 @@ +name: Close inactive issues +on: + schedule: + - cron: "30 1 * * *" +jobs: + close-issues: + runs-on: ubuntu-latest + permissions: + issues: write + pull-requests: write + steps: + - name: Awaiting response issues + uses: actions/stale@v9 + with: + days-before-issue-stale: 14 + days-before-issue-close: 14 + stale-issue-label: "stale" + # reason for closed the issue default value is not_planned + close-issue-reason: completed + only-labels: "stat:awaiting response from contributor" + stale-issue-message: > + This issue is stale because it has been open for 14 days with no activity. + It will be closed if no further activity occurs. Thank you. + # List of labels to remove when issues/PRs unstale. + labels-to-remove-when-unstale: "stat:awaiting response from contributor" + close-issue-message: > + This issue was closed because it has been inactive for 28 days. + Please reopen if you'd like to work on this further. + days-before-pr-stale: 14 + days-before-pr-close: 14 + stale-pr-message: "This PR is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you." + close-pr-message: "This PR was closed because it has been inactive for 28 days. Please reopen if you'd like to work on this further." + repo-token: ${{ secrets.GITHUB_TOKEN }} + - name: Contribution issues + uses: actions/stale@v9 + with: + days-before-issue-stale: 180 + days-before-issue-close: 365 + stale-issue-label: "stale" + # reason for closed the issue default value is not_planned + close-issue-reason: not_planned + any-of-labels: "stat:contributions welcome,good first issue" + # List of labels to remove when issues/PRs unstale. + labels-to-remove-when-unstale: "stat:contributions welcome,good first issue" + stale-issue-message: > + This issue is stale because it has been open for 180 days with no activity. + It will be closed if no further activity occurs. Thank you. + close-issue-message: > + This issue was closed because it has been inactive for more than 1 year. + repo-token: ${{ secrets.GITHUB_TOKEN }} diff --git a/.kokoro/github/ubuntu/gpu/build.sh b/.kokoro/github/ubuntu/gpu/build.sh index c145fae536..9d07218317 100644 --- a/.kokoro/github/ubuntu/gpu/build.sh +++ b/.kokoro/github/ubuntu/gpu/build.sh @@ -51,7 +51,7 @@ pip install --no-deps -e "." --progress-bar off # Run Extra Large Tests for Continuous builds if [ "${RUN_XLARGE:-0}" == "1" ] then - pytest --check_gpu --run_large --run_extra_large --durations 0 \ + pytest --cache-clear --check_gpu --run_large --run_extra_large --durations 0 \ keras_cv/bounding_box \ keras_cv/callbacks \ keras_cv/losses \ @@ -65,7 +65,7 @@ then keras_cv/models/segmentation \ keras_cv/models/stable_diffusion else - pytest --check_gpu --run_large --durations 0 \ + pytest --cache-clear --check_gpu --run_large --durations 0 \ keras_cv/bounding_box \ keras_cv/callbacks \ keras_cv/losses \ diff --git a/.github/API_DESIGN.md b/API_DESIGN.md similarity index 100% rename from .github/API_DESIGN.md rename to API_DESIGN.md diff --git a/.github/CALL_FOR_CONTRIBUTIONS.md b/CALL_FOR_CONTRIBUTIONS.md similarity index 100% rename from .github/CALL_FOR_CONTRIBUTIONS.md rename to CALL_FOR_CONTRIBUTIONS.md diff --git a/.github/CONTRIBUTING.md b/CONTRIBUTING.md similarity index 100% rename from .github/CONTRIBUTING.md rename to CONTRIBUTING.md diff --git a/benchmarks/vectorized_randomly_zoomed_crop.py b/benchmarks/vectorized_randomly_zoomed_crop.py index 4e807fd1ab..3a207ed2e3 100644 --- a/benchmarks/vectorized_randomly_zoomed_crop.py +++ b/benchmarks/vectorized_randomly_zoomed_crop.py @@ -249,10 +249,10 @@ def from_config(cls, config): config["zoom_factor"] ) if isinstance(config["aspect_ratio_factor"], dict): - config[ - "aspect_ratio_factor" - ] = keras.utils.deserialize_keras_object( - config["aspect_ratio_factor"] + config["aspect_ratio_factor"] = ( + keras.utils.deserialize_keras_object( + config["aspect_ratio_factor"] + ) ) return cls(**config) diff --git a/examples/layers/preprocessing/classification/random_crop_and_zoom.py b/examples/layers/preprocessing/classification/random_crop_and_resize_demo.py similarity index 90% rename from examples/layers/preprocessing/classification/random_crop_and_zoom.py rename to examples/layers/preprocessing/classification/random_crop_and_resize_demo.py index 3fe8aa1e00..cb8f9e5ffe 100644 --- a/examples/layers/preprocessing/classification/random_crop_and_zoom.py +++ b/examples/layers/preprocessing/classification/random_crop_and_resize_demo.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""random_resized_crop_demo.py.py shows how to use the RandomResizedCrop -preprocessing layer. Operates on an image of elephant. In this script the image +"""This demo example shows how to use the RandomCropAndResize preprocessing +layer. Operates on an image of elephant. In this script the image is loaded, then are passed through the preprocessing layers. Finally, they are shown using matplotlib. """ diff --git a/keras_cv/__init__.py b/keras_cv/__init__.py index 6d6ef3704c..36c7d3511b 100644 --- a/keras_cv/__init__.py +++ b/keras_cv/__init__.py @@ -41,5 +41,5 @@ from keras_cv.core import FactorSampler from keras_cv.core import NormalFactorSampler from keras_cv.core import UniformFactorSampler - -__version__ = "0.8.1" +from keras_cv.version_utils import __version__ +from keras_cv.version_utils import version diff --git a/keras_cv/layers/object_detection/anchor_generator.py b/keras_cv/layers/object_detection/anchor_generator.py index 30dd421afd..effc125143 100644 --- a/keras_cv/layers/object_detection/anchor_generator.py +++ b/keras_cv/layers/object_detection/anchor_generator.py @@ -172,7 +172,7 @@ def __call__(self, image=None, image_shape=None): "Expected `image` to be a Tensor of rank 3. Got " f"image.shape.rank={len(image.shape)}" ) - image_shape = image.shape + image_shape = tuple(image.shape) results = {} for key, generator in self.anchor_generators.items(): diff --git a/keras_cv/layers/preprocessing/base_image_augmentation_layer.py b/keras_cv/layers/preprocessing/base_image_augmentation_layer.py index ef2e9cefe7..167da7ad0b 100644 --- a/keras_cv/layers/preprocessing/base_image_augmentation_layer.py +++ b/keras_cv/layers/preprocessing/base_image_augmentation_layer.py @@ -236,15 +236,15 @@ def _compute_output_signature(self, inputs): bounding_boxes = inputs.get(BOUNDING_BOXES, None) if bounding_boxes is not None: - fn_output_signature[ - BOUNDING_BOXES - ] = self._compute_bounding_box_signature(bounding_boxes) + fn_output_signature[BOUNDING_BOXES] = ( + self._compute_bounding_box_signature(bounding_boxes) + ) segmentation_masks = inputs.get(SEGMENTATION_MASKS, None) if segmentation_masks is not None: - fn_output_signature[ - SEGMENTATION_MASKS - ] = self.compute_image_signature(segmentation_masks) + fn_output_signature[SEGMENTATION_MASKS] = ( + self.compute_image_signature(segmentation_masks) + ) keypoints = inputs.get(KEYPOINTS, None) if keypoints is not None: diff --git a/keras_cv/layers/preprocessing/random_crop_and_resize.py b/keras_cv/layers/preprocessing/random_crop_and_resize.py index 593515ad09..cd947d5835 100644 --- a/keras_cv/layers/preprocessing/random_crop_and_resize.py +++ b/keras_cv/layers/preprocessing/random_crop_and_resize.py @@ -272,10 +272,10 @@ def from_config(cls, config): config["crop_area_factor"] ) if isinstance(config["aspect_ratio_factor"], dict): - config[ - "aspect_ratio_factor" - ] = keras.utils.deserialize_keras_object( - config["aspect_ratio_factor"] + config["aspect_ratio_factor"] = ( + keras.utils.deserialize_keras_object( + config["aspect_ratio_factor"] + ) ) return cls(**config) diff --git a/keras_cv/layers/preprocessing/vectorized_base_image_augmentation_layer.py b/keras_cv/layers/preprocessing/vectorized_base_image_augmentation_layer.py index 3d9fc8e52a..fd36e22065 100644 --- a/keras_cv/layers/preprocessing/vectorized_base_image_augmentation_layer.py +++ b/keras_cv/layers/preprocessing/vectorized_base_image_augmentation_layer.py @@ -17,6 +17,7 @@ from keras_cv import bounding_box from keras_cv.api_export import keras_cv_export +from keras_cv.backend import config from keras_cv.backend import keras from keras_cv.backend import ops from keras_cv.backend import scope @@ -412,6 +413,8 @@ def _batch_augment(self, inputs): def call(self, inputs): # try to convert a given backend native tensor to TensorFlow tensor # before passing it over to TFDataScope + is_tf_backend = config.backend() == "tensorflow" + is_in_tf_graph = not tf.executing_eagerly() contains_ragged = lambda y: any( tree.map_structure( lambda x: isinstance(x, (tf.RaggedTensor, tf.SparseTensor)), @@ -419,7 +422,7 @@ def call(self, inputs): ) ) inputs_contain_ragged = contains_ragged(inputs) - if not inputs_contain_ragged: + if not is_tf_backend and not inputs_contain_ragged: inputs = tree.map_structure( lambda x: tf.convert_to_tensor(x), inputs ) @@ -443,13 +446,14 @@ def call(self, inputs): # backend native tensors. This is to avoid breaking TF data # pipelines that can't easily be ported to become backend # agnostic. - if not inputs_contain_ragged and not contains_ragged(outputs): - outputs = tree.map_structure( - # some layers return None, handle that case when - # converting to tensors - lambda x: ops.convert_to_tensor(x) if x is not None else x, - outputs, - ) + if not is_tf_backend and not is_in_tf_graph: + if not inputs_contain_ragged and not contains_ragged(outputs): + outputs = tree.map_structure( + # some layers return None, handle that case when + # converting to tensors + lambda x: ops.convert_to_tensor(x) if x is not None else x, + outputs, + ) return outputs def _format_inputs(self, inputs): diff --git a/keras_cv/layers/preprocessing/vectorized_base_image_augmentation_layer_test.py b/keras_cv/layers/preprocessing/vectorized_base_image_augmentation_layer_test.py index 3ebdfdb820..c2d0daa840 100644 --- a/keras_cv/layers/preprocessing/vectorized_base_image_augmentation_layer_test.py +++ b/keras_cv/layers/preprocessing/vectorized_base_image_augmentation_layer_test.py @@ -549,3 +549,15 @@ def test_converts_ragged_to_dense_segmentation_masks(self): {"images": images, "segmentation_masks": segmentation_masks} ) self.assertTrue(isinstance(result["segmentation_masks"], tf.Tensor)) + + def test_in_tf_data_pipeline(self): + images = np.random.randn(4, 100, 100, 3).astype("float32") + train_ds = tf.data.Dataset.from_tensor_slices(images) + train_ds = train_ds.map(lambda x: {"images": x}) + train_ds = train_ds.map( + VectorizedRandomAddLayer(fixed_value=2.0) + ).batch(4) + for output in train_ds.take(1): + pass + self.assertTrue(isinstance(output["images"], tf.Tensor)) + self.assertAllClose(output["images"], images + 2.0) diff --git a/keras_cv/layers/regularization/squeeze_excite.py b/keras_cv/layers/regularization/squeeze_excite.py index cb03cc6942..8cbcc5bd94 100644 --- a/keras_cv/layers/regularization/squeeze_excite.py +++ b/keras_cv/layers/regularization/squeeze_excite.py @@ -118,10 +118,10 @@ def get_config(self): @classmethod def from_config(cls, config): if isinstance(config["squeeze_activation"], dict): - config[ - "squeeze_activation" - ] = keras.saving.deserialize_keras_object( - config["squeeze_activation"] + config["squeeze_activation"] = ( + keras.saving.deserialize_keras_object( + config["squeeze_activation"] + ) ) if isinstance(config["excite_activation"], dict): config["excite_activation"] = keras.saving.deserialize_keras_object( diff --git a/keras_cv/layers/vit_det_layers.py b/keras_cv/layers/vit_det_layers.py index 9311a957f5..2e053db4cb 100644 --- a/keras_cv/layers/vit_det_layers.py +++ b/keras_cv/layers/vit_det_layers.py @@ -430,9 +430,9 @@ def __init__( key_dim=self.project_dim // self.num_heads, use_bias=use_bias, use_rel_pos=use_rel_pos, - input_size=input_size - if window_size == 0 - else (window_size, window_size), + input_size=( + input_size if window_size == 0 else (window_size, window_size) + ), ) self.mlp_block = MLP( mlp_dim, diff --git a/keras_cv/metrics/coco/pycoco_wrapper.py b/keras_cv/metrics/coco/pycoco_wrapper.py index 3c09784388..659cdef0a0 100644 --- a/keras_cv/metrics/coco/pycoco_wrapper.py +++ b/keras_cv/metrics/coco/pycoco_wrapper.py @@ -125,6 +125,9 @@ def _convert_predictions_to_coco_annotations(predictions): num_batches = len(predictions["source_id"]) for i in range(num_batches): batch_size = predictions["source_id"][i].shape[0] + predictions["detection_boxes"][i] = predictions["detection_boxes"][ + i + ].copy() for j in range(batch_size): max_num_detections = predictions["num_detections"][i][j] predictions["detection_boxes"][i][j] = _yxyx_to_xywh( diff --git a/keras_cv/metrics/object_detection/box_coco_metrics.py b/keras_cv/metrics/object_detection/box_coco_metrics.py index a59af8c767..47d86ba1c2 100644 --- a/keras_cv/metrics/object_detection/box_coco_metrics.py +++ b/keras_cv/metrics/object_detection/box_coco_metrics.py @@ -212,9 +212,9 @@ def result_fn(self, force=False): ) result = {} for i, key in enumerate(METRIC_NAMES): - result[ - self.name_prefix() + METRIC_MAPPING[key] - ] = py_func_result[i] + result[self.name_prefix() + METRIC_MAPPING[key]] = ( + py_func_result[i] + ) return result obj.result = types.MethodType(result_fn, obj) diff --git a/keras_cv/models/__init__.py b/keras_cv/models/__init__.py index ae775ed824..b9b90b946a 100644 --- a/keras_cv/models/__init__.py +++ b/keras_cv/models/__init__.py @@ -190,6 +190,7 @@ from keras_cv.models.object_detection.yolo_v8.yolo_v8_detector import ( YOLOV8Detector, ) +from keras_cv.models.segmentation import BASNet from keras_cv.models.segmentation import DeepLabV3Plus from keras_cv.models.segmentation import SAMMaskDecoder from keras_cv.models.segmentation import SAMPromptEncoder diff --git a/keras_cv/models/backbones/densenet/densenet_backbone.py b/keras_cv/models/backbones/densenet/densenet_backbone.py index 28109b64fa..251f3601ec 100644 --- a/keras_cv/models/backbones/densenet/densenet_backbone.py +++ b/keras_cv/models/backbones/densenet/densenet_backbone.py @@ -119,9 +119,9 @@ def __init__( name=f"conv{len(stackwise_num_repeats) + 1}", ) - pyramid_level_inputs[ - f"P{len(stackwise_num_repeats) + 1}" - ] = utils.get_tensor_input_name(x) + pyramid_level_inputs[f"P{len(stackwise_num_repeats) + 1}"] = ( + utils.get_tensor_input_name(x) + ) x = keras.layers.BatchNormalization( axis=BN_AXIS, epsilon=BN_EPSILON, name="bn" )(x) diff --git a/keras_cv/models/backbones/resnet_v1/resnet_v1_backbone.py b/keras_cv/models/backbones/resnet_v1/resnet_v1_backbone.py index 61046234d3..07c896613c 100644 --- a/keras_cv/models/backbones/resnet_v1/resnet_v1_backbone.py +++ b/keras_cv/models/backbones/resnet_v1/resnet_v1_backbone.py @@ -130,9 +130,9 @@ def __init__( first_shortcut=(block_type == "block" or stack_index > 0), name=f"v2_stack_{stack_index}", ) - pyramid_level_inputs[ - f"P{stack_index + 2}" - ] = utils.get_tensor_input_name(x) + pyramid_level_inputs[f"P{stack_index + 2}"] = ( + utils.get_tensor_input_name(x) + ) # Create model. super().__init__(inputs=inputs, outputs=x, **kwargs) diff --git a/keras_cv/models/backbones/resnet_v2/resnet_v2_backbone.py b/keras_cv/models/backbones/resnet_v2/resnet_v2_backbone.py index a31841f7fc..6a0cc74740 100644 --- a/keras_cv/models/backbones/resnet_v2/resnet_v2_backbone.py +++ b/keras_cv/models/backbones/resnet_v2/resnet_v2_backbone.py @@ -136,9 +136,9 @@ def __init__( first_shortcut=(block_type == "block" or stack_index > 0), name=f"v2_stack_{stack_index}", ) - pyramid_level_inputs[ - f"P{stack_index + 2}" - ] = utils.get_tensor_input_name(x) + pyramid_level_inputs[f"P{stack_index + 2}"] = ( + utils.get_tensor_input_name(x) + ) x = keras.layers.BatchNormalization( axis=BN_AXIS, epsilon=BN_EPSILON, name="post_bn" diff --git a/keras_cv/models/backbones/vit_det/vit_det_backbone.py b/keras_cv/models/backbones/vit_det/vit_det_backbone.py index c2c21ab98e..beb730f4df 100644 --- a/keras_cv/models/backbones/vit_det/vit_det_backbone.py +++ b/keras_cv/models/backbones/vit_det/vit_det_backbone.py @@ -144,9 +144,9 @@ def __init__( num_heads=num_heads, use_bias=use_bias, use_rel_pos=use_rel_pos, - window_size=window_size - if i not in global_attention_indices - else 0, + window_size=( + window_size if i not in global_attention_indices else 0 + ), input_size=(img_size // patch_size, img_size // patch_size), )(x) x = keras.models.Sequential( diff --git a/keras_cv/models/legacy/darknet.py b/keras_cv/models/legacy/darknet.py index ea7fd429f2..2dc14d499d 100644 --- a/keras_cv/models/legacy/darknet.py +++ b/keras_cv/models/legacy/darknet.py @@ -76,7 +76,6 @@ @keras.utils.register_keras_serializable(package="keras_cv.models") class DarkNet(keras.Model): - """Represents the DarkNet architecture. The DarkNet architecture is commonly used for detection tasks. It is diff --git a/keras_cv/models/legacy/mlp_mixer.py b/keras_cv/models/legacy/mlp_mixer.py index a48544f905..170d0a4c6f 100644 --- a/keras_cv/models/legacy/mlp_mixer.py +++ b/keras_cv/models/legacy/mlp_mixer.py @@ -143,7 +143,6 @@ def apply_mixer_block(x, tokens_mlp_dim, channels_mlp_dim, name=None): @keras.utils.register_keras_serializable(package="keras_cv.models") class MLPMixer(keras.Model): - """Instantiates the MLP Mixer architecture. Args: diff --git a/keras_cv/models/object_detection/yolo_v8/yolo_v8_backbone.py b/keras_cv/models/object_detection/yolo_v8/yolo_v8_backbone.py index f4bd99fafa..a2bf4bdd3b 100644 --- a/keras_cv/models/object_detection/yolo_v8/yolo_v8_backbone.py +++ b/keras_cv/models/object_detection/yolo_v8/yolo_v8_backbone.py @@ -178,9 +178,9 @@ def __init__( activation=activation, name=f"{stack_name}_spp_fast", ) - pyramid_level_inputs[ - f"P{stack_id + 2}" - ] = utils.get_tensor_input_name(x) + pyramid_level_inputs[f"P{stack_id + 2}"] = ( + utils.get_tensor_input_name(x) + ) super().__init__(inputs=inputs, outputs=x, **kwargs) self.pyramid_level_inputs = pyramid_level_inputs diff --git a/keras_cv/models/object_detection/yolo_v8/yolo_v8_detector.py b/keras_cv/models/object_detection/yolo_v8/yolo_v8_detector.py index bfba44945c..6c17c71a72 100644 --- a/keras_cv/models/object_detection/yolo_v8/yolo_v8_detector.py +++ b/keras_cv/models/object_detection/yolo_v8/yolo_v8_detector.py @@ -663,9 +663,9 @@ def from_config(cls, config): if prediction_decoder is not None and isinstance( prediction_decoder, dict ): - config[ - "prediction_decoder" - ] = keras.saving.deserialize_keras_object(prediction_decoder) + config["prediction_decoder"] = ( + keras.saving.deserialize_keras_object(prediction_decoder) + ) return cls(**config) @classproperty diff --git a/keras_cv/models/segmentation/__init__.py b/keras_cv/models/segmentation/__init__.py index aa4ffab4a4..13a9795dda 100644 --- a/keras_cv/models/segmentation/__init__.py +++ b/keras_cv/models/segmentation/__init__.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from keras_cv.models.segmentation.basnet import BASNet from keras_cv.models.segmentation.deeplab_v3_plus import DeepLabV3Plus from keras_cv.models.segmentation.segformer import SegFormer from keras_cv.models.segmentation.segment_anything import SAMMaskDecoder diff --git a/keras_cv/models/segmentation/basnet/__init__.py b/keras_cv/models/segmentation/basnet/__init__.py new file mode 100644 index 0000000000..b51fd6c004 --- /dev/null +++ b/keras_cv/models/segmentation/basnet/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2023 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from keras_cv.models.segmentation.basnet.basnet import BASNet diff --git a/keras_cv/models/segmentation/basnet/basnet.py b/keras_cv/models/segmentation/basnet/basnet.py new file mode 100644 index 0000000000..2803d4425c --- /dev/null +++ b/keras_cv/models/segmentation/basnet/basnet.py @@ -0,0 +1,454 @@ +# Copyright 2023 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy + +from keras_cv.api_export import keras_cv_export +from keras_cv.backend import keras +from keras_cv.models import utils +from keras_cv.models.backbones.backbone_presets import backbone_presets +from keras_cv.models.backbones.resnet_v1.resnet_v1_backbone import ( + apply_basic_block as resnet_basic_block, +) +from keras_cv.models.segmentation.basnet.basnet_presets import basnet_presets +from keras_cv.models.segmentation.basnet.basnet_presets import ( + presets_no_weights, +) +from keras_cv.models.segmentation.basnet.basnet_presets import ( + presets_with_weights, +) +from keras_cv.models.task import Task +from keras_cv.utils.python_utils import classproperty + + +@keras_cv_export( + [ + "keras_cv.models.BASNet", + "keras_cv.models.segmentation.BASNet", + ] +) +class BASNet(Task): + """ + A Keras model implementing the BASNet architecture for semantic + segmentation. + + References: + - [BASNet: Boundary-Aware Segmentation Network for Mobile and Web Applications](https://arxiv.org/abs/2101.04704) + + Args: + backbone: `keras.Model`. The backbone network for the model that is + used as a feature extractor for BASNet prediction encoder. Currently + supported backbones are ResNet18 and ResNet34. Default backbone is + `keras_cv.models.ResNet34Backbone()` + (Note: Do not specify 'input_shape', 'input_tensor', or 'include_rescaling' + within the backbone. Please provide these while initializing the + 'BASNet' model.) + num_classes: int, the number of classes for the segmentation model. + input_shape: optional shape tuple, defaults to (None, None, 3). + input_tensor: optional Keras tensor (i.e., output of `layers.Input()`) + to use as image input for the model. + include_rescaling: bool, whether to rescale the inputs. If set + to `True`, inputs will be passed through a `Rescaling(1/255.0)` + layer. + projection_filters: int, number of filters in the convolution layer + projecting low-level features from the `backbone`. + prediction_heads: (Optional) List of `keras.layers.Layer` defining + the prediction module head for the model. If not provided, a + default head is created with a Conv2D layer followed by resizing. + refinement_head: (Optional) a `keras.layers.Layer` defining the + refinement module head for the model. If not provided, a default + head is created with a Conv2D layer. + + Examples: + ```python + + import keras_cv + + images = np.ones(shape=(1, 288, 288, 3)) + labels = np.zeros(shape=(1, 288, 288, 1)) + + # Note: Do not specify 'input_shape', 'input_tensor', or + # 'include_rescaling' within the backbone. + backbone = keras_cv.models.ResNet34Backbone() + model = keras_cv.models.segmentation.BASNet( + backbone=backbone, + num_classes=1, + input_shape=[288, 288, 3], + include_rescaling=False + ) + + # Evaluate model + output = model(images) + pred_labels = output[0] + + # Train model + model.compile( + optimizer="adam", + loss=keras.losses.BinaryCrossentropy(from_logits=False), + metrics=["accuracy"], + ) + model.fit(images, labels, epochs=3) + ``` + """ # noqa: E501 + + def __init__( + self, + backbone, + num_classes, + input_shape=(None, None, 3), + input_tensor=None, + include_rescaling=False, + projection_filters=64, + prediction_heads=None, + refinement_head=None, + **kwargs, + ): + if not isinstance(backbone, keras.layers.Layer) or not isinstance( + backbone, keras.Model + ): + raise ValueError( + "Argument `backbone` must be a `keras.layers.Layer` instance" + f" or `keras.Model`. Received instead" + f" backbone={backbone} (of type {type(backbone)})." + ) + + if backbone.input_shape != (None, None, None, 3): + raise ValueError( + "Do not specify 'input_shape' or 'input_tensor' within the" + " 'BASNet' backbone. \nPlease provide 'input_shape' or" + " 'input_tensor' while initializing the 'BASNet' model." + ) + + inputs = utils.parse_model_inputs(input_shape, input_tensor) + x = inputs + + if include_rescaling: + x = keras.layers.Rescaling(1 / 255.0)(x) + + if prediction_heads is None: + prediction_heads = [] + for size in (1, 2, 4, 8, 16, 32, 32): + head_layers = [ + keras.layers.Conv2D( + num_classes, kernel_size=(3, 3), padding="same" + ) + ] + if size != 1: + head_layers.append( + keras.layers.UpSampling2D( + size=size, interpolation="bilinear" + ) + ) + prediction_heads.append(keras.Sequential(head_layers)) + + if refinement_head is None: + refinement_head = keras.Sequential( + [ + keras.layers.Conv2D( + num_classes, kernel_size=(3, 3), padding="same" + ), + ] + ) + + # Prediction model. + predict_model = basnet_predict( + x, backbone, projection_filters, prediction_heads + ) + + # Refinement model. + refine_model = basnet_rrm( + predict_model, projection_filters, refinement_head + ) + + outputs = refine_model.outputs # Combine outputs. + outputs.extend(predict_model.outputs) + + outputs = [ + keras.layers.Activation("sigmoid", dtype="float32")(_) + for _ in outputs + ] # Activations. + + super().__init__(inputs=inputs, outputs=outputs, **kwargs) + + self.backbone = backbone + self.num_classes = num_classes + self.input_tensor = input_tensor + self.include_rescaling = include_rescaling + self.projection_filters = projection_filters + self.prediction_heads = prediction_heads + self.refinement_head = refinement_head + + def get_config(self): + return { + "backbone": keras.saving.serialize_keras_object(self.backbone), + "num_classes": self.num_classes, + "input_shape": self.input_shape[1:], + "input_tensor": keras.saving.serialize_keras_object( + self.input_tensor + ), + "include_rescaling": self.include_rescaling, + "projection_filters": self.projection_filters, + "prediction_heads": [ + keras.saving.serialize_keras_object(prediction_head) + for prediction_head in self.prediction_heads + ], + "refinement_head": keras.saving.serialize_keras_object( + self.refinement_head + ), + } + + @classmethod + def from_config(cls, config): + if "backbone" in config and isinstance(config["backbone"], dict): + input_shape = (None, None, 3) + if isinstance(config["backbone"]["config"]["input_shape"], list): + input_shape = list(input_shape) + if config["backbone"]["config"]["input_shape"] != input_shape: + config["input_shape"] = config["backbone"]["config"][ + "input_shape" + ] + config["backbone"]["config"]["input_shape"] = input_shape + config["backbone"] = keras.layers.deserialize(config["backbone"]) + + if "input_tensor" in config and isinstance( + config["input_tensor"], dict + ): + config["input_tensor"] = keras.layers.deserialize( + config["input_tensor"] + ) + + if "prediction_heads" in config and isinstance( + config["prediction_heads"], list + ): + for i in range(len(config["prediction_heads"])): + if isinstance(config["prediction_heads"][i], dict): + config["prediction_heads"][i] = keras.layers.deserialize( + config["prediction_heads"][i] + ) + + if "refinement_head" in config and isinstance( + config["refinement_head"], dict + ): + config["refinement_head"] = keras.layers.deserialize( + config["refinement_head"] + ) + return super().from_config(config) + + @classproperty + def presets(cls): + """Dictionary of preset names and configurations.""" + filtered_backbone_presets = copy.deepcopy( + { + k: v + for k, v in backbone_presets.items() + if k in ("resnet18", "resnet34") + } + ) + + return copy.deepcopy({**filtered_backbone_presets, **basnet_presets}) + + @classproperty + def presets_with_weights(cls): + """ + Dictionary of preset names and configurations that include weights. + """ + return copy.deepcopy(presets_with_weights) + + @classproperty + def presets_without_weights(cls): + """ + Dictionary of preset names and configurations that has no weights. + """ + return copy.deepcopy(presets_no_weights) + + @classproperty + def backbone_presets(cls): + """ + Dictionary of preset names and configurations of compatible backbones. + """ + filtered_backbone_presets = copy.deepcopy( + { + k: v + for k, v in backbone_presets.items() + if k in ("resnet18", "resnet34") + } + ) + filtered_presets = copy.deepcopy(filtered_backbone_presets) + return filtered_presets + + +def convolution_block(x_input, filters, dilation=1): + """ + Apply convolution + batch normalization + ReLU activation. + + Args: + x_input: Input keras tensor. + filters: int, number of output filters in the convolution. + dilation: int, dilation rate for the convolution operation. + Defaults to 1. + + Returns: + A tensor with convolution, batch normalization, and ReLU + activation applied. + """ + x = keras.layers.Conv2D( + filters, (3, 3), padding="same", dilation_rate=dilation + )(x_input) + x = keras.layers.BatchNormalization()(x) + return keras.layers.Activation("relu")(x) + + +def get_resnet_block(_resnet, block_num): + """ + Extract and return a specific ResNet block. + + Args: + _resnet: `keras.Model`. ResNet model instance. + block_num: int, block number to extract. + + Returns: + A Keras Model representing the specified ResNet block. + """ + + extractor_levels = ["P2", "P3", "P4", "P5"] + return keras.models.Model( + inputs=_resnet.get_layer(f"v2_stack_{block_num}_block1_1_conv").input, + outputs=_resnet.get_layer( + _resnet.pyramid_level_inputs[extractor_levels[block_num]] + ).output, + name=f"resnet_block{block_num + 1}", + ) + + +def basnet_predict(x_input, backbone, filters, segmentation_heads): + """ + BASNet Prediction Module. + + This module outputs a coarse label map by integrating heavy + encoder, bridge, and decoder blocks. + + Args: + x_input: Input keras tensor. + backbone: `keras.Model`. The backbone network used as a feature + extractor for BASNet prediction encoder. + filters: int, the number of filters. + segmentation_heads: List of `keras.layers.Layer`, A list of Keras + layers serving as the segmentation head for prediction module. + + + Returns: + A Keras Model that integrates the encoder, bridge, and decoder + blocks for coarse label map prediction. + """ + num_stages = 6 + + x = x_input + + # -------------Encoder-------------- + x = keras.layers.Conv2D(filters, kernel_size=(3, 3), padding="same")(x) + + encoder_blocks = [] + for i in range(num_stages): + if i < 4: # First four stages are adopted from ResNet backbone. + x = get_resnet_block(backbone, i)(x) + encoder_blocks.append(x) + else: # Last 2 stages consist of three basic resnet blocks. + x = keras.layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2))(x) + for j in range(3): + x = resnet_basic_block( + x, + filters=x.shape[3], + conv_shortcut=False, + name=f"v1_basic_block_{i + 1}_{j + 1}", + ) + encoder_blocks.append(x) + + # -------------Bridge------------- + x = convolution_block(x, filters=filters * 8, dilation=2) + x = convolution_block(x, filters=filters * 8, dilation=2) + x = convolution_block(x, filters=filters * 8, dilation=2) + encoder_blocks.append(x) + + # -------------Decoder------------- + decoder_blocks = [] + for i in reversed(range(num_stages)): + if i != (num_stages - 1): # Except first, scale other decoder stages. + x = keras.layers.UpSampling2D(size=2, interpolation="bilinear")(x) + + x = keras.layers.concatenate([encoder_blocks[i], x], axis=-1) + x = convolution_block(x, filters=filters * 8) + x = convolution_block(x, filters=filters * 8) + x = convolution_block(x, filters=filters * 8) + decoder_blocks.append(x) + + decoder_blocks.reverse() # Change order from last to first decoder stage. + decoder_blocks.append(encoder_blocks[-1]) # Copy bridge to decoder. + + # -------------Side Outputs-------------- + decoder_blocks = [ + segmentation_head(decoder_block) # Prediction segmentation head. + for segmentation_head, decoder_block in zip( + segmentation_heads, decoder_blocks + ) + ] + + return keras.models.Model(inputs=[x_input], outputs=decoder_blocks) + + +def basnet_rrm(base_model, filters, segmentation_head): + """ + BASNet Residual Refinement Module (RRM). + + This module outputs a fine label map by integrating light encoder, + bridge, and decoder blocks. + + Args: + base_model: Keras model used as the base or coarse label map. + filters: int, the number of filters. + segmentation_head: a `keras.layers.Layer`, A Keras layer serving + as the segmentation head for refinement module. + + Returns: + A Keras Model that constructs the Residual Refinement Module (RRM). + """ + num_stages = 4 + + x_input = base_model.output[0] + + # -------------Encoder-------------- + x = keras.layers.Conv2D(filters, kernel_size=(3, 3), padding="same")( + x_input + ) + + encoder_blocks = [] + for _ in range(num_stages): + x = convolution_block(x, filters=filters) + encoder_blocks.append(x) + x = keras.layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2))(x) + + # -------------Bridge-------------- + x = convolution_block(x, filters=filters) + + # -------------Decoder-------------- + for i in reversed(range(num_stages)): + x = keras.layers.UpSampling2D(size=2, interpolation="bilinear")(x) + x = keras.layers.concatenate([encoder_blocks[i], x], axis=-1) + x = convolution_block(x, filters=filters) + + x = segmentation_head(x) # Refinement segmentation head. + + # ------------- refined = coarse + residual + x = keras.layers.Add()([x_input, x]) # Add prediction + refinement output + + return keras.models.Model(inputs=base_model.input, outputs=[x]) diff --git a/keras_cv/models/segmentation/basnet/basnet_presets.py b/keras_cv/models/segmentation/basnet/basnet_presets.py new file mode 100644 index 0000000000..69d323fd0f --- /dev/null +++ b/keras_cv/models/segmentation/basnet/basnet_presets.py @@ -0,0 +1,51 @@ +# Copyright 2023 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""BASNet model preset configurations.""" + +from keras_cv.models.backbones.resnet_v1 import resnet_v1_backbone_presets + +presets_no_weights = { + "basnet_resnet18": { + "metadata": { + "description": "BASNet with a ResNet18 v1 backbone.", + "params": 98780872, + "official_name": "BASNet", + "path": "basnet_resnet18", + }, + "config": { + "backbone": resnet_v1_backbone_presets.backbone_presets["resnet18"], + "num_classes": 1, + "input_shape": (288, 288, 3), + }, + }, + "basnet_resnet34": { + "metadata": { + "description": "BASNet with a ResNet34 v1 backbone.", + "params": 108896456, + "official_name": "BASNet", + "path": "basnet_resnet34", + }, + "config": { + "backbone": resnet_v1_backbone_presets.backbone_presets["resnet34"], + "num_classes": 1, + "input_shape": (288, 288, 3), + }, + }, +} + +presets_with_weights = { + # TODO: Add BASNet preset with weights +} + +basnet_presets = {**presets_no_weights, **presets_with_weights} diff --git a/keras_cv/models/segmentation/basnet/basnet_test.py b/keras_cv/models/segmentation/basnet/basnet_test.py new file mode 100644 index 0000000000..88408c134c --- /dev/null +++ b/keras_cv/models/segmentation/basnet/basnet_test.py @@ -0,0 +1,142 @@ +# Copyright 2023 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import os + +import numpy as np +import pytest +import tensorflow as tf +from absl.testing import parameterized + +from keras_cv.backend import keras +from keras_cv.backend import ops +from keras_cv.backend.config import keras_3 +from keras_cv.models import BASNet +from keras_cv.models import ResNet18Backbone +from keras_cv.tests.test_case import TestCase + + +class BASNetTest(TestCase): + def test_basnet_construction(self): + backbone = ResNet18Backbone() + model = BASNet( + input_shape=[64, 64, 3], backbone=backbone, num_classes=1 + ) + model.compile( + optimizer="adam", + loss=keras.losses.BinaryCrossentropy(), + metrics=["accuracy"], + ) + + @pytest.mark.large + def test_basnet_call(self): + backbone = ResNet18Backbone() + model = BASNet( + input_shape=[64, 64, 3], backbone=backbone, num_classes=1 + ) + images = np.random.uniform(size=(2, 64, 64, 3)) + _ = model(images) + _ = model.predict(images) + + @pytest.mark.large + @pytest.mark.filterwarnings("ignore::UserWarning") + def test_weights_change(self): + input_size = [64, 64, 3] + target_size = [64, 64, 1] + + images = np.ones([1] + input_size) + labels = np.random.uniform(size=[1] + target_size) + ds = tf.data.Dataset.from_tensor_slices((images, labels)) + ds = ds.repeat(2) + ds = ds.batch(2) + + backbone = ResNet18Backbone() + model = BASNet( + input_shape=[64, 64, 3], backbone=backbone, num_classes=1 + ) + model_metrics = ["accuracy"] + if keras_3(): + model_metrics = ["accuracy" for _ in range(8)] + + model.compile( + optimizer="adam", + loss=keras.losses.BinaryCrossentropy(), + metrics=model_metrics, + ) + + original_weights = model.refinement_head.get_weights() + model.fit(ds, epochs=1, batch_size=1) + updated_weights = model.refinement_head.get_weights() + + for w1, w2 in zip(original_weights, updated_weights): + self.assertNotAllEqual(w1, w2) + self.assertFalse(ops.any(ops.isnan(w2))) + + @pytest.mark.large + def test_with_model_preset_forward_pass(self): + self.skipTest("Skipping preset test until BASNet weights are added.") + model = BASNet.from_preset( + "basnet_resnet34", + ) + image = np.ones((1, 288, 288, 3)) + output = ops.expand_dims(ops.argmax(model(image), axis=-1), axis=-1) + output = output[0] + expected_output = np.zeros((1, 288, 288, 1)) + self.assertAllClose(output, expected_output) + + @pytest.mark.large + def test_saved_model(self): + target_size = [64, 64, 3] + + backbone = ResNet18Backbone() + model = BASNet( + input_shape=[64, 64, 3], backbone=backbone, num_classes=1 + ) + + input_batch = np.ones(shape=[2] + target_size) + model_output = model(input_batch) + + save_path = os.path.join(self.get_temp_dir(), "model.keras") + if keras_3(): + model.save(save_path) + else: + model.save(save_path, save_format="keras_v3") + # Free up model memory + del model + gc.collect() + restored_model = keras.models.load_model(save_path) + + # Check we got the real object back. + self.assertIsInstance(restored_model, BASNet) + + # Check that output matches. + restored_output = restored_model(input_batch) + self.assertAllClose(model_output, restored_output) + + +@pytest.mark.large +class BASNetSmokeTest(TestCase): + @parameterized.named_parameters( + *[(preset, preset) for preset in ["resnet18", "resnet34"]] + ) + def test_backbone_preset(self, preset): + model = BASNet.from_preset( + preset, + num_classes=1, + ) + xs = np.random.uniform(size=(1, 128, 128, 3)) + output = model(xs)[0] + + self.assertEqual(output.shape, (1, 128, 128, 1)) diff --git a/keras_cv/models/stable_diffusion/noise_scheduler.py b/keras_cv/models/stable_diffusion/noise_scheduler.py index bd1c0dc51e..c5c100848c 100644 --- a/keras_cv/models/stable_diffusion/noise_scheduler.py +++ b/keras_cv/models/stable_diffusion/noise_scheduler.py @@ -54,9 +54,7 @@ def __init__( elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. self.betas = ( - ops.linspace( - beta_start**0.5, beta_end**0.5, train_timesteps - ) + ops.linspace(beta_start**0.5, beta_end**0.5, train_timesteps) ** 2 ) else: diff --git a/keras_cv/models/stable_diffusion/stable_diffusion.py b/keras_cv/models/stable_diffusion/stable_diffusion.py index 299f44d3d0..a68923dc78 100644 --- a/keras_cv/models/stable_diffusion/stable_diffusion.py +++ b/keras_cv/models/stable_diffusion/stable_diffusion.py @@ -209,7 +209,10 @@ def generate_image( latent = self._get_initial_diffusion_noise(batch_size, seed) # Iterative reverse diffusion stage - timesteps = np.arange(1, 1000, 1000 // num_steps) + num_timesteps = 1000 + ratio = (num_timesteps - 1) / (num_steps - 1) + timesteps = (np.arange(0, num_steps) * ratio).round().astype(np.int64) + alphas, alphas_prev = self._get_initial_alphas(timesteps) progbar = keras.utils.Progbar(len(timesteps)) iteration = 0 diff --git a/keras_cv/version_utils.py b/keras_cv/version_utils.py new file mode 100644 index 0000000000..527546c643 --- /dev/null +++ b/keras_cv/version_utils.py @@ -0,0 +1,23 @@ +# Copyright 2023 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from keras_cv.api_export import keras_cv_export + +# Unique source of truth for the version number. +__version__ = "0.8.2" + + +@keras_cv_export("keras_cv.version") +def version(): + return __version__ diff --git a/pip_build.py b/pip_build.py index 29f574001f..ba61963697 100644 --- a/pip_build.py +++ b/pip_build.py @@ -64,11 +64,22 @@ def export_version_string(version, is_nightly=False): ) f.write(setup_contents) + # Overwrite the version string with our package version. + with open(os.path.join(package, "src", "version_utils.py")) as f: + version_contents = f.readlines() + with open(os.path.join(package, "src", "version_utils.py"), "w") as f: + for line in version_contents: + if line.startswith("__version__"): + f.write(f'__version__ = "{version}"\n') + else: + f.write(line) + # Make sure to export the __version__ string with open(os.path.join(package, "__init__.py")) as f: init_contents = f.read() with open(os.path.join(package, "__init__.py"), "w") as f: - f.write(init_contents + "\n\n" + f'__version__ = "{version}"\n') + f.write(init_contents) + f.write("from keras_cv.src.version_utils import __version__\n") def copy_source_to_build_directory(root_path): diff --git a/setup.py b/setup.py index ffe7cbb4a8..19dc42248c 100644 --- a/setup.py +++ b/setup.py @@ -21,6 +21,21 @@ from setuptools import setup from setuptools.dist import Distribution + +def read(rel_path): + here = os.path.abspath(os.path.dirname(__file__)) + with open(os.path.join(here, rel_path)) as fp: + return fp.read() + + +def get_version(rel_path): + for line in read(rel_path).splitlines(): + if line.startswith("__version__"): + delim = '"' if '"' in line else "'" + return line.split(delim)[1] + raise RuntimeError("Unable to find version string.") + + BUILD_WITH_CUSTOM_OPS = ( "BUILD_WITH_CUSTOM_OPS" in os.environ and os.environ["BUILD_WITH_CUSTOM_OPS"] == "true" @@ -28,6 +43,10 @@ HERE = pathlib.Path(__file__).parent README = (HERE / "README.md").read_text() +if os.path.exists("keras_cv/version_utils.py"): + VERSION = get_version("keras_cv/version_utils.py") +else: + VERSION = get_version("keras_cv/src/version_utils.py") class BinaryDistribution(Distribution): @@ -45,6 +64,7 @@ def is_pure(self): description="Industry-strength computer Vision extensions for Keras.", long_description=README, long_description_content_type="text/markdown", + version=VERSION, url="https://github.com/keras-team/keras-cv", author="Keras team", author_email="keras-cv@google.com",