Skip to content

Commit

Permalink
[2.x] Use predictable serialization logic for transport headers
Browse files Browse the repository at this point in the history
This change will prevent new clusters from using the 'custom serialization' format that was causing performance impact with customers in OpenSearch 2.11.

Background: the serialization changes from opensearch-project#2802 introduced issues where for certain serialization headers that were previously very small for over the wire transmission become much larger. The root cause of this was that the JDK serialization process was able to detect duplicate entries and then use an encoding format to make it compressible. Adding this logic into the serialization system from OpenSearch is non-trivial and is not being invested in.

Signed-off-by: Peter Nied <petern@amazon.com>
  • Loading branch information
peternied committed Apr 25, 2024
1 parent badea0d commit 60045e7
Show file tree
Hide file tree
Showing 7 changed files with 113 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.opensearch.security.ssl.util.ExceptionUtils;
import org.opensearch.security.ssl.util.SSLRequestHelper;
import org.opensearch.security.support.ConfigConstants;
import org.opensearch.security.support.SerializationFormat;
import org.opensearch.tasks.Task;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportChannel;
Expand Down Expand Up @@ -92,7 +93,7 @@ public final void messageReceived(T request, TransportChannel channel, Task task

threadContext.putTransient(
ConfigConstants.USE_JDK_SERIALIZATION,
channel.getVersion().before(ConfigConstants.FIRST_CUSTOM_SERIALIZATION_SUPPORTED_OS_VERSION)
SerializationFormat.determineFormat(channel.getVersion()) == SerializationFormat.JDK
);

if (SSLRequestHelper.containsBadHeader(threadContext, "_opendistro_security_ssl_")) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;

import org.opensearch.Version;
import org.opensearch.common.settings.Settings;
import org.opensearch.security.auditlog.impl.AuditCategory;

Expand Down Expand Up @@ -332,7 +331,6 @@ public enum RolesMappingResolution {
public static final String TENANCY_GLOBAL_TENANT_DEFAULT_NAME = "";

public static final String USE_JDK_SERIALIZATION = "plugins.security.use_jdk_serialization";
public static final Version FIRST_CUSTOM_SERIALIZATION_SUPPORTED_OS_VERSION = Version.V_2_11_0;

// On-behalf-of endpoints settings
// CS-SUPPRESS-SINGLE: RegexpSingleline get Extensions Settings
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.security.support;

import org.opensearch.Version;

public enum SerializationFormat {
/** Uses Java's native serialization system */
JDK,
/** Uses a custom serializer built ontop of OpenSearch 2.11 */
CustomSerializer_2_11;

private static final Version FIRST_CUSTOM_SERIALIZATION_SUPPORTED_OS_VERSION = Version.V_2_11_0;
private static final Version CUSTOM_SERIALIZATION_NO_LONGER_SUPPORTED_OS_VERSION = Version.V_2_14_0;

/**
* Determines the format of serialization that should be used from a version identifier
*/
public static SerializationFormat determineFormat(final Version version) {
if (version.onOrAfter(FIRST_CUSTOM_SERIALIZATION_SUPPORTED_OS_VERSION)
&& version.before(CUSTOM_SERIALIZATION_NO_LONGER_SUPPORTED_OS_VERSION)) {
return SerializationFormat.CustomSerializer_2_11;
}
return SerializationFormat.JDK;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
import org.opensearch.security.support.Base64Helper;
import org.opensearch.security.support.ConfigConstants;
import org.opensearch.security.support.HeaderHelper;
import org.opensearch.security.support.SerializationFormat;
import org.opensearch.security.user.User;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.Transport.Connection;
Expand Down Expand Up @@ -150,7 +151,8 @@ public <T extends TransportResponse> void sendRequestDecorate(
final String origCCSTransientMf = getThreadContext().getTransient(ConfigConstants.OPENDISTRO_SECURITY_MASKED_FIELD_CCS);

final boolean isDebugEnabled = log.isDebugEnabled();
final boolean useJDKSerialization = connection.getVersion().before(ConfigConstants.FIRST_CUSTOM_SERIALIZATION_SUPPORTED_OS_VERSION);

final var serializationFormat = SerializationFormat.determineFormat(connection.getVersion());
final boolean isSameNodeRequest = localNode != null && localNode.equals(connection.getNode());

try (ThreadContext.StoredContext stashedContext = getThreadContext().stashContext()) {
Expand Down Expand Up @@ -228,25 +230,28 @@ && getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_INJECTED_ROL
);
}

if (useJDKSerialization) {
Map<String, String> jdkSerializedHeaders = new HashMap<>();
HeaderHelper.getAllSerializedHeaderNames()
.stream()
.filter(k -> headerMap.get(k) != null)
.forEach(k -> jdkSerializedHeaders.put(k, Base64Helper.ensureJDKSerialized(headerMap.get(k))));
headerMap.putAll(jdkSerializedHeaders);
try {
if (serializationFormat == SerializationFormat.JDK) {
Map<String, String> jdkSerializedHeaders = new HashMap<>();
HeaderHelper.getAllSerializedHeaderNames()
.stream()
.filter(k -> headerMap.get(k) != null)
.forEach(k -> jdkSerializedHeaders.put(k, Base64Helper.ensureJDKSerialized(headerMap.get(k))));
headerMap.putAll(jdkSerializedHeaders);
}
getThreadContext().putHeader(headerMap);
} catch (IllegalArgumentException iae) {
log.debug("Failed to add headers information onto on thread context", iae);
}

getThreadContext().putHeader(headerMap);

ensureCorrectHeaders(
remoteAddress0,
user0,
origin0,
injectedUserString,
injectedRolesString,
isSameNodeRequest,
useJDKSerialization
serializationFormat
);

if (actionTraceEnabled.get()) {
Expand Down Expand Up @@ -275,7 +280,7 @@ private void ensureCorrectHeaders(
final String injectedUserString,
final String injectedRolesString,
final boolean isSameNodeRequest,
final boolean useJDKSerialization
final SerializationFormat format
) {
// keep original address

Expand Down Expand Up @@ -313,6 +318,7 @@ && getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_ORIGIN_HEADE
getThreadContext().putTransient(ConfigConstants.OPENDISTRO_SECURITY_INJECTED_USER, injectedUserString);
}
} else {
final var useJDKSerialization = format == SerializationFormat.JDK;
if (transportAddress != null) {
getThreadContext().putHeader(
ConfigConstants.OPENDISTRO_SECURITY_REMOTE_ADDRESS_HEADER,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,17 @@
package org.opensearch.security.support;

import java.io.Serializable;
import java.util.HashMap;
import java.util.stream.IntStream;

import org.junit.Assert;
import org.junit.Test;

import static org.hamcrest.Matchers.closeTo;
import static org.hamcrest.Matchers.equalTo;
import static org.opensearch.security.support.Base64Helper.deserializeObject;
import static org.opensearch.security.support.Base64Helper.serializeObject;
import static org.junit.Assert.assertThat;

public class Base64HelperTest {

Expand Down Expand Up @@ -48,4 +53,22 @@ public void testEnsureJDKSerialized() {
Assert.assertEquals(jdkSerialized, Base64Helper.ensureJDKSerialized(jdkSerialized));
Assert.assertEquals(jdkSerialized, Base64Helper.ensureJDKSerialized(customSerialized));
}

@Test
public void testDuplicatedItemSizes() {
var largeObject = new HashMap<String, Object>();
var hm = new HashMap<>();
IntStream.range(0, 100).forEach(i -> { hm.put("c" + i, "cvalue" + i); });
IntStream.range(0, 100).forEach(i -> { largeObject.put("b" + i, hm); });

final var jdkSerialized = Base64Helper.serializeObject(largeObject, true);
final var customSerialized = Base64Helper.serializeObject(largeObject, false);
final var customSerializedOnlyHashMap = Base64Helper.serializeObject(hm, false);

assertThat(jdkSerialized.length(), equalTo(3832));
// The custom serializer is ~50x larger than the jdk serialized version
assertThat(customSerialized.length(), equalTo(184792));
// Show that the majority of the size of the custom serialized large object is the map duplicated ~100 times
assertThat((double) customSerializedOnlyHashMap.length(), closeTo(customSerialized.length() / 100, 70d));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.UnknownHostException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;

import org.junit.Before;
import org.junit.Test;
Expand Down Expand Up @@ -120,7 +121,7 @@ public class SecurityInterceptorTests {

private AsyncSender sender;
private AsyncSender serializedSender;
private AsyncSender nullSender;
private AtomicReference<CountDownLatch> senderLatch = new AtomicReference<>(new CountDownLatch(1));

@Before
public void setup() {
Expand Down Expand Up @@ -208,6 +209,7 @@ public <T extends TransportResponse> void sendRequest(
) {
String serializedUserHeader = threadPool.getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER);
assertEquals(serializedUserHeader, Base64Helper.serializeObject(user, true));
senderLatch.get().countDown();
}
};

Expand All @@ -222,6 +224,7 @@ public <T extends TransportResponse> void sendRequest(
) {
User transientUser = threadPool.getThreadContext().getTransient(ConfigConstants.OPENDISTRO_SECURITY_USER);
assertEquals(transientUser, user);
senderLatch.get().countDown();
}
};

Expand Down Expand Up @@ -249,17 +252,16 @@ final void completableRequestDecorate(
TransportResponseHandler handler,
DiscoveryNode localNode
) {
securityInterceptor.sendRequestDecorate(sender, connection, action, request, options, handler, localNode);
verifyOriginalContext(user);
try {
senderLatch.get().await(1, TimeUnit.SECONDS);
} catch (final InterruptedException e) {
throw new RuntimeException(e);
}

ExecutorService singleThreadExecutor = Executors.newSingleThreadExecutor();

singleThreadExecutor.execute(() -> {
try {
securityInterceptor.sendRequestDecorate(sender, connection, action, request, options, handler, localNode);
verifyOriginalContext(user);
} finally {
singleThreadExecutor.shutdown();
}
});
// Reset the latch so another request can be processed
senderLatch.set(new CountDownLatch(1));
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,16 @@ public void testUseJDKSerializationHeaderIsSetOnMessageReceived() throws Excepti
when(transportChannel.getVersion()).thenReturn(Version.V_2_11_0);
Assert.assertThrows(Exception.class, () -> securitySSLRequestHandler.messageReceived(transportRequest, transportChannel, task));
Assert.assertFalse(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION));

threadPool.getThreadContext().stashContext();
when(transportChannel.getVersion()).thenReturn(Version.V_2_13_0);
Assert.assertThrows(Exception.class, () -> securitySSLRequestHandler.messageReceived(transportRequest, transportChannel, task));
Assert.assertFalse(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION));

threadPool.getThreadContext().stashContext();
when(transportChannel.getVersion()).thenReturn(Version.V_2_14_0);
Assert.assertThrows(Exception.class, () -> securitySSLRequestHandler.messageReceived(transportRequest, transportChannel, task));
Assert.assertTrue(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION));
}

@Test
Expand All @@ -111,6 +121,16 @@ public void testUseJDKSerializationHeaderIsSetWithWrapperChannel() throws Except
when(transportChannel.getVersion()).thenReturn(Version.V_2_11_0);
Assert.assertThrows(Exception.class, () -> securitySSLRequestHandler.messageReceived(transportRequest, wrappedChannel, task));
Assert.assertFalse(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION));

threadPool.getThreadContext().stashContext();
when(transportChannel.getVersion()).thenReturn(Version.V_2_13_0);
Assert.assertThrows(Exception.class, () -> securitySSLRequestHandler.messageReceived(transportRequest, wrappedChannel, task));
Assert.assertFalse(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION));

threadPool.getThreadContext().stashContext();
when(transportChannel.getVersion()).thenReturn(Version.V_2_14_0);
Assert.assertThrows(Exception.class, () -> securitySSLRequestHandler.messageReceived(transportRequest, wrappedChannel, task));
Assert.assertTrue(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION));
}

@Test
Expand Down

0 comments on commit 60045e7

Please sign in to comment.