diff --git a/android/pytorch_android/semantic_segmentation/build.gradle b/android/pytorch_android/semantic_segmentation/build.gradle index 8ec2c2c1..51f01e4e 100644 --- a/android/pytorch_android/semantic_segmentation/build.gradle +++ b/android/pytorch_android/semantic_segmentation/build.gradle @@ -40,7 +40,7 @@ dependencies { implementation "androidx.camera:camera-view:1.1.0" implementation "ai.djl:api:0.20.0-SNAPSHOT" - implementation "ai.djl.android:core:${djl_version}" + implementation "ai.djl.android:core:0.20.0-SNAPSHOT" runtimeOnly "ai.djl.pytorch:pytorch-engine:${djl_version}" runtimeOnly "ai.djl.android:pytorch-native:${djl_version}" } \ No newline at end of file diff --git a/android/pytorch_android/semantic_segmentation/src/main/java/ai/djl/examples/semanticsegmentation/SemanticActivity.java b/android/pytorch_android/semantic_segmentation/src/main/java/ai/djl/examples/semanticsegmentation/SemanticActivity.java index 78a832a3..ca7a99f8 100644 --- a/android/pytorch_android/semantic_segmentation/src/main/java/ai/djl/examples/semanticsegmentation/SemanticActivity.java +++ b/android/pytorch_android/semantic_segmentation/src/main/java/ai/djl/examples/semanticsegmentation/SemanticActivity.java @@ -49,7 +49,11 @@ import com.google.common.util.concurrent.ListenableFuture; import java.io.IOException; +import java.net.MalformedURLException; +import java.net.URL; import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.List; import java.util.concurrent.ExecutionException; import java.util.concurrent.Executor; import java.util.concurrent.Executors; @@ -59,10 +63,9 @@ import ai.djl.inference.Predictor; import ai.djl.modality.cv.Image; import ai.djl.modality.cv.ImageFactory; -import ai.djl.modality.cv.output.Segmentation; +import ai.djl.modality.cv.output.CategoryMask; import ai.djl.repository.zoo.ZooModel; import ai.djl.translate.TranslateException; -import ai.djl.util.RandomUtils; public class SemanticActivity extends AppCompatActivity implements CameraXConfig.Provider { @@ -70,6 +73,8 @@ public class SemanticActivity extends AppCompatActivity implements CameraXConfig private static final int CAMERA_REQUEST_CODE = 1; + public static final List list = Arrays.asList("foo", "bar"); + // Class names: 'background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', // 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', // 'sheep', 'sofa', 'train', 'tv' @@ -78,7 +83,7 @@ public class SemanticActivity extends AppCompatActivity implements CameraXConfig private PreviewView mViewFinder; private ImageView mImagePreview; private ImageButton mCloseImageButton; - private ImageView mImageMask; + private ImageView mImagePredicted; private ProgressBar mProgressBar; private FloatingActionButton mCaptureButton; private SwitchCompat mSwitch; @@ -87,8 +92,8 @@ public class SemanticActivity extends AppCompatActivity implements CameraXConfig private Bitmap bitmapBuffer; boolean isUseSelfTranslator = false; - ZooModel model; - Predictor predictor; + ZooModel model; + Predictor predictor; Executor executor = Executors.newSingleThreadExecutor(); @Override @@ -100,7 +105,7 @@ protected void onCreate(Bundle savedInstanceState) { mViewFinder = binding.viewFinder; mImagePreview = binding.imagePreview; mCloseImageButton = binding.closeImagePreview; - mImageMask = binding.imageMask; + mImagePredicted = binding.imagePredicted; mProgressBar = binding.progressBar; mCaptureButton = binding.captureButton; mSwitch = binding.useSelfTranslatorSwitch; @@ -117,7 +122,7 @@ public boolean onTouch(View view, MotionEvent motionEvent) { mCaptureButton.setOnClickListener(view -> { mCaptureButton.setEnabled(false); - mImageMask.setVisibility(View.GONE); + mImagePredicted.setVisibility(View.GONE); if (imageCapture != null) { imageCapture.takePicture(executor, new ImageCapture.OnImageCapturedCallback() { @@ -161,7 +166,7 @@ public void onError(@NonNull ImageCaptureException exception) { mCloseImageButton.setVisibility(View.GONE); mViewFinder.setVisibility(View.VISIBLE); mImagePreview.setVisibility(View.GONE); - mImageMask.setVisibility(View.GONE); + mImagePredicted.setVisibility(View.GONE); mProgressBar.setVisibility(View.GONE); mCaptureButton.setEnabled(true); }); @@ -295,37 +300,18 @@ public void run() { try { // Predict Image img = ImageFactory.getInstance().fromImage(bitmapBuffer); - Segmentation result = predictor.predict(img); - - int[][] mask = result.getMask(); - int width = mask.length; - int height = mask[0].length; - int[] colors = new int[CLASSNUM]; - for (int i = 0; i < CLASSNUM; i++) { - colors[i] = Color.argb(255, RandomUtils.nextInt(256), RandomUtils.nextInt(256), - RandomUtils.nextInt(256)); - } - - // change color of pixels in image array where objects have been detected - int[] pixels = new int[width * height]; - for (int i = 0; i < width; i++) { - for (int j = 0; j < height; j++) { - int idx = mask[i][j]; - if (idx > 0) { - pixels[j * width + i] = colors[idx]; - } else { // Make background transparent - pixels[j * width + i] = Color.TRANSPARENT; - } - } - } - Bitmap bmpSegmentation = Bitmap.createBitmap(pixels, width, height, - Bitmap.Config.ARGB_8888); - final Bitmap scaledBitmap = Bitmap.createScaledBitmap(bmpSegmentation, - bitmapBuffer.getWidth(), bitmapBuffer.getHeight(), true); + CategoryMask mask = predictor.predict(img); + + Image backgroundImg = ImageFactory.getInstance().fromUrl(new URL("https://images.pexels.com/photos/924824/pexels-photo-924824.jpeg")); + // mask.drawMask(img, 0.5f); + // mask.drawMask(img, 0.5f, Color.TRANSPARENT); + mask.drawMask(img, 1, backgroundImg); + // mask.drawMask(img, 20, Color.BLUE); + Bitmap segmentBitmap = (Bitmap) img.getWrappedImage(); runOnUiThread(() -> { - mImageMask.setImageBitmap(scaledBitmap); - mImageMask.setVisibility(View.VISIBLE); + mImagePredicted.setImageBitmap(segmentBitmap); + mImagePredicted.setVisibility(View.VISIBLE); mCloseImageButton.setVisibility(View.VISIBLE); mProgressBar.setVisibility(View.GONE); }); @@ -337,6 +323,10 @@ public void run() { Snackbar.make(findViewById(android.R.id.content), "Inference failed.", Snackbar.LENGTH_LONG).show(); }); + } catch (MalformedURLException e) { + Log.d(TAG, "MalformedURLException: " + e.getMessage()); + } catch (IOException e) { + Log.d(TAG, "IOException: " + e.getMessage()); } } } diff --git a/android/pytorch_android/semantic_segmentation/src/main/java/ai/djl/examples/semanticsegmentation/SemanticModel.java b/android/pytorch_android/semantic_segmentation/src/main/java/ai/djl/examples/semanticsegmentation/SemanticModel.java index 33cfab9a..808b18d9 100644 --- a/android/pytorch_android/semantic_segmentation/src/main/java/ai/djl/examples/semanticsegmentation/SemanticModel.java +++ b/android/pytorch_android/semantic_segmentation/src/main/java/ai/djl/examples/semanticsegmentation/SemanticModel.java @@ -19,7 +19,7 @@ import ai.djl.ModelException; import ai.djl.modality.cv.Image; -import ai.djl.modality.cv.output.Segmentation; +import ai.djl.modality.cv.output.CategoryMask; import ai.djl.modality.cv.translator.SemanticSegmentationTranslator; import ai.djl.repository.zoo.Criteria; import ai.djl.repository.zoo.ZooModel; @@ -30,17 +30,17 @@ final class SemanticModel { private SemanticModel() { } - public static ZooModel loadModel() throws ModelException, IOException { + public static ZooModel loadModel() throws ModelException, IOException { String url = "https://mlrepo.djl.ai/model/cv/semantic_segmentation/ai/djl/pytorch/deeplabv3/0.0.1/deeplabv3.zip"; Map arguments = new ConcurrentHashMap<>(); arguments.put("toTensor", "true"); arguments.put("normalize", "true"); - Translator translator = SemanticSegmentationTranslator.builder(arguments).build(); + Translator translator = SemanticSegmentationTranslator.builder(arguments).build(); - Criteria criteria = + Criteria criteria = Criteria.builder() - .setTypes(Image.class, Segmentation.class) + .setTypes(Image.class, CategoryMask.class) .optModelUrls(url) .optTranslator(translator) .optEngine("PyTorch") diff --git a/android/pytorch_android/semantic_segmentation/src/main/res/layout/activity_main.xml b/android/pytorch_android/semantic_segmentation/src/main/res/layout/activity_main.xml index 29af3ba7..58ef22f6 100644 --- a/android/pytorch_android/semantic_segmentation/src/main/res/layout/activity_main.xml +++ b/android/pytorch_android/semantic_segmentation/src/main/res/layout/activity_main.xml @@ -11,7 +11,7 @@ android:id="@+id/view_finder" android:layout_width="@dimen/preview_width" android:layout_height="@dimen/preview_height" - android:layout_marginTop="@dimen/margin_xxxlarge" + android:layout_marginTop="@dimen/margin_large" app:layout_constraintEnd_toEndOf="parent" app:layout_constraintStart_toStartOf="parent" app:layout_constraintTop_toTopOf="parent" /> @@ -20,7 +20,7 @@ android:id="@+id/image_preview" android:layout_width="@dimen/preview_width" android:layout_height="@dimen/preview_height" - android:layout_marginTop="@dimen/margin_xxxlarge" + android:layout_marginTop="@dimen/margin_large" android:contentDescription="@string/app_name" android:elevation="10dp" android:scaleType="centerCrop" @@ -30,32 +30,31 @@ app:layout_constraintTop_toTopOf="parent" /> + app:layout_constraintTop_toBottomOf="@id/view_finder" /> + app:layout_constraintTop_toTopOf="@+id/image_predicted" />