diff --git a/android/pytorch_android/semantic_segmentation/src/main/java/ai/djl/examples/semanticsegmentation/SelfSemanticSegmentationTranslator.java b/android/pytorch_android/semantic_segmentation/src/main/java/ai/djl/examples/semanticsegmentation/SelfSemanticSegmentationTranslator.java new file mode 100644 index 00000000..82c98801 --- /dev/null +++ b/android/pytorch_android/semantic_segmentation/src/main/java/ai/djl/examples/semanticsegmentation/SelfSemanticSegmentationTranslator.java @@ -0,0 +1,195 @@ +package ai.djl.examples.semanticsegmentation; + +import java.nio.ByteBuffer; +import java.util.Map; + +import ai.djl.modality.cv.Image; +import ai.djl.modality.cv.ImageFactory; +import ai.djl.modality.cv.output.DetectedObjects; +import ai.djl.modality.cv.translator.BaseImageTranslator; +import ai.djl.modality.cv.util.NDImageUtils; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.types.DataType; +import ai.djl.ndarray.types.Shape; +import ai.djl.translate.ArgumentsUtil; +import ai.djl.translate.Transform; +import ai.djl.translate.TranslatorContext; + +/** + * A {@link BaseImageTranslator} that post-process the {@link NDArray} into {@link DetectedObjects} + * with boundaries at the detailed pixel level. + */ +public class SelfSemanticSegmentationTranslator extends BaseImageTranslator { + + private final int shortEdge; + private final int maxEdge; + + private static final int CHANNEL = 3; + private static final int CLASSNUM = 21; + + /** + * Creates the Semantic Segmentation translator from the given builder. + * + * @param builder the builder for the translator + */ + public SelfSemanticSegmentationTranslator(Builder builder) { + super(builder); + this.shortEdge = builder.shortEdge; + this.maxEdge = builder.maxEdge; + + pipeline.insert(0, null, new ResizeShort()); + } + + /** {@inheritDoc} */ + @Override + public NDList processInput(TranslatorContext ctx, Image image) { + ctx.setAttachment("originalHeight", image.getHeight()); + ctx.setAttachment("originalWidth", image.getWidth()); + ctx.setAttachment("originalImage",image); + + float a = 128; + byte b = (byte) a; + System.out.println("**************************************************************: "+ b); + return super.processInput(ctx, image); + } + + /** {@inheritDoc} */ + @Override + public Image processOutput(TranslatorContext ctx, NDList list) { + // scores contains the probabilities of each pixel being a certain object + float[] scores = list.get(1).toFloatArray(); + Shape shape = list.get(1).getShape(); + int width = (int) shape.get(2); + int height = (int) shape.get(1); + Image originalImage = (Image) ctx.getAttachment("originalImage"); + + + // build image array + try (NDManager manager = NDManager.newBaseManager()) { + NDArray imageArray = originalImage.toNDArray(manager); + imageArray = NDImageUtils.resize(imageArray, width, height); + imageArray.toUint8Array(); + + int imageSize = width * height; + ByteBuffer bb = manager.allocateDirect(CHANNEL * imageSize); + int[] row = imageArray.toType(DataType.UINT8, false).toUint8Array(); + + // change color of pixels in image array where objects have been detected + for (int h = 0; h < height; h++) { + for (int w = 0; w < width; w++) { + int index = h * width + w; + int maxi = 0; + double maxnum = -Double.MAX_VALUE; + for (int i = 0; i < CLASSNUM; i++) { + // get score for each i at the h,w pixel of the image + float score = scores[i * (imageSize) + index]; + if (score > maxnum) { + maxnum = score; + maxi = i; + } + } + if (maxi > 0) { + + bb.put((byte) (row[3*index])); + bb.put((byte) (row[3*index+1])); + bb.put((byte) (row[3*index+2])); + + } else { + bb.position(bb.position() + 3); + } + } + } + bb.rewind(); + int originW = (int) ctx.getAttachment("originalWidth"); + int originH = (int) ctx.getAttachment("originalHeight"); + NDArray fullImage = + manager.create(bb, new Shape(height, width, CHANNEL), DataType.UINT8); + NDArray resized = NDImageUtils.resize(fullImage, originW, originH); + + return ImageFactory.getInstance().fromNDArray(resized); + } + } + + /** + * Creates a builder to build a {@code SelfSemanticSegmentationTranslator}. + * + * @return a new builder + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Creates a builder to build a {@code SelfSemanticSegmentationTranslator} with specified arguments. + * + * @param arguments arguments to specify builder options + * @return a new builder + */ + public static Builder builder(Map arguments) { + Builder builder = new Builder(); + + builder.configPreProcess(arguments); + builder.configPostProcess(arguments); + + return builder; + } + + /** Resizes the image based on the shorter edge or maximum edge length. */ + private class ResizeShort implements Transform { + /** {@inheritDoc} */ + @Override + public NDArray transform(NDArray array) { + Shape shape = array.getShape(); + int width = (int) shape.get(1); + int height = (int) shape.get(0); + int min = Math.min(width, height); + int max = Math.max(width, height); + float scale = shortEdge / (float) min; + if (Math.round(scale * max) > maxEdge) { + scale = maxEdge / (float) max; + } + int rescaledHeight = Math.round(height * scale); + int rescaledWidth = Math.round(width * scale); + + return NDImageUtils.resize(array, rescaledWidth, rescaledHeight); + } + } + + /** The builder for Semantic Segmentation translator. */ + public static class Builder extends ClassificationBuilder { + int shortEdge = 600; + int maxEdge = 1000; + + Builder() {} + + /** {@inheritDoc} */ + @Override + protected Builder self() { + return this; + } + + /** {@inheritDoc} */ + @Override + protected void configPostProcess(Map arguments) { + super.configPostProcess(arguments); + shortEdge = ArgumentsUtil.intValue(arguments, "shortEdge", 600); + maxEdge = ArgumentsUtil.intValue(arguments, "maxEdge", 1000); + } + @Override + protected void configPreProcess(Map arguments) { + super.configPreProcess(arguments); + } + + /** + * Builds the translator. + * + * @return the new translator + */ + public SelfSemanticSegmentationTranslator build() { + validate(); + return new SelfSemanticSegmentationTranslator(this); + } + } +} diff --git a/android/pytorch_android/semantic_segmentation/src/main/java/ai/djl/examples/semanticsegmentation/SemanticActivity.java b/android/pytorch_android/semantic_segmentation/src/main/java/ai/djl/examples/semanticsegmentation/SemanticActivity.java index fcbc0b0a..8132cfe5 100644 --- a/android/pytorch_android/semantic_segmentation/src/main/java/ai/djl/examples/semanticsegmentation/SemanticActivity.java +++ b/android/pytorch_android/semantic_segmentation/src/main/java/ai/djl/examples/semanticsegmentation/SemanticActivity.java @@ -21,6 +21,7 @@ import android.os.Bundle; import android.util.Log; import android.util.Size; +import android.view.MotionEvent; import android.view.Surface; import android.view.View; import android.widget.ImageButton; @@ -30,6 +31,7 @@ import androidx.annotation.NonNull; import androidx.appcompat.app.AlertDialog; import androidx.appcompat.app.AppCompatActivity; +import androidx.appcompat.widget.SwitchCompat; import androidx.camera.camera2.Camera2Config; import androidx.camera.core.CameraSelector; import androidx.camera.core.CameraXConfig; @@ -71,9 +73,11 @@ public class SemanticActivity extends AppCompatActivity implements CameraXConfig private ImageView mImagePredicted; private ProgressBar mProgressBar; private FloatingActionButton mCaptureButton; + private SwitchCompat mSwitch; private ImageCapture imageCapture; private Bitmap bitmapBuffer; + boolean isUseSelfTranslator = false; ZooModel model; Predictor predictor; @@ -91,8 +95,19 @@ protected void onCreate(Bundle savedInstanceState) { mImagePredicted = binding.imagePredicted; mProgressBar = binding.progressBar; mCaptureButton = binding.captureButton; + mSwitch = binding.useSelfTranslatorSwitch; + mSwitch.setOnTouchListener(new View.OnTouchListener(){ + @Override + public boolean onTouch(View view, MotionEvent motionEvent) { + view.performClick(); + isUseSelfTranslator = mSwitch.isChecked(); + executor.execute(new LoadModelTask()); + return false; + } + }); mCaptureButton.setOnClickListener(view -> { + mCaptureButton.setEnabled(false); mImagePredicted.setVisibility(View.GONE); @@ -144,6 +159,7 @@ public void onError(@NonNull ImageCaptureException exception) { Snackbar.make(findViewById(android.R.id.content), R.string.message_download_model, Snackbar.LENGTH_LONG).show(); + //first initialize executor.execute(new LoadModelTask()); } @@ -240,7 +256,8 @@ private class LoadModelTask implements Runnable { @Override public void run() { try { - model = SemanticModel.loadModel(); + isUseSelfTranslator = mSwitch.isChecked(); + model = SemanticModel.loadModel(isUseSelfTranslator); predictor = model.newPredictor(); runOnUiThread(() -> { mCaptureButton.setEnabled(true); diff --git a/android/pytorch_android/semantic_segmentation/src/main/java/ai/djl/examples/semanticsegmentation/SemanticModel.java b/android/pytorch_android/semantic_segmentation/src/main/java/ai/djl/examples/semanticsegmentation/SemanticModel.java index 15f45779..8e7ea1d5 100644 --- a/android/pytorch_android/semantic_segmentation/src/main/java/ai/djl/examples/semanticsegmentation/SemanticModel.java +++ b/android/pytorch_android/semantic_segmentation/src/main/java/ai/djl/examples/semanticsegmentation/SemanticModel.java @@ -22,20 +22,27 @@ import ai.djl.modality.cv.translator.SemanticSegmentationTranslator; import ai.djl.repository.zoo.Criteria; import ai.djl.repository.zoo.ZooModel; +import ai.djl.translate.Translator; final class SemanticModel { private SemanticModel() { } - public static ZooModel loadModel() throws ModelException, IOException { + public static ZooModel loadModel(boolean isUseSelfTranslator) throws ModelException, IOException { String url = "https://mlrepo.djl.ai/model/cv/semantic_segmentation/ai/djl/pytorch/deeplabv3/0.0.1/deeplabv3.zip"; Map arguments = new ConcurrentHashMap<>(); arguments.put("toTensor", "true"); arguments.put("normalize", "true"); - SemanticSegmentationTranslator translator = - SemanticSegmentationTranslator.builder(arguments).build(); + Translator translator; + if(isUseSelfTranslator){ + translator = SelfSemanticSegmentationTranslator.builder(arguments).build(); + }else { + translator = SemanticSegmentationTranslator.builder(arguments).build(); + } + + Criteria criteria = Criteria.builder() diff --git a/android/pytorch_android/semantic_segmentation/src/main/res/layout/activity_main.xml b/android/pytorch_android/semantic_segmentation/src/main/res/layout/activity_main.xml index b09addc8..58ef22f6 100644 --- a/android/pytorch_android/semantic_segmentation/src/main/res/layout/activity_main.xml +++ b/android/pytorch_android/semantic_segmentation/src/main/res/layout/activity_main.xml @@ -4,6 +4,7 @@ xmlns:tools="http://schemas.android.com/tools" android:layout_width="match_parent" android:layout_height="match_parent" + android:elevation="10dp" tools:context=".SemanticActivity"> + app:layout_constraintStart_toStartOf="parent" + app:layout_constraintTop_toTopOf="parent" /> + app:layout_constraintTop_toTopOf="@id/image_predicted" + app:tint="@color/button_selector" /> + app:layout_constraintStart_toStartOf="parent" + app:layout_constraintTop_toTopOf="@+id/image_predicted" /> + + + app:layout_constraintHorizontal_bias="0.498" + app:layout_constraintStart_toStartOf="parent" /> \ No newline at end of file diff --git a/android/pytorch_android/semantic_segmentation/src/main/res/values/strings.xml b/android/pytorch_android/semantic_segmentation/src/main/res/values/strings.xml index 4b136055..9a0b661a 100644 --- a/android/pytorch_android/semantic_segmentation/src/main/res/values/strings.xml +++ b/android/pytorch_android/semantic_segmentation/src/main/res/values/strings.xml @@ -4,4 +4,5 @@ Download model complete Error Download model failed + use self-translator \ No newline at end of file