Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(rt): add conversions to and from ByteStream and Flow<ByteArray> #947

Merged
merged 5 commits into from
Sep 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .changes/f82c0433-30f9-4246-8f18-91402c5ac0ab.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"id": "f82c0433-30f9-4246-8f18-91402c5ac0ab",
"type": "feature",
"description": "Add conversions to and from `Flow<ByteArray>` and `ByteStream`",
"issues": [
"awslabs/aws-sdk-kotlin#612"
]
}
4 changes: 4 additions & 0 deletions runtime/runtime-core/api/runtime-core.api
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,10 @@ public final class aws/smithy/kotlin/runtime/content/ByteStreamKt {
public static final fun cancel (Laws/smithy/kotlin/runtime/content/ByteStream;)V
public static final fun decodeToString (Laws/smithy/kotlin/runtime/content/ByteStream;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
public static final fun toByteArray (Laws/smithy/kotlin/runtime/content/ByteStream;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
public static final fun toByteStream (Lkotlinx/coroutines/flow/Flow;Lkotlinx/coroutines/CoroutineScope;Ljava/lang/Long;)Laws/smithy/kotlin/runtime/content/ByteStream;
public static synthetic fun toByteStream$default (Lkotlinx/coroutines/flow/Flow;Lkotlinx/coroutines/CoroutineScope;Ljava/lang/Long;ILjava/lang/Object;)Laws/smithy/kotlin/runtime/content/ByteStream;
public static final fun toFlow (Laws/smithy/kotlin/runtime/content/ByteStream;J)Lkotlinx/coroutines/flow/Flow;
public static synthetic fun toFlow$default (Laws/smithy/kotlin/runtime/content/ByteStream;JILjava/lang/Object;)Lkotlinx/coroutines/flow/Flow;
}

public abstract class aws/smithy/kotlin/runtime/content/Document {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
*/
package aws.smithy.kotlin.runtime.content

import aws.smithy.kotlin.runtime.io.SdkByteReadChannel
import aws.smithy.kotlin.runtime.io.SdkSource
import aws.smithy.kotlin.runtime.io.readToBuffer
import aws.smithy.kotlin.runtime.io.readToByteArray
import aws.smithy.kotlin.runtime.io.*
import aws.smithy.kotlin.runtime.io.internal.SdkDispatchers
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.flow.*
import kotlinx.coroutines.launch

/**
* Represents an abstract read-only stream of bytes
Expand Down Expand Up @@ -106,3 +107,92 @@ public fun ByteStream.cancel() {
is ByteStream.SourceStream -> stream.readFrom().close()
}
}

/**
* Return a [Flow] that consumes the underlying [ByteStream] when collected.
*
* @param bufferSize the size of the buffers to emit from the flow. All buffers emitted
* will be of this size except for the last one which may be less than the requested buffer size.
* This parameter has no effect for the [ByteStream.Buffer] variant. The emitted [ByteArray]
* will be whatever size the in-memory buffer already is in that case.
*/
public fun ByteStream.toFlow(bufferSize: Long = 8192): Flow<ByteArray> = when (this) {
is ByteStream.Buffer -> flowOf(bytes())
is ByteStream.ChannelStream -> readFrom().toFlow(bufferSize)
is ByteStream.SourceStream -> readFrom().toFlow(bufferSize).flowOn(SdkDispatchers.IO)
}

/**
* Create a [ByteStream] from a [Flow] of byte arrays.
*
* @param scope the [CoroutineScope] to use for launching a coroutine to do the collection in.
* @param contentLength the overall content length of the [Flow] (if known). If set this will be
* used as [ByteStream.contentLength]. Some APIs require a known `Content-Length` header and
* since the total size of the flow can't be calculated without collecting it callers should set this
* parameter appropriately in those cases.
*/
public fun Flow<ByteArray>.toByteStream(
scope: CoroutineScope,
contentLength: Long? = null,
): ByteStream {
val ch = SdkByteChannel(true)
var totalWritten = 0L
val job = scope.launch {
collect { bytes ->
ch.write(bytes)
totalWritten += bytes.size

check(contentLength == null || totalWritten <= contentLength) {
"$totalWritten bytes collected from flow exceeds reported content length of $contentLength"
}
}

check(contentLength == null || totalWritten == contentLength) {
"expected $contentLength bytes collected from flow, got $totalWritten"
}

ch.close()
}

job.invokeOnCompletion { cause ->
ch.close(cause)
}

return object : ByteStream.ChannelStream() {
override val contentLength: Long? = contentLength
override val isOneShot: Boolean = true
override fun readFrom(): SdkByteReadChannel = ch
}
}

private fun SdkByteReadChannel.toFlow(bufferSize: Long): Flow<ByteArray> = flow {
val chan = this@toFlow
val sink = SdkBuffer()
while (!chan.isClosedForRead) {
val rc = chan.read(sink, bufferSize)
if (rc == -1L) break
if (sink.size >= bufferSize) {
val bytes = sink.readByteArray(bufferSize)
emit(bytes)
}
}
if (sink.size > 0L) {
emit(sink.readByteArray())
}
}

private fun SdkSource.toFlow(bufferSize: Long): Flow<ByteArray> = flow {
val source = this@toFlow
val sink = SdkBuffer()
while (true) {
val rc = source.read(sink, bufferSize)
if (rc == -1L) break
if (sink.size >= bufferSize) {
val bytes = sink.readByteArray(bufferSize)
emit(bytes)
}
}
if (sink.size > 0L) {
emit(sink.readByteArray())
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/
package aws.smithy.kotlin.runtime.content

import aws.smithy.kotlin.runtime.io.SdkByteReadChannel
import aws.smithy.kotlin.runtime.io.SdkSource
import aws.smithy.kotlin.runtime.io.source

fun interface ByteStreamFactory {
fun byteStream(input: ByteArray): ByteStream
companion object {
val BYTE_ARRAY: ByteStreamFactory = ByteStreamFactory { input -> ByteStream.fromBytes(input) }

val SDK_SOURCE: ByteStreamFactory = ByteStreamFactory { input ->
object : ByteStream.SourceStream() {
override fun readFrom(): SdkSource = input.source()
override val contentLength: Long = input.size.toLong()
}
}

val SDK_CHANNEL: ByteStreamFactory = ByteStreamFactory { input ->
object : ByteStream.ChannelStream() {
override fun readFrom(): SdkByteReadChannel = SdkByteReadChannel(input)
override val contentLength: Long = input.size.toLong()
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/
package aws.smithy.kotlin.runtime.content

import io.kotest.matchers.string.shouldContain
import kotlinx.coroutines.*
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.flow.*
import kotlinx.coroutines.test.runTest
import java.lang.RuntimeException
import kotlin.test.*

class ByteStreamBufferFlowTest : ByteStreamFlowTest(ByteStreamFactory.BYTE_ARRAY)
class ByteStreamSourceStreamFlowTest : ByteStreamFlowTest(ByteStreamFactory.SDK_SOURCE)
class ByteStreamChannelSourceFlowTest : ByteStreamFlowTest(ByteStreamFactory.SDK_CHANNEL)

abstract class ByteStreamFlowTest(
private val factory: ByteStreamFactory,
) {
@Test
fun testToFlowWithSizeHint() = runTest {
val data = "a korf is a tiger".repeat(1024).encodeToByteArray()
val bufferSize = 8182 * 2
val byteStream = factory.byteStream(data)
val flow = byteStream.toFlow(bufferSize.toLong())
val buffers = mutableListOf<ByteArray>()
flow.toList(buffers)

val totalCollected = buffers.sumOf { it.size }
assertEquals(data.size, totalCollected)

if (byteStream is ByteStream.Buffer) {
assertEquals(1, buffers.size)
assertContentEquals(data, buffers.first())
} else {
val expectedFullBuffers = data.size / bufferSize
for (i in 0 until expectedFullBuffers) {
val b = buffers[i]
val expected = data.sliceArray((i * bufferSize)until (i * bufferSize + bufferSize))
assertEquals(bufferSize, b.size)
assertContentEquals(expected, b)
}

val last = buffers.last()
val expected = data.sliceArray(((buffers.size - 1) * bufferSize) until data.size)
assertContentEquals(expected, last)
}
}

class FlowToByteStreamTest {
private fun testByteArray(size: Int): ByteArray = ByteArray(size) { i -> i.toByte() }

val data = listOf(
testByteArray(576),
testByteArray(9172),
testByteArray(3278),
)

@Test
fun testFlowToByteStreamReadAll() = runTest {
val flow = data.asFlow()
val scope = CoroutineScope(coroutineContext)
val byteStream = flow.toByteStream(scope)

assertNull(byteStream.contentLength)

val actual = byteStream.toByteArray()
val expected = data.reduce { acc, bytes -> acc + bytes }
assertContentEquals(expected, actual)
}

@Test
fun testContentLengthOverflow() = runTest {
val advertisedContentLength = 1024L
testInvalidContentLength(advertisedContentLength, "9748 bytes collected from flow exceeds reported content length of 1024")
}

@Test
fun testContentLengthUnderflow() = runTest {
val advertisedContentLength = data.sumOf { it.size } + 100L
testInvalidContentLength(advertisedContentLength, "expected 13126 bytes collected from flow, got 13026")
}

private suspend fun testInvalidContentLength(advertisedContentLength: Long, expectedMessage: String) {
val job = Job()
val uncaughtExceptions = mutableListOf<Throwable>()
val exHandler = CoroutineExceptionHandler { _, throwable -> uncaughtExceptions.add(throwable) }
val scope = CoroutineScope(job + exHandler)
val byteStream = data
.asFlow()
.toByteStream(scope, advertisedContentLength)

assertEquals(advertisedContentLength, byteStream.contentLength)

val ex = assertFailsWith<IllegalStateException> {
byteStream.toByteArray()
}

ex.message?.shouldContain(expectedMessage)
assertTrue(job.isCancelled)
job.join()

assertEquals(1, uncaughtExceptions.size)
}

@Test
fun testScopeCancellation() = runTest {
// cancelling the scope should close/cancel the channel
val waiter = Channel<Unit>(1)
val flow = flow {
emit(testByteArray(128))
emit(testByteArray(277))
waiter.receive()
emit(testByteArray(97))
}

val job = Job()
val scope = CoroutineScope(job)
val byteStream = flow.toByteStream(scope)
assertIs<ByteStream.ChannelStream>(byteStream)
assertNull(byteStream.contentLength)
yield()

job.cancel("scope cancelled")
waiter.send(Unit)
job.join()

val ch = byteStream.readFrom()
assertTrue(ch.isClosedForRead)
assertTrue(ch.isClosedForWrite)
assertIs<CancellationException>(ch.closedCause)
ch.closedCause?.message.shouldContain("scope cancelled")
}

@Test
fun testChannelCancellation() = runTest {
// cancelling the channel should cancel the scope (via write failing)
val waiter = Channel<Unit>(1)
val flow = flow {
emit(testByteArray(128))
emit(testByteArray(277))
waiter.receive()
emit(testByteArray(97))
}

val uncaughtExceptions = mutableListOf<Throwable>()
val exHandler = CoroutineExceptionHandler { _, throwable -> uncaughtExceptions.add(throwable) }
val job = Job()
val scope = CoroutineScope(job + exHandler)
val byteStream = flow.toByteStream(scope)
assertIs<ByteStream.ChannelStream>(byteStream)

val ch = byteStream.readFrom()
val cause = RuntimeException("chan cancelled")
ch.cancel(cause)

// unblock the flow
waiter.send(Unit)

job.join()
assertTrue(job.isCancelled)
assertEquals(1, uncaughtExceptions.size)
uncaughtExceptions.first().message.shouldContain("chan cancelled")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,34 +4,12 @@
*/
package aws.smithy.kotlin.runtime.content

import aws.smithy.kotlin.runtime.io.SdkByteReadChannel
import aws.smithy.kotlin.runtime.io.SdkSource
import aws.smithy.kotlin.runtime.io.source
import java.io.InputStream
import kotlin.test.Test
import kotlin.test.assertContentEquals
import kotlin.test.assertEquals

fun interface ByteStreamFactory {
fun inputStream(input: ByteArray): InputStream
companion object {
val BYTE_ARRAY: ByteStreamFactory = ByteStreamFactory { input -> ByteStream.fromBytes(input).toInputStream() }

val SDK_SOURCE: ByteStreamFactory = ByteStreamFactory { input ->
object : ByteStream.SourceStream() {
override fun readFrom(): SdkSource = input.source()
override val contentLength: Long = input.size.toLong()
}.toInputStream()
}

val SDK_CHANNEL: ByteStreamFactory = ByteStreamFactory { input ->
object : ByteStream.ChannelStream() {
override fun readFrom(): SdkByteReadChannel = SdkByteReadChannel(input)
override val contentLength: Long = input.size.toLong()
}.toInputStream()
}
}
}
fun ByteStreamFactory.inputStream(input: ByteArray): InputStream = byteStream(input).toInputStream()

class ByteStreamBufferInputStreamTest : ByteStreamInputStreamTest(ByteStreamFactory.BYTE_ARRAY)
class ByteStreamSourceStreamInputStreamTest : ByteStreamInputStreamTest(ByteStreamFactory.SDK_SOURCE)
Expand Down