Skip to content

Commit

Permalink
chore: move subscription headers to protocol
Browse files Browse the repository at this point in the history
  • Loading branch information
Jordan-Nelson committed Aug 29, 2024
1 parent 32c5043 commit 9ad62cd
Show file tree
Hide file tree
Showing 5 changed files with 160 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,28 +26,18 @@ const _requiredHeaders = {
AWSHeaders.contentType: 'application/json; charset=utf-8',
};

// AppSync expects "{}" encoded in the URI as the payload during handshake.
const _emptyBody = <String, dynamic>{};
/// The default payload to include to AppSync.
///
/// AppSync expects "{}" encoded in the URI as the payload during handshake.
@internal
const appSyncDefaultPayload = <String, dynamic>{};

/// Generate a URI for the connection and all subscriptions.
///
/// See https://docs.aws.amazon.com/appsync/latest/devguide/real-time-websocket-client.html#handshake-details-to-establish-the-websocket-connection=
Future<Uri> generateConnectionUri(
ApiOutputs config,
AmplifyAuthProviderRepository authRepo,
) async {
// First, generate auth query parameters.
final authorizationHeaders = await _generateAuthorizationHeaders(
config,
isConnectionInit: true,
authRepo: authRepo,
body: _emptyBody,
);
final encodedAuthHeaders =
base64.encode(json.encode(authorizationHeaders).codeUnits);
Future<Uri> generateConnectionUri(ApiOutputs config) async {
final authQueryParameters = {
'header': encodedAuthHeaders,
'payload': base64.encode(utf8.encode(json.encode(_emptyBody))),
'payload': base64.encode(utf8.encode(json.encode(appSyncDefaultPayload))),
};
// Conditionally format the URI for a) AppSync domain b) custom domain.
var endpointUriHost = Uri.parse(config.url).host;
Expand Down Expand Up @@ -86,7 +76,7 @@ Future<WebSocketSubscriptionRegistrationMessage>
required GraphQLRequest<T> request,
}) async {
final body = {'variables': request.variables, 'query': request.document};
final authorizationHeaders = await _generateAuthorizationHeaders(
final authorizationHeaders = await generateAuthorizationHeaders(
config,
isConnectionInit: false,
authRepo: authRepo,
Expand Down Expand Up @@ -114,7 +104,8 @@ Future<WebSocketSubscriptionRegistrationMessage>
/// a canonical HTTP request that is authorized but never sent. The headers from
/// the HTTP request are reformatted and returned. This logic applies for all auth
/// modes as determined by [authRepo] parameter.
Future<Map<String, String>> _generateAuthorizationHeaders(
@internal
Future<Map<String, String>> generateAuthorizationHeaders(
ApiOutputs config, {
required bool isConnectionInit,
required AmplifyAuthProviderRepository authRepo,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import 'package:amplify_api_dart/src/graphql/web_socket/types/subscriptions_even
import 'package:amplify_api_dart/src/graphql/web_socket/types/web_socket_message_stream_transformer.dart';
import 'package:amplify_api_dart/src/graphql/web_socket/types/web_socket_types.dart';
import 'package:amplify_core/amplify_core.dart';
// ignore: implementation_imports
import 'package:amplify_core/src/config/amplify_outputs/api_outputs.dart';
import 'package:async/async.dart';
import 'package:meta/meta.dart';
import 'package:stream_transform/stream_transform.dart';
Expand Down Expand Up @@ -72,15 +74,14 @@ class AmplifyWebSocketService
);

try {
const webSocketProtocols = ['graphql-ws'];
final connectionUri = await generateConnectionUri(
final protocols = await generateProtocols(
state.config,
state.authProviderRepo,
);

final connectionUri = await generateConnectionUri(state.config);
final channel = WebSocketChannel.connect(
connectionUri,
protocols: webSocketProtocols,
protocols: protocols,
);
sink = channel.sink;

Expand All @@ -95,6 +96,24 @@ class AmplifyWebSocketService
}
}

/// Generates a list of protocols from a [WebSocketState].
@visibleForTesting
Future<List<String>> generateProtocols(
ApiOutputs outputs,
AmplifyAuthProviderRepository authRepo,
) async {
final authorizationHeaders = await generateAuthorizationHeaders(
outputs,
isConnectionInit: true,
authRepo: authRepo,
body: appSyncDefaultPayload,
);
final encodedAuthHeaders = base64Url.encode(
json.encode(authorizationHeaders).codeUnits,
);
return ['graphql-ws', 'header-$encodedAuthHeaders'];
}

@override
Future<void> register(
ConnectedState state,
Expand Down
25 changes: 23 additions & 2 deletions packages/api/amplify_api_dart/test/util.dart
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,9 @@ const testApiKeyConfigCustomDomain = DataOutputs(
);

const expectedApiKeyWebSocketConnectionUrl =
'wss://abc123.appsync-realtime-api.us-east-1.amazonaws.com/graphql?header=eyJBY2NlcHQiOiJhcHBsaWNhdGlvbi9qc29uLCB0ZXh0L2phdmFzY3JpcHQiLCJDb250ZW50LUVuY29kaW5nIjoiYW16LTEuMCIsIkNvbnRlbnQtVHlwZSI6ImFwcGxpY2F0aW9uL2pzb247IGNoYXJzZXQ9dXRmLTgiLCJYLUFwaS1LZXkiOiJhYmMtMTIzIiwiSG9zdCI6ImFiYzEyMy5hcHBzeW5jLWFwaS51cy1lYXN0LTEuYW1hem9uYXdzLmNvbSJ9&payload=e30%3D';
'wss://abc123.appsync-realtime-api.us-east-1.amazonaws.com/graphql?payload=e30%3D';
const expectedApiKeyWebSocketConnectionUrlCustomDomain =
'wss://foo.bar.aws.dev/graphql/realtime?header=eyJBY2NlcHQiOiJhcHBsaWNhdGlvbi9qc29uLCB0ZXh0L2phdmFzY3JpcHQiLCJDb250ZW50LUVuY29kaW5nIjoiYW16LTEuMCIsIkNvbnRlbnQtVHlwZSI6ImFwcGxpY2F0aW9uL2pzb247IGNoYXJzZXQ9dXRmLTgiLCJYLUFwaS1LZXkiOiJhYmMtMTIzIiwiSG9zdCI6ImZvby5iYXIuYXdzLmRldiJ9&payload=e30%3D';
'wss://foo.bar.aws.dev/graphql/realtime?payload=e30%3D';

AmplifyAuthProviderRepository getTestAuthProviderRepo() {
final testAuthProviderRepo = AmplifyAuthProviderRepository()
Expand Down Expand Up @@ -341,3 +341,24 @@ void testQueryPredicateTranslation(
}

final deepEquals = const DeepCollectionEquality().equals;

/// Creates [DataOutputs] and [AmplifyAuthProviderRepository] for use in tests.
(DataOutputs, AmplifyAuthProviderRepository) createOutputsAndRepo(
AmplifyAuthProvider authProvider,
APIAuthorizationType type, [
String? apiKey,
]) {
final repo = AmplifyAuthProviderRepository()
..registerAuthProvider(
type.authProviderToken,
authProvider,
);
final outputs = DataOutputs(
awsRegion: 'us-east-1',
url: 'https://example.com/',
defaultAuthorizationType: type,
authorizationTypes: [type],
apiKey: type == APIAuthorizationType.apiKey ? apiKey : null,
);
return (outputs, repo);
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,20 +47,17 @@ void main() {
}

group('generateConnectionUri', () {
test('should generate authorized connection URI', () async {
final actualConnectionUri =
await generateConnectionUri(testApiKeyConfig, authProviderRepo);
test('should generate connection URI', () async {
final actualConnectionUri = await generateConnectionUri(testApiKeyConfig);
expect(
actualConnectionUri.toString(),
expectedApiKeyWebSocketConnectionUrl,
);
});

test('should generate authorized connection URI with a custom domain',
() async {
test('should generate connection URI with a custom domain', () async {
final actualConnectionUri = await generateConnectionUri(
testApiKeyConfigCustomDomain,
authProviderRepo,
);
expect(
actualConnectionUri.toString(),
Expand Down Expand Up @@ -141,4 +138,68 @@ void main() {
);
});
});

group('generateAuthorizationHeaders', () {
const apiKey = 'fake-key';

test('should generate headers for API key Authorization', () async {
final (outputs, repo) = createOutputsAndRepo(
AppSyncApiKeyAuthProvider(),
APIAuthorizationType.apiKey,
apiKey,
);
final headers = await generateAuthorizationHeaders(
outputs,
isConnectionInit: true,
authRepo: repo,
body: {},
);
expect(headers[xApiKey], apiKey);
expect(headers.containsKey(AWSHeaders.accept), true);
expect(headers.containsKey(AWSHeaders.contentEncoding), true);
expect(headers.containsKey(AWSHeaders.contentType), true);
expect(headers.containsKey(AWSHeaders.host), true);
});

test('should generate headers for IAM Authorization', () async {
final (outputs, repo) = createOutputsAndRepo(
TestIamAuthProvider(),
APIAuthorizationType.iam,
);
final headers = await generateAuthorizationHeaders(
outputs,
isConnectionInit: true,
authRepo: repo,
body: {},
);
expect(
headers['Authorization']!.contains('Credential=fake-access-key-123'),
true,
);
expect(headers.containsKey(AWSHeaders.date), true);
expect(headers.containsKey(AWSHeaders.contentSHA256), true);
expect(headers.containsKey(AWSHeaders.accept), true);
expect(headers.containsKey(AWSHeaders.contentEncoding), true);
expect(headers.containsKey(AWSHeaders.contentType), true);
expect(headers.containsKey(AWSHeaders.host), true);
});

test('should generate headers for user pool Authorization', () async {
final (outputs, repo) = createOutputsAndRepo(
TestTokenAuthProvider(),
APIAuthorizationType.userPools,
);
final headers = await generateAuthorizationHeaders(
outputs,
isConnectionInit: true,
authRepo: repo,
body: {},
);
expect(headers[AWSHeaders.authorization], 'test-access-token-123');
expect(headers.containsKey(AWSHeaders.accept), true);
expect(headers.containsKey(AWSHeaders.contentEncoding), true);
expect(headers.containsKey(AWSHeaders.contentType), true);
expect(headers.containsKey(AWSHeaders.host), true);
});
});
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

import 'dart:convert';

import 'package:amplify_api_dart/src/graphql/providers/app_sync_api_key_auth_provider.dart';
import 'package:amplify_api_dart/src/graphql/web_socket/services/web_socket_service.dart';
import 'package:amplify_core/amplify_core.dart';
import 'package:test/test.dart';

import '../util.dart';

void main() {
group('AmplifyWebSocketService', () {
group('generateProtocols', () {});
const apiKey = 'fake-key';
test('should generate a protocol that includes the appropriate headers',
() async {
final (outputs, repo) = createOutputsAndRepo(
AppSyncApiKeyAuthProvider(),
APIAuthorizationType.apiKey,
apiKey,
);
final service = AmplifyWebSocketService();
final protocols = await service.generateProtocols(outputs, repo);
final encodedHeaders = protocols[1].replaceFirst('header-', '');
final headers = json.decode(
String.fromCharCodes(base64Url.decode(encodedHeaders)),
) as Map<String, dynamic>;
expect(headers[xApiKey], apiKey);
expect(headers.containsKey(AWSHeaders.accept), true);
expect(headers.containsKey(AWSHeaders.contentEncoding), true);
expect(headers.containsKey(AWSHeaders.contentType), true);
expect(headers.containsKey(AWSHeaders.host), true);
});
});
}

0 comments on commit 9ad62cd

Please sign in to comment.