Skip to content

Commit

Permalink
fix(NODE-5127): implement reject kmsRequest on server close (#3964)
Browse files Browse the repository at this point in the history
  • Loading branch information
alenakhineika authored Jan 18, 2024
1 parent 4e56482 commit 568e05f
Show file tree
Hide file tree
Showing 2 changed files with 167 additions and 62 deletions.
139 changes: 77 additions & 62 deletions src/client-side-encryption/state_machine.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import {
import { type ProxyOptions } from '../cmap/connection';
import { getSocks, type SocksLib } from '../deps';
import { type MongoClient, type MongoClientOptions } from '../mongo_client';
import { BufferPool, MongoDBCollectionNamespace } from '../utils';
import { BufferPool, MongoDBCollectionNamespace, promiseWithResolvers } from '../utils';
import { type DataKey } from './client_encryption';
import { MongoCryptError } from './errors';
import { type MongocryptdManager } from './mongocryptd_manager';
Expand Down Expand Up @@ -282,7 +282,7 @@ export class StateMachine {
* @param kmsContext - A C++ KMS context returned from the bindings
* @returns A promise that resolves when the KMS reply has be fully parsed
*/
kmsRequest(request: MongoCryptKMSRequest): Promise<void> {
async kmsRequest(request: MongoCryptKMSRequest): Promise<void> {
const parsedUrl = request.endpoint.split(':');
const port = parsedUrl[1] != null ? Number.parseInt(parsedUrl[1], 10) : HTTPS_PORT;
const options: tls.ConnectionOptions & { host: string; port: number } = {
Expand All @@ -291,52 +291,73 @@ export class StateMachine {
port
};
const message = request.message;
const buffer = new BufferPool();

// TODO(NODE-3959): We can adopt `for-await on(socket, 'data')` with logic to control abort
// eslint-disable-next-line @typescript-eslint/no-misused-promises, no-async-promise-executor
return new Promise(async (resolve, reject) => {
const buffer = new BufferPool();
const netSocket: net.Socket = new net.Socket();
let socket: tls.TLSSocket;

// eslint-disable-next-line prefer-const
let socket: net.Socket;
let rawSocket: net.Socket;

function destroySockets() {
for (const sock of [socket, rawSocket]) {
if (sock) {
sock.removeAllListeners();
sock.destroy();
}
function destroySockets() {
for (const sock of [socket, netSocket]) {
if (sock) {
sock.removeAllListeners();
sock.destroy();
}
}
}

function ontimeout() {
destroySockets();
reject(new MongoCryptError('KMS request timed out'));
}
function ontimeout() {
return new MongoCryptError('KMS request timed out');
}

function onerror(cause: Error) {
return new MongoCryptError('KMS request failed', { cause });
}

function onerror(err: Error) {
destroySockets();
const mcError = new MongoCryptError('KMS request failed', { cause: err });
reject(mcError);
function onclose() {
return new MongoCryptError('KMS request closed');
}

const tlsOptions = this.options.tlsOptions;
if (tlsOptions) {
const kmsProvider = request.kmsProvider as ClientEncryptionDataKeyProvider;
const providerTlsOptions = tlsOptions[kmsProvider];
if (providerTlsOptions) {
const error = this.validateTlsOptions(kmsProvider, providerTlsOptions);
if (error) {
throw error;
}
try {
await this.setTlsOptions(providerTlsOptions, options);
} catch (err) {
throw onerror(err);
}
}
}

const {
promise: willConnect,
reject: rejectOnNetSocketError,
resolve: resolveOnNetSocketConnect
} = promiseWithResolvers<void>();
netSocket
.once('timeout', () => rejectOnNetSocketError(ontimeout()))
.once('error', err => rejectOnNetSocketError(onerror(err)))
.once('close', () => rejectOnNetSocketError(onclose()))
.once('connect', () => resolveOnNetSocketConnect());

try {
if (this.options.proxyOptions && this.options.proxyOptions.proxyHost) {
rawSocket = net.connect({
netSocket.connect({
host: this.options.proxyOptions.proxyHost,
port: this.options.proxyOptions.proxyPort || 1080
});
await willConnect;

rawSocket.on('timeout', ontimeout);
rawSocket.on('error', onerror);
try {
// eslint-disable-next-line @typescript-eslint/no-var-requires
const events = require('events') as typeof import('events');
await events.once(rawSocket, 'connect');
socks ??= loadSocks();
options.socket = (
await socks.SocksClient.createConnection({
existing_socket: rawSocket,
existing_socket: netSocket,
command: 'connect',
destination: { host: options.host, port: options.port },
proxy: {
Expand All @@ -350,45 +371,39 @@ export class StateMachine {
})
).socket;
} catch (err) {
return onerror(err);
throw onerror(err);
}
}

const tlsOptions = this.options.tlsOptions;
if (tlsOptions) {
const kmsProvider = request.kmsProvider as ClientEncryptionDataKeyProvider;
const providerTlsOptions = tlsOptions[kmsProvider];
if (providerTlsOptions) {
const error = this.validateTlsOptions(kmsProvider, providerTlsOptions);
if (error) reject(error);
try {
await this.setTlsOptions(providerTlsOptions, options);
} catch (error) {
return onerror(error);
}
}
}
socket = tls.connect(options, () => {
socket.write(message);
});

socket.once('timeout', ontimeout);
socket.once('error', onerror);

socket.on('data', data => {
buffer.append(data);
while (request.bytesNeeded > 0 && buffer.length) {
const bytesNeeded = Math.min(request.bytesNeeded, buffer.length);
request.addResponse(buffer.read(bytesNeeded));
}
const {
promise: willResolveKmsRequest,
reject: rejectOnTlsSocketError,
resolve
} = promiseWithResolvers<void>();
socket
.once('timeout', () => rejectOnTlsSocketError(ontimeout()))
.once('error', err => rejectOnTlsSocketError(onerror(err)))
.once('close', () => rejectOnTlsSocketError(onclose()))
.on('data', data => {
buffer.append(data);
while (request.bytesNeeded > 0 && buffer.length) {
const bytesNeeded = Math.min(request.bytesNeeded, buffer.length);
request.addResponse(buffer.read(bytesNeeded));
}

if (request.bytesNeeded <= 0) {
// There's no need for any more activity on this socket at this point.
destroySockets();
resolve();
}
});
});
if (request.bytesNeeded <= 0) {
resolve();
}
});
await willResolveKmsRequest;
} finally {
// There's no need for any more activity on this socket at this point.
destroySockets();
}
}

*requests(context: MongoCryptContext) {
Expand Down
90 changes: 90 additions & 0 deletions test/unit/client-side-encryption/state_machine.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,96 @@ describe('StateMachine', function () {
});
});

context('when server closed the socket', function () {
context('Socks5', function () {
let server;

beforeEach(async function () {
server = net.createServer(async socket => {
socket.end();
});
server.listen(0);
await once(server, 'listening');
});

afterEach(function () {
server.close();
});

it('throws a MongoCryptError with SocksClientError cause', async function () {
const stateMachine = new StateMachine({
proxyOptions: {
proxyHost: 'localhost',
proxyPort: server.address().port
}
} as any);
const request = new MockRequest(Buffer.from('foobar'), 500);

try {
await stateMachine.kmsRequest(request);
} catch (err) {
expect(err.name).to.equal('MongoCryptError');
expect(err.message).to.equal('KMS request failed');
expect(err.cause.constructor.name).to.equal('SocksClientError');
return;
}
expect.fail('missed exception');
});
});

context('endpoint with host and port', function () {
let server;
let serverSocket;

beforeEach(async function () {
server = net.createServer(async socket => {
serverSocket = socket;
});
server.listen(0);
await once(server, 'listening');
});

afterEach(function () {
server.close();
});

beforeEach(async function () {
const netSocket = net.connect({
port: server.address().port
});
await once(netSocket, 'connect');
this.sinon.stub(tls, 'connect').returns(netSocket);
});

afterEach(function () {
server.close();
this.sinon.restore();
});

it('throws a MongoCryptError error', async function () {
const stateMachine = new StateMachine({
host: 'localhost',
port: server.address().port
} as any);
const request = new MockRequest(Buffer.from('foobar'), 500);

try {
const kmsRequestPromise = stateMachine.kmsRequest(request);

await promisify(setTimeout)(0);
serverSocket.end();

await kmsRequestPromise;
} catch (err) {
expect(err.name).to.equal('MongoCryptError');
expect(err.message).to.equal('KMS request closed');
return;
}
expect.fail('missed exception');
});
});
});

afterEach(function () {
this.sinon.restore();
});
Expand Down

0 comments on commit 568e05f

Please sign in to comment.