diff --git a/src/main/java/org/opensearch/ad/client/AnomalyDetectionNodeClient.java b/src/main/java/org/opensearch/ad/client/AnomalyDetectionNodeClient.java index 60bb274ab..714ad353f 100644 --- a/src/main/java/org/opensearch/ad/client/AnomalyDetectionNodeClient.java +++ b/src/main/java/org/opensearch/ad/client/AnomalyDetectionNodeClient.java @@ -17,12 +17,15 @@ import org.opensearch.client.Client; import org.opensearch.core.action.ActionListener; import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry; public class AnomalyDetectionNodeClient implements AnomalyDetectionClient { private final Client client; + private final NamedWriteableRegistry namedWriteableRegistry; - public AnomalyDetectionNodeClient(Client client) { + public AnomalyDetectionNodeClient(Client client, NamedWriteableRegistry namedWriteableRegistry) { this.client = client; + this.namedWriteableRegistry = namedWriteableRegistry; } @Override @@ -46,6 +49,9 @@ public void getDetectorProfile(GetAnomalyDetectorRequest profileRequest, ActionL // We need to wrap AD-specific response type listeners around an internal listener, and re-generate the response from a generic // ActionResponse. This is needed to prevent classloader issues and ClassCastExceptions when executed by other plugins. + // Additionally, we need to inject the configured NamedWriteableRegistry so NamedWriteables (present in sub-fields of + // GetAnomalyDetectorResponse) are able to be re-serialized and prevent errors like the following: + // "can't read named writeable from StreamInput" private ActionListener getAnomalyDetectorResponseActionListener( ActionListener listener ) { @@ -53,7 +59,8 @@ private ActionListener getAnomalyDetectorResponseAct listener.onResponse(getAnomalyDetectorResponse); }, listener::onFailure); ActionListener actionListener = wrapActionListener(internalListener, actionResponse -> { - GetAnomalyDetectorResponse response = GetAnomalyDetectorResponse.fromActionResponse(actionResponse); + GetAnomalyDetectorResponse response = GetAnomalyDetectorResponse + .fromActionResponse(actionResponse, this.namedWriteableRegistry); return response; }); return actionListener; diff --git a/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorResponse.java b/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorResponse.java index 091c396fb..84ad4659e 100644 --- a/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorResponse.java +++ b/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorResponse.java @@ -24,6 +24,8 @@ import org.opensearch.ad.util.RestHandlerUtils; import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry; import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; @@ -218,16 +220,21 @@ public AnomalyDetector getDetector() { return detector; } - public static GetAnomalyDetectorResponse fromActionResponse(ActionResponse actionResponse) { + public static GetAnomalyDetectorResponse fromActionResponse( + ActionResponse actionResponse, + NamedWriteableRegistry namedWriteableRegistry + ) { if (actionResponse instanceof GetAnomalyDetectorResponse) { return (GetAnomalyDetectorResponse) actionResponse; } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + try { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos); actionResponse.writeTo(osso); - try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { - return new GetAnomalyDetectorResponse(input); - } + InputStreamStreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray())); + NamedWriteableAwareStreamInput namedWriteableAwareInput = new NamedWriteableAwareStreamInput(input, namedWriteableRegistry); + return new GetAnomalyDetectorResponse(namedWriteableAwareInput); } catch (IOException e) { throw new UncheckedIOException("failed to parse ActionResponse into GetAnomalyDetectorResponse", e); } diff --git a/src/test/java/org/opensearch/ad/client/AnomalyDetectionNodeClientTests.java b/src/test/java/org/opensearch/ad/client/AnomalyDetectionNodeClientTests.java index 1b9866309..df845301f 100644 --- a/src/test/java/org/opensearch/ad/client/AnomalyDetectionNodeClientTests.java +++ b/src/test/java/org/opensearch/ad/client/AnomalyDetectionNodeClientTests.java @@ -41,6 +41,7 @@ import org.opensearch.client.Client; import org.opensearch.common.lucene.uid.Versions; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry; import org.opensearch.core.rest.RestStatus; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.TermQueryBuilder; @@ -62,7 +63,7 @@ public class AnomalyDetectionNodeClientTests extends HistoricalAnalysisIntegTest @Before public void setup() { clientSpy = spy(client()); - adClient = new AnomalyDetectionNodeClient(clientSpy); + adClient = new AnomalyDetectionNodeClient(clientSpy, mock(NamedWriteableRegistry.class)); } @Test diff --git a/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorResponseTests.java b/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorResponseTests.java index 8bf6a3ce4..9a43fb40b 100644 --- a/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorResponseTests.java +++ b/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorResponseTests.java @@ -19,8 +19,11 @@ import org.opensearch.ad.AnomalyDetectorPlugin; import org.opensearch.ad.TestHelpers; import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput; import org.opensearch.core.common.io.stream.NamedWriteableRegistry; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.rest.RestStatus; import org.opensearch.plugins.Plugin; import org.opensearch.test.InternalSettingsPlugin; @@ -76,6 +79,21 @@ public void testSerializationWithJobAndTask() throws IOException { assertEquals(response.getDetector(), parsedResponse.getDetector()); } + public void testFromActionResponse() throws IOException { + GetAnomalyDetectorResponse response = createGetAnomalyDetectorResponse(true, true); + BytesStreamOutput output = new BytesStreamOutput(); + response.writeTo(output); + NamedWriteableAwareStreamInput input = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), writableRegistry()); + + GetAnomalyDetectorResponse reserializedResponse = GetAnomalyDetectorResponse + .fromActionResponse((ActionResponse) response, writableRegistry()); + assertEquals(response, reserializedResponse); + + ActionResponse invalidActionResponse = new TestActionResponse(input); + assertThrows(Exception.class, () -> GetAnomalyDetectorResponse.fromActionResponse(invalidActionResponse, writableRegistry())); + + } + private GetAnomalyDetectorResponse createGetAnomalyDetectorResponse(boolean returnJob, boolean returnTask) throws IOException { GetAnomalyDetectorResponse response = new GetAnomalyDetectorResponse( randomLong(), @@ -95,4 +113,17 @@ private GetAnomalyDetectorResponse createGetAnomalyDetectorResponse(boolean retu ); return response; } + + // A test ActionResponse class with an inactive writeTo class. Used to ensure exceptions + // are thrown when parsing implementations of such class. + private class TestActionResponse extends ActionResponse { + public TestActionResponse(StreamInput in) throws IOException { + super(in); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + return; + } + } }