Skip to content

Commit

Permalink
Draw mask on the image
Browse files Browse the repository at this point in the history
  • Loading branch information
xyang16 committed Sep 29, 2022
1 parent 4e0fe42 commit c78d3b2
Show file tree
Hide file tree
Showing 7 changed files with 71 additions and 234 deletions.
2 changes: 1 addition & 1 deletion android/gradle.properties
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@ org.gradle.jvmargs=-Xmx2048m --add-exports jdk.compiler/com.sun.tools.javac.api=
android.useAndroidX=true
# Automatically convert third-party libraries to use AndroidX
android.enableJetifier=true
djl_version=0.19.0
djl_version=0.20.0-SNAPSHOT
4 changes: 2 additions & 2 deletions android/pytorch_android/semantic_segmentation/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ dependencies {
implementation "androidx.camera:camera-view:1.1.0"

implementation "ai.djl:api:${djl_version}"
implementation "ai.djl.android:core:${djl_version}"
implementation "ai.djl.android:core:0.19.0"
runtimeOnly "ai.djl.pytorch:pytorch-engine:${djl_version}"
runtimeOnly "ai.djl.android:pytorch-native:${djl_version}"
runtimeOnly "ai.djl.android:pytorch-native:0.19.0"
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import android.content.pm.PackageManager;
import android.graphics.Bitmap;
import android.graphics.BitmapFactory;
import android.graphics.Color;
import android.graphics.Matrix;
import android.os.Bundle;
import android.util.Log;
Expand Down Expand Up @@ -58,19 +59,23 @@
import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.output.Segmentation;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.translate.TranslateException;
import ai.djl.util.RandomUtils;

public class SemanticActivity extends AppCompatActivity implements CameraXConfig.Provider {

private static final String TAG = SemanticActivity.class.getSimpleName();

private static final int CAMERA_REQUEST_CODE = 1;

private static final int CLASSNUM = 21;

private PreviewView mViewFinder;
private ImageView mImagePreview;
private ImageButton mCloseImageButton;
private ImageView mImagePredicted;
private ImageView mImageMask;
private ProgressBar mProgressBar;
private FloatingActionButton mCaptureButton;
private SwitchCompat mSwitch;
Expand All @@ -79,8 +84,8 @@ public class SemanticActivity extends AppCompatActivity implements CameraXConfig
private Bitmap bitmapBuffer;
boolean isUseSelfTranslator = false;

ZooModel<Image, Image> model;
Predictor<Image, Image> predictor;
ZooModel<Image, Segmentation> model;
Predictor<Image, Segmentation> predictor;
Executor executor = Executors.newSingleThreadExecutor();

@Override
Expand All @@ -92,7 +97,7 @@ protected void onCreate(Bundle savedInstanceState) {
mViewFinder = binding.viewFinder;
mImagePreview = binding.imagePreview;
mCloseImageButton = binding.closeImagePreview;
mImagePredicted = binding.imagePredicted;
mImageMask = binding.imageMask;
mProgressBar = binding.progressBar;
mCaptureButton = binding.captureButton;
mSwitch = binding.useSelfTranslatorSwitch;
Expand All @@ -109,7 +114,7 @@ public boolean onTouch(View view, MotionEvent motionEvent) {
mCaptureButton.setOnClickListener(view -> {

mCaptureButton.setEnabled(false);
mImagePredicted.setVisibility(View.GONE);
mImageMask.setVisibility(View.GONE);

if (imageCapture != null) {
imageCapture.takePicture(executor, new ImageCapture.OnImageCapturedCallback() {
Expand All @@ -126,7 +131,8 @@ public void onCaptureSuccess(@NonNull ImageProxy image) {
if (image.getImageInfo().getRotationDegrees() != 0) {
Matrix matrix = new Matrix();
matrix.postRotate(image.getImageInfo().getRotationDegrees());
bitmapBuffer = Bitmap.createBitmap(bitmapBuffer, 0, 0, bitmapBuffer.getWidth(), bitmapBuffer.getHeight(), matrix, true);
bitmapBuffer = Bitmap.createBitmap(bitmapBuffer, 0, 0,
bitmapBuffer.getWidth(), bitmapBuffer.getHeight(), matrix, true);
}

runOnUiThread(() -> {
Expand All @@ -152,7 +158,7 @@ public void onError(@NonNull ImageCaptureException exception) {
mCloseImageButton.setVisibility(View.GONE);
mViewFinder.setVisibility(View.VISIBLE);
mImagePreview.setVisibility(View.GONE);
mImagePredicted.setVisibility(View.GONE);
mImageMask.setVisibility(View.GONE);
mProgressBar.setVisibility(View.GONE);
mCaptureButton.setEnabled(true);
});
Expand Down Expand Up @@ -186,7 +192,8 @@ protected void onDestroy() {
}

@Override
public void onRequestPermissionsResult(int requestCode, @NonNull String[] permissions, @NonNull int[] grantResults) {
public void onRequestPermissionsResult(int requestCode, @NonNull String[] permissions,
@NonNull int[] grantResults) {
super.onRequestPermissionsResult(requestCode, permissions, grantResults);
if (requestCode == CAMERA_REQUEST_CODE) {
if (grantResults.length == 1 && grantResults[0] == PackageManager.PERMISSION_GRANTED) {
Expand Down Expand Up @@ -218,7 +225,8 @@ private void requestCameraPermission() {
}

private void setUpCamera() {
ListenableFuture<ProcessCameraProvider> cameraProviderFuture = ProcessCameraProvider.getInstance(SemanticActivity.this);
ListenableFuture<ProcessCameraProvider> cameraProviderFuture =
ProcessCameraProvider.getInstance(SemanticActivity.this);
cameraProviderFuture.addListener(() -> {
try {
ProcessCameraProvider cameraProvider = cameraProviderFuture.get();
Expand Down Expand Up @@ -256,12 +264,12 @@ private class LoadModelTask implements Runnable {
@Override
public void run() {
try {
isUseSelfTranslator = mSwitch.isChecked();
model = SemanticModel.loadModel(isUseSelfTranslator);
model = SemanticModel.loadModel();
predictor = model.newPredictor();
runOnUiThread(() -> {
mCaptureButton.setEnabled(true);
Snackbar.make(findViewById(android.R.id.content), R.string.message_download_model_complete, Snackbar.LENGTH_LONG).show();
Snackbar.make(findViewById(android.R.id.content),
R.string.message_download_model_complete, Snackbar.LENGTH_LONG).show();
});
} catch (IOException | ModelException e) {
Log.e(TAG, null, e);
Expand All @@ -284,10 +292,37 @@ public void run() {
try {
// Predict
Image img = ImageFactory.getInstance().fromImage(bitmapBuffer);
Bitmap transferredBitmap = (Bitmap) predictor.predict(img).getWrappedImage();
Segmentation result = predictor.predict(img);

int[][] mask = result.getMask();
int width = mask.length;
int height = mask[0].length;
int[] colors = new int[CLASSNUM];
for (int i = 0; i < CLASSNUM; i++) {
colors[i] = Color.argb(255, RandomUtils.nextInt(256), RandomUtils.nextInt(256),
RandomUtils.nextInt(256));
}

// change color of pixels in image array where objects have been detected
int[] pixels = new int[width * height];
for (int i = 0; i < width; i++) {
for (int j = 0; j < height; j++) {
int maxi = mask[i][j];
if (maxi > 0) {
pixels[j * width + i] = colors[maxi];
} else {
pixels[j * width + i] = Color.TRANSPARENT;
}
}
}
Bitmap bmpSegmentation = Bitmap.createBitmap(pixels, width, height,
Bitmap.Config.ARGB_8888);
final Bitmap scaledBitmap = Bitmap.createScaledBitmap(bmpSegmentation,
bitmapBuffer.getWidth(), bitmapBuffer.getHeight(), true);

runOnUiThread(() -> {
mImagePredicted.setImageBitmap(transferredBitmap);
mImagePredicted.setVisibility(View.VISIBLE);
mImageMask.setImageBitmap(scaledBitmap);
mImageMask.setVisibility(View.VISIBLE);
mCloseImageButton.setVisibility(View.VISIBLE);
mProgressBar.setVisibility(View.GONE);
});
Expand All @@ -296,7 +331,8 @@ public void run() {
runOnUiThread(() -> {
mCaptureButton.setEnabled(true);
mProgressBar.setVisibility(View.GONE);
Snackbar.make(findViewById(android.R.id.content), "Inference failed.", Snackbar.LENGTH_LONG).show();
Snackbar.make(findViewById(android.R.id.content), "Inference failed.",
Snackbar.LENGTH_LONG).show();
});
}
}
Expand Down
Loading

0 comments on commit c78d3b2

Please sign in to comment.