Skip to content

Commit

Permalink
add a translator to get the original photo pixels
Browse files Browse the repository at this point in the history
  • Loading branch information
wxm2018 committed Sep 21, 2022
1 parent cd3db85 commit 4e0fe42
Show file tree
Hide file tree
Showing 5 changed files with 255 additions and 19 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
package ai.djl.examples.semanticsegmentation;

import java.nio.ByteBuffer;
import java.util.Map;

import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.modality.cv.translator.BaseImageTranslator;
import ai.djl.modality.cv.util.NDImageUtils;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.translate.ArgumentsUtil;
import ai.djl.translate.Transform;
import ai.djl.translate.TranslatorContext;

/**
* A {@link BaseImageTranslator} that post-process the {@link NDArray} into {@link DetectedObjects}
* with boundaries at the detailed pixel level.
*/
public class SelfSemanticSegmentationTranslator extends BaseImageTranslator<Image> {

private final int shortEdge;
private final int maxEdge;

private static final int CHANNEL = 3;
private static final int CLASSNUM = 21;

/**
* Creates the Semantic Segmentation translator from the given builder.
*
* @param builder the builder for the translator
*/
public SelfSemanticSegmentationTranslator(Builder builder) {
super(builder);
this.shortEdge = builder.shortEdge;
this.maxEdge = builder.maxEdge;

pipeline.insert(0, null, new ResizeShort());
}

/** {@inheritDoc} */
@Override
public NDList processInput(TranslatorContext ctx, Image image) {
ctx.setAttachment("originalHeight", image.getHeight());
ctx.setAttachment("originalWidth", image.getWidth());
ctx.setAttachment("originalImage",image);

float a = 128;
byte b = (byte) a;
System.out.println("**************************************************************: "+ b);
return super.processInput(ctx, image);
}

/** {@inheritDoc} */
@Override
public Image processOutput(TranslatorContext ctx, NDList list) {
// scores contains the probabilities of each pixel being a certain object
float[] scores = list.get(1).toFloatArray();
Shape shape = list.get(1).getShape();
int width = (int) shape.get(2);
int height = (int) shape.get(1);
Image originalImage = (Image) ctx.getAttachment("originalImage");


// build image array
try (NDManager manager = NDManager.newBaseManager()) {
NDArray imageArray = originalImage.toNDArray(manager);
imageArray = NDImageUtils.resize(imageArray, width, height);
imageArray.toUint8Array();

int imageSize = width * height;
ByteBuffer bb = manager.allocateDirect(CHANNEL * imageSize);
int[] row = imageArray.toType(DataType.UINT8, false).toUint8Array();

// change color of pixels in image array where objects have been detected
for (int h = 0; h < height; h++) {
for (int w = 0; w < width; w++) {
int index = h * width + w;
int maxi = 0;
double maxnum = -Double.MAX_VALUE;
for (int i = 0; i < CLASSNUM; i++) {
// get score for each i at the h,w pixel of the image
float score = scores[i * (imageSize) + index];
if (score > maxnum) {
maxnum = score;
maxi = i;
}
}
if (maxi > 0) {

bb.put((byte) (row[3*index]));
bb.put((byte) (row[3*index+1]));
bb.put((byte) (row[3*index+2]));

} else {
bb.position(bb.position() + 3);
}
}
}
bb.rewind();
int originW = (int) ctx.getAttachment("originalWidth");
int originH = (int) ctx.getAttachment("originalHeight");
NDArray fullImage =
manager.create(bb, new Shape(height, width, CHANNEL), DataType.UINT8);
NDArray resized = NDImageUtils.resize(fullImage, originW, originH);

return ImageFactory.getInstance().fromNDArray(resized);
}
}

/**
* Creates a builder to build a {@code SelfSemanticSegmentationTranslator}.
*
* @return a new builder
*/
public static Builder builder() {
return new Builder();
}

/**
* Creates a builder to build a {@code SelfSemanticSegmentationTranslator} with specified arguments.
*
* @param arguments arguments to specify builder options
* @return a new builder
*/
public static Builder builder(Map<String, ?> arguments) {
Builder builder = new Builder();

builder.configPreProcess(arguments);
builder.configPostProcess(arguments);

return builder;
}

/** Resizes the image based on the shorter edge or maximum edge length. */
private class ResizeShort implements Transform {
/** {@inheritDoc} */
@Override
public NDArray transform(NDArray array) {
Shape shape = array.getShape();
int width = (int) shape.get(1);
int height = (int) shape.get(0);
int min = Math.min(width, height);
int max = Math.max(width, height);
float scale = shortEdge / (float) min;
if (Math.round(scale * max) > maxEdge) {
scale = maxEdge / (float) max;
}
int rescaledHeight = Math.round(height * scale);
int rescaledWidth = Math.round(width * scale);

return NDImageUtils.resize(array, rescaledWidth, rescaledHeight);
}
}

/** The builder for Semantic Segmentation translator. */
public static class Builder extends ClassificationBuilder<Builder> {
int shortEdge = 600;
int maxEdge = 1000;

Builder() {}

/** {@inheritDoc} */
@Override
protected Builder self() {
return this;
}

/** {@inheritDoc} */
@Override
protected void configPostProcess(Map<String, ?> arguments) {
super.configPostProcess(arguments);
shortEdge = ArgumentsUtil.intValue(arguments, "shortEdge", 600);
maxEdge = ArgumentsUtil.intValue(arguments, "maxEdge", 1000);
}
@Override
protected void configPreProcess(Map<String, ?> arguments) {
super.configPreProcess(arguments);
}

/**
* Builds the translator.
*
* @return the new translator
*/
public SelfSemanticSegmentationTranslator build() {
validate();
return new SelfSemanticSegmentationTranslator(this);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import android.os.Bundle;
import android.util.Log;
import android.util.Size;
import android.view.MotionEvent;
import android.view.Surface;
import android.view.View;
import android.widget.ImageButton;
Expand All @@ -30,6 +31,7 @@
import androidx.annotation.NonNull;
import androidx.appcompat.app.AlertDialog;
import androidx.appcompat.app.AppCompatActivity;
import androidx.appcompat.widget.SwitchCompat;
import androidx.camera.camera2.Camera2Config;
import androidx.camera.core.CameraSelector;
import androidx.camera.core.CameraXConfig;
Expand Down Expand Up @@ -71,9 +73,11 @@ public class SemanticActivity extends AppCompatActivity implements CameraXConfig
private ImageView mImagePredicted;
private ProgressBar mProgressBar;
private FloatingActionButton mCaptureButton;
private SwitchCompat mSwitch;

private ImageCapture imageCapture;
private Bitmap bitmapBuffer;
boolean isUseSelfTranslator = false;

ZooModel<Image, Image> model;
Predictor<Image, Image> predictor;
Expand All @@ -91,8 +95,19 @@ protected void onCreate(Bundle savedInstanceState) {
mImagePredicted = binding.imagePredicted;
mProgressBar = binding.progressBar;
mCaptureButton = binding.captureButton;
mSwitch = binding.useSelfTranslatorSwitch;
mSwitch.setOnTouchListener(new View.OnTouchListener(){
@Override
public boolean onTouch(View view, MotionEvent motionEvent) {
view.performClick();
isUseSelfTranslator = mSwitch.isChecked();
executor.execute(new LoadModelTask());
return false;
}
});

mCaptureButton.setOnClickListener(view -> {

mCaptureButton.setEnabled(false);
mImagePredicted.setVisibility(View.GONE);

Expand Down Expand Up @@ -144,6 +159,7 @@ public void onError(@NonNull ImageCaptureException exception) {

Snackbar.make(findViewById(android.R.id.content), R.string.message_download_model,
Snackbar.LENGTH_LONG).show();
//first initialize
executor.execute(new LoadModelTask());
}

Expand Down Expand Up @@ -240,7 +256,8 @@ private class LoadModelTask implements Runnable {
@Override
public void run() {
try {
model = SemanticModel.loadModel();
isUseSelfTranslator = mSwitch.isChecked();
model = SemanticModel.loadModel(isUseSelfTranslator);
predictor = model.newPredictor();
runOnUiThread(() -> {
mCaptureButton.setEnabled(true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,27 @@
import ai.djl.modality.cv.translator.SemanticSegmentationTranslator;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.translate.Translator;

final class SemanticModel {

private SemanticModel() {
}

public static ZooModel<Image, Image> loadModel() throws ModelException, IOException {
public static ZooModel<Image, Image> loadModel(boolean isUseSelfTranslator) throws ModelException, IOException {
String url =
"https://mlrepo.djl.ai/model/cv/semantic_segmentation/ai/djl/pytorch/deeplabv3/0.0.1/deeplabv3.zip";
Map<String, String> arguments = new ConcurrentHashMap<>();
arguments.put("toTensor", "true");
arguments.put("normalize", "true");
SemanticSegmentationTranslator translator =
SemanticSegmentationTranslator.builder(arguments).build();
Translator<Image, Image> translator;
if(isUseSelfTranslator){
translator = SelfSemanticSegmentationTranslator.builder(arguments).build();
}else {
translator = SemanticSegmentationTranslator.builder(arguments).build();
}



Criteria<Image, Image> criteria =
Criteria.builder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,29 @@
xmlns:tools="http://schemas.android.com/tools"
android:layout_width="match_parent"
android:layout_height="match_parent"
android:elevation="10dp"
tools:context=".SemanticActivity">

<androidx.camera.view.PreviewView
android:id="@+id/view_finder"
android:layout_width="@dimen/preview_width"
android:layout_height="@dimen/preview_height"
android:layout_marginTop="@dimen/margin_large"
app:layout_constraintStart_toStartOf="parent"
app:layout_constraintEnd_toEndOf="parent"
app:layout_constraintTop_toTopOf="parent"/>
app:layout_constraintStart_toStartOf="parent"
app:layout_constraintTop_toTopOf="parent" />

<ImageView
android:id="@+id/image_preview"
android:layout_width="@dimen/preview_width"
android:layout_height="@dimen/preview_height"
android:layout_marginTop="@dimen/margin_large"
android:contentDescription="@string/app_name"
android:elevation="10dp"
android:scaleType="centerCrop"
android:visibility="gone"
app:layout_constraintStart_toStartOf="parent"
app:layout_constraintEnd_toEndOf="parent"
app:layout_constraintStart_toStartOf="parent"
app:layout_constraintTop_toTopOf="parent" />

<ImageView
Expand All @@ -33,10 +35,11 @@
android:layout_height="@dimen/preview_height"
android:layout_marginTop="@dimen/margin_large"
android:contentDescription="@string/app_name"
android:elevation="10dp"
android:scaleType="centerCrop"
android:visibility="gone"
app:layout_constraintStart_toStartOf="parent"
app:layout_constraintEnd_toEndOf="parent"
app:layout_constraintStart_toStartOf="parent"
app:layout_constraintTop_toBottomOf="@id/view_finder" />

<ImageButton
Expand All @@ -45,35 +48,48 @@
android:layout_height="wrap_content"
android:layout_marginTop="@dimen/margin_medium"
android:layout_marginEnd="@dimen/margin_medium"
android:contentDescription="@string/app_name"
android:visibility="gone"
android:background="@drawable/circle_background"
android:contentDescription="@string/app_name"
android:elevation="20dp"
android:src="@drawable/ic_close"
app:tint="@color/button_selector"
android:visibility="gone"
app:layout_constraintEnd_toEndOf="@id/image_predicted"
app:layout_constraintTop_toTopOf="@id/image_predicted" />
app:layout_constraintTop_toTopOf="@id/image_predicted"
app:tint="@color/button_selector" />

<ProgressBar
android:id="@+id/progress_bar"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:layout_marginTop="@dimen/margin_xxlarge"
android:visibility="invisible"
app:layout_constraintStart_toStartOf="parent"
app:layout_constraintEnd_toEndOf="parent"
app:layout_constraintTop_toTopOf="@+id/image_predicted" />
app:layout_constraintStart_toStartOf="parent"
app:layout_constraintTop_toTopOf="@+id/image_predicted" />

<com.google.android.material.floatingactionbutton.FloatingActionButton
android:id="@+id/capture_button"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:layout_marginBottom="@dimen/margin_xlarge"
android:layout_gravity="bottom"
android:contentDescription="@string/app_name"
android:src="@drawable/ic_camera"
android:layout_marginBottom="@dimen/margin_xlarge"
android:backgroundTint="@color/fab_selector"
android:contentDescription="@string/app_name"
android:enabled="false"
app:layout_constraintStart_toStartOf="parent"
android:src="@drawable/ic_camera"
app:layout_constraintBottom_toBottomOf="parent"
app:layout_constraintEnd_toEndOf="parent"
app:layout_constraintStart_toStartOf="parent" />

<androidx.appcompat.widget.SwitchCompat
android:id="@+id/use_self_translator_switch"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:layout_marginBottom="16dp"
android:minHeight="48dp"
android:text="@string/use_self_translator"
app:layout_constraintBottom_toTopOf="@+id/capture_button"
app:layout_constraintEnd_toEndOf="parent"
app:layout_constraintBottom_toBottomOf="parent" />
app:layout_constraintHorizontal_bias="0.498"
app:layout_constraintStart_toStartOf="parent" />
</androidx.constraintlayout.widget.ConstraintLayout>
Loading

0 comments on commit 4e0fe42

Please sign in to comment.