From 8ae88a7049291abc3be2003552a04fbc6d092fb1 Mon Sep 17 00:00:00 2001 From: Craig Perkins Date: Thu, 19 Sep 2024 13:37:26 -0400 Subject: [PATCH] Add ensureCustomSerialization to ensure that headers are serialized correctly with multiple transport hops (#4741) Signed-off-by: Craig Perkins --- .../configuration/ClusterInfoHolder.java | 12 +++ .../security/filter/SecurityFilter.java | 2 +- .../security/support/Base64Helper.java | 28 +++++- .../transport/SecurityInterceptor.java | 24 +++-- .../security/support/Base64HelperTest.java | 9 ++ .../transport/SecurityInterceptorTests.java | 95 +++++++++++++++++-- 6 files changed, 150 insertions(+), 20 deletions(-) diff --git a/src/main/java/org/opensearch/security/configuration/ClusterInfoHolder.java b/src/main/java/org/opensearch/security/configuration/ClusterInfoHolder.java index d7429c5d1d..a9f08eb5f1 100644 --- a/src/main/java/org/opensearch/security/configuration/ClusterInfoHolder.java +++ b/src/main/java/org/opensearch/security/configuration/ClusterInfoHolder.java @@ -29,6 +29,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.Version; import org.opensearch.cluster.ClusterChangedEvent; import org.opensearch.cluster.ClusterStateListener; import org.opensearch.cluster.node.DiscoveryNode; @@ -67,6 +68,17 @@ public boolean isInitialized() { return initialized; } + public Version getMinNodeVersion() { + if (nodes == null) { + if (log.isDebugEnabled()) { + log.debug("Cluster Info Holder not initialized yet for 'nodes'"); + } + return null; + } + + return nodes.getMinNodeVersion(); + } + public Boolean hasNode(DiscoveryNode node) { if (nodes == null) { if (log.isDebugEnabled()) { diff --git a/src/main/java/org/opensearch/security/filter/SecurityFilter.java b/src/main/java/org/opensearch/security/filter/SecurityFilter.java index 1116e70845..f0ab7bb487 100644 --- a/src/main/java/org/opensearch/security/filter/SecurityFilter.java +++ b/src/main/java/org/opensearch/security/filter/SecurityFilter.java @@ -185,7 +185,7 @@ private void ap } if (threadContext.getTransient(ConfigConstants.USE_JDK_SERIALIZATION) == null) { - threadContext.putTransient(ConfigConstants.USE_JDK_SERIALIZATION, false); + threadContext.putTransient(ConfigConstants.USE_JDK_SERIALIZATION, true); } final ComplianceConfig complianceConfig = auditLog.getComplianceConfig(); diff --git a/src/main/java/org/opensearch/security/support/Base64Helper.java b/src/main/java/org/opensearch/security/support/Base64Helper.java index a5fbab8515..7e104ace54 100644 --- a/src/main/java/org/opensearch/security/support/Base64Helper.java +++ b/src/main/java/org/opensearch/security/support/Base64Helper.java @@ -35,11 +35,11 @@ public static String serializeObject(final Serializable object, final boolean us } public static String serializeObject(final Serializable object) { - return serializeObject(object, false); + return serializeObject(object, true); } public static Serializable deserializeObject(final String string) { - return deserializeObject(string, false); + return deserializeObject(string, true); } public static Serializable deserializeObject(final String string, final boolean useJDKDeserialization) { @@ -69,4 +69,28 @@ public static String ensureJDKSerialized(final String string) { // If we see an exception now, we want the caller to see it - return Base64Helper.serializeObject(serializable, true); } + + /** + * Ensures that the returned string is custom serialized. + * + * If the supplied string is a JDK serialized representation, will deserialize it and further serialize using + * custom, otherwise returns the string as is. + * + * @param string original string, can be JDK or custom serialized + * @return custom serialized string + */ + public static String ensureCustomSerialized(final String string) { + Serializable serializable; + try { + serializable = Base64Helper.deserializeObject(string, true); + } catch (Exception e) { + // We received an exception when de-serializing the given string. It is probably custom serialized. + // Try to deserialize using custom + Base64Helper.deserializeObject(string, false); + // Since we could deserialize the object using custom, the string is already custom serialized, return as is + return string; + } + // If we see an exception now, we want the caller to see it - + return Base64Helper.serializeObject(serializable, false); + } } diff --git a/src/main/java/org/opensearch/security/transport/SecurityInterceptor.java b/src/main/java/org/opensearch/security/transport/SecurityInterceptor.java index f55d9ac338..9741014fda 100644 --- a/src/main/java/org/opensearch/security/transport/SecurityInterceptor.java +++ b/src/main/java/org/opensearch/security/transport/SecurityInterceptor.java @@ -39,6 +39,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.Version; import org.opensearch.action.admin.cluster.shards.ClusterSearchShardsAction; import org.opensearch.action.admin.cluster.shards.ClusterSearchShardsResponse; import org.opensearch.action.get.GetRequest; @@ -231,13 +232,22 @@ && getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_INJECTED_ROL } try { - if (serializationFormat == SerializationFormat.JDK) { - Map jdkSerializedHeaders = new HashMap<>(); - HeaderHelper.getAllSerializedHeaderNames() - .stream() - .filter(k -> headerMap.get(k) != null) - .forEach(k -> jdkSerializedHeaders.put(k, Base64Helper.ensureJDKSerialized(headerMap.get(k)))); - headerMap.putAll(jdkSerializedHeaders); + if (clusterInfoHolder.getMinNodeVersion() == null || clusterInfoHolder.getMinNodeVersion().before(Version.V_2_14_0)) { + if (serializationFormat == SerializationFormat.JDK) { + Map jdkSerializedHeaders = new HashMap<>(); + HeaderHelper.getAllSerializedHeaderNames() + .stream() + .filter(k -> headerMap.get(k) != null) + .forEach(k -> jdkSerializedHeaders.put(k, Base64Helper.ensureJDKSerialized(headerMap.get(k)))); + headerMap.putAll(jdkSerializedHeaders); + } else if (serializationFormat == SerializationFormat.CustomSerializer_2_11) { + Map customSerializedHeaders = new HashMap<>(); + HeaderHelper.getAllSerializedHeaderNames() + .stream() + .filter(k -> headerMap.get(k) != null) + .forEach(k -> customSerializedHeaders.put(k, Base64Helper.ensureCustomSerialized(headerMap.get(k)))); + headerMap.putAll(customSerializedHeaders); + } } getThreadContext().putHeader(headerMap); } catch (IllegalArgumentException iae) { diff --git a/src/test/java/org/opensearch/security/support/Base64HelperTest.java b/src/test/java/org/opensearch/security/support/Base64HelperTest.java index de21c67d52..7c7e68b342 100644 --- a/src/test/java/org/opensearch/security/support/Base64HelperTest.java +++ b/src/test/java/org/opensearch/security/support/Base64HelperTest.java @@ -53,6 +53,15 @@ public void testEnsureJDKSerialized() { assertThat(Base64Helper.ensureJDKSerialized(customSerialized), is(jdkSerialized)); } + @Test + public void testEnsureCustomSerialized() { + String test = "string"; + String jdkSerialized = Base64Helper.serializeObject(test, true); + String customSerialized = Base64Helper.serializeObject(test, false); + assertThat(Base64Helper.ensureCustomSerialized(jdkSerialized), is(customSerialized)); + assertThat(Base64Helper.ensureCustomSerialized(customSerialized), is(customSerialized)); + } + @Test public void testDuplicatedItemSizes() { var largeObject = new HashMap(); diff --git a/src/test/java/org/opensearch/security/transport/SecurityInterceptorTests.java b/src/test/java/org/opensearch/security/transport/SecurityInterceptorTests.java index 42884862a2..d12fafb247 100644 --- a/src/test/java/org/opensearch/security/transport/SecurityInterceptorTests.java +++ b/src/test/java/org/opensearch/security/transport/SecurityInterceptorTests.java @@ -119,9 +119,12 @@ public class SecurityInterceptorTests { private Connection connection3; private DiscoveryNode otherRemoteNode; private Connection connection4; + private DiscoveryNode remoteNodeWithCustomSerialization; + private Connection connection5; private AsyncSender sender; - private AsyncSender serializedSender; + private AsyncSender jdkSerializedSender; + private AsyncSender customSerializedSender; private AtomicReference senderLatch = new AtomicReference<>(new CountDownLatch(1)); @Before @@ -199,7 +202,30 @@ public void setup() { otherRemoteNode = new DiscoveryNode("remote-node2", new TransportAddress(remoteAddress, 9876), remoteNodeVersion); connection4 = transportService.getConnection(otherRemoteNode); - serializedSender = new AsyncSender() { + remoteNodeWithCustomSerialization = new DiscoveryNode( + "remote-node-with-custom-serialization", + new TransportAddress(localAddress, 7456), + Version.V_2_12_0 + ); + connection5 = transportService.getConnection(remoteNodeWithCustomSerialization); + + jdkSerializedSender = new AsyncSender() { + @Override + public void sendRequest( + Connection connection, + String action, + TransportRequest request, + TransportRequestOptions options, + TransportResponseHandler handler + ) { + String serializedUserHeader = threadPool.getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER); + User deserializedUser = (User) Base64Helper.deserializeObject(serializedUserHeader, true); + assertThat(deserializedUser, is(user)); + senderLatch.get().countDown(); + } + }; + + customSerializedSender = new AsyncSender() { @Override public void sendRequest( Connection connection, @@ -209,7 +235,7 @@ public void sendRequest( TransportResponseHandler handler ) { String serializedUserHeader = threadPool.getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER); - assertThat(serializedUserHeader, is(Base64Helper.serializeObject(user, true))); + assertThat(serializedUserHeader, is(Base64Helper.serializeObject(user, false))); senderLatch.get().countDown(); } }; @@ -265,6 +291,27 @@ final void completableRequestDecorate( senderLatch.set(new CountDownLatch(1)); } + @SuppressWarnings({ "rawtypes", "unchecked" }) + final void completableRequestDecorateWithPreviouslyPopulatedHeaders( + AsyncSender sender, + Connection connection, + String action, + TransportRequest request, + TransportRequestOptions options, + TransportResponseHandler handler, + DiscoveryNode localNode + ) { + securityInterceptor.sendRequestDecorate(sender, connection, action, request, options, handler, localNode); + try { + senderLatch.get().await(1, TimeUnit.SECONDS); + } catch (final InterruptedException e) { + throw new RuntimeException(e); + } + + // Reset the latch so another request can be processed + senderLatch.set(new CountDownLatch(1)); + } + @Test public void testSendRequestDecorateLocalConnection() { @@ -278,16 +325,44 @@ public void testSendRequestDecorateLocalConnection() { public void testSendRequestDecorateRemoteConnection() { // this is a remote request - completableRequestDecorate(serializedSender, connection3, action, request, options, handler, localNode); + completableRequestDecorate(jdkSerializedSender, connection3, action, request, options, handler, localNode); // this is a remote request where the transport address is different - completableRequestDecorate(serializedSender, connection4, action, request, options, handler, localNode); + completableRequestDecorate(jdkSerializedSender, connection4, action, request, options, handler, localNode); + } + + @Test + public void testSendRequestDecorateRemoteConnectionUsesJDKSerialization() { + threadPool.getThreadContext().putHeader(ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER, Base64Helper.serializeObject(user, false)); + completableRequestDecorateWithPreviouslyPopulatedHeaders( + jdkSerializedSender, + connection3, + action, + request, + options, + handler, + localNode + ); + } + + @Test + public void testSendRequestDecorateRemoteConnectionUsesCustomSerialization() { + threadPool.getThreadContext().putHeader(ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER, Base64Helper.serializeObject(user, true)); + completableRequestDecorateWithPreviouslyPopulatedHeaders( + customSerializedSender, + connection5, + action, + request, + options, + handler, + localNode + ); } @Test public void testSendNoOriginNodeCausesSerialization() { // this is a request where the local node is null; have to use the remote connection since the serialization will fail - completableRequestDecorate(serializedSender, connection3, action, request, options, handler, null); + completableRequestDecorate(jdkSerializedSender, connection3, action, request, options, handler, null); } @Test @@ -296,7 +371,7 @@ public void testSendNoConnectionShouldThrowNPE() { // The completable version swallows the NPE so have to call actual method assertThrows( java.lang.NullPointerException.class, - () -> securityInterceptor.sendRequestDecorate(serializedSender, null, action, request, options, handler, localNode) + () -> securityInterceptor.sendRequestDecorate(jdkSerializedSender, null, action, request, options, handler, localNode) ); } @@ -328,7 +403,7 @@ public void testCustomRemoteAddressCausesSerialization() { ConfigConstants.OPENDISTRO_SECURITY_REMOTE_ADDRESS, String.valueOf(new TransportAddress(new InetSocketAddress("8.8.8.8", 80))) ); - completableRequestDecorate(serializedSender, connection3, action, request, options, handler, localNode); + completableRequestDecorate(jdkSerializedSender, connection3, action, request, options, handler, localNode); } @Test @@ -351,7 +426,7 @@ public void testFakeHeaderIsIgnored() { // this is a local request completableRequestDecorate(sender, connection1, action, request, options, handler, localNode); // this is a remote request - completableRequestDecorate(serializedSender, connection3, action, request, options, handler, localNode); + completableRequestDecorate(jdkSerializedSender, connection3, action, request, options, handler, localNode); } @Test @@ -363,7 +438,7 @@ public void testNullHeaderIsIgnored() { // this is a local request completableRequestDecorate(sender, connection1, action, request, options, handler, localNode); // this is a remote request - completableRequestDecorate(serializedSender, connection3, action, request, options, handler, localNode); + completableRequestDecorate(jdkSerializedSender, connection3, action, request, options, handler, localNode); } @Test