Skip to content

Commit

Permalink
[pytorch] Add BigGAN demo
Browse files Browse the repository at this point in the history
  • Loading branch information
AzizZayed committed Jun 15, 2021
1 parent f145614 commit a6ded8c
Show file tree
Hide file tree
Showing 8 changed files with 1,394 additions and 2 deletions.
15 changes: 14 additions & 1 deletion api/src/main/java/ai/djl/ndarray/BaseNDManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.util.PairList;
import ai.djl.util.RandomUtils;
import java.nio.Buffer;
import java.nio.file.Path;
import java.util.UUID;
Expand Down Expand Up @@ -156,7 +157,19 @@ public NDArray randomNormal(float loc, float scale, Shape shape, DataType dataTy
/** {@inheritDoc} */
@Override
public NDArray truncatedNormal(float loc, float scale, Shape shape, DataType dataType) {
throw new UnsupportedOperationException("Not supported!");
int sampleSize = (int) shape.size();
double[] dist = new double[sampleSize];

for (int i = 0; i < sampleSize; i++) {
double sample = RandomUtils.nextGaussian();
while (sample < -2 || sample > 2) {
sample = RandomUtils.nextGaussian();
}

dist[i] = sample;
}

return create(dist).addi(loc).muli(scale).reshape(shape).toType(dataType, false);
}

/** {@inheritDoc} */
Expand Down
2 changes: 1 addition & 1 deletion examples/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ dependencies {
}

application {
mainClassName = System.getProperty("main", "ai.djl.examples.inference.ObjectDetection")
mainClassName = System.getProperty("main", "ai.djl.examples.inference.biggan.Generator")
}

run {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
/*
* Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.examples.inference.biggan;

import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public final class BigGANCategory {
private static final Logger logger = LoggerFactory.getLogger(BigGANCategory.class);

public static final int NUMBER_OF_CATEGORIES = 1000;
private static final Map<String, BigGANCategory> CATEGORIES_BY_NAME =
new ConcurrentHashMap<>(NUMBER_OF_CATEGORIES);
private static String[] categoriesById;

private int id;
private String[] names;

static {
try {
parseCategories();
} catch (IOException e) {
logger.error("Error parsing the ImageNet categories: {}", e);
}
createCategoriesByName();
}

private BigGANCategory(int id, String[] names) {
this.id = id;
this.names = names;
}

public int getId() {
return id;
}

public String[] getNames() {
return names.clone();
}

public static BigGANCategory id(int id) {
String names = categoriesById[id];
int index = names.indexOf(',');
if (index < 0) {
return of(names);
} else {
return of(names.substring(0, index));
}
}

public static BigGANCategory of(String name) {
if (!CATEGORIES_BY_NAME.containsKey(name)) {
throw new IllegalArgumentException(name + " is not a valid category.");
}
return CATEGORIES_BY_NAME.get(name);
}

private static void createCategoriesByName() {
for (int i = 0; i < NUMBER_OF_CATEGORIES; i++) {
String[] categoryNames = categoriesById[i].split(", ");
BigGANCategory category = new BigGANCategory(i, categoryNames);

for (String name : categoryNames) {
CATEGORIES_BY_NAME.put(name, category);
}
}
}

private static void parseCategories() throws IOException {
String filePath = "src/main/resources/categories.txt";

List<String> fileLines = Files.readAllLines(Paths.get(filePath));
List<String> categories = new ArrayList<>(NUMBER_OF_CATEGORIES);
for (String line : fileLines) {
int nameIndex = line.indexOf(':') + 2;
categories.add(line.substring(nameIndex));
}

categoriesById = categories.toArray(new String[] {});
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.examples.inference.biggan;

public final class BigGANInput {
private int sampleSize;
private float truncation;
private BigGANCategory category;

public BigGANInput(int sampleSize, float truncation, BigGANCategory category) {
this.sampleSize = sampleSize;
this.truncation = truncation;
this.category = category;
}

BigGANInput(Builder builder) {
this.sampleSize = builder.sampleSize;
this.truncation = builder.truncation;
this.category = builder.category;
}

public int getSampleSize() {
return sampleSize;
}

public float getTruncation() {
return truncation;
}

public BigGANCategory getCategory() {
return category;
}

public static Builder builder() {
return new Builder();
}

public static final class Builder {
private int sampleSize = 5;
private float truncation = 0.5f;
private BigGANCategory category;

Builder() {
category = BigGANCategory.of("Egyptian cat");
}

public Builder optSampleSize(int sampleSize) {
this.sampleSize = sampleSize;
return this;
}

public Builder optTruncation(float truncation) {
this.truncation = truncation;
return this;
}

public Builder setCategory(BigGANCategory category) {
this.category = category;
return this;
}

public BigGANInput build() {
return new BigGANInput(this);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
/*
* Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.examples.inference.biggan;

import ai.djl.engine.Engine;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.translate.Batchifier;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

final class BigGANTranslator implements Translator<BigGANInput, Image[]> {
private static final Logger logger = LoggerFactory.getLogger(BigGANTranslator.class);
private static final int SEED_COLUMN_SIZE = 128;

@Override
public Image[] processOutput(TranslatorContext ctx, NDList list) throws Exception {
logOutputList(list);

NDArray output = list.get(0).addi(1).muli(128).clip(0, 255).toType(DataType.UINT8, false);

int sampleSize = (int) output.getShape().get(0);
Image[] images = new Image[sampleSize];

for (int i = 0; i < sampleSize; i++) {
images[i] = ImageFactory.getInstance().fromNDArray(output.get(i));
}

return images;
}

private void logOutputList(NDList list) {
logger.info("");
logger.info("MY OUTPUT:");
list.forEach(array -> logger.info(" out: {}", array.getShape()));
}

@Override
public NDList processInput(TranslatorContext ctx, BigGANInput input) throws Exception {
Engine.getInstance().setRandomSeed(0);
NDManager manager = ctx.getNDManager();

NDArray categoryArray = createCategoryArray(manager, input);
NDArray seed =
manager.truncatedNormal(new Shape(input.getSampleSize(), SEED_COLUMN_SIZE))
.muli(input.getTruncation());
NDArray truncation = manager.create(input.getTruncation());

logInputArrays(categoryArray, seed, truncation);
return new NDList(seed, categoryArray, truncation);
}

private NDArray createCategoryArray(NDManager manager, BigGANInput input) {
int categoryId = input.getCategory().getId();
int sampleSize = input.getSampleSize();

int[] indices = new int[sampleSize];
for (int i = 0; i < sampleSize; i++) {
indices[i] = categoryId;
}
return manager.create(indices).oneHot(BigGANCategory.NUMBER_OF_CATEGORIES);
}

private void logInputArrays(NDArray categoryArray, NDArray seed, NDArray truncation) {
logger.info("");
logger.info("MY INPUTS: ");
logger.info(" y: {}", categoryArray.getShape());
logger.info(" z: {}", seed.get(":, :10"));
logger.info(" truncation: {}", truncation.getShape());
}

@Override
public Batchifier getBatchifier() {
return null;
}
}
Loading

0 comments on commit a6ded8c

Please sign in to comment.