Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Streaming Predict and streamable BytesSupplier #2470

Merged
merged 1 commit into from
Mar 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions api/src/main/java/ai/djl/inference/Predictor.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

import ai.djl.Device;
import ai.djl.Model;
import ai.djl.inference.streaming.StreamingBlock;
import ai.djl.inference.streaming.StreamingTranslator;
import ai.djl.metric.Metrics;
import ai.djl.metric.Unit;
import ai.djl.ndarray.LazyNDArray;
Expand Down Expand Up @@ -190,6 +192,70 @@ public List<O> batchPredict(List<I> inputs) throws TranslateException {
}
}

/**
* Predicts an item for inference.
*
* @param input the input
* @return the output object defined by the user
* @throws TranslateException if an error occurs during prediction
*/
@SuppressWarnings({"PMD.AvoidRethrowingException", "PMD.IdenticalCatchBranches"})
public O streamingPredict(I input) throws TranslateException {

if (!(block instanceof StreamingBlock)) {
throw new IllegalStateException(
"streamingPredict() can only be called with a StreamingBlock");
}
StreamingBlock streamingBlock = (StreamingBlock) block;

if (!(translator instanceof StreamingTranslator)) {
throw new IllegalStateException(
"streamingPredict() can only be called with a StreamingTranslator");
}
StreamingTranslator<I, O> streamingTranslator = (StreamingTranslator<I, O>) translator;

try {
PredictorContext context = new PredictorContext();
if (!prepared) {
translator.prepare(context);
prepared = true;
}
Batchifier batchifier = translator.getBatchifier();
if (batchifier == null) {
NDList ndList = translator.processInput(context, input);

return streamingTranslator.processStreamOutput(
context,
streamingBlock
.forwardStream(parameterStore, ndList, false)
.onClose(context::close));
}

// For the batched case, need to create singleton batch and unbatchify singleton
NDList inputBatch = processInputs(context, Collections.singletonList(input));
return streamingTranslator.processStreamOutput(
context,
streamingBlock
.forwardStream(parameterStore, inputBatch, false)
.map(
result -> {
NDList[] unbatched =
translator.getBatchifier().unbatchify(result);
if (unbatched.length != 1) {
throw new IllegalStateException(
"Unexpected number of outputs from model");
}
return unbatched[0];
})
.onClose(context::close));

} catch (TranslateException e) {
throw e;
} catch (Exception e) {
throw new TranslateException(e);
}
}

