Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inject NamedWriteableRegistry in AD node client #1164

Merged
merged 1 commit into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@
import org.opensearch.client.Client;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.action.ActionResponse;
import org.opensearch.core.common.io.stream.NamedWriteableRegistry;

public class AnomalyDetectionNodeClient implements AnomalyDetectionClient {
private final Client client;
private final NamedWriteableRegistry namedWriteableRegistry;

public AnomalyDetectionNodeClient(Client client) {
public AnomalyDetectionNodeClient(Client client, NamedWriteableRegistry namedWriteableRegistry) {
this.client = client;
this.namedWriteableRegistry = namedWriteableRegistry;
}

@Override
Expand All @@ -46,14 +49,18 @@ public void getDetectorProfile(GetAnomalyDetectorRequest profileRequest, ActionL

// We need to wrap AD-specific response type listeners around an internal listener, and re-generate the response from a generic
// ActionResponse. This is needed to prevent classloader issues and ClassCastExceptions when executed by other plugins.
// Additionally, we need to inject the configured NamedWriteableRegistry so NamedWriteables (present in sub-fields of
// GetAnomalyDetectorResponse) are able to be re-serialized and prevent errors like the following:
// "can't read named writeable from StreamInput"
private ActionListener<GetAnomalyDetectorResponse> getAnomalyDetectorResponseActionListener(
ActionListener<GetAnomalyDetectorResponse> listener
) {
ActionListener<GetAnomalyDetectorResponse> internalListener = ActionListener.wrap(getAnomalyDetectorResponse -> {
listener.onResponse(getAnomalyDetectorResponse);
}, listener::onFailure);
ActionListener<GetAnomalyDetectorResponse> actionListener = wrapActionListener(internalListener, actionResponse -> {
GetAnomalyDetectorResponse response = GetAnomalyDetectorResponse.fromActionResponse(actionResponse);
GetAnomalyDetectorResponse response = GetAnomalyDetectorResponse
.fromActionResponse(actionResponse, this.namedWriteableRegistry);
return response;
});
return actionListener;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import org.opensearch.ad.model.EntityProfile;
import org.opensearch.core.action.ActionResponse;
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput;
import org.opensearch.core.common.io.stream.NamedWriteableRegistry;
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
Expand Down Expand Up @@ -218,16 +220,21 @@ public AnomalyDetector getDetector() {
return detector;
}

public static GetAnomalyDetectorResponse fromActionResponse(ActionResponse actionResponse) {
public static GetAnomalyDetectorResponse fromActionResponse(
ActionResponse actionResponse,
NamedWriteableRegistry namedWriteableRegistry
) {
if (actionResponse instanceof GetAnomalyDetectorResponse) {
return (GetAnomalyDetectorResponse) actionResponse;
}

try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
try {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos);
actionResponse.writeTo(osso);
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) {
return new GetAnomalyDetectorResponse(input);
}
InputStreamStreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()));
NamedWriteableAwareStreamInput namedWriteableAwareInput = new NamedWriteableAwareStreamInput(input, namedWriteableRegistry);
return new GetAnomalyDetectorResponse(namedWriteableAwareInput);
} catch (IOException e) {
throw new UncheckedIOException("failed to parse ActionResponse into GetAnomalyDetectorResponse", e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import org.opensearch.client.Client;
import org.opensearch.common.lucene.uid.Versions;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.io.stream.NamedWriteableRegistry;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.TermQueryBuilder;
Expand All @@ -64,7 +65,7 @@ public class AnomalyDetectionNodeClientTests extends HistoricalAnalysisIntegTest
@Before
public void setup() {
clientSpy = spy(client());
adClient = new AnomalyDetectionNodeClient(clientSpy);
adClient = new AnomalyDetectionNodeClient(clientSpy, mock(NamedWriteableRegistry.class));
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@
import java.util.Collection;

import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.core.action.ActionResponse;
import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput;
import org.opensearch.core.common.io.stream.NamedWriteableRegistry;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.plugins.Plugin;
import org.opensearch.test.InternalSettingsPlugin;
Expand Down Expand Up @@ -76,6 +79,21 @@ public void testSerializationWithJobAndTask() throws IOException {
assertEquals(response.getDetector(), parsedResponse.getDetector());
}

public void testFromActionResponse() throws IOException {
GetAnomalyDetectorResponse response = createGetAnomalyDetectorResponse(true, true);
BytesStreamOutput output = new BytesStreamOutput();
response.writeTo(output);
NamedWriteableAwareStreamInput input = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), writableRegistry());

GetAnomalyDetectorResponse reserializedResponse = GetAnomalyDetectorResponse
.fromActionResponse((ActionResponse) response, writableRegistry());
assertEquals(response, reserializedResponse);

ActionResponse invalidActionResponse = new TestActionResponse(input);
assertThrows(Exception.class, () -> GetAnomalyDetectorResponse.fromActionResponse(invalidActionResponse, writableRegistry()));

}

private GetAnomalyDetectorResponse createGetAnomalyDetectorResponse(boolean returnJob, boolean returnTask) throws IOException {
GetAnomalyDetectorResponse response = new GetAnomalyDetectorResponse(
randomLong(),
Expand All @@ -95,4 +113,17 @@ private GetAnomalyDetectorResponse createGetAnomalyDetectorResponse(boolean retu
);
return response;
}

// A test ActionResponse class with an inactive writeTo class. Used to ensure exceptions
// are thrown when parsing implementations of such class.
private class TestActionResponse extends ActionResponse {
public TestActionResponse(StreamInput in) throws IOException {
super(in);
}

@Override
public void writeTo(StreamOutput out) throws IOException {
return;
}
}
}
Loading