Skip to content

Commit

Permalink
Base64Helper uses custom serialization
Browse files Browse the repository at this point in the history
Signed-off-by: Paras Jain <parasjaz@amazon.com>
  • Loading branch information
Paras Jain committed Jun 2, 2023
1 parent 76a5d7f commit add8f29
Show file tree
Hide file tree
Showing 14 changed files with 472 additions and 27 deletions.
15 changes: 15 additions & 0 deletions src/main/java/com/amazon/dlic/auth/ldap/LdapUser.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

package com.amazon.dlic.auth.ldap;

import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
Expand All @@ -20,6 +21,8 @@

import com.amazon.dlic.auth.ldap.util.Utils;

import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.io.stream.StreamOutput;
import org.opensearch.security.support.WildcardMatcher;
import org.opensearch.security.user.AuthCredentials;
import org.opensearch.security.user.User;
Expand All @@ -45,6 +48,12 @@ public LdapUser(
attributes.putAll(extractLdapAttributes(originalUsername, userEntry, customAttrMaxValueLen, allowlistedCustomLdapAttrMatcher));
}

public LdapUser(StreamInput in) throws IOException {
super(in);
userEntry = null;
originalUsername = in.readString();
}

/**
* May return null because ldapEntry is transient
*
Expand Down Expand Up @@ -88,4 +97,10 @@ public static Map<String, String> extractLdapAttributes(
}
return Collections.unmodifiableMap(attributes);
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(originalUsername);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -584,15 +584,15 @@ private Origin getOrigin() {
private TransportAddress getRemoteAddress() {
TransportAddress address = threadPool.getThreadContext().getTransient(ConfigConstants.OPENDISTRO_SECURITY_REMOTE_ADDRESS);
if(address == null && threadPool.getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_REMOTE_ADDRESS_HEADER) != null) {
address = new TransportAddress((InetSocketAddress) Base64Helper.deserializeObject(threadPool.getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_REMOTE_ADDRESS_HEADER)));
address = new TransportAddress((InetSocketAddress) Base64Helper.deserializeObject(threadPool.getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_REMOTE_ADDRESS_HEADER), threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION)));
}
return address;
}

private String getUser() {
User user = threadPool.getThreadContext().getTransient(ConfigConstants.OPENDISTRO_SECURITY_USER);
if(user == null && threadPool.getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER) != null) {
user = (User) Base64Helper.deserializeObject(threadPool.getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER));
user = (User) Base64Helper.deserializeObject(threadPool.getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER), threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION));
}
return user==null?null:user.getName();
}
Expand Down
15 changes: 14 additions & 1 deletion src/main/java/org/opensearch/security/auth/UserInjector.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

package org.opensearch.security.auth;

import java.io.IOException;
import java.io.ObjectStreamException;
import java.net.InetAddress;
import java.net.UnknownHostException;
Expand All @@ -36,6 +37,8 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.io.stream.StreamOutput;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.transport.TransportAddress;
import org.opensearch.rest.RestRequest;
Expand Down Expand Up @@ -63,13 +66,18 @@ public UserInjector(Settings settings, ThreadPool threadPool, AuditLog auditLog,

}

static class InjectedUser extends User {
public static class InjectedUser extends User {
private transient TransportAddress transportAddress;

public InjectedUser(String name) {
super(name);
}

public InjectedUser(StreamInput in) throws IOException {
super(in);
this.setInjected(true);
}

private Object writeReplace() throws ObjectStreamException {
User user = new User(getName());
user.addRoles(getRoles());
Expand All @@ -96,6 +104,11 @@ public void setTransportAddress(String addr) throws UnknownHostException, Illega

this.transportAddress = new TransportAddress(iAdress, port);
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
}
}

public InjectedUser getInjectedUser() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ private void setDlsHeaders(EvaluatedDlsFlsConfig dlsFls, ActionRequest request)
}
} else {
if (threadContext.getHeader(ConfigConstants.OPENDISTRO_SECURITY_DLS_QUERY_HEADER) != null) {
Object deserializedDlsQueries = Base64Helper.deserializeObject(threadContext.getHeader(ConfigConstants.OPENDISTRO_SECURITY_DLS_QUERY_HEADER));
Object deserializedDlsQueries = Base64Helper.deserializeObject(threadContext.getHeader(ConfigConstants.OPENDISTRO_SECURITY_DLS_QUERY_HEADER), threadContext.getTransient(ConfigConstants.USE_JDK_SERIALIZATION));
if (!dlsQueries.equals(deserializedDlsQueries)) {
throw new OpenSearchSecurityException(ConfigConstants.OPENDISTRO_SECURITY_DLS_QUERY_HEADER + " does not match (SG 900D)");
}
Expand Down Expand Up @@ -437,7 +437,7 @@ private void setFlsHeaders(EvaluatedDlsFlsConfig dlsFls, ActionRequest request)
} else {

if (threadContext.getHeader(ConfigConstants.OPENDISTRO_SECURITY_MASKED_FIELD_HEADER) != null) {
if (!maskedFieldsMap.equals(Base64Helper.deserializeObject(threadContext.getHeader(ConfigConstants.OPENDISTRO_SECURITY_MASKED_FIELD_HEADER)))) {
if (!maskedFieldsMap.equals(Base64Helper.deserializeObject(threadContext.getHeader(ConfigConstants.OPENDISTRO_SECURITY_MASKED_FIELD_HEADER), threadContext.getTransient(ConfigConstants.USE_JDK_SERIALIZATION)))) {
throw new OpenSearchSecurityException(ConfigConstants.OPENDISTRO_SECURITY_MASKED_FIELD_HEADER + " does not match (SG 901D)");
} else {
if (log.isDebugEnabled()) {
Expand All @@ -463,9 +463,9 @@ private void setFlsHeaders(EvaluatedDlsFlsConfig dlsFls, ActionRequest request)
}
} else {
if (threadContext.getHeader(ConfigConstants.OPENDISTRO_SECURITY_FLS_FIELDS_HEADER) != null) {
if (!flsFields.equals(Base64Helper.deserializeObject(threadContext.getHeader(ConfigConstants.OPENDISTRO_SECURITY_FLS_FIELDS_HEADER)))) {
if (!flsFields.equals(Base64Helper.deserializeObject(threadContext.getHeader(ConfigConstants.OPENDISTRO_SECURITY_FLS_FIELDS_HEADER), threadContext.getTransient(ConfigConstants.USE_JDK_SERIALIZATION)))) {
throw new OpenSearchSecurityException(ConfigConstants.OPENDISTRO_SECURITY_FLS_FIELDS_HEADER + " does not match (SG 901D) " + flsFields
+ "---" + Base64Helper.deserializeObject(threadContext.getHeader(ConfigConstants.OPENDISTRO_SECURITY_FLS_FIELDS_HEADER)));
+ "---" + Base64Helper.deserializeObject(threadContext.getHeader(ConfigConstants.OPENDISTRO_SECURITY_FLS_FIELDS_HEADER), threadContext.getTransient(ConfigConstants.USE_JDK_SERIALIZATION)));
} else {
if (log.isDebugEnabled()) {
log.debug(ConfigConstants.OPENDISTRO_SECURITY_FLS_FIELDS_HEADER + " already set");
Expand Down
174 changes: 172 additions & 2 deletions src/main/java/org/opensearch/security/support/Base64Helper.java
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@
import java.util.regex.Pattern;

import com.google.common.base.Preconditions;
import com.google.common.collect.BiMap;
import com.google.common.collect.HashBiMap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.io.BaseEncoding;
Expand All @@ -61,7 +63,12 @@

import org.opensearch.OpenSearchException;
import org.opensearch.SpecialPermission;
import org.opensearch.common.io.stream.BytesStreamInput;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.io.stream.Writeable;
import org.opensearch.core.common.Strings;
import org.opensearch.security.auth.UserInjector;
import org.opensearch.security.user.User;

public class Base64Helper {
Expand All @@ -88,10 +95,40 @@ public class Base64Helper {
Enum.class
);


private enum CustomSerializationFormat {

WRITEABLE(1),
STREAMABLE(2),
GENERIC(3);

private final int id;

CustomSerializationFormat(int id) {
this.id = id;
}

static CustomSerializationFormat fromId(int id) {
switch (id) {
case 1: return WRITEABLE;
case 2: return STREAMABLE;
case 3: return GENERIC;
default: throw new IllegalArgumentException(String.format("%d is not a valid id", id));
}
}

}

private static final BiMap<Class<?>, Integer> writeableClassToIdMap = HashBiMap.create();
private static final StreamableRegistry streamableRegistry = StreamableRegistry.getInstance();
private static final Set<String> SAFE_CLASS_NAMES = Collections.singleton(
"org.ldaptive.LdapAttribute$LdapAttributeValues"
);

static {
registerAllWriteables();
}

private static boolean isSafeClass(Class<?> cls) {
return cls.isArray() ||
SAFE_CLASSES.contains(cls) ||
Expand Down Expand Up @@ -156,7 +193,7 @@ protected Object replaceObject(Object obj) throws IOException {
}
}

public static String serializeObject(final Serializable object) {
private static String serializeObjectJDK(final Serializable object) {

Preconditions.checkArgument(object != null, "object must not be null");

Expand All @@ -170,7 +207,47 @@ public static String serializeObject(final Serializable object) {
return BaseEncoding.base64().encode(bytes);
}

public static Serializable deserializeObject(final String string) {
private static String serializeObjectCustom(final Serializable object) {

Preconditions.checkArgument(object != null, "object must not be null");
final BytesStreamOutput streamOutput = new BytesStreamOutput(128);
Class<?> clazz = object.getClass();
try {
CustomSerializationFormat customSerializationFormat = getCustomSerializationMode(clazz);
switch (customSerializationFormat) {
case WRITEABLE:
streamOutput.writeByte((byte) CustomSerializationFormat.WRITEABLE.id);
streamOutput.writeByte((byte) getWriteableClassID(clazz).intValue());
((Writeable) object).writeTo(streamOutput);
break;
case STREAMABLE:
streamOutput.writeByte((byte) CustomSerializationFormat.STREAMABLE.id);
streamableRegistry.writeTo(streamOutput, object);
break;
case GENERIC:
streamOutput.writeByte((byte) CustomSerializationFormat.GENERIC.id);
streamOutput.writeGenericValue(object);
break;
default:
throw new IllegalArgumentException(String.format("Could not determine custom serialization mode for class %s", clazz.getName()));
}
} catch (final Exception e) {
throw new OpenSearchException("Instance {} of class {} is not serializable", e, object, object.getClass());
}
final byte[] bytes = streamOutput.bytes().toBytesRef().bytes;
streamOutput.close();
return BaseEncoding.base64().encode(bytes);
}

public static String serializeObject(final Serializable object, final boolean useJDKSerialization) {
return useJDKSerialization ? serializeObjectJDK(object) : serializeObjectCustom(object);
}

public static String serializeObject(final Serializable object) {
return serializeObjectCustom(object);
}

private static Serializable deserializeObjectJDK(final String string) {

Preconditions.checkArgument(!Strings.isNullOrEmpty(string), "string must not be null or empty");

Expand All @@ -183,6 +260,37 @@ public static Serializable deserializeObject(final String string) {
}
}

private static Serializable deserializeObjectCustom(final String string) {

Preconditions.checkArgument(!Strings.isNullOrEmpty(string), "string must not be null or empty");
final byte[] bytes = BaseEncoding.base64().decode(string);
try (final BytesStreamInput streamInput = new BytesStreamInput(bytes)) {
CustomSerializationFormat serializationFormat = CustomSerializationFormat.fromId(streamInput.readByte());
switch (serializationFormat) {
case WRITEABLE:
final int classId = streamInput.readByte();
Class<?> clazz = getWriteableClassFromId(classId);
return (Serializable) clazz.getConstructor(StreamInput.class).newInstance(streamInput);
case STREAMABLE:
return (Serializable) streamableRegistry.readFrom(streamInput);
case GENERIC:
return (Serializable) streamInput.readGenericValue();
default:
throw new IllegalArgumentException("Could not determine custom deserialization mode");
}
} catch (final Exception e) {
throw new OpenSearchException(e);
}
}

public static Serializable deserializeObject(final String string) {
return deserializeObjectCustom(string);
}

public static Serializable deserializeObject(final String string, final boolean useJDKDeserialization) {
return useJDKDeserialization ? deserializeObjectJDK(string) : deserializeObjectCustom(string);
}

private final static class SafeObjectInputStream extends ObjectInputStream {

public SafeObjectInputStream(InputStream in) throws IOException {
Expand All @@ -200,4 +308,66 @@ protected Class<?> resolveClass(ObjectStreamClass desc) throws IOException, Clas
throw new InvalidClassException("Unauthorized deserialization attempt ", clazz.getName());
}
}

private static boolean isWriteable(Class<?> clazz) {
return Writeable.class.isAssignableFrom(clazz);
}

/**
* Returns integer ID for the registered Writeable class
* <br/>
* Protected for testing
*/
protected static Integer getWriteableClassID(Class<?> clazz) {
if ( !isWriteable(clazz) ) {
throw new OpenSearchException("clazz should implement Writeable ", clazz);
}
if( !writeableClassToIdMap.containsKey(clazz) ) {
throw new OpenSearchException("Writeable clazz not registered ", clazz);
}
return writeableClassToIdMap.get(clazz);
}

