Skip to content

Commit

Permalink
Revert client pool changes from authentication method
Browse files Browse the repository at this point in the history
Signed-off-by: Bandini Bhopi <bandinib@amazon.com>
  • Loading branch information
bandinib-amzn committed Mar 15, 2024
1 parent d273c57 commit bbd290e
Show file tree
Hide file tree
Showing 9 changed files with 151 additions and 150 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,12 @@

import { AuthenticationMethodRegistery } from './authentication_methods_registry';
import { AuthenticationMethod } from '../../server/types';
import { AuthType } from '../../common/data_sources';
import { OpenSearchClientPoolSetup } from '../client';

const clientPoolSetup: OpenSearchClientPoolSetup = {
getClientFromPool: jest.fn(),
addClientToPool: jest.fn(),
};

const createAuthenticationMethod = (
authMethod: Partial<AuthenticationMethod>
): AuthenticationMethod => ({
name: 'unknown',
authType: AuthType.NoAuth,
credentialProvider: jest.fn(),
clientPoolSetup,
legacyClientPoolSetup: clientPoolSetup,
...authMethod,
});

Expand Down Expand Up @@ -69,14 +59,14 @@ describe('AuthenticationMethodRegistery', () => {
registry.registerAuthenticationMethod(
createAuthenticationMethod({
name: 'typeA',
authType: AuthType.NoAuth,
credentialProvider: jest.fn(),
})
);

const typeA = registry.getAuthenticationMethod('typeA')!;

expect(() => {
typeA.authType = AuthType.SigV4;
typeA.credentialProvider = jest.fn();
}).toThrow();
expect(() => {
typeA.name = 'foo';
Expand Down
21 changes: 4 additions & 17 deletions src/plugins/data_source/server/client/client_pool.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,13 @@
import { Client } from '@opensearch-project/opensearch';
import { Client as LegacyClient } from 'elasticsearch';
import LRUCache from 'lru-cache';
import { Logger, OpenSearchDashboardsRequest } from 'src/core/server';
import { Logger } from 'src/core/server';
import { AuthType } from '../../common/data_sources';
import { DataSourcePluginConfigType } from '../../config';

export interface OpenSearchClientPoolSetup {
getClientFromPool: (
endpoint: string,
authType: AuthType,
request?: OpenSearchDashboardsRequest
) => Client | LegacyClient | undefined;
addClientToPool: (
endpoint: string,
authType: AuthType,
client: Client | LegacyClient,
request?: OpenSearchDashboardsRequest
) => void;
getClientFromPool: (endpoint: string, authType: AuthType) => Client | LegacyClient | undefined;
addClientToPool: (endpoint: string, authType: AuthType, client: Client | LegacyClient) => void;
}

/**
Expand Down Expand Up @@ -82,11 +73,7 @@ export class OpenSearchClientPool {
});
this.logger.info(`Created data source aws client pool of size ${size}`);

const getClientFromPool = (
key: string,
authType: AuthType,
request?: OpenSearchDashboardsRequest
) => {
const getClientFromPool = (key: string, authType: AuthType) => {
const selectedCache = authType === AuthType.SigV4 ? this.awsClientCache : this.clientCache;

return selectedCache!.get(key);
Expand Down
53 changes: 24 additions & 29 deletions src/plugins/data_source/server/client/configure_client.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import { ClientOptions } from '@opensearch-project/opensearch';
import { opensearchClientMock } from '../../../../core/server/opensearch/client/mocks';
import { cryptographyServiceSetupMock } from '../cryptography_service.mocks';
import { CryptographyServiceSetup } from '../cryptography_service';
import { DataSourceClientParams, AuthenticationMethod } from '../types';
import { DataSourceClientParams, AuthenticationMethod, ClientParameters } from '../types';
import { CustomApiSchemaRegistry } from '../schema_registry';
import { IAuthenticationMethodRegistery } from '../auth_registry';
import { authenticationMethodRegisteryMock } from '../auth_registry/authentication_methods_registry.mock';
Expand All @@ -47,6 +47,7 @@ describe('configureClient', () => {
let sigV4AuthContent: SigV4Content;
let customApiSchemaRegistry: CustomApiSchemaRegistry;
let authenticationMethodRegistery: jest.Mocked<IAuthenticationMethodRegistery>;
let clientParameters: ClientParameters;

const customAuthContent = {
region: 'us-east-1',
Expand All @@ -60,10 +61,7 @@ describe('configureClient', () => {

const authMethod: AuthenticationMethod = {
name: 'typeA',
authType: AuthType.SigV4,
credentialProvider: jest.fn(),
clientPoolSetup,
legacyClientPoolSetup: clientPoolSetup,
};

beforeEach(() => {
Expand Down Expand Up @@ -122,12 +120,21 @@ describe('configureClient', () => {
customApiSchemaRegistryPromise: Promise.resolve(customApiSchemaRegistry),
};

clientParameters = {
authType: AuthType.SigV4,
endpoint: dataSourceAttr.endpoint,
cacheKeySuffix: '',
credentials: sigV4AuthContent,
};

ClientMock.mockImplementation(() => dsClient);
authenticationMethodRegistery.getAuthenticationMethod.mockImplementation(() => authMethod);
authRegistryCredentialProviderMock.mockReturnValue(clientParameters);
});

afterEach(() => {
ClientMock.mockReset();
authRegistryCredentialProviderMock.mockReset();
});

test('configure client with auth.type == no_auth, will call new Client() to create client', async () => {
Expand Down Expand Up @@ -291,11 +298,6 @@ describe('configureClient', () => {
references: [],
});

authRegistryCredentialProviderMock.mockReturnValue({
credential: sigV4AuthContent,
type: AuthType.SigV4,
});

const client = await configureClient(
{ ...dataSourceClientParams, authRegistry: authenticationMethodRegistery },
clientPoolSetup,
Expand Down Expand Up @@ -336,8 +338,8 @@ describe('configureClient', () => {
});

authRegistryCredentialProviderMock.mockReturnValue({
credential: mockCredentials,
type: AuthType.SigV4,
...clientParameters,
credentials: mockCredentials,
});

const client = await configureClient(
Expand Down Expand Up @@ -376,11 +378,6 @@ describe('configureClient', () => {
references: [],
});

authRegistryCredentialProviderMock.mockReturnValue({
credential: sigV4AuthContent,
type: AuthType.SigV4,
});

const client = await configureClient(
{ ...dataSourceClientParams, authRegistry: authenticationMethodRegistery },
clientPoolSetup,
Expand Down Expand Up @@ -556,10 +553,7 @@ describe('configureClient', () => {
beforeEach(() => {
const authMethodWithClientPool: AuthenticationMethod = {
name: 'clientPoolTest',
authType: AuthType.SigV4,
credentialProvider: jest.fn(),
clientPoolSetup: opensearchClientPoolSetup,
legacyClientPoolSetup: clientPoolSetup,
};
authenticationMethodRegistery.getAuthenticationMethod
.mockReset()
Expand All @@ -577,33 +571,29 @@ describe('configureClient', () => {
},
references: [],
});
authRegistryCredentialProviderMock.mockReturnValue({
credential: sigV4AuthContent,
type: AuthType.SigV4,
});
});
test('Auth Method from Registry: If endpoint is same for multiple requests client pool size should be 1', async () => {
test('If endpoint is same for multiple requests client pool size should be 1', async () => {
await configureClient(
{ ...dataSourceClientParams, authRegistry: authenticationMethodRegistery },
clientPoolSetup,
opensearchClientPoolSetup,
config,
logger
);

await configureClient(
{ ...dataSourceClientParams, authRegistry: authenticationMethodRegistery },
clientPoolSetup,
opensearchClientPoolSetup,
config,
logger
);

expect(ClientMock).toHaveBeenCalledTimes(1);
});

test('Auth Method from Registry: If endpoint is different for two requests client pool size should be 2', async () => {
test('If endpoint is different for two requests client pool size should be 2', async () => {
await configureClient(
{ ...dataSourceClientParams, authRegistry: authenticationMethodRegistery },
clientPoolSetup,
opensearchClientPoolSetup,
config,
logger
);
Expand All @@ -625,10 +615,15 @@ describe('configureClient', () => {
},
references: [],
});
authRegistryCredentialProviderMock.mockReturnValue({
...clientParameters,
endpoint: 'http://test.com',
cacheKeySuffix: 'test',
});

await configureClient(
{ ...dataSourceClientParams, authRegistry: authenticationMethodRegistery },
clientPoolSetup,
opensearchClientPoolSetup,
config,
logger
);
Expand Down
50 changes: 29 additions & 21 deletions src/plugins/data_source/server/client/configure_client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
*/

import { Client, ClientOptions } from '@opensearch-project/opensearch';
import { Client as LegacyClient } from 'elasticsearch';
import { AwsSigv4Signer } from '@opensearch-project/opensearch/aws';
import { Logger, OpenSearchDashboardsRequest } from '../../../../../src/core/server';
import {
Expand All @@ -16,7 +17,7 @@ import {
import { DataSourcePluginConfigType } from '../../config';
import { CryptographyServiceSetup } from '../cryptography_service';
import { createDataSourceError } from '../lib/error';
import { AuthenticationMethod, DataSourceClientParams } from '../types';
import { DataSourceClientParams, ClientParameters } from '../types';
import { parseClientOptions } from './client_config';
import { OpenSearchClientPoolSetup } from './client_pool';
import {
Expand All @@ -25,6 +26,7 @@ import {
getCredential,
getDataSource,
getAuthenticationMethod,
generateCacheKey,
} from './configure_client_utils';
import { authRegistryCredentialProvider } from '../util/credential_provider';

Expand All @@ -44,6 +46,7 @@ export const configureClient = async (
): Promise<Client> => {
let dataSource;
let requireDecryption = true;
let clientParams;

try {
// configure test client
Expand All @@ -66,25 +69,32 @@ export const configureClient = async (
dataSource = await getDataSource(dataSourceId!, savedObjects);
}

let clientPool = openSearchClientPoolSetup;
const authenticationMethod = getAuthenticationMethod(dataSource, authRegistry);
if (authenticationMethod !== undefined) {
clientPool = authenticationMethod.clientPoolSetup;
clientParams = await authRegistryCredentialProvider(authenticationMethod, {
dataSourceAttr: dataSource,
request,
cryptography,
});
}
const rootClient = getRootClient(dataSource, clientPool.getClientFromPool, request) as Client;
const rootClient = getRootClient(
dataSource,
openSearchClientPoolSetup.getClientFromPool,
clientParams
) as Client;

const registeredSchema = (await customApiSchemaRegistryPromise).getAll();

return await getQueryClient(
dataSource,
clientPool,
openSearchClientPoolSetup.addClientToPool,
config,
registeredSchema,
cryptography,
rootClient,
dataSourceId,
request,
authenticationMethod,
clientParams,
requireDecryption
);
} catch (error: any) {
Expand Down Expand Up @@ -113,43 +123,41 @@ export const configureClient = async (
*/
const getQueryClient = async (
dataSourceAttr: DataSourceAttributes,
clientPool: OpenSearchClientPoolSetup,
addClientToPool: (endpoint: string, authType: AuthType, client: Client | LegacyClient) => void,
config: DataSourcePluginConfigType,
registeredSchema: any[],
cryptography?: CryptographyServiceSetup,
rootClient?: Client,
dataSourceId?: string,
request?: OpenSearchDashboardsRequest,
authenticationMethod?: AuthenticationMethod,
clientParams?: ClientParameters,
requireDecryption: boolean = true
): Promise<Client> => {
let credential;
let cacheKeySuffix;
let {
auth: { type },
endpoint,
} = dataSourceAttr;
const { endpoint } = dataSourceAttr;
const clientOptions = parseClientOptions(config, endpoint, registeredSchema);

if (authenticationMethod !== undefined) {
const credentialProvider = await authRegistryCredentialProvider(authenticationMethod, {
dataSourceAttr,
request,
cryptography,
});
credential = credentialProvider.credential;
type = credentialProvider.type;
if (clientParams !== undefined) {
credential = clientParams.credentials;
type = clientParams.authType;
cacheKeySuffix = clientParams.cacheKeySuffix;
endpoint = clientParams.endpoint;

if (credential.service === undefined) {
credential = { ...credential, service: dataSourceAttr.auth.credentials?.service };
}
}

const cacheKey = endpoint;
const cacheKey = generateCacheKey(endpoint, cacheKeySuffix);

switch (type) {
case AuthType.NoAuth:
if (!rootClient) rootClient = new Client(clientOptions);
clientPool.addClientToPool(cacheKey, type, rootClient, request);
addClientToPool(cacheKey, type, rootClient);

return rootClient.child();

Expand All @@ -161,7 +169,7 @@ const getQueryClient = async (
: (dataSourceAttr.auth.credentials as UsernamePasswordTypedContent));

if (!rootClient) rootClient = new Client(clientOptions);
clientPool.addClientToPool(cacheKey, type, rootClient, request);
addClientToPool(cacheKey, type, rootClient);

return getBasicAuthClient(rootClient, credential);

Expand All @@ -175,7 +183,7 @@ const getQueryClient = async (
if (!rootClient) {
rootClient = getAWSClient(credential, clientOptions);
}
clientPool.addClientToPool(cacheKey, type, rootClient, request);
addClientToPool(cacheKey, type, rootClient);

return getAWSChildClient(rootClient, credential);

Expand Down
Loading

0 comments on commit bbd290e

Please sign in to comment.