Skip to content

Commit

Permalink
Add drawMask utility
Browse files Browse the repository at this point in the history
  • Loading branch information
xyang16 committed Sep 30, 2022
1 parent fd8afad commit c698fef
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 55 deletions.
2 changes: 1 addition & 1 deletion android/pytorch_android/semantic_segmentation/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ dependencies {
implementation "androidx.camera:camera-view:1.1.0"

implementation "ai.djl:api:0.20.0-SNAPSHOT"
implementation "ai.djl.android:core:${djl_version}"
implementation "ai.djl.android:core:0.20.0-SNAPSHOT"
runtimeOnly "ai.djl.pytorch:pytorch-engine:${djl_version}"
runtimeOnly "ai.djl.android:pytorch-native:${djl_version}"
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,11 @@
import com.google.common.util.concurrent.ListenableFuture;

import java.io.IOException;
import java.net.MalformedURLException;
import java.net.URL;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
Expand All @@ -59,17 +63,18 @@
import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.output.Segmentation;
import ai.djl.modality.cv.output.CategoryMask;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.translate.TranslateException;
import ai.djl.util.RandomUtils;

public class SemanticActivity extends AppCompatActivity implements CameraXConfig.Provider {

private static final String TAG = SemanticActivity.class.getSimpleName();

private static final int CAMERA_REQUEST_CODE = 1;

public static final List<String> list = Arrays.asList("foo", "bar");

// Class names: 'background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car',
// 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant',
// 'sheep', 'sofa', 'train', 'tv'
Expand All @@ -78,7 +83,7 @@ public class SemanticActivity extends AppCompatActivity implements CameraXConfig
private PreviewView mViewFinder;
private ImageView mImagePreview;
private ImageButton mCloseImageButton;
private ImageView mImageMask;
private ImageView mImagePredicted;
private ProgressBar mProgressBar;
private FloatingActionButton mCaptureButton;
private SwitchCompat mSwitch;
Expand All @@ -87,8 +92,8 @@ public class SemanticActivity extends AppCompatActivity implements CameraXConfig
private Bitmap bitmapBuffer;
boolean isUseSelfTranslator = false;

ZooModel<Image, Segmentation> model;
Predictor<Image, Segmentation> predictor;
ZooModel<Image, CategoryMask> model;
Predictor<Image, CategoryMask> predictor;
Executor executor = Executors.newSingleThreadExecutor();

@Override
Expand All @@ -100,7 +105,7 @@ protected void onCreate(Bundle savedInstanceState) {
mViewFinder = binding.viewFinder;
mImagePreview = binding.imagePreview;
mCloseImageButton = binding.closeImagePreview;
mImageMask = binding.imageMask;
mImagePredicted = binding.imagePredicted;
mProgressBar = binding.progressBar;
mCaptureButton = binding.captureButton;
mSwitch = binding.useSelfTranslatorSwitch;
Expand All @@ -117,7 +122,7 @@ public boolean onTouch(View view, MotionEvent motionEvent) {
mCaptureButton.setOnClickListener(view -> {

mCaptureButton.setEnabled(false);
mImageMask.setVisibility(View.GONE);
mImagePredicted.setVisibility(View.GONE);

if (imageCapture != null) {
imageCapture.takePicture(executor, new ImageCapture.OnImageCapturedCallback() {
Expand Down Expand Up @@ -161,7 +166,7 @@ public void onError(@NonNull ImageCaptureException exception) {
mCloseImageButton.setVisibility(View.GONE);
mViewFinder.setVisibility(View.VISIBLE);
mImagePreview.setVisibility(View.GONE);
mImageMask.setVisibility(View.GONE);
mImagePredicted.setVisibility(View.GONE);
mProgressBar.setVisibility(View.GONE);
mCaptureButton.setEnabled(true);
});
Expand Down Expand Up @@ -295,37 +300,18 @@ public void run() {
try {
// Predict
Image img = ImageFactory.getInstance().fromImage(bitmapBuffer);
Segmentation result = predictor.predict(img);

int[][] mask = result.getMask();
int width = mask.length;
int height = mask[0].length;
int[] colors = new int[CLASSNUM];
for (int i = 0; i < CLASSNUM; i++) {
colors[i] = Color.argb(255, RandomUtils.nextInt(256), RandomUtils.nextInt(256),
RandomUtils.nextInt(256));
}

// change color of pixels in image array where objects have been detected
int[] pixels = new int[width * height];
for (int i = 0; i < width; i++) {
for (int j = 0; j < height; j++) {
int idx = mask[i][j];
if (idx > 0) {
pixels[j * width + i] = colors[idx];
} else { // Make background transparent
pixels[j * width + i] = Color.TRANSPARENT;
}
}
}
Bitmap bmpSegmentation = Bitmap.createBitmap(pixels, width, height,
Bitmap.Config.ARGB_8888);
final Bitmap scaledBitmap = Bitmap.createScaledBitmap(bmpSegmentation,
bitmapBuffer.getWidth(), bitmapBuffer.getHeight(), true);
CategoryMask mask = predictor.predict(img);

// mask.drawMask(img, 0.5f);
// mask.drawMask(img, 0.5f, Color.TRANSPARENT);
Image backgroundImg = ImageFactory.getInstance().fromUrl(new URL("https://images.pexels.com/photos/924824/pexels-photo-924824.jpeg"));
mask.drawMask(img, 1, backgroundImg);
// mask.drawMask(img, 20, Color.BLUE);
Bitmap segmentBitmap = (Bitmap) img.getWrappedImage();

runOnUiThread(() -> {
mImageMask.setImageBitmap(scaledBitmap);
mImageMask.setVisibility(View.VISIBLE);
mImagePredicted.setImageBitmap(segmentBitmap);
mImagePredicted.setVisibility(View.VISIBLE);
mCloseImageButton.setVisibility(View.VISIBLE);
mProgressBar.setVisibility(View.GONE);
});
Expand All @@ -337,6 +323,10 @@ public void run() {
Snackbar.make(findViewById(android.R.id.content), "Inference failed.",
Snackbar.LENGTH_LONG).show();
});
} catch (MalformedURLException e) {
Log.d(TAG, "MalformedURLException: " + e.getMessage());
} catch (IOException e) {
Log.d(TAG, "IOException: " + e.getMessage());
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import ai.djl.ModelException;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.output.Segmentation;
import ai.djl.modality.cv.output.CategoryMask;
import ai.djl.modality.cv.translator.SemanticSegmentationTranslator;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ZooModel;
Expand All @@ -30,17 +30,17 @@ final class SemanticModel {
private SemanticModel() {
}

public static ZooModel<Image, Segmentation> loadModel() throws ModelException, IOException {
public static ZooModel<Image, CategoryMask> loadModel() throws ModelException, IOException {
String url =
"https://mlrepo.djl.ai/model/cv/semantic_segmentation/ai/djl/pytorch/deeplabv3/0.0.1/deeplabv3.zip";
Map<String, String> arguments = new ConcurrentHashMap<>();
arguments.put("toTensor", "true");
arguments.put("normalize", "true");
Translator<Image, Segmentation> translator = SemanticSegmentationTranslator.builder(arguments).build();
Translator<Image, CategoryMask> translator = SemanticSegmentationTranslator.builder(arguments).build();

Criteria<Image, Segmentation> criteria =
Criteria<Image, CategoryMask> criteria =
Criteria.builder()
.setTypes(Image.class, Segmentation.class)
.setTypes(Image.class, CategoryMask.class)
.optModelUrls(url)
.optTranslator(translator)
.optEngine("PyTorch")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
android:id="@+id/view_finder"
android:layout_width="@dimen/preview_width"
android:layout_height="@dimen/preview_height"
android:layout_marginTop="@dimen/margin_xxxlarge"
android:layout_marginTop="@dimen/margin_large"
app:layout_constraintEnd_toEndOf="parent"
app:layout_constraintStart_toStartOf="parent"
app:layout_constraintTop_toTopOf="parent" />
Expand All @@ -20,7 +20,7 @@
android:id="@+id/image_preview"
android:layout_width="@dimen/preview_width"
android:layout_height="@dimen/preview_height"
android:layout_marginTop="@dimen/margin_xxxlarge"
android:layout_marginTop="@dimen/margin_large"
android:contentDescription="@string/app_name"
android:elevation="10dp"
android:scaleType="centerCrop"
Expand All @@ -30,32 +30,31 @@
app:layout_constraintTop_toTopOf="parent" />

<ImageView
android:id="@+id/image_mask"
android:id="@+id/image_predicted"
android:layout_width="@dimen/preview_width"
android:layout_height="@dimen/preview_height"
android:layout_marginTop="@dimen/margin_xxxlarge"
android:layout_marginTop="@dimen/margin_large"
android:contentDescription="@string/app_name"
android:elevation="10dp"
android:scaleType="centerCrop"
android:alpha="0.5"
android:visibility="gone"
app:layout_constraintEnd_toEndOf="parent"
app:layout_constraintStart_toStartOf="parent"
app:layout_constraintTop_toTopOf="parent" />
app:layout_constraintTop_toBottomOf="@id/view_finder" />

<ImageButton
android:id="@+id/close_image_preview"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:layout_marginTop="@dimen/margin_small"
android:layout_marginEnd="@dimen/margin_small"
android:layout_marginTop="@dimen/margin_medium"
android:layout_marginEnd="@dimen/margin_medium"
android:background="@drawable/circle_background"
android:contentDescription="@string/app_name"
android:elevation="20dp"
android:src="@drawable/ic_close"
android:visibility="gone"
app:layout_constraintEnd_toEndOf="@id/image_preview"
app:layout_constraintTop_toTopOf="@id/image_preview"
app:layout_constraintEnd_toEndOf="@id/image_predicted"
app:layout_constraintTop_toTopOf="@id/image_predicted"
app:tint="@color/button_selector" />

<ProgressBar
Expand All @@ -66,14 +65,14 @@
android:visibility="invisible"
app:layout_constraintEnd_toEndOf="parent"
app:layout_constraintStart_toStartOf="parent"
app:layout_constraintTop_toBottomOf="@+id/view_finder" />
app:layout_constraintTop_toTopOf="@+id/image_predicted" />

<com.google.android.material.floatingactionbutton.FloatingActionButton
android:id="@+id/capture_button"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:layout_gravity="bottom"
android:layout_marginBottom="@dimen/margin_xxlarge"
android:layout_marginBottom="@dimen/margin_xlarge"
android:backgroundTint="@color/fab_selector"
android:contentDescription="@string/app_name"
android:enabled="false"
Expand Down

0 comments on commit c698fef

Please sign in to comment.