diff --git a/packages/api/amplify_api_dart/lib/src/decorators/web_socket_auth_utils.dart b/packages/api/amplify_api_dart/lib/src/decorators/web_socket_auth_utils.dart index 2eb8660a33b..ee26bf2d779 100644 --- a/packages/api/amplify_api_dart/lib/src/decorators/web_socket_auth_utils.dart +++ b/packages/api/amplify_api_dart/lib/src/decorators/web_socket_auth_utils.dart @@ -26,28 +26,18 @@ const _requiredHeaders = { AWSHeaders.contentType: 'application/json; charset=utf-8', }; -// AppSync expects "{}" encoded in the URI as the payload during handshake. -const _emptyBody = {}; +/// The default payload to include to AppSync. +/// +/// AppSync expects "{}" encoded in the URI as the payload during handshake. +@internal +const appSyncDefaultPayload = {}; /// Generate a URI for the connection and all subscriptions. /// /// See https://docs.aws.amazon.com/appsync/latest/devguide/real-time-websocket-client.html#handshake-details-to-establish-the-websocket-connection= -Future generateConnectionUri( - ApiOutputs config, - AmplifyAuthProviderRepository authRepo, -) async { - // First, generate auth query parameters. - final authorizationHeaders = await _generateAuthorizationHeaders( - config, - isConnectionInit: true, - authRepo: authRepo, - body: _emptyBody, - ); - final encodedAuthHeaders = - base64.encode(json.encode(authorizationHeaders).codeUnits); +Future generateConnectionUri(ApiOutputs config) async { final authQueryParameters = { - 'header': encodedAuthHeaders, - 'payload': base64.encode(utf8.encode(json.encode(_emptyBody))), + 'payload': base64.encode(utf8.encode(json.encode(appSyncDefaultPayload))), }; // Conditionally format the URI for a) AppSync domain b) custom domain. var endpointUriHost = Uri.parse(config.url).host; @@ -86,7 +76,7 @@ Future required GraphQLRequest request, }) async { final body = {'variables': request.variables, 'query': request.document}; - final authorizationHeaders = await _generateAuthorizationHeaders( + final authorizationHeaders = await generateAuthorizationHeaders( config, isConnectionInit: false, authRepo: authRepo, @@ -114,7 +104,8 @@ Future /// a canonical HTTP request that is authorized but never sent. The headers from /// the HTTP request are reformatted and returned. This logic applies for all auth /// modes as determined by [authRepo] parameter. -Future> _generateAuthorizationHeaders( +@internal +Future> generateAuthorizationHeaders( ApiOutputs config, { required bool isConnectionInit, required AmplifyAuthProviderRepository authRepo, diff --git a/packages/api/amplify_api_dart/lib/src/graphql/web_socket/services/web_socket_service.dart b/packages/api/amplify_api_dart/lib/src/graphql/web_socket/services/web_socket_service.dart index 8cb33714861..0ffc1d4c2bf 100644 --- a/packages/api/amplify_api_dart/lib/src/graphql/web_socket/services/web_socket_service.dart +++ b/packages/api/amplify_api_dart/lib/src/graphql/web_socket/services/web_socket_service.dart @@ -12,6 +12,8 @@ import 'package:amplify_api_dart/src/graphql/web_socket/types/subscriptions_even import 'package:amplify_api_dart/src/graphql/web_socket/types/web_socket_message_stream_transformer.dart'; import 'package:amplify_api_dart/src/graphql/web_socket/types/web_socket_types.dart'; import 'package:amplify_core/amplify_core.dart'; +// ignore: implementation_imports +import 'package:amplify_core/src/config/amplify_outputs/api_outputs.dart'; import 'package:async/async.dart'; import 'package:meta/meta.dart'; import 'package:stream_transform/stream_transform.dart'; @@ -72,15 +74,14 @@ class AmplifyWebSocketService ); try { - const webSocketProtocols = ['graphql-ws']; - final connectionUri = await generateConnectionUri( + final protocols = await generateProtocols( state.config, state.authProviderRepo, ); - + final connectionUri = await generateConnectionUri(state.config); final channel = WebSocketChannel.connect( connectionUri, - protocols: webSocketProtocols, + protocols: protocols, ); sink = channel.sink; @@ -95,6 +96,24 @@ class AmplifyWebSocketService } } + /// Generates a list of protocols from a [WebSocketState]. + @visibleForTesting + Future> generateProtocols( + ApiOutputs outputs, + AmplifyAuthProviderRepository authRepo, + ) async { + final authorizationHeaders = await generateAuthorizationHeaders( + outputs, + isConnectionInit: true, + authRepo: authRepo, + body: appSyncDefaultPayload, + ); + final encodedAuthHeaders = base64Url.encode( + json.encode(authorizationHeaders).codeUnits, + ); + return ['graphql-ws', 'header-$encodedAuthHeaders']; + } + @override Future register( ConnectedState state, diff --git a/packages/api/amplify_api_dart/test/util.dart b/packages/api/amplify_api_dart/test/util.dart index f9c0494865a..f4d7e9e3900 100644 --- a/packages/api/amplify_api_dart/test/util.dart +++ b/packages/api/amplify_api_dart/test/util.dart @@ -89,9 +89,9 @@ const testApiKeyConfigCustomDomain = DataOutputs( ); const expectedApiKeyWebSocketConnectionUrl = - 'wss://abc123.appsync-realtime-api.us-east-1.amazonaws.com/graphql?header=eyJBY2NlcHQiOiJhcHBsaWNhdGlvbi9qc29uLCB0ZXh0L2phdmFzY3JpcHQiLCJDb250ZW50LUVuY29kaW5nIjoiYW16LTEuMCIsIkNvbnRlbnQtVHlwZSI6ImFwcGxpY2F0aW9uL2pzb247IGNoYXJzZXQ9dXRmLTgiLCJYLUFwaS1LZXkiOiJhYmMtMTIzIiwiSG9zdCI6ImFiYzEyMy5hcHBzeW5jLWFwaS51cy1lYXN0LTEuYW1hem9uYXdzLmNvbSJ9&payload=e30%3D'; + 'wss://abc123.appsync-realtime-api.us-east-1.amazonaws.com/graphql?payload=e30%3D'; const expectedApiKeyWebSocketConnectionUrlCustomDomain = - 'wss://foo.bar.aws.dev/graphql/realtime?header=eyJBY2NlcHQiOiJhcHBsaWNhdGlvbi9qc29uLCB0ZXh0L2phdmFzY3JpcHQiLCJDb250ZW50LUVuY29kaW5nIjoiYW16LTEuMCIsIkNvbnRlbnQtVHlwZSI6ImFwcGxpY2F0aW9uL2pzb247IGNoYXJzZXQ9dXRmLTgiLCJYLUFwaS1LZXkiOiJhYmMtMTIzIiwiSG9zdCI6ImZvby5iYXIuYXdzLmRldiJ9&payload=e30%3D'; + 'wss://foo.bar.aws.dev/graphql/realtime?payload=e30%3D'; AmplifyAuthProviderRepository getTestAuthProviderRepo() { final testAuthProviderRepo = AmplifyAuthProviderRepository() @@ -341,3 +341,24 @@ void testQueryPredicateTranslation( } final deepEquals = const DeepCollectionEquality().equals; + +/// Creates [DataOutputs] and [AmplifyAuthProviderRepository] for use in tests. +(DataOutputs, AmplifyAuthProviderRepository) createOutputsAndRepo( + AmplifyAuthProvider authProvider, + APIAuthorizationType type, [ + String? apiKey, +]) { + final repo = AmplifyAuthProviderRepository() + ..registerAuthProvider( + type.authProviderToken, + authProvider, + ); + final outputs = DataOutputs( + awsRegion: 'us-east-1', + url: 'https://example.com/', + defaultAuthorizationType: type, + authorizationTypes: [type], + apiKey: type == APIAuthorizationType.apiKey ? apiKey : null, + ); + return (outputs, repo); +} diff --git a/packages/api/amplify_api_dart/test/web_socket/web_socket_auth_utils_test.dart b/packages/api/amplify_api_dart/test/web_socket/web_socket_auth_utils_test.dart index ed9d4100820..c32396435e7 100644 --- a/packages/api/amplify_api_dart/test/web_socket/web_socket_auth_utils_test.dart +++ b/packages/api/amplify_api_dart/test/web_socket/web_socket_auth_utils_test.dart @@ -47,20 +47,17 @@ void main() { } group('generateConnectionUri', () { - test('should generate authorized connection URI', () async { - final actualConnectionUri = - await generateConnectionUri(testApiKeyConfig, authProviderRepo); + test('should generate connection URI', () async { + final actualConnectionUri = await generateConnectionUri(testApiKeyConfig); expect( actualConnectionUri.toString(), expectedApiKeyWebSocketConnectionUrl, ); }); - test('should generate authorized connection URI with a custom domain', - () async { + test('should generate connection URI with a custom domain', () async { final actualConnectionUri = await generateConnectionUri( testApiKeyConfigCustomDomain, - authProviderRepo, ); expect( actualConnectionUri.toString(), @@ -141,4 +138,68 @@ void main() { ); }); }); + + group('generateAuthorizationHeaders', () { + const apiKey = 'fake-key'; + + test('should generate headers for API key Authorization', () async { + final (outputs, repo) = createOutputsAndRepo( + AppSyncApiKeyAuthProvider(), + APIAuthorizationType.apiKey, + apiKey, + ); + final headers = await generateAuthorizationHeaders( + outputs, + isConnectionInit: true, + authRepo: repo, + body: {}, + ); + expect(headers[xApiKey], apiKey); + expect(headers.containsKey(AWSHeaders.accept), true); + expect(headers.containsKey(AWSHeaders.contentEncoding), true); + expect(headers.containsKey(AWSHeaders.contentType), true); + expect(headers.containsKey(AWSHeaders.host), true); + }); + + test('should generate headers for IAM Authorization', () async { + final (outputs, repo) = createOutputsAndRepo( + TestIamAuthProvider(), + APIAuthorizationType.iam, + ); + final headers = await generateAuthorizationHeaders( + outputs, + isConnectionInit: true, + authRepo: repo, + body: {}, + ); + expect( + headers['Authorization']!.contains('Credential=fake-access-key-123'), + true, + ); + expect(headers.containsKey(AWSHeaders.date), true); + expect(headers.containsKey(AWSHeaders.contentSHA256), true); + expect(headers.containsKey(AWSHeaders.accept), true); + expect(headers.containsKey(AWSHeaders.contentEncoding), true); + expect(headers.containsKey(AWSHeaders.contentType), true); + expect(headers.containsKey(AWSHeaders.host), true); + }); + + test('should generate headers for user pool Authorization', () async { + final (outputs, repo) = createOutputsAndRepo( + TestTokenAuthProvider(), + APIAuthorizationType.userPools, + ); + final headers = await generateAuthorizationHeaders( + outputs, + isConnectionInit: true, + authRepo: repo, + body: {}, + ); + expect(headers[AWSHeaders.authorization], 'test-access-token-123'); + expect(headers.containsKey(AWSHeaders.accept), true); + expect(headers.containsKey(AWSHeaders.contentEncoding), true); + expect(headers.containsKey(AWSHeaders.contentType), true); + expect(headers.containsKey(AWSHeaders.host), true); + }); + }); } diff --git a/packages/api/amplify_api_dart/test/web_socket/web_socket_service_test.dart b/packages/api/amplify_api_dart/test/web_socket/web_socket_service_test.dart new file mode 100644 index 00000000000..7c4759b64d1 --- /dev/null +++ b/packages/api/amplify_api_dart/test/web_socket/web_socket_service_test.dart @@ -0,0 +1,37 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +import 'dart:convert'; + +import 'package:amplify_api_dart/src/graphql/providers/app_sync_api_key_auth_provider.dart'; +import 'package:amplify_api_dart/src/graphql/web_socket/services/web_socket_service.dart'; +import 'package:amplify_core/amplify_core.dart'; +import 'package:test/test.dart'; + +import '../util.dart'; + +void main() { + group('AmplifyWebSocketService', () { + group('generateProtocols', () {}); + const apiKey = 'fake-key'; + test('should generate a protocol that includes the appropriate headers', + () async { + final (outputs, repo) = createOutputsAndRepo( + AppSyncApiKeyAuthProvider(), + APIAuthorizationType.apiKey, + apiKey, + ); + final service = AmplifyWebSocketService(); + final protocols = await service.generateProtocols(outputs, repo); + final encodedHeaders = protocols[1].replaceFirst('header-', ''); + final headers = json.decode( + String.fromCharCodes(base64Url.decode(encodedHeaders)), + ) as Map; + expect(headers[xApiKey], apiKey); + expect(headers.containsKey(AWSHeaders.accept), true); + expect(headers.containsKey(AWSHeaders.contentEncoding), true); + expect(headers.containsKey(AWSHeaders.contentType), true); + expect(headers.containsKey(AWSHeaders.host), true); + }); + }); +}