From 6f5bb95e3ff5befd44cddf0be1d690663d0343cd Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Fri, 24 Mar 2023 07:34:19 -0700 Subject: [PATCH] Adds non-blocking poll() for BytesSupplier (#2478) --- .../java/ai/djl/modality/ChunkedBytesSupplier.java | 10 ++++++++++ .../java/ai/djl/modality/ChunkedBytesSupplierTest.java | 3 ++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/api/src/main/java/ai/djl/modality/ChunkedBytesSupplier.java b/api/src/main/java/ai/djl/modality/ChunkedBytesSupplier.java index a40d4fe84f9..c8b2d5d0e0e 100644 --- a/api/src/main/java/ai/djl/modality/ChunkedBytesSupplier.java +++ b/api/src/main/java/ai/djl/modality/ChunkedBytesSupplier.java @@ -71,6 +71,16 @@ public byte[] nextChunk(long timeout, TimeUnit unit) throws InterruptedException return data.getAsBytes(); } + /** + * Retrieves and removes the head of chunk or returns {@code null} if data is not available. + * + * @return the head of chunk or returns {@code null} if data is not available + */ + public byte[] poll() { + BytesSupplier data = queue.poll(); + return data == null ? null : data.getAsBytes(); + } + /** {@inheritDoc} */ @Override public byte[] getAsBytes() { diff --git a/api/src/test/java/ai/djl/modality/ChunkedBytesSupplierTest.java b/api/src/test/java/ai/djl/modality/ChunkedBytesSupplierTest.java index 8af1aeafa50..1d41352f375 100644 --- a/api/src/test/java/ai/djl/modality/ChunkedBytesSupplierTest.java +++ b/api/src/test/java/ai/djl/modality/ChunkedBytesSupplierTest.java @@ -24,6 +24,7 @@ public class ChunkedBytesSupplierTest { @Test public void test() throws InterruptedException, IOException { ChunkedBytesSupplier supplier = new ChunkedBytesSupplier(); + Assert.assertNull(supplier.poll()); Assert.assertThrows(() -> supplier.nextChunk(1, TimeUnit.MICROSECONDS)); supplier.appendContent(new byte[] {1, 2}, false); @@ -37,7 +38,7 @@ public void test() throws InterruptedException, IOException { data.appendContent(new byte[] {1, 2}, true); Assert.assertTrue(data.hasNext()); - Assert.assertEquals(data.nextChunk(1, TimeUnit.MILLISECONDS).length, 0); + Assert.assertEquals(data.poll().length, 0); Assert.assertEquals(data.nextChunk(1, TimeUnit.MILLISECONDS), new byte[] {1, 2}); Assert.assertFalse(data.hasNext());