diff --git a/api/src/main/java/ai/djl/inference/Predictor.java b/api/src/main/java/ai/djl/inference/Predictor.java index fb5663602f4..b5fb26348cc 100644 --- a/api/src/main/java/ai/djl/inference/Predictor.java +++ b/api/src/main/java/ai/djl/inference/Predictor.java @@ -14,6 +14,8 @@ import ai.djl.Device; import ai.djl.Model; +import ai.djl.inference.streaming.StreamingBlock; +import ai.djl.inference.streaming.StreamingTranslator; import ai.djl.metric.Metrics; import ai.djl.metric.Unit; import ai.djl.ndarray.LazyNDArray; @@ -190,6 +192,70 @@ public List batchPredict(List inputs) throws TranslateException { } } + /** + * Predicts an item for inference. + * + * @param input the input + * @return the output object defined by the user + * @throws TranslateException if an error occurs during prediction + */ + @SuppressWarnings({"PMD.AvoidRethrowingException", "PMD.IdenticalCatchBranches"}) + public O streamingPredict(I input) throws TranslateException { + + if (!(block instanceof StreamingBlock)) { + throw new IllegalStateException( + "streamingPredict() can only be called with a StreamingBlock"); + } + StreamingBlock streamingBlock = (StreamingBlock) block; + + if (!(translator instanceof StreamingTranslator)) { + throw new IllegalStateException( + "streamingPredict() can only be called with a StreamingTranslator"); + } + StreamingTranslator streamingTranslator = (StreamingTranslator) translator; + + try { + PredictorContext context = new PredictorContext(); + if (!prepared) { + translator.prepare(context); + prepared = true; + } + Batchifier batchifier = translator.getBatchifier(); + if (batchifier == null) { + NDList ndList = translator.processInput(context, input); + + return streamingTranslator.processStreamOutput( + context, + streamingBlock + .forwardStream(parameterStore, ndList, false) + .onClose(context::close)); + } + + // For the batched case, need to create singleton batch and unbatchify singleton + NDList inputBatch = processInputs(context, Collections.singletonList(input)); + return streamingTranslator.processStreamOutput( + context, + streamingBlock + .forwardStream(parameterStore, inputBatch, false) + .map( + result -> { + NDList[] unbatched = + translator.getBatchifier().unbatchify(result); + if (unbatched.length != 1) { + throw new IllegalStateException( + "Unexpected number of outputs from model"); + } + return unbatched[0]; + }) + .onClose(context::close)); + + } catch (TranslateException e) { + throw e; + } catch (Exception e) { + throw new TranslateException(e); + } + } + /** * Attaches a Metrics param to use for benchmark. * diff --git a/api/src/main/java/ai/djl/modality/ChunkedBytesSupplier.java b/api/src/main/java/ai/djl/inference/streaming/ChunkedBytesSupplier.java similarity index 98% rename from api/src/main/java/ai/djl/modality/ChunkedBytesSupplier.java rename to api/src/main/java/ai/djl/inference/streaming/ChunkedBytesSupplier.java index c8b2d5d0e0e..61bf10f5a7d 100644 --- a/api/src/main/java/ai/djl/modality/ChunkedBytesSupplier.java +++ b/api/src/main/java/ai/djl/inference/streaming/ChunkedBytesSupplier.java @@ -10,7 +10,7 @@ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ -package ai.djl.modality; +package ai.djl.inference.streaming; import ai.djl.ndarray.BytesSupplier; diff --git a/api/src/main/java/ai/djl/inference/streaming/IteratorBytesSupplier.java b/api/src/main/java/ai/djl/inference/streaming/IteratorBytesSupplier.java new file mode 100644 index 00000000000..0e95af67c36 --- /dev/null +++ b/api/src/main/java/ai/djl/inference/streaming/IteratorBytesSupplier.java @@ -0,0 +1,69 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.inference.streaming; + +import ai.djl.ndarray.BytesSupplier; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Iterator; + +/** + * An {@link IteratorBytesSupplier} is a streaming {@link BytesSupplier} suitable for synchronous + * usage. + */ +public class IteratorBytesSupplier implements BytesSupplier, Iterator { + + private Iterator sources; + + /** + * Constructs an {@link IteratorBytesSupplier}. + * + * @param sources the source suppliers + */ + public IteratorBytesSupplier(Iterator sources) { + this.sources = sources; + } + + /** {@inheritDoc} */ + @Override + public boolean hasNext() { + return sources.hasNext(); + } + + /** {@inheritDoc} */ + @Override + public byte[] next() { + return sources.next().getAsBytes(); + } + + /** {@inheritDoc} */ + @Override + public ByteBuffer toByteBuffer() { + return ByteBuffer.wrap(getAsBytes()); + } + + /** {@inheritDoc} */ + @Override + public byte[] getAsBytes() { + try (ByteArrayOutputStream bos = new ByteArrayOutputStream()) { + while (hasNext()) { + bos.write(next()); + } + return bos.toByteArray(); + } catch (IOException e) { + throw new AssertionError("Failed to read BytesSupplier", e); + } + } +} diff --git a/api/src/main/java/ai/djl/inference/streaming/PublisherBytesSupplier.java b/api/src/main/java/ai/djl/inference/streaming/PublisherBytesSupplier.java new file mode 100644 index 00000000000..d83c4678f33 --- /dev/null +++ b/api/src/main/java/ai/djl/inference/streaming/PublisherBytesSupplier.java @@ -0,0 +1,130 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.inference.streaming; + +import ai.djl.ndarray.BytesSupplier; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Consumer; + +/** + * An {@link PublisherBytesSupplier} is a streaming {@link BytesSupplier} suitable for reactive + * asynchronous usage. + */ +public class PublisherBytesSupplier implements BytesSupplier { + + private final List allData; + private final AtomicBoolean completed; + private Consumer subscriber; + private final AtomicInteger dataPushed; + + /** Constructs a {@link PublisherBytesSupplier}. */ + public PublisherBytesSupplier() { + allData = new ArrayList<>(); + completed = new AtomicBoolean(); + dataPushed = new AtomicInteger(); + } + + /** + * Appends content to the {@code BytesSupplier}. + * + * @param data bytes to append + * @param lastChunk true if this is the last chunk + */ + public void appendContent(byte[] data, boolean lastChunk) { + synchronized (allData) { + allData.add(data); + } + if (lastChunk) { + completed.set(true); + } + pushData(); + } + + /** + * Adds the subscriber to the {@link BytesSupplier} to get notified about additional data. + * + * @param subscriber a consumer function that will receive bytes when new daata is added and + * null when completed + */ + public void subscribe(Consumer subscriber) { + if (this.subscriber != null) { + throw new IllegalStateException( + "The PublisherBytesSupplier only allows a single Subscriber"); + } + this.subscriber = subscriber; + pushData(); + } + + private void pushData() { + if (subscriber == null) { + return; + } + + int dataAvailable; + synchronized (allData) { + dataAvailable = allData.size(); + } + + int sent = dataPushed.getAndSet(dataAvailable); + if (sent < dataAvailable) { + synchronized (this) { + for (; sent < dataAvailable; sent++) { + subscriber.accept(allData.get(sent)); + } + if (completed.get()) { + subscriber.accept(null); + } + } + } + } + + /** Waits until completed before passing thread (BLOCKS THREAD!). */ + @SuppressWarnings("PMD.EmptyControlStatement") + public void waitToRead() { + // Block until complete!!! + while (!completed.get()) { + // Do nothing + } + } + + /** {@inheritDoc} */ + @Override + public byte[] getAsBytes() { + if (!completed.get()) { + throw new IllegalStateException( + "PublisherByteSupplier must be completely filled before reading."); + } + + try (ByteArrayOutputStream bos = new ByteArrayOutputStream()) { + for (byte[] data : allData) { + bos.write(data); + } + return bos.toByteArray(); + } catch (IOException e) { + throw new AssertionError("Failed to read BytesSupplier", e); + } + } + + /** {@inheritDoc} */ + @Override + public ByteBuffer toByteBuffer() { + return ByteBuffer.wrap(getAsBytes()); + } +} diff --git a/api/src/main/java/ai/djl/inference/streaming/StreamingBlock.java b/api/src/main/java/ai/djl/inference/streaming/StreamingBlock.java new file mode 100644 index 00000000000..3b456ffc40f --- /dev/null +++ b/api/src/main/java/ai/djl/inference/streaming/StreamingBlock.java @@ -0,0 +1,81 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.inference.streaming; + +import ai.djl.ndarray.NDList; +import ai.djl.nn.Block; +import ai.djl.training.ParameterStore; +import ai.djl.util.PairList; + +import java.util.Iterator; +import java.util.Spliterator; +import java.util.Spliterators; +import java.util.stream.Stream; +import java.util.stream.StreamSupport; + +/** + * A {@link Block} possessing the additional streaming forward capabilities used by {@link + * ai.djl.inference.Predictor#streamingPredict(Object)}. + */ +public interface StreamingBlock extends Block { + + /** + * Applies the operating function of the block once, but returns the result in chunks. This + * method should only be called on blocks that are initialized. + * + * @param parameterStore the parameter store + * @param inputs the input NDList + * @param training true for a training forward pass (turn on dropout and layerNorm) + * @return the output of the forward pass + */ + default Stream forwardStream( + ParameterStore parameterStore, NDList inputs, boolean training) { + return forwardStream(parameterStore, inputs, training, null); + } + + /** + * Applies the operating function of the block once, but returns the result in chunks. This + * method should only be called on blocks that are initialized. + * + * @param parameterStore the parameter store + * @param inputs the input NDList + * @param training true for a training forward pass (turn on dropout and layerNorm) + * @param params optional parameters + * @return the output of the forward pass + */ + default Stream forwardStream( + ParameterStore parameterStore, + NDList inputs, + boolean training, + PairList params) { + Iterator itr = forwardStreamIter(parameterStore, inputs, training, params); + Spliterator spitr = Spliterators.spliteratorUnknownSize(itr, Spliterator.NONNULL); + return StreamSupport.stream(spitr, false); + } + + /** + * Applies the operating function of the block once, but returns the result in chunks. This + * method should only be called on blocks that are initialized. + * + * @param parameterStore the parameter store + * @param inputs the input NDList + * @param training true for a training forward pass (turn on dropout and layerNorm) + * @param params optional parameters + * @return the output of the forward pass + */ + Iterator forwardStreamIter( + ParameterStore parameterStore, + NDList inputs, + boolean training, + PairList params); +} diff --git a/api/src/main/java/ai/djl/inference/streaming/StreamingTranslator.java b/api/src/main/java/ai/djl/inference/streaming/StreamingTranslator.java new file mode 100644 index 00000000000..c035613723f --- /dev/null +++ b/api/src/main/java/ai/djl/inference/streaming/StreamingTranslator.java @@ -0,0 +1,41 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.inference.streaming; + +import ai.djl.ndarray.NDList; +import ai.djl.translate.Translator; +import ai.djl.translate.TranslatorContext; + +import java.util.stream.Stream; + +/** + * An expansion of the {@link Translator} with postProcessing for the {@link StreamingBlock} (used + * by {@link ai.djl.inference.Predictor#streamingPredict(Object)}. + * + * @param the input type + * @param the output type + */ +public interface StreamingTranslator extends Translator { + + /** + * Processes the output NDList to the corresponding output object. + * + * @param ctx the toolkit used for post-processing + * @param list the output NDList after inference, usually immutable in engines like + * PyTorch. @see Issue 1774 + * @return the output object of expected type + * @throws Exception if an error occurs during processing output + */ + @SuppressWarnings("PMD.SignatureDeclareThrowsException") + O processStreamOutput(TranslatorContext ctx, Stream list) throws Exception; +} diff --git a/api/src/main/java/ai/djl/inference/streaming/package-info.java b/api/src/main/java/ai/djl/inference/streaming/package-info.java new file mode 100644 index 00000000000..5fc43b6e7ea --- /dev/null +++ b/api/src/main/java/ai/djl/inference/streaming/package-info.java @@ -0,0 +1,19 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +/** + * Contains classes to implement streaming inference tasks. + * + * @see ai.djl.inference.Predictor + */ +package ai.djl.inference.streaming; diff --git a/api/src/main/java/ai/djl/nn/SequentialBlock.java b/api/src/main/java/ai/djl/nn/SequentialBlock.java index 06fa9826347..d8f0988efec 100644 --- a/api/src/main/java/ai/djl/nn/SequentialBlock.java +++ b/api/src/main/java/ai/djl/nn/SequentialBlock.java @@ -13,6 +13,7 @@ package ai.djl.nn; import ai.djl.MalformedModelException; +import ai.djl.inference.streaming.StreamingBlock; import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDList; import ai.djl.ndarray.NDManager; @@ -27,6 +28,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; +import java.util.Iterator; import java.util.List; import java.util.function.Function; import java.util.stream.Collectors; @@ -38,7 +40,7 @@ * *

{@code SequentialBlock} has no direct parameters. */ -public class SequentialBlock extends AbstractBlock { +public class SequentialBlock extends AbstractBlock implements StreamingBlock { private static final byte VERSION = 3; private boolean returnIntermediate; @@ -216,6 +218,16 @@ protected NDList forwardInternal( return current; } + /** {@inheritDoc} */ + @Override + public Iterator forwardStreamIter( + ParameterStore parameterStore, + NDList inputs, + boolean training, + PairList params) { + return new StreamIterator(parameterStore, inputs, training); + } + /** {@inheritDoc} */ @Override public void initializeChildBlocks(NDManager manager, DataType dataType, Shape... inputShapes) { @@ -266,4 +278,35 @@ public void loadMetadata(byte loadVersion, DataInputStream is) throw new MalformedModelException("Unsupported encoding version: " + loadVersion); } } + + private final class StreamIterator implements Iterator { + + private int childIndex; + private ParameterStore parameterStore; + private NDList current; + private boolean training; + + private StreamIterator(ParameterStore parameterStore, NDList inputs, boolean training) { + this.parameterStore = parameterStore; + this.current = inputs; + this.training = training; + childIndex = 0; + } + + /** {@inheritDoc} */ + @Override + public boolean hasNext() { + return childIndex < children.size(); + } + + /** {@inheritDoc} */ + @Override + public NDList next() { + current = + children.get(childIndex++) + .getValue() + .forward(parameterStore, current, training); + return current; + } + } } diff --git a/api/src/main/java/ai/djl/translate/ServingTranslator.java b/api/src/main/java/ai/djl/translate/ServingTranslator.java index 49d0bd247f0..ba16d8500d4 100644 --- a/api/src/main/java/ai/djl/translate/ServingTranslator.java +++ b/api/src/main/java/ai/djl/translate/ServingTranslator.java @@ -12,13 +12,19 @@ */ package ai.djl.translate; +import ai.djl.inference.streaming.IteratorBytesSupplier; +import ai.djl.inference.streaming.StreamingTranslator; import ai.djl.modality.Input; import ai.djl.modality.Output; +import ai.djl.ndarray.BytesSupplier; +import ai.djl.ndarray.NDList; +import java.util.Iterator; import java.util.Map; +import java.util.stream.Stream; /** A {@link Translator} that can handle generic {@link Input} and {@link Output}. */ -public interface ServingTranslator extends Translator { +public interface ServingTranslator extends StreamingTranslator { /** * Sets the configurations for the {@code Translator} instance. @@ -26,4 +32,24 @@ public interface ServingTranslator extends Translator { * @param arguments the configurations for the {@code Translator} instance */ void setArguments(Map arguments); + + /** {@inheritDoc} */ + @Override + @SuppressWarnings("PMD.AvoidThrowingRawExceptionTypes") + default Output processStreamOutput(TranslatorContext ctx, Stream list) { + Iterator outputs = + list.map( + ndList -> { + try { + return processOutput(ctx, ndList).getData(); + } catch (Exception e) { + throw new RuntimeException(e); + } + }) + .iterator(); + IteratorBytesSupplier bytesSupplier = new IteratorBytesSupplier(outputs); + Output output = new Output(); + output.add(bytesSupplier); + return output; + } } diff --git a/api/src/test/java/ai/djl/modality/ChunkedBytesSupplierTest.java b/api/src/test/java/ai/djl/inference/streaming/ChunkedBytesSupplierTest.java similarity index 97% rename from api/src/test/java/ai/djl/modality/ChunkedBytesSupplierTest.java rename to api/src/test/java/ai/djl/inference/streaming/ChunkedBytesSupplierTest.java index 1d41352f375..3aa1bcf3686 100644 --- a/api/src/test/java/ai/djl/modality/ChunkedBytesSupplierTest.java +++ b/api/src/test/java/ai/djl/inference/streaming/ChunkedBytesSupplierTest.java @@ -10,7 +10,7 @@ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ -package ai.djl.modality; +package ai.djl.inference.streaming; import org.testng.Assert; import org.testng.annotations.Test; diff --git a/api/src/test/java/ai/djl/inference/streaming/IteratorBytesSupplierTest.java b/api/src/test/java/ai/djl/inference/streaming/IteratorBytesSupplierTest.java new file mode 100644 index 00000000000..6985c0418a2 --- /dev/null +++ b/api/src/test/java/ai/djl/inference/streaming/IteratorBytesSupplierTest.java @@ -0,0 +1,46 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.inference.streaming; + +import ai.djl.ndarray.BytesSupplier; + +import org.testng.Assert; +import org.testng.annotations.Test; + +import java.util.Iterator; +import java.util.stream.Stream; + +public class IteratorBytesSupplierTest { + + @Test + public void testIterate() { + Iterator iterator = + Stream.of("a", "b", "c").map(BytesSupplier::wrap).iterator(); + IteratorBytesSupplier supplier = new IteratorBytesSupplier(iterator); + + Assert.assertTrue(supplier.hasNext()); + Assert.assertEquals(supplier.next(), new byte[] {97}); + Assert.assertEquals(supplier.next(), new byte[] {98}); + Assert.assertEquals(supplier.next(), new byte[] {99}); + Assert.assertFalse(supplier.hasNext()); + } + + @Test + public void testAsBytes() { + Iterator iterator = + Stream.of("a", "b", "c").map(BytesSupplier::wrap).iterator(); + IteratorBytesSupplier supplier = new IteratorBytesSupplier(iterator); + + Assert.assertEquals(supplier.getAsBytes(), new byte[] {97, 98, 99}); + } +} diff --git a/api/src/test/java/ai/djl/inference/streaming/PublisherBytesSupplierTest.java b/api/src/test/java/ai/djl/inference/streaming/PublisherBytesSupplierTest.java new file mode 100644 index 00000000000..8c140688124 --- /dev/null +++ b/api/src/test/java/ai/djl/inference/streaming/PublisherBytesSupplierTest.java @@ -0,0 +1,46 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.inference.streaming; + +import org.testng.Assert; +import org.testng.annotations.Test; + +import java.util.concurrent.atomic.AtomicInteger; + +public class PublisherBytesSupplierTest { + + @Test + public void test() { + AtomicInteger contentCount = new AtomicInteger(); + PublisherBytesSupplier supplier = new PublisherBytesSupplier(); + + // Add to supplier without subscriber + supplier.appendContent(new byte[] {1}, false); + Assert.assertEquals(contentCount.get(), 0); + + // Subscribing with data should trigger subscriptions + supplier.subscribe( + d -> { + if (d == null) { + // Do nothing on completion + return; + } + contentCount.getAndIncrement(); + }); + Assert.assertEquals(contentCount.get(), 1); + + // Add to supplier with subscriber + supplier.appendContent(new byte[] {1}, true); + Assert.assertEquals(contentCount.get(), 2); + } +} diff --git a/api/src/test/java/ai/djl/inference/streaming/package-info.java b/api/src/test/java/ai/djl/inference/streaming/package-info.java new file mode 100644 index 00000000000..e389da05b71 --- /dev/null +++ b/api/src/test/java/ai/djl/inference/streaming/package-info.java @@ -0,0 +1,14 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +/** Contains tests for {@link ai.djl.inference.streaming}. */ +package ai.djl.inference.streaming; diff --git a/integration/src/main/java/ai/djl/integration/tests/inference/StreamingTest.java b/integration/src/main/java/ai/djl/integration/tests/inference/StreamingTest.java new file mode 100644 index 00000000000..0c3b94e7316 --- /dev/null +++ b/integration/src/main/java/ai/djl/integration/tests/inference/StreamingTest.java @@ -0,0 +1,78 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.integration.tests.inference; + +import ai.djl.Model; +import ai.djl.inference.Predictor; +import ai.djl.inference.streaming.StreamingTranslator; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.types.DataType; +import ai.djl.ndarray.types.Shape; +import ai.djl.nn.Parameter.Type; +import ai.djl.nn.SequentialBlock; +import ai.djl.nn.core.Linear; +import ai.djl.training.initializer.Initializer; +import ai.djl.translate.TranslateException; +import ai.djl.translate.TranslatorContext; + +import org.testng.Assert; +import org.testng.annotations.Test; + +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.DoubleStream; +import java.util.stream.Stream; + +public class StreamingTest { + + @Test + public void testSequential() throws TranslateException { + try (Model model = Model.newInstance("test")) { + SequentialBlock block = new SequentialBlock(); + block.add(Linear.builder().setUnits(1).build()); + block.add(Linear.builder().setUnits(1).build()); + model.setBlock(block); + + block.setInitializer(Initializer.ONES, Type.WEIGHT); + block.initialize(model.getNDManager(), DataType.FLOAT64, new Shape(1, 1)); + + try (Predictor predictor = + model.newPredictor(new TestTranslator())) { + List results = + predictor.streamingPredict(1.0).boxed().collect(Collectors.toList()); + Assert.assertEquals(results, Arrays.asList(1.0, 1.0)); + } + } + } + + private static class TestTranslator implements StreamingTranslator { + + /** {@inheritDoc} */ + @Override + public NDList processInput(TranslatorContext ctx, Double input) { + return new NDList(ctx.getNDManager().create(input)); + } + + /** {@inheritDoc} */ + @Override + public DoubleStream processOutput(TranslatorContext ctx, NDList list) { + return Arrays.stream(list.singletonOrThrow().toDoubleArray()); + } + + @Override + public DoubleStream processStreamOutput(TranslatorContext ctx, Stream list) { + return list.mapToDouble(l -> l.singletonOrThrow().getDouble()); + } + } +}