private static Class<?> getWriteableClassFromId(int id) {
return writeableClassToIdMap.inverse().get(id);
}

/**
* Registers the given <code>Writeable</code> class for custom serialization by assigning an incrementing integer ID
* IDs are stored in a HashBiMap
* @param clazz class to be registered
*/
private static void registerWriteable(Class<? extends Writeable> clazz) {
if ( writeableClassToIdMap.containsKey(clazz) ) {
throw new OpenSearchException("writeable clazz is already registered ", clazz.getName());
}
int id = writeableClassToIdMap.size() + 1;
writeableClassToIdMap.put(clazz, id);
}

/**
* Registers all <code>Writeable</code> classes for custom serialization support.
* Removing existing classes / changing order of registration will cause a breaking change in the serialization protocol
* as <code>registerWriteable</code> assigns an incrementing integer ID to each of the classes in the order it is called
* starting from <code>1</code>.
*<br/>
* New classes can safely be added towards the end.
*/
private static void registerAllWriteables() {
registerWriteable(User.class);
registerWriteable(LdapUser.class);
registerWriteable(UserInjector.InjectedUser.class);
registerWriteable(SourceFieldsContext.class);
}

private static CustomSerializationFormat getCustomSerializationMode(Class<?> clazz) {
if ( isWriteable(clazz) ) {
return CustomSerializationFormat.WRITEABLE;
} else if (streamableRegistry.isStreamable(clazz) ) {
return CustomSerializationFormat.STREAMABLE;
} else {
return CustomSerializationFormat.GENERIC;
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ public class ConfigConstants {

public static final String OPENDISTRO_SECURITY_INITIAL_ACTION_CLASS_HEADER = OPENDISTRO_SECURITY_CONFIG_PREFIX+"initial_action_class_header";

public static final String OPENDISTRO_SECURITY_SOURCE_FIELD_CONTEXT = OPENDISTRO_SECURITY_CONFIG_PREFIX+"source_field_context";

/**
* Set by SSL plugin for https requests only
*/
Expand Down Expand Up @@ -296,6 +298,8 @@ public enum RolesMappingResolution {
public static final String TENANCY_GLOBAL_TENANT_NAME = "global";
public static final String TENANCY_GLOBAL_TENANT_DEFAULT_NAME = "";

public static final String USE_JDK_SERIALIZATION = "plugins.security.use_jdk_serialization";

public static Set<String> getSettingAsSet(final Settings settings, final String key, final List<String> defaultList, final boolean ignoreCaseForNone) {
final List<String> list = settings.getAsList(key, defaultList);
if (list.size() == 1 && "NONE".equals(ignoreCaseForNone? list.get(0).toUpperCase() : list.get(0))) {
Expand Down
Loading

0 comments on commit add8f29

Please sign in to comment.