/**
* Attaches a Metrics param to use for benchmark.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.modality;
package ai.djl.inference.streaming;

import ai.djl.ndarray.BytesSupplier;

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.inference.streaming;

import ai.djl.ndarray.BytesSupplier;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.Iterator;

/**
* An {@link IteratorBytesSupplier} is a streaming {@link BytesSupplier} suitable for synchronous
* usage.
*/
public class IteratorBytesSupplier implements BytesSupplier, Iterator<byte[]> {

private Iterator<BytesSupplier> sources;

/**
* Constructs an {@link IteratorBytesSupplier}.
*
* @param sources the source suppliers
*/
public IteratorBytesSupplier(Iterator<BytesSupplier> sources) {
this.sources = sources;
}

/** {@inheritDoc} */
@Override
public boolean hasNext() {
return sources.hasNext();
}

/** {@inheritDoc} */
@Override
public byte[] next() {
return sources.next().getAsBytes();
}

/** {@inheritDoc} */
@Override
public ByteBuffer toByteBuffer() {
return ByteBuffer.wrap(getAsBytes());
}

/** {@inheritDoc} */
@Override
public byte[] getAsBytes() {
try (ByteArrayOutputStream bos = new ByteArrayOutputStream()) {
while (hasNext()) {
bos.write(next());
}
return bos.toByteArray();
} catch (IOException e) {
throw new AssertionError("Failed to read BytesSupplier", e);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
/*
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.inference.streaming;

import ai.djl.ndarray.BytesSupplier;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Consumer;

/**
* An {@link PublisherBytesSupplier} is a streaming {@link BytesSupplier} suitable for reactive
* asynchronous usage.
*/
public class PublisherBytesSupplier implements BytesSupplier {

private final List<byte[]> allData;
private final AtomicBoolean completed;
private Consumer<byte[]> subscriber;
private final AtomicInteger dataPushed;

/** Constructs a {@link PublisherBytesSupplier}. */
public PublisherBytesSupplier() {
allData = new ArrayList<>();
completed = new AtomicBoolean();
dataPushed = new AtomicInteger();
}

/**
* Appends content to the {@code BytesSupplier}.
*
* @param data bytes to append
* @param lastChunk true if this is the last chunk
*/
public void appendContent(byte[] data, boolean lastChunk) {
synchronized (allData) {
allData.add(data);
}
if (lastChunk) {
completed.set(true);
}
pushData();
}

/**
* Adds the subscriber to the {@link BytesSupplier} to get notified about additional data.
*
* @param subscriber a consumer function that will receive bytes when new daata is added and
* null when completed
*/
public void subscribe(Consumer<byte[]> subscriber) {
if (this.subscriber != null) {
throw new IllegalStateException(
"The PublisherBytesSupplier only allows a single Subscriber");
}
this.subscriber = subscriber;
pushData();
}

private void pushData() {
if (subscriber == null) {
return;
}

int dataAvailable;
synchronized (allData) {
dataAvailable = allData.size();
}

int sent = dataPushed.getAndSet(dataAvailable);
if (sent < dataAvailable) {
synchronized (this) {
for (; sent < dataAvailable; sent++) {
subscriber.accept(allData.get(sent));
}
if (completed.get()) {
subscriber.accept(null);
}
}
}
}

/** Waits until completed before passing thread (BLOCKS THREAD!). */
@SuppressWarnings("PMD.EmptyControlStatement")
public void waitToRead() {
// Block until complete!!!
while (!completed.get()) {
// Do nothing
}
}

/** {@inheritDoc} */
@Override
public byte[] getAsBytes() {
if (!completed.get()) {
throw new IllegalStateException(
"PublisherByteSupplier must be completely filled before reading.");
}

try (ByteArrayOutputStream bos = new ByteArrayOutputStream()) {
for (byte[] data : allData) {
bos.write(data);
}
return bos.toByteArray();
} catch (IOException e) {
throw new AssertionError("Failed to read BytesSupplier", e);
}
}

/** {@inheritDoc} */
@Override
public ByteBuffer toByteBuffer() {
return ByteBuffer.wrap(getAsBytes());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/*
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.inference.streaming;

import ai.djl.ndarray.NDList;
import ai.djl.nn.Block;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;

import java.util.Iterator;
import java.util.Spliterator;
import java.util.Spliterators;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;

/**
* A {@link Block} possessing the additional streaming forward capabilities used by {@link
* ai.djl.inference.Predictor#streamingPredict(Object)}.
*/
public interface StreamingBlock extends Block {

/**
* Applies the operating function of the block once, but returns the result in chunks. This
* method should only be called on blocks that are initialized.
*
* @param parameterStore the parameter store
* @param inputs the input NDList
* @param training true for a training forward pass (turn on dropout and layerNorm)
* @return the output of the forward pass
*/
default Stream<NDList> forwardStream(
ParameterStore parameterStore, NDList inputs, boolean training) {
return forwardStream(parameterStore, inputs, training, null);
}

/**
* Applies the operating function of the block once, but returns the result in chunks. This
* method should only be called on blocks that are initialized.
*
* @param parameterStore the parameter store
* @param inputs the input NDList
* @param training true for a training forward pass (turn on dropout and layerNorm)
* @param params optional parameters
* @return the output of the forward pass
*/
default Stream<NDList> forwardStream(
ParameterStore parameterStore,
NDList inputs,
boolean training,
PairList<String, Object> params) {
Iterator<NDList> itr = forwardStreamIter(parameterStore, inputs, training, params);
Spliterator<NDList> spitr = Spliterators.spliteratorUnknownSize(itr, Spliterator.NONNULL);
return StreamSupport.stream(spitr, false);
}

/**
* Applies the operating function of the block once, but returns the result in chunks. This
* method should only be called on blocks that are initialized.
*
* @param parameterStore the parameter store
* @param inputs the input NDList
* @param training true for a training forward pass (turn on dropout and layerNorm)
* @param params optional parameters
* @return the output of the forward pass
*/
Iterator<NDList> forwardStreamIter(
ParameterStore parameterStore,
NDList inputs,
boolean training,
PairList<String, Object> params);
}
Loading