diff --git a/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorRequest.java b/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorRequest.java index aef29626d..83358bc9d 100644 --- a/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorRequest.java +++ b/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorRequest.java @@ -11,10 +11,14 @@ package org.opensearch.ad.transport; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; import java.io.IOException; import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.timeseries.model.Entity; @@ -119,4 +123,19 @@ public void writeTo(StreamOutput out) throws IOException { public ActionRequestValidationException validate() { return null; } + + public static GetAnomalyDetectorRequest fromActionRequest(final ActionRequest actionRequest) { + if (actionRequest instanceof GetAnomalyDetectorRequest) { + return (GetAnomalyDetectorRequest) actionRequest; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionRequest.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new GetAnomalyDetectorRequest(input); + } + } catch (IOException e) { + throw new IllegalArgumentException("failed to parse ActionRequest into GetAnomalyDetectorRequest", e); + } + } } diff --git a/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorTransportAction.java b/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorTransportAction.java index 3b040c9e1..bdc4460f2 100644 --- a/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorTransportAction.java @@ -35,6 +35,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.ActionRequest; import org.opensearch.action.get.MultiGetItemResponse; import org.opensearch.action.get.MultiGetRequest; import org.opensearch.action.get.MultiGetResponse; @@ -75,7 +76,7 @@ import com.google.common.collect.Sets; -public class GetAnomalyDetectorTransportAction extends HandledTransportAction { +public class GetAnomalyDetectorTransportAction extends HandledTransportAction { private static final Logger LOG = LogManager.getLogger(GetAnomalyDetectorTransportAction.class); @@ -131,7 +132,9 @@ public GetAnomalyDetectorTransportAction( } @Override - protected void doExecute(Task task, GetAnomalyDetectorRequest request, ActionListener actionListener) { + protected void doExecute(Task task, ActionRequest actionRequest, ActionListener actionListener) { + GetAnomalyDetectorRequest request = GetAnomalyDetectorRequest.fromActionRequest(actionRequest); + String detectorID = request.getDetectorID(); User user = getUserContext(client); ActionListener listener = wrapRestActionListener(actionListener, FAIL_TO_GET_DETECTOR);