From 757e7531da6077250920d2b6b4d7766a402eda2a Mon Sep 17 00:00:00 2001 From: afeenster Date: Tue, 2 Mar 2021 13:07:21 -0800 Subject: [PATCH 01/25] This creates the component which will populate the Download Tab with Download Buttons. --- .../main/webapp/components/DowloadButtons.jsx | 55 +++++++++++++++++++ 1 file changed, 55 insertions(+) create mode 100644 central/src/main/webapp/components/DowloadButtons.jsx diff --git a/central/src/main/webapp/components/DowloadButtons.jsx b/central/src/main/webapp/components/DowloadButtons.jsx new file mode 100644 index 00000000000..166fdf1a27f --- /dev/null +++ b/central/src/main/webapp/components/DowloadButtons.jsx @@ -0,0 +1,55 @@ +import React, { Component, useState, useEffect, useRef } from "react"; +import Button from '@material-ui/core/Button'; +import ReactDOM from 'react-dom'; + +import { makeStyles } from '@material-ui/core/styles'; +import { theme } from '../css/useStyles' +import axios from 'axios' + + +const useFetch = (modelName) => { + const [data, setData] = useState([]); + + useEffect(() => { + async function fetchData() { + axios.get("http://"+window.location.host+"/serving/models?modelName="+modelName) + .then(function(response) { + let appdata = Object.keys(response.data).map(function(key) { + console.log(key) + return { + key: key, + link: response.data[key] + }; + }); + setData(appdata); + console.log(appdata) + }) + .catch(function(error) { + console.log(error); + }) + .then(function() { + // always executed + }); + + } + fetchData(); + }, ["http://"+window.location.host+"/serving/models?modelName="+modelName]); + + return data; +}; + + + +export default function ModelDownloadButtons(props) { + const modelLinks = useFetch(props.modelName); + return ( + <> + {Object.keys(modelLinks).map((keys) => ( + + + ) + )} + + ); + +} \ No newline at end of file From 3563314e19ca3944de0b38203c99f756c082088a Mon Sep 17 00:00:00 2001 From: afeenster Date: Tue, 2 Mar 2021 13:10:50 -0800 Subject: [PATCH 02/25] Making a place for the download buttons. --- central/src/main/webapp/components/ModelView.jsx | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/central/src/main/webapp/components/ModelView.jsx b/central/src/main/webapp/components/ModelView.jsx index 3cfeb6d0f51..9e9ccadd3e7 100644 --- a/central/src/main/webapp/components/ModelView.jsx +++ b/central/src/main/webapp/components/ModelView.jsx @@ -19,6 +19,7 @@ import Chip from '@material-ui/core/Chip'; import Divider from '@material-ui/core/Divider'; import DynForm from './DynForm'; +import ModelDownloadButtons from './DownloadButtons'; import axios from 'axios' @@ -186,6 +187,7 @@ export default function ModelView(props) { + @@ -250,6 +252,9 @@ export default function ModelView(props) { : } + + + From 3dc1bf20d1d26baf08d02dfa7928f82e3b9c97e5 Mon Sep 17 00:00:00 2001 From: afeenster Date: Tue, 2 Mar 2021 13:53:59 -0800 Subject: [PATCH 03/25] Adding the Model Download Handler allowing the backend to feed the links into the Model View and making slight changes for readablity. --- .../HttpStaticFileServerInitializer.java | 2 + .../central/handler/ModelDownloadHandler.java | 123 +++++++++++++ .../central/http/BadRequestException.java | 40 +++++ .../responseencoder/HttpRequestResponse.java | 125 +++++++++++++ .../djl/serving/central/utils/NettyUtils.java | 167 ++++++++++++++++++ ...DowloadButtons.jsx => DownloadButtons.jsx} | 6 +- 6 files changed, 459 insertions(+), 4 deletions(-) create mode 100644 central/src/main/java/ai/djl/serving/central/handler/ModelDownloadHandler.java create mode 100644 central/src/main/java/ai/djl/serving/central/http/BadRequestException.java create mode 100644 central/src/main/java/ai/djl/serving/central/responseencoder/HttpRequestResponse.java create mode 100644 central/src/main/java/ai/djl/serving/central/utils/NettyUtils.java rename central/src/main/webapp/components/{DowloadButtons.jsx => DownloadButtons.jsx} (96%) diff --git a/central/src/main/java/ai/djl/serving/central/HttpStaticFileServerInitializer.java b/central/src/main/java/ai/djl/serving/central/HttpStaticFileServerInitializer.java index e54a1303280..464eaf62527 100644 --- a/central/src/main/java/ai/djl/serving/central/HttpStaticFileServerInitializer.java +++ b/central/src/main/java/ai/djl/serving/central/HttpStaticFileServerInitializer.java @@ -14,6 +14,7 @@ import ai.djl.serving.central.handler.HttpStaticFileServerHandler; import ai.djl.serving.central.handler.ModelMetaDataHandler; +import ai.djl.serving.central.handler.ModelDownloadHandler; import io.netty.channel.ChannelInitializer; import io.netty.channel.ChannelPipeline; import io.netty.channel.socket.SocketChannel; @@ -54,6 +55,7 @@ public void initChannel(SocketChannel ch) { pipeline.addLast(new HttpServerCodec()); pipeline.addLast(new HttpObjectAggregator(65536)); pipeline.addLast(new ChunkedWriteHandler()); + pipeline.addLast(new ModelDownloadHandler()); pipeline.addLast(new ModelMetaDataHandler()); pipeline.addLast(new HttpStaticFileServerHandler()); } diff --git a/central/src/main/java/ai/djl/serving/central/handler/ModelDownloadHandler.java b/central/src/main/java/ai/djl/serving/central/handler/ModelDownloadHandler.java new file mode 100644 index 00000000000..e7bc0d7b0a4 --- /dev/null +++ b/central/src/main/java/ai/djl/serving/central/handler/ModelDownloadHandler.java @@ -0,0 +1,123 @@ +/* + * Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.serving.central.handler; + +import ai.djl.Application; +import ai.djl.repository.Artifact; +import ai.djl.repository.zoo.ModelNotFoundException; +import ai.djl.repository.zoo.ModelZoo; +import ai.djl.serving.central.responseencoder.HttpRequestResponse; +import ai.djl.serving.central.utils.NettyUtils; +import ai.djl.serving.http.BadRequestException; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.SimpleChannelInboundHandler; +import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.QueryStringDecoder; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.net.URI; +import java.util.*; +import java.util.concurrent.CompletableFuture; + +final class ModelLink { + + private static final Logger logger = LoggerFactory.getLogger(ModelLink.class); + private static Map links = new HashMap(); + private ModelLink() {} + + private static URI BASE_URI = URI.create("https://mlrepo.djl.ai/"); + + public static Map linkFinder(String modelName) throws IOException, ModelNotFoundException { + Map> models = ModelZoo.listModels(); + models.forEach( + (app, list) -> { + list.forEach( + artifact -> { + if (artifact.getName().equals(modelName)){ + for (Map.Entry entry : + artifact.getFiles().entrySet()) { + URI fileUri = URI.create(entry.getValue().getUri()); + URI baseUri = artifact.getMetadata().getRepositoryUri(); + if (!fileUri.isAbsolute()) { + fileUri = BASE_URI.resolve(baseUri).resolve(fileUri); + } + try { + links.put(entry.getKey(),fileUri); + } catch(Exception e){ + logger.info(String.valueOf(e)); + } + } + }}); + }); + return links; + } + + public static void main(String[] args) throws IOException, ModelNotFoundException { + logger.info("Output:"); + logger.info(String.valueOf(linkFinder("simple_pose_resnet50_v1b"))); + } +} + + +/** + * A handler to handle deployment requests from the UI/ + * @author erik.bamberg@web.de + * + */ +public class ModelDownloadHandler extends SimpleChannelInboundHandler { + + HttpRequestResponse jsonResponse; + public ModelDownloadHandler() { jsonResponse = new HttpRequestResponse(); } + + /** + * handle the deployment request by forwarding the request to the serving-instance. + * + * @param ctx the context + * @param request the full request + */ + @Override + protected void channelRead0(ChannelHandlerContext ctx, FullHttpRequest request) throws IOException, ModelNotFoundException { + final Logger logger = LoggerFactory.getLogger(ModelDownloadHandler.class); + QueryStringDecoder decoder = new QueryStringDecoder(request.uri()); + String modelName=NettyUtils.getParameter(decoder, "modelName", null); + CompletableFuture.supplyAsync( + () -> { + try { + if (modelName!=null) { + logger.info(String.valueOf(ModelLink.linkFinder(modelName))); + return ModelLink.linkFinder(modelName); + } else { + throw new BadRequestException("modelName and url is mandatory."); + } + + } catch (IOException | ModelNotFoundException ex) { + throw new IllegalArgumentException(ex.getMessage(), ex); + } + }) + .exceptionally((ex) -> Collections.emptyMap()) + .thenAccept(linksMap -> jsonResponse.sendAsJson(ctx, request, linksMap)); + } + + + /** {@inheritDoc} */ + @Override + public boolean acceptInboundMessage(Object msg) { + FullHttpRequest request = (FullHttpRequest) msg; + + String uri = request.uri(); + return uri.startsWith("/serving/models?"); + } + +} diff --git a/central/src/main/java/ai/djl/serving/central/http/BadRequestException.java b/central/src/main/java/ai/djl/serving/central/http/BadRequestException.java new file mode 100644 index 00000000000..5aac81f7868 --- /dev/null +++ b/central/src/main/java/ai/djl/serving/central/http/BadRequestException.java @@ -0,0 +1,40 @@ +/* + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.serving.http; + +/** Thrown when a bad HTTP request is received. */ +public class BadRequestException extends IllegalArgumentException { + + static final long serialVersionUID = 1L; + + /** + * Constructs an {@code BadRequestException} with the specified detail message. + * + * @param message The detail message (which is saved for later retrieval by the {@link + * #getMessage()} method) + */ + public BadRequestException(String message) { + super(message); + } + + /** + * Constructs an {@code BadRequestException} with the specified detail message and a root cause. + * + * @param message The detail message (which is saved for later retrieval by the {@link + * #getMessage()} method) + * @param cause root cause + */ + public BadRequestException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/central/src/main/java/ai/djl/serving/central/responseencoder/HttpRequestResponse.java b/central/src/main/java/ai/djl/serving/central/responseencoder/HttpRequestResponse.java new file mode 100644 index 00000000000..a7a5cf42f6d --- /dev/null +++ b/central/src/main/java/ai/djl/serving/central/responseencoder/HttpRequestResponse.java @@ -0,0 +1,125 @@ +/* + * Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.serving.central.responseencoder; + +import ai.djl.modality.Classifications; +import ai.djl.modality.Classifications.ClassificationsSerializer; +import ai.djl.modality.cv.output.DetectedObjects; +import ai.djl.repository.Metadata; +import com.google.gson.Gson; +import com.google.gson.GsonBuilder; +import com.google.gson.JsonPrimitive; +import com.google.gson.JsonSerializer; +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.http.DefaultFullHttpResponse; +import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.FullHttpResponse; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpHeaderValues; +import io.netty.handler.codec.http.HttpResponseStatus; +import io.netty.handler.codec.http.HttpUtil; +import io.netty.handler.codec.http.HttpVersion; +import io.netty.util.CharsetUtil; +import java.lang.reflect.Modifier; + +/** + * serialize to json and send the response to the client. + * + * @author erik.bamberg@web.de + */ +public class HttpRequestResponse { + + private static final Gson GSON_WITH_TRANSIENT_FIELDS = + new GsonBuilder() + .setDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSS'Z'") + .setPrettyPrinting() + .excludeFieldsWithModifiers(Modifier.STATIC) + .registerTypeAdapter(Classifications.class, new ClassificationsSerializer()) + .registerTypeAdapter(DetectedObjects.class, new ClassificationsSerializer()) + .registerTypeAdapter(Metadata.class, new MetaDataSerializer()) + .registerTypeAdapter( + Double.class, + (JsonSerializer) + (src, t, ctx) -> { + long v = src.longValue(); + if (src.equals(Double.valueOf(String.valueOf(v)))) { + return new JsonPrimitive(v); + } + return new JsonPrimitive(src); + }) + .create(); + + /** + * send a response to the client. + * + * @param ctx channel context + * @param request full request + * @param entity the response + */ + public void sendAsJson(ChannelHandlerContext ctx, FullHttpRequest request, Object entity) { + + String serialized = GSON_WITH_TRANSIENT_FIELDS.toJson(entity); + ByteBuf buffer = ctx.alloc().buffer(serialized.length()); + buffer.writeCharSequence(serialized, CharsetUtil.UTF_8); + + FullHttpResponse response = + new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK, buffer); + response.headers().set(HttpHeaderNames.CONTENT_TYPE, "application/json; charset=UTF-8"); + boolean keepAlive = HttpUtil.isKeepAlive(request); + this.sendAndCleanupConnection(ctx, response, keepAlive); + } + + /** + * send content of a ByteBuffer as + * response to the client. + * + * @param ctx channel context + * @param request full request + * @param entity the response + */ + public void sendByteBuffer(ChannelHandlerContext ctx, ByteBuf buffer) { + + FullHttpResponse response = + new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK, buffer); + response.headers().set(HttpHeaderNames.CONTENT_TYPE, "application/json; charset=UTF-8"); + this.sendAndCleanupConnection(ctx, response, false); + } + + /** + * If Keep-Alive is disabled, attaches "Connection: close" header to the response and closes the + * connection after the response being sent. + * + * @param ctx context + * @param request full request + * @param response full response + */ + private void sendAndCleanupConnection( + ChannelHandlerContext ctx, FullHttpResponse response, boolean keepAlive ) { + HttpUtil.setContentLength(response, response.content().readableBytes()); + if (!keepAlive) { + // We're going to close the connection as soon as the response is sent, + // so we should also make it clear for the client. + response.headers().set(HttpHeaderNames.CONNECTION, HttpHeaderValues.CLOSE); + } + + ChannelFuture flushPromise = ctx.writeAndFlush(response); + + if (!keepAlive) { + // Close the connection as soon as the response is sent. + flushPromise.addListener(ChannelFutureListener.CLOSE); + } + } +} diff --git a/central/src/main/java/ai/djl/serving/central/utils/NettyUtils.java b/central/src/main/java/ai/djl/serving/central/utils/NettyUtils.java new file mode 100644 index 00000000000..6d24256551e --- /dev/null +++ b/central/src/main/java/ai/djl/serving/central/utils/NettyUtils.java @@ -0,0 +1,167 @@ +/* + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.serving.central.utils; + +import ai.djl.ModelException; +import ai.djl.modality.Input; +import ai.djl.repository.zoo.ModelNotFoundException; +import ai.djl.util.JsonUtils; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.http.DefaultFullHttpResponse; +import io.netty.handler.codec.http.FullHttpResponse; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpHeaderValues; +import io.netty.handler.codec.http.HttpHeaders; +import io.netty.handler.codec.http.HttpRequest; +import io.netty.handler.codec.http.HttpResponseStatus; +import io.netty.handler.codec.http.HttpUtil; +import io.netty.handler.codec.http.HttpVersion; +import io.netty.handler.codec.http.QueryStringDecoder; +import io.netty.handler.codec.http.multipart.Attribute; +import io.netty.handler.codec.http.multipart.FileUpload; +import io.netty.handler.codec.http.multipart.InterfaceHttpData; +import io.netty.util.AttributeKey; +import io.netty.util.CharsetUtil; +import java.io.IOException; +import java.net.SocketAddress; +import java.nio.charset.StandardCharsets; +import java.util.List; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** A utility class that handling Netty request and response. */ +public final class NettyUtils { + + private static final Logger logger = LoggerFactory.getLogger("NettyUtils"); + + private NettyUtils() {} + + /** + * Sends the json object to client. + * + * @param ctx the connection context + * @param json the json object + */ + public static void sendJsonResponse(ChannelHandlerContext ctx, Object json) { + sendJsonResponse(ctx, JsonUtils.GSON_PRETTY.toJson(json), HttpResponseStatus.OK); + } + + /** + * Sends the json string to client with specified status. + * + * @param ctx the connection context + * @param json the json string + * @param status the HTTP status + */ + public static void sendJsonResponse( + ChannelHandlerContext ctx, Object json, HttpResponseStatus status) { + sendJsonResponse(ctx, JsonUtils.GSON_PRETTY.toJson(json), status); + } + + /** + * Sends the json string to client. + * + * @param ctx the connection context + * @param json the json string + */ + public static void sendJsonResponse(ChannelHandlerContext ctx, String json) { + sendJsonResponse(ctx, json, HttpResponseStatus.OK); + } + + + + /** + * Returns the bytes for the specified {@code ByteBuf}. + * + * @param buf the {@code ByteBuf} to read + * @return the bytes for the specified {@code ByteBuf} + */ + public static byte[] getBytes(ByteBuf buf) { + if (buf.hasArray()) { + return buf.array(); + } + + byte[] ret = new byte[buf.readableBytes()]; + int readerIndex = buf.readerIndex(); + buf.getBytes(readerIndex, ret); + return ret; + } + + /** + * Reads the parameter's value for the key from the uri. + * + * @param decoder the {@code QueryStringDecoder} parsed from uri + * @param key the parameter key + * @param def the default value + * @return the parameter's value + */ + public static String getParameter(QueryStringDecoder decoder, String key, String def) { + List param = decoder.parameters().get(key); + if (param != null && !param.isEmpty()) { + return param.get(0); + } + return def; + } + + /** + * Read the parameter's integer value for the key from the uri. + * + * @param decoder the {@code QueryStringDecoder} parsed from uri + * @param key the parameter key + * @param def the default value + * @return the parameter's integer value + * @throws NumberFormatException exception is thrown when the parameter-value is not numeric. + */ + public static int getIntParameter(QueryStringDecoder decoder, String key, int def) { + String value = getParameter(decoder, key, null); + if (value == null || value.isEmpty()) { + return def; + } + return Integer.parseInt(value); + } + + /** + * Parses form data and added to the {@link Input} object. + * + * @param data the form data + * @param input the {@link Input} object to be added to + */ + public static void addFormData(InterfaceHttpData data, Input input) { + if (data == null) { + return; + } + try { + String name = data.getName(); + switch (data.getHttpDataType()) { + case Attribute: + Attribute attribute = (Attribute) data; + input.addData(name, attribute.getValue().getBytes(StandardCharsets.UTF_8)); + break; + case FileUpload: + FileUpload fileUpload = (FileUpload) data; + input.addData(name, getBytes(fileUpload.getByteBuf())); + break; + default: + throw new IllegalArgumentException( + "Except form field, but got " + data.getHttpDataType()); + } + } catch (IOException e) { + throw new AssertionError(e); + } + } +} diff --git a/central/src/main/webapp/components/DowloadButtons.jsx b/central/src/main/webapp/components/DownloadButtons.jsx similarity index 96% rename from central/src/main/webapp/components/DowloadButtons.jsx rename to central/src/main/webapp/components/DownloadButtons.jsx index 166fdf1a27f..7914d4a3723 100644 --- a/central/src/main/webapp/components/DowloadButtons.jsx +++ b/central/src/main/webapp/components/DownloadButtons.jsx @@ -3,7 +3,6 @@ import Button from '@material-ui/core/Button'; import ReactDOM from 'react-dom'; import { makeStyles } from '@material-ui/core/styles'; -import { theme } from '../css/useStyles' import axios from 'axios' @@ -50,6 +49,5 @@ export default function ModelDownloadButtons(props) { ) )} - ); - -} \ No newline at end of file + ); +} From 89f48c4351cc61ec62fa647a69bca172eef535b8 Mon Sep 17 00:00:00 2001 From: afeenster Date: Tue, 2 Mar 2021 14:07:46 -0800 Subject: [PATCH 04/25] Getting rid of some of the test code. --- .../serving/central/handler/ModelDownloadHandler.java | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/central/src/main/java/ai/djl/serving/central/handler/ModelDownloadHandler.java b/central/src/main/java/ai/djl/serving/central/handler/ModelDownloadHandler.java index e7bc0d7b0a4..8224661de32 100644 --- a/central/src/main/java/ai/djl/serving/central/handler/ModelDownloadHandler.java +++ b/central/src/main/java/ai/djl/serving/central/handler/ModelDownloadHandler.java @@ -63,17 +63,12 @@ public static Map linkFinder(String modelName) throws IOException, }); return links; } - - public static void main(String[] args) throws IOException, ModelNotFoundException { - logger.info("Output:"); - logger.info(String.valueOf(linkFinder("simple_pose_resnet50_v1b"))); - } } /** - * A handler to handle deployment requests from the UI/ - * @author erik.bamberg@web.de + * A handler to handle download requests from the UI + * @author anfee1@morgan.edu * */ public class ModelDownloadHandler extends SimpleChannelInboundHandler { @@ -99,7 +94,7 @@ protected void channelRead0(ChannelHandlerContext ctx, FullHttpRequest request) logger.info(String.valueOf(ModelLink.linkFinder(modelName))); return ModelLink.linkFinder(modelName); } else { - throw new BadRequestException("modelName and url is mandatory."); + throw new BadRequestException("modelName is mandatory."); } } catch (IOException | ModelNotFoundException ex) { From a491d98ad70610c42210799a15f7a8bb8ed367e5 Mon Sep 17 00:00:00 2001 From: Jake Lee Date: Tue, 2 Mar 2021 15:41:42 -0800 Subject: [PATCH 05/25] Improve Block usability (#712) * Use builder pattern for Parameter (#661) * Make XavierInitializer default value & Improve setInitializer (#664) * Refactor initialize (#675) * Remove NDManager on getOutputShapes (#710) --- .../java/ai/djl/modality/nlp/Decoder.java | 4 +- .../java/ai/djl/modality/nlp/Encoder.java | 4 +- .../ai/djl/modality/nlp/EncoderDecoder.java | 9 +- .../nlp/embedding/TrainableTextEmbedding.java | 4 +- .../main/java/ai/djl/nn/AbstractBlock.java | 145 ++++------- .../java/ai/djl/nn/AbstractSymbolBlock.java | 3 +- api/src/main/java/ai/djl/nn/Block.java | 31 +-- api/src/main/java/ai/djl/nn/LambdaBlock.java | 6 +- .../main/java/ai/djl/nn/ParallelBlock.java | 8 +- api/src/main/java/ai/djl/nn/Parameter.java | 246 ++++++++++++------ .../main/java/ai/djl/nn/ParameterType.java | 41 --- .../main/java/ai/djl/nn/SequentialBlock.java | 7 +- .../ai/djl/nn/convolutional/Convolution.java | 33 ++- .../djl/nn/convolutional/Deconvolution.java | 33 ++- .../ai/djl/nn/core/ConstantEmbedding.java | 2 +- .../main/java/ai/djl/nn/core/Embedding.java | 24 +- api/src/main/java/ai/djl/nn/core/Linear.java | 34 ++- api/src/main/java/ai/djl/nn/core/Prelu.java | 12 +- .../main/java/ai/djl/nn/norm/BatchNorm.java | 46 +++- api/src/main/java/ai/djl/nn/norm/Dropout.java | 3 +- .../ai/djl/nn/recurrent/RecurrentBlock.java | 56 ++-- .../java/ai/djl/nn/transformer/BertBlock.java | 16 +- .../BertMaskedLanguageModelBlock.java | 10 +- .../nn/transformer/BertNextSentenceBlock.java | 2 +- .../nn/transformer/BertPretrainingBlock.java | 5 +- .../ai/djl/nn/transformer/IdEmbedding.java | 10 +- .../PointwiseFeedForwardBlock.java | 17 +- .../ScaledDotProductAttentionBlock.java | 2 +- .../transformer/TransformerEncoderBlock.java | 2 +- .../djl/training/DefaultTrainingConfig.java | 51 +++- .../java/ai/djl/training/TrainingConfig.java | 7 +- .../initializer/XavierInitializer.java | 2 +- .../ai/djl/nn/convolutional/ShapeUtils.java | 2 +- .../basicdataset/AirfoilRandomAccessTest.java | 3 +- .../basicdataset/AmesRandomAccessTest.java | 3 +- .../java/ai/djl/basicdataset/PikachuTest.java | 3 +- .../examples/training/TrainBertOnCode.java | 4 +- .../examples/training/TrainMnistWithLSTM.java | 2 - .../ai/djl/fasttext/FtTrainingConfig.java | 5 +- .../modality/cv/SingleShotDetectionTest.java | 52 ++-- .../modality/nlp/SimpleTextEncoderTest.java | 2 - .../model_zoo/classification/AlexNetTest.java | 6 +- .../classification/GoogLeNetTest.java | 6 +- .../model_zoo/classification/LeNetTest.java | 6 +- .../model_zoo/classification/NiNTest.java | 9 +- .../model_zoo/classification/ResnetTest.java | 7 +- .../classification/SqueezenetTest.java | 2 +- .../model_zoo/classification/VGGTest.java | 10 +- .../NDArrayElementArithmeticOpTest.java | 5 +- .../integration/tests/nn/BlockCoreTest.java | 51 ++-- .../tests/nn/PoolingOperationsTest.java | 4 +- .../ScaledDotProductAttentionBlockTest.java | 3 +- .../tests/training/ActivationTest.java | 4 +- .../tests/training/BlocksTest.java | 4 +- .../tests/training/DatasetTest.java | 4 +- .../GradientCollectorIntegrationTest.java | 5 +- .../integration/tests/training/ModelTest.java | 2 - .../tests/training/OptimizerTest.java | 17 +- .../ssd/SingleShotDetection.java | 84 +++--- .../java/ai/djl/mxnet/engine/MxModel.java | 11 +- .../ai/djl/mxnet/engine/MxSymbolBlock.java | 32 ++- .../MxGradientCollectorIntegrationTest.java | 3 +- .../mxnet/integration/MxSymbolBlockTest.java | 8 +- .../paddlepaddle/engine/PpSymbolBlock.java | 2 +- .../java/ai/djl/pytorch/engine/PtModel.java | 12 +- .../ai/djl/pytorch/engine/PtSymbolBlock.java | 2 +- .../djl/tensorflow/engine/TfSymbolBlock.java | 6 +- 67 files changed, 705 insertions(+), 551 deletions(-) delete mode 100644 api/src/main/java/ai/djl/nn/ParameterType.java diff --git a/api/src/main/java/ai/djl/modality/nlp/Decoder.java b/api/src/main/java/ai/djl/modality/nlp/Decoder.java index 72a6ead095f..e6dd38fedf3 100644 --- a/api/src/main/java/ai/djl/modality/nlp/Decoder.java +++ b/api/src/main/java/ai/djl/modality/nlp/Decoder.java @@ -64,8 +64,8 @@ public void initializeChildBlocks(NDManager manager, DataType dataType, Shape... /** {@inheritDoc} */ @Override - public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) { - return block.getOutputShapes(manager, inputShapes); + public Shape[] getOutputShapes(Shape[] inputShapes) { + return block.getOutputShapes(inputShapes); } /** {@inheritDoc} */ diff --git a/api/src/main/java/ai/djl/modality/nlp/Encoder.java b/api/src/main/java/ai/djl/modality/nlp/Encoder.java index 7ad4fbdcf6f..e3b1631322f 100644 --- a/api/src/main/java/ai/djl/modality/nlp/Encoder.java +++ b/api/src/main/java/ai/djl/modality/nlp/Encoder.java @@ -79,8 +79,8 @@ public void initializeChildBlocks(NDManager manager, DataType dataType, Shape... /** {@inheritDoc} */ @Override - public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) { - return block.getOutputShapes(manager, inputShapes); + public Shape[] getOutputShapes(Shape[] inputShapes) { + return block.getOutputShapes(inputShapes); } /** {@inheritDoc} */ diff --git a/api/src/main/java/ai/djl/modality/nlp/EncoderDecoder.java b/api/src/main/java/ai/djl/modality/nlp/EncoderDecoder.java index afc0958ace9..a7c47380a9d 100644 --- a/api/src/main/java/ai/djl/modality/nlp/EncoderDecoder.java +++ b/api/src/main/java/ai/djl/modality/nlp/EncoderDecoder.java @@ -97,19 +97,18 @@ public NDList forward( * @param manager the NDManager to initialize the parameters * @param dataType the datatype of the parameters * @param inputShapes the shapes of the inputs to the block - * @return the shapes of the outputs of the block */ @Override - public Shape[] initialize(NDManager manager, DataType dataType, Shape... inputShapes) { + public void initialize(NDManager manager, DataType dataType, Shape... inputShapes) { beforeInitialize(inputShapes); encoder.initialize(manager, dataType, inputShapes[0]); - return decoder.initialize(manager, dataType, inputShapes[1]); + decoder.initialize(manager, dataType, inputShapes[1]); } /** {@inheritDoc} */ @Override - public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) { - return decoder.getOutputShapes(manager, new Shape[] {inputShapes[1]}); + public Shape[] getOutputShapes(Shape[] inputShapes) { + return decoder.getOutputShapes(new Shape[] {inputShapes[1]}); } /** {@inheritDoc} */ diff --git a/api/src/main/java/ai/djl/modality/nlp/embedding/TrainableTextEmbedding.java b/api/src/main/java/ai/djl/modality/nlp/embedding/TrainableTextEmbedding.java index f848f25f3e5..4fd7dc71277 100644 --- a/api/src/main/java/ai/djl/modality/nlp/embedding/TrainableTextEmbedding.java +++ b/api/src/main/java/ai/djl/modality/nlp/embedding/TrainableTextEmbedding.java @@ -88,7 +88,7 @@ public void initializeChildBlocks(NDManager manager, DataType dataType, Shape... /** {@inheritDoc} */ @Override - public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) { - return trainableWordEmbedding.getOutputShapes(manager, inputShapes); + public Shape[] getOutputShapes(Shape[] inputShapes) { + return trainableWordEmbedding.getOutputShapes(inputShapes); } } diff --git a/api/src/main/java/ai/djl/nn/AbstractBlock.java b/api/src/main/java/ai/djl/nn/AbstractBlock.java index 6d4c2c4ecd2..7f67d8431a6 100644 --- a/api/src/main/java/ai/djl/nn/AbstractBlock.java +++ b/api/src/main/java/ai/djl/nn/AbstractBlock.java @@ -30,7 +30,7 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Locale; -import java.util.function.Function; +import java.util.function.Predicate; /** * {@code AbstractBlock} is an abstract implementation of {@link Block}. @@ -43,12 +43,11 @@ *
    *
  • Define a version for serializing parameter and metadata and pass it to the parent * constructor - *
  • Use {@link AbstractBlock#addParameter(Parameter, Shape)} or {@link - * AbstractBlock#addParameter(Parameter, Function)} to add parameters to your block in the + *
  • Use {@link AbstractBlock#addParameter(Parameter)} to add parameters to your block in the * constructor if necessary. *
  • Use {@link AbstractBlock#addChildBlock(String, Block)} to add child blocks if necessary. - *
  • Override {@link AbstractBlock#getOutputShapes(NDManager, Shape[])} to determine the shape - * of your custom block's output based on the input it will receive. + *
  • Override {@link Block#getOutputShapes(Shape[])} to determine the shape of your custom + * block's output based on the input it will receive. *
  • Override {@link AbstractBlock#initializeChildBlocks(NDManager, DataType, Shape...)} if you * added child blocks to initialize them based on the input shape your block will receive. You * can skip this if your block does not contain child blocks @@ -61,9 +60,9 @@ *
* *

If you use {@link AbstractBlock#addParameter(Parameter)} to add parameters, you have to take - * care of parameter initialization yourself. In this case, you need to override {@link - * AbstractBlock#getParameterShape(String, Shape[])} to determine the shape of your parameters. If - * you use the other variants of {@code addParameter} this is done for you. + * care of parameter initialization yourself. In this case, you need to setShape to your parameters + * if you know the shape of Parameter or you can implement prepare to setShape when you see the + * input shape. */ // Using LinkedHashMap instead of Map is intentional: we want to make sure that consumers // of this API know the children and parameters are always iterated over in insertion order. @@ -99,14 +98,6 @@ public abstract class AbstractBlock implements Block { */ protected LinkedHashMap parameters = new LinkedHashMap<>(); - /** - * Callbacks to determine the shape of a parameter. Values may be null in which case extending - * classes need to override {@link Block#getParameterShape(String, Shape[])} and implement - * parameter shape resolution manually. - */ - protected LinkedHashMap> parameterShapeCallbacks = - new LinkedHashMap<>(); - /** * Builds an empty block with the given version for parameter serialization. * @@ -195,73 +186,20 @@ protected final B addChildBlock(String name, B block) { return block; } - /** - * Adds a parameter to this block. If parameters are added with this method, subclasses need to - * override {@link Block#getParameterShape(String, Shape[])} and return the shapes of parameters - * themselves. - * - * @param parameter the parameter to add, not null - * @param

the specific parameter subclass - * @return the parameter passed as arguments to make it easier to create and assign paramters in - * one line - */ - protected final

P addParameter(P parameter) { - return addParameter(parameter, (Function) null); - } - /** * Adds a parameter to this block. If parameters are added with this method, intialization of * the parameter works out of the box * - * @param parameter the parameter to add, not null - * @param shape the shape of the parameter * @param

the specific parameter subclass - * @return the parameter passed as arguments to make it easier to create and assign paramters in - * one line - */ - protected final

P addParameter(P parameter, Shape shape) { - return addParameter(parameter, (inputShapes) -> shape); - } - - /** - * Adds a parameter to this block. If parameters are added with this method, intialization of - * the parameter works out of the box - * * @param parameter the parameter to add, not null - * @param shapeCallback the method to call once the input shape of this block is known to - * determine the shape of the given parameter - * @param

the specific parameter subclass * @return the parameter passed as arguments to make it easier to create and assign parameters * in one line */ - protected final

P addParameter( - P parameter, Function shapeCallback) { + protected final

P addParameter(P parameter) { parameters.put(parameter.getName(), parameter); - parameterShapeCallbacks.put(parameter.getName(), shapeCallback); return parameter; } - /** {@inheritDoc} */ - @Override - public Shape getParameterShape(String name, Shape[] inputShapes) { - Function callback = parameterShapeCallbacks.get(name); - if (callback == null) { - Parameter parameter = parameters.get(name); - if (parameter == null) { - throw new IllegalArgumentException( - "No parameter named " + name + " found in this block."); - } else { - throw new IllegalStateException( - "No shape initializer for parameter " - + name - + "found. " - + "Either pass an initializer for the shape when adding the " - + "parameter or override getParameterShape in the subclass."); - } - } - return callback.apply(inputShapes); - } - /** {@inheritDoc} */ @Override public BlockList getChildren() { @@ -285,13 +223,9 @@ public PairList describeInput() { /** {@inheritDoc} */ @Override - public void setInitializer(Initializer initializer) { - for (Parameter parameter : parameters.values()) { - parameter.setInitializer(initializer, false); - } - for (Block child : children.values()) { - child.setInitializer(initializer); - } + public void setInitializer(Initializer initializer, Parameter.Type params) { + Predicate predicate = parameter -> parameter.getType().equals(params); + setInitializer(initializer, predicate); } /** {@inheritDoc} */ @@ -301,18 +235,50 @@ public void setInitializer(Initializer initializer, String paramName) { if (parameter == null) { throw new IllegalArgumentException("Could not find parameter " + paramName); } - parameter.setInitializer(initializer, true); + parameter.setInitializer(initializer); } /** {@inheritDoc} */ @Override - public Shape[] initialize(NDManager manager, DataType dataType, Shape... inputShapes) { + public void setInitializer(Initializer initializer, Predicate predicate) { + List params = getParameters().values(); + for (Parameter param : params) { + if (predicate.test(param)) { + param.setInitializer(initializer); + } + } + } + + /** {@inheritDoc} */ + @Override + public void initialize(NDManager manager, DataType dataType, Shape... inputShapes) { beforeInitialize(inputShapes); + // if parameters are initialized, skip it + if (!isInitialized()) { + // setShape for all params + prepare(inputShapes); + } for (Parameter parameter : parameters.values()) { - parameter.initialize(manager, dataType, inputShapes); + parameter.initialize(manager, dataType); } initializeChildBlocks(manager, dataType, inputShapes); - return getOutputShapes(manager, inputShapes); + } + + /** + * Performs any action necessary before initialization. For example, keep the input information + * or verify the layout. + * + * @param inputShapes the expected shapes of the input + */ + protected void beforeInitialize(Shape... inputShapes) { + if (inputNames.isEmpty()) { + // automatically assign input names + inputNames = new ArrayList<>(); + for (int i = 0; i < inputShapes.length; ++i) { + inputNames.add("data" + i); + } + } + this.inputShapes = inputShapes; } /** @@ -355,20 +321,11 @@ public ParameterList getDirectParameters() { } /** - * Performs any action necessary before initialization. + * Sets the shape of {@link Parameter}s. * - * @param inputShapes the expected shapes of the input + * @param inputShapes the shapes of inputs */ - protected void beforeInitialize(Shape[] inputShapes) { - if (inputNames.isEmpty()) { - // automatically assign input names - inputNames = new ArrayList<>(); - for (int i = 0; i < inputShapes.length; ++i) { - inputNames.add("data" + i); - } - } - this.inputShapes = inputShapes; - } + protected void prepare(Shape[] inputShapes) {} /** {@inheritDoc} */ @Override @@ -494,7 +451,7 @@ public String toString() { appendShape(sb, inputShapeDescription.values().toArray(new Shape[0])); sb.append(" -> "); Shape[] outputShapes = - getOutputShapes(null, inputShapeDescription.values().toArray(new Shape[0])); + getOutputShapes(inputShapeDescription.values().toArray(new Shape[0])); appendShape(sb, outputShapes); } else { sb.append("Uninitialized"); diff --git a/api/src/main/java/ai/djl/nn/AbstractSymbolBlock.java b/api/src/main/java/ai/djl/nn/AbstractSymbolBlock.java index f870eea148f..fdd7fcc0d6c 100644 --- a/api/src/main/java/ai/djl/nn/AbstractSymbolBlock.java +++ b/api/src/main/java/ai/djl/nn/AbstractSymbolBlock.java @@ -12,7 +12,6 @@ */ package ai.djl.nn; -import ai.djl.ndarray.NDManager; import ai.djl.ndarray.types.Shape; /** {@code AbstractSymbolBlock} is an abstract implementation of {@link SymbolBlock}. */ @@ -29,7 +28,7 @@ public AbstractSymbolBlock(byte version) { /** {@inheritDoc} */ @Override - public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) { + public Shape[] getOutputShapes(Shape[] inputShapes) { throw new UnsupportedOperationException("not implement!"); } } diff --git a/api/src/main/java/ai/djl/nn/Block.java b/api/src/main/java/ai/djl/nn/Block.java index 6febeff5f6f..4bd3fb09e2b 100644 --- a/api/src/main/java/ai/djl/nn/Block.java +++ b/api/src/main/java/ai/djl/nn/Block.java @@ -25,6 +25,7 @@ import java.io.DataInputStream; import java.io.DataOutputStream; import java.io.IOException; +import java.util.function.Predicate; /** * A {@code Block} is a composable function that forms a neural network. @@ -158,11 +159,12 @@ default NDList forward( } /** - * Sets an {@link Initializer} to the block. + * Sets an {@link Initializer} to all the parameters that match parameter type in the block. * * @param initializer the initializer to set + * @param type the Parameter Type we want to setInitializer */ - void setInitializer(Initializer initializer); + void setInitializer(Initializer initializer, Parameter.Type type); /** * Sets an {@link Initializer} to the specified direct parameter of the block, overriding the @@ -173,15 +175,22 @@ default NDList forward( */ void setInitializer(Initializer initializer, String paramName); + /** + * Sets an {@link Initializer} to all the parameters that match Predicate in the block. + * + * @param initializer the initializer to be set + * @param predicate predicate function to indicate parameters you want to set + */ + void setInitializer(Initializer initializer, Predicate predicate); + /** * Initializes the parameters of the block. This method must be called before calling `forward`. * * @param manager the NDManager to initialize the parameters * @param dataType the datatype of the parameters * @param inputShapes the shapes of the inputs to the block - * @return the shapes of the outputs of the block */ - Shape[] initialize(NDManager manager, DataType dataType, Shape... inputShapes); + void initialize(NDManager manager, DataType dataType, Shape... inputShapes); /** * Returns a boolean whether the block is initialized. @@ -232,25 +241,13 @@ default NDList forward( */ ParameterList getParameters(); - /** - * Returns the shape of the specified direct parameter of this block given the shapes of the - * input to the block. - * - * @param name the name of the parameter - * @param inputShapes the shapes of the input to the block - * @return the shape of the parameter specified - * @throws IllegalArgumentException if the parameter name specified is invalid - */ - Shape getParameterShape(String name, Shape[] inputShapes); - /** * Returns the expected output shapes of the block for the specified input shapes. * - * @param manager an NDManager * @param inputShapes the shapes of the inputs * @return the expected output shapes of the block */ - Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes); + Shape[] getOutputShapes(Shape[] inputShapes); /** * Writes the parameters of the block to the given outputStream. diff --git a/api/src/main/java/ai/djl/nn/LambdaBlock.java b/api/src/main/java/ai/djl/nn/LambdaBlock.java index 939e0e09b29..1b6fc0d3bc5 100644 --- a/api/src/main/java/ai/djl/nn/LambdaBlock.java +++ b/api/src/main/java/ai/djl/nn/LambdaBlock.java @@ -70,11 +70,11 @@ protected NDList forwardInternal( /** {@inheritDoc} */ @Override - public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) { - try (NDManager subManager = manager.newSubManager()) { + public Shape[] getOutputShapes(Shape[] inputShapes) { + try (NDManager manager = NDManager.newBaseManager()) { NDList input = new NDList(inputShapes.length); for (Shape shape : inputShapes) { - input.add(subManager.zeros(shape)); + input.add(manager.zeros(shape)); } NDList output = lambda.apply(input); Shape[] outputShapes = new Shape[output.size()]; diff --git a/api/src/main/java/ai/djl/nn/ParallelBlock.java b/api/src/main/java/ai/djl/nn/ParallelBlock.java index 92ce9f6349e..9d04569a63f 100644 --- a/api/src/main/java/ai/djl/nn/ParallelBlock.java +++ b/api/src/main/java/ai/djl/nn/ParallelBlock.java @@ -149,16 +149,16 @@ public void initializeChildBlocks(NDManager manager, DataType dataType, Shape... /** {@inheritDoc} */ @Override - public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) { + public Shape[] getOutputShapes(Shape[] inputShapes) { Preconditions.checkArgument(!children.isEmpty(), "The parallel block is empty"); - try (NDManager subManager = manager.newSubManager()) { + try (NDManager manager = NDManager.newBaseManager()) { List inputs = new ArrayList<>(); for (Block block : children.values()) { - Shape[] shapes = block.getOutputShapes(manager, inputShapes); + Shape[] shapes = block.getOutputShapes(inputShapes); NDList output = new NDList(shapes.length); for (Shape shape : shapes) { - output.add(subManager.create(shape)); + output.add(manager.create(shape)); } inputs.add(output); } diff --git a/api/src/main/java/ai/djl/nn/Parameter.java b/api/src/main/java/ai/djl/nn/Parameter.java index 7d3c2954cde..369111057c7 100644 --- a/api/src/main/java/ai/djl/nn/Parameter.java +++ b/api/src/main/java/ai/djl/nn/Parameter.java @@ -19,6 +19,7 @@ import ai.djl.ndarray.types.Shape; import ai.djl.ndarray.types.SparseFormat; import ai.djl.training.initializer.Initializer; +import ai.djl.training.initializer.XavierInitializer; import java.io.DataInputStream; import java.io.DataOutputStream; import java.io.IOException; @@ -42,62 +43,23 @@ public class Parameter implements AutoCloseable { private String id; private String name; - private Block block; - private ParameterType type; - private DataType mandatoryDataType; + private Shape shape; + private Type type; private Initializer initializer; private NDArray array; private boolean requiresGrad; private SparseFormat gradientFormat; - /** - * Creates a {@code Parameter} with the given name, and parameter type, and associated with the - * given {@link Block}. - * - * @param name the name of the {@code Parameter} - * @param block the block with which this {@code Parameter} is associated - * @param type the type of this {@code Parameter} - */ - public Parameter(String name, Block block, ParameterType type) { - this(name, block, type, true, SparseFormat.DENSE); - } - - /** - * Creates a {@code Parameter} with the given name, and parameter type, and associated with the - * given {@link Block}. - * - * @param name the name of the {@code Parameter} - * @param block the block with which this {@code Parameter} is associated - * @param type the type of this {@code Parameter} - * @param requiresGrad whether this {@code Parameter} needs to compute gradients - */ - public Parameter(String name, Block block, ParameterType type, boolean requiresGrad) { - this(name, block, type, requiresGrad, SparseFormat.DENSE); - } - - /** - * Creates a {@code Parameter} with the given name, and parameter type, and associated with the - * given {@link Block}. - * - * @param name the name of the {@code Parameter} - * @param block the block with which this {@code Parameter} is associated - * @param type the type of this {@code Parameter} - * @param requireGrad whether this {@code Parameter} needs to compute gradients - * @param gradientFormat the {@link SparseFormat} of the gradient array - */ - public Parameter( - String name, - Block block, - ParameterType type, - boolean requireGrad, - SparseFormat gradientFormat) { + Parameter(Builder builder) { this.id = UUID.randomUUID().toString(); - this.name = name; - this.block = block; - this.type = type; - this.requiresGrad = requireGrad; - this.initializer = type.getInitializer(); - this.gradientFormat = gradientFormat; + this.name = builder.name; + this.shape = builder.shape; + this.type = builder.type; + this.array = builder.array; + this.requiresGrad = builder.requiresGrad; + this.initializer = + (builder.initializer != null) ? builder.initializer : type.getInitializer(); + this.gradientFormat = builder.gradientFormat; } /** @@ -123,7 +85,7 @@ public String getName() { * * @return the type of this {@code Parameter} */ - public ParameterType getType() { + public Type getType() { return type; } @@ -133,10 +95,26 @@ public ParameterType getType() { * @param array the {@link NDArray} that contains values of this {@code Parameter} */ public void setArray(NDArray array) { + if (shape != null) { + throw new IllegalStateException("array has been set! Use either setArray or setShape"); + } this.array = array; + shape = array.getShape(); array.setName(name); } + /** + * Sets the shape of this {@code Parameter}. + * + * @param shape the shape of this {@code Parameter} + */ + public void setShape(Shape shape) { + if (array != null) { + throw new IllegalStateException("array has been set! Use either setArray or setShape"); + } + this.shape = shape; + } + /** * Gets the values of this {@code Parameter} as an {@link NDArray}. * @@ -158,15 +136,6 @@ public boolean requireGradient() { return requiresGrad; } - /** - * Sets the mandatory data type for this {@code Parameter}. - * - * @param mandatoryDataType the mandatory data type for this {@code Parameter} - */ - public void setMandatoryDataType(DataType mandatoryDataType) { - this.mandatoryDataType = mandatoryDataType; - } - /** * Checks if this {@code Parameter} is initialized. * @@ -181,12 +150,9 @@ public boolean isInitialized() { * flag is true, sets the initializer regardless. * * @param initializer the initializer to be set - * @param overwrite if true, set the initializer regardless of whether its already set or not */ - public void setInitializer(Initializer initializer, boolean overwrite) { - if (overwrite || this.initializer == null) { - this.initializer = initializer; - } + public void setInitializer(Initializer initializer) { + this.initializer = initializer; } /** @@ -195,17 +161,12 @@ public void setInitializer(Initializer initializer, boolean overwrite) { * * @param manager an NDManager to create the arrays * @param dataType the datatype of the {@code Parameter} - * @param inputShapes the expected input shapes */ - public void initialize(NDManager manager, DataType dataType, Shape[] inputShapes) { + public void initialize(NDManager manager, DataType dataType) { Objects.requireNonNull(initializer, "No initializer has been set"); + Objects.requireNonNull(shape, "No parameter shape has been set"); if (!isInitialized()) { - Shape shape = block.getParameterShape(name, inputShapes); - array = - initializer.initialize( - manager, - shape, - mandatoryDataType == null ? dataType : mandatoryDataType); + array = initializer.initialize(manager, shape, dataType); array.setName(name); } @@ -266,6 +227,8 @@ public void load(NDManager manager, DataInputStream dis) } array = manager.decode(dis); + // set the shape of the parameter and prepare() can be skipped + shape = array.getShape(); } /** {@inheritDoc} */ @@ -276,4 +239,141 @@ public void close() { array = null; } } + + /** + * Creates a builder to build a {@code Parameter}. + * + *

The methods start with {@code set} are required fields, and {@code opt} for optional + * fields. + * + * @return a new builder + */ + public static Parameter.Builder builder() { + return new Parameter.Builder(); + } + + /** Enumerates the types of {@link Parameter}. */ + public enum Type { + WEIGHT( + new XavierInitializer( + XavierInitializer.RandomType.GAUSSIAN, XavierInitializer.FactorType.IN, 2)), + BIAS(Initializer.ZEROS), + GAMMA(Initializer.ONES), + BETA(Initializer.ZEROS), + RUNNING_MEAN(Initializer.ZEROS), + RUNNING_VAR(Initializer.ONES), + OTHER(null); + + private final transient Initializer initializer; + + Type(Initializer initializer) { + this.initializer = initializer; + } + + /** + * Gets the {@link Initializer} of this {@code ParameterType}. + * + * @return the {@link Initializer} of this {@code ParameterType} + */ + public Initializer getInitializer() { + return initializer; + } + } + + /** A Builder to construct a {@code Parameter}. */ + public static final class Builder { + String name; + Shape shape; + Type type; + Initializer initializer; + NDArray array; + boolean requiresGrad = true; + SparseFormat gradientFormat; + + /** + * Sets the name of the {@code Parameter}. + * + * @param name the name of the {@code Parameter} + * @return this {@code Parameter} + */ + public Builder setName(String name) { + this.name = name; + return this; + } + + /** + * Sets the {@code Type} of the {@code Parameter}. + * + * @param type the {@code Type} of the {@code Parameter} + * @return this {@code Parameter} + */ + public Builder setType(Type type) { + this.type = type; + return this; + } + + /** + * Sets the shape of the {@code Parameter}. + * + * @param shape the shape of the {@code Parameter} + * @return this {@code Parameter} + */ + public Builder optShape(Shape shape) { + this.shape = shape; + return this; + } + + /** + * Sets the Initializer of the {@code Parameter}. + * + * @param initializer the Initializer of the {@code Parameter} + * @return this {@code Parameter} + */ + public Builder optInitializer(Initializer initializer) { + this.initializer = initializer; + return this; + } + + /** + * Sets the array of the {@code Parameter}. + * + * @param array the array of the {@code Parameter} + * @return this {@code Parameter} + */ + public Builder optArray(NDArray array) { + this.array = array; + return this; + } + + /** + * Sets if the {@code Parameter} requires gradient. + * + * @param requiresGrad if the {@code Parameter} requires gradient + * @return this {@code Parameter} + */ + public Builder optRequiresGrad(boolean requiresGrad) { + this.requiresGrad = requiresGrad; + return this; + } + + /** + * Sets the {@link SparseFormat} of the {@code Parameter}. + * + * @param gradientFormat the {@link SparseFormat} of the {@code Parameter} + * @return this {@code Parameter} + */ + public Builder optGradientFormat(SparseFormat gradientFormat) { + this.gradientFormat = gradientFormat; + return this; + } + + /** + * Builds a {@code Parameter} instance. + * + * @return the {@code Parameter} instance + */ + public Parameter build() { + return new Parameter(this); + } + } } diff --git a/api/src/main/java/ai/djl/nn/ParameterType.java b/api/src/main/java/ai/djl/nn/ParameterType.java deleted file mode 100644 index dde3cb23ab8..00000000000 --- a/api/src/main/java/ai/djl/nn/ParameterType.java +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance - * with the License. A copy of the License is located at - * - * http://aws.amazon.com/apache2.0/ - * - * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES - * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions - * and limitations under the License. - */ -package ai.djl.nn; - -import ai.djl.training.initializer.Initializer; - -/** Enumerates the types of {@link Parameter}. */ -public enum ParameterType { - WEIGHT(null), - BIAS(Initializer.ZEROS), - GAMMA(Initializer.ONES), - BETA(Initializer.ZEROS), - RUNNING_MEAN(Initializer.ZEROS), - RUNNING_VAR(Initializer.ONES), - OTHER(null); - - private final transient Initializer initializer; - - ParameterType(Initializer initializer) { - this.initializer = initializer; - } - - /** - * Gets the {@link Initializer} of this {@code ParameterType}. - * - * @return the {@link Initializer} of this {@code ParameterType} - */ - public Initializer getInitializer() { - return initializer; - } -} diff --git a/api/src/main/java/ai/djl/nn/SequentialBlock.java b/api/src/main/java/ai/djl/nn/SequentialBlock.java index d0c6f168fb8..8908c1ad6c8 100644 --- a/api/src/main/java/ai/djl/nn/SequentialBlock.java +++ b/api/src/main/java/ai/djl/nn/SequentialBlock.java @@ -154,19 +154,20 @@ protected NDList forwardInternal( public void initializeChildBlocks(NDManager manager, DataType dataType, Shape... inputShapes) { Shape[] shapes = inputShapes; for (Block child : getChildren().values()) { - shapes = child.initialize(manager, dataType, shapes); + child.initialize(manager, dataType, shapes); + shapes = child.getOutputShapes(shapes); } } /** {@inheritDoc} */ @Override - public Shape[] getOutputShapes(NDManager manager, Shape[] inputs) { + public Shape[] getOutputShapes(Shape[] inputs) { if (children.isEmpty()) { throw new IllegalArgumentException("The sequential block is empty"); } Shape[] current = inputs; for (Block block : children.values()) { - current = block.getOutputShapes(manager, current); + current = block.getOutputShapes(current); } return current; } diff --git a/api/src/main/java/ai/djl/nn/convolutional/Convolution.java b/api/src/main/java/ai/djl/nn/convolutional/Convolution.java index 6334935804c..1d211f1d903 100644 --- a/api/src/main/java/ai/djl/nn/convolutional/Convolution.java +++ b/api/src/main/java/ai/djl/nn/convolutional/Convolution.java @@ -16,13 +16,11 @@ import ai.djl.MalformedModelException; import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDList; -import ai.djl.ndarray.NDManager; import ai.djl.ndarray.types.LayoutType; import ai.djl.ndarray.types.Shape; import ai.djl.nn.AbstractBlock; import ai.djl.nn.Block; import ai.djl.nn.Parameter; -import ai.djl.nn.ParameterType; import ai.djl.training.ParameterStore; import ai.djl.util.PairList; import java.io.DataInputStream; @@ -102,13 +100,17 @@ public Convolution(ConvolutionBuilder builder) { weight = addParameter( - new Parameter("weight", this, ParameterType.WEIGHT), - (inputShapes) -> - new Shape(filters, inputShapes[0].get(1)).addAll(kernelShape)); + Parameter.builder() + .setName("weight") + .setType(Parameter.Type.WEIGHT) + .build()); if (includeBias) { bias = addParameter( - new Parameter("bias", this, ParameterType.BIAS), new Shape(filters)); + Parameter.builder() + .setName("bias") + .setType(Parameter.Type.BIAS) + .build()); } } @@ -149,15 +151,24 @@ protected NDList forwardInternal( /** {@inheritDoc} */ @Override - protected void beforeInitialize(Shape[] inputs) { - super.beforeInitialize(inputs); - Shape inputShape = inputs[0]; - Block.validateLayout(getExpectedLayout(), inputShape.getLayout()); + protected void beforeInitialize(Shape... inputShapes) { + super.beforeInitialize(inputShapes); + Block.validateLayout(getExpectedLayout(), inputShapes[0].getLayout()); } /** {@inheritDoc} */ @Override - public Shape[] getOutputShapes(NDManager manager, Shape[] inputs) { + protected void prepare(Shape[] inputs) { + long inputChannel = inputs[0].get(1); + weight.setShape(new Shape(filters, inputChannel / groups).addAll(kernelShape)); + if (bias != null) { + bias.setShape(new Shape(filters)); + } + } + + /** {@inheritDoc} */ + @Override + public Shape[] getOutputShapes(Shape[] inputs) { long[] shape = new long[numDimensions()]; shape[0] = inputs[0].get(0); shape[1] = filters; diff --git a/api/src/main/java/ai/djl/nn/convolutional/Deconvolution.java b/api/src/main/java/ai/djl/nn/convolutional/Deconvolution.java index 890a4db9eff..10059208a5c 100644 --- a/api/src/main/java/ai/djl/nn/convolutional/Deconvolution.java +++ b/api/src/main/java/ai/djl/nn/convolutional/Deconvolution.java @@ -16,13 +16,11 @@ import ai.djl.MalformedModelException; import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDList; -import ai.djl.ndarray.NDManager; import ai.djl.ndarray.types.LayoutType; import ai.djl.ndarray.types.Shape; import ai.djl.nn.AbstractBlock; import ai.djl.nn.Block; import ai.djl.nn.Parameter; -import ai.djl.nn.ParameterType; import ai.djl.training.ParameterStore; import ai.djl.util.PairList; import java.io.DataInputStream; @@ -78,13 +76,17 @@ public Deconvolution(DeconvolutionBuilder builder) { weight = addParameter( - new Parameter("weight", this, ParameterType.WEIGHT), - (inputShapes) -> - new Shape(filters, inputShapes[0].get(1)).addAll(kernelShape)); + Parameter.builder() + .setName("weight") + .setType(Parameter.Type.WEIGHT) + .build()); if (includeBias) { bias = addParameter( - new Parameter("bias", this, ParameterType.BIAS), new Shape(filters)); + Parameter.builder() + .setName("bias") + .setType(Parameter.Type.BIAS) + .build()); } } @@ -126,15 +128,24 @@ protected NDList forwardInternal( /** {@inheritDoc} */ @Override - protected void beforeInitialize(Shape[] inputs) { - super.beforeInitialize(inputs); - Shape inputShape = inputs[0]; - Block.validateLayout(getExpectedLayout(), inputShape.getLayout()); + protected void beforeInitialize(Shape... inputShapes) { + super.beforeInitialize(inputShapes); + Block.validateLayout(getExpectedLayout(), inputShapes[0].getLayout()); } /** {@inheritDoc} */ @Override - public Shape[] getOutputShapes(NDManager manager, Shape[] inputs) { + protected void prepare(Shape[] inputs) { + long inputChannel = inputs[0].get(1); + weight.setShape(new Shape(filters, inputChannel / groups).addAll(kernelShape)); + if (bias != null) { + bias.setShape(new Shape(filters)); + } + } + + /** {@inheritDoc} */ + @Override + public Shape[] getOutputShapes(Shape[] inputs) { long[] shape = new long[numDimensions()]; shape[0] = inputs[0].get(0); shape[1] = filters; diff --git a/api/src/main/java/ai/djl/nn/core/ConstantEmbedding.java b/api/src/main/java/ai/djl/nn/core/ConstantEmbedding.java index e2dceddcef4..ce5c486d701 100644 --- a/api/src/main/java/ai/djl/nn/core/ConstantEmbedding.java +++ b/api/src/main/java/ai/djl/nn/core/ConstantEmbedding.java @@ -57,7 +57,7 @@ protected NDList forwardInternal( /** {@inheritDoc} */ @Override - public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) { + public Shape[] getOutputShapes(Shape[] inputShapes) { return new Shape[] {inputShapes[0].addAll(embedding.getShape())}; } diff --git a/api/src/main/java/ai/djl/nn/core/Embedding.java b/api/src/main/java/ai/djl/nn/core/Embedding.java index 48736aaa8e1..2e65295c1f7 100644 --- a/api/src/main/java/ai/djl/nn/core/Embedding.java +++ b/api/src/main/java/ai/djl/nn/core/Embedding.java @@ -23,7 +23,6 @@ import ai.djl.nn.AbstractBlock; import ai.djl.nn.Block; import ai.djl.nn.Parameter; -import ai.djl.nn.ParameterType; import ai.djl.training.ParameterStore; import ai.djl.util.PairList; import java.io.DataInputStream; @@ -56,8 +55,11 @@ protected Embedding(BaseBuilder baseBuilder) { sparseFormat = baseBuilder.sparseFormat; embedding = addParameter( - new Parameter("embedding", this, ParameterType.WEIGHT, true, sparseFormat), - (inputShapes) -> new Shape(numEmbeddings, embeddingSize)); + Parameter.builder() + .setName("embedding") + .setType(Parameter.Type.WEIGHT) + .optGradientFormat(sparseFormat) + .build()); if (baseBuilder.fallthrough != null && baseBuilder.defaultItem != null) { throw new IllegalArgumentException( "You can not specify both a fallthrough and a defaultItem"); @@ -93,15 +95,25 @@ public Embedding(NDArray embedding, SparseFormat format) { this.sparseFormat = format; this.embedding = addParameter( - new Parameter("embedding", this, ParameterType.WEIGHT, true, sparseFormat), - (inputShapes) -> new Shape(numEmbeddings, embeddingSize)); + Parameter.builder() + .setName("embedding") + .setType(Parameter.Type.WEIGHT) + .optGradientFormat(sparseFormat) + .build()); this.embedding.setArray(embedding); inputShapes = new Shape[] {new Shape(-1)}; } /** {@inheritDoc} */ @Override - public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) { + public void prepare(Shape[] inputShapes) { + // numItems will be adjusted by embedding array or fallthroughEmbedding + embedding.setShape(new Shape(numEmbeddings, embeddingSize)); + } + + /** {@inheritDoc} */ + @Override + public Shape[] getOutputShapes(Shape[] inputShapes) { return new Shape[] {inputShapes[0].addAll(new Shape(embeddingSize))}; } diff --git a/api/src/main/java/ai/djl/nn/core/Linear.java b/api/src/main/java/ai/djl/nn/core/Linear.java index 34a8c020a18..79459348fb2 100644 --- a/api/src/main/java/ai/djl/nn/core/Linear.java +++ b/api/src/main/java/ai/djl/nn/core/Linear.java @@ -16,12 +16,10 @@ import ai.djl.MalformedModelException; import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDList; -import ai.djl.ndarray.NDManager; import ai.djl.ndarray.types.Shape; import ai.djl.nn.AbstractBlock; import ai.djl.nn.Block; import ai.djl.nn.Parameter; -import ai.djl.nn.ParameterType; import ai.djl.training.ParameterStore; import ai.djl.util.PairList; import ai.djl.util.Preconditions; @@ -59,14 +57,19 @@ public class Linear extends AbstractBlock { Linear(Builder builder) { super(VERSION); units = builder.units; - // "inputFeatures" is only known after "beforeInitialize" is called, hence we need - // a callback, even if we do not used the callback parameter weight = addParameter( - new Parameter("weight", this, ParameterType.WEIGHT), - inputShapes -> new Shape(units, inputFeatures)); + Parameter.builder() + .setName("weight") + .setType(Parameter.Type.WEIGHT) + .build()); if (builder.bias) { - bias = addParameter(new Parameter("bias", this, ParameterType.BIAS), new Shape(units)); + bias = + addParameter( + Parameter.builder() + .setName("bias") + .setType(Parameter.Type.BIAS) + .build()); } } @@ -86,8 +89,8 @@ protected NDList forwardInternal( /** {@inheritDoc} */ @Override - public Shape[] getOutputShapes(NDManager manager, Shape[] inputs) { - return new Shape[] {inputShape.addAll(new Shape(units))}; + public Shape[] getOutputShapes(Shape[] inputs) { + return new Shape[] {inputs[0].slice(0, inputs[0].dimension() - 1).add(units)}; } /** {@inheritDoc} */ @@ -99,13 +102,24 @@ public PairList describeInput() { /** {@inheritDoc} */ @Override - public void beforeInitialize(Shape[] inputShapes) { + protected void beforeInitialize(Shape... inputShapes) { super.beforeInitialize(inputShapes); + Preconditions.checkArgument(inputShapes.length == 1, "Linear block only support 1 input"); Shape input = inputShapes[0]; inputFeatures = input.get(input.dimension() - 1); inputShape = input.slice(0, input.dimension() - 1); } + /** {@inheritDoc} */ + @Override + public void prepare(Shape[] inputShapes) { + Shape input = inputShapes[0]; + weight.setShape(new Shape(units, input.get(input.dimension() - 1))); + if (bias != null) { + bias.setShape(new Shape(units)); + } + } + /** {@inheritDoc} */ @Override protected void saveMetadata(DataOutputStream os) throws IOException { diff --git a/api/src/main/java/ai/djl/nn/core/Prelu.java b/api/src/main/java/ai/djl/nn/core/Prelu.java index abb268f9446..ced9d709e87 100644 --- a/api/src/main/java/ai/djl/nn/core/Prelu.java +++ b/api/src/main/java/ai/djl/nn/core/Prelu.java @@ -15,12 +15,10 @@ import ai.djl.MalformedModelException; import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDList; -import ai.djl.ndarray.NDManager; import ai.djl.ndarray.internal.NDArrayEx; import ai.djl.ndarray.types.Shape; import ai.djl.nn.AbstractBlock; import ai.djl.nn.Parameter; -import ai.djl.nn.ParameterType; import ai.djl.training.ParameterStore; import ai.djl.util.PairList; import java.io.DataInputStream; @@ -44,7 +42,13 @@ public class Prelu extends AbstractBlock { /** Creates a Parametric ReLU Block. */ public Prelu() { super(VERSION); - alpha = addParameter(new Parameter("alpha", this, ParameterType.OTHER), new Shape()); + alpha = + addParameter( + Parameter.builder() + .setName("alpha") + .setType(Parameter.Type.WEIGHT) + .optShape(new Shape()) + .build()); } /** {@inheritDoc} */ @@ -61,7 +65,7 @@ protected NDList forwardInternal( /** {@inheritDoc} */ @Override - public Shape[] getOutputShapes(NDManager manager, Shape[] inputs) { + public Shape[] getOutputShapes(Shape[] inputs) { return new Shape[] {inputs[0]}; } diff --git a/api/src/main/java/ai/djl/nn/norm/BatchNorm.java b/api/src/main/java/ai/djl/nn/norm/BatchNorm.java index 7f5f243646b..dc93360811a 100644 --- a/api/src/main/java/ai/djl/nn/norm/BatchNorm.java +++ b/api/src/main/java/ai/djl/nn/norm/BatchNorm.java @@ -16,12 +16,10 @@ import ai.djl.MalformedModelException; import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDList; -import ai.djl.ndarray.NDManager; import ai.djl.ndarray.internal.NDArrayEx; import ai.djl.ndarray.types.Shape; import ai.djl.nn.AbstractBlock; import ai.djl.nn.Parameter; -import ai.djl.nn.ParameterType; import ai.djl.training.ParameterStore; import ai.djl.util.PairList; import java.io.DataInputStream; @@ -86,26 +84,37 @@ public class BatchNorm extends AbstractBlock { momentum = builder.momentum; center = builder.center; scale = builder.scale; - // When creating parameters we use a callback as "inChannels" is set before initialization, - // it is not known yet. + // make gamma trainable if scale gamma = addParameter( - new Parameter("gamma", this, ParameterType.GAMMA, scale), - (inputShapes) -> new Shape(inChannels)); + Parameter.builder() + .setName("gamma") + .setType(Parameter.Type.GAMMA) + .optRequiresGrad(scale) + .build()); // make beta trainable if center beta = addParameter( - new Parameter("beta", this, ParameterType.BETA, center), - (inputShapes) -> new Shape(inChannels)); + Parameter.builder() + .setName("beta") + .setType(Parameter.Type.BETA) + .optRequiresGrad(center) + .build()); runningMean = addParameter( - new Parameter("runningMean", this, ParameterType.RUNNING_MEAN, false), - (inputShapes) -> new Shape(inChannels)); + Parameter.builder() + .setName("runningMean") + .setType(Parameter.Type.RUNNING_MEAN) + .optRequiresGrad(false) + .build()); runningVar = addParameter( - new Parameter("runningVar", this, ParameterType.RUNNING_VAR, false), - (inputShapes) -> new Shape(inChannels)); + Parameter.builder() + .setName("runningVar") + .setType(Parameter.Type.RUNNING_VAR) + .optRequiresGrad(false) + .build()); } /** {@inheritDoc} */ @@ -135,17 +144,26 @@ protected NDList forwardInternal( /** {@inheritDoc} */ @Override - public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) { + public Shape[] getOutputShapes(Shape[] inputShapes) { return new Shape[] {inputShapes[0]}; } /** {@inheritDoc} */ @Override - public void beforeInitialize(Shape[] inputShapes) { + protected void beforeInitialize(Shape... inputShapes) { super.beforeInitialize(inputShapes); inChannels = inputShapes[0].size(axis); } + /** {@inheritDoc} */ + @Override + public void prepare(Shape[] inputShapes) { + gamma.setShape(new Shape(inChannels)); + beta.setShape(new Shape(inChannels)); + runningMean.setShape(new Shape(inChannels)); + runningVar.setShape(new Shape(inChannels)); + } + /** {@inheritDoc} */ @Override protected void saveMetadata(DataOutputStream os) throws IOException { diff --git a/api/src/main/java/ai/djl/nn/norm/Dropout.java b/api/src/main/java/ai/djl/nn/norm/Dropout.java index 7a44c954288..20bcd201138 100644 --- a/api/src/main/java/ai/djl/nn/norm/Dropout.java +++ b/api/src/main/java/ai/djl/nn/norm/Dropout.java @@ -15,7 +15,6 @@ import ai.djl.MalformedModelException; import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDList; -import ai.djl.ndarray.NDManager; import ai.djl.ndarray.internal.NDArrayEx; import ai.djl.ndarray.types.Shape; import ai.djl.nn.AbstractBlock; @@ -76,7 +75,7 @@ protected NDList forwardInternal( /** {@inheritDoc} */ @Override - public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) { + public Shape[] getOutputShapes(Shape[] inputShapes) { return new Shape[] {inputShapes[0]}; } diff --git a/api/src/main/java/ai/djl/nn/recurrent/RecurrentBlock.java b/api/src/main/java/ai/djl/nn/recurrent/RecurrentBlock.java index 1251c32ee49..c51620f5dc5 100644 --- a/api/src/main/java/ai/djl/nn/recurrent/RecurrentBlock.java +++ b/api/src/main/java/ai/djl/nn/recurrent/RecurrentBlock.java @@ -13,13 +13,13 @@ package ai.djl.nn.recurrent; import ai.djl.MalformedModelException; -import ai.djl.ndarray.NDManager; import ai.djl.ndarray.types.LayoutType; import ai.djl.ndarray.types.Shape; import ai.djl.nn.AbstractBlock; import ai.djl.nn.Block; import ai.djl.nn.Parameter; -import ai.djl.nn.ParameterType; +import ai.djl.nn.ParameterList; +import ai.djl.util.Pair; import java.io.DataInputStream; import java.io.IOException; @@ -67,7 +67,7 @@ public RecurrentBlock(BaseBuilder builder) { bidirectional = builder.bidirectional; returnState = builder.returnState; - ParameterType[] parameterTypes = {ParameterType.WEIGHT, ParameterType.BIAS}; + Parameter.Type[] parameterTypes = {Parameter.Type.WEIGHT, Parameter.Type.BIAS}; String[] directions = {"l"}; if (builder.bidirectional) { directions = new String[] {"l", "r"}; @@ -75,12 +75,13 @@ public RecurrentBlock(BaseBuilder builder) { String[] gateStrings = {"i2h", "h2h"}; for (int i = 0; i < numLayers; i++) { - for (ParameterType parameterType : parameterTypes) { + for (Parameter.Type parameterType : parameterTypes) { for (String direction : directions) { for (String gateString : gateStrings) { String name = direction + '_' + i + '_' + gateString + '_' + parameterType.name(); - addParameter(new Parameter(name, this, parameterType)); + addParameter( + Parameter.builder().setName(name).setType(parameterType).build()); } } } @@ -89,7 +90,7 @@ public RecurrentBlock(BaseBuilder builder) { /** {@inheritDoc} */ @Override - public Shape[] getOutputShapes(NDManager manager, Shape[] inputs) { + public Shape[] getOutputShapes(Shape[] inputs) { Shape inputShape = inputs[0]; Shape outputShape = new Shape(inputShape.get(0), inputShape.get(1), stateSize * getNumDirections()); @@ -109,31 +110,34 @@ public Shape[] getOutputShapes(NDManager manager, Shape[] inputs) { /** {@inheritDoc} */ @Override - public void beforeInitialize(Shape[] inputs) { - super.beforeInitialize(inputs); - Shape inputShape = inputs[0]; - Block.validateLayout(EXPECTED_LAYOUT, inputShape.getLayout()); + protected void beforeInitialize(Shape... inputShapes) { + super.beforeInitialize(inputShapes); + Block.validateLayout(EXPECTED_LAYOUT, inputShapes[0].getLayout()); } /** {@inheritDoc} */ @Override - public Shape getParameterShape(String name, Shape[] inputShapes) { - int layer = Integer.parseInt(name.split("_")[1]); - Shape shape = inputShapes[0]; - long inputs = shape.get(2); - if (layer > 0) { - inputs = stateSize * getNumDirections(); - } - if (name.contains("BIAS")) { - return new Shape(gates * stateSize); - } - if (name.contains("i2h")) { - return new Shape(gates * stateSize, inputs); - } - if (name.contains("h2h")) { - return new Shape(gates * stateSize, stateSize); + public void prepare(Shape[] inputs) { + Shape inputShape = inputs[0]; + ParameterList parameters = getDirectParameters(); + for (Pair pair : parameters) { + String name = pair.getKey(); + Parameter parameter = pair.getValue(); + int layer = Integer.parseInt(name.split("_")[1]); + long inputSize = inputShape.get(2); + if (layer > 0) { + inputSize = stateSize * getNumDirections(); + } + if (name.contains("BIAS")) { + parameter.setShape(new Shape(gates * stateSize)); + } else if (name.contains("i2h")) { + parameter.setShape(new Shape(gates * stateSize, inputSize)); + } else if (name.contains("h2h")) { + parameter.setShape(new Shape(gates * stateSize, stateSize)); + } else { + throw new IllegalArgumentException("Invalid parameter name"); + } } - throw new IllegalArgumentException("Invalid parameter name"); } /** {@inheritDoc} */ diff --git a/api/src/main/java/ai/djl/nn/transformer/BertBlock.java b/api/src/main/java/ai/djl/nn/transformer/BertBlock.java index 066fc0c4229..b8ca2caf120 100644 --- a/api/src/main/java/ai/djl/nn/transformer/BertBlock.java +++ b/api/src/main/java/ai/djl/nn/transformer/BertBlock.java @@ -22,7 +22,6 @@ import ai.djl.nn.Activation; import ai.djl.nn.Block; import ai.djl.nn.Parameter; -import ai.djl.nn.ParameterType; import ai.djl.nn.core.Linear; import ai.djl.nn.norm.BatchNorm; import ai.djl.nn.norm.Dropout; @@ -77,8 +76,12 @@ private BertBlock(Builder builder) { // embedding for the position this.positionEmebdding = addParameter( - new Parameter(PARAM_POSITION_EMBEDDING, this, ParameterType.WEIGHT), - new Shape(builder.maxSequenceLength, builder.embeddingSize)); + Parameter.builder() + .setName(PARAM_POSITION_EMBEDDING) + .setType(Parameter.Type.WEIGHT) + .optShape( + new Shape(builder.maxSequenceLength, builder.embeddingSize)) + .build()); // embedding for the input types this.typeEmbedding = addChildBlock( @@ -153,7 +156,7 @@ public int getTypeDictionarySize() { /** {@inheritDoc} */ @Override - public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) { + public Shape[] getOutputShapes(Shape[] inputShapes) { long batch = inputShapes[0].get(0); long seqLength = inputShapes[0].get(1); return new Shape[] { @@ -164,11 +167,12 @@ public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) { /** {@inheritDoc} */ @Override public void initializeChildBlocks(NDManager manager, DataType dataType, Shape... inputShapes) { - beforeInitialize(inputShapes); + super.beforeInitialize(inputShapes); inputNames = Arrays.asList("tokenIds", "typeIds", "masks"); Shape[] tokenShape = {inputShapes[0]}; Shape[] typeShape = {inputShapes[1]}; - Shape[] embeddingOutput = this.tokenEmbedding.initialize(manager, dataType, tokenShape); + this.tokenEmbedding.initialize(manager, dataType, tokenShape); + Shape[] embeddingOutput = this.tokenEmbedding.getOutputShapes(tokenShape); this.typeEmbedding.initialize(manager, dataType, typeShape); this.embeddingNorm.initialize(manager, dataType, embeddingOutput); this.embeddingDropout.initialize(manager, dataType, embeddingOutput); diff --git a/api/src/main/java/ai/djl/nn/transformer/BertMaskedLanguageModelBlock.java b/api/src/main/java/ai/djl/nn/transformer/BertMaskedLanguageModelBlock.java index ab9a7216803..7fc336434b6 100644 --- a/api/src/main/java/ai/djl/nn/transformer/BertMaskedLanguageModelBlock.java +++ b/api/src/main/java/ai/djl/nn/transformer/BertMaskedLanguageModelBlock.java @@ -19,7 +19,6 @@ import ai.djl.ndarray.types.Shape; import ai.djl.nn.AbstractBlock; import ai.djl.nn.Parameter; -import ai.djl.nn.ParameterType; import ai.djl.nn.core.Linear; import ai.djl.nn.norm.BatchNorm; import ai.djl.training.ParameterStore; @@ -59,8 +58,11 @@ public BertMaskedLanguageModelBlock( this.sequenceNorm = addChildBlock("sequenceNorm", BatchNorm.builder().optAxis(1).build()); this.dictionaryBias = addParameter( - new Parameter("dictionaryBias", this, ParameterType.BIAS), - new Shape(bertBlock.getTokenDictionarySize())); + Parameter.builder() + .setName("dictionaryBias") + .setType(Parameter.Type.BIAS) + .optShape(new Shape(bertBlock.getTokenDictionarySize())) + .build()); this.hiddenActivation = hiddenActivation; } @@ -140,7 +142,7 @@ protected NDList forwardInternal( /** {@inheritDoc} */ @Override - public Shape[] getOutputShapes(final NDManager manager, final Shape[] inputShapes) { + public Shape[] getOutputShapes(final Shape[] inputShapes) { int batchSize = (int) inputShapes[0].get(0); int indexCount = (int) inputShapes[1].get(1); int dictionarySize = (int) inputShapes[2].get(0); diff --git a/api/src/main/java/ai/djl/nn/transformer/BertNextSentenceBlock.java b/api/src/main/java/ai/djl/nn/transformer/BertNextSentenceBlock.java index 40f9a717b65..aac91ab5745 100644 --- a/api/src/main/java/ai/djl/nn/transformer/BertNextSentenceBlock.java +++ b/api/src/main/java/ai/djl/nn/transformer/BertNextSentenceBlock.java @@ -53,7 +53,7 @@ protected NDList forwardInternal( /** {@inheritDoc} */ @Override - public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) { + public Shape[] getOutputShapes(Shape[] inputShapes) { return new Shape[] {new Shape(inputShapes[0].get(0), 2)}; } } diff --git a/api/src/main/java/ai/djl/nn/transformer/BertPretrainingBlock.java b/api/src/main/java/ai/djl/nn/transformer/BertPretrainingBlock.java index 2478811d037..d7f44089564 100644 --- a/api/src/main/java/ai/djl/nn/transformer/BertPretrainingBlock.java +++ b/api/src/main/java/ai/djl/nn/transformer/BertPretrainingBlock.java @@ -51,7 +51,8 @@ public BertPretrainingBlock(final BertBlock.Builder builder) { @Override public void initializeChildBlocks(NDManager manager, DataType dataType, Shape... inputShapes) { inputNames = Arrays.asList("tokenIds", "typeIds", "sequenceMasks", "maskedIndices"); - Shape[] bertOutputShapes = bertBlock.initialize(manager, dataType, inputShapes); + bertBlock.initialize(manager, dataType, inputShapes); + Shape[] bertOutputShapes = bertBlock.getOutputShapes(inputShapes); Shape embeddedSequence = bertOutputShapes[0]; Shape pooledOutput = bertOutputShapes[1]; Shape maskedIndices = inputShapes[2]; @@ -97,7 +98,7 @@ protected NDList forwardInternal( /** {@inheritDoc} */ @Override - public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) { + public Shape[] getOutputShapes(Shape[] inputShapes) { long batchSize = inputShapes[0].get(0); long maskedIndexCount = inputShapes[3].get(1); return new Shape[] { diff --git a/api/src/main/java/ai/djl/nn/transformer/IdEmbedding.java b/api/src/main/java/ai/djl/nn/transformer/IdEmbedding.java index 4aa6484f766..2b3c52d3526 100644 --- a/api/src/main/java/ai/djl/nn/transformer/IdEmbedding.java +++ b/api/src/main/java/ai/djl/nn/transformer/IdEmbedding.java @@ -21,7 +21,6 @@ import ai.djl.nn.AbstractBlock; import ai.djl.nn.Block; import ai.djl.nn.Parameter; -import ai.djl.nn.ParameterType; import ai.djl.training.ParameterStore; import ai.djl.util.PairList; import java.util.Arrays; @@ -47,13 +46,16 @@ private IdEmbedding(Builder builder) { this.embeddingSize = builder.embeddingSize; this.embedding = addParameter( - new Parameter(EMBEDDING_PARAM_NAME, this, ParameterType.WEIGHT), - new Shape(dictionarySize, embeddingSize)); + Parameter.builder() + .setName(EMBEDDING_PARAM_NAME) + .setType(Parameter.Type.WEIGHT) + .optShape(new Shape(dictionarySize, embeddingSize)) + .build()); } /** {@inheritDoc} */ @Override - public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) { + public Shape[] getOutputShapes(Shape[] inputShapes) { return new Shape[] {inputShapes[0].addAll(new Shape(embeddingSize))}; } diff --git a/api/src/main/java/ai/djl/nn/transformer/PointwiseFeedForwardBlock.java b/api/src/main/java/ai/djl/nn/transformer/PointwiseFeedForwardBlock.java index 10d4aa20b40..9d3367f9cf5 100644 --- a/api/src/main/java/ai/djl/nn/transformer/PointwiseFeedForwardBlock.java +++ b/api/src/main/java/ai/djl/nn/transformer/PointwiseFeedForwardBlock.java @@ -32,8 +32,6 @@ public class PointwiseFeedForwardBlock extends AbstractBlock { private static final byte VERSION = 1; - private Shape outputShape; - /** * Creates a pointwise feed-forward block. * @@ -49,7 +47,7 @@ public PointwiseFeedForwardBlock( super(VERSION); // add hidden layers with activation int count = 0; - for (final int hiddenSize : hiddenSizes) { + for (int hiddenSize : hiddenSizes) { addChildBlock( "linear_" + count, Linear.builder().optBias(true).setUnits(hiddenSize).build()); addChildBlock("activation_" + count, new LambdaBlock(activationFunction)); @@ -61,8 +59,11 @@ public PointwiseFeedForwardBlock( /** {@inheritDoc} */ @Override - public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) { - return new Shape[] {outputShape}; + public Shape[] getOutputShapes(Shape[] inputShapes) { + for (Block child : children.values()) { + inputShapes = child.getOutputShapes(inputShapes); + } + return inputShapes; } /** {@inheritDoc} */ @@ -75,16 +76,16 @@ public void initializeChildBlocks(NDManager manager, DataType dataType, Shape... } // Now that we know the input shape, we can determine the reshape necessary // to shape the input and re-shape the output - final Shape inputShape = inputShapes[0]; + Shape inputShape = inputShapes[0]; if (inputShape.dimension() < 2) { throw new IllegalArgumentException( "Pointwise feed forward blocks need an input of at least dimension 2."); } Shape lastShape = inputShape; for (Block child : children.values()) { - lastShape = child.initialize(manager, dataType, lastShape)[0]; + child.initialize(manager, dataType, lastShape); + lastShape = getOutputShapes(new Shape[] {lastShape})[0]; } - outputShape = lastShape; } /** {@inheritDoc} */ diff --git a/api/src/main/java/ai/djl/nn/transformer/ScaledDotProductAttentionBlock.java b/api/src/main/java/ai/djl/nn/transformer/ScaledDotProductAttentionBlock.java index ffee4df219c..fa293f4b37a 100644 --- a/api/src/main/java/ai/djl/nn/transformer/ScaledDotProductAttentionBlock.java +++ b/api/src/main/java/ai/djl/nn/transformer/ScaledDotProductAttentionBlock.java @@ -146,7 +146,7 @@ public Linear getResultProjection() { /** {@inheritDoc} */ @Override - public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) { + public Shape[] getOutputShapes(Shape[] inputShapes) { // Return shape is the shape of the query. For 2 or less inputs we have self-attention, i.e. // the shape of the output is the shape of the input if (inputShapes.length == 1 || inputShapes.length == 2) { diff --git a/api/src/main/java/ai/djl/nn/transformer/TransformerEncoderBlock.java b/api/src/main/java/ai/djl/nn/transformer/TransformerEncoderBlock.java index 83fb87446e9..4cbabecada7 100644 --- a/api/src/main/java/ai/djl/nn/transformer/TransformerEncoderBlock.java +++ b/api/src/main/java/ai/djl/nn/transformer/TransformerEncoderBlock.java @@ -82,7 +82,7 @@ public TransformerEncoderBlock( /** {@inheritDoc} */ @Override - public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) { + public Shape[] getOutputShapes(Shape[] inputShapes) { return inputShapes; } diff --git a/api/src/main/java/ai/djl/training/DefaultTrainingConfig.java b/api/src/main/java/ai/djl/training/DefaultTrainingConfig.java index be9f7bbfec6..b9580914201 100644 --- a/api/src/main/java/ai/djl/training/DefaultTrainingConfig.java +++ b/api/src/main/java/ai/djl/training/DefaultTrainingConfig.java @@ -13,23 +13,23 @@ package ai.djl.training; import ai.djl.Device; +import ai.djl.nn.Parameter; import ai.djl.training.evaluator.Evaluator; import ai.djl.training.initializer.Initializer; -import ai.djl.training.initializer.XavierInitializer; -import ai.djl.training.initializer.XavierInitializer.FactorType; -import ai.djl.training.initializer.XavierInitializer.RandomType; import ai.djl.training.listener.TrainingListener; import ai.djl.training.loss.Loss; import ai.djl.training.optimizer.Adam; import ai.djl.training.optimizer.Optimizer; +import ai.djl.util.PairList; import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.function.Predicate; /** {@code DefaultTrainingConfig} is an implementation of the {@link TrainingConfig} interface. */ public class DefaultTrainingConfig implements TrainingConfig { - private Initializer initializer; + private PairList> initializers = new PairList<>(); private Optimizer optimizer; private Device[] devices; private Loss loss; @@ -38,15 +38,12 @@ public class DefaultTrainingConfig implements TrainingConfig { /** * Creates an instance of {@code DefaultTrainingConfig} with the given {@link Loss}. {@code - * DefaultTrainingConfig} creates a default {@link TrainingConfig} with the {@link - * XavierInitializer} as initialiser, {@link Adam} as optimiser, and the given {@link Loss}. The - * evaluators and listeners are left to the user's discretion. + * DefaultTrainingConfig} creates a default {@link TrainingConfig}, {@link Adam} as optimiser, + * and the given {@link Loss}. The evaluators and listeners are left to the user's discretion. * * @param loss the loss to use for training */ public DefaultTrainingConfig(Loss loss) { - // Defaults to initializer defined in https://arxiv.org/abs/1502.01852 - this.initializer = new XavierInitializer(RandomType.GAUSSIAN, FactorType.IN, 2); optimizer = Adam.builder().build(); this.loss = loss; evaluators = new ArrayList<>(); @@ -58,10 +55,38 @@ public DefaultTrainingConfig(Loss loss) { * href="https://arxiv.org/abs/1502.01852">paper). * * @param initializer the initialer to use for the parameters + * @param type the {@link Parameter.Type} of the parameters * @return this {@code DefaultTrainingConfig} */ - public DefaultTrainingConfig optInitializer(Initializer initializer) { - this.initializer = initializer; + public DefaultTrainingConfig optInitializer(Initializer initializer, Parameter.Type type) { + initializers.add(initializer, parameter -> parameter.getType().equals(type)); + return this; + } + + /** + * Sets the {@link Initializer} to use for the parameters (default from paper). + * + * @param initializer the initialer to use for the parameters + * @param name the name of the parameter + * @return this {@code DefaultTrainingConfig} + */ + public DefaultTrainingConfig optInitializer(Initializer initializer, String name) { + initializers.add(initializer, parameter -> parameter.getName().equals(name)); + return this; + } + + /** + * Sets the {@link Initializer} to use for the parameters (default from paper). + * + * @param initializer the initialer to use for the parameters + * @param predicate the predicate to identify parameter + * @return this {@code DefaultTrainingConfig} + */ + public DefaultTrainingConfig optInitializer( + Initializer initializer, Predicate predicate) { + initializers.add(initializer, predicate); return this; } @@ -120,8 +145,8 @@ public Device[] getDevices() { /** {@inheritDoc} */ @Override - public Initializer getInitializer() { - return initializer; + public PairList> getInitializers() { + return initializers; } /** {@inheritDoc} */ diff --git a/api/src/main/java/ai/djl/training/TrainingConfig.java b/api/src/main/java/ai/djl/training/TrainingConfig.java index 46748232035..0a4b5928266 100644 --- a/api/src/main/java/ai/djl/training/TrainingConfig.java +++ b/api/src/main/java/ai/djl/training/TrainingConfig.java @@ -13,12 +13,15 @@ package ai.djl.training; import ai.djl.Device; +import ai.djl.nn.Parameter; import ai.djl.training.evaluator.Evaluator; import ai.djl.training.initializer.Initializer; import ai.djl.training.listener.TrainingListener; import ai.djl.training.loss.Loss; import ai.djl.training.optimizer.Optimizer; +import ai.djl.util.PairList; import java.util.List; +import java.util.function.Predicate; /** * An interface that is responsible for holding the configuration required by {@link Trainer}. @@ -64,11 +67,11 @@ public interface TrainingConfig { Device[] getDevices(); /** - * Gets the {@link Initializer} to initialize the parameters of the model. + * Gets a list of {@link Initializer} and Predicate to initialize the parameters of the model. * * @return an {@link Initializer} */ - Initializer getInitializer(); + PairList> getInitializers(); /** * Gets the {@link Optimizer} to use during training. diff --git a/api/src/main/java/ai/djl/training/initializer/XavierInitializer.java b/api/src/main/java/ai/djl/training/initializer/XavierInitializer.java index 64ee6279f71..8cf76c7b8f0 100644 --- a/api/src/main/java/ai/djl/training/initializer/XavierInitializer.java +++ b/api/src/main/java/ai/djl/training/initializer/XavierInitializer.java @@ -80,7 +80,7 @@ public XavierInitializer(RandomType randomType, FactorType factorType, float mag /** Creates a new instance of {@code XavierInitializer}. */ public XavierInitializer() { - this(RandomType.UNIFORM, FactorType.AVG, 3f); + this(RandomType.UNIFORM, FactorType.AVG, 6f); } /** {@inheritDoc} */ diff --git a/api/src/test/java/ai/djl/nn/convolutional/ShapeUtils.java b/api/src/test/java/ai/djl/nn/convolutional/ShapeUtils.java index abd5e9ec2a2..562dd3a0de9 100644 --- a/api/src/test/java/ai/djl/nn/convolutional/ShapeUtils.java +++ b/api/src/test/java/ai/djl/nn/convolutional/ShapeUtils.java @@ -31,7 +31,7 @@ private ShapeUtils() {} * @return the corresponding output shape for the provided input */ public static Shape outputShapeForBlock(NDManager manager, Block block, Shape inputShape) { - Shape[] outputs = block.getOutputShapes(manager, new Shape[] {inputShape}); + Shape[] outputs = block.getOutputShapes(new Shape[] {inputShape}); return outputs[0]; } diff --git a/basicdataset/src/test/java/ai/djl/basicdataset/AirfoilRandomAccessTest.java b/basicdataset/src/test/java/ai/djl/basicdataset/AirfoilRandomAccessTest.java index 3efdd00af77..dc8028c5c92 100644 --- a/basicdataset/src/test/java/ai/djl/basicdataset/AirfoilRandomAccessTest.java +++ b/basicdataset/src/test/java/ai/djl/basicdataset/AirfoilRandomAccessTest.java @@ -17,6 +17,7 @@ import ai.djl.ndarray.NDList; import ai.djl.ndarray.NDManager; import ai.djl.nn.Blocks; +import ai.djl.nn.Parameter; import ai.djl.training.DefaultTrainingConfig; import ai.djl.training.Trainer; import ai.djl.training.TrainingConfig; @@ -48,7 +49,7 @@ public class AirfoilRandomAccessTest { public void testAirfoilRemote() throws IOException, TranslateException { TrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss()) - .optInitializer(Initializer.ONES); + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT); try (Model model = Model.newInstance("model")) { model.setBlock(Blocks.identityBlock()); diff --git a/basicdataset/src/test/java/ai/djl/basicdataset/AmesRandomAccessTest.java b/basicdataset/src/test/java/ai/djl/basicdataset/AmesRandomAccessTest.java index 20bf7f8ac56..716f04c578f 100644 --- a/basicdataset/src/test/java/ai/djl/basicdataset/AmesRandomAccessTest.java +++ b/basicdataset/src/test/java/ai/djl/basicdataset/AmesRandomAccessTest.java @@ -17,6 +17,7 @@ import ai.djl.ndarray.NDList; import ai.djl.ndarray.NDManager; import ai.djl.nn.Blocks; +import ai.djl.nn.Parameter; import ai.djl.training.DefaultTrainingConfig; import ai.djl.training.Trainer; import ai.djl.training.TrainingConfig; @@ -48,7 +49,7 @@ public class AmesRandomAccessTest { public void testAmesRandomAccessRemote() throws IOException, TranslateException { TrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss()) - .optInitializer(Initializer.ONES); + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT); try (Model model = Model.newInstance("model")) { model.setBlock(Blocks.identityBlock()); diff --git a/basicdataset/src/test/java/ai/djl/basicdataset/PikachuTest.java b/basicdataset/src/test/java/ai/djl/basicdataset/PikachuTest.java index 75bcb3e914a..063ec07fb13 100644 --- a/basicdataset/src/test/java/ai/djl/basicdataset/PikachuTest.java +++ b/basicdataset/src/test/java/ai/djl/basicdataset/PikachuTest.java @@ -16,6 +16,7 @@ import ai.djl.basicdataset.cv.PikachuDetection; import ai.djl.ndarray.types.Shape; import ai.djl.nn.Blocks; +import ai.djl.nn.Parameter; import ai.djl.training.DefaultTrainingConfig; import ai.djl.training.Trainer; import ai.djl.training.TrainingConfig; @@ -42,7 +43,7 @@ public void testPikachuRemote() throws IOException, TranslateException { .build(); TrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss()) - .optInitializer(new NormalInitializer(0.01f)); + .optInitializer(new NormalInitializer(0.01f), Parameter.Type.WEIGHT); try (Model model = Model.newInstance("model")) { model.setBlock(Blocks.identityBlock()); try (Trainer trainer = model.newTrainer(config)) { diff --git a/examples/src/main/java/ai/djl/examples/training/TrainBertOnCode.java b/examples/src/main/java/ai/djl/examples/training/TrainBertOnCode.java index 53eca2e14b4..7d1cda2adc2 100644 --- a/examples/src/main/java/ai/djl/examples/training/TrainBertOnCode.java +++ b/examples/src/main/java/ai/djl/examples/training/TrainBertOnCode.java @@ -20,6 +20,7 @@ import ai.djl.ndarray.NDList; import ai.djl.ndarray.NDManager; import ai.djl.ndarray.types.Shape; +import ai.djl.nn.Parameter; import ai.djl.nn.transformer.BertBlock; import ai.djl.nn.transformer.BertPretrainingBlock; import ai.djl.nn.transformer.BertPretrainingLoss; @@ -135,7 +136,8 @@ private static Model createBertPretrainingModel(Dictionary dictionary) { model.setBlock( new BertPretrainingBlock( BERT_BUILDER.setTokenDictionarySize(dictionary.tokens.size()))); - model.getBlock().setInitializer(new TruncatedNormalInitializer(0.02f)); + model.getBlock() + .setInitializer(new TruncatedNormalInitializer(0.02f), Parameter.Type.WEIGHT); return model; } diff --git a/examples/src/main/java/ai/djl/examples/training/TrainMnistWithLSTM.java b/examples/src/main/java/ai/djl/examples/training/TrainMnistWithLSTM.java index 13ce0f2f742..76078106a22 100644 --- a/examples/src/main/java/ai/djl/examples/training/TrainMnistWithLSTM.java +++ b/examples/src/main/java/ai/djl/examples/training/TrainMnistWithLSTM.java @@ -31,7 +31,6 @@ import ai.djl.training.dataset.Dataset; import ai.djl.training.dataset.RandomAccessDataset; import ai.djl.training.evaluator.Accuracy; -import ai.djl.training.initializer.XavierInitializer; import ai.djl.training.listener.SaveModelTrainingListener; import ai.djl.training.listener.TrainingListener; import ai.djl.training.loss.Loss; @@ -119,7 +118,6 @@ public static DefaultTrainingConfig setupTrainingConfig(Arguments arguments) { return new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss()) .addEvaluator(new Accuracy()) - .optInitializer(new XavierInitializer()) .optDevices(Device.getDevices(arguments.getMaxGpus())) .addTrainingListeners(TrainingListener.Defaults.logging(outputDir)) .addTrainingListeners(listener); diff --git a/extensions/fasttext/src/main/java/ai/djl/fasttext/FtTrainingConfig.java b/extensions/fasttext/src/main/java/ai/djl/fasttext/FtTrainingConfig.java index b1631b6b73f..b68fc81e1c7 100644 --- a/extensions/fasttext/src/main/java/ai/djl/fasttext/FtTrainingConfig.java +++ b/extensions/fasttext/src/main/java/ai/djl/fasttext/FtTrainingConfig.java @@ -13,15 +13,18 @@ package ai.djl.fasttext; import ai.djl.Device; +import ai.djl.nn.Parameter; import ai.djl.training.TrainingConfig; import ai.djl.training.evaluator.Evaluator; import ai.djl.training.initializer.Initializer; import ai.djl.training.listener.TrainingListener; import ai.djl.training.loss.Loss; import ai.djl.training.optimizer.Optimizer; +import ai.djl.util.PairList; import java.nio.file.Path; import java.util.ArrayList; import java.util.List; +import java.util.function.Predicate; /** An interface that is responsible for holding the configuration required by fastText training. */ public class FtTrainingConfig implements TrainingConfig { @@ -247,7 +250,7 @@ public Device[] getDevices() { /** {@inheritDoc} */ @Override - public Initializer getInitializer() { + public PairList> getInitializers() { return null; } diff --git a/integration/src/main/java/ai/djl/integration/tests/modality/cv/SingleShotDetectionTest.java b/integration/src/main/java/ai/djl/integration/tests/modality/cv/SingleShotDetectionTest.java index f105d46c8cd..2f6e41cda36 100644 --- a/integration/src/main/java/ai/djl/integration/tests/modality/cv/SingleShotDetectionTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/modality/cv/SingleShotDetectionTest.java @@ -20,7 +20,6 @@ import ai.djl.nn.Block; import ai.djl.nn.SequentialBlock; import ai.djl.training.ParameterStore; -import ai.djl.training.initializer.XavierInitializer; import java.util.ArrayList; import java.util.Arrays; import java.util.List; @@ -31,41 +30,34 @@ public class SingleShotDetectionTest { @Test public void testClassPredictorBlocks() { - try (NDManager manager = NDManager.newBaseManager()) { - Block block = SingleShotDetection.getClassPredictionBlock(5, 10); - Assert.assertEquals( - block.getOutputShapes(manager, new Shape[] {new Shape(2, 8, 20, 20)})[0], - new Shape(2, 55, 20, 20)); - block = SingleShotDetection.getClassPredictionBlock(3, 10); - Assert.assertEquals( - block.getOutputShapes(manager, new Shape[] {new Shape(2, 16, 10, 10)})[0], - new Shape(2, 33, 10, 10)); - } + Block block = SingleShotDetection.getClassPredictionBlock(5, 10); + Assert.assertEquals( + block.getOutputShapes(new Shape[] {new Shape(2, 8, 20, 20)})[0], + new Shape(2, 55, 20, 20)); + block = SingleShotDetection.getClassPredictionBlock(3, 10); + Assert.assertEquals( + block.getOutputShapes(new Shape[] {new Shape(2, 16, 10, 10)})[0], + new Shape(2, 33, 10, 10)); } @Test public void testAnchorPredictorBlocks() { - try (NDManager manager = NDManager.newBaseManager()) { - Block block = SingleShotDetection.getAnchorPredictionBlock(5); - Assert.assertEquals( - block.getOutputShapes(manager, new Shape[] {new Shape(2, 8, 20, 20)})[0], - new Shape(2, 20, 20, 20)); - block = SingleShotDetection.getClassPredictionBlock(3, 10); - Assert.assertEquals( - block.getOutputShapes(manager, new Shape[] {new Shape(2, 16, 10, 10)})[0], - new Shape(2, 33, 10, 10)); - } + Block block = SingleShotDetection.getAnchorPredictionBlock(5); + Assert.assertEquals( + block.getOutputShapes(new Shape[] {new Shape(2, 8, 20, 20)})[0], + new Shape(2, 20, 20, 20)); + block = SingleShotDetection.getClassPredictionBlock(3, 10); + Assert.assertEquals( + block.getOutputShapes(new Shape[] {new Shape(2, 16, 10, 10)})[0], + new Shape(2, 33, 10, 10)); } @Test public void testDownSamplingBlock() { - try (NDManager manager = NDManager.newBaseManager()) { - Block sequentialBlock = SingleShotDetection.getDownSamplingBlock(10); - Assert.assertEquals( - sequentialBlock - .getOutputShapes(manager, new Shape[] {new Shape(2, 3, 20, 20)})[0], - new Shape(2, 10, 10, 10)); - } + Block sequentialBlock = SingleShotDetection.getDownSamplingBlock(10); + Assert.assertEquals( + sequentialBlock.getOutputShapes(new Shape[] {new Shape(2, 3, 20, 20)})[0], + new Shape(2, 10, 10, 10)); } @Test @@ -97,7 +89,6 @@ public void testSingleShotDetectionShape() { .setSizes(sizes) .setBaseNetwork(block) .build(); - ssd.setInitializer(new XavierInitializer()); ssd.initialize(manager, DataType.FLOAT32, new Shape(32, 3, 256, 256)); ParameterStore ps = new ParameterStore(manager, false); NDList output = @@ -105,8 +96,7 @@ public void testSingleShotDetectionShape() { Assert.assertEquals(output.get(0).getShape(), new Shape(1, 5444, 4)); Assert.assertEquals(output.get(1).getShape(), new Shape(32, 5444, 2)); Assert.assertEquals(output.get(2).getShape(), new Shape(32, 21776)); - Shape[] outputShapes = - ssd.getOutputShapes(manager, new Shape[] {new Shape(32, 3, 256, 256)}); + Shape[] outputShapes = ssd.getOutputShapes(new Shape[] {new Shape(32, 3, 256, 256)}); Assert.assertEquals(outputShapes[0], new Shape(1, 5444, 4)); Assert.assertEquals(outputShapes[1], new Shape(32, 5444, 2)); Assert.assertEquals(outputShapes[2], new Shape(32, 21776)); diff --git a/integration/src/main/java/ai/djl/integration/tests/modality/nlp/SimpleTextEncoderTest.java b/integration/src/main/java/ai/djl/integration/tests/modality/nlp/SimpleTextEncoderTest.java index 131df85c53e..cc6d43853c8 100644 --- a/integration/src/main/java/ai/djl/integration/tests/modality/nlp/SimpleTextEncoderTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/modality/nlp/SimpleTextEncoderTest.java @@ -23,7 +23,6 @@ import ai.djl.ndarray.types.Shape; import ai.djl.nn.recurrent.LSTM; import ai.djl.training.ParameterStore; -import ai.djl.training.initializer.XavierInitializer; import java.util.Arrays; import org.testng.Assert; import org.testng.annotations.Test; @@ -50,7 +49,6 @@ public void testEncoder() { .optReturnState(true) .build()); try (NDManager manager = NDManager.newBaseManager(TestUtils.getDevices()[0])) { - encoder.setInitializer(new XavierInitializer()); encoder.initialize(manager, DataType.FLOAT32, new Shape(4, 7)); NDList output = encoder.forward( diff --git a/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/AlexNetTest.java b/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/AlexNetTest.java index f956bfb46bc..ae0ea251e24 100644 --- a/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/AlexNetTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/AlexNetTest.java @@ -159,7 +159,7 @@ public void testOutputShapes() { Shape currentShape = x.getShape(); Block alexNet = AlexNet.builder().build(); - alexNet.setInitializer(Initializer.ONES); + alexNet.setInitializer(Initializer.ONES, Parameter.Type.WEIGHT); alexNet.initialize(manager, DataType.FLOAT32, currentShape); Map shapeMap = new ConcurrentHashMap<>(); @@ -169,7 +169,7 @@ public void testOutputShapes() { alexNet.getChildren() .get(i) .getValue() - .getOutputShapes(manager, new Shape[] {currentShape}); + .getOutputShapes(new Shape[] {currentShape}); currentShape = newShape[0]; shapeMap.put(alexNet.getChildren().get(i).getKey(), currentShape); } @@ -188,7 +188,7 @@ public void testForwardMethod() { Block alexNet = AlexNet.builder().build(); int batchSize = 1; NDArray x = manager.ones(new Shape(batchSize, 1, 224, 224)); - alexNet.setInitializer(Initializer.ONES); + alexNet.setInitializer(Initializer.ONES, Parameter.Type.WEIGHT); alexNet.initialize(manager, DataType.FLOAT32, x.getShape()); NDArray xHat = alexNet.forward(new ParameterStore(manager, true), new NDList(x), false) diff --git a/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/GoogLeNetTest.java b/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/GoogLeNetTest.java index a149b7e40e3..4befb517a2f 100644 --- a/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/GoogLeNetTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/GoogLeNetTest.java @@ -100,7 +100,7 @@ public void testOutputShapes() { Shape currentShape = x.getShape(); Block googLeNet = GoogLeNet.builder().build(); - googLeNet.setInitializer(Initializer.ONES); + googLeNet.setInitializer(Initializer.ONES, Parameter.Type.WEIGHT); googLeNet.initialize(manager, DataType.FLOAT32, currentShape); Map shapeMap = new ConcurrentHashMap<>(); @@ -111,7 +111,7 @@ public void testOutputShapes() { .getChildren() .get(i) .getValue() - .getOutputShapes(manager, new Shape[] {currentShape}); + .getOutputShapes(new Shape[] {currentShape}); currentShape = newShape[0]; shapeMap.put(googLeNet.getChildren().get(i).getKey(), currentShape); } @@ -130,7 +130,7 @@ public void testForwardMethod() { Block googLeNet = GoogLeNet.builder().build(); int batchSize = 1; NDArray x = manager.ones(new Shape(batchSize, 1, 28, 28)); - googLeNet.setInitializer(Initializer.ONES); + googLeNet.setInitializer(Initializer.ONES, Parameter.Type.WEIGHT); googLeNet.initialize(manager, DataType.FLOAT32, x.getShape()); NDArray xHat = googLeNet diff --git a/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/LeNetTest.java b/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/LeNetTest.java index 39744eeefe6..6160ed5f7e4 100644 --- a/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/LeNetTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/LeNetTest.java @@ -137,7 +137,7 @@ public void testOutputShapes() { Shape currentShape = x.getShape(); Block leNet = LeNet.builder().build(); - leNet.setInitializer(Initializer.ONES); + leNet.setInitializer(Initializer.ONES, Parameter.Type.WEIGHT); leNet.initialize(manager, DataType.FLOAT32, currentShape); Map shapeMap = new ConcurrentHashMap<>(); @@ -147,7 +147,7 @@ public void testOutputShapes() { leNet.getChildren() .get(i) .getValue() - .getOutputShapes(manager, new Shape[] {currentShape}); + .getOutputShapes(new Shape[] {currentShape}); currentShape = newShape[0]; shapeMap.put(leNet.getChildren().get(i).getKey(), currentShape); } @@ -165,7 +165,7 @@ public void testForwardMethod() { Block leNet = LeNet.builder().build(); int batchSize = 1; NDArray x = manager.ones(new Shape(batchSize, 1, 28, 28)); - leNet.setInitializer(Initializer.ONES); + leNet.setInitializer(Initializer.ONES, Parameter.Type.WEIGHT); leNet.initialize(manager, DataType.FLOAT32, x.getShape()); NDArray xHat = leNet.forward(new ParameterStore(manager, true), new NDList(x), true) diff --git a/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/NiNTest.java b/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/NiNTest.java index cac77221d69..f8544debc70 100644 --- a/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/NiNTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/NiNTest.java @@ -152,17 +152,14 @@ public void testOutputShapes() { Shape currentShape = x.getShape(); Block nin = NiN.builder().build(); - nin.setInitializer(Initializer.ONES); + nin.setInitializer(Initializer.ONES, Parameter.Type.WEIGHT); nin.initialize(manager, DataType.FLOAT32, currentShape); Map shapeMap = new ConcurrentHashMap<>(); for (int i = 0; i < nin.getChildren().size(); i++) { Shape[] newShape = - nin.getChildren() - .get(i) - .getValue() - .getOutputShapes(manager, new Shape[] {currentShape}); + nin.getChildren().get(i).getValue().getOutputShapes(new Shape[] {currentShape}); currentShape = newShape[0]; shapeMap.put(nin.getChildren().get(i).getKey(), currentShape); } @@ -180,7 +177,7 @@ public void testForwardMethod() { Block nin = NiN.builder().build(); int batchSize = 1; NDArray x = manager.ones(new Shape(batchSize, 1, 224, 224)); - nin.setInitializer(Initializer.ONES); + nin.setInitializer(Initializer.ONES, Parameter.Type.WEIGHT); nin.initialize(manager, DataType.FLOAT32, x.getShape()); NDArray xHat = nin.forward(new ParameterStore(manager, true), new NDList(x), false) diff --git a/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/ResnetTest.java b/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/ResnetTest.java index fd50e39943f..21c6779f36f 100644 --- a/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/ResnetTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/ResnetTest.java @@ -57,7 +57,7 @@ public void testTrain() { TrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss()) .optDevices(Device.getDevices(2)) - .optInitializer(Initializer.ONES); + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT); Block resNet50 = ResNetV1.builder() @@ -123,7 +123,7 @@ public void testLoadTrain() TrainingConfig config = new DefaultTrainingConfig(Loss.l1Loss()) .optDevices(Device.getDevices(2)) - .optInitializer(Initializer.ONES); + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT); try (Trainer trainer = model.newTrainer(config)) { int batchSize = 2; Shape inputShape = new Shape(batchSize, 3, 32, 32); @@ -131,8 +131,7 @@ public void testLoadTrain() trainer.initialize(inputShape); NDManager manager = trainer.getManager(); - Shape[] outputShape = - model.getBlock().getOutputShapes(manager, new Shape[] {inputShape}); + Shape[] outputShape = model.getBlock().getOutputShapes(new Shape[] {inputShape}); NDArray data = manager.ones(new Shape(batchSize, 3, 32, 32)); NDArray label = manager.ones(outputShape[0]); diff --git a/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/SqueezenetTest.java b/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/SqueezenetTest.java index 267bd41cf80..114907144aa 100644 --- a/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/SqueezenetTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/SqueezenetTest.java @@ -40,7 +40,7 @@ public void testTrain() { TrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss()) .optDevices(Device.getDevices(2)) - .optInitializer(Initializer.ONES); + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT); Block squeezeNet = SqueezeNet.squeezenet(10); try (Model model = Model.newInstance("squeezenet")) { model.setBlock(squeezeNet); diff --git a/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/VGGTest.java b/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/VGGTest.java index b8520edd8f9..5c22bacecf1 100644 --- a/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/VGGTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/VGGTest.java @@ -107,17 +107,14 @@ public void testOutputShapes() { Shape currentShape = x.getShape(); Block vgg = VGG.builder().build(); - vgg.setInitializer(Initializer.ONES); + vgg.setInitializer(Initializer.ONES, Parameter.Type.WEIGHT); vgg.initialize(manager, DataType.FLOAT32, currentShape); Map shapeMap = new ConcurrentHashMap<>(); for (int i = 0; i < vgg.getChildren().size(); i++) { Shape[] newShape = - vgg.getChildren() - .get(i) - .getValue() - .getOutputShapes(manager, new Shape[] {currentShape}); + vgg.getChildren().get(i).getValue().getOutputShapes(new Shape[] {currentShape}); currentShape = newShape[0]; shapeMap.put(vgg.getChildren().get(i).getKey(), currentShape); } @@ -137,8 +134,9 @@ public void testForwardMethod() { Block vgg = VGG.builder().build(); int batchSize = 1; NDArray x = manager.ones(new Shape(batchSize, 1, 224, 224)); - vgg.setInitializer(Initializer.ONES); + vgg.setInitializer(Initializer.ONES, Parameter.Type.WEIGHT); vgg.initialize(manager, DataType.FLOAT32, x.getShape()); + NDArray xHat = vgg.forward(new ParameterStore(manager, true), new NDList(x), false) .singletonOrThrow(); diff --git a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDArrayElementArithmeticOpTest.java b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDArrayElementArithmeticOpTest.java index 8e97ed38b87..2dbfac8ef2e 100644 --- a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDArrayElementArithmeticOpTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDArrayElementArithmeticOpTest.java @@ -18,6 +18,7 @@ import ai.djl.ndarray.NDManager; import ai.djl.ndarray.types.Shape; import ai.djl.nn.Blocks; +import ai.djl.nn.Parameter; import ai.djl.testing.Assertions; import ai.djl.training.DefaultTrainingConfig; import ai.djl.training.GradientCollector; @@ -139,7 +140,7 @@ public void testAddScalar() { try (Trainer trainer = model.newTrainer( new DefaultTrainingConfig(Loss.l2Loss()) - .optInitializer(Initializer.ONES))) { + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT))) { try (GradientCollector gradCol = trainer.newGradientCollector()) { lhs.attachGradient(); result = NDArrays.add(lhs, 2); @@ -360,7 +361,7 @@ public void testDot() { try (Trainer trainer = model.newTrainer( new DefaultTrainingConfig(Loss.l2Loss()) - .optInitializer(Initializer.ONES))) { + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT))) { try (GradientCollector gradCol = trainer.newGradientCollector()) { lhs.attachGradient(); result = NDArrays.dot(lhs, rhs); diff --git a/integration/src/main/java/ai/djl/integration/tests/nn/BlockCoreTest.java b/integration/src/main/java/ai/djl/integration/tests/nn/BlockCoreTest.java index b0289856ae9..c133b382327 100644 --- a/integration/src/main/java/ai/djl/integration/tests/nn/BlockCoreTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/nn/BlockCoreTest.java @@ -62,7 +62,8 @@ public class BlockCoreTest { @Test public void testLinear() throws IOException, MalformedModelException { TrainingConfig config = - new DefaultTrainingConfig(Loss.l2Loss()).optInitializer(Initializer.ONES); + new DefaultTrainingConfig(Loss.l2Loss()) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT); long outSize = 3; Block block = Linear.builder().setUnits(outSize).build(); @@ -124,7 +125,8 @@ public void testLinear() throws IOException, MalformedModelException { @Test public void testLinearWithDefinedLayout() throws IOException, MalformedModelException { TrainingConfig config = - new DefaultTrainingConfig(Loss.l2Loss()).optInitializer(Initializer.ONES); + new DefaultTrainingConfig(Loss.l2Loss()) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT); long outSize = 3; Block block = Linear.builder().setUnits(outSize).build(); @@ -176,7 +178,8 @@ public void testLinearWithDefinedLayout() throws IOException, MalformedModelExce @Test public void testBatchNorm() throws IOException, MalformedModelException { TrainingConfig config = - new DefaultTrainingConfig(Loss.l2Loss()).optInitializer(Initializer.ONES); + new DefaultTrainingConfig(Loss.l2Loss()) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT); Block block = BatchNorm.builder().build(); try (Model model = Model.newInstance("model")) { @@ -203,7 +206,8 @@ public void testBatchNorm() throws IOException, MalformedModelException { @Test public void testDropout() throws IOException, MalformedModelException { TrainingConfig config = - new DefaultTrainingConfig(Loss.l2Loss()).optInitializer(Initializer.ONES); + new DefaultTrainingConfig(Loss.l2Loss()) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT); Block block = Dropout.builder().optRate(.5f).build(); try (Model model = Model.newInstance("model")) { @@ -229,7 +233,8 @@ public void testDropout() throws IOException, MalformedModelException { @Test public void testEmbedding() throws IOException, MalformedModelException { TrainingConfig config = - new DefaultTrainingConfig(Loss.l2Loss()).optInitializer(Initializer.ONES); + new DefaultTrainingConfig(Loss.l2Loss()) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT); TrainableWordEmbedding block = TrainableWordEmbedding.builder() @@ -262,7 +267,8 @@ public void testEmbedding() throws IOException, MalformedModelException { @Test public void testConv1d() throws IOException, MalformedModelException { TrainingConfig config = - new DefaultTrainingConfig(Loss.l2Loss()).optInitializer(Initializer.ONES); + new DefaultTrainingConfig(Loss.l2Loss()) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT); Block block = Conv1d.builder().setKernelShape(new Shape(2)).setFilters(1).optBias(false).build(); @@ -283,7 +289,7 @@ public void testConv1d() throws IOException, MalformedModelException { NDArray out = trainer.forward(new NDList(data)).singletonOrThrow(); Assert.assertEquals(out, expected); - Shape[] outputShape = block.getOutputShapes(manager, new Shape[] {inputShape}); + Shape[] outputShape = block.getOutputShapes(new Shape[] {inputShape}); Assert.assertEquals(out.getShape(), outputShape[0]); testEncode(manager, block); @@ -294,7 +300,8 @@ public void testConv1d() throws IOException, MalformedModelException { @Test public void testConv1dTranspose() throws IOException, MalformedModelException { TrainingConfig config = - new DefaultTrainingConfig(Loss.l2Loss()).optInitializer(Initializer.ONES); + new DefaultTrainingConfig(Loss.l2Loss()) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT); Block block = Conv1dTranspose.builder() @@ -317,7 +324,7 @@ public void testConv1dTranspose() throws IOException, MalformedModelException { NDArray out = trainer.forward(new NDList(data)).singletonOrThrow(); Assert.assertEquals(out, expected); - Shape[] outputShape = block.getOutputShapes(manager, new Shape[] {inputShape}); + Shape[] outputShape = block.getOutputShapes(new Shape[] {inputShape}); Assert.assertEquals(out.getShape(), outputShape[0]); testEncode(manager, block); @@ -328,7 +335,8 @@ public void testConv1dTranspose() throws IOException, MalformedModelException { @Test public void testConv2d() throws IOException, MalformedModelException { TrainingConfig config = - new DefaultTrainingConfig(Loss.l2Loss()).optInitializer(Initializer.ONES); + new DefaultTrainingConfig(Loss.l2Loss()) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT); Block block = Conv2d.builder().setKernelShape(new Shape(2, 2)).setFilters(1).build(); try (Model model = Model.newInstance("model")) { @@ -359,7 +367,8 @@ public void testConv2d() throws IOException, MalformedModelException { @Test public void testConv2dTranspose() throws IOException, MalformedModelException { TrainingConfig config = - new DefaultTrainingConfig(Loss.l2Loss()).optInitializer(Initializer.ONES); + new DefaultTrainingConfig(Loss.l2Loss()) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT); Block block = Conv2dTranspose.builder().setKernelShape(new Shape(2, 2)).setFilters(1).build(); @@ -396,7 +405,8 @@ public void testConv2dTranspose() throws IOException, MalformedModelException { @Test public void testConv3d() throws IOException, MalformedModelException { TrainingConfig config = - new DefaultTrainingConfig(Loss.l2Loss()).optInitializer(Initializer.ONES); + new DefaultTrainingConfig(Loss.l2Loss()) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT); Block block = Conv3d.builder().setKernelShape(new Shape(2, 2, 2)).setFilters(1).build(); try (Model model = Model.newInstance("model")) { @@ -422,8 +432,7 @@ public void testConv3d() throws IOException, MalformedModelException { NDArray result = trainer.forward(new NDList(data)).singletonOrThrow(); Assert.assertEquals(result, expected); - Shape[] outputShape = - block.getOutputShapes(manager, new Shape[] {new Shape(1, 1, 3, 3, 3)}); + Shape[] outputShape = block.getOutputShapes(new Shape[] {new Shape(1, 1, 3, 3, 3)}); Assert.assertEquals(result.getShape(), outputShape[0]); testEncode(manager, block); @@ -437,7 +446,7 @@ public void testRNNTanh() throws IOException, MalformedModelException { Loss loss = new SoftmaxCrossEntropyLoss("SmCeLoss", 1, -1, false, true); TrainingConfig config = new DefaultTrainingConfig(loss) - .optInitializer(Initializer.ONES) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT) .optDevices(TestUtils.getDevices()); Block block = RNN.builder() @@ -484,7 +493,7 @@ public void testRNNRelu() throws IOException, MalformedModelException { Loss loss = new SoftmaxCrossEntropyLoss("SmCeLoss", 1, -1, false, true); TrainingConfig config = new DefaultTrainingConfig(loss) - .optInitializer(Initializer.ONES) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT) .optDevices(TestUtils.getDevices()); Block block = RNN.builder() @@ -534,7 +543,7 @@ public void testLstm() throws IOException, MalformedModelException { Loss loss = new SoftmaxCrossEntropyLoss("SmCeLoss", 1, -1, false, true); TrainingConfig config = new DefaultTrainingConfig(loss) - .optInitializer(Initializer.ONES) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT) .optDevices(TestUtils.getDevices()); Block block = LSTM.builder() @@ -585,7 +594,7 @@ public void testGRU() throws IOException, MalformedModelException { Loss loss = new SoftmaxCrossEntropyLoss("SmCeLoss", 1, -1, false, true); TrainingConfig config = new DefaultTrainingConfig(loss) - .optInitializer(Initializer.ONES) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT) .optDevices(TestUtils.getDevices()); GRU block = GRU.builder() @@ -638,7 +647,8 @@ public void testGRU() throws IOException, MalformedModelException { @Test public void testSequentialBlock() throws IOException, MalformedModelException { TrainingConfig config = - new DefaultTrainingConfig(Loss.l2Loss()).optInitializer(Initializer.ONES); + new DefaultTrainingConfig(Loss.l2Loss()) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT); SequentialBlock block = new SequentialBlock(); block.addSingleton(x -> x.mul(6.5f)); block.add(Linear.builder().setUnits(10).build()); @@ -678,7 +688,8 @@ public void testSequentialBlock() throws IOException, MalformedModelException { @Test public void testParallelBlock() throws IOException, MalformedModelException { TrainingConfig config = - new DefaultTrainingConfig(Loss.l2Loss()).optInitializer(Initializer.ONES); + new DefaultTrainingConfig(Loss.l2Loss()) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT); ParallelBlock block = new ParallelBlock( list -> diff --git a/integration/src/main/java/ai/djl/integration/tests/nn/PoolingOperationsTest.java b/integration/src/main/java/ai/djl/integration/tests/nn/PoolingOperationsTest.java index 6f5f4e1ebc3..44a4e9816dd 100644 --- a/integration/src/main/java/ai/djl/integration/tests/nn/PoolingOperationsTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/nn/PoolingOperationsTest.java @@ -18,6 +18,7 @@ import ai.djl.ndarray.NDManager; import ai.djl.ndarray.index.NDIndex; import ai.djl.ndarray.types.Shape; +import ai.djl.nn.Parameter; import ai.djl.nn.pooling.Pool; import ai.djl.training.DefaultTrainingConfig; import ai.djl.training.Trainer; @@ -29,7 +30,8 @@ public class PoolingOperationsTest { TrainingConfig config = - new DefaultTrainingConfig(Loss.l2Loss()).optInitializer(Initializer.ONES); + new DefaultTrainingConfig(Loss.l2Loss()) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT); @Test public void testMaxPool1d() { diff --git a/integration/src/main/java/ai/djl/integration/tests/nn/ScaledDotProductAttentionBlockTest.java b/integration/src/main/java/ai/djl/integration/tests/nn/ScaledDotProductAttentionBlockTest.java index 9595e88ef15..df8a64f4803 100644 --- a/integration/src/main/java/ai/djl/integration/tests/nn/ScaledDotProductAttentionBlockTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/nn/ScaledDotProductAttentionBlockTest.java @@ -19,6 +19,7 @@ import ai.djl.ndarray.NDManager; import ai.djl.ndarray.types.DataType; import ai.djl.ndarray.types.Shape; +import ai.djl.nn.Parameter; import ai.djl.nn.transformer.ScaledDotProductAttentionBlock; import ai.djl.training.GradientCollector; import ai.djl.training.ParameterStore; @@ -752,7 +753,7 @@ public void testMaskedAttention() { .optAttentionProbsDropoutProb(0.0f) .build(); - block.setInitializer(new NormalInitializer()); + block.setInitializer(new NormalInitializer(), Parameter.Type.WEIGHT); block.getKeyProjection().setInitializer(keyKernelInitializer, "weight"); block.getValueProjection().setInitializer(valueKernelInitializer, "weight"); block.getQueryProjection().setInitializer(queryKernelInitializer, "weight"); diff --git a/integration/src/main/java/ai/djl/integration/tests/training/ActivationTest.java b/integration/src/main/java/ai/djl/integration/tests/training/ActivationTest.java index 443370b10d8..b3e8502a9ca 100644 --- a/integration/src/main/java/ai/djl/integration/tests/training/ActivationTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/training/ActivationTest.java @@ -18,6 +18,7 @@ import ai.djl.ndarray.NDManager; import ai.djl.ndarray.types.Shape; import ai.djl.nn.Activation; +import ai.djl.nn.Parameter; import ai.djl.testing.Assertions; import ai.djl.training.DefaultTrainingConfig; import ai.djl.training.Trainer; @@ -30,7 +31,8 @@ public class ActivationTest { TrainingConfig config = - new DefaultTrainingConfig(Loss.l2Loss()).optInitializer(Initializer.ONES); + new DefaultTrainingConfig(Loss.l2Loss()) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT); @Test public void testRelu() { diff --git a/integration/src/main/java/ai/djl/integration/tests/training/BlocksTest.java b/integration/src/main/java/ai/djl/integration/tests/training/BlocksTest.java index a86705162ef..630ad57d3e6 100644 --- a/integration/src/main/java/ai/djl/integration/tests/training/BlocksTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/training/BlocksTest.java @@ -18,6 +18,7 @@ import ai.djl.ndarray.NDManager; import ai.djl.ndarray.types.Shape; import ai.djl.nn.Blocks; +import ai.djl.nn.Parameter; import ai.djl.testing.Assertions; import ai.djl.training.DefaultTrainingConfig; import ai.djl.training.ParameterStore; @@ -30,7 +31,8 @@ public class BlocksTest { TrainingConfig config = - new DefaultTrainingConfig(Loss.l2Loss()).optInitializer(Initializer.ONES); + new DefaultTrainingConfig(Loss.l2Loss()) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT); @Test public void testFlattenBlock() { diff --git a/integration/src/main/java/ai/djl/integration/tests/training/DatasetTest.java b/integration/src/main/java/ai/djl/integration/tests/training/DatasetTest.java index 4b5dd466c31..8d06c85efd8 100644 --- a/integration/src/main/java/ai/djl/integration/tests/training/DatasetTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/training/DatasetTest.java @@ -20,6 +20,7 @@ import ai.djl.ndarray.types.DataType; import ai.djl.ndarray.types.Shape; import ai.djl.nn.Blocks; +import ai.djl.nn.Parameter; import ai.djl.training.DefaultTrainingConfig; import ai.djl.training.Trainer; import ai.djl.training.TrainingConfig; @@ -49,7 +50,8 @@ public class DatasetTest { private TrainingConfig config = - new DefaultTrainingConfig(Loss.l2Loss()).optInitializer(Initializer.ONES); + new DefaultTrainingConfig(Loss.l2Loss()) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT); @Test public void testSequenceSampler() throws IOException, TranslateException { diff --git a/integration/src/main/java/ai/djl/integration/tests/training/GradientCollectorIntegrationTest.java b/integration/src/main/java/ai/djl/integration/tests/training/GradientCollectorIntegrationTest.java index 154583e59e3..96680659f78 100644 --- a/integration/src/main/java/ai/djl/integration/tests/training/GradientCollectorIntegrationTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/training/GradientCollectorIntegrationTest.java @@ -19,6 +19,7 @@ import ai.djl.ndarray.types.DataType; import ai.djl.ndarray.types.Shape; import ai.djl.nn.Blocks; +import ai.djl.nn.Parameter; import ai.djl.nn.core.Linear; import ai.djl.testing.Assertions; import ai.djl.training.DefaultTrainingConfig; @@ -49,7 +50,7 @@ public void testAutograd() { try (Trainer trainer = model.newTrainer( new DefaultTrainingConfig(Loss.l2Loss()) - .optInitializer(Initializer.ONES))) { + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT))) { try (GradientCollector gradCol = trainer.newGradientCollector()) { NDArray lhs = manager.create(new float[] {6, -9, -12, 15, 0, 4}, new Shape(2, 3)); @@ -87,7 +88,7 @@ public void testTrain() throws IOException, TranslateException { TrainingConfig config = new DefaultTrainingConfig(Loss.l2Loss()) .addTrainingListeners(new EvaluatorTrainingListener()) - .optInitializer(Initializer.ONES) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT) .optOptimizer(optimizer); try (Model model = Model.newInstance("linear")) { diff --git a/integration/src/main/java/ai/djl/integration/tests/training/ModelTest.java b/integration/src/main/java/ai/djl/integration/tests/training/ModelTest.java index 2e919ca78d5..6023b1b054f 100644 --- a/integration/src/main/java/ai/djl/integration/tests/training/ModelTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/training/ModelTest.java @@ -21,7 +21,6 @@ import ai.djl.nn.convolutional.Conv2d; import ai.djl.nn.norm.BatchNorm; import ai.djl.testing.Assertions; -import ai.djl.training.initializer.XavierInitializer; import java.io.IOException; import java.nio.file.Paths; import org.testng.Assert; @@ -36,7 +35,6 @@ public void testModelSaveAndLoad() throws IOException, MalformedModelException { block.add(BatchNorm.builder().build()); try (Model saveModel = Model.newInstance("saveModel"); Model loadModel = Model.newInstance("loadModel")) { - block.setInitializer(new XavierInitializer()); block.initialize(saveModel.getNDManager(), DataType.FLOAT32, new Shape(1, 3, 32, 32)); ParameterList savedParameters = block.getParameters(); saveModel.setBlock(block); diff --git a/integration/src/main/java/ai/djl/integration/tests/training/OptimizerTest.java b/integration/src/main/java/ai/djl/integration/tests/training/OptimizerTest.java index 8e472c626f2..5657f584805 100644 --- a/integration/src/main/java/ai/djl/integration/tests/training/OptimizerTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/training/OptimizerTest.java @@ -20,6 +20,7 @@ import ai.djl.ndarray.NDManager; import ai.djl.ndarray.types.Shape; import ai.djl.nn.Block; +import ai.djl.nn.Parameter; import ai.djl.nn.core.Linear; import ai.djl.testing.Assertions; import ai.djl.training.DefaultTrainingConfig; @@ -46,7 +47,7 @@ public void testSgd() { Device[] devices = Device.getDevices(1); TrainingConfig config = new DefaultTrainingConfig(Loss.l2Loss()) - .optInitializer(Initializer.ONES) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT) .optOptimizer(sgd) .optDevices(devices); Block block = Linear.builder().setUnits(CHANNELS).build(); @@ -78,7 +79,7 @@ public void testSgdWithMomentum() { Device[] devices = Device.getDevices(1); TrainingConfig config = new DefaultTrainingConfig(Loss.l2Loss()) - .optInitializer(Initializer.ONES) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT) .optOptimizer(optim) .optDevices(devices); Block block = Linear.builder().setUnits(CHANNELS).build(); @@ -118,7 +119,7 @@ public void testNag() { Device[] devices = Device.getDevices(1); TrainingConfig config = new DefaultTrainingConfig(Loss.l2Loss()) - .optInitializer(Initializer.ONES) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT) .optOptimizer(optim) .optDevices(devices); Block block = Linear.builder().setUnits(CHANNELS).build(); @@ -147,7 +148,7 @@ public void testAdam() { Device[] devices = Device.getDevices(1); TrainingConfig config = new DefaultTrainingConfig(Loss.l2Loss()) - .optInitializer(Initializer.ONES) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT) .optOptimizer(optim) .optDevices(devices); Block block = Linear.builder().setUnits(CHANNELS).build(); @@ -176,7 +177,7 @@ public void testAdagrad() { Device[] devices = Device.getDevices(1); TrainingConfig config = new DefaultTrainingConfig(Loss.l2Loss()) - .optInitializer(Initializer.ONES) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT) .optOptimizer(optim) .optDevices(devices); Block block = Linear.builder().setUnits(CHANNELS).build(); @@ -209,7 +210,7 @@ public void testRMSProp() { Device[] devices = Device.getDevices(1); TrainingConfig config = new DefaultTrainingConfig(Loss.l2Loss()) - .optInitializer(Initializer.ONES) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT) .optOptimizer(optim) .optDevices(devices); Block block = Linear.builder().setUnits(CHANNELS).build(); @@ -243,7 +244,7 @@ public void testRMSPropAlex() { Device[] devices = Device.getDevices(1); TrainingConfig config = new DefaultTrainingConfig(Loss.l2Loss()) - .optInitializer(Initializer.ONES) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT) .optOptimizer(optim) .optDevices(devices); Block block = Linear.builder().setUnits(CHANNELS).build(); @@ -273,7 +274,7 @@ public void testAdadelta() { Device[] devices = Device.getDevices(1); TrainingConfig config = new DefaultTrainingConfig(Loss.l2Loss()) - .optInitializer(Initializer.ONES) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT) .optOptimizer(optim) .optDevices(devices); diff --git a/model-zoo/src/main/java/ai/djl/basicmodelzoo/cv/object_detection/ssd/SingleShotDetection.java b/model-zoo/src/main/java/ai/djl/basicmodelzoo/cv/object_detection/ssd/SingleShotDetection.java index 4d9486c70cd..430c1aa66aa 100644 --- a/model-zoo/src/main/java/ai/djl/basicmodelzoo/cv/object_detection/ssd/SingleShotDetection.java +++ b/model-zoo/src/main/java/ai/djl/basicmodelzoo/cv/object_detection/ssd/SingleShotDetection.java @@ -111,46 +111,48 @@ private NDArray concatPredictions(NDList output) { /** {@inheritDoc} */ @Override - public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) { - // TODO: output shape is wrong - Shape[] childInputShapes = inputShapes; - Shape[] anchorShapes = new Shape[features.size()]; - Shape[] classPredictionShapes = new Shape[features.size()]; - Shape[] anchorPredictionShapes = new Shape[features.size()]; - for (int i = 0; i < features.size(); i++) { - childInputShapes = features.get(i).getOutputShapes(manager, childInputShapes); - anchorShapes[i] = - multiBoxPriors - .get(i) - .generateAnchorBoxes(manager.ones(childInputShapes[0])) - .getShape(); - classPredictionShapes[i] = - classPredictionBlocks.get(i).getOutputShapes(manager, childInputShapes)[0]; - anchorPredictionShapes[i] = - anchorPredictionBlocks.get(i).getOutputShapes(manager, childInputShapes)[0]; - } - Shape anchorOutputShape = new Shape(); - for (Shape shape : anchorShapes) { - anchorOutputShape = concatShape(anchorOutputShape, shape, 1); - } + public Shape[] getOutputShapes(Shape[] inputShapes) { + try (NDManager manager = NDManager.newBaseManager()) { + // TODO: output shape is wrong + Shape[] childInputShapes = inputShapes; + Shape[] anchorShapes = new Shape[features.size()]; + Shape[] classPredictionShapes = new Shape[features.size()]; + Shape[] anchorPredictionShapes = new Shape[features.size()]; + for (int i = 0; i < features.size(); i++) { + childInputShapes = features.get(i).getOutputShapes(childInputShapes); + anchorShapes[i] = + multiBoxPriors + .get(i) + .generateAnchorBoxes(manager.ones(childInputShapes[0])) + .getShape(); + classPredictionShapes[i] = + classPredictionBlocks.get(i).getOutputShapes(childInputShapes)[0]; + anchorPredictionShapes[i] = + anchorPredictionBlocks.get(i).getOutputShapes(childInputShapes)[0]; + } + Shape anchorOutputShape = new Shape(); + for (Shape shape : anchorShapes) { + anchorOutputShape = concatShape(anchorOutputShape, shape, 1); + } - NDList classPredictions = new NDList(); - for (Shape shape : classPredictionShapes) { - classPredictions.add(manager.ones(shape)); - } - NDArray classPredictionOutput = concatPredictions(classPredictions); - Shape classPredictionOutputShape = - classPredictionOutput - .reshape(classPredictionOutput.size(0), -1, numClasses + 1) - .getShape(); - NDList anchorPredictions = new NDList(); - for (Shape shape : anchorPredictionShapes) { - anchorPredictions.add(manager.ones(shape)); + NDList classPredictions = new NDList(); + for (Shape shape : classPredictionShapes) { + classPredictions.add(manager.ones(shape)); + } + NDArray classPredictionOutput = concatPredictions(classPredictions); + Shape classPredictionOutputShape = + classPredictionOutput + .reshape(classPredictionOutput.size(0), -1, numClasses + 1) + .getShape(); + NDList anchorPredictions = new NDList(); + for (Shape shape : anchorPredictionShapes) { + anchorPredictions.add(manager.ones(shape)); + } + Shape anchorPredictionOutputShape = concatPredictions(anchorPredictions).getShape(); + return new Shape[] { + anchorOutputShape, classPredictionOutputShape, anchorPredictionOutputShape + }; } - Shape anchorPredictionOutputShape = concatPredictions(anchorPredictions).getShape(); - return new Shape[] { - anchorOutputShape, classPredictionOutputShape, anchorPredictionOutputShape - }; } private Shape concatShape(Shape shape, Shape concat, int axis) { @@ -177,15 +179,15 @@ private Shape concatShape(Shape shape, Shape concat, int axis) { /** {@inheritDoc} */ @Override - public Shape[] initialize(NDManager manager, DataType dataType, Shape... inputShapes) { + public void initialize(NDManager manager, DataType dataType, Shape... inputShapes) { beforeInitialize(inputShapes); Shape[] shapes = inputShapes; for (int i = 0; i < features.size(); i++) { - shapes = features.get(i).initialize(manager, dataType, shapes); + features.get(i).initialize(manager, dataType, shapes); + shapes = features.get(i).getOutputShapes(shapes); classPredictionBlocks.get(i).initialize(manager, dataType, shapes); anchorPredictionBlocks.get(i).initialize(manager, dataType, shapes); } - return getOutputShapes(manager, inputShapes); } /** {@inheritDoc} */ diff --git a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxModel.java b/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxModel.java index ee8afeba541..54c58bf1a01 100644 --- a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxModel.java +++ b/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxModel.java @@ -24,6 +24,8 @@ import ai.djl.training.Trainer; import ai.djl.training.TrainingConfig; import ai.djl.training.initializer.Initializer; +import ai.djl.util.Pair; +import ai.djl.util.PairList; import java.io.FileNotFoundException; import java.io.IOException; import java.nio.file.Files; @@ -33,6 +35,7 @@ import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Predicate; import java.util.stream.Collectors; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -122,12 +125,16 @@ public void load(Path modelPath, String prefix, Map options) /** {@inheritDoc} */ @Override public Trainer newTrainer(TrainingConfig trainingConfig) { - Initializer initializer = trainingConfig.getInitializer(); + PairList> initializer = trainingConfig.getInitializers(); if (block == null) { throw new IllegalStateException( "You must set a block for the model before creating a new trainer"); } - block.setInitializer(initializer); + for (Pair> pair : initializer) { + if (pair.getKey() != null && pair.getValue() != null) { + block.setInitializer(pair.getKey(), pair.getValue()); + } + } return new Trainer(this, trainingConfig); } diff --git a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxSymbolBlock.java b/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxSymbolBlock.java index ffc04996a4b..3b63b96a928 100644 --- a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxSymbolBlock.java +++ b/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxSymbolBlock.java @@ -21,7 +21,6 @@ import ai.djl.ndarray.types.Shape; import ai.djl.nn.AbstractSymbolBlock; import ai.djl.nn.Parameter; -import ai.djl.nn.ParameterType; import ai.djl.nn.SymbolBlock; import ai.djl.training.ParameterStore; import ai.djl.util.PairList; @@ -201,7 +200,7 @@ protected NDList forwardInternal( /** {@inheritDoc} */ @Override - public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) { + public Shape[] getOutputShapes(Shape[] inputShapes) { if (outputShapes == null) { String[] outputNames = symbol.getOutputNames(); outputShapes = new Shape[outputNames.length]; @@ -232,9 +231,7 @@ public void removeLastBlock() { } } - /** {@inheritDoc} */ - @Override - public Shape getParameterShape(String name, Shape[] inputShapes) { + private Shape getParameterShape(String name, Shape[] inputShapes) { if (paramShapes == null) { PairList pairs = new PairList<>(); for (int i = 0; i < inputNames.size(); i++) { @@ -314,25 +311,32 @@ private void initBlock() { Set auxNameSet = new HashSet<>(Arrays.asList(symbol.getAuxNames())); for (String name : allNames) { - ParameterType type = inferType(name); + Parameter.Type type = inferType(name); boolean requireGrad = !auxNameSet.contains(name); - mxNetParams.add(new Parameter(name, this, type, requireGrad)); + mxNetParams.add( + Parameter.builder() + .setName(name) + .setType(type) + .optRequiresGrad(requireGrad) + .build()); } first = true; } - private static ParameterType inferType(String name) { + private static Parameter.Type inferType(String name) { if (name.endsWith("bias")) { - return ParameterType.BIAS; + return Parameter.Type.BIAS; } else if (name.endsWith("gamma")) { - return ParameterType.GAMMA; + return Parameter.Type.GAMMA; } else if (name.endsWith("beta")) { - return ParameterType.BETA; + return Parameter.Type.BETA; } else if (name.endsWith("moving_mean") || name.endsWith("running_mean")) { - return ParameterType.RUNNING_MEAN; + return Parameter.Type.RUNNING_MEAN; } else if (name.endsWith("moving_var") || name.endsWith("running_var")) { - return ParameterType.RUNNING_VAR; + return Parameter.Type.RUNNING_VAR; + } else if (name.endsWith("weight")) { + return Parameter.Type.WEIGHT; } - return ParameterType.OTHER; + return Parameter.Type.OTHER; } } diff --git a/mxnet/mxnet-engine/src/test/java/ai/djl/mxnet/integration/MxGradientCollectorIntegrationTest.java b/mxnet/mxnet-engine/src/test/java/ai/djl/mxnet/integration/MxGradientCollectorIntegrationTest.java index 0ba9177404f..299136587c7 100644 --- a/mxnet/mxnet-engine/src/test/java/ai/djl/mxnet/integration/MxGradientCollectorIntegrationTest.java +++ b/mxnet/mxnet-engine/src/test/java/ai/djl/mxnet/integration/MxGradientCollectorIntegrationTest.java @@ -19,6 +19,7 @@ import ai.djl.ndarray.NDManager; import ai.djl.ndarray.types.Shape; import ai.djl.nn.Blocks; +import ai.djl.nn.Parameter; import ai.djl.testing.Assertions; import ai.djl.training.DefaultTrainingConfig; import ai.djl.training.GradientCollector; @@ -38,7 +39,7 @@ public void testMxAutograd() { try (Trainer trainer = model.newTrainer( new DefaultTrainingConfig(Loss.l2Loss()) - .optInitializer(Initializer.ONES))) { + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT))) { try (GradientCollector gradCol = trainer.newGradientCollector()) { NDArray lhs = manager.create(new float[] {6, -9, -12, 15, 0, 4}, new Shape(2, 3)); diff --git a/mxnet/mxnet-model-zoo/src/test/java/ai/djl/mxnet/integration/MxSymbolBlockTest.java b/mxnet/mxnet-model-zoo/src/test/java/ai/djl/mxnet/integration/MxSymbolBlockTest.java index d855bb90ced..74fde4bf89a 100644 --- a/mxnet/mxnet-model-zoo/src/test/java/ai/djl/mxnet/integration/MxSymbolBlockTest.java +++ b/mxnet/mxnet-model-zoo/src/test/java/ai/djl/mxnet/integration/MxSymbolBlockTest.java @@ -25,6 +25,7 @@ import ai.djl.ndarray.NDManager; import ai.djl.ndarray.types.Shape; import ai.djl.nn.Block; +import ai.djl.nn.Parameter; import ai.djl.nn.SequentialBlock; import ai.djl.nn.SymbolBlock; import ai.djl.nn.core.Linear; @@ -85,7 +86,7 @@ public void trainWithNewParam() throws IOException, ModelNotFoundException, MalformedModelException { TrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss()) - .optInitializer(Initializer.ONES); + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT); try (Model model = MxModelZoo.MLP.loadModel()) { model.getBlock().clear(); try (Trainer trainer = model.newTrainer(config)) { @@ -113,7 +114,7 @@ public void trainWithExistParam() throws IOException, ModelNotFoundException, MalformedModelException { TrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss()) - .optInitializer(Initializer.ONES); + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT); try (Model model = MxModelZoo.MLP.loadModel()) { try (Trainer trainer = model.newTrainer(config)) { NDManager manager = trainer.getManager(); @@ -140,7 +141,7 @@ public void trainWithCustomLayer() throws IOException, ModelNotFoundException, MalformedModelException { TrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss()) - .optInitializer(Initializer.ONES); + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT); try (Model model = MxModelZoo.MLP.loadModel()) { NDManager manager = model.getNDManager(); SymbolBlock mlp = (SymbolBlock) model.getBlock(); @@ -149,7 +150,6 @@ public void trainWithCustomLayer() newMlp.add(mlp); Linear linear = Linear.builder().setUnits(10).build(); - linear.setInitializer(Initializer.ONES); newMlp.add(linear); model.setBlock(newMlp); diff --git a/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpSymbolBlock.java b/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpSymbolBlock.java index 522a928b32b..43b3a960f4c 100644 --- a/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpSymbolBlock.java +++ b/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpSymbolBlock.java @@ -98,7 +98,7 @@ private NDList getOutputs(PpNDArray[] outputs, boolean foreignEngine, NDManager /** {@inheritDoc} */ @Override - public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) { + public Shape[] getOutputShapes(Shape[] inputShapes) { return new Shape[0]; } } diff --git a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtModel.java b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtModel.java index 51ec90497ac..3d56947e34d 100644 --- a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtModel.java +++ b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtModel.java @@ -17,10 +17,13 @@ import ai.djl.MalformedModelException; import ai.djl.Model; import ai.djl.ndarray.types.DataType; +import ai.djl.nn.Parameter; import ai.djl.pytorch.jni.JniUtils; import ai.djl.training.Trainer; import ai.djl.training.TrainingConfig; import ai.djl.training.initializer.Initializer; +import ai.djl.util.Pair; +import ai.djl.util.PairList; import java.io.FileNotFoundException; import java.io.IOException; import java.nio.file.Files; @@ -28,6 +31,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.function.Predicate; import java.util.stream.Collectors; /** @@ -125,12 +129,16 @@ private Path findModelFile(String prefix) { /** {@inheritDoc} */ @Override public Trainer newTrainer(TrainingConfig trainingConfig) { - Initializer initializer = trainingConfig.getInitializer(); + PairList> initializer = trainingConfig.getInitializers(); if (block == null) { throw new IllegalStateException( "You must set a block for the model before creating a new trainer"); } - block.setInitializer(initializer); + for (Pair> pair : initializer) { + if (pair.getKey() != null && pair.getValue() != null) { + block.setInitializer(pair.getKey(), pair.getValue()); + } + } return new Trainer(this, trainingConfig); } diff --git a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtSymbolBlock.java b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtSymbolBlock.java index 7cf0e5e3ac8..964649cd5a1 100644 --- a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtSymbolBlock.java +++ b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtSymbolBlock.java @@ -155,7 +155,7 @@ public PairList describeOutput() { /** {@inheritDoc} */ @Override - public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) { + public Shape[] getOutputShapes(Shape[] inputShapes) { return new Shape[0]; } diff --git a/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfSymbolBlock.java b/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfSymbolBlock.java index 43f667d2dc2..7a6f8bdb753 100644 --- a/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfSymbolBlock.java +++ b/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfSymbolBlock.java @@ -131,8 +131,8 @@ protected NDList forwardInternal( /** {@inheritDoc} */ @Override - public Shape[] initialize(NDManager manager, DataType dataType, Shape... inputShapes) { - return new Shape[0]; + public void initialize(NDManager manager, DataType dataType, Shape... inputShapes) { + throw new IllegalStateException("TfSymbolBlock can't be initialized"); } /** {@inheritDoc} */ @@ -197,7 +197,7 @@ public final PairList describeOutput() { /** {@inheritDoc} */ @Override - public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) { + public Shape[] getOutputShapes(Shape[] inputShapes) { return new Shape[0]; } From 5ca07f4f8092422e8e85054ab76af6e1a5c82783 Mon Sep 17 00:00:00 2001 From: afeenster Date: Wed, 3 Mar 2021 09:43:39 -0800 Subject: [PATCH 06/25] Removing unnecessary logging messages. --- .../ai/djl/serving/central/handler/ModelDownloadHandler.java | 2 -- 1 file changed, 2 deletions(-) diff --git a/central/src/main/java/ai/djl/serving/central/handler/ModelDownloadHandler.java b/central/src/main/java/ai/djl/serving/central/handler/ModelDownloadHandler.java index 8224661de32..0e4313cc7d0 100644 --- a/central/src/main/java/ai/djl/serving/central/handler/ModelDownloadHandler.java +++ b/central/src/main/java/ai/djl/serving/central/handler/ModelDownloadHandler.java @@ -84,14 +84,12 @@ public class ModelDownloadHandler extends SimpleChannelInboundHandler { try { if (modelName!=null) { - logger.info(String.valueOf(ModelLink.linkFinder(modelName))); return ModelLink.linkFinder(modelName); } else { throw new BadRequestException("modelName is mandatory."); From eb8d51d187e8aca12147d28e21e0746e848278ee Mon Sep 17 00:00:00 2001 From: Lanking Date: Wed, 3 Mar 2021 11:17:14 -0800 Subject: [PATCH 07/25] block factory init commit (#697) --- api/src/main/java/ai/djl/engine/Engine.java | 9 ++ api/src/main/java/ai/djl/nn/BlockFactory.java | 34 +++++ api/src/main/java/ai/djl/nn/SymbolBlock.java | 11 ++ .../java/ai/djl/dlr/engine/DlrEngine.java | 7 ++ .../tests/nn/BlockFactoryTest.java | 116 ++++++++++++++++++ .../java/ai/djl/mxnet/engine/MxEngine.java | 7 ++ .../ai/djl/mxnet/engine/MxSymbolBlock.java | 19 ++- .../ai/djl/onnxruntime/engine/OrtEngine.java | 7 ++ .../ai/djl/paddlepaddle/engine/PpEngine.java | 7 ++ .../java/ai/djl/pytorch/engine/PtEngine.java | 7 ++ .../ai/djl/tensorflow/engine/TfEngine.java | 7 ++ .../ai/djl/tflite/engine/TfLiteEngine.java | 7 ++ 12 files changed, 227 insertions(+), 11 deletions(-) create mode 100644 api/src/main/java/ai/djl/nn/BlockFactory.java create mode 100644 integration/src/main/java/ai/djl/integration/tests/nn/BlockFactoryTest.java diff --git a/api/src/main/java/ai/djl/engine/Engine.java b/api/src/main/java/ai/djl/engine/Engine.java index 7673673a116..336aad53648 100644 --- a/api/src/main/java/ai/djl/engine/Engine.java +++ b/api/src/main/java/ai/djl/engine/Engine.java @@ -15,6 +15,7 @@ import ai.djl.Device; import ai.djl.Model; import ai.djl.ndarray.NDManager; +import ai.djl.nn.SymbolBlock; import ai.djl.training.GradientCollector; import ai.djl.training.LocalParameterServer; import ai.djl.training.ParameterServer; @@ -190,6 +191,14 @@ public Device defaultDevice() { return defaultDevice; } + /** + * Construct an empty SymbolBlock for loading. + * + * @param manager the manager to manage parameters + * @return Empty {@link SymbolBlock} for static graph + */ + public abstract SymbolBlock newSymbolBlock(NDManager manager); + /** * Constructs a new model. * diff --git a/api/src/main/java/ai/djl/nn/BlockFactory.java b/api/src/main/java/ai/djl/nn/BlockFactory.java new file mode 100644 index 00000000000..c6747b0fe64 --- /dev/null +++ b/api/src/main/java/ai/djl/nn/BlockFactory.java @@ -0,0 +1,34 @@ +/* + * Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.nn; + +import ai.djl.ndarray.NDManager; +import ai.djl.repository.zoo.ModelZoo; +import java.io.Serializable; + +/** + * Block factory is a component to make standard for block creating and saving procedure. Block + * factory design is intended to bypass the serialization of the blocks. This class can be used by + * {@link ModelZoo} or DJL Serving to recover the block to its uninitialized states. User should + * combine this method with the block.loadParameter to get the block with all parameters. + */ +public interface BlockFactory extends Serializable { + + /** + * Constructs the uninitialized block. + * + * @param manager the manager to assign to block + * @return the uninitialized block + */ + Block newBlock(NDManager manager); +} diff --git a/api/src/main/java/ai/djl/nn/SymbolBlock.java b/api/src/main/java/ai/djl/nn/SymbolBlock.java index b9e2ec73568..dca341e1eed 100644 --- a/api/src/main/java/ai/djl/nn/SymbolBlock.java +++ b/api/src/main/java/ai/djl/nn/SymbolBlock.java @@ -12,6 +12,7 @@ */ package ai.djl.nn; +import ai.djl.ndarray.NDManager; import ai.djl.ndarray.types.Shape; import ai.djl.util.PairList; @@ -21,6 +22,16 @@ */ public interface SymbolBlock extends Block { + /** + * Creates an empty SymbolBlock instance. + * + * @param manager the manager to be applied in the SymbolBlock + * @return a new Model instance + */ + static SymbolBlock newInstance(NDManager manager) { + return manager.getEngine().newSymbolBlock(manager); + } + /** Removes the last block in the symbolic graph. */ default void removeLastBlock() { throw new UnsupportedOperationException("not supported"); diff --git a/dlr/dlr-engine/src/main/java/ai/djl/dlr/engine/DlrEngine.java b/dlr/dlr-engine/src/main/java/ai/djl/dlr/engine/DlrEngine.java index 7a05b901a2d..e28426e048f 100644 --- a/dlr/dlr-engine/src/main/java/ai/djl/dlr/engine/DlrEngine.java +++ b/dlr/dlr-engine/src/main/java/ai/djl/dlr/engine/DlrEngine.java @@ -19,6 +19,7 @@ import ai.djl.engine.Engine; import ai.djl.engine.EngineException; import ai.djl.ndarray.NDManager; +import ai.djl.nn.SymbolBlock; import ai.djl.training.GradientCollector; /** @@ -80,6 +81,12 @@ public boolean hasCapability(String capability) { return false; } + /** {@inheritDoc} */ + @Override + public SymbolBlock newSymbolBlock(NDManager manager) { + throw new UnsupportedOperationException("DLR does not support empty SymbolBlock"); + } + /** {@inheritDoc} */ @Override public Model newModel(String name, Device device) { diff --git a/integration/src/main/java/ai/djl/integration/tests/nn/BlockFactoryTest.java b/integration/src/main/java/ai/djl/integration/tests/nn/BlockFactoryTest.java new file mode 100644 index 00000000000..f65e3052be0 --- /dev/null +++ b/integration/src/main/java/ai/djl/integration/tests/nn/BlockFactoryTest.java @@ -0,0 +1,116 @@ +/* + * Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.integration.tests.nn; + +import ai.djl.Application; +import ai.djl.MalformedModelException; +import ai.djl.Model; +import ai.djl.engine.Engine; +import ai.djl.inference.Predictor; +import ai.djl.modality.Classifications; +import ai.djl.modality.cv.Image; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.types.Shape; +import ai.djl.nn.Block; +import ai.djl.nn.BlockFactory; +import ai.djl.nn.Blocks; +import ai.djl.nn.SequentialBlock; +import ai.djl.nn.SymbolBlock; +import ai.djl.nn.core.Linear; +import ai.djl.repository.zoo.Criteria; +import ai.djl.repository.zoo.ModelNotFoundException; +import ai.djl.repository.zoo.ModelZoo; +import ai.djl.testing.Assertions; +import ai.djl.training.ParameterStore; +import ai.djl.training.util.ProgressBar; +import ai.djl.translate.NoopTranslator; +import ai.djl.translate.TranslateException; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import org.testng.annotations.Test; + +public class BlockFactoryTest { + + @Test + public void testBlockLoadingSaving() + throws IOException, ModelNotFoundException, MalformedModelException, + TranslateException { + TestBlockFactory factory = new TestBlockFactory(); + Model model = factory.getRemoveLastBlockModel(); + try (NDManager manager = NDManager.newBaseManager()) { + Block block = model.getBlock(); + block.forward( + new ParameterStore(manager, true), + new NDList(manager.ones(new Shape(1, 3, 32, 32))), + true); + ByteArrayOutputStream os = new ByteArrayOutputStream(); + block.saveParameters(new DataOutputStream(os)); + ByteArrayInputStream bis = new ByteArrayInputStream(os.toByteArray()); + Block newBlock = factory.newBlock(manager); + newBlock.loadParameters(manager, new DataInputStream(bis)); + try (Model test = Model.newInstance("test")) { + test.setBlock(newBlock); + try (Predictor predOrigin = + model.newPredictor(new NoopTranslator()); + Predictor predDest = + test.newPredictor(new NoopTranslator())) { + NDList input = new NDList(manager.ones(new Shape(1, 3, 32, 32))); + NDList originOut = predOrigin.predict(input); + NDList destOut = predDest.predict(input); + Assertions.assertAlmostEquals(originOut, destOut); + } + } + } + } + + static class TestBlockFactory implements BlockFactory { + + private static final long serialVersionUID = 1234567L; + + @Override + public Block newBlock(NDManager manager) { + SequentialBlock newBlock = new SequentialBlock(); + newBlock.add(SymbolBlock.newInstance(manager)); + newBlock.add(Blocks.batchFlattenBlock()); + newBlock.add(Linear.builder().setUnits(10).build()); + return newBlock; + } + + public Model getRemoveLastBlockModel() + throws MalformedModelException, ModelNotFoundException, IOException { + String name = Engine.getInstance().getEngineName(); + Criteria.Builder builder = + Criteria.builder() + .optApplication(Application.CV.IMAGE_CLASSIFICATION) + .setTypes(Image.class, Classifications.class) + .optProgress(new ProgressBar()) + .optArtifactId("resnet") + .optEngine(name) + .optGroupId("ai.djl." + name.toLowerCase()) + .optFilter("layers", "50"); + Model model = ModelZoo.loadModel(builder.build()); + SequentialBlock newBlock = new SequentialBlock(); + SymbolBlock block = (SymbolBlock) model.getBlock(); + block.removeLastBlock(); + newBlock.add(block); + newBlock.add(Blocks.batchFlattenBlock()); + newBlock.add(Linear.builder().setUnits(10).build()); + model.setBlock(newBlock); + return model; + } + } +} diff --git a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxEngine.java b/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxEngine.java index c0275e1bdf5..2106c8d2f52 100644 --- a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxEngine.java +++ b/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxEngine.java @@ -19,6 +19,7 @@ import ai.djl.mxnet.jna.JnaUtils; import ai.djl.mxnet.jna.LibUtils; import ai.djl.ndarray.NDManager; +import ai.djl.nn.SymbolBlock; import ai.djl.training.GradientCollector; import ai.djl.training.LocalParameterServer; import ai.djl.training.ParameterServer; @@ -101,6 +102,12 @@ public boolean hasCapability(String capability) { return JnaUtils.getFeatures().contains(capability); } + /** {@inheritDoc} */ + @Override + public SymbolBlock newSymbolBlock(NDManager manager) { + return new MxSymbolBlock(manager); + } + /** {@inheritDoc} */ @Override public Model newModel(String name, Device device) { diff --git a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxSymbolBlock.java b/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxSymbolBlock.java index 3b63b96a928..b2c8d6cdb67 100644 --- a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxSymbolBlock.java +++ b/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxSymbolBlock.java @@ -260,10 +260,8 @@ public void saveParameters(DataOutputStream os) throws IOException { for (String name : inputNames) { os.writeUTF(name); } - for (Parameter parameter : parameters.values()) { - if (!inputNames.contains(parameter.getName())) { - parameter.save(os); - } + for (Parameter parameter : mxNetParams) { + parameter.save(os); } } @@ -286,21 +284,20 @@ public void loadParameters(NDManager manager, DataInputStream is) throw new MalformedModelException("InputStream ends at symbol loading!"); } // init block only if it is not set - if (symbol == null) { - symbol = - Symbol.loadJson( - (MxNDManager) manager, new String(bytes, StandardCharsets.UTF_8)); - initBlock(); - } + symbol = + Symbol.loadJson( + (MxNDManager) manager, new String(bytes, StandardCharsets.UTF_8)); + initBlock(); } int size = is.readInt(); for (int i = 0; i < size; ++i) { inputNames.add(is.readUTF()); } - for (Parameter parameter : parameters.values()) { + for (Parameter parameter : mxNetParams) { parameter.load(this.manager, is); } + setInputNames(inputNames); } private void initBlock() { diff --git a/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngine.java b/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngine.java index 9b95bc9249a..d5c37e00d47 100644 --- a/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngine.java +++ b/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngine.java @@ -17,6 +17,7 @@ import ai.djl.Model; import ai.djl.engine.Engine; import ai.djl.ndarray.NDManager; +import ai.djl.nn.SymbolBlock; import ai.djl.training.GradientCollector; import ai.onnxruntime.OrtEnvironment; @@ -85,6 +86,12 @@ public Model newModel(String name, Device device) { return new OrtModel(name, newBaseManager(device), env); } + /** {@inheritDoc} */ + @Override + public SymbolBlock newSymbolBlock(NDManager manager) { + throw new UnsupportedOperationException("ONNXRuntime does not support empty SymbolBlock"); + } + /** {@inheritDoc} */ @Override public NDManager newBaseManager() { diff --git a/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpEngine.java b/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpEngine.java index 22faec1e947..c08c0e31304 100644 --- a/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpEngine.java +++ b/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpEngine.java @@ -16,6 +16,7 @@ import ai.djl.Model; import ai.djl.engine.Engine; import ai.djl.ndarray.NDManager; +import ai.djl.nn.SymbolBlock; import ai.djl.paddlepaddle.jni.JniUtils; import ai.djl.paddlepaddle.jni.LibUtils; import ai.djl.training.GradientCollector; @@ -85,6 +86,12 @@ public Model newModel(String name, Device device) { return new PpModel(name, newBaseManager(device)); } + /** {@inheritDoc} */ + @Override + public SymbolBlock newSymbolBlock(NDManager manager) { + throw new UnsupportedOperationException("PaddlePaddle does not support empty SymbolBlock"); + } + /** {@inheritDoc} */ @Override public NDManager newBaseManager() { diff --git a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtEngine.java b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtEngine.java index 4830addd0b8..4002ef76308 100644 --- a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtEngine.java +++ b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtEngine.java @@ -17,6 +17,7 @@ import ai.djl.engine.Engine; import ai.djl.engine.EngineException; import ai.djl.ndarray.NDManager; +import ai.djl.nn.SymbolBlock; import ai.djl.pytorch.jni.JniUtils; import ai.djl.pytorch.jni.LibUtils; import ai.djl.training.GradientCollector; @@ -81,6 +82,12 @@ public boolean hasCapability(String capability) { return JniUtils.getFeatures().contains(capability); } + /** {@inheritDoc} */ + @Override + public SymbolBlock newSymbolBlock(NDManager manager) { + return new PtSymbolBlock((PtNDManager) manager); + } + /** {@inheritDoc} */ @Override public Model newModel(String name, Device device) { diff --git a/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfEngine.java b/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfEngine.java index 6b2629a72f6..9ab49d24b87 100644 --- a/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfEngine.java +++ b/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfEngine.java @@ -18,6 +18,7 @@ import ai.djl.engine.EngineException; import ai.djl.engine.StandardCapabilities; import ai.djl.ndarray.NDManager; +import ai.djl.nn.SymbolBlock; import ai.djl.training.GradientCollector; import ai.djl.util.RandomUtils; import org.tensorflow.EagerSession; @@ -56,6 +57,12 @@ public Model newModel(String name, Device device) { return new TfModel(name, device); } + /** {@inheritDoc} */ + @Override + public SymbolBlock newSymbolBlock(NDManager manager) { + throw new UnsupportedOperationException("TensorFlow does not support empty SymbolBlock"); + } + /** {@inheritDoc} */ @Override public String getEngineName() { diff --git a/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteEngine.java b/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteEngine.java index 6d74876c825..a28ec7c9fe1 100644 --- a/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteEngine.java +++ b/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteEngine.java @@ -17,6 +17,7 @@ import ai.djl.Model; import ai.djl.engine.Engine; import ai.djl.ndarray.NDManager; +import ai.djl.nn.SymbolBlock; import ai.djl.training.GradientCollector; /** @@ -83,6 +84,12 @@ public Model newModel(String name, Device device) { return new TfLiteModel(name, newBaseManager(device)); } + /** {@inheritDoc} */ + @Override + public SymbolBlock newSymbolBlock(NDManager manager) { + throw new UnsupportedOperationException("TFLite does not support empty SymbolBlock"); + } + /** {@inheritDoc} */ @Override public NDManager newBaseManager() { From cb352adaf58d2517aad2ca01098ed0126f3612aa Mon Sep 17 00:00:00 2001 From: aksrajvanshi Date: Wed, 3 Mar 2021 14:15:28 -0800 Subject: [PATCH 08/25] [DOCS] Fixing TrainingListener documentation (#718) * Fixing TrainingListener documentation * Fixing PR reviews --- .../ai/djl/training/listener/MemoryTrainingListener.java | 8 +++++++- .../djl/training/listener/SaveModelTrainingListener.java | 5 +++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/api/src/main/java/ai/djl/training/listener/MemoryTrainingListener.java b/api/src/main/java/ai/djl/training/listener/MemoryTrainingListener.java index 829b0d9994b..5cb6303840c 100644 --- a/api/src/main/java/ai/djl/training/listener/MemoryTrainingListener.java +++ b/api/src/main/java/ai/djl/training/listener/MemoryTrainingListener.java @@ -53,6 +53,10 @@ public MemoryTrainingListener() {} /** * Constructs a {@link MemoryTrainingListener} that outputs data in the given directory. * + *

If an output directory is provided, the file "$outputDir/memory.log" will be created after + * training with the memory usage results. The log file consists of heap bytes, non-heap bytes, + * cpu percentage and rss bytes consumption along with the timestamps. + * * @param outputDir the directory to output the tracked memory data in */ public MemoryTrainingListener(String outputDir) { @@ -81,7 +85,9 @@ public void onTrainingEnd(Trainer trainer) { } /** - * Collect memory information. + * Collects memory information. In order to collect metrics, the {@link Trainer} must set + * metrics. Monitor the metrics by enabling the following flag in the command line arguments: + * -Dcollect-memory=true * * @param metrics {@link Metrics} to store memory information */ diff --git a/api/src/main/java/ai/djl/training/listener/SaveModelTrainingListener.java b/api/src/main/java/ai/djl/training/listener/SaveModelTrainingListener.java index 1ae2916aea2..0f128031e8d 100644 --- a/api/src/main/java/ai/djl/training/listener/SaveModelTrainingListener.java +++ b/api/src/main/java/ai/djl/training/listener/SaveModelTrainingListener.java @@ -108,7 +108,8 @@ public void setOverrideModelName(String overrideModelName) { } /** - * Returns the checkpoint frequency (or -1 for no checkpointing). + * Returns the checkpoint frequency (or -1 for no checkpointing) in {@link + * SaveModelTrainingListener}. * * @return the checkpoint frequency (or -1 for no checkpointing) */ @@ -117,7 +118,7 @@ public int getCheckpoint() { } /** - * Sets the checkpoint frequency. + * Sets the checkpoint frequency in {@link SaveModelTrainingListener}. * * @param checkpoint how many epochs between checkpoints (or -1 for no checkpoints) */ From 48cf6630c11810a0aab441c590dc61e5a4ffa9b4 Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Wed, 3 Mar 2021 17:29:58 -0800 Subject: [PATCH 09/25] Fix DJL serving flaky test for mac (#721) Change-Id: I9eccc84b0c34652e50c5fe5a4fe42f2b82d65a3d --- .../java/ai/djl/serving/ModelServerTest.java | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/serving/src/test/java/ai/djl/serving/ModelServerTest.java b/serving/src/test/java/ai/djl/serving/ModelServerTest.java index fe6eef7ec25..1195af84b25 100644 --- a/serving/src/test/java/ai/djl/serving/ModelServerTest.java +++ b/serving/src/test/java/ai/djl/serving/ModelServerTest.java @@ -434,6 +434,7 @@ private void testInvalidRootRequest() throws InterruptedException { HttpRequest req = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/"); channel.writeAndFlush(req).sync(); latch.await(); + channel.closeFuture().sync(); channel.close(); ErrorResponse resp = JsonUtils.GSON.fromJson(result, ErrorResponse.class); @@ -452,6 +453,7 @@ private void testInvalidUri() throws InterruptedException { new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/InvalidUrl"); channel.writeAndFlush(req).sync(); latch.await(); + channel.closeFuture().sync(); channel.close(); ErrorResponse resp = JsonUtils.GSON.fromJson(result, ErrorResponse.class); @@ -471,6 +473,7 @@ private void testInvalidDescribeModel() throws InterruptedException { HttpVersion.HTTP_1_1, HttpMethod.OPTIONS, "/predictions/InvalidModel"); channel.writeAndFlush(req).sync(); latch.await(); + channel.closeFuture().sync(); channel.close(); ErrorResponse resp = JsonUtils.GSON.fromJson(result, ErrorResponse.class); @@ -489,6 +492,7 @@ private void testInvalidPredictionsUri() throws InterruptedException { new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/predictions"); channel.writeAndFlush(req).sync(); latch.await(); + channel.closeFuture().sync(); channel.close(); ErrorResponse resp = JsonUtils.GSON.fromJson(result, ErrorResponse.class); @@ -508,6 +512,7 @@ private void testPredictionsModelNotFound() throws InterruptedException { HttpVersion.HTTP_1_1, HttpMethod.GET, "/predictions/InvalidModel"); channel.writeAndFlush(req).sync(); latch.await(); + channel.closeFuture().sync(); channel.close(); ErrorResponse resp = JsonUtils.GSON.fromJson(result, ErrorResponse.class); @@ -526,6 +531,7 @@ private void testInvalidManagementUri() throws InterruptedException { new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/InvalidUrl"); channel.writeAndFlush(req).sync(); latch.await(); + channel.closeFuture().sync(); channel.close(); ErrorResponse resp = JsonUtils.GSON.fromJson(result, ErrorResponse.class); @@ -544,6 +550,7 @@ private void testInvalidManagementMethod() throws InterruptedException { new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.PUT, "/models"); channel.writeAndFlush(req).sync(); latch.await(); + channel.closeFuture().sync(); channel.close(); ErrorResponse resp = JsonUtils.GSON.fromJson(result, ErrorResponse.class); @@ -562,6 +569,7 @@ private void testInvalidPredictionsMethod() throws InterruptedException { new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, "/models/noop"); channel.writeAndFlush(req).sync(); latch.await(); + channel.closeFuture().sync(); channel.close(); ErrorResponse resp = JsonUtils.GSON.fromJson(result, ErrorResponse.class); @@ -581,6 +589,7 @@ private void testDescribeModelNotFound() throws InterruptedException { HttpVersion.HTTP_1_1, HttpMethod.GET, "/models/InvalidModel"); channel.writeAndFlush(req).sync(); latch.await(); + channel.closeFuture().sync(); channel.close(); ErrorResponse resp = JsonUtils.GSON.fromJson(result, ErrorResponse.class); @@ -599,6 +608,7 @@ private void testRegisterModelMissingUrl() throws InterruptedException { new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, "/models"); channel.writeAndFlush(req).sync(); latch.await(); + channel.closeFuture().sync(); channel.close(); ErrorResponse resp = JsonUtils.GSON.fromJson(result, ErrorResponse.class); @@ -618,6 +628,7 @@ private void testRegisterModelNotFound() throws InterruptedException { HttpVersion.HTTP_1_1, HttpMethod.POST, "/models?url=InvalidUrl"); channel.writeAndFlush(req).sync(); latch.await(); + channel.closeFuture().sync(); channel.close(); ErrorResponse resp = JsonUtils.GSON.fromJson(result, ErrorResponse.class); @@ -643,6 +654,7 @@ private void testRegisterModelConflict() + URLEncoder.encode(url, StandardCharsets.UTF_8.name())); channel.writeAndFlush(req); latch.await(); + channel.closeFuture().sync(); channel.close(); ErrorResponse resp = JsonUtils.GSON.fromJson(result, ErrorResponse.class); @@ -663,6 +675,7 @@ private void testInvalidScaleModel() throws InterruptedException { "/models/mlp?min_worker=10&max_worker=1"); channel.writeAndFlush(req).sync(); latch.await(); + channel.closeFuture().sync(); channel.close(); ErrorResponse resp = JsonUtils.GSON.fromJson(result, ErrorResponse.class); @@ -681,6 +694,7 @@ private void testScaleModelNotFound() throws InterruptedException { new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.PUT, "/models/fake"); channel.writeAndFlush(req).sync(); latch.await(); + channel.closeFuture().sync(); channel.close(); ErrorResponse resp = JsonUtils.GSON.fromJson(result, ErrorResponse.class); @@ -699,6 +713,7 @@ private void testUnregisterModelNotFound() throws InterruptedException { new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.DELETE, "/models/fake"); channel.writeAndFlush(req).sync(); latch.await(); + channel.closeFuture().sync(); channel.close(); ErrorResponse resp = JsonUtils.GSON.fromJson(result, ErrorResponse.class); @@ -731,6 +746,7 @@ private void testServiceUnavailable() throws InterruptedException { req.headers().set(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.APPLICATION_OCTET_STREAM); channel.writeAndFlush(req).sync(); latch.await(); + channel.closeFuture().sync(); channel.close(); ErrorResponse resp = JsonUtils.GSON.fromJson(result, ErrorResponse.class); From 28a32ffcdf01072b60773f1bca2655e0f10f81ae Mon Sep 17 00:00:00 2001 From: afeenster Date: Thu, 4 Mar 2021 14:46:05 -0800 Subject: [PATCH 10/25] Fixing all of the nits. --- .../HttpStaticFileServerInitializer.java | 2 +- .../serving/central/classes/ModelLink.java | 74 ++++++++++++++++ .../serving/central/classes/package-info.java | 14 +++ .../central/handler/ModelDownloadHandler.java | 86 ++++++------------- .../central/http/BadRequestException.java | 2 +- .../serving/central/http/package-info.java | 14 +++ .../responseencoder/HttpRequestResponse.java | 10 +-- .../djl/serving/central/utils/NettyUtils.java | 23 ----- .../serving/central/utils/package-info.java | 14 +++ 9 files changed, 146 insertions(+), 93 deletions(-) create mode 100644 central/src/main/java/ai/djl/serving/central/classes/ModelLink.java create mode 100644 central/src/main/java/ai/djl/serving/central/classes/package-info.java create mode 100644 central/src/main/java/ai/djl/serving/central/http/package-info.java create mode 100644 central/src/main/java/ai/djl/serving/central/utils/package-info.java diff --git a/central/src/main/java/ai/djl/serving/central/HttpStaticFileServerInitializer.java b/central/src/main/java/ai/djl/serving/central/HttpStaticFileServerInitializer.java index 464eaf62527..5a7ef1fc1b8 100644 --- a/central/src/main/java/ai/djl/serving/central/HttpStaticFileServerInitializer.java +++ b/central/src/main/java/ai/djl/serving/central/HttpStaticFileServerInitializer.java @@ -13,8 +13,8 @@ package ai.djl.serving.central; import ai.djl.serving.central.handler.HttpStaticFileServerHandler; -import ai.djl.serving.central.handler.ModelMetaDataHandler; import ai.djl.serving.central.handler.ModelDownloadHandler; +import ai.djl.serving.central.handler.ModelMetaDataHandler; import io.netty.channel.ChannelInitializer; import io.netty.channel.ChannelPipeline; import io.netty.channel.socket.SocketChannel; diff --git a/central/src/main/java/ai/djl/serving/central/classes/ModelLink.java b/central/src/main/java/ai/djl/serving/central/classes/ModelLink.java new file mode 100644 index 00000000000..92c762e12f7 --- /dev/null +++ b/central/src/main/java/ai/djl/serving/central/classes/ModelLink.java @@ -0,0 +1,74 @@ +/* + * Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.serving.central.classes; + +import ai.djl.Application; +import ai.djl.repository.Artifact; +import ai.djl.repository.zoo.ModelNotFoundException; +import ai.djl.repository.zoo.ModelZoo; +import java.io.IOException; +import java.net.URI; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A class to find the URL link when given a model name. + * + * @author anfee1@morgan.edu + */ +public final class ModelLink { + + private static URI base = URI.create("https://mlrepo.djl.ai/"); + private static Map links = new ConcurrentHashMap<>(); + private static final Logger logger = LoggerFactory.getLogger(ModelLink.class); + + private ModelLink() {} + + /** + * Takes in a model name and returns a Map of download links. + * + * @param modelName the connection context + * @return This returns a map of download links + * @throws IOException throws an exception + * @throws ModelNotFoundException throws an exception + */ + public static Map linkFinder(String modelName) + throws IOException, ModelNotFoundException { + Map> models = ModelZoo.listModels(); + models.forEach( + (app, list) -> { + list.forEach( + artifact -> { + if (artifact.getName().equals(modelName)) { + for (Map.Entry entry : + artifact.getFiles().entrySet()) { + URI fileUri = URI.create(entry.getValue().getUri()); + URI baseUri = artifact.getMetadata().getRepositoryUri(); + if (!fileUri.isAbsolute()) { + fileUri = base.resolve(baseUri).resolve(fileUri); + } + try { + links.put(entry.getKey(), fileUri); + } catch (Exception e) { + logger.info(String.valueOf(e)); + } + } + } + }); + }); + return links; + } +} diff --git a/central/src/main/java/ai/djl/serving/central/classes/package-info.java b/central/src/main/java/ai/djl/serving/central/classes/package-info.java new file mode 100644 index 00000000000..4fb08a7697d --- /dev/null +++ b/central/src/main/java/ai/djl/serving/central/classes/package-info.java @@ -0,0 +1,14 @@ +/* + * Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +/** Contains a classes that do specific server functions. */ +package ai.djl.serving.central.classes; diff --git a/central/src/main/java/ai/djl/serving/central/handler/ModelDownloadHandler.java b/central/src/main/java/ai/djl/serving/central/handler/ModelDownloadHandler.java index 0e4313cc7d0..ede04a307b9 100644 --- a/central/src/main/java/ai/djl/serving/central/handler/ModelDownloadHandler.java +++ b/central/src/main/java/ai/djl/serving/central/handler/ModelDownloadHandler.java @@ -12,69 +12,32 @@ */ package ai.djl.serving.central.handler; -import ai.djl.Application; -import ai.djl.repository.Artifact; import ai.djl.repository.zoo.ModelNotFoundException; -import ai.djl.repository.zoo.ModelZoo; +import ai.djl.serving.central.classes.ModelLink; +import ai.djl.serving.central.http.BadRequestException; import ai.djl.serving.central.responseencoder.HttpRequestResponse; import ai.djl.serving.central.utils.NettyUtils; -import ai.djl.serving.http.BadRequestException; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.SimpleChannelInboundHandler; import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.QueryStringDecoder; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - import java.io.IOException; -import java.net.URI; -import java.util.*; +import java.util.Collections; import java.util.concurrent.CompletableFuture; -final class ModelLink { - - private static final Logger logger = LoggerFactory.getLogger(ModelLink.class); - private static Map links = new HashMap(); - private ModelLink() {} - - private static URI BASE_URI = URI.create("https://mlrepo.djl.ai/"); - - public static Map linkFinder(String modelName) throws IOException, ModelNotFoundException { - Map> models = ModelZoo.listModels(); - models.forEach( - (app, list) -> { - list.forEach( - artifact -> { - if (artifact.getName().equals(modelName)){ - for (Map.Entry entry : - artifact.getFiles().entrySet()) { - URI fileUri = URI.create(entry.getValue().getUri()); - URI baseUri = artifact.getMetadata().getRepositoryUri(); - if (!fileUri.isAbsolute()) { - fileUri = BASE_URI.resolve(baseUri).resolve(fileUri); - } - try { - links.put(entry.getKey(),fileUri); - } catch(Exception e){ - logger.info(String.valueOf(e)); - } - } - }}); - }); - return links; - } -} - - /** - * A handler to handle download requests from the UI - * @author anfee1@morgan.edu + * A handler to handle download requests from the ModelView. * + * @author anfee1@morgan.edu */ public class ModelDownloadHandler extends SimpleChannelInboundHandler { HttpRequestResponse jsonResponse; - public ModelDownloadHandler() { jsonResponse = new HttpRequestResponse(); } + + /** constructing a ModelDownloadHandler. */ + public ModelDownloadHandler() { + jsonResponse = new HttpRequestResponse(); + } /** * handle the deployment request by forwarding the request to the serving-instance. @@ -83,27 +46,27 @@ public class ModelDownloadHandler extends SimpleChannelInboundHandler { - try { - if (modelName!=null) { - return ModelLink.linkFinder(modelName); - } else { - throw new BadRequestException("modelName is mandatory."); - } + () -> { + try { + if (modelName != null) { + return ModelLink.linkFinder(modelName); + } else { + throw new BadRequestException("modelName is mandatory."); + } - } catch (IOException | ModelNotFoundException ex) { - throw new IllegalArgumentException(ex.getMessage(), ex); - } - }) + } catch (IOException | ModelNotFoundException ex) { + throw new IllegalArgumentException(ex.getMessage(), ex); + } + }) .exceptionally((ex) -> Collections.emptyMap()) .thenAccept(linksMap -> jsonResponse.sendAsJson(ctx, request, linksMap)); } - /** {@inheritDoc} */ @Override public boolean acceptInboundMessage(Object msg) { @@ -112,5 +75,4 @@ public boolean acceptInboundMessage(Object msg) { String uri = request.uri(); return uri.startsWith("/serving/models?"); } - } diff --git a/central/src/main/java/ai/djl/serving/central/http/BadRequestException.java b/central/src/main/java/ai/djl/serving/central/http/BadRequestException.java index 5aac81f7868..c4905078b63 100644 --- a/central/src/main/java/ai/djl/serving/central/http/BadRequestException.java +++ b/central/src/main/java/ai/djl/serving/central/http/BadRequestException.java @@ -10,7 +10,7 @@ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ -package ai.djl.serving.http; +package ai.djl.serving.central.http; /** Thrown when a bad HTTP request is received. */ public class BadRequestException extends IllegalArgumentException { diff --git a/central/src/main/java/ai/djl/serving/central/http/package-info.java b/central/src/main/java/ai/djl/serving/central/http/package-info.java new file mode 100644 index 00000000000..fb26e5c6c82 --- /dev/null +++ b/central/src/main/java/ai/djl/serving/central/http/package-info.java @@ -0,0 +1,14 @@ +/* + * Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +/** Contains HTTP codes. */ +package ai.djl.serving.central.http; diff --git a/central/src/main/java/ai/djl/serving/central/responseencoder/HttpRequestResponse.java b/central/src/main/java/ai/djl/serving/central/responseencoder/HttpRequestResponse.java index a7a5cf42f6d..664a3773417 100644 --- a/central/src/main/java/ai/djl/serving/central/responseencoder/HttpRequestResponse.java +++ b/central/src/main/java/ai/djl/serving/central/responseencoder/HttpRequestResponse.java @@ -83,12 +83,10 @@ public void sendAsJson(ChannelHandlerContext ctx, FullHttpRequest request, Objec } /** - * send content of a ByteBuffer as - * response to the client. + * send content of a ByteBuffer as response to the client. * * @param ctx channel context - * @param request full request - * @param entity the response + * @param buffer response buffer */ public void sendByteBuffer(ChannelHandlerContext ctx, ByteBuf buffer) { @@ -103,11 +101,11 @@ public void sendByteBuffer(ChannelHandlerContext ctx, ByteBuf buffer) { * connection after the response being sent. * * @param ctx context - * @param request full request * @param response full response + * @param keepAlive is alive or not */ private void sendAndCleanupConnection( - ChannelHandlerContext ctx, FullHttpResponse response, boolean keepAlive ) { + ChannelHandlerContext ctx, FullHttpResponse response, boolean keepAlive) { HttpUtil.setContentLength(response, response.content().readableBytes()); if (!keepAlive) { // We're going to close the connection as soon as the response is sent, diff --git a/central/src/main/java/ai/djl/serving/central/utils/NettyUtils.java b/central/src/main/java/ai/djl/serving/central/utils/NettyUtils.java index 6d24256551e..44ddd83b6fd 100644 --- a/central/src/main/java/ai/djl/serving/central/utils/NettyUtils.java +++ b/central/src/main/java/ai/djl/serving/central/utils/NettyUtils.java @@ -12,43 +12,22 @@ */ package ai.djl.serving.central.utils; -import ai.djl.ModelException; import ai.djl.modality.Input; -import ai.djl.repository.zoo.ModelNotFoundException; import ai.djl.util.JsonUtils; import io.netty.buffer.ByteBuf; -import io.netty.buffer.Unpooled; -import io.netty.channel.Channel; -import io.netty.channel.ChannelFuture; -import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandlerContext; -import io.netty.handler.codec.http.DefaultFullHttpResponse; -import io.netty.handler.codec.http.FullHttpResponse; -import io.netty.handler.codec.http.HttpHeaderNames; -import io.netty.handler.codec.http.HttpHeaderValues; -import io.netty.handler.codec.http.HttpHeaders; -import io.netty.handler.codec.http.HttpRequest; import io.netty.handler.codec.http.HttpResponseStatus; -import io.netty.handler.codec.http.HttpUtil; -import io.netty.handler.codec.http.HttpVersion; import io.netty.handler.codec.http.QueryStringDecoder; import io.netty.handler.codec.http.multipart.Attribute; import io.netty.handler.codec.http.multipart.FileUpload; import io.netty.handler.codec.http.multipart.InterfaceHttpData; -import io.netty.util.AttributeKey; -import io.netty.util.CharsetUtil; import java.io.IOException; -import java.net.SocketAddress; import java.nio.charset.StandardCharsets; import java.util.List; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; /** A utility class that handling Netty request and response. */ public final class NettyUtils { - private static final Logger logger = LoggerFactory.getLogger("NettyUtils"); - private NettyUtils() {} /** @@ -83,8 +62,6 @@ public static void sendJsonResponse(ChannelHandlerContext ctx, String json) { sendJsonResponse(ctx, json, HttpResponseStatus.OK); } - - /** * Returns the bytes for the specified {@code ByteBuf}. * diff --git a/central/src/main/java/ai/djl/serving/central/utils/package-info.java b/central/src/main/java/ai/djl/serving/central/utils/package-info.java new file mode 100644 index 00000000000..8bee987b03f --- /dev/null +++ b/central/src/main/java/ai/djl/serving/central/utils/package-info.java @@ -0,0 +1,14 @@ +/* + * Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +/** Contains utility classes that hand response and requests. */ +package ai.djl.serving.central.utils; From c9d28c8306427053b6c9c5f4a7d529a48ffcc86e Mon Sep 17 00:00:00 2001 From: afeenster Date: Thu, 4 Mar 2021 15:25:31 -0800 Subject: [PATCH 11/25] Getting rid of unnecessary methods. --- .../djl/serving/central/utils/NettyUtils.java | 35 ------------------- 1 file changed, 35 deletions(-) diff --git a/central/src/main/java/ai/djl/serving/central/utils/NettyUtils.java b/central/src/main/java/ai/djl/serving/central/utils/NettyUtils.java index 44ddd83b6fd..ce0e3623cf5 100644 --- a/central/src/main/java/ai/djl/serving/central/utils/NettyUtils.java +++ b/central/src/main/java/ai/djl/serving/central/utils/NettyUtils.java @@ -13,10 +13,7 @@ package ai.djl.serving.central.utils; import ai.djl.modality.Input; -import ai.djl.util.JsonUtils; import io.netty.buffer.ByteBuf; -import io.netty.channel.ChannelHandlerContext; -import io.netty.handler.codec.http.HttpResponseStatus; import io.netty.handler.codec.http.QueryStringDecoder; import io.netty.handler.codec.http.multipart.Attribute; import io.netty.handler.codec.http.multipart.FileUpload; @@ -30,38 +27,6 @@ public final class NettyUtils { private NettyUtils() {} - /** - * Sends the json object to client. - * - * @param ctx the connection context - * @param json the json object - */ - public static void sendJsonResponse(ChannelHandlerContext ctx, Object json) { - sendJsonResponse(ctx, JsonUtils.GSON_PRETTY.toJson(json), HttpResponseStatus.OK); - } - - /** - * Sends the json string to client with specified status. - * - * @param ctx the connection context - * @param json the json string - * @param status the HTTP status - */ - public static void sendJsonResponse( - ChannelHandlerContext ctx, Object json, HttpResponseStatus status) { - sendJsonResponse(ctx, JsonUtils.GSON_PRETTY.toJson(json), status); - } - - /** - * Sends the json string to client. - * - * @param ctx the connection context - * @param json the json string - */ - public static void sendJsonResponse(ChannelHandlerContext ctx, String json) { - sendJsonResponse(ctx, json, HttpResponseStatus.OK); - } - /** * Returns the bytes for the specified {@code ByteBuf}. * From a059417104217b22d4fc800e50c0602401f827ee Mon Sep 17 00:00:00 2001 From: Lanking Date: Thu, 4 Mar 2021 20:25:47 -0800 Subject: [PATCH 12/25] update onnxruntime along with String tensor (#724) --- .../java/ai/djl/ndarray/BaseNDManager.java | 6 +++++ .../main/java/ai/djl/ndarray/NDManager.java | 11 ++++++-- gradle.properties | 2 +- onnxruntime/onnxruntime-engine/README.md | 4 +-- .../ai/djl/onnxruntime/engine/OrtEngine.java | 2 +- .../djl/onnxruntime/engine/OrtNDManager.java | 27 +++++++++++++++++++ .../ai/djl/onnxruntime/engine/OrtUtils.java | 6 +++++ .../ai/djl/onnxruntime/engine/OrtTest.java | 27 +++++++++++++++++++ .../ai/djl/paddlepaddle/engine/PpEngine.java | 2 +- .../ai/djl/paddlepaddle/engine/PpNDArray.java | 17 ------------ .../ai/djl/tensorflow/engine/TfNDManager.java | 9 ++++++- 11 files changed, 88 insertions(+), 25 deletions(-) diff --git a/api/src/main/java/ai/djl/ndarray/BaseNDManager.java b/api/src/main/java/ai/djl/ndarray/BaseNDManager.java index 1b80170d908..da16422c5f0 100644 --- a/api/src/main/java/ai/djl/ndarray/BaseNDManager.java +++ b/api/src/main/java/ai/djl/ndarray/BaseNDManager.java @@ -49,6 +49,12 @@ public NDArray create(String data) { throw new UnsupportedOperationException("Not supported!"); } + /** {@inheritDoc} */ + @Override + public NDArray create(String[] data) { + throw new UnsupportedOperationException("Not supported!"); + } + /** {@inheritDoc} */ @Override public NDArray create(Shape shape, DataType dataType) { diff --git a/api/src/main/java/ai/djl/ndarray/NDManager.java b/api/src/main/java/ai/djl/ndarray/NDManager.java index cfcf1609b1c..6f95ee035eb 100644 --- a/api/src/main/java/ai/djl/ndarray/NDManager.java +++ b/api/src/main/java/ai/djl/ndarray/NDManager.java @@ -235,14 +235,21 @@ default NDArray create(boolean data) { } /** - * Creates and initializes a scalar {@link NDArray}. NDArray of String DataType only supports - * scalar. + * Creates and initializes a scalar {@link NDArray}. * * @param data the String data that needs to be set * @return a new instance of {@link NDArray} */ NDArray create(String data); + /** + * Creates and initializes 1D {@link NDArray}. + * + * @param data the String data that needs to be set + * @return a new instance of {@link NDArray} + */ + NDArray create(String[] data); + /** * Creates and initializes a 1D {@link NDArray}. * diff --git a/gradle.properties b/gradle.properties index cbc4bbf41c5..7b7b13cab0d 100644 --- a/gradle.properties +++ b/gradle.properties @@ -13,7 +13,7 @@ pytorch_version=1.7.1 tensorflow_version=2.3.1 tflite_version=2.4.1 dlr_version=1.6.0 -onnxruntime_version=1.5.2 +onnxruntime_version=1.7.0 paddlepaddle_version=2.0.0 sentencepiece_version=0.1.92 fasttext_version=0.9.2 diff --git a/onnxruntime/onnxruntime-engine/README.md b/onnxruntime/onnxruntime-engine/README.md index e81711fec0b..40a24d0ce20 100644 --- a/onnxruntime/onnxruntime-engine/README.md +++ b/onnxruntime/onnxruntime-engine/README.md @@ -73,7 +73,7 @@ Maven: com.microsoft.onnxruntime onnxruntime_gpu - 1.5.2 + 1.7.0 runtime ``` @@ -83,5 +83,5 @@ Gradle: implementation("ai.djl.onnxruntime:onnxruntime-engine:0.10.0") { exclude group: "com.microsoft.onnxruntime", module: "onnxruntime" } - implementation "com.microsoft.onnxruntime:onnxruntime_gpu:1.5.2" + implementation "com.microsoft.onnxruntime:onnxruntime_gpu:1.7.0" ``` diff --git a/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngine.java b/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngine.java index d5c37e00d47..5396d0e4e2d 100644 --- a/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngine.java +++ b/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngine.java @@ -70,7 +70,7 @@ private Engine getAlternativeEngine() { /** {@inheritDoc} */ @Override public String getVersion() { - return "1.5.2"; + return "1.7.0"; } /** {@inheritDoc} */ diff --git a/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtNDManager.java b/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtNDManager.java index ee427153209..b0f4dd19e02 100644 --- a/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtNDManager.java +++ b/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtNDManager.java @@ -63,6 +63,33 @@ public OrtNDArray create(Buffer data, Shape shape, DataType dataType) { } } + /** {@inheritDoc} */ + @Override + public NDArray create(String data) { + return create(new String[] {data}, new Shape(1)); + } + + /** {@inheritDoc} */ + @Override + public NDArray create(String[] data) { + return create(data, new Shape(data.length)); + } + + /** + * Create A String tensor based on the provided shape. + * + * @param data the flattened String array + * @param shape the shape of the String NDArray + * @return a new instance of {@link NDArray} + */ + public NDArray create(String[] data, Shape shape) { + try { + return new OrtNDArray(this, OrtUtils.toTensor(env, data, shape)); + } catch (OrtException e) { + throw new EngineException(e); + } + } + /** {@inheritDoc} */ @Override public NDArray zeros(Shape shape, DataType dataType) { diff --git a/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtUtils.java b/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtUtils.java index 44d06580bb6..a1a0408f99c 100644 --- a/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtUtils.java +++ b/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtUtils.java @@ -63,6 +63,12 @@ public static OnnxTensor toTensor( } } + public static OnnxTensor toTensor(OrtEnvironment env, String[] inputs, Shape shape) + throws OrtException { + long[] sh = shape.getShape(); + return OnnxTensor.createTensor(env, inputs, sh); + } + public static NDArray toNDArray(NDManager manager, OnnxTensor tensor) { if (manager instanceof OrtNDManager) { return ((OrtNDManager) manager).create(tensor); diff --git a/onnxruntime/onnxruntime-engine/src/test/java/ai/djl/onnxruntime/engine/OrtTest.java b/onnxruntime/onnxruntime-engine/src/test/java/ai/djl/onnxruntime/engine/OrtTest.java index d867cbb324b..dbcbe7013cb 100644 --- a/onnxruntime/onnxruntime-engine/src/test/java/ai/djl/onnxruntime/engine/OrtTest.java +++ b/onnxruntime/onnxruntime-engine/src/test/java/ai/djl/onnxruntime/engine/OrtTest.java @@ -12,12 +12,17 @@ */ package ai.djl.onnxruntime.engine; +import ai.djl.MalformedModelException; import ai.djl.Model; import ai.djl.ModelException; import ai.djl.inference.Predictor; import ai.djl.modality.Classifications; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.types.Shape; import ai.djl.onnxruntime.zoo.tabular.randomforest.IrisFlower; import ai.djl.repository.zoo.Criteria; +import ai.djl.repository.zoo.ModelNotFoundException; import ai.djl.repository.zoo.ModelZoo; import ai.djl.repository.zoo.ZooModel; import ai.djl.translate.TranslateException; @@ -63,4 +68,26 @@ public void testOrt() throws TranslateException, ModelException, IOException { throw new SkipException("Ignore missing libgomp.so.1 error."); } } + + @Test + public void testStringTensor() + throws MalformedModelException, ModelNotFoundException, IOException, + TranslateException { + Criteria criteria = + Criteria.builder() + .setTypes(NDList.class, NDList.class) + .optEngine("OnnxRuntime") + .optModelUrls( + "https://resources.djl.ai/test-models/onnxruntime/pipeline_tfidf.zip") + .build(); + try (ZooModel model = ModelZoo.loadModel(criteria); + Predictor predictor = model.newPredictor()) { + OrtNDManager manager = (OrtNDManager) OrtNDManager.getSystemManager().newSubManager(); + NDArray stringNd = + manager.create( + new String[] {" Re: Jack can't hide from keith@cco.", " I like dogs"}, + new Shape(1, 2)); + predictor.predict(new NDList(stringNd)); + } + } } diff --git a/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpEngine.java b/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpEngine.java index c08c0e31304..20468243717 100644 --- a/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpEngine.java +++ b/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpEngine.java @@ -60,7 +60,7 @@ Engine getAlternativeEngine() { if (alternativeEngine == null) { Engine engine = Engine.getInstance(); if (engine.getRank() < getRank()) { - // alternativeEngine should not have the same rank as ORT + // alternativeEngine should not have the same rank as Paddle alternativeEngine = engine; } } diff --git a/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpNDArray.java b/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpNDArray.java index 121213161ee..9e160e52213 100644 --- a/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpNDArray.java +++ b/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpNDArray.java @@ -42,23 +42,6 @@ public PpNDArray(PpNDManager manager, long handle) { manager.attach(getUid(), this); } - /** - * Constructs an PaddlePaddle NDArray from a {@link PpNDManager} (internal. Use {@link - * NDManager} instead). - * - * @param manager the manager to attach the new array to - * @param pointer the native tensor handle - * @param shape the shape of {@code PpNDArray} - * @param dataType the data type of {@code PpNDArray} - */ - public PpNDArray(PpNDManager manager, long pointer, Shape shape, DataType dataType) { - super(pointer); - this.manager = manager; - this.shape = shape; - this.dataType = dataType; - manager.attach(getUid(), this); - } - /** {@inheritDoc} */ @Override public NDManager getManager() { diff --git a/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDManager.java b/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDManager.java index efa991c8008..2bfac11167c 100644 --- a/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDManager.java +++ b/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDManager.java @@ -149,12 +149,19 @@ public NDArray create(float data) { /** {@inheritDoc} */ @Override public NDArray create(String data) { - // create scalar tensor with float try (Tensor tensor = TString.scalarOf(data)) { return new TfNDArray(this, tensor); } } + /** {@inheritDoc} */ + @Override + public NDArray create(String[] data) { + try (Tensor tensor = TString.vectorOf(data)) { + return new TfNDArray(this, tensor); + } + } + /** {@inheritDoc} */ @Override public NDArray create(Shape shape, DataType dataType) { From 347eb0767ffa4506a25e04acda1ae0c0248db558 Mon Sep 17 00:00:00 2001 From: Jake Lee Date: Fri, 5 Mar 2021 11:10:28 -0800 Subject: [PATCH 13/25] Add profiler doc (#722) --- docs/development/profiler.md | 62 ++++++++++++++++++++++++++++++++++++ docs/mkdocs.yml | 1 + 2 files changed, 63 insertions(+) create mode 100644 docs/development/profiler.md diff --git a/docs/development/profiler.md b/docs/development/profiler.md new file mode 100644 index 00000000000..3a02ae8f0e5 --- /dev/null +++ b/docs/development/profiler.md @@ -0,0 +1,62 @@ +## Profiler (Experimental) + +Currently, DJL supports experimental profilers for developers that +investigate the performance of operator execution as well as memory consumption. +The profilers are from engines directly and DJL just expose them. +So different engines have different APIs and produce different output format. +We are still working in progress on the feature. +In the future, we are considering to design a unified APIs and output unified format. + +### MXNet + +By setting the following environment variable, it generates `profile.json` after executing the code. + +``` +export MXNET_PROFILER_AUTOSTART=1 +``` + +You can view it in a browser using trace consumer like `chrome://tracing `. Here is a snapshot that shows the sample output. +![img](https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/tutorials/python/profiler/profiler_output_chrome.png) + +### PyTorch + +DJL have integrated PyTorch C++ profiler API and expose `JniUtils.startProfile` and `JniUtils.stopProfile(outputFile)` Java APIs. +`JniUtils.startProfile` takes `useCuda(boolean)`, `recordShape(boolean)` and `profileMemory(boolean)` arguments respectively. +`useCuda` indicates if profiler enables timing of CUDA events using the cudaEvent API. +`recordShape` indicates if information about input dimensions will be collected or not. +`profileMemory` indicates if profiler report memory usage or not. +`JniUtils.stopProfile` takes a outputFile of String type. + +Wrap the code snippet you want to profile in between `JniUtils.startProfile` and `JniUtils.stopProfile`. +Here is an example. + +``` +try (ZooModel model = ModelZoo.loadModel(criteria)) { + try (Predictor predictor = model.newPredictor()) { + Image image = ImageFactory.getInstance() + .fromNDArray(manager.zeros(new Shape(3, 224, 224), DataType.UINT8)); + + JniUtils.startProfile(false, true, true); + predictor.predict(image); + JniUtils.stopProfile(outputFile); + } catch (TranslateException e) { + e.printStackTrace(); +} +``` + +The output format is composed of operator execution record. +Each record contains `name`(operator name), `dur`(time duration), `shape`(input shapes), `cpu mem`(cpu memory footprint). + +``` +{ + "name": "aten::empty", + "ph": "X", + "ts": 24528.313000, + "dur": 5.246000, + "tid": 1, + "pid": "CPU Functions", + "shape": [[], [], [], [], [], []], + "cpu mem": "0 b", + "args": {} +} +``` diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index 5699539d81a..6bf4f9c0cf7 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -76,6 +76,7 @@ nav: - 'docs/development/memory_management.md' - 'docs/development/inference_performance_optimization.md' - 'docs/development/benchmark_with_djl.md' + - 'docs/development/profiler.md' - DJL Community: - 'docs/forums.md' - 'leaders.md' From a363db7b2f85263895b4dec262a0278d175a0ccd Mon Sep 17 00:00:00 2001 From: afeenster Date: Fri, 5 Mar 2021 11:18:47 -0800 Subject: [PATCH 14/25] Resolving some comments. --- .../serving/central/classes/package-info.java | 14 ------- .../central/handler/ModelDownloadHandler.java | 10 ++--- .../responseencoder/HttpRequestResponse.java | 2 +- .../ModelLink.java => utils/ModelUri.java} | 39 +++++++------------ .../webapp/components/DownloadButtons.jsx | 16 ++------ 5 files changed, 25 insertions(+), 56 deletions(-) delete mode 100644 central/src/main/java/ai/djl/serving/central/classes/package-info.java rename central/src/main/java/ai/djl/serving/central/{classes/ModelLink.java => utils/ModelUri.java} (66%) diff --git a/central/src/main/java/ai/djl/serving/central/classes/package-info.java b/central/src/main/java/ai/djl/serving/central/classes/package-info.java deleted file mode 100644 index 4fb08a7697d..00000000000 --- a/central/src/main/java/ai/djl/serving/central/classes/package-info.java +++ /dev/null @@ -1,14 +0,0 @@ -/* - * Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance - * with the License. A copy of the License is located at - * - * http://aws.amazon.com/apache2.0/ - * - * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES - * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions - * and limitations under the License. - */ -/** Contains a classes that do specific server functions. */ -package ai.djl.serving.central.classes; diff --git a/central/src/main/java/ai/djl/serving/central/handler/ModelDownloadHandler.java b/central/src/main/java/ai/djl/serving/central/handler/ModelDownloadHandler.java index ede04a307b9..d01915f5cf2 100644 --- a/central/src/main/java/ai/djl/serving/central/handler/ModelDownloadHandler.java +++ b/central/src/main/java/ai/djl/serving/central/handler/ModelDownloadHandler.java @@ -13,9 +13,9 @@ package ai.djl.serving.central.handler; import ai.djl.repository.zoo.ModelNotFoundException; -import ai.djl.serving.central.classes.ModelLink; import ai.djl.serving.central.http.BadRequestException; import ai.djl.serving.central.responseencoder.HttpRequestResponse; +import ai.djl.serving.central.utils.ModelUri; import ai.djl.serving.central.utils.NettyUtils; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.SimpleChannelInboundHandler; @@ -34,13 +34,13 @@ public class ModelDownloadHandler extends SimpleChannelInboundHandler { try { if (modelName != null) { - return ModelLink.linkFinder(modelName); + return ModelUri.uriFinder(modelName); } else { throw new BadRequestException("modelName is mandatory."); } @@ -64,7 +64,7 @@ protected void channelRead0(ChannelHandlerContext ctx, FullHttpRequest request) } }) .exceptionally((ex) -> Collections.emptyMap()) - .thenAccept(linksMap -> jsonResponse.sendAsJson(ctx, request, linksMap)); + .thenAccept(uriMap -> jsonResponse.sendAsJson(ctx, request, uriMap)); } /** {@inheritDoc} */ diff --git a/central/src/main/java/ai/djl/serving/central/responseencoder/HttpRequestResponse.java b/central/src/main/java/ai/djl/serving/central/responseencoder/HttpRequestResponse.java index 664a3773417..59e97dffdfd 100644 --- a/central/src/main/java/ai/djl/serving/central/responseencoder/HttpRequestResponse.java +++ b/central/src/main/java/ai/djl/serving/central/responseencoder/HttpRequestResponse.java @@ -36,7 +36,7 @@ import java.lang.reflect.Modifier; /** - * serialize to json and send the response to the client. + * Serialize to json and send the response to the client. * * @author erik.bamberg@web.de */ diff --git a/central/src/main/java/ai/djl/serving/central/classes/ModelLink.java b/central/src/main/java/ai/djl/serving/central/utils/ModelUri.java similarity index 66% rename from central/src/main/java/ai/djl/serving/central/classes/ModelLink.java rename to central/src/main/java/ai/djl/serving/central/utils/ModelUri.java index 92c762e12f7..d1329935ee3 100644 --- a/central/src/main/java/ai/djl/serving/central/classes/ModelLink.java +++ b/central/src/main/java/ai/djl/serving/central/utils/ModelUri.java @@ -10,10 +10,11 @@ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ -package ai.djl.serving.central.classes; +package ai.djl.serving.central.utils; import ai.djl.Application; import ai.djl.repository.Artifact; +import ai.djl.repository.zoo.Criteria; import ai.djl.repository.zoo.ModelNotFoundException; import ai.djl.repository.zoo.ModelZoo; import java.io.IOException; @@ -21,33 +22,27 @@ import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -/** - * A class to find the URL link when given a model name. - * - * @author anfee1@morgan.edu - */ -public final class ModelLink { +/** A class to find the URIs when given a model name. */ +public final class ModelUri { private static URI base = URI.create("https://mlrepo.djl.ai/"); - private static Map links = new ConcurrentHashMap<>(); - private static final Logger logger = LoggerFactory.getLogger(ModelLink.class); - private ModelLink() {} + private ModelUri() {} /** - * Takes in a model name and returns a Map of download links. + * Takes in a model name and returns a Map of download URIs. * * @param modelName the connection context - * @return This returns a map of download links - * @throws IOException throws an exception - * @throws ModelNotFoundException throws an exception + * @return a map of download URIs + * @throws IOException if the uri could not be found + * @throws ModelNotFoundException if Model can not be found */ - public static Map linkFinder(String modelName) + public static Map uriFinder(String modelName) throws IOException, ModelNotFoundException { - Map> models = ModelZoo.listModels(); + Criteria criteria = Criteria.builder().optModelName(modelName).build(); + Map> models = ModelZoo.listModels(criteria); + Map uris = new ConcurrentHashMap<>(); models.forEach( (app, list) -> { list.forEach( @@ -60,15 +55,11 @@ public static Map linkFinder(String modelName) if (!fileUri.isAbsolute()) { fileUri = base.resolve(baseUri).resolve(fileUri); } - try { - links.put(entry.getKey(), fileUri); - } catch (Exception e) { - logger.info(String.valueOf(e)); - } + uris.put(entry.getKey(), fileUri); } } }); }); - return links; + return uris; } } diff --git a/central/src/main/webapp/components/DownloadButtons.jsx b/central/src/main/webapp/components/DownloadButtons.jsx index 7914d4a3723..924d79d10d1 100644 --- a/central/src/main/webapp/components/DownloadButtons.jsx +++ b/central/src/main/webapp/components/DownloadButtons.jsx @@ -14,7 +14,6 @@ const useFetch = (modelName) => { axios.get("http://"+window.location.host+"/serving/models?modelName="+modelName) .then(function(response) { let appdata = Object.keys(response.data).map(function(key) { - console.log(key) return { key: key, link: response.data[key] @@ -23,16 +22,9 @@ const useFetch = (modelName) => { setData(appdata); console.log(appdata) }) - .catch(function(error) { - console.log(error); - }) - .then(function() { - // always executed - }); - } fetchData(); - }, ["http://"+window.location.host+"/serving/models?modelName="+modelName]); + }, [modelName]); return data; }; @@ -40,11 +32,11 @@ const useFetch = (modelName) => { export default function ModelDownloadButtons(props) { - const modelLinks = useFetch(props.modelName); + const modelUris = useFetch(props.modelName); return ( <> - {Object.keys(modelLinks).map((keys) => ( - + {Object.keys(modelUris).map((keys) => ( + ) )} From 9d55e6e4571e60161c3d86136d3fb2cec2032a7d Mon Sep 17 00:00:00 2001 From: afeenster Date: Fri, 5 Mar 2021 15:35:56 -0800 Subject: [PATCH 15/25] Using a better criteria incase multiple models have the same name. --- .../central/handler/ModelDownloadHandler.java | 5 +++- .../djl/serving/central/utils/ModelUri.java | 29 +++++++++++-------- .../webapp/components/DownloadButtons.jsx | 9 +++--- .../src/main/webapp/components/ModelView.jsx | 2 +- 4 files changed, 27 insertions(+), 18 deletions(-) diff --git a/central/src/main/java/ai/djl/serving/central/handler/ModelDownloadHandler.java b/central/src/main/java/ai/djl/serving/central/handler/ModelDownloadHandler.java index d01915f5cf2..d6daf2944d6 100644 --- a/central/src/main/java/ai/djl/serving/central/handler/ModelDownloadHandler.java +++ b/central/src/main/java/ai/djl/serving/central/handler/ModelDownloadHandler.java @@ -50,11 +50,14 @@ protected void channelRead0(ChannelHandlerContext ctx, FullHttpRequest request) throws IOException, ModelNotFoundException { QueryStringDecoder decoder = new QueryStringDecoder(request.uri()); String modelName = NettyUtils.getParameter(decoder, "modelName", null); + String modelGroupId = NettyUtils.getParameter(decoder, "groupId", null); + String modelArtifactId = NettyUtils.getParameter(decoder, "artifactId", null); CompletableFuture.supplyAsync( () -> { try { if (modelName != null) { - return ModelUri.uriFinder(modelName); + return ModelUri.uriFinder( + modelArtifactId, modelGroupId, modelName); } else { throw new BadRequestException("modelName is mandatory."); } diff --git a/central/src/main/java/ai/djl/serving/central/utils/ModelUri.java b/central/src/main/java/ai/djl/serving/central/utils/ModelUri.java index d1329935ee3..8be24b50402 100644 --- a/central/src/main/java/ai/djl/serving/central/utils/ModelUri.java +++ b/central/src/main/java/ai/djl/serving/central/utils/ModelUri.java @@ -33,30 +33,35 @@ private ModelUri() {} /** * Takes in a model name and returns a Map of download URIs. * - * @param modelName the connection context + * @param artifactId is the artifactId of the model + * @param groupId is the groupId of the model + * @param name is the name of the model * @return a map of download URIs * @throws IOException if the uri could not be found * @throws ModelNotFoundException if Model can not be found */ - public static Map uriFinder(String modelName) + public static Map uriFinder(String artifactId, String groupId, String name) throws IOException, ModelNotFoundException { - Criteria criteria = Criteria.builder().optModelName(modelName).build(); + Criteria criteria = + Criteria.builder() + .optModelName(name) + .optGroupId(groupId) + .optArtifactId(artifactId) + .build(); Map> models = ModelZoo.listModels(criteria); Map uris = new ConcurrentHashMap<>(); models.forEach( (app, list) -> { list.forEach( artifact -> { - if (artifact.getName().equals(modelName)) { - for (Map.Entry entry : - artifact.getFiles().entrySet()) { - URI fileUri = URI.create(entry.getValue().getUri()); - URI baseUri = artifact.getMetadata().getRepositoryUri(); - if (!fileUri.isAbsolute()) { - fileUri = base.resolve(baseUri).resolve(fileUri); - } - uris.put(entry.getKey(), fileUri); + for (Map.Entry entry : + artifact.getFiles().entrySet()) { + URI fileUri = URI.create(entry.getValue().getUri()); + URI baseUri = artifact.getMetadata().getRepositoryUri(); + if (!fileUri.isAbsolute()) { + fileUri = base.resolve(baseUri).resolve(fileUri); } + uris.put(entry.getKey(), fileUri); } }); }); diff --git a/central/src/main/webapp/components/DownloadButtons.jsx b/central/src/main/webapp/components/DownloadButtons.jsx index 924d79d10d1..97374cf0e5d 100644 --- a/central/src/main/webapp/components/DownloadButtons.jsx +++ b/central/src/main/webapp/components/DownloadButtons.jsx @@ -6,12 +6,13 @@ import { makeStyles } from '@material-ui/core/styles'; import axios from 'axios' -const useFetch = (modelName) => { +const useFetch = (model) => { const [data, setData] = useState([]); useEffect(() => { async function fetchData() { - axios.get("http://"+window.location.host+"/serving/models?modelName="+modelName) + + axios.get("http://"+window.location.host+"/serving/models?modelName="+model.name+"&artifactId="+model.metadata.artifactId+"&groupId="+model.metadata.groupId) .then(function(response) { let appdata = Object.keys(response.data).map(function(key) { return { @@ -24,7 +25,7 @@ const useFetch = (modelName) => { }) } fetchData(); - }, [modelName]); + }, [model.modelName,model.metadata.artifactId,model.metadata.groupId]); return data; }; @@ -32,7 +33,7 @@ const useFetch = (modelName) => { export default function ModelDownloadButtons(props) { - const modelUris = useFetch(props.modelName); + const modelUris = useFetch(props.model); return ( <> {Object.keys(modelUris).map((keys) => ( diff --git a/central/src/main/webapp/components/ModelView.jsx b/central/src/main/webapp/components/ModelView.jsx index 9e9ccadd3e7..f5db8853ae9 100644 --- a/central/src/main/webapp/components/ModelView.jsx +++ b/central/src/main/webapp/components/ModelView.jsx @@ -253,7 +253,7 @@ export default function ModelView(props) { } - + From b4a1cc0fd669cb47f25debad7e110adc48d4f871 Mon Sep 17 00:00:00 2001 From: afeenster Date: Fri, 5 Mar 2021 15:41:19 -0800 Subject: [PATCH 16/25] Fixing the java doc. --- .../src/main/java/ai/djl/serving/central/utils/ModelUri.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/central/src/main/java/ai/djl/serving/central/utils/ModelUri.java b/central/src/main/java/ai/djl/serving/central/utils/ModelUri.java index 8be24b50402..8ebabeaa211 100644 --- a/central/src/main/java/ai/djl/serving/central/utils/ModelUri.java +++ b/central/src/main/java/ai/djl/serving/central/utils/ModelUri.java @@ -31,7 +31,7 @@ public final class ModelUri { private ModelUri() {} /** - * Takes in a model name and returns a Map of download URIs. + * Takes in a model name, artifactId, and groupId to return a Map of download URIs. * * @param artifactId is the artifactId of the model * @param groupId is the groupId of the model From a66e168ed76d301dadedc8299bffd94c7e398e3b Mon Sep 17 00:00:00 2001 From: Zach Kimberg Date: Mon, 8 Mar 2021 12:58:11 -0600 Subject: [PATCH 17/25] Configure verbose of mxnet extra libraries (#728) Change-Id: I66d54aa496cccbb9e8c0a89eeaa458605958d9c6 --- .../src/main/java/ai/djl/mxnet/engine/MxEngine.java | 6 +++++- .../src/main/java/ai/djl/mxnet/jna/JnaUtils.java | 5 +++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxEngine.java b/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxEngine.java index 2106c8d2f52..527a1d1f693 100644 --- a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxEngine.java +++ b/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxEngine.java @@ -40,6 +40,7 @@ public final class MxEngine extends Engine { public static final String ENGINE_NAME = "MXNet"; + private static final String MXNET_EXTRA_LIBRARY_VERBOSE = "MXNET_EXTRA_LIBRARY_VERBOSE"; /** Constructs an MXNet Engine. */ private MxEngine() {} @@ -56,6 +57,9 @@ static Engine newInstance() { // load extra MXNet library String paths = System.getenv("MXNET_EXTRA_LIBRARY_PATH"); + boolean extraLibVerbose = + System.getenv().containsKey(MXNET_EXTRA_LIBRARY_VERBOSE) + && System.getenv(MXNET_EXTRA_LIBRARY_VERBOSE).equals("true"); if (paths != null) { String[] files = paths.split(","); for (String file : files) { @@ -63,7 +67,7 @@ static Engine newInstance() { if (Files.notExists(path)) { throw new FileNotFoundException("Extra Library not found: " + file); } - JnaUtils.loadLib(path.toAbsolutePath().toString(), 1); + JnaUtils.loadLib(path.toAbsolutePath().toString(), extraLibVerbose); } } diff --git a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/jna/JnaUtils.java b/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/jna/JnaUtils.java index 0527049a25a..43daa3e388f 100644 --- a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/jna/JnaUtils.java +++ b/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/jna/JnaUtils.java @@ -1254,8 +1254,9 @@ public static List> inferShape(Symbol symbol, PairList Date: Mon, 8 Mar 2021 13:21:52 -0800 Subject: [PATCH 18/25] Added a TODO for using the artifact repo to get the base uri. --- central/src/main/java/ai/djl/serving/central/utils/ModelUri.java | 1 + 1 file changed, 1 insertion(+) diff --git a/central/src/main/java/ai/djl/serving/central/utils/ModelUri.java b/central/src/main/java/ai/djl/serving/central/utils/ModelUri.java index 8ebabeaa211..b1bc9a6a691 100644 --- a/central/src/main/java/ai/djl/serving/central/utils/ModelUri.java +++ b/central/src/main/java/ai/djl/serving/central/utils/ModelUri.java @@ -26,6 +26,7 @@ /** A class to find the URIs when given a model name. */ public final class ModelUri { + // TODO: Use the artifact repository to create base URI private static URI base = URI.create("https://mlrepo.djl.ai/"); private ModelUri() {} From f881e4d4cc97d8fcc3a55aaff2c262dc4b3f73c2 Mon Sep 17 00:00:00 2001 From: Lanking Date: Mon, 8 Mar 2021 13:42:26 -0800 Subject: [PATCH 19/25] paddlepaddle CN notebook (#730) * paddlepaddle CN notebook * install font Change-Id: I2d749e617b0bf78ecbcd168b82c53a1fab49a2c0 * refactor on name Change-Id: I9e379eee51ceae16391850b3ba9782acb04c4021 * Refine the text Co-authored-by: gstu1130 --- .github/workflows/docs.yml | 2 + docs/mkdocs.yml | 7 +- .../face_mask_detection_paddlepaddle.ipynb | 2 +- .../face_mask_detection_paddlepaddle_zh.ipynb | 365 ++++++++++++++++++ jupyter/paddlepaddle/paddle_ocr_java_zh.ipynb | 322 +++++++++++++++ 5 files changed, 696 insertions(+), 2 deletions(-) create mode 100644 jupyter/paddlepaddle/face_mask_detection_paddlepaddle_zh.ipynb create mode 100644 jupyter/paddlepaddle/paddle_ocr_java_zh.ipynb diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 10b7cf9a432..a19542f1b04 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -18,6 +18,8 @@ jobs: uses: actions/setup-python@v1 with: python-version: '3.x' + - name: Install CN fonts + run: apt-get update && apt-get install fonts-arphic-uming - name: install Python Dependencies run: pip3 install nbconvert==5.6.1 mkdocs mkdocs-exclude mknotebooks==0.4.1 mkdocs-material jupyter Pygments Markdown==3.2.2 - name: Install IJava kernel diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index 6bf4f9c0cf7..b6da69a60ac 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -110,7 +110,12 @@ nav: - TensorFlow Model Zoo: 'tensorflow/tensorflow-model-zoo/README.md' - PaddlePaddle: - Overview: 'paddlepaddle/README.md' - - Load a PaddlePaddle Model: 'jupyter/paddlepaddle/face_mask_detection_paddlepaddle.ipynb' + - Facemask detection using PaddlePaddle: + - English: 'jupyter/paddlepaddle/face_mask_detection_paddlepaddle.ipynb' + - 中文: 'jupyter/paddlepaddle/face_mask_detection_paddlepaddle_zh.ipynb' + - PaddleOCR example: + - English: 'jupyter/paddlepaddle/paddle_ocr_java.ipynb' + - 中文: 'jupyter/paddlepaddle/paddle_ocr_java_zh.ipynb' - Modules: - PaddlePaddle Engine: 'paddlepaddle/paddlepaddle-engine/README.md' - PaddlePaddle Model Zoo: 'paddlepaddle/paddlepaddle-model-zoo/README.md' diff --git a/jupyter/paddlepaddle/face_mask_detection_paddlepaddle.ipynb b/jupyter/paddlepaddle/face_mask_detection_paddlepaddle.ipynb index ef713ea9719..10e15035cb5 100644 --- a/jupyter/paddlepaddle/face_mask_detection_paddlepaddle.ipynb +++ b/jupyter/paddlepaddle/face_mask_detection_paddlepaddle.ipynb @@ -8,7 +8,7 @@ "\n", "In this tutorial, we will be using pretrained PaddlePaddle model from [PaddleHub](https://github.com/PaddlePaddle/PaddleHub/tree/release/v1.5/demo/mask_detection/cpp) to do mask detection on the sample image. To complete this procedure, there are two steps needs to be done:\n", "\n", - "- Recognize face on the image (no maater wearing mask or not) using Face object detection model\n", + "- Recognize face on the image (no matter wearing mask or not) using Face object detection model\n", "- classify the face is wearing mask or not\n", "\n", "These two steps will involve two paddle models. We will implement the corresponding preprocess and postprocess logic to it.\n", diff --git a/jupyter/paddlepaddle/face_mask_detection_paddlepaddle_zh.ipynb b/jupyter/paddlepaddle/face_mask_detection_paddlepaddle_zh.ipynb new file mode 100644 index 00000000000..0d6a9851179 --- /dev/null +++ b/jupyter/paddlepaddle/face_mask_detection_paddlepaddle_zh.ipynb @@ -0,0 +1,365 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 用飛槳+ DJL 實作人臉口罩辨識\n", + "在這個教學中我們將會展示利用 PaddleHub 下載預訓練好的 PaddlePaddle 模型並針對範例照片做人臉口罩辨識。這個範例總共會分成兩個步驟:\n", + "\n", + "- 用臉部檢測模型識別圖片中的人臉(無論是否有戴口罩) \n", + "- 確認圖片中的臉是否有戴口罩\n", + "\n", + "這兩個步驟會包含使用兩個 Paddle 模型,我們會在接下來的內容介紹兩個模型對應需要做的前後處理邏輯\n", + "\n", + "## 導入相關環境依賴及子類別\n", + "在這個例子中的前處理飛槳深度學習引擎需要搭配 DJL 混合模式進行深度學習推理,原因是引擎本身沒有包含 NDArray 操作,因此需要藉用其他引擎的 NDArray 操作能力來完成。這邊我們導入 PyTorch 來做協同的前處理工作:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "// %mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/\n", + "\n", + "%maven ai.djl:api:0.10.0\n", + "%maven ai.djl.paddlepaddle:paddlepaddle-model-zoo:0.10.0\n", + "%maven ai.djl.paddlepaddle:paddlepaddle-native-auto:2.0.0\n", + "%maven org.slf4j:slf4j-api:1.7.26\n", + "%maven org.slf4j:slf4j-simple:1.7.26\n", + "\n", + "// second engine to do preprocessing and postprocessing\n", + "%maven ai.djl.pytorch:pytorch-engine:0.10.0\n", + "%maven ai.djl.pytorch:pytorch-native-auto:1.7.1" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import ai.djl.Application;\n", + "import ai.djl.MalformedModelException;\n", + "import ai.djl.ModelException;\n", + "import ai.djl.inference.Predictor;\n", + "import ai.djl.modality.Classifications;\n", + "import ai.djl.modality.cv.*;\n", + "import ai.djl.modality.cv.output.*;\n", + "import ai.djl.modality.cv.transform.*;\n", + "import ai.djl.modality.cv.translator.ImageClassificationTranslator;\n", + "import ai.djl.modality.cv.util.NDImageUtils;\n", + "import ai.djl.ndarray.*;\n", + "import ai.djl.ndarray.types.Shape;\n", + "import ai.djl.repository.zoo.*;\n", + "import ai.djl.translate.*;\n", + "\n", + "import java.io.IOException;\n", + "import java.nio.file.*;\n", + "import java.util.*;" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 臉部偵測模型\n", + "現在我們可以開始處理第一個模型,在將圖片輸入臉部檢測模型前我們必須先做一些預處理:\n", + "•\t調整圖片尺寸: 以特定比例縮小圖片\n", + "•\t用一個數值對縮小後圖片正規化\n", + "對開發者來說好消息是,DJL 提供了 Translator 介面來幫助開發做這樣的預處理. 一個比較粗略的 Translator 架構如下:\n", + "\n", + "![](https://github.com/awslabs/djl/blob/master/examples/docs/img/workFlow.png?raw=true)\n", + "\n", + "在接下來的段落,我們會利用一個 FaceTranslator 子類別實作來完成工作\n", + "### 預處理\n", + "在這個階段我們會讀取一張圖片並且對其做一些事先的預處理,讓我們先示範讀取一張圖片:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "String url = \"https://raw.githubusercontent.com/PaddlePaddle/PaddleHub/release/v1.5/demo/mask_detection/python/images/mask.jpg\";\n", + "Image img = ImageFactory.getInstance().fromUrl(url);\n", + "img.getWrappedImage();" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "接著,讓我們試著對圖片做一些預處理的轉換:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "NDList processImageInput(NDManager manager, Image input, float shrink) {\n", + " NDArray array = input.toNDArray(manager);\n", + " Shape shape = array.getShape();\n", + " array = NDImageUtils.resize(\n", + " array, (int) (shape.get(1) * shrink), (int) (shape.get(0) * shrink));\n", + " array = array.transpose(2, 0, 1).flip(0); // HWC -> CHW BGR -> RGB\n", + " NDArray mean = manager.create(new float[] {104f, 117f, 123f}, new Shape(3, 1, 1));\n", + " array = array.sub(mean).mul(0.007843f); // normalization\n", + " array = array.expandDims(0); // make batch dimension\n", + " return new NDList(array);\n", + "}\n", + "\n", + "processImageInput(NDManager.newBaseManager(), img, 0.5f);" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "如上述所見,我們已經把圖片轉成如下尺寸的 NDArray: (披量, 通道(RGB), 高度, 寬度). 這是物件檢測模型輸入的格式\n", + "### 後處理\n", + "當我們做後處理時, 模型輸出的格式是 (number_of_boxes, (class_id, probability, xmin, ymin, xmax, ymax)). 我們可以將其存入預先建立好的 DJL 子類別 DetectedObjects 以便做後續操作. 我們假設有一組推論後的輸出是 ((1, 0.99, 0.2, 0.4, 0.5, 0.8)) 並且試著把人像框顯示在圖片上" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "DetectedObjects processImageOutput(NDList list, List className, float threshold) {\n", + " NDArray result = list.singletonOrThrow();\n", + " float[] probabilities = result.get(\":,1\").toFloatArray();\n", + " List names = new ArrayList<>();\n", + " List prob = new ArrayList<>();\n", + " List boxes = new ArrayList<>();\n", + " for (int i = 0; i < probabilities.length; i++) {\n", + " if (probabilities[i] >= threshold) {\n", + " float[] array = result.get(i).toFloatArray();\n", + " names.add(className.get((int) array[0]));\n", + " prob.add((double) probabilities[i]);\n", + " boxes.add(\n", + " new Rectangle(\n", + " array[2], array[3], array[4] - array[2], array[5] - array[3]));\n", + " }\n", + " }\n", + " return new DetectedObjects(names, prob, boxes);\n", + "}\n", + "\n", + "NDArray tempOutput = NDManager.newBaseManager().create(new float[]{1f, 0.99f, 0.1f, 0.1f, 0.2f, 0.2f}, new Shape(1, 6));\n", + "DetectedObjects testBox = processImageOutput(new NDList(tempOutput), Arrays.asList(\"Not Face\", \"Face\"), 0.7f);\n", + "Image newImage = img.duplicate(Image.Type.TYPE_INT_ARGB);\n", + "newImage.drawBoundingBoxes(testBox);\n", + "newImage.getWrappedImage();" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 生成一個翻譯器並執行推理任務\n", + "透過這個步驟,你會理解 DJL 中的前後處理如何運作,現在讓我們把前數的幾個步驟串在一起並對真實圖片進行操作:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class FaceTranslator implements Translator {\n", + "\n", + " private float shrink;\n", + " private float threshold;\n", + " private List className;\n", + "\n", + " FaceTranslator(float shrink, float threshold) {\n", + " this.shrink = shrink;\n", + " this.threshold = threshold;\n", + " className = Arrays.asList(\"Not Face\", \"Face\");\n", + " }\n", + "\n", + " @Override\n", + " public DetectedObjects processOutput(TranslatorContext ctx, NDList list) {\n", + " return processImageOutput(list, className, threshold);\n", + " }\n", + "\n", + " @Override\n", + " public NDList processInput(TranslatorContext ctx, Image input) {\n", + " return processImageInput(ctx.getNDManager(), input, shrink);\n", + " }\n", + "\n", + " @Override\n", + " public Batchifier getBatchifier() {\n", + " return null;\n", + " }\n", + "}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "要執行這個人臉檢測推理,我們必須先從 DJL 的 Paddle Model Zoo 讀取模型,在讀取模型之前我們必須指定好 `Crieteria` . `Crieteria` 是用來確認要從哪邊讀取模型而後執行 `Translator` 來進行模型導入. 接著,我們只要利用 `Predictor` 就可以開始進行推論" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "Criteria criteria =\n", + " Criteria.builder()\n", + " .optApplication(Application.CV.OBJECT_DETECTION)\n", + " .setTypes(Image.class, DetectedObjects.class)\n", + " .optArtifactId(\"face_detection\")\n", + " .optTranslator(new FaceTranslator(0.5f, 0.7f))\n", + " .optFilter(\"flavor\", \"server\")\n", + " .build();\n", + " \n", + "var model = ModelZoo.loadModel(criteria);\n", + "var predictor = model.newPredictor();\n", + "\n", + "DetectedObjects inferenceResult = predictor.predict(img);\n", + "newImage = img.duplicate(Image.Type.TYPE_INT_ARGB);\n", + "newImage.drawBoundingBoxes(inferenceResult);\n", + "newImage.getWrappedImage();" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "如圖片所示,這個推論服務已經可以正確的辨識出圖片中的三張人臉\n", + "## 口罩分類模型\n", + "一旦有了圖片的座標,我們就可以將圖片裁剪到適當大小並且將其傳給口罩分類模型做後續的推論\n", + "### 圖片裁剪\n", + "圖中方框位置的數值範圍從0到1, 只要將這個數值乘上圖片的長寬我們就可以將方框對應到圖片中的準確位置. 為了使裁剪後的圖片有更好的精確度,我們將圖片裁剪成方形,讓我們示範一下:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "int[] extendSquare(\n", + " double xmin, double ymin, double width, double height, double percentage) {\n", + " double centerx = xmin + width / 2;\n", + " double centery = ymin + height / 2;\n", + " double maxDist = Math.max(width / 2, height / 2) * (1 + percentage);\n", + " return new int[] {\n", + " (int) (centerx - maxDist), (int) (centery - maxDist), (int) (2 * maxDist)\n", + " };\n", + "}\n", + "\n", + "Image getSubImage(Image img, BoundingBox box) {\n", + " Rectangle rect = box.getBounds();\n", + " int width = img.getWidth();\n", + " int height = img.getHeight();\n", + " int[] squareBox =\n", + " extendSquare(\n", + " rect.getX() * width,\n", + " rect.getY() * height,\n", + " rect.getWidth() * width,\n", + " rect.getHeight() * height,\n", + " 0.18);\n", + " return img.getSubimage(squareBox[0], squareBox[1], squareBox[2], squareBox[2]);\n", + "}\n", + "\n", + "List faces = inferenceResult.items();\n", + "getSubImage(img, faces.get(2).getBoundingBox()).getWrappedImage();" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 事先準備 Translator 並讀取模型\n", + "在使用臉部檢測模型的時候,我們可以利用 DJL 預先建好的 `ImageClassificationTranslator` 並且加上一些轉換。這個 Translator 提供了一些基礎的圖片翻譯處理並且同時包含一些進階的標準化圖片處理。以這個例子來說, 我們不需要額外建立新的 `Translator` 而使用預先建立的就可以" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "var criteria = Criteria.builder()\n", + " .optApplication(Application.CV.IMAGE_CLASSIFICATION)\n", + " .setTypes(Image.class, Classifications.class)\n", + " .optTranslator(\n", + " ImageClassificationTranslator.builder()\n", + " .addTransform(new Resize(128, 128))\n", + " .addTransform(new ToTensor()) // HWC -> CHW div(255)\n", + " .addTransform(\n", + " new Normalize(\n", + " new float[] {0.5f, 0.5f, 0.5f},\n", + " new float[] {1.0f, 1.0f, 1.0f}))\n", + " .addTransform(nd -> nd.flip(0)) // RGB -> GBR\n", + " .build())\n", + " .optArtifactId(\"mask_classification\")\n", + " .optFilter(\"flavor\", \"server\")\n", + " .build();\n", + "\n", + "var classifyModel = ModelZoo.loadModel(criteria);\n", + "var classifier = classifyModel.newPredictor();" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 執行推論任務\n", + "最後,要完成一個口罩識別的任務,我們只需要將上述的步驟合在一起即可。我們先將圖片做裁剪後並對其做上述的推論操作,結束之後再生成一個新的分類子類別 `DetectedObjects`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "List names = new ArrayList<>();\n", + "List prob = new ArrayList<>();\n", + "List rect = new ArrayList<>();\n", + "for (DetectedObjects.DetectedObject face : faces) {\n", + " Image subImg = getSubImage(img, face.getBoundingBox());\n", + " Classifications classifications = classifier.predict(subImg);\n", + " names.add(classifications.best().getClassName());\n", + " prob.add(face.getProbability());\n", + " rect.add(face.getBoundingBox());\n", + "}\n", + "\n", + "newImage = img.duplicate(Image.Type.TYPE_INT_ARGB);\n", + "newImage.drawBoundingBoxes(new DetectedObjects(names, prob, rect));\n", + "newImage.getWrappedImage();" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Java", + "language": "java", + "name": "java" + }, + "language_info": { + "codemirror_mode": "java", + "file_extension": ".jshell", + "mimetype": "text/x-java-source", + "name": "Java", + "pygments_lexer": "java", + "version": "12.0.2+10" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/jupyter/paddlepaddle/paddle_ocr_java_zh.ipynb b/jupyter/paddlepaddle/paddle_ocr_java_zh.ipynb new file mode 100644 index 00000000000..1df7e59ec2d --- /dev/null +++ b/jupyter/paddlepaddle/paddle_ocr_java_zh.ipynb @@ -0,0 +1,322 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# PaddleOCR在DJL 上的實現\n", + "在這個教程裡,我們會展示利用 PaddleOCR 下載預訓練好文字處理模型並對指定的照片進行文學文字檢測 (OCR)。這個教程總共會分成三個部分:\n", + "\n", + "- 文字區塊檢測: 從圖片檢測出文字區塊\n", + "- 文字角度檢測: 確認文字是否需要旋轉\n", + "- 文字識別: 確認區塊內的文字\n", + "\n", + "## 導入相關環境依賴及子類別\n", + "在這個例子中的前處理飛槳深度學習引擎需要搭配DJL混合模式進行深度學習推理,原因是引擎本身沒有包含ND數組操作,因此需要藉用其他引擎的數組操作能力來完成。這邊我們導入Pytorch來做協同的前處理工作:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "// %mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/\n", + "\n", + "%maven ai.djl:api:0.10.0\n", + "%maven ai.djl.paddlepaddle:paddlepaddle-model-zoo:0.10.0\n", + "%maven ai.djl.paddlepaddle:paddlepaddle-native-auto:2.0.0\n", + "%maven org.slf4j:slf4j-api:1.7.26\n", + "%maven org.slf4j:slf4j-simple:1.7.26\n", + "\n", + "// second engine to do preprocessing and postprocessing\n", + "%maven ai.djl.pytorch:pytorch-engine:0.10.0\n", + "%maven ai.djl.pytorch:pytorch-native-auto:1.7.1" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import ai.djl.*;\n", + "import ai.djl.inference.Predictor;\n", + "import ai.djl.modality.Classifications;\n", + "import ai.djl.modality.cv.Image;\n", + "import ai.djl.modality.cv.ImageFactory;\n", + "import ai.djl.modality.cv.output.*;\n", + "import ai.djl.modality.cv.util.NDImageUtils;\n", + "import ai.djl.ndarray.*;\n", + "import ai.djl.ndarray.types.DataType;\n", + "import ai.djl.ndarray.types.Shape;\n", + "import ai.djl.repository.zoo.*;\n", + "import ai.djl.paddlepaddle.zoo.cv.objectdetection.PpWordDetectionTranslator;\n", + "import ai.djl.paddlepaddle.zoo.cv.imageclassification.PpWordRotateTranslator;\n", + "import ai.djl.paddlepaddle.zoo.cv.wordrecognition.PpWordRecognitionTranslator;\n", + "import ai.djl.translate.*;\n", + "import java.util.concurrent.ConcurrentHashMap;" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 圖片讀取\n", + "首先讓我們載入這次教程會用到的機票範例圖片:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "String url = \"https://resources.djl.ai/images/flight_ticket.jpg\";\n", + "Image img = ImageFactory.getInstance().fromUrl(url);\n", + "img.getWrappedImage();" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 文字區塊檢測\n", + "我們首先從 [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.0/doc/doc_en/inference_en.md#convert-detection-model-to-inference-model) 開發套件中讀取文字檢測的模型,之後我們可以生成一個DJL `Predictor` 並將其命名為 `detector`.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "var criteria1 = Criteria.builder()\n", + " .optEngine(\"PaddlePaddle\")\n", + " .setTypes(Image.class, DetectedObjects.class)\n", + " .optModelUrls(\"https://resources.djl.ai/test-models/paddleOCR/mobile/det_db.zip\")\n", + " .optTranslator(new PpWordDetectionTranslator(new ConcurrentHashMap()))\n", + " .build();\n", + "var detectionModel = ModelZoo.loadModel(criteria1);\n", + "var detector = detectionModel.newPredictor();" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "接著我們檢測出圖片中的文字區塊,這個模型的原始輸出是含有標註所有文字區域的圖算法(Bitmap),我們可以利用`PpWordDetectionTranslator` 函式將圖算法的輸出轉成長方形的方框來裁剪圖片" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "var detectedObj = detector.predict(img);\n", + "Image newImage = img.duplicate(Image.Type.TYPE_INT_ARGB);\n", + "newImage.drawBoundingBoxes(detectedObj);\n", + "newImage.getWrappedImage();" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "如上所示,所標註的文字區塊都非常窄,且沒有包住所有完整的文字區塊。讓我們嘗試使用`extendRect`函式來擴展文字框的長寬到需要的大小, 再利用 `getSubImage` 裁剪並擷取出文子區塊。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "Image getSubImage(Image img, BoundingBox box) {\n", + " Rectangle rect = box.getBounds();\n", + " double[] extended = extendRect(rect.getX(), rect.getY(), rect.getWidth(), rect.getHeight());\n", + " int width = img.getWidth();\n", + " int height = img.getHeight();\n", + " int[] recovered = {\n", + " (int) (extended[0] * width),\n", + " (int) (extended[1] * height),\n", + " (int) (extended[2] * width),\n", + " (int) (extended[3] * height)\n", + " };\n", + " return img.getSubimage(recovered[0], recovered[1], recovered[2], recovered[3]);\n", + "}\n", + "\n", + "double[] extendRect(double xmin, double ymin, double width, double height) {\n", + " double centerx = xmin + width / 2;\n", + " double centery = ymin + height / 2;\n", + " if (width > height) {\n", + " width += height * 2.0;\n", + " height *= 3.0;\n", + " } else {\n", + " height += width * 2.0;\n", + " width *= 3.0;\n", + " }\n", + " double newX = centerx - width / 2 < 0 ? 0 : centerx - width / 2;\n", + " double newY = centery - height / 2 < 0 ? 0 : centery - height / 2;\n", + " double newWidth = newX + width > 1 ? 1 - newX : width;\n", + " double newHeight = newY + height > 1 ? 1 - newY : height;\n", + " return new double[] {newX, newY, newWidth, newHeight};\n", + "}" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "讓我們輸出其中一個文字區塊" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "List boxes = detectedObj.items();\n", + "var sample = getSubImage(img, boxes.get(5).getBoundingBox());\n", + "sample.getWrappedImage();" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 文字角度檢測\n", + "我們從 [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.0/doc/doc_en/inference_en.md#convert-angle-classification-model-to-inference-model) 輸出這個模型並確認圖片及文字是否需要旋轉。以下的代碼會讀入這個模型並生成a `rotateClassifier` 子類別" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "var criteria2 = Criteria.builder()\n", + " .optEngine(\"PaddlePaddle\")\n", + " .setTypes(Image.class, Classifications.class)\n", + " .optModelUrls(\"https://resources.djl.ai/test-models/paddleOCR/mobile/cls.zip\")\n", + " .optTranslator(new PpWordRotateTranslator())\n", + " .build();\n", + "var rotateModel = ModelZoo.loadModel(criteria2);\n", + "var rotateClassifier = rotateModel.newPredictor();" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 文字識別\n", + "\n", + "我們從 [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.0/doc/doc_en/inference_en.md#convert-recognition-model-to-inference-model) 輸出這個模型並識別圖片中的文字, 我們一樣仿造上述的步驟讀取這個模型\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "var criteria3 = Criteria.builder()\n", + " .optEngine(\"PaddlePaddle\")\n", + " .setTypes(Image.class, String.class)\n", + " .optModelUrls(\"https://resources.djl.ai/test-models/paddleOCR/mobile/rec_crnn.zip\")\n", + " .optTranslator(new PpWordRecognitionTranslator())\n", + " .build();\n", + "var recognitionModel = ModelZoo.loadModel(criteria3);\n", + "var recognizer = recognitionModel.newPredictor();" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "接著我們可以試著套用這兩個模型在先前剪裁好的文字區塊上" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "System.out.println(rotateClassifier.predict(sample));\n", + "recognizer.predict(sample);" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "最後我們把這些模型串連在一起並套用在整張圖片上看看結果會如何。DJL提供了豐富的影像工具包讓你可以從圖片中擷取出文字並且完美呈現" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "Image rotateImg(Image image) {\n", + " try (NDManager manager = NDManager.newBaseManager()) {\n", + " NDArray rotated = NDImageUtils.rotate90(image.toNDArray(manager), 1);\n", + " return ImageFactory.getInstance().fromNDArray(rotated);\n", + " }\n", + "}\n", + "\n", + "List names = new ArrayList<>();\n", + "List prob = new ArrayList<>();\n", + "List rect = new ArrayList<>();\n", + "\n", + "for (int i = 0; i < boxes.size(); i++) {\n", + " Image subImg = getSubImage(img, boxes.get(i).getBoundingBox());\n", + " if (subImg.getHeight() * 1.0 / subImg.getWidth() > 1.5) {\n", + " subImg = rotateImg(subImg);\n", + " }\n", + " Classifications.Classification result = rotateClassifier.predict(subImg).best();\n", + " if (\"Rotate\".equals(result.getClassName()) && result.getProbability() > 0.8) {\n", + " subImg = rotateImg(subImg);\n", + " }\n", + " String name = recognizer.predict(subImg);\n", + " names.add(name);\n", + " prob.add(-1.0);\n", + " rect.add(boxes.get(i).getBoundingBox());\n", + "}\n", + "newImage.drawBoundingBoxes(new DetectedObjects(names, prob, rect));\n", + "newImage.getWrappedImage();" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Java", + "language": "java", + "name": "java" + }, + "language_info": { + "codemirror_mode": "java", + "file_extension": ".jshell", + "mimetype": "text/x-java-source", + "name": "Java", + "pygments_lexer": "java", + "version": "11.0.5+10-LTS" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} From a6a2232446723e839cf27e395992bb168735f330 Mon Sep 17 00:00:00 2001 From: Lanking Date: Mon, 8 Mar 2021 16:09:59 -0800 Subject: [PATCH 20/25] add EI documentation (#733) * add EI documentation * fix pmd rules Change-Id: Ieee5577c26f6df2843781f8f9180de35069a5de3 --- docs/mkdocs.yml | 1 + docs/mxnet/mxnet_backend_optimizer.md | 33 +++++++++++++++++++ .../java/ai/djl/mxnet/engine/MxEngine.java | 7 ++-- 3 files changed, 38 insertions(+), 3 deletions(-) create mode 100644 docs/mxnet/mxnet_backend_optimizer.md diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index b6da69a60ac..92b01552145 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -84,6 +84,7 @@ nav: - Overview: 'mxnet/README.md' - Import Gluon Model: 'docs/mxnet/how_to_convert_your_model_to_symbol.md' - Load a MXNet Model: 'jupyter/load_mxnet_model.ipynb' + - Backend Optimizer for MXNet: 'docs/mxnet/mxnet_backend_optimizer.md' - Modules: - MXNet Engine: 'mxnet/mxnet-engine/README.md' - MXNet Model Zoo: 'mxnet/mxnet-model-zoo/README.md' diff --git a/docs/mxnet/mxnet_backend_optimizer.md b/docs/mxnet/mxnet_backend_optimizer.md new file mode 100644 index 00000000000..e6e80423fbf --- /dev/null +++ b/docs/mxnet/mxnet_backend_optimizer.md @@ -0,0 +1,33 @@ +# Custom backend optimizer support on Apache MXNet + +Apache MXNet currently implemented a method that allowing third-party +backend optimizer to accelerate the inference result. DJL currently +also exposed this functionality through the `MxOptimizeFor` option +of the Criteria. + +``` +.optOption("MxOptimizeFor", "optimizer_name") +``` + +After a name is passed, DJL will try to find the party library from + the environment variable called `MXNET_EXTRA_LIBRARY_PATH`. Users are required to +set this environment variable to locate the library. After that, you should see the messages from the inference to see if the library is enabled. + +Here is a list of supporting backend optimizers: + +## AWS [Elastic Inference Accelerator](https://docs.aws.amazon.com/elastic-inference/latest/developerguide/what-is-ei.html) (EIA) + +Currently, you can use EIA library for DJL on all EI enabled instance. + +You can follow the instruction to start your EI application with DJL: + +``` +> https://docs.aws.amazon.com/elastic-inference/latest/developerguide/ei-mxnet.html +``` + +Currently, the EI logging is disabled. For debugging purpose, you can enable that through +setting the `MXNET_EXTRA_LIBRARY_VERBOSE` environment variable: + +``` +export MXNET_EXTRA_LIBRARY_VERBOSE=true +``` diff --git a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxEngine.java b/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxEngine.java index 527a1d1f693..fb9c16df8ec 100644 --- a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxEngine.java +++ b/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxEngine.java @@ -57,9 +57,10 @@ static Engine newInstance() { // load extra MXNet library String paths = System.getenv("MXNET_EXTRA_LIBRARY_PATH"); - boolean extraLibVerbose = - System.getenv().containsKey(MXNET_EXTRA_LIBRARY_VERBOSE) - && System.getenv(MXNET_EXTRA_LIBRARY_VERBOSE).equals("true"); + boolean extraLibVerbose = false; + if (System.getenv().containsKey(MXNET_EXTRA_LIBRARY_VERBOSE)) { + extraLibVerbose = Boolean.parseBoolean(System.getenv(MXNET_EXTRA_LIBRARY_VERBOSE)); + } if (paths != null) { String[] files = paths.split(","); for (String file : files) { From a90129ee518c091f0b83184d3ce470a2f84969d4 Mon Sep 17 00:00:00 2001 From: Lanking Date: Mon, 8 Mar 2021 18:01:49 -0800 Subject: [PATCH 21/25] allow pytorch stream model loading (#729) * allow pytorch stream model loading * updates Change-Id: Ibc26261b90de673712e90de0d640a8f32f23763e --- .github/workflows/docs.yml | 2 +- api/src/main/java/ai/djl/ndarray/NDArray.java | 9 ++++ .../java/ai/djl/ndarray/NDArrayAdapter.java | 6 +++ .../java/ai/djl/dlr/engine/DlrEngine.java | 3 ++ .../java/ai/djl/mxnet/engine/MxNDArray.java | 7 +++ .../ai/djl/onnxruntime/engine/OrtEngine.java | 3 ++ .../ai/djl/onnxruntime/engine/OrtNDArray.java | 35 +++++++++++---- .../ai/djl/onnxruntime/engine/OrtUtils.java | 2 + .../ai/djl/onnxruntime/engine/OrtTest.java | 8 +++- .../ai/djl/paddlepaddle/engine/PpEngine.java | 3 ++ .../java/ai/djl/pytorch/engine/PtModel.java | 13 ++++++ .../java/ai/djl/pytorch/engine/PtNDArray.java | 6 +++ .../djl/pytorch/integration/PtModelTest.java | 43 +++++++++++++++++++ .../ai/djl/tensorflow/engine/TfNDArray.java | 7 +++ .../ai/djl/tflite/engine/TfLiteEngine.java | 5 ++- 15 files changed, 140 insertions(+), 12 deletions(-) create mode 100644 pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/PtModelTest.java diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index a19542f1b04..7b3ab8f9925 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -19,7 +19,7 @@ jobs: with: python-version: '3.x' - name: Install CN fonts - run: apt-get update && apt-get install fonts-arphic-uming + run: sudo apt-get update && sudo apt-get install fonts-arphic-uming - name: install Python Dependencies run: pip3 install nbconvert==5.6.1 mkdocs mkdocs-exclude mknotebooks==0.4.1 mkdocs-material jupyter Pygments Markdown==3.2.2 - name: Install IJava kernel diff --git a/api/src/main/java/ai/djl/ndarray/NDArray.java b/api/src/main/java/ai/djl/ndarray/NDArray.java index 758c052537f..f7a307be806 100644 --- a/api/src/main/java/ai/djl/ndarray/NDArray.java +++ b/api/src/main/java/ai/djl/ndarray/NDArray.java @@ -371,6 +371,15 @@ default boolean[] toBooleanArray() { return ret; } + /** + * Converts this {@code NDArray} to a String array. + * + *

This method is only applicable to the String typed NDArray and not for printing purpose + * + * @return Array of Strings + */ + String[] toStringArray(); + /** * Converts this {@code NDArray} to a Number array based on its {@link DataType}. * diff --git a/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java b/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java index 8afb9ca3ec0..56ffa83c8d1 100644 --- a/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java +++ b/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java @@ -135,6 +135,12 @@ default NDArray stopGradient() { throw new UnsupportedOperationException(UNSUPPORTED_MSG); } + /** {@inheritDoc} */ + @Override + default String[] toStringArray() { + throw new UnsupportedOperationException(UNSUPPORTED_MSG); + } + /** {@inheritDoc} */ @Override default ByteBuffer toByteBuffer() { diff --git a/dlr/dlr-engine/src/main/java/ai/djl/dlr/engine/DlrEngine.java b/dlr/dlr-engine/src/main/java/ai/djl/dlr/engine/DlrEngine.java index e28426e048f..f0975ff8296 100644 --- a/dlr/dlr-engine/src/main/java/ai/djl/dlr/engine/DlrEngine.java +++ b/dlr/dlr-engine/src/main/java/ai/djl/dlr/engine/DlrEngine.java @@ -47,6 +47,9 @@ static Engine newInstance() { } private Engine getAlternativeEngine() { + if (Boolean.getBoolean("ai.djl.dlr.disable_alternative")) { + return null; + } if (alternativeEngine == null) { Engine engine = Engine.getInstance(); if (engine.getRank() < getRank()) { diff --git a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java b/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java index e8ecde776a9..c1277dca17a 100644 --- a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java +++ b/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java @@ -268,11 +268,18 @@ public boolean hasGradient() { return hasGradient; } + /** {@inheritDoc} */ @Override public NDArray stopGradient() { return manager.invoke("stop_gradient", this, null); } + /** {@inheritDoc} */ + @Override + public String[] toStringArray() { + throw new UnsupportedOperationException("String NDArray is not supported!"); + } + /** {@inheritDoc} */ @Override public ByteBuffer toByteBuffer() { diff --git a/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngine.java b/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngine.java index 5396d0e4e2d..0be98f3f722 100644 --- a/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngine.java +++ b/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngine.java @@ -57,6 +57,9 @@ public int getRank() { } private Engine getAlternativeEngine() { + if (Boolean.getBoolean("ai.djl.onnx.disable_alternative")) { + return null; + } if (alternativeEngine == null) { Engine engine = Engine.getInstance(); if (engine.getRank() < getRank()) { diff --git a/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtNDArray.java b/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtNDArray.java index 42556a788ca..806fb96714e 100644 --- a/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtNDArray.java +++ b/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtNDArray.java @@ -13,12 +13,16 @@ package ai.djl.onnxruntime.engine; import ai.djl.Device; +import ai.djl.engine.EngineException; import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDArrayAdapter; import ai.djl.ndarray.NDManager; import ai.djl.ndarray.types.DataType; import ai.djl.ndarray.types.Shape; import ai.onnxruntime.OnnxTensor; +import ai.onnxruntime.OrtException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; import java.util.Arrays; import java.util.UUID; @@ -117,20 +121,35 @@ public void detach() { manager = OrtNDManager.getSystemManager(); } + /** {@inheritDoc} */ + @Override + public String[] toStringArray() { + try { + return (String[]) tensor.getValue(); + } catch (OrtException e) { + throw new EngineException(e); + } + } + + /** {@inheritDoc} */ + @Override + public ByteBuffer toByteBuffer() { + return tensor.getByteBuffer().order(ByteOrder.nativeOrder()); + } + /** {@inheritDoc} */ @Override public String toString() { if (isClosed) { return "This array is already closed"; } - return "ND: " - + getShape() - + ' ' - + getDevice() - + ' ' - + getDataType() - + '\n' - + Arrays.toString(toArray()); + String arrStr; + if (getDataType() == DataType.STRING) { + arrStr = Arrays.toString(toStringArray()); + } else { + arrStr = Arrays.toString(toArray()); + } + return "ND: " + getShape() + ' ' + getDevice() + ' ' + getDataType() + '\n' + arrStr; } /** {@inheritDoc} */ diff --git a/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtUtils.java b/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtUtils.java index a1a0408f99c..90d8e200e98 100644 --- a/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtUtils.java +++ b/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtUtils.java @@ -98,6 +98,8 @@ public static DataType toDataType(OnnxJavaType javaType) { return DataType.BOOLEAN; case UNKNOWN: return DataType.UNKNOWN; + case STRING: + return DataType.STRING; default: throw new UnsupportedOperationException("type is not supported: " + javaType); } diff --git a/onnxruntime/onnxruntime-engine/src/test/java/ai/djl/onnxruntime/engine/OrtTest.java b/onnxruntime/onnxruntime-engine/src/test/java/ai/djl/onnxruntime/engine/OrtTest.java index dbcbe7013cb..4d8df458605 100644 --- a/onnxruntime/onnxruntime-engine/src/test/java/ai/djl/onnxruntime/engine/OrtTest.java +++ b/onnxruntime/onnxruntime-engine/src/test/java/ai/djl/onnxruntime/engine/OrtTest.java @@ -73,6 +73,7 @@ public void testOrt() throws TranslateException, ModelException, IOException { public void testStringTensor() throws MalformedModelException, ModelNotFoundException, IOException, TranslateException { + System.setProperty("ai.djl.onnx.disable_alternative", "true"); Criteria criteria = Criteria.builder() .setTypes(NDList.class, NDList.class) @@ -82,12 +83,15 @@ public void testStringTensor() .build(); try (ZooModel model = ModelZoo.loadModel(criteria); Predictor predictor = model.newPredictor()) { - OrtNDManager manager = (OrtNDManager) OrtNDManager.getSystemManager().newSubManager(); + OrtNDManager manager = (OrtNDManager) model.getNDManager(); NDArray stringNd = manager.create( new String[] {" Re: Jack can't hide from keith@cco.", " I like dogs"}, new Shape(1, 2)); - predictor.predict(new NDList(stringNd)); + NDList result = predictor.predict(new NDList(stringNd)); + Assert.assertEquals(result.size(), 2); + Assert.assertEquals(result.get(0).toLongArray(), new long[] {1}); } + System.clearProperty("ai.djl.onnx.disable_alternative"); } } diff --git a/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpEngine.java b/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpEngine.java index 20468243717..d2913b2eeca 100644 --- a/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpEngine.java +++ b/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpEngine.java @@ -57,6 +57,9 @@ public int getRank() { } Engine getAlternativeEngine() { + if (Boolean.getBoolean("ai.djl.paddlepaddle.disable_alternative")) { + return null; + } if (alternativeEngine == null) { Engine engine = Engine.getInstance(); if (engine.getRank() < getRank()) { diff --git a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtModel.java b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtModel.java index 3d56947e34d..20ee7430fb2 100644 --- a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtModel.java +++ b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtModel.java @@ -26,6 +26,7 @@ import ai.djl.util.PairList; import java.io.FileNotFoundException; import java.io.IOException; +import java.io.InputStream; import java.nio.file.Files; import java.nio.file.Path; import java.util.ArrayList; @@ -101,6 +102,18 @@ public void load(Path modelPath, String prefix, Map options) } } + /** + * Load PyTorch model from {@link InputStream}. + * + *

Currently, only TorchScript file are supported + * + * @param modelStream the stream of the model file + * @throws IOException model loading error + */ + public void load(InputStream modelStream) throws IOException { + block = JniUtils.loadModule((PtNDManager) manager, modelStream, manager.getDevice(), false); + } + private Path findModelFile(String prefix) { if (Files.isRegularFile(modelDir)) { Path file = modelDir; diff --git a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java index dedf1988561..1a9e6613797 100644 --- a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java +++ b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java @@ -211,6 +211,12 @@ public ByteBuffer toByteBuffer() { return JniUtils.getByteBuffer(this); } + /** {@inheritDoc} */ + @Override + public String[] toStringArray() { + throw new UnsupportedOperationException("String NDArray is not supported!"); + } + /** {@inheritDoc} */ @Override public void set(Buffer data) { diff --git a/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/PtModelTest.java b/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/PtModelTest.java new file mode 100644 index 00000000000..cda9b86d8dd --- /dev/null +++ b/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/PtModelTest.java @@ -0,0 +1,43 @@ +/* + * Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.pytorch.integration; + +import ai.djl.Model; +import ai.djl.inference.Predictor; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.types.Shape; +import ai.djl.pytorch.engine.PtModel; +import ai.djl.translate.NoopTranslator; +import ai.djl.translate.TranslateException; +import java.io.IOException; +import java.net.URL; +import org.testng.Assert; +import org.testng.annotations.Test; + +public class PtModelTest { + + @Test + public void testLoadFromStream() throws IOException, TranslateException { + URL url = + new URL("https://djl-ai.s3.amazonaws.com/resources/test-models/traced_resnet18.pt"); + try (PtModel model = (PtModel) Model.newInstance("test model")) { + model.load(url.openStream()); + try (Predictor predictor = model.newPredictor(new NoopTranslator())) { + NDArray array = model.getNDManager().ones(new Shape(1, 3, 224, 224)); + NDArray result = predictor.predict(new NDList(array)).singletonOrThrow(); + Assert.assertEquals(result.getShape(), new Shape(1, 1000)); + } + } + } +} diff --git a/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java b/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java index 478257d23b7..8707eac015c 100644 --- a/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java +++ b/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java @@ -257,6 +257,13 @@ public boolean[] toBooleanArray() { return result; } + @Override + public String[] toStringArray() { + // TODO: Parse String Array from bytes[] + throw new UnsupportedOperationException( + "TensorFlow does not supporting printing String NDArray"); + } + /** {@inheritDoc} */ @Override public ByteBuffer toByteBuffer() { diff --git a/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteEngine.java b/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteEngine.java index a28ec7c9fe1..140d6369f36 100644 --- a/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteEngine.java +++ b/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteEngine.java @@ -54,6 +54,9 @@ public int getRank() { } private Engine getAlternativeEngine() { + if (Boolean.getBoolean("ai.djl.tflite.disable_alternative")) { + return null; + } if (alternativeEngine == null) { Engine engine = Engine.getInstance(); if (engine.getRank() < getRank()) { @@ -67,7 +70,7 @@ private Engine getAlternativeEngine() { /** {@inheritDoc} */ @Override public String getVersion() { - return "1.4.0"; + return "2.4.1"; } /** {@inheritDoc} */ From c6aebe027077aa78d5bb89af886c2154ad885707 Mon Sep 17 00:00:00 2001 From: Lanking Date: Tue, 9 Mar 2021 11:20:30 -0800 Subject: [PATCH 22/25] add NDList decode from inputStream (#734) Change-Id: I6a31d8b0b955f2dbb762220b101e3928a34699c1 --- api/src/main/java/ai/djl/ndarray/NDList.java | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/api/src/main/java/ai/djl/ndarray/NDList.java b/api/src/main/java/ai/djl/ndarray/NDList.java index 88f727cf807..19a4b3888da 100644 --- a/api/src/main/java/ai/djl/ndarray/NDList.java +++ b/api/src/main/java/ai/djl/ndarray/NDList.java @@ -19,6 +19,7 @@ import java.io.DataInputStream; import java.io.DataOutputStream; import java.io.IOException; +import java.io.InputStream; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; @@ -77,7 +78,18 @@ public NDList(Collection other) { * @return {@code NDList} */ public static NDList decode(NDManager manager, byte[] byteArray) { - try (DataInputStream dis = new DataInputStream(new ByteArrayInputStream(byteArray))) { + return decode(manager, new ByteArrayInputStream(byteArray)); + } + + /** + * Decodes NDList from {@link InputStream}. + * + * @param manager manager assigned to {@link NDArray} + * @param is input stream contains the ndlist information + * @return {@code NDList} + */ + public static NDList decode(NDManager manager, InputStream is) { + try (DataInputStream dis = new DataInputStream(is)) { int size = dis.readInt(); if (size < 0) { throw new IllegalArgumentException("Invalid NDList size: " + size); From 8342d443250019a30e8e4bb95708032661b57710 Mon Sep 17 00:00:00 2001 From: Zach Kimberg Date: Tue, 9 Mar 2021 11:23:36 -0800 Subject: [PATCH 23/25] Remove memory scope and improve memory management (#695) The MemoryScope reveals a number of shortcomings within the DJL memory management. While the MemoryScope is deleted, many of them are fixed as part of this PR. First, the NDManager.{attach, detach} were renamed to xxxInternal. This is to differentiate them from the attach and detach methods that are intended to be used. There are two new concepts in memory management. An NDResource interface was created to combine the concepts of managed memory that was used in NDArray and NDList. It could also be used in more classes in the future. This includes the getManager, attach, and detach. Within the NDManager, it gains a second "management convention". The first convention of normal resources are added to the manager and then closed when the manager closes. This works for small numbers of things on the NDArray, but not when operations transitively create. So, the second convention is a tempResource. Instead of freeing them when the manager is closed, they are returned to their original manager. This is used to create a temporary scope, do operations within it, and then the inputs and return value are returned to the parent while the intermediate work is cleaned. This also matches the concepts of ownership/borrowing as well. Using these, a few additional helper methods were created. There is `NDManager.from(resource)` to ease creation of managers based on a resource. There is also `scopeManager.ret(returnValue)` to help with returning values outside of the scopeManager. Lastly, there is a `scopeManager.{temp,}AttachAll` to attach a number of resources to a manager within a single call. Using these improvements, the new method were applied to the old locations where MemoryScope was used as well as an additional case in NDManagerEx. Also, the old attach methods were altered to be `void`. Because the return values are no longer used anywhere and are not as necessary in the current scheme, I figured it would simplify things. It also helps for things like `NDList.attach` which does not have a single original NDManager when attaching. Change-Id: I91d109cd14d70fa64fd8fffa0b50d88ab053013e --- .../java/ai/djl/ndarray/BaseNDManager.java | 26 ++- api/src/main/java/ai/djl/ndarray/NDArray.java | 30 +--- .../java/ai/djl/ndarray/NDArrayAdapter.java | 2 +- api/src/main/java/ai/djl/ndarray/NDList.java | 47 ++--- .../main/java/ai/djl/ndarray/NDManager.java | 78 +++++++- .../main/java/ai/djl/ndarray/NDResource.java | 57 ++++++ .../java/ai/djl/nn/transformer/BertBlock.java | 18 +- .../BertMaskedLanguageModelBlock.java | 48 ++--- .../BertMaskedLanguageModelLoss.java | 68 +++---- .../nn/transformer/BertNextSentenceLoss.java | 49 ++--- .../nn/transformer/BertPretrainingBlock.java | 47 ++--- .../ai/djl/nn/transformer/MemoryScope.java | 168 ------------------ .../java/ai/djl/dlr/engine/DlrNDArray.java | 17 +- .../java/ai/djl/dlr/engine/DlrNDManager.java | 6 +- .../java/ai/djl/mxnet/engine/CachedOp.java | 4 +- .../java/ai/djl/mxnet/engine/MxNDArray.java | 19 +- .../java/ai/djl/mxnet/engine/MxNDArrayEx.java | 7 +- .../java/ai/djl/mxnet/engine/MxNDManager.java | 6 +- .../main/java/ai/djl/mxnet/engine/Symbol.java | 4 +- .../ai/djl/onnxruntime/engine/OrtNDArray.java | 17 +- .../djl/onnxruntime/engine/OrtNDManager.java | 6 +- .../ai/djl/paddlepaddle/engine/PpNDArray.java | 17 +- .../djl/paddlepaddle/engine/PpNDManager.java | 6 +- .../java/ai/djl/pytorch/engine/PtNDArray.java | 21 ++- .../ai/djl/pytorch/engine/PtNDManager.java | 6 +- .../ai/djl/pytorch/engine/PtSymbolBlock.java | 6 +- .../ai/djl/tensorflow/engine/TfNDArray.java | 17 +- .../ai/djl/tensorflow/engine/TfNDManager.java | 6 +- .../ai/djl/tflite/engine/TfLiteNDArray.java | 17 +- .../ai/djl/tflite/engine/TfLiteNDManager.java | 6 +- 30 files changed, 409 insertions(+), 417 deletions(-) create mode 100644 api/src/main/java/ai/djl/ndarray/NDResource.java delete mode 100644 api/src/main/java/ai/djl/nn/transformer/MemoryScope.java diff --git a/api/src/main/java/ai/djl/ndarray/BaseNDManager.java b/api/src/main/java/ai/djl/ndarray/BaseNDManager.java index da16422c5f0..2459b37f333 100644 --- a/api/src/main/java/ai/djl/ndarray/BaseNDManager.java +++ b/api/src/main/java/ai/djl/ndarray/BaseNDManager.java @@ -15,6 +15,7 @@ import ai.djl.Device; import ai.djl.ndarray.types.DataType; import ai.djl.ndarray.types.Shape; +import ai.djl.util.Pair; import ai.djl.util.PairList; import java.nio.Buffer; import java.nio.file.Path; @@ -34,12 +35,14 @@ public abstract class BaseNDManager implements NDManager { protected String name; protected Device device; protected ConcurrentHashMap resources; + protected ConcurrentHashMap> tempResources; protected AtomicBoolean closed = new AtomicBoolean(false); protected BaseNDManager(NDManager parent, Device device) { this.parent = parent; this.device = Device.defaultIfNull(device, getEngine()); resources = new ConcurrentHashMap<>(); + tempResources = new ConcurrentHashMap<>(); uid = UUID.randomUUID().toString(); } @@ -203,7 +206,7 @@ public String toString() { /** {@inheritDoc} */ @Override - public synchronized void attach(String resourceId, AutoCloseable resource) { + public synchronized void attachInternal(String resourceId, AutoCloseable resource) { if (closed.get()) { throw new IllegalStateException("NDManager has been closed already."); } @@ -212,7 +215,17 @@ public synchronized void attach(String resourceId, AutoCloseable resource) { /** {@inheritDoc} */ @Override - public synchronized void detach(String resourceId) { + public void tempAttachInternal( + NDManager originalManager, String resourceId, NDResource resource) { + if (closed.get()) { + throw new IllegalStateException("NDManager has been closed already."); + } + tempResources.put(resourceId, new Pair<>(resource, originalManager)); + } + + /** {@inheritDoc} */ + @Override + public synchronized void detachInternal(String resourceId) { if (closed.get()) { // This may happen in the middle of BaseNDManager.close() return; @@ -244,7 +257,14 @@ public synchronized void close() { logger.error("Resource close failed.", e); } } - parent.detach(uid); + for (Pair resource : tempResources.values()) { + try { + resource.getKey().attach(resource.getValue()); + } catch (Exception e) { + logger.error("Temporary resource return failed.", e); + } + } + parent.detachInternal(uid); resources.clear(); } } diff --git a/api/src/main/java/ai/djl/ndarray/NDArray.java b/api/src/main/java/ai/djl/ndarray/NDArray.java index f7a307be806..559439e68b3 100644 --- a/api/src/main/java/ai/djl/ndarray/NDArray.java +++ b/api/src/main/java/ai/djl/ndarray/NDArray.java @@ -40,7 +40,7 @@ * href="https://github.com/awslabs/djl/blob/master/docs/development/memory_management.md">NDArray * Memory Management Guide */ -public interface NDArray extends AutoCloseable { +public interface NDArray extends NDResource { /** * Decodes {@code NDArray} from bytes. @@ -53,13 +53,6 @@ static NDArray decode(NDManager manager, byte[] byteArray) { return manager.decode(byteArray); } - /** - * Returns the {@link NDManager} used to create this {@code NDArray}. - * - * @return the {@link NDManager} used to create this {@code NDArray} - */ - NDManager getManager(); - /** * Returns the name of this {@code NDArray}. * @@ -146,27 +139,6 @@ default byte[] encode() { return NDSerializer.encode(this); } - /** - * Attaches this {@code NDArray} to the specified {@link NDManager}. - * - *

Attached resource will be closed when the {@link NDManager} is closed. - * - * @param manager the {@link NDManager} to be attached - * @return the original {@link NDManager} - */ - NDManager attach(NDManager manager); - - /** - * Detaches the {@code NDArray} from current {@link NDManager}'s lifecycle. - * - *

The {@code NDArray} becomes un-managed, it is the user's responsibility to close the - * {@code NDArray}. Failure to close the resource might cause your machine to run out of native - * memory. - * - * @see NDManager - */ - void detach(); - /** * Moves this {@code NDArray} to a different {@link Device}. * diff --git a/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java b/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java index 56ffa83c8d1..613338ad025 100644 --- a/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java +++ b/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java @@ -83,7 +83,7 @@ default SparseFormat getSparseFormat() { /** {@inheritDoc} */ @Override - default NDManager attach(NDManager manager) { + default void attach(NDManager manager) { throw new UnsupportedOperationException(UNSUPPORTED_MSG); } diff --git a/api/src/main/java/ai/djl/ndarray/NDList.java b/api/src/main/java/ai/djl/ndarray/NDList.java index 19a4b3888da..021de74e810 100644 --- a/api/src/main/java/ai/djl/ndarray/NDList.java +++ b/api/src/main/java/ai/djl/ndarray/NDList.java @@ -23,9 +23,6 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; -import java.util.List; -import java.util.stream.Collectors; -import java.util.stream.IntStream; /** * An {@code NDList} represents a sequence of {@link NDArray}s with names. @@ -35,7 +32,7 @@ * * @see NDArray */ -public class NDList extends ArrayList implements AutoCloseable { +public class NDList extends ArrayList implements NDResource { private static final long serialVersionUID = 1L; @@ -212,36 +209,28 @@ public NDList toDevice(Device device, boolean copy) { return newNDList; } - /** - * Attaches each ndarray in this list to the specified manager. - * - * @param manager the manager to attach the lists to - * @return a list of {@code NDManager} with which original NDArray are attached - * @see NDArray#attach(NDManager) - */ - public List attach(NDManager manager) { - return stream().map(array -> array.attach(manager)).collect(Collectors.toList()); + /** {@inheritDoc} */ + @Override + public NDManager getManager() { + return head().getManager(); } - /** - * Attaches each ndarray in this list to the specified manager. - * - * @param managers the list of managers to attach - * @return a list of {@code NDManager} with which original NDArray are attached - */ - public List attach(List managers) { - return IntStream.range(0, size()) - .mapToObj(i -> get(i).attach(managers.get(i))) - .collect(Collectors.toList()); + /** {@inheritDoc} */ + @Override + public void attach(NDManager manager) { + stream().forEach(array -> array.attach(manager)); } - /** - * Detaches each ndarray in this list from their current managers. - * - * @see NDArray#detach() - */ + /** {@inheritDoc} */ + @Override + public void tempAttach(NDManager manager) { + stream().forEach(array -> array.tempAttach(manager)); + } + + /** {@inheritDoc} */ + @Override public void detach() { - forEach(NDArray::detach); + stream().forEach(NDResource::detach); } /** diff --git a/api/src/main/java/ai/djl/ndarray/NDManager.java b/api/src/main/java/ai/djl/ndarray/NDManager.java index 6f95ee035eb..b7beb80c2b2 100644 --- a/api/src/main/java/ai/djl/ndarray/NDManager.java +++ b/api/src/main/java/ai/djl/ndarray/NDManager.java @@ -133,6 +133,16 @@ static NDManager newBaseManager(Device device, String engineName) { return Engine.getEngine(engineName).newBaseManager(device); } + /** + * Creates a new manager based on the given resource. + * + * @param resource the resource to use + * @return a new memory scrope containing the array + */ + static NDManager from(NDResource resource) { + return resource.getManager().newSubManager(); + } + /** * Allocates a new engine specific direct byte buffer. * @@ -1281,14 +1291,34 @@ default NDArray randomNormal( Device getDevice(); /** - * Attaches a {@link NDArray} or {@code NDManager} to this {@code NDManager}. + * Attaches a resource to this {@code NDManager}. + * + *

The attached resource will be closed when this {@code NDManager} is closed. + * + *

This attachment is internal. Many resources will internally track which manager they are + * attached to. In that case, you should call {@link NDResource#attach(NDManager)} instead and + * that should then call attachInternal. + * + * @param resourceId the unique resourceId + * @param resource the {@link AutoCloseable} resource to be attached + */ + void attachInternal(String resourceId, AutoCloseable resource); + + /** + * Temporarily attaches a resource to this {@code NDManager} to be returned when this is closed. + * + *

The attached resource will be returned to it's original manager when this {@code + * NDManager} is closed. * - *

Attached resource will be closed when this {@code NDManager} is closed. + *

This attachment is internal. Many resources will internally track which manager they are + * attached to. In that case, you should call {@link NDResource#attach(NDManager)} instead and + * that should then call tempAttachInternal. * + * @param originalManager the original manager to return the resource to * @param resourceId the unique resourceId * @param resource the {@link AutoCloseable} resource to be attached */ - void attach(String resourceId, AutoCloseable resource); + void tempAttachInternal(NDManager originalManager, String resourceId, NDResource resource); /** * Detaches a {@link NDArray} from this {@code NDManager}'s lifecycle. @@ -1297,9 +1327,49 @@ default NDArray randomNormal( * resource. Failed to close the resource has to wait on GC to be freed, and might cause out of * native memory. * + *

This detach is internal. Many resources will internally track which manager they are + * attached to. In that case, you should call {@link NDResource#detach()} instead and that + * should then call detachInternal. + * * @param resourceId the resourceId to be removed from this {@code NDManager}'s lifecycle */ - void detach(String resourceId); + void detachInternal(String resourceId); + + /** + * Returns a value outside of this manager by attaching to this manager's parent. + * + * @param resource the resource to return + * @param the type of the resource + * @return the passed in resource, after attaching to a new manager + */ + default T ret(T resource) { + resource.attach(getParentManager()); + return resource; + } + + /** + * Attaches all resources to this manager. + * + * @param resources the resources to attach + * @see NDResource#attach(NDManager) + */ + default void attachAll(NDResource... resources) { + for (NDResource resource : resources) { + resource.attach(this); + } + } + + /** + * Temporarily attaches all resources to this manager. + * + * @param resources the resources to attach + * @see NDResource#tempAttach(NDManager) + */ + default void tempAttachAll(NDResource... resources) { + for (NDResource resource : resources) { + resource.tempAttach(this); + } + } /** * An engine specific generic invocation to native operation. diff --git a/api/src/main/java/ai/djl/ndarray/NDResource.java b/api/src/main/java/ai/djl/ndarray/NDResource.java new file mode 100644 index 00000000000..8033d022608 --- /dev/null +++ b/api/src/main/java/ai/djl/ndarray/NDResource.java @@ -0,0 +1,57 @@ +/* + * Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.ndarray; + +/** An object which is managed by an {@link NDManager} and tracks the manager it is attached to. */ +public interface NDResource extends AutoCloseable { + + /** + * Returns the {@link NDManager} that manages this. + * + * @return the {@link NDManager} that manages this. + */ + NDManager getManager(); + + /** + * Attaches this {@link NDResource} to the specified {@link NDManager}. + * + *

Attached resource will be closed when the {@link NDManager} is closed. + * + * @param manager the {@link NDManager} to be attached to + */ + void attach(NDManager manager); + + /** + * Temporarily attaches this {@link NDResource} to the specified {@link NDManager}. + * + *

Attached resource will be returned to the original manager when the {@link NDManager} is + * closed. + * + * @param manager the {@link NDManager} to be attached to + */ + void tempAttach(NDManager manager); + + /** + * Detaches the {@link NDResource} from current {@link NDManager}'s lifecycle. + * + *

This becomes un-managed and it is the user's responsibility to close this. Failure to + * close the resource might cause your machine to run out of native memory. + * + * @see NDManager + */ + void detach(); + + /** {@inheritDoc} */ + @Override + void close(); +} diff --git a/api/src/main/java/ai/djl/nn/transformer/BertBlock.java b/api/src/main/java/ai/djl/nn/transformer/BertBlock.java index b8ca2caf120..868f6ace5fc 100644 --- a/api/src/main/java/ai/djl/nn/transformer/BertBlock.java +++ b/api/src/main/java/ai/djl/nn/transformer/BertBlock.java @@ -213,7 +213,8 @@ protected NDList forwardInternal( NDArray typeIds = inputs.get(1); // Third are the masks for the input NDArray masks = inputs.get(2); - MemoryScope initScope = MemoryScope.from(tokenIds).add(typeIds, masks); + NDManager initScope = NDManager.from(tokenIds); + initScope.tempAttachAll(inputs); // Create embeddings for inputs NDArray embeddedTokens = tokenEmbedding.forward(ps, new NDList(tokenIds), training).singletonOrThrow(); @@ -241,16 +242,15 @@ protected NDList forwardInternal( .mul(-100000f); // turn 1s (original 0s) into -100000 // Run through all transformer blocks NDList lastOutput = dropoutEmbedding; - initScope - .remove(tokenIds, typeIds, masks) - .waitToRead(dropoutEmbedding) - .waitToRead(offsetMask) - .close(); + initScope.ret(lastOutput); + initScope.ret(offsetMask); + initScope.close(); for (final TransformerEncoderBlock block : transformerEncoderBlocks) { NDList input = new NDList(lastOutput.head(), offsetMask); - MemoryScope innerScope = MemoryScope.from(input); - lastOutput = block.forward(ps, input, training); - innerScope.remove(offsetMask).waitToRead(lastOutput).close(); + try (NDManager innerScope = NDManager.from(input)) { + innerScope.tempAttachAll(input); + lastOutput = innerScope.ret(block.forward(ps, input, training)); + } } // We also return the pooled output - this is an additional fully connected layer // only applied to the first token, assumed to be the CLS token to be used for training diff --git a/api/src/main/java/ai/djl/nn/transformer/BertMaskedLanguageModelBlock.java b/api/src/main/java/ai/djl/nn/transformer/BertMaskedLanguageModelBlock.java index 7fc336434b6..4c52e1e5f0a 100644 --- a/api/src/main/java/ai/djl/nn/transformer/BertMaskedLanguageModelBlock.java +++ b/api/src/main/java/ai/djl/nn/transformer/BertMaskedLanguageModelBlock.java @@ -114,30 +114,32 @@ protected NDList forwardInternal( NDArray sequenceOutput = inputs.get(0); // (B, S, E) NDArray maskedIndices = inputs.get(1); // (B, I) NDArray embeddingTable = inputs.get(2); // (D, E) - MemoryScope scope = MemoryScope.from(sequenceOutput).add(maskedIndices); - NDArray gatheredTokens = gatherFromIndices(sequenceOutput, maskedIndices); // (B * I, E) - NDArray projectedTokens = - hiddenActivation.apply( - sequenceProjection - .forward(ps, new NDList(gatheredTokens), training) - .head()); // (B * I, E) - NDArray normalizedTokens = - sequenceNorm - .forward(ps, new NDList(projectedTokens), training) - .head(); // (B * I, E) - // raw logits for each position to correspond to an entry in the embedding table - NDArray embeddingTransposed = embeddingTable.transpose(); - embeddingTransposed.attach(gatheredTokens.getManager()); - NDArray logits = normalizedTokens.dot(embeddingTransposed); // (B * I, D) - // we add an offset for each dictionary entry - NDArray logitsWithBias = - logits.add(ps.getValue(dictionaryBias, logits.getDevice(), training)); // (B * I, D) - // now we apply log Softmax to get proper log probabilities - NDArray logProbs = logitsWithBias.logSoftmax(1); // (B * I, D) + try (NDManager scope = NDManager.from(sequenceOutput)) { + scope.tempAttachAll(sequenceOutput, maskedIndices); + NDArray gatheredTokens = gatherFromIndices(sequenceOutput, maskedIndices); // (B * I, E) + NDArray projectedTokens = + hiddenActivation.apply( + sequenceProjection + .forward(ps, new NDList(gatheredTokens), training) + .head()); // (B * I, E) + NDArray normalizedTokens = + sequenceNorm + .forward(ps, new NDList(projectedTokens), training) + .head(); // (B * I, E) + // raw logits for each position to correspond to an entry in the embedding table + NDArray embeddingTransposed = embeddingTable.transpose(); + embeddingTransposed.attach(gatheredTokens.getManager()); + NDArray logits = normalizedTokens.dot(embeddingTransposed); // (B * I, D) + // we add an offset for each dictionary entry + NDArray logitsWithBias = + logits.add( + ps.getValue( + dictionaryBias, logits.getDevice(), training)); // (B * I, D) + // now we apply log Softmax to get proper log probabilities + NDArray logProbs = logitsWithBias.logSoftmax(1); // (B * I, D) - scope.remove(sequenceOutput, maskedIndices).waitToRead(logProbs).close(); - - return new NDList(logProbs); + return scope.ret(new NDList(logProbs)); + } } /** {@inheritDoc} */ diff --git a/api/src/main/java/ai/djl/nn/transformer/BertMaskedLanguageModelLoss.java b/api/src/main/java/ai/djl/nn/transformer/BertMaskedLanguageModelLoss.java index 203fa1cf857..dc640e3d171 100644 --- a/api/src/main/java/ai/djl/nn/transformer/BertMaskedLanguageModelLoss.java +++ b/api/src/main/java/ai/djl/nn/transformer/BertMaskedLanguageModelLoss.java @@ -14,6 +14,7 @@ import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; import ai.djl.ndarray.types.DataType; import ai.djl.training.loss.Loss; @@ -40,29 +41,30 @@ public BertMaskedLanguageModelLoss(int labelIdx, int maskIdx, int logProbsIdx) { @Override public NDArray evaluate(NDList labels, NDList predictions) { - MemoryScope scope = MemoryScope.from(labels).add(predictions); + try (NDManager scope = NDManager.from(labels)) { + scope.tempAttachAll(labels, predictions); - NDArray logProbs = predictions.get(logProbsIdx); // (B * I, D) - int dictionarySize = (int) logProbs.getShape().get(1); - NDArray targetIds = labels.get(labelIdx).flatten(); // (B * I) - NDArray mask = labels.get(maskIdx).flatten().toType(DataType.FLOAT32, false); // (B * I) - NDArray targetOneHots = targetIds.oneHot(dictionarySize); - // Multiplying log_probs and one_hot_labels leaves the log probabilities of the correct - // entries. - // By summing we get the total predicition quality. We want to minimize the error, - // so we negate the value - as we have logarithms, probability = 1 means log(prob) = 0, - // the less sure we are the smaller the log value. - NDArray perExampleLoss = logProbs.mul(targetOneHots).sum(new int[] {1}).mul(-1); - // Multiplying log_probs and one_hot_labels leaves the log probabilities of the correct - // entries. - // By summing we get the total prediction quality. - NDArray numerator = perExampleLoss.mul(mask).sum(); - // We normalize the loss by the actual number of predictions we had to make - NDArray denominator = mask.sum().add(1e-5f); - NDArray result = numerator.div(denominator); + NDArray logProbs = predictions.get(logProbsIdx); // (B * I, D) + int dictionarySize = (int) logProbs.getShape().get(1); + NDArray targetIds = labels.get(labelIdx).flatten(); // (B * I) + NDArray mask = labels.get(maskIdx).flatten().toType(DataType.FLOAT32, false); // (B * I) + NDArray targetOneHots = targetIds.oneHot(dictionarySize); + // Multiplying log_probs and one_hot_labels leaves the log probabilities of the correct + // entries. + // By summing we get the total predicition quality. We want to minimize the error, + // so we negate the value - as we have logarithms, probability = 1 means log(prob) = 0, + // the less sure we are the smaller the log value. + NDArray perExampleLoss = logProbs.mul(targetOneHots).sum(new int[] {1}).mul(-1); + // Multiplying log_probs and one_hot_labels leaves the log probabilities of the correct + // entries. + // By summing we get the total prediction quality. + NDArray numerator = perExampleLoss.mul(mask).sum(); + // We normalize the loss by the actual number of predictions we had to make + NDArray denominator = mask.sum().add(1e-5f); + NDArray result = numerator.div(denominator); - scope.remove(labels, predictions).waitToRead(result).close(); - return result; + return scope.ret(result); + } } /** @@ -73,19 +75,19 @@ public NDArray evaluate(NDList labels, NDList predictions) { * @return the percentage of correctly predicted masked tokens */ public NDArray accuracy(NDList labels, NDList predictions) { - MemoryScope scope = MemoryScope.from(labels).add(predictions); + try (NDManager scope = NDManager.from(labels)) { + scope.tempAttachAll(labels, predictions); - NDArray mask = labels.get(maskIdx).flatten(); // (B * I) - NDArray targetIds = labels.get(labelIdx).flatten(); // (B * I) - NDArray logProbs = predictions.get(logProbsIdx); // (B * I, D) - NDArray predictedIs = logProbs.argMax(1).toType(DataType.INT32, false); // (B * I) - NDArray equal = predictedIs.eq(targetIds).mul(mask); - NDArray equalCount = equal.sum().toType(DataType.FLOAT32, false); - NDArray count = mask.sum().toType(DataType.FLOAT32, false); - NDArray result = equalCount.div(count); + NDArray mask = labels.get(maskIdx).flatten(); // (B * I) + NDArray targetIds = labels.get(labelIdx).flatten(); // (B * I) + NDArray logProbs = predictions.get(logProbsIdx); // (B * I, D) + NDArray predictedIs = logProbs.argMax(1).toType(DataType.INT32, false); // (B * I) + NDArray equal = predictedIs.eq(targetIds).mul(mask); + NDArray equalCount = equal.sum().toType(DataType.FLOAT32, false); + NDArray count = mask.sum().toType(DataType.FLOAT32, false); + NDArray result = equalCount.div(count); - scope.remove(labels, predictions).waitToRead(result); - - return result; + return scope.ret(result); + } } } diff --git a/api/src/main/java/ai/djl/nn/transformer/BertNextSentenceLoss.java b/api/src/main/java/ai/djl/nn/transformer/BertNextSentenceLoss.java index 0916e096c1c..b11e2828d82 100644 --- a/api/src/main/java/ai/djl/nn/transformer/BertNextSentenceLoss.java +++ b/api/src/main/java/ai/djl/nn/transformer/BertNextSentenceLoss.java @@ -14,6 +14,7 @@ import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; import ai.djl.ndarray.types.DataType; import ai.djl.training.loss.Loss; @@ -38,20 +39,21 @@ public BertNextSentenceLoss(int labelIdx, int nextSentencePredictionIdx) { @Override public NDArray evaluate(NDList labels, NDList predictions) { - MemoryScope scope = MemoryScope.from(labels).add(predictions); - NDArray label = labels.get(labelIdx).toType(DataType.FLOAT32, false); - // predictions are log(softmax) - NDArray logPredictions = predictions.get(nextSentencePredictionIdx); - NDArray oneHotLabels = label.oneHot(2); - // we use negative log likelihood as loss: log(softmax) turns high confidence into - // negative values near one, low confidence into negative values near -inf, - // negating gives almost 0 for high confidence and near +inf for very low confidence - NDArray logPredictionForLabels = oneHotLabels.mul(logPredictions); - NDArray summedPredictions = logPredictionForLabels.sum(new int[] {1}); - NDArray perExampleLoss = summedPredictions.mul(-1f); - NDArray result = perExampleLoss.mean(); - scope.remove(labels, predictions).waitToRead(result).close(); - return result; + try (NDManager scope = NDManager.from(labels)) { + scope.tempAttachAll(labels, predictions); + NDArray label = labels.get(labelIdx).toType(DataType.FLOAT32, false); + // predictions are log(softmax) + NDArray logPredictions = predictions.get(nextSentencePredictionIdx); + NDArray oneHotLabels = label.oneHot(2); + // we use negative log likelihood as loss: log(softmax) turns high confidence into + // negative values near one, low confidence into negative values near -inf, + // negating gives almost 0 for high confidence and near +inf for very low confidence + NDArray logPredictionForLabels = oneHotLabels.mul(logPredictions); + NDArray summedPredictions = logPredictionForLabels.sum(new int[] {1}); + NDArray perExampleLoss = summedPredictions.mul(-1f); + NDArray result = perExampleLoss.mean(); + return scope.ret(result); + } } /** @@ -62,15 +64,16 @@ public NDArray evaluate(NDList labels, NDList predictions) { * @return the fraction of correct predictions. */ public NDArray accuracy(NDList labels, NDList predictions) { - MemoryScope scope = MemoryScope.from(labels).add(predictions); - NDArray label = labels.get(labelIdx); - NDArray predictionLogProbs = predictions.get(nextSentencePredictionIdx); - // predictions are log(softmax) -> highest confidence is highest (negative) value near 0 - NDArray prediction = predictionLogProbs.argMax(1).toType(DataType.INT32, false); - NDArray equalCount = label.eq(prediction).sum().toType(DataType.FLOAT32, false); - NDArray result = equalCount.div(label.getShape().size()); - scope.remove(labels, predictions).waitToRead(result).close(); + try (NDManager scope = NDManager.from(labels)) { + scope.tempAttachAll(labels, predictions); + NDArray label = labels.get(labelIdx); + NDArray predictionLogProbs = predictions.get(nextSentencePredictionIdx); + // predictions are log(softmax) -> highest confidence is highest (negative) value near 0 + NDArray prediction = predictionLogProbs.argMax(1).toType(DataType.INT32, false); + NDArray equalCount = label.eq(prediction).sum().toType(DataType.FLOAT32, false); + NDArray result = equalCount.div(label.getShape().size()); - return result; + return scope.ret(result); + } } } diff --git a/api/src/main/java/ai/djl/nn/transformer/BertPretrainingBlock.java b/api/src/main/java/ai/djl/nn/transformer/BertPretrainingBlock.java index d7f44089564..5e60f28c25c 100644 --- a/api/src/main/java/ai/djl/nn/transformer/BertPretrainingBlock.java +++ b/api/src/main/java/ai/djl/nn/transformer/BertPretrainingBlock.java @@ -70,30 +70,31 @@ protected NDList forwardInternal( NDArray typeIds = inputs.get(1); NDArray sequenceMasks = inputs.get(2); NDArray maskedIndices = inputs.get(3); - MemoryScope scope = MemoryScope.from(tokenIds).add(typeIds, sequenceMasks, maskedIndices); - // run the core bert model - NDList bertResult = - bertBlock.forward(ps, new NDList(tokenIds, typeIds, sequenceMasks), training); - NDArray embeddedSequence = bertResult.get(0); - NDArray pooledOutput = bertResult.get(1); - // apply pooled output to the classifier - NDArray nextSentenceProbabilities = - nsBlock.forward(ps, new NDList(pooledOutput), training).singletonOrThrow(); - // de-mask masked tokens - NDArray embeddingTable = - bertBlock.getTokenEmbedding().getValue(ps, embeddedSequence.getDevice(), training); - NDArray logProbs = - mlBlock.forward( - ps, - new NDList(embeddedSequence, maskedIndices, embeddingTable), - training) - .singletonOrThrow(); + try (NDManager scope = NDManager.from(tokenIds)) { + scope.tempAttachAll(inputs); + // run the core bert model + NDList bertResult = + bertBlock.forward(ps, new NDList(tokenIds, typeIds, sequenceMasks), training); + NDArray embeddedSequence = bertResult.get(0); + NDArray pooledOutput = bertResult.get(1); + // apply pooled output to the classifier + NDArray nextSentenceProbabilities = + nsBlock.forward(ps, new NDList(pooledOutput), training).singletonOrThrow(); + // de-mask masked tokens + NDArray embeddingTable = + bertBlock + .getTokenEmbedding() + .getValue(ps, embeddedSequence.getDevice(), training); + NDArray logProbs = + mlBlock.forward( + ps, + new NDList(embeddedSequence, maskedIndices, embeddingTable), + training) + .singletonOrThrow(); - scope.remove(tokenIds, typeIds, sequenceMasks, maskedIndices) - .waitToRead(nextSentenceProbabilities, logProbs) - .close(); - // return the next sentence & masked language result to apply the loss to - return new NDList(nextSentenceProbabilities, logProbs); + // return the next sentence & masked language result to apply the loss to + return scope.ret(new NDList(nextSentenceProbabilities, logProbs)); + } } /** {@inheritDoc} */ diff --git a/api/src/main/java/ai/djl/nn/transformer/MemoryScope.java b/api/src/main/java/ai/djl/nn/transformer/MemoryScope.java deleted file mode 100644 index e1bf4ca159a..00000000000 --- a/api/src/main/java/ai/djl/nn/transformer/MemoryScope.java +++ /dev/null @@ -1,168 +0,0 @@ -/* - * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance - * with the License. A copy of the License is located at - * - * http://aws.amazon.com/apache2.0/ - * - * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES - * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions - * and limitations under the License. - */ -package ai.djl.nn.transformer; - -import ai.djl.ndarray.LazyNDArray; -import ai.djl.ndarray.NDArray; -import ai.djl.ndarray.NDList; -import ai.djl.ndarray.NDManager; - -/** - * Helper class for more complicated memory management scenarios. Allows to avoid boilerplate for - * memory handling. Makes sure the sub NDManager used is connected to the correct GPU to avoid - * crashes. - */ -public final class MemoryScope implements AutoCloseable { - - private NDManager parentManager; - private NDManager subManager; - - private MemoryScope(NDManager parentManager, NDManager subManager) { - this.parentManager = parentManager; - this.subManager = subManager; - } - - /** - * Adds all arrays in the given lists to this memory scope. - * - * @param lists the lists whose arrays to add to this scope, may be empty - * @return this scope - */ - public MemoryScope add(NDList... lists) { - for (NDList list : lists) { - list.attach(subManager); - } - return this; - } - - /** - * Adds the given arrays to this scopes sub manager. - * - * @param arrays the arrays to add - * @return this scope - */ - public MemoryScope add(NDArray... arrays) { - for (NDArray array : arrays) { - array.attach(subManager); - } - return this; - } - - /** - * Remove the given arrays from this scope and attach them back to this scopes parent NDManager. - * - * @param lists the lists containing the arrays to remove - * @return this scope - */ - public MemoryScope remove(NDList... lists) { - for (NDList list : lists) { - list.attach(parentManager); - } - return this; - } - - /** - * Remove the given arrays from this scope and attach them back to this scopes parent NDManager. - * - * @param arrays arrays to remove - * @return this scope - */ - public MemoryScope remove(NDArray... arrays) { - for (NDArray array : arrays) { - array.attach(parentManager); - } - return this; - } - - /** - * Returns the NDManager used to manage this scopes resources. - * - * @return the NDManager used to manage this scopes resources - */ - public NDManager getScopeManager() { - return subManager; - } - - /** - * Waits for all given arrays to be ready to read, i.e. waits for pending computations that - * write to them, then removes them from this scope. - * - * @param arrays arrays to wait for - * @return this scope - */ - public MemoryScope waitToRead(NDArray... arrays) { - for (NDArray array : arrays) { - if (array instanceof LazyNDArray) { - LazyNDArray lazyNDArray = (LazyNDArray) array; - lazyNDArray.waitToRead(); - } - remove(array); - } - return this; - } - - /** - * Waits for all arrays in all given lists to be ready to be read, i.e. waits for pending - * computations that write to them, then removes them from this scope. - * - * @param lists may be empty - * @return this scope - */ - public MemoryScope waitToRead(NDList... lists) { - for (NDList list : lists) { - if (list != null) { - for (NDArray array : list) { - waitToRead(array); - } - } - } - return this; - } - - /** - * Closes this scope by closing the sub manager used to manage it. This causes all arrays still - * attached to this scope to be closed as well. - */ - @Override - public void close() { - subManager.close(); - } - - /** - * Creates a new memory scope for the device of the given array and adds the array. - * - * @param ndArray an array - * @return a new memory scrope containing the array - */ - public static MemoryScope from(final NDArray ndArray) { - return new MemoryScope( - ndArray.getManager(), - ndArray.getManager().newSubManager(ndArray.getDevice())) - .add(ndArray); - } - - /** - * Creates a new memory scope that fits the device of the first array in the given list, adds - * all arrays in the given list. - * - * @param list a list of arrays, must not be empty - * @return a new memory scope - */ - public static MemoryScope from(final NDList list) { - NDArray ndArray = list.head(); - return new MemoryScope( - ndArray.getManager(), - ndArray.getManager().newSubManager(ndArray.getDevice())) - .add(list); - } -} diff --git a/dlr/dlr-engine/src/main/java/ai/djl/dlr/engine/DlrNDArray.java b/dlr/dlr-engine/src/main/java/ai/djl/dlr/engine/DlrNDArray.java index 8cb8fa653cc..f8c1b8fc119 100644 --- a/dlr/dlr-engine/src/main/java/ai/djl/dlr/engine/DlrNDArray.java +++ b/dlr/dlr-engine/src/main/java/ai/djl/dlr/engine/DlrNDArray.java @@ -45,7 +45,7 @@ public class DlrNDArray implements NDArrayAdapter { this.data = data; this.shape = shape; uid = UUID.randomUUID().toString(); - manager.attach(uid, this); + manager.attachInternal(uid, this); } /** {@inheritDoc} */ @@ -94,18 +94,25 @@ public Shape getShape() { /** {@inheritDoc} */ @Override - public NDManager attach(NDManager manager) { + public void attach(NDManager manager) { + detach(); + this.manager = (DlrNDManager) manager; + manager.attachInternal(getUid(), this); + } + + /** {@inheritDoc} */ + @Override + public void tempAttach(NDManager manager) { detach(); NDManager original = this.manager; this.manager = (DlrNDManager) manager; - manager.attach(getUid(), this); - return original; + manager.tempAttachInternal(original, getUid(), this); } /** {@inheritDoc} */ @Override public void detach() { - manager.detach(getUid()); + manager.detachInternal(getUid()); manager = DlrNDManager.getSystemManager(); } diff --git a/dlr/dlr-engine/src/main/java/ai/djl/dlr/engine/DlrNDManager.java b/dlr/dlr-engine/src/main/java/ai/djl/dlr/engine/DlrNDManager.java index f879c907cc3..ec87ba6675d 100644 --- a/dlr/dlr-engine/src/main/java/ai/djl/dlr/engine/DlrNDManager.java +++ b/dlr/dlr-engine/src/main/java/ai/djl/dlr/engine/DlrNDManager.java @@ -54,7 +54,7 @@ public ByteBuffer allocateDirect(int capacity) { @Override public DlrNDManager newSubManager(Device dev) { DlrNDManager manager = new DlrNDManager(this, dev); - attach(manager.uid, manager); + attachInternal(manager.uid, manager); return manager; } @@ -105,11 +105,11 @@ private static final class SystemManager extends DlrNDManager { /** {@inheritDoc} */ @Override - public void attach(String resourceId, AutoCloseable resource) {} + public void attachInternal(String resourceId, AutoCloseable resource) {} /** {@inheritDoc} */ @Override - public void detach(String resourceId) {} + public void detachInternal(String resourceId) {} /** {@inheritDoc} */ @Override diff --git a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/CachedOp.java b/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/CachedOp.java index 30460a25592..d13c75e684a 100644 --- a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/CachedOp.java +++ b/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/CachedOp.java @@ -73,7 +73,7 @@ public CachedOp( this.dataIndicesMap = dataIndices.toMap(); // holds all parameter and data NDArray values, final inputs to CachedOp this.manager = manager; - manager.attach(getUid(), this); + manager.attachInternal(getUid(), this); } /** @@ -139,7 +139,7 @@ public NDList forward(ParameterStore parameterStore, NDList data, boolean traini public void close() { Pointer pointer = handle.getAndSet(null); if (pointer != null) { - manager.detach(getUid()); + manager.detachInternal(getUid()); JnaUtils.freeCachedOp(pointer); manager = null; } diff --git a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java b/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java index c1277dca17a..1b9b066be0e 100644 --- a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java +++ b/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java @@ -92,7 +92,7 @@ public class MxNDArray extends NativeResource implements LazyNDArray { super(handle); this.manager = manager; mxNDArrayEx = new MxNDArrayEx(this); - manager.attach(getUid(), this); + manager.attachInternal(getUid(), this); } /** @@ -163,18 +163,25 @@ public SparseFormat getSparseFormat() { /** {@inheritDoc} */ @Override - public NDManager attach(NDManager manager) { + public void attach(NDManager manager) { + detach(); + this.manager = (MxNDManager) manager; + manager.attachInternal(getUid(), this); + } + + /** {@inheritDoc} */ + @Override + public void tempAttach(NDManager manager) { NDManager original = this.manager; detach(); this.manager = (MxNDManager) manager; - manager.attach(getUid(), this); - return original; + manager.tempAttachInternal(original, getUid(), this); } /** {@inheritDoc} */ @Override public void detach() { - manager.detach(getUid()); + manager.detachInternal(getUid()); manager = MxNDManager.getSystemManager(); } @@ -1609,7 +1616,7 @@ public void close() { if (pointer != null) { JnaUtils.waitToRead(pointer); JnaUtils.freeNdArray(pointer); - manager.detach(getUid()); + manager.detachInternal(getUid()); manager = null; } } diff --git a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArrayEx.java b/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArrayEx.java index 09f154e964d..5e9c32cf047 100644 --- a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArrayEx.java +++ b/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArrayEx.java @@ -377,8 +377,7 @@ public void adadeltaUpdate( // create a baseManager to close all intermediate NDArrays try (NDManager subManager = NDManager.newBaseManager()) { - List inputManagers = inputs.attach(subManager); - List weightManagers = weights.attach(subManager); + subManager.tempAttachAll(inputs, weights); // Preprocess Gradient grad.muli(rescaleGrad); @@ -394,10 +393,6 @@ public void adadeltaUpdate( // Update weight weight.subi(g); - - // attach back to their previous managers - inputs.attach(inputManagers); - weights.attach(weightManagers); } } diff --git a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDManager.java b/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDManager.java index 1effa33d3a5..b76f0492c52 100644 --- a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDManager.java +++ b/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDManager.java @@ -272,7 +272,7 @@ public NDArray randomMultinomial(int n, NDArray pValues) { @Override public MxNDManager newSubManager(Device dev) { MxNDManager manager = new MxNDManager(this, dev, version); - attach(manager.uid, manager); + attachInternal(manager.uid, manager); return manager; } @@ -386,11 +386,11 @@ private static final class SystemManager extends MxNDManager { /** {@inheritDoc} */ @Override - public void attach(String resourceId, AutoCloseable resource) {} + public void attachInternal(String resourceId, AutoCloseable resource) {} /** {@inheritDoc} */ @Override - public void detach(String resourceId) {} + public void detachInternal(String resourceId) {} /** {@inheritDoc} */ @Override diff --git a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/Symbol.java b/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/Symbol.java index 9bd825b3c7d..8fcc92e5061 100644 --- a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/Symbol.java +++ b/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/Symbol.java @@ -52,7 +52,7 @@ public class Symbol extends NativeResource { Symbol(MxNDManager manager, Pointer pointer) { super(pointer); this.manager = manager; - manager.attach(getUid(), this); + manager.attachInternal(getUid(), this); // argParams = JnaUtils.listSymbolArguments(getHandle()); // auxParams = JnaUtils.listSymbolAuxiliaryStates(getHandle()); } @@ -311,7 +311,7 @@ public String toString() { public void close() { Pointer pointer = handle.getAndSet(null); if (pointer != null) { - manager.detach(getUid()); + manager.detachInternal(getUid()); JnaUtils.freeSymbol(pointer); manager = null; } diff --git a/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtNDArray.java b/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtNDArray.java index 806fb96714e..997f6c29162 100644 --- a/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtNDArray.java +++ b/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtNDArray.java @@ -48,7 +48,7 @@ public class OrtNDArray implements NDArrayAdapter { this.manager = manager; this.tensor = tensor; uid = UUID.randomUUID().toString(); - manager.attach(uid, this); + manager.attachInternal(uid, this); } OnnxTensor getTensor() { @@ -106,18 +106,25 @@ public Shape getShape() { /** {@inheritDoc} */ @Override - public NDManager attach(NDManager manager) { + public void attach(NDManager manager) { + detach(); + this.manager = (OrtNDManager) manager; + manager.attachInternal(getUid(), this); + } + + /** {@inheritDoc} */ + @Override + public void tempAttach(NDManager manager) { detach(); NDManager original = this.manager; this.manager = (OrtNDManager) manager; - manager.attach(getUid(), this); - return original; + manager.tempAttachInternal(original, getUid(), this); } /** {@inheritDoc} */ @Override public void detach() { - manager.detach(getUid()); + manager.detachInternal(getUid()); manager = OrtNDManager.getSystemManager(); } diff --git a/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtNDManager.java b/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtNDManager.java index b0f4dd19e02..4120b64df40 100644 --- a/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtNDManager.java +++ b/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtNDManager.java @@ -136,7 +136,7 @@ public NDArray ones(Shape shape, DataType dataType) { @Override public OrtNDManager newSubManager(Device device) { OrtNDManager manager = new OrtNDManager(this, device, env); - attach(manager.uid, manager); + attachInternal(manager.uid, manager); return manager; } @@ -155,11 +155,11 @@ private static final class SystemManager extends OrtNDManager { /** {@inheritDoc} */ @Override - public void attach(String resourceId, AutoCloseable resource) {} + public void attachInternal(String resourceId, AutoCloseable resource) {} /** {@inheritDoc} */ @Override - public void detach(String resourceId) {} + public void detachInternal(String resourceId) {} /** {@inheritDoc} */ @Override diff --git a/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpNDArray.java b/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpNDArray.java index 9e160e52213..cdb7ba30ba8 100644 --- a/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpNDArray.java +++ b/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpNDArray.java @@ -39,7 +39,7 @@ public class PpNDArray extends NativeResource implements NDArrayAdapter { public PpNDArray(PpNDManager manager, long handle) { super(handle); this.manager = manager; - manager.attach(getUid(), this); + manager.attachInternal(getUid(), this); } /** {@inheritDoc} */ @@ -86,18 +86,25 @@ public Shape getShape() { /** {@inheritDoc} */ @Override - public NDManager attach(NDManager manager) { + public void attach(NDManager manager) { + detach(); + this.manager = (PpNDManager) manager; + manager.attachInternal(getUid(), this); + } + + /** {@inheritDoc} */ + @Override + public void tempAttach(NDManager manager) { detach(); NDManager original = this.manager; this.manager = (PpNDManager) manager; - manager.attach(getUid(), this); - return original; + manager.tempAttachInternal(original, getUid(), this); } /** {@inheritDoc} */ @Override public void detach() { - manager.detach(getUid()); + manager.detachInternal(getUid()); manager = PpNDManager.getSystemManager(); } diff --git a/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpNDManager.java b/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpNDManager.java index b277456af7b..f0aca98460d 100644 --- a/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpNDManager.java +++ b/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpNDManager.java @@ -51,7 +51,7 @@ public PpNDManager newSubManager() { @Override public PpNDManager newSubManager(Device device) { PpNDManager manager = new PpNDManager(this, device); - attach(manager.uid, manager); + attachInternal(manager.uid, manager); return manager; } @@ -156,11 +156,11 @@ private static final class SystemManager extends PpNDManager { /** {@inheritDoc} */ @Override - public void attach(String resourceId, AutoCloseable resource) {} + public void attachInternal(String resourceId, AutoCloseable resource) {} /** {@inheritDoc} */ @Override - public void detach(String resourceId) {} + public void detachInternal(String resourceId) {} /** {@inheritDoc} */ @Override diff --git a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java index 1a9e6613797..93f811326c0 100644 --- a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java +++ b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java @@ -69,7 +69,7 @@ public PtNDArray(PtNDManager manager, long handle) { super(handle); this.manager = manager; this.ptNDArrayEx = new PtNDArrayEx(this); - manager.attach(getUid(), this); + manager.attachInternal(getUid(), this); } /** @@ -84,7 +84,7 @@ public PtNDArray(PtNDManager manager, long handle, ByteBuffer data) { super(handle); this.manager = manager; this.ptNDArrayEx = new PtNDArrayEx(this); - manager.attach(getUid(), this); + manager.attachInternal(getUid(), this); dataRef = data; } @@ -285,18 +285,25 @@ public void copyTo(NDArray array) { /** {@inheritDoc} */ @Override - public NDManager attach(NDManager manager) { + public void attach(NDManager manager) { + detach(); + this.manager = (PtNDManager) manager; + manager.attachInternal(getUid(), this); + } + + /** {@inheritDoc} */ + @Override + public void tempAttach(NDManager manager) { detach(); NDManager original = this.manager; this.manager = (PtNDManager) manager; - manager.attach(getUid(), this); - return original; + manager.tempAttachInternal(original, getUid(), this); } /** {@inheritDoc} */ @Override public void detach() { - manager.detach(getUid()); + manager.detachInternal(getUid()); manager = PtNDManager.getSystemManager(); } @@ -1442,7 +1449,7 @@ public void close() { Long pointer = handle.getAndSet(null); if (pointer != null) { JniUtils.deleteNDArray(pointer); - manager.detach(getUid()); + manager.detachInternal(getUid()); manager = null; dataRef = null; } diff --git a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDManager.java b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDManager.java index ac51d54180c..7d2fa9a9123 100644 --- a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDManager.java +++ b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDManager.java @@ -181,7 +181,7 @@ public NDArray randomNormal(float loc, float scale, Shape shape, DataType dataTy @Override public PtNDManager newSubManager(Device device) { PtNDManager manager = new PtNDManager(this, device); - attach(manager.uid, manager); + attachInternal(manager.uid, manager); return manager; } @@ -200,11 +200,11 @@ private static final class SystemManager extends PtNDManager { /** {@inheritDoc} */ @Override - public void attach(String resourceId, AutoCloseable resource) {} + public void attachInternal(String resourceId, AutoCloseable resource) {} /** {@inheritDoc} */ @Override - public void detach(String resourceId) {} + public void detachInternal(String resourceId) {} /** {@inheritDoc} */ @Override diff --git a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtSymbolBlock.java b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtSymbolBlock.java index 964649cd5a1..9eebb11846b 100644 --- a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtSymbolBlock.java +++ b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtSymbolBlock.java @@ -65,7 +65,7 @@ public PtSymbolBlock(PtNDManager manager, long handle) { this.handle = new AtomicReference<>(handle); this.manager = manager; uid = String.valueOf(handle); - manager.attach(uid, this); + manager.attachInternal(uid, this); // training mode is on by default isTrain = true; first = true; @@ -90,7 +90,7 @@ public void close() { Long pointer = handle.getAndSet(null); if (pointer != null) { JniUtils.deleteModule(pointer); - manager.detach(uid); + manager.detachInternal(uid); manager = null; } } @@ -177,7 +177,7 @@ public void loadParameters(NDManager manager, DataInputStream is) long rawHandle = JniUtils.loadModuleHandle(is, manager.getDevice(), true); this.handle = new AtomicReference<>(rawHandle); uid = String.valueOf(rawHandle); - manager.attach(uid, this); + manager.attachInternal(uid, this); } /** diff --git a/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java b/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java index 8707eac015c..76dcefbbf3a 100644 --- a/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java +++ b/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java @@ -73,7 +73,7 @@ public class TfNDArray implements NDArray { this.manager = (TfNDManager) manager; this.tf = this.manager.getTf(); uid = UUID.randomUUID().toString(); - manager.attach(uid, this); + manager.attachInternal(uid, this); this.operand = this.manager .getEagerSession() @@ -286,18 +286,25 @@ public void set(Buffer data) { /** {@inheritDoc} */ @Override - public NDManager attach(NDManager manager) { + public void attach(NDManager manager) { + detach(); + this.manager = (TfNDManager) manager; + manager.attachInternal(uid, this); + } + + /** {@inheritDoc} */ + @Override + public void tempAttach(NDManager manager) { detach(); NDManager original = this.manager; this.manager = (TfNDManager) manager; - manager.attach(uid, this); - return original; + manager.tempAttachInternal(original, uid, this); } /** {@inheritDoc} */ @Override public void detach() { - manager.detach(getUid()); + manager.detachInternal(getUid()); manager = TfNDManager.getSystemManager(); } diff --git a/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDManager.java b/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDManager.java index 2bfac11167c..90048fe9c2c 100644 --- a/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDManager.java +++ b/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDManager.java @@ -423,7 +423,7 @@ public NDArray randomNormal(float loc, float scale, Shape shape, DataType dataTy @Override public TfNDManager newSubManager(Device device) { TfNDManager manager = new TfNDManager(this, device); - attach(manager.uid, manager); + attachInternal(manager.uid, manager); // initialize eager sessions and operators only for sub managers manager.getEagerSession(); manager.getTf(); @@ -447,11 +447,11 @@ private static final class SystemManager extends TfNDManager { /** {@inheritDoc} */ @Override - public void attach(String resrouceId, AutoCloseable resource) {} + public void attachInternal(String resrouceId, AutoCloseable resource) {} /** {@inheritDoc} */ @Override - public void detach(String resourceId) {} + public void detachInternal(String resourceId) {} /** {@inheritDoc} */ @Override diff --git a/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteNDArray.java b/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteNDArray.java index 9b59b4893ab..86a87cc8fac 100644 --- a/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteNDArray.java +++ b/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteNDArray.java @@ -39,7 +39,7 @@ public class TfLiteNDArray implements NDArrayAdapter { TfLiteNDArray(TfLiteNDManager manager, Tensor tensor) { this.manager = manager; uid = UUID.randomUUID().toString(); - manager.attach(uid, this); + manager.attachInternal(uid, this); this.tensor = tensor; shape = new Shape(Arrays.stream(tensor.shape()).mapToLong(i -> i).toArray()); dataType = TfLiteDataType.fromTf(tensor.dataType()); @@ -103,18 +103,25 @@ public SparseFormat getSparseFormat() { /** {@inheritDoc} */ @Override - public NDManager attach(NDManager manager) { + public void attach(NDManager manager) { + detach(); + this.manager = (TfLiteNDManager) manager; + manager.attachInternal(getUid(), this); + } + + /** {@inheritDoc} */ + @Override + public void tempAttach(NDManager manager) { detach(); NDManager original = this.manager; this.manager = (TfLiteNDManager) manager; - manager.attach(getUid(), this); - return original; + manager.tempAttachInternal(original, getUid(), this); } /** {@inheritDoc} */ @Override public void detach() { - manager.detach(getUid()); + manager.detachInternal(getUid()); manager = TfLiteNDManager.getSystemManager(); } diff --git a/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteNDManager.java b/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteNDManager.java index 8d55da2e7a1..64da62767f7 100644 --- a/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteNDManager.java +++ b/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteNDManager.java @@ -132,7 +132,7 @@ public NDArray ones(Shape shape, DataType dataType) { @Override public TfLiteNDManager newSubManager(Device device) { TfLiteNDManager manager = new TfLiteNDManager(this, device); - attach(manager.uid, manager); + attachInternal(manager.uid, manager); return manager; } @@ -151,11 +151,11 @@ private static final class SystemManager extends TfLiteNDManager { /** {@inheritDoc} */ @Override - public void attach(String resourceId, AutoCloseable resource) {} + public void attachInternal(String resourceId, AutoCloseable resource) {} /** {@inheritDoc} */ @Override - public void detach(String resourceId) {} + public void detachInternal(String resourceId) {} /** {@inheritDoc} */ @Override From 43e5891acc4cdee73ed7793b9d1991e7311930d1 Mon Sep 17 00:00:00 2001 From: Zach Kimberg Date: Tue, 9 Mar 2021 11:23:55 -0800 Subject: [PATCH 24/25] Remove erroneous random forest application (#726) The application was changed to the more accurate softmax_regression (matching the terminology from the D2L book). Change-Id: I1f69f005bbe38b125f2709c2988d06c14eebb765 --- api/src/main/java/ai/djl/Application.java | 11 +---------- .../machine_learning_with_ONNXRuntime.ipynb | 2 +- .../main/java/ai/djl/onnxruntime/zoo/OrtModelZoo.java | 2 +- .../IrisClassificationModelLoader.java | 5 +++-- .../IrisFlower.java | 2 +- .../package-info.java | 2 +- .../test/java/ai/djl/onnxruntime/engine/OrtTest.java | 2 +- .../ai/djl/onnxruntime/iris_flowers/metadata.json | 2 +- 8 files changed, 10 insertions(+), 18 deletions(-) rename onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/zoo/tabular/{randomforest => softmax_regression}/IrisClassificationModelLoader.java (96%) rename onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/zoo/tabular/{randomforest => softmax_regression}/IrisFlower.java (96%) rename onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/zoo/tabular/{randomforest => softmax_regression}/package-info.java (92%) rename onnxruntime/onnxruntime-engine/src/test/resources/mlrepo/model/tabular/{random_forest => softmax_regression}/ai/djl/onnxruntime/iris_flowers/metadata.json (94%) diff --git a/api/src/main/java/ai/djl/Application.java b/api/src/main/java/ai/djl/Application.java index 76d8e8926ac..299c7e1ccc0 100644 --- a/api/src/main/java/ai/djl/Application.java +++ b/api/src/main/java/ai/djl/Application.java @@ -268,15 +268,6 @@ public interface Tabular { * @see The D2L * chapter introducing this application */ - Application SOFTMAX_REGRESSION = new Application("tabular/linear_regression"); - - /** - * This is erroneous because random forest is a technique (not deep learning), not an - * application. - * - *

The actual application is likely to be in {@link Tabular}, especially {@link - * #SOFTMAX_REGRESSION}. - */ - Application RANDOM_FOREST = new Application("tabular/random_forest"); + Application SOFTMAX_REGRESSION = new Application("tabular/softmax_regression"); } } diff --git a/jupyter/onnxruntime/machine_learning_with_ONNXRuntime.ipynb b/jupyter/onnxruntime/machine_learning_with_ONNXRuntime.ipynb index 7ec80efabb9..7ab1b7a7106 100644 --- a/jupyter/onnxruntime/machine_learning_with_ONNXRuntime.ipynb +++ b/jupyter/onnxruntime/machine_learning_with_ONNXRuntime.ipynb @@ -178,7 +178,7 @@ "metadata": {}, "outputs": [], "source": [ - "String modelUrl = \"https://mlrepo.djl.ai/model/tabular/random_forest/ai/djl/onnxruntime/iris_flowers/0.0.1/iris_flowers.zip\";\n", + "String modelUrl = \"https://mlrepo.djl.ai/model/tabular/softmax_regression/ai/djl/onnxruntime/iris_flowers/0.0.1/iris_flowers.zip\";\n", "Criteria criteria = Criteria.builder()\n", " .setTypes(IrisFlower.class, Classifications.class)\n", " .optModelUrls(modelUrl)\n", diff --git a/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/zoo/OrtModelZoo.java b/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/zoo/OrtModelZoo.java index 0c5dfa5ec9e..3dadcc532af 100644 --- a/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/zoo/OrtModelZoo.java +++ b/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/zoo/OrtModelZoo.java @@ -13,7 +13,7 @@ package ai.djl.onnxruntime.zoo; import ai.djl.onnxruntime.engine.OrtEngine; -import ai.djl.onnxruntime.zoo.tabular.randomforest.IrisClassificationModelLoader; +import ai.djl.onnxruntime.zoo.tabular.softmax_regression.IrisClassificationModelLoader; import ai.djl.repository.Repository; import ai.djl.repository.zoo.ModelZoo; import java.util.Collections; diff --git a/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/zoo/tabular/randomforest/IrisClassificationModelLoader.java b/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/zoo/tabular/softmax_regression/IrisClassificationModelLoader.java similarity index 96% rename from onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/zoo/tabular/randomforest/IrisClassificationModelLoader.java rename to onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/zoo/tabular/softmax_regression/IrisClassificationModelLoader.java index 67109f547ea..8b48fbb393b 100644 --- a/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/zoo/tabular/randomforest/IrisClassificationModelLoader.java +++ b/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/zoo/tabular/softmax_regression/IrisClassificationModelLoader.java @@ -10,9 +10,10 @@ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ -package ai.djl.onnxruntime.zoo.tabular.randomforest; +package ai.djl.onnxruntime.zoo.tabular.softmax_regression; import ai.djl.Application; +import ai.djl.Application.Tabular; import ai.djl.MalformedModelException; import ai.djl.Model; import ai.djl.modality.Classifications; @@ -39,7 +40,7 @@ /** Model loader for onnx iris_flowers models. */ public class IrisClassificationModelLoader extends BaseModelLoader { - private static final Application APPLICATION = Application.Tabular.RANDOM_FOREST; + private static final Application APPLICATION = Tabular.SOFTMAX_REGRESSION; private static final String GROUP_ID = OrtModelZoo.GROUP_ID; private static final String ARTIFACT_ID = "iris_flowers"; private static final String VERSION = "0.0.1"; diff --git a/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/zoo/tabular/randomforest/IrisFlower.java b/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/zoo/tabular/softmax_regression/IrisFlower.java similarity index 96% rename from onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/zoo/tabular/randomforest/IrisFlower.java rename to onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/zoo/tabular/softmax_regression/IrisFlower.java index ba77ec71b96..27263e520f2 100644 --- a/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/zoo/tabular/randomforest/IrisFlower.java +++ b/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/zoo/tabular/softmax_regression/IrisFlower.java @@ -10,7 +10,7 @@ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ -package ai.djl.onnxruntime.zoo.tabular.randomforest; +package ai.djl.onnxruntime.zoo.tabular.softmax_regression; /** A class holds the iris flower features. */ public class IrisFlower { diff --git a/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/zoo/tabular/randomforest/package-info.java b/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/zoo/tabular/softmax_regression/package-info.java similarity index 92% rename from onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/zoo/tabular/randomforest/package-info.java rename to onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/zoo/tabular/softmax_regression/package-info.java index 7759b6962de..d6e6aee3bf3 100644 --- a/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/zoo/tabular/randomforest/package-info.java +++ b/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/zoo/tabular/softmax_regression/package-info.java @@ -14,4 +14,4 @@ /** * Contains classes for the classification models in the {@link ai.djl.onnxruntime.zoo.OrtModelZoo}. */ -package ai.djl.onnxruntime.zoo.tabular.randomforest; +package ai.djl.onnxruntime.zoo.tabular.softmax_regression; diff --git a/onnxruntime/onnxruntime-engine/src/test/java/ai/djl/onnxruntime/engine/OrtTest.java b/onnxruntime/onnxruntime-engine/src/test/java/ai/djl/onnxruntime/engine/OrtTest.java index 4d8df458605..67369346487 100644 --- a/onnxruntime/onnxruntime-engine/src/test/java/ai/djl/onnxruntime/engine/OrtTest.java +++ b/onnxruntime/onnxruntime-engine/src/test/java/ai/djl/onnxruntime/engine/OrtTest.java @@ -20,7 +20,7 @@ import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDList; import ai.djl.ndarray.types.Shape; -import ai.djl.onnxruntime.zoo.tabular.randomforest.IrisFlower; +import ai.djl.onnxruntime.zoo.tabular.softmax_regression.IrisFlower; import ai.djl.repository.zoo.Criteria; import ai.djl.repository.zoo.ModelNotFoundException; import ai.djl.repository.zoo.ModelZoo; diff --git a/onnxruntime/onnxruntime-engine/src/test/resources/mlrepo/model/tabular/random_forest/ai/djl/onnxruntime/iris_flowers/metadata.json b/onnxruntime/onnxruntime-engine/src/test/resources/mlrepo/model/tabular/softmax_regression/ai/djl/onnxruntime/iris_flowers/metadata.json similarity index 94% rename from onnxruntime/onnxruntime-engine/src/test/resources/mlrepo/model/tabular/random_forest/ai/djl/onnxruntime/iris_flowers/metadata.json rename to onnxruntime/onnxruntime-engine/src/test/resources/mlrepo/model/tabular/softmax_regression/ai/djl/onnxruntime/iris_flowers/metadata.json index ced6365dc22..e5979c1b771 100644 --- a/onnxruntime/onnxruntime-engine/src/test/resources/mlrepo/model/tabular/random_forest/ai/djl/onnxruntime/iris_flowers/metadata.json +++ b/onnxruntime/onnxruntime-engine/src/test/resources/mlrepo/model/tabular/softmax_regression/ai/djl/onnxruntime/iris_flowers/metadata.json @@ -1,7 +1,7 @@ { "metadataVersion": "0.2", "resourceType": "model", - "application": "tabular/random_forest", + "application": "tabular/softmax_regression", "groupId": "ai.djl.onnxruntime", "artifactId": "iris_flowers", "name": "iris_flowers", From 2158e99d3c456f5e317e1c2ae4aac90c1d3fef80 Mon Sep 17 00:00:00 2001 From: Lanking Date: Tue, 9 Mar 2021 14:26:25 -0800 Subject: [PATCH 25/25] Minor fixes on duplicated code (#736) * remove methods that already defined in the NDArrayAdapter Change-Id: I01cc03a7f5b427bf31c6b3fd8d2136f2a27fe93b * refactor toString Change-Id: Iea22b16e1daa9f759b55c1a8b8b85536482e551a * remove sparse NDArray Change-Id: Icb44096519775f54cb32cc768c14f49e33dc7ea5 * fix test Change-Id: Icef580ed77e7bba22864ce44577de3cba51e3e41 --- api/src/main/java/ai/djl/BaseModel.java | 16 +++++ .../java/ai/djl/dlr/engine/DlrNDArray.java | 12 ---- .../tests/ndarray/NDArrayCreationOpTest.java | 4 +- .../java/ai/djl/mxnet/engine/MxModel.java | 16 ----- .../java/ai/djl/mxnet/engine/MxNDArray.java | 3 + .../java/ai/djl/mxnet/engine/MxNDManager.java | 13 ++-- .../ai/djl/mxnet/engine/MxSparseNDArray.java | 62 ------------------- .../ai/djl/paddlepaddle/engine/PpModel.java | 16 ----- 8 files changed, 27 insertions(+), 115 deletions(-) delete mode 100644 mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxSparseNDArray.java diff --git a/api/src/main/java/ai/djl/BaseModel.java b/api/src/main/java/ai/djl/BaseModel.java index 31e81f0efd2..d6fa016fc18 100644 --- a/api/src/main/java/ai/djl/BaseModel.java +++ b/api/src/main/java/ai/djl/BaseModel.java @@ -267,6 +267,22 @@ public Path getModelPath() { return modelDir; } + /** {@inheritDoc} */ + @Override + public String toString() { + StringBuilder sb = new StringBuilder(200); + sb.append("Model (\n\tName: ").append(modelName); + if (modelDir != null) { + sb.append("\n\tModel location: ").append(modelDir.toAbsolutePath()); + } + sb.append("\n\tData Type: ").append(dataType); + for (Map.Entry entry : properties.entrySet()) { + sb.append("\n\t").append(entry.getKey()).append(": ").append(entry.getValue()); + } + sb.append("\n)"); + return sb.toString(); + } + /** {@inheritDoc} */ @SuppressWarnings("deprecation") @Override diff --git a/dlr/dlr-engine/src/main/java/ai/djl/dlr/engine/DlrNDArray.java b/dlr/dlr-engine/src/main/java/ai/djl/dlr/engine/DlrNDArray.java index f8c1b8fc119..59ab1683913 100644 --- a/dlr/dlr-engine/src/main/java/ai/djl/dlr/engine/DlrNDArray.java +++ b/dlr/dlr-engine/src/main/java/ai/djl/dlr/engine/DlrNDArray.java @@ -116,18 +116,6 @@ public void detach() { manager = DlrNDManager.getSystemManager(); } - /** {@inheritDoc} */ - @Override - public boolean hasGradient() { - return false; - } - - /** {@inheritDoc} */ - @Override - public NDArray stopGradient() { - throw new UnsupportedOperationException("Not supported for DLR"); - } - /** {@inheritDoc} */ @Override public ByteBuffer toByteBuffer() { diff --git a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDArrayCreationOpTest.java b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDArrayCreationOpTest.java index 090e96c4b7d..9a020f41259 100644 --- a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDArrayCreationOpTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDArrayCreationOpTest.java @@ -70,7 +70,7 @@ public void testCreateCSRMatrix() { long[] indptr = {0, 2, 2, 3}; long[] indices = {0, 2, 1}; NDArray nd = manager.createCSR(buf, indptr, indices, new Shape(3, 4)); - float[] array = nd.toFloatArray(); + float[] array = nd.toDense().toFloatArray(); Assert.assertEquals(array[0], expected[0]); Assert.assertEquals(array[2], expected[1]); Assert.assertEquals(array[9], expected[2]); @@ -85,7 +85,7 @@ public void testCreateRowSparseMatrix() { FloatBuffer buf = FloatBuffer.wrap(expected); long[] indices = {0, 1, 3}; NDArray nd = manager.createRowSparse(buf, new Shape(3, 2), indices, new Shape(4, 2)); - float[] array = nd.toFloatArray(); + float[] array = nd.toDense().toFloatArray(); Assert.assertEquals(array[0], expected[0]); Assert.assertEquals(array[1], expected[1]); Assert.assertEquals(array[2], expected[2]); diff --git a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxModel.java b/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxModel.java index 54c58bf1a01..5c675a9519c 100644 --- a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxModel.java +++ b/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxModel.java @@ -200,20 +200,4 @@ private void loadParameters(Path paramFile, Map options) dataType = paramNDlist.head().getDataType(); logger.debug("MXNet Model {} ({}) loaded successfully.", paramFile, dataType); } - - /** {@inheritDoc} */ - @Override - public String toString() { - StringBuilder sb = new StringBuilder(200); - sb.append("Model (\n\tName: ").append(modelName); - if (modelDir != null) { - sb.append("\n\tModel location: ").append(modelDir.toAbsolutePath()); - } - sb.append("\n\tData Type: ").append(dataType); - for (Map.Entry entry : properties.entrySet()) { - sb.append("\n\t").append(entry.getKey()).append(": ").append(entry.getValue()); - } - sb.append("\n)"); - return sb.toString(); - } } diff --git a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java b/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java index 1b9b066be0e..e48cce9af31 100644 --- a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java +++ b/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java @@ -290,6 +290,9 @@ public String[] toStringArray() { /** {@inheritDoc} */ @Override public ByteBuffer toByteBuffer() { + if (getSparseFormat() != SparseFormat.DENSE) { + throw new IllegalStateException("Require Dense NDArray, actual " + getSparseFormat()); + } Shape sh = getShape(); DataType dType = getDataType(); long product = sh.size(); diff --git a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDManager.java b/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDManager.java index b76f0492c52..cb18c373c68 100644 --- a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDManager.java +++ b/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDManager.java @@ -81,8 +81,8 @@ public MxNDArray create(Pointer handle) { * @param fmt the sparse format to use * @return the created array */ - public MxSparseNDArray create(Pointer handle, SparseFormat fmt) { - return new MxSparseNDArray(this, handle, fmt); + public MxNDArray create(Pointer handle, SparseFormat fmt) { + return new MxNDArray(this, handle, fmt); } /** {@inheritDoc} */ @@ -97,7 +97,7 @@ public MxNDArray create(Shape shape, DataType dataType) { /** {@inheritDoc} */ @Override - public MxSparseNDArray createCSR(Buffer data, long[] indptr, long[] indices, Shape shape) { + public MxNDArray createCSR(Buffer data, long[] indptr, long[] indices, Shape shape) { SparseFormat fmt = SparseFormat.CSR; DataType dataType = DataType.fromBuffer(data); MxNDArray indptrNd = create(new Shape(indptr.length), DataType.INT64); @@ -113,7 +113,7 @@ public MxSparseNDArray createCSR(Buffer data, long[] indptr, long[] indices, Sha new DataType[] {indptrNd.getDataType(), indicesNd.getDataType()}, new Shape[] {indptrNd.getShape(), indicesNd.getShape()}, false); - MxSparseNDArray sparse = create(handle, fmt); + MxNDArray sparse = create(handle, fmt); MxNDArray dataNd = create(new Shape(data.remaining()), dataType); dataNd.set(data); JnaUtils.ndArraySyncCopyFromNdArray(sparse, dataNd, -1); @@ -124,8 +124,7 @@ public MxSparseNDArray createCSR(Buffer data, long[] indptr, long[] indices, Sha /** {@inheritDoc} */ @Override - public MxSparseNDArray createRowSparse( - Buffer data, Shape dataShape, long[] indices, Shape shape) { + public MxNDArray createRowSparse(Buffer data, Shape dataShape, long[] indices, Shape shape) { SparseFormat fmt = SparseFormat.ROW_SPARSE; DataType dataType = DataType.fromBuffer(data); MxNDArray indicesNd = create(new Shape(indices.length), DataType.INT64); @@ -139,7 +138,7 @@ public MxSparseNDArray createRowSparse( new DataType[] {indicesNd.getDataType()}, new Shape[] {indicesNd.getShape()}, false); - MxSparseNDArray sparse = create(handle, fmt); + MxNDArray sparse = create(handle, fmt); MxNDArray dataNd = create(dataShape, dataType); dataNd.set(data); JnaUtils.ndArraySyncCopyFromNdArray(sparse, dataNd, -1); diff --git a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxSparseNDArray.java b/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxSparseNDArray.java deleted file mode 100644 index 942591ceafe..00000000000 --- a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxSparseNDArray.java +++ /dev/null @@ -1,62 +0,0 @@ -/* - * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance - * with the License. A copy of the License is located at - * - * http://aws.amazon.com/apache2.0/ - * - * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES - * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions - * and limitations under the License. - */ -package ai.djl.mxnet.engine; - -import ai.djl.ndarray.NDArray; -import ai.djl.ndarray.index.NDIndex; -import ai.djl.ndarray.types.SparseFormat; -import com.sun.jna.Pointer; -import java.nio.Buffer; -import java.nio.ByteBuffer; - -/** - * {@code MxSparseNDArray} is an instance of {@link MxNDArray} and {@link NDArray} for sparse - * NDArrays. - * - *

{@code MxSparseNDArray}s are created automatically when the Engine creates Arrays that are - * sparse. They can be created deliberately by specifying the {@link SparseFormat}. Some operations - * may not be supported with Sparse NDArrays in MXNet. - * - * @see SparseFormat - */ -public class MxSparseNDArray extends MxNDArray { - - /** - * Constructs a {@code MxSparseNDArray} for the given data. - * - * @param manager the manager to attach the array to - * @param handle the pointer to the native memory of the MXNDArray - * @param fmt the sparse format - */ - MxSparseNDArray(MxNDManager manager, Pointer handle, SparseFormat fmt) { - super(manager, handle, fmt); - } - - /** {@inheritDoc} */ - @Override - public void set(Buffer data) { - throw new IllegalStateException("Unsupported operation for Sparse"); - } - - /** {@inheritDoc} */ - @Override - public NDArray get(NDIndex index) { - return toDense().get(index); - } - - /** {@inheritDoc} */ - @Override - public ByteBuffer toByteBuffer() { - return toDense().toByteBuffer(); - } -} diff --git a/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpModel.java b/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpModel.java index 458b7333d23..25ebbb4c680 100644 --- a/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpModel.java +++ b/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpModel.java @@ -96,22 +96,6 @@ private String[] findModelFile(Path dir) { return null; } - /** {@inheritDoc} */ - @Override - public String toString() { - StringBuilder sb = new StringBuilder(200); - sb.append("Model (\n\tName: ").append(modelName); - if (modelDir != null) { - sb.append("\n\tModel location: ").append(modelDir.toAbsolutePath()); - } - sb.append("\n\tData Type: ").append(dataType); - for (Map.Entry entry : properties.entrySet()) { - sb.append("\n\t").append(entry.getKey()).append(": ").append(entry.getValue()); - } - sb.append("\n)"); - return sb.toString(); - } - /** {@inheritDoc} */ @Override public Predictor newPredictor(Translator translator) {