diff --git a/keras_cv/models/segmentation/segformer/segformer_test.py b/keras_cv/models/segmentation/segformer/segformer_test.py index 5f33b033da..bbd33ad472 100644 --- a/keras_cv/models/segmentation/segformer/segformer_test.py +++ b/keras_cv/models/segmentation/segformer/segformer_test.py @@ -54,12 +54,13 @@ def test_segformer_preset_error(self): def test_segformer_call(self): backbone = MiTBackbone.from_preset("mit_b0") mit_model = SegFormer(backbone=backbone, num_classes=1) - + mit_model.compile(loss=keras.losses.BinaryCrossentropy()) images = np.random.uniform(size=(2, 224, 224, 3)) mit_output = mit_model(images) mit_pred = mit_model.predict(images) seg_model = SegFormer.from_preset("segformer_b0", num_classes=1) + seg_model.compile(loss=keras.losses.BinaryCrossentropy()) seg_output = seg_model(images) seg_pred = seg_model.predict(images) @@ -98,7 +99,7 @@ def test_saved_model(self): target_size = [512, 512, 3] backbone = MiTBackbone.from_preset("mit_b0", input_shape=[512, 512, 3]) - model = SegFormer(backbone=backbone, num_classes=1) + model = SegFormer(backbone=backbone, num_classes=2) input_batch = np.ones(shape=[2] + target_size) model_output = model(input_batch) @@ -121,7 +122,7 @@ def test_saved_model(self): def test_preset_saved_model(self): target_size = [224, 224, 3] - model = SegFormer.from_preset("segformer_b0", num_classes=1) + model = SegFormer.from_preset("segformer_b0", num_classes=2) input_batch = np.ones(shape=[2] + target_size) model_output = model(input_batch)