diff --git a/Package.swift b/Package.swift index 977f8f365ec..db774cfd0da 100644 --- a/Package.swift +++ b/Package.swift @@ -555,4 +555,4 @@ let servicesWithIntegrationTests: [String] = [ servicesWithIntegrationTests.forEach(addIntegrationTestTarget) // Uncomment this line to enable protocol tests -addProtocolTests() \ No newline at end of file +// addProtocolTests() \ No newline at end of file diff --git a/Sources/Core/AWSClientRuntime/EventStream/AWSMessageDecoder.swift b/Sources/Core/AWSClientRuntime/EventStream/AWSMessageDecoder.swift index e484e571007..169f9b0225b 100644 --- a/Sources/Core/AWSClientRuntime/EventStream/AWSMessageDecoder.swift +++ b/Sources/Core/AWSClientRuntime/EventStream/AWSMessageDecoder.swift @@ -17,6 +17,9 @@ extension AWSEventStream { private var decoder: EventStreamMessageDecoder? private var messageBuffer: [EventStream.Message] = [] private var error: Error? + private var initialMessage: Data = Data() + private var onInitialResponseReceived: ((Data?) -> Void)? + private var didProcessInitialMessage = false private var decodedPayload = Data() private var decodededHeaders: [EventStreamHeader] = [] @@ -44,8 +47,22 @@ extension AWSEventStream { self.logger.debug("onComplete") let message = EventStream.Message(headers: self.decodededHeaders.toHeaders(), payload: self.decodedPayload) - self.messageBuffer.append(message) + if message.headers.contains( + EventStream.Header(name: ":event-type", value: .string("initial-response")) + ) { + self.initialMessage = message.payload + self.onInitialResponseReceived?(self.initialMessage) + self.onInitialResponseReceived = nil // ensure reference counter is set to 0 and cleaned up + self.didProcessInitialMessage = true + } else { + self.messageBuffer.append(message) + if !self.didProcessInitialMessage { + self.onInitialResponseReceived?(nil) // Signal that initial-response will never come. + self.onInitialResponseReceived = nil // ensure reference counter is set to 0 and cleaned up + self.didProcessInitialMessage = true + } + } // This could be end of the stream, hence reset the state self.decodedPayload = Data() self.decodededHeaders = [] @@ -88,6 +105,38 @@ extension AWSEventStream { return message } + // Responsible for waiting on the initial response. + // It uses Swift's concurrency model to asynchronously return the data. + public func awaitInitialResponse() async -> Data? { + // The 'withCheckedContinuation' function is used to bridge asynchronous code + // that doesn't use Swift's concurrency model with code that does. + return await withCheckedContinuation { continuation in + // Here, we attempt to retrieve the initial response. + // Once the data is retrieved (or determined to be nil), + // the continuation is resumed with the result. + retrieveInitialResponse { data in + continuation.resume(returning: data) + } + } + } + + // Attempt to get the initial response. + // If the initial message has been processed, it immediately calls the completion handler. + // Otherwise, it sets up a callback to be triggered once the initial response is received. + private func retrieveInitialResponse(completion: @escaping (Data?) -> Void) { + // Check if the initial message has already been processed. + if self.didProcessInitialMessage { + // If it has been processed, immediately call the completion handler + // with the potentially nil or populated 'initialMessage' value. + completion(initialMessage) + } else { + // If the initial message hasn't been processed, + // set the 'onInitialResponseReceived' callback to our completion handler, + // so it can be called later once the initial response is received. + self.onInitialResponseReceived = completion + } + } + /// Throws an error if one has occurred. /// This should be called before any other methods to make sure /// that the decoder is in a valid state. diff --git a/Tests/Core/AWSClientRuntimeTests/EventStream/AWSMessageDecoderTests.swift b/Tests/Core/AWSClientRuntimeTests/EventStream/AWSMessageDecoderTests.swift index 3c66d8ed665..8fd13659e92 100644 --- a/Tests/Core/AWSClientRuntimeTests/EventStream/AWSMessageDecoderTests.swift +++ b/Tests/Core/AWSClientRuntimeTests/EventStream/AWSMessageDecoderTests.swift @@ -30,6 +30,13 @@ final class AWSMessageDecoderTests: XCTestCase { XCTAssertEqual(validMessageNoHeaders, decodedMessage) } + func testDecode_MessageWithInitialResponse() { + try! sut.feed(data: validInitialResponseMessageData) + let decodedMessage = try! sut.message() + // initialResponse message should not be added to the messageBuffer + XCTAssertNil(decodedMessage) + } + func testEndOfStream_StreamClosed() { try! sut.feed(data: validMessageDataNoHeaders[0..