From 717ac88a3746c81ca054a485c04ac4b285fe17f5 Mon Sep 17 00:00:00 2001 From: Jeff Yemin Date: Thu, 4 Jun 2020 08:57:26 -0400 Subject: [PATCH] Apply client-side encryption in transactions on sharded clusters This fixes a bug in both sync and async drivers where client-side encryption is not applied when in a transaction. JAVA-3752 --- .../async/client/ClientSessionBinding.java | 98 +++++----- .../client/internal/AsyncCryptBinding.java | 15 ++ .../ClientSideEncryptionSessionTest.java | 172 ++++++++++++++++++ .../mongodb/binding/AsyncClusterBinding.java | 7 + .../com/mongodb/binding/ClusterBinding.java | 15 +- .../AsyncClusterAwareReadWriteBinding.java | 11 ++ .../binding/ClusterAwareReadWriteBinding.java | 8 + .../client/internal/ClientSessionBinding.java | 31 ++-- .../mongodb/client/internal/CryptBinding.java | 6 + .../ClientSideEncryptionSessionTest.java | 154 ++++++++++++++++ 10 files changed, 440 insertions(+), 77 deletions(-) create mode 100644 driver-async/src/test/functional/com/mongodb/async/client/ClientSideEncryptionSessionTest.java create mode 100644 driver-sync/src/test/functional/com/mongodb/client/ClientSideEncryptionSessionTest.java diff --git a/driver-async/src/main/com/mongodb/async/client/ClientSessionBinding.java b/driver-async/src/main/com/mongodb/async/client/ClientSessionBinding.java index 9a75818a135..e0878a821fd 100644 --- a/driver-async/src/main/com/mongodb/async/client/ClientSessionBinding.java +++ b/driver-async/src/main/com/mongodb/async/client/ClientSessionBinding.java @@ -21,7 +21,6 @@ import com.mongodb.async.SingleResultCallback; import com.mongodb.binding.AsyncConnectionSource; import com.mongodb.binding.AsyncReadWriteBinding; -import com.mongodb.binding.AsyncSingleServerBinding; import com.mongodb.connection.AsyncConnection; import com.mongodb.connection.ClusterType; import com.mongodb.connection.Server; @@ -53,29 +52,19 @@ public ReadPreference getReadPreference() { @Override public void getReadConnectionSource(final SingleResultCallback callback) { - wrapped.getReadConnectionSource(new SingleResultCallback() { - @Override - public void onResult(final AsyncConnectionSource result, final Throwable t) { - if (t != null) { - callback.onResult(null, t); - } else { - wrapConnectionSource(result, callback); - } - } - }); + if (isActiveShardedTxn()) { + getPinnedConnectionSource(callback); + } else { + wrapped.getReadConnectionSource(new WrappingCallback(callback)); + } } public void getWriteConnectionSource(final SingleResultCallback callback) { - wrapped.getWriteConnectionSource(new SingleResultCallback() { - @Override - public void onResult(final AsyncConnectionSource result, final Throwable t) { - if (t != null) { - callback.onResult(null, t); - } else { - wrapConnectionSource(result, callback); - } - } - }); + if (isActiveShardedTxn()) { + getPinnedConnectionSource(callback); + } else { + wrapped.getWriteConnectionSource(new WrappingCallback(callback)); + } } @Override @@ -83,47 +72,25 @@ public SessionContext getSessionContext() { return sessionContext; } - private void wrapConnectionSource(final AsyncConnectionSource connectionSource, - final SingleResultCallback callback) { - if (isActiveShardedTxn()) { - if (session.getPinnedServerAddress() == null) { - wrapped.getCluster().selectServerAsync( - new ReadPreferenceServerSelector(wrapped.getReadPreference()), - new SingleResultCallback() { - @Override - public void onResult(final Server server, final Throwable t) { - if (t != null) { - callback.onResult(null, t); - } else { - session.setPinnedServerAddress(server.getDescription().getAddress()); - setSingleServerBindingConnectionSource(callback); - } + private void getPinnedConnectionSource(final SingleResultCallback callback) { + if (session.getPinnedServerAddress() == null) { + wrapped.getCluster().selectServerAsync( + new ReadPreferenceServerSelector(wrapped.getReadPreference()), new SingleResultCallback() { + @Override + public void onResult(final Server server, final Throwable t) { + if (t != null) { + callback.onResult(null, t); + } else { + session.setPinnedServerAddress(server.getDescription().getAddress()); + wrapped.getConnectionSource(session.getPinnedServerAddress(), new WrappingCallback(callback)); } - }); - } else { - setSingleServerBindingConnectionSource(callback); - } + } + }); } else { - callback.onResult(new SessionBindingAsyncConnectionSource(connectionSource), null); + wrapped.getConnectionSource(session.getPinnedServerAddress(), new WrappingCallback(callback)); } } - private void setSingleServerBindingConnectionSource(final SingleResultCallback callback) { - final AsyncSingleServerBinding binding = - new AsyncSingleServerBinding(wrapped.getCluster(), session.getPinnedServerAddress(), wrapped.getReadPreference()); - binding.getWriteConnectionSource(new SingleResultCallback() { - @Override - public void onResult(final AsyncConnectionSource result, final Throwable t) { - binding.release(); - if (t != null) { - callback.onResult(null, t); - } else { - callback.onResult(new SessionBindingAsyncConnectionSource(result), null); - } - } - }); - } - @Override public int getCount() { return wrapped.getCount(); @@ -225,4 +192,21 @@ public ReadConcern getReadConcern() { } } } + + private class WrappingCallback implements SingleResultCallback { + private final SingleResultCallback callback; + + WrappingCallback(final SingleResultCallback callback) { + this.callback = callback; + } + + @Override + public void onResult(final AsyncConnectionSource result, final Throwable t) { + if (t != null) { + callback.onResult(null, t); + } else { + callback.onResult(new SessionBindingAsyncConnectionSource(result), null); + } + } + } } diff --git a/driver-async/src/main/com/mongodb/async/client/internal/AsyncCryptBinding.java b/driver-async/src/main/com/mongodb/async/client/internal/AsyncCryptBinding.java index 5d9ef119cee..251aa256b41 100644 --- a/driver-async/src/main/com/mongodb/async/client/internal/AsyncCryptBinding.java +++ b/driver-async/src/main/com/mongodb/async/client/internal/AsyncCryptBinding.java @@ -17,6 +17,7 @@ package com.mongodb.async.client.internal; import com.mongodb.ReadPreference; +import com.mongodb.ServerAddress; import com.mongodb.async.SingleResultCallback; import com.mongodb.binding.AsyncConnectionSource; import com.mongodb.binding.AsyncReadWriteBinding; @@ -74,6 +75,20 @@ public void onResult(final AsyncConnectionSource result, final Throwable t) { }); } + @Override + public void getConnectionSource(final ServerAddress serverAddress, final SingleResultCallback callback) { + wrapped.getConnectionSource(serverAddress, new SingleResultCallback() { + @Override + public void onResult(final AsyncConnectionSource result, final Throwable t) { + if (t != null) { + callback.onResult(null, t); + } else { + callback.onResult(new AsyncCryptConnectionSource(result), null); + } + } + }); + } + @Override public int getCount() { return wrapped.getCount(); diff --git a/driver-async/src/test/functional/com/mongodb/async/client/ClientSideEncryptionSessionTest.java b/driver-async/src/test/functional/com/mongodb/async/client/ClientSideEncryptionSessionTest.java new file mode 100644 index 00000000000..c760e0ff0f3 --- /dev/null +++ b/driver-async/src/test/functional/com/mongodb/async/client/ClientSideEncryptionSessionTest.java @@ -0,0 +1,172 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.async.client; + +import com.mongodb.AutoEncryptionSettings; +import com.mongodb.MongoClientSettings; +import com.mongodb.MongoNamespace; +import com.mongodb.WriteConcern; +import com.mongodb.async.FutureResultCallback; +import com.mongodb.client.test.CollectionHelper; +import org.bson.BsonDocument; +import org.bson.BsonString; +import org.bson.codecs.BsonDocumentCodec; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.io.File; +import java.io.IOException; +import java.net.URISyntaxException; +import java.util.Arrays; +import java.util.Base64; +import java.util.Collection; +import java.util.HashMap; +import java.util.Map; + +import static com.mongodb.ClusterFixture.isNotAtLeastJava8; +import static com.mongodb.ClusterFixture.isStandalone; +import static com.mongodb.ClusterFixture.serverVersionAtLeast; +import static com.mongodb.async.client.Fixture.getDefaultDatabaseName; +import static com.mongodb.async.client.Fixture.getMongoClient; +import static com.mongodb.async.client.Fixture.getMongoClientBuilderFromConnectionString; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.junit.Assume.assumeFalse; +import static org.junit.Assume.assumeTrue; +import static util.JsonPoweredTestHelper.getTestDocument; + +@RunWith(Parameterized.class) +public class ClientSideEncryptionSessionTest { + private static final String COLLECTION_NAME = "clientSideEncryptionSessionsTest"; + + private MongoClient client = getMongoClient(); + private MongoClient clientEncrypted; + private final boolean useTransaction; + + @Parameterized.Parameters(name = "useTransaction: {0}") + public static Collection data() { + return Arrays.asList(new Object[]{true}, new Object[]{false}); + } + + public ClientSideEncryptionSessionTest(final boolean useTransaction) { + this.useTransaction = useTransaction; + } + + @Before + public void setUp() throws Throwable { + assumeFalse(isNotAtLeastJava8()); + assumeTrue(serverVersionAtLeast(4, 2)); + assumeFalse(isStandalone()); + + /* Step 1: get unencrypted client and recreate keys collection */ + client = getMongoClient(); + MongoDatabase keyVaultDatabase = client.getDatabase("keyvault"); + MongoCollection dataKeys = keyVaultDatabase.getCollection("datakeys", BsonDocument.class) + .withWriteConcern(WriteConcern.MAJORITY); + FutureResultCallback voidCallback = new FutureResultCallback(); + dataKeys.drop(voidCallback); + voidCallback.get(); + + voidCallback = new FutureResultCallback(); + dataKeys.insertOne(bsonDocumentFromPath("external-key.json"), voidCallback); + voidCallback.get(); + + /* Step 2: create encryption objects. */ + Map> kmsProviders = new HashMap>(); + Map localMasterkey = new HashMap(); + Map schemaMap = new HashMap(); + + byte[] localMasterKeyBytes = Base64.getDecoder().decode("Mng0NCt4ZHVUYUJCa1kxNkVyNUR1QURhZ2h2UzR2d2RrZzh0cFBwM3R6NmdWMDFBM" + + "UN3YkQ5aXRRMkhGRGdQV09wOGVNYUMxT2k3NjZKelhaQmRCZGJkTXVyZG9uSjFk"); + localMasterkey.put("key", localMasterKeyBytes); + kmsProviders.put("local", localMasterkey); + schemaMap.put(getDefaultDatabaseName() + "." + COLLECTION_NAME, bsonDocumentFromPath("external-schema.json")); + + MongoClientSettings clientSettings = getMongoClientBuilderFromConnectionString() + .autoEncryptionSettings(AutoEncryptionSettings.builder() + .keyVaultNamespace("keyvault.datakeys") + .kmsProviders(kmsProviders) + .schemaMap(schemaMap).build()) + .build(); + clientEncrypted = MongoClients.create(clientSettings); + + CollectionHelper collectionHelper = + new CollectionHelper(new BsonDocumentCodec(), new MongoNamespace(getDefaultDatabaseName(), COLLECTION_NAME)); + collectionHelper.drop(); + collectionHelper.create(); + } + + @After + public void after() { + if (clientEncrypted != null) { + try { + clientEncrypted.close(); + } catch (Exception e) { + // ignore + } + } + } + + @Test + public void testWithExplicitSession() throws Throwable { + BsonString unencryptedValue = new BsonString("test"); + + FutureResultCallback clientSessionCallback = new FutureResultCallback(); + clientEncrypted.startSession(clientSessionCallback); + ClientSession clientSession = clientSessionCallback.get(); + try { + if (useTransaction) { + clientSession.startTransaction(); + } + MongoCollection autoEncryptedCollection = clientEncrypted.getDatabase(getDefaultDatabaseName()) + .getCollection(COLLECTION_NAME, BsonDocument.class); + FutureResultCallback insertCallback = new FutureResultCallback(); + autoEncryptedCollection.insertOne(clientSession, new BsonDocument().append("encrypted", new BsonString("test")), + insertCallback); + insertCallback.get(); + + FutureResultCallback findCallback = new FutureResultCallback(); + autoEncryptedCollection.find(clientSession).first(findCallback); + BsonDocument unencryptedDocument = findCallback.get(); + assertEquals(unencryptedValue, unencryptedDocument.getString("encrypted")); + + if (useTransaction) { + FutureResultCallback commitCallback = new FutureResultCallback(); + clientSession.commitTransaction(commitCallback); + commitCallback.get(); + } + } finally { + clientSession.close(); + } + + MongoCollection encryptedCollection = client.getDatabase(getDefaultDatabaseName()) + .getCollection(COLLECTION_NAME, BsonDocument.class); + FutureResultCallback findCallback = new FutureResultCallback(); + encryptedCollection.find().first(findCallback); + BsonDocument encryptedDocument = findCallback.get(); + assertTrue(encryptedDocument.isBinary("encrypted")); + assertEquals(6, encryptedDocument.getBinary("encrypted").getType()); + } + + private static BsonDocument bsonDocumentFromPath(final String path) throws IOException, URISyntaxException { + return getTestDocument(new File(ClientSideEncryptionSessionTest.class + .getResource("/client-side-encryption-external/" + path).toURI())); + } +} diff --git a/driver-core/src/main/com/mongodb/binding/AsyncClusterBinding.java b/driver-core/src/main/com/mongodb/binding/AsyncClusterBinding.java index 3216a8605a8..1c008441c87 100644 --- a/driver-core/src/main/com/mongodb/binding/AsyncClusterBinding.java +++ b/driver-core/src/main/com/mongodb/binding/AsyncClusterBinding.java @@ -18,6 +18,7 @@ import com.mongodb.ReadConcern; import com.mongodb.ReadPreference; +import com.mongodb.ServerAddress; import com.mongodb.async.SingleResultCallback; import com.mongodb.connection.AsyncConnection; import com.mongodb.connection.Cluster; @@ -27,6 +28,7 @@ import com.mongodb.internal.binding.AsyncClusterAwareReadWriteBinding; import com.mongodb.internal.connection.ReadConcernAwareNoOpSessionContext; import com.mongodb.selector.ReadPreferenceServerSelector; +import com.mongodb.selector.ServerAddressSelector; import com.mongodb.selector.ServerSelector; import com.mongodb.selector.WritableServerSelector; import com.mongodb.session.SessionContext; @@ -102,6 +104,11 @@ public void getWriteConnectionSource(final SingleResultCallback callback) { + getAsyncClusterBindingConnectionSource(new ServerAddressSelector(serverAddress), callback); + } + private void getAsyncClusterBindingConnectionSource(final ServerSelector serverSelector, final SingleResultCallback callback) { cluster.selectServerAsync(serverSelector, new SingleResultCallback() { diff --git a/driver-core/src/main/com/mongodb/binding/ClusterBinding.java b/driver-core/src/main/com/mongodb/binding/ClusterBinding.java index 9aa89a314e3..51460746642 100644 --- a/driver-core/src/main/com/mongodb/binding/ClusterBinding.java +++ b/driver-core/src/main/com/mongodb/binding/ClusterBinding.java @@ -18,6 +18,7 @@ import com.mongodb.ReadConcern; import com.mongodb.ReadPreference; +import com.mongodb.ServerAddress; import com.mongodb.connection.Cluster; import com.mongodb.connection.Connection; import com.mongodb.connection.Server; @@ -26,6 +27,7 @@ import com.mongodb.internal.binding.ClusterAwareReadWriteBinding; import com.mongodb.internal.connection.ReadConcernAwareNoOpSessionContext; import com.mongodb.selector.ReadPreferenceServerSelector; +import com.mongodb.selector.ServerAddressSelector; import com.mongodb.selector.ServerSelector; import com.mongodb.selector.WritableServerSelector; import com.mongodb.session.SessionContext; @@ -89,13 +91,13 @@ public ReadPreference getReadPreference() { } @Override - public ConnectionSource getReadConnectionSource() { - return new ClusterBindingConnectionSource(new ReadPreferenceServerSelector(readPreference)); + public SessionContext getSessionContext() { + return new ReadConcernAwareNoOpSessionContext(readConcern); } @Override - public SessionContext getSessionContext() { - return new ReadConcernAwareNoOpSessionContext(readConcern); + public ConnectionSource getReadConnectionSource() { + return new ClusterBindingConnectionSource(new ReadPreferenceServerSelector(readPreference)); } @Override @@ -103,6 +105,11 @@ public ConnectionSource getWriteConnectionSource() { return new ClusterBindingConnectionSource(new WritableServerSelector()); } + @Override + public ConnectionSource getConnectionSource(final ServerAddress serverAddress) { + return new ClusterBindingConnectionSource(new ServerAddressSelector(serverAddress)); + } + private final class ClusterBindingConnectionSource extends AbstractReferenceCounted implements ConnectionSource { private final Server server; diff --git a/driver-core/src/main/com/mongodb/internal/binding/AsyncClusterAwareReadWriteBinding.java b/driver-core/src/main/com/mongodb/internal/binding/AsyncClusterAwareReadWriteBinding.java index ec0523f23f7..c84df3d84e9 100644 --- a/driver-core/src/main/com/mongodb/internal/binding/AsyncClusterAwareReadWriteBinding.java +++ b/driver-core/src/main/com/mongodb/internal/binding/AsyncClusterAwareReadWriteBinding.java @@ -17,6 +17,9 @@ package com.mongodb.internal.binding; +import com.mongodb.ServerAddress; +import com.mongodb.async.SingleResultCallback; +import com.mongodb.binding.AsyncConnectionSource; import com.mongodb.binding.AsyncReadWriteBinding; import com.mongodb.connection.Cluster; @@ -25,4 +28,12 @@ */ public interface AsyncClusterAwareReadWriteBinding extends AsyncReadWriteBinding { Cluster getCluster(); + + /** + * Returns a connection source to the specified server + * + * @param serverAddress the server address + * @param callback the to be passed the connection source + */ + void getConnectionSource(ServerAddress serverAddress, SingleResultCallback callback); } diff --git a/driver-core/src/main/com/mongodb/internal/binding/ClusterAwareReadWriteBinding.java b/driver-core/src/main/com/mongodb/internal/binding/ClusterAwareReadWriteBinding.java index cc460620957..e3db0354be8 100644 --- a/driver-core/src/main/com/mongodb/internal/binding/ClusterAwareReadWriteBinding.java +++ b/driver-core/src/main/com/mongodb/internal/binding/ClusterAwareReadWriteBinding.java @@ -16,6 +16,8 @@ package com.mongodb.internal.binding; +import com.mongodb.ServerAddress; +import com.mongodb.binding.ConnectionSource; import com.mongodb.binding.ReadWriteBinding; import com.mongodb.connection.Cluster; @@ -24,4 +26,10 @@ */ public interface ClusterAwareReadWriteBinding extends ReadWriteBinding { Cluster getCluster(); + + /** + * Returns a connection source to the specified server address. + * @return the connection source + */ + ConnectionSource getConnectionSource(ServerAddress serverAddress); } diff --git a/driver-sync/src/main/com/mongodb/client/internal/ClientSessionBinding.java b/driver-sync/src/main/com/mongodb/client/internal/ClientSessionBinding.java index 5ee9028106d..ef956d2f06b 100644 --- a/driver-sync/src/main/com/mongodb/client/internal/ClientSessionBinding.java +++ b/driver-sync/src/main/com/mongodb/client/internal/ClientSessionBinding.java @@ -18,9 +18,9 @@ import com.mongodb.ReadConcern; import com.mongodb.ReadPreference; +import com.mongodb.ServerAddress; import com.mongodb.binding.ConnectionSource; import com.mongodb.binding.ReadWriteBinding; -import com.mongodb.binding.SingleServerBinding; import com.mongodb.client.ClientSession; import com.mongodb.connection.ClusterType; import com.mongodb.connection.Connection; @@ -79,23 +79,19 @@ private void closeSessionIfCountIsZero() { @Override public ConnectionSource getReadConnectionSource() { - return new SessionBindingConnectionSource(wrapConnectionSource(wrapped.getReadConnectionSource())); + if (isActiveShardedTxn()) { + return new SessionBindingConnectionSource(wrapped.getConnectionSource(pinServer())); + } else { + return new SessionBindingConnectionSource(wrapped.getReadConnectionSource()); + } } public ConnectionSource getWriteConnectionSource() { - return new SessionBindingConnectionSource(wrapConnectionSource(wrapped.getWriteConnectionSource())); - } - - private ConnectionSource wrapConnectionSource(final ConnectionSource connectionSource) { - ConnectionSource retVal = connectionSource; if (isActiveShardedTxn()) { - setPinnedServerAddress(); - SingleServerBinding binding = new SingleServerBinding(wrapped.getCluster(), session.getPinnedServerAddress(), - wrapped.getReadPreference()); - retVal = binding.getWriteConnectionSource(); - binding.release(); + return new SessionBindingConnectionSource(wrapped.getConnectionSource(pinServer())); + } else { + return new SessionBindingConnectionSource(wrapped.getWriteConnectionSource()); } - return retVal; } @Override @@ -107,11 +103,14 @@ private boolean isActiveShardedTxn() { return session.hasActiveTransaction() && wrapped.getCluster().getDescription().getType() == ClusterType.SHARDED; } - private void setPinnedServerAddress() { - if (session.getPinnedServerAddress() == null) { + private ServerAddress pinServer() { + ServerAddress pinnedServerAddress = session.getPinnedServerAddress(); + if (pinnedServerAddress == null) { Server server = wrapped.getCluster().selectServer(new ReadPreferenceServerSelector(wrapped.getReadPreference())); - session.setPinnedServerAddress(server.getDescription().getAddress()); + pinnedServerAddress = server.getDescription().getAddress(); + session.setPinnedServerAddress(pinnedServerAddress); } + return pinnedServerAddress; } private class SessionBindingConnectionSource implements ConnectionSource { diff --git a/driver-sync/src/main/com/mongodb/client/internal/CryptBinding.java b/driver-sync/src/main/com/mongodb/client/internal/CryptBinding.java index bd05f633a65..2a6a4b095fa 100644 --- a/driver-sync/src/main/com/mongodb/client/internal/CryptBinding.java +++ b/driver-sync/src/main/com/mongodb/client/internal/CryptBinding.java @@ -17,6 +17,7 @@ package com.mongodb.client.internal; import com.mongodb.ReadPreference; +import com.mongodb.ServerAddress; import com.mongodb.binding.ConnectionSource; import com.mongodb.binding.ReadWriteBinding; import com.mongodb.connection.Cluster; @@ -49,6 +50,11 @@ public ConnectionSource getWriteConnectionSource() { return new CryptConnectionSource(wrapped.getWriteConnectionSource()); } + @Override + public ConnectionSource getConnectionSource(final ServerAddress serverAddress) { + return new CryptConnectionSource(wrapped.getConnectionSource(serverAddress)); + } + @Override public SessionContext getSessionContext() { return wrapped.getSessionContext(); diff --git a/driver-sync/src/test/functional/com/mongodb/client/ClientSideEncryptionSessionTest.java b/driver-sync/src/test/functional/com/mongodb/client/ClientSideEncryptionSessionTest.java new file mode 100644 index 00000000000..a488d36df50 --- /dev/null +++ b/driver-sync/src/test/functional/com/mongodb/client/ClientSideEncryptionSessionTest.java @@ -0,0 +1,154 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.client; + +import com.mongodb.AutoEncryptionSettings; +import com.mongodb.MongoClientSettings; +import com.mongodb.MongoNamespace; +import com.mongodb.WriteConcern; +import com.mongodb.client.test.CollectionHelper; +import org.bson.BsonDocument; +import org.bson.BsonString; +import org.bson.codecs.BsonDocumentCodec; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.io.File; +import java.io.IOException; +import java.net.URISyntaxException; +import java.util.Arrays; +import java.util.Base64; +import java.util.Collection; +import java.util.HashMap; +import java.util.Map; + +import static com.mongodb.ClusterFixture.isNotAtLeastJava8; +import static com.mongodb.ClusterFixture.isStandalone; +import static com.mongodb.ClusterFixture.serverVersionAtLeast; +import static com.mongodb.client.Fixture.getDefaultDatabaseName; +import static com.mongodb.client.Fixture.getMongoClient; +import static com.mongodb.client.Fixture.getMongoClientSettingsBuilder; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.junit.Assume.assumeFalse; +import static org.junit.Assume.assumeTrue; +import static util.JsonPoweredTestHelper.getTestDocument; + +@RunWith(Parameterized.class) +public class ClientSideEncryptionSessionTest { + private static final String COLLECTION_NAME = "clientSideEncryptionSessionsTest"; + private MongoClient client, clientEncrypted; + private final boolean useTransaction; + + @Parameterized.Parameters(name = "useTransaction: {0}") + public static Collection data() { + return Arrays.asList(new Object[]{true}, new Object[]{false}); + } + + public ClientSideEncryptionSessionTest(final boolean useTransaction) { + this.useTransaction = useTransaction; + } + + @Before + public void setUp() throws IOException, URISyntaxException { + assumeFalse(isNotAtLeastJava8()); + assumeTrue(serverVersionAtLeast(4, 2)); + assumeFalse(isStandalone()); + + /* Step 1: get unencrypted client and recreate keys collection */ + client = getMongoClient(); + MongoDatabase keyvaultDatabase = client.getDatabase("keyvault"); + MongoCollection datakeys = keyvaultDatabase.getCollection("datakeys", BsonDocument.class) + .withWriteConcern(WriteConcern.MAJORITY); + datakeys.drop(); + datakeys.insertOne(bsonDocumentFromPath("external-key.json")); + + /* Step 2: create encryption objects. */ + Map> kmsProviders = new HashMap>(); + Map localMasterkey = new HashMap(); + Map schemaMap = new HashMap(); + + byte[] localMasterkeyBytes = Base64.getDecoder().decode("Mng0NCt4ZHVUYUJCa1kxNkVyNUR1QURhZ2h2UzR2d2RrZzh0cFBwM3R6NmdWMDFBM" + + "UN3YkQ5aXRRMkhGRGdQV09wOGVNYUMxT2k3NjZKelhaQmRCZGJkTXVyZG9uSjFk"); + localMasterkey.put("key", localMasterkeyBytes); + kmsProviders.put("local", localMasterkey); + schemaMap.put(getDefaultDatabaseName() + "." + COLLECTION_NAME, bsonDocumentFromPath("external-schema.json")); + + AutoEncryptionSettings autoEncryptionSettings = AutoEncryptionSettings.builder() + .keyVaultNamespace("keyvault.datakeys") + .kmsProviders(kmsProviders) + .schemaMap(schemaMap).build(); + + MongoClientSettings clientSettings = getMongoClientSettingsBuilder() + .autoEncryptionSettings(autoEncryptionSettings) + .build(); + clientEncrypted = MongoClients.create(clientSettings); + + CollectionHelper collectionHelper = + new CollectionHelper(new BsonDocumentCodec(), new MongoNamespace(getDefaultDatabaseName(), COLLECTION_NAME)); + collectionHelper.drop(); + collectionHelper.create(); + } + + @After + public void after() { + if (clientEncrypted != null) { + try { + clientEncrypted.close(); + } catch (Exception e) { + // ignore + } + } + } + + @Test + public void testWithExplicitSession() { + BsonString unencryptedValue = new BsonString("test"); + + ClientSession clientSession = clientEncrypted.startSession(); + try { + if (useTransaction) { + clientSession.startTransaction(); + } + MongoCollection encryptedCollection = clientEncrypted.getDatabase(getDefaultDatabaseName()) + .getCollection(COLLECTION_NAME, BsonDocument.class); + encryptedCollection.insertOne(clientSession, new BsonDocument().append("encrypted", unencryptedValue)); + BsonDocument unencryptedDocument = encryptedCollection.find(clientSession).first(); + assertEquals(unencryptedValue, unencryptedDocument.getString("encrypted")); + if (useTransaction) { + clientSession.commitTransaction(); + } + } finally { + clientSession.close(); + } + + MongoCollection unencryptedCollection = client.getDatabase(getDefaultDatabaseName()) + .getCollection(COLLECTION_NAME, BsonDocument.class); + BsonDocument encryptedDocument = unencryptedCollection.find().first(); + assertTrue(encryptedDocument.isBinary("encrypted")); + assertEquals(6, encryptedDocument.getBinary("encrypted").getType()); + } + + + private static BsonDocument bsonDocumentFromPath(final String path) throws IOException, URISyntaxException { + return getTestDocument(new File(ClientSideEncryptionSessionTest.class + .getResource("/client-side-encryption-external/" + path).toURI())); + } +}