diff --git a/src/main/java/org/opensearch/agent/ToolPlugin.java b/src/main/java/org/opensearch/agent/ToolPlugin.java index 1db7915e..0124aa7a 100644 --- a/src/main/java/org/opensearch/agent/ToolPlugin.java +++ b/src/main/java/org/opensearch/agent/ToolPlugin.java @@ -16,7 +16,6 @@ import org.opensearch.agent.tools.SearchAlertsTool; import org.opensearch.agent.tools.SearchAnomalyDetectorsTool; import org.opensearch.agent.tools.SearchAnomalyResultsTool; -import org.opensearch.agent.tools.SearchIndexTool; import org.opensearch.agent.tools.SearchMonitorsTool; import org.opensearch.agent.tools.VectorDBTool; import org.opensearch.agent.tools.VisualizationsTool; @@ -66,7 +65,6 @@ public Collection createComponents( VisualizationsTool.Factory.getInstance().init(client); NeuralSparseSearchTool.Factory.getInstance().init(client, xContentRegistry); VectorDBTool.Factory.getInstance().init(client, xContentRegistry); - SearchIndexTool.Factory.getInstance().init(client, xContentRegistry); RAGTool.Factory.getInstance().init(client, xContentRegistry); SearchAlertsTool.Factory.getInstance().init(client); SearchAnomalyDetectorsTool.Factory.getInstance().init(client, namedWriteableRegistry); @@ -83,7 +81,6 @@ public List> getToolFactories() { NeuralSparseSearchTool.Factory.getInstance(), VectorDBTool.Factory.getInstance(), VisualizationsTool.Factory.getInstance(), - SearchIndexTool.Factory.getInstance(), RAGTool.Factory.getInstance(), SearchAlertsTool.Factory.getInstance(), SearchAnomalyDetectorsTool.Factory.getInstance(), diff --git a/src/main/java/org/opensearch/agent/tools/AbstractRetrieverTool.java b/src/main/java/org/opensearch/agent/tools/AbstractRetrieverTool.java index 5003f0fa..f01dde7e 100644 --- a/src/main/java/org/opensearch/agent/tools/AbstractRetrieverTool.java +++ b/src/main/java/org/opensearch/agent/tools/AbstractRetrieverTool.java @@ -68,7 +68,7 @@ protected AbstractRetrieverTool( protected abstract String getQueryBody(String queryText); - public static Map processResponse(SearchHit hit) { + private static Map processResponse(SearchHit hit) { Map docContent = new HashMap<>(); docContent.put("_index", hit.getIndex()); docContent.put("_id", hit.getId()); diff --git a/src/main/java/org/opensearch/agent/tools/SearchIndexTool.java b/src/main/java/org/opensearch/agent/tools/SearchIndexTool.java deleted file mode 100644 index f6d5a80a..00000000 --- a/src/main/java/org/opensearch/agent/tools/SearchIndexTool.java +++ /dev/null @@ -1,189 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.agent.tools; - -import static org.opensearch.ml.common.CommonValue.*; - -import java.io.IOException; -import java.security.AccessController; -import java.security.PrivilegedExceptionAction; -import java.util.Map; -import java.util.Objects; -import java.util.Optional; - -import org.opensearch.action.search.SearchRequest; -import org.opensearch.action.search.SearchResponse; -import org.opensearch.client.Client; -import org.opensearch.common.xcontent.LoggingDeprecationHandler; -import org.opensearch.common.xcontent.XContentType; -import org.opensearch.core.action.ActionListener; -import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.ml.common.spi.tools.Tool; -import org.opensearch.ml.common.spi.tools.ToolAnnotation; -import org.opensearch.ml.common.transport.connector.MLConnectorSearchAction; -import org.opensearch.ml.common.transport.model.MLModelSearchAction; -import org.opensearch.ml.common.transport.model_group.MLModelGroupSearchAction; -import org.opensearch.ml.common.utils.StringUtils; -import org.opensearch.search.SearchHit; -import org.opensearch.search.builder.SearchSourceBuilder; - -import com.google.gson.JsonElement; -import com.google.gson.JsonObject; - -import lombok.Getter; -import lombok.Setter; -import lombok.extern.log4j.Log4j2; - -@Getter -@Setter -@Log4j2 -@ToolAnnotation(SearchIndexTool.TYPE) -public class SearchIndexTool implements Tool { - - public static final String INPUT_FIELD = "input"; - public static final String INDEX_FIELD = "index"; - public static final String QUERY_FIELD = "query"; - - public static final String TYPE = "SearchIndexTool"; - private static final String DEFAULT_DESCRIPTION = - "Use this tool to search an index by providing two parameters: 'index' for the index name, and 'query' for the OpenSearch DSL formatted query. Only use this tool when a DSL query is available."; - - private String name = TYPE; - - private String description = DEFAULT_DESCRIPTION; - - private Client client; - - private NamedXContentRegistry xContentRegistry; - - public SearchIndexTool(Client client, NamedXContentRegistry xContentRegistry) { - this.client = client; - this.xContentRegistry = xContentRegistry; - } - - @Override - public String getType() { - return TYPE; - } - - @Override - public String getVersion() { - return null; - } - - @Override - public boolean validate(Map parameters) { - return parameters != null && parameters.containsKey(INPUT_FIELD) && parameters.get(INPUT_FIELD) != null; - } - - private SearchRequest getSearchRequest(String index, String query) throws IOException { - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); - XContentParser queryParser = XContentType.JSON.xContent().createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, query); - searchSourceBuilder.parseXContent(queryParser); - return new SearchRequest().source(searchSourceBuilder).indices(index); - } - - @Override - public void run(Map parameters, ActionListener listener) { - try { - String input = parameters.get(INPUT_FIELD); - JsonObject jsonObject = StringUtils.gson.fromJson(input, JsonObject.class); - String index = Optional.ofNullable(jsonObject).map(x -> x.get(INDEX_FIELD)).map(JsonElement::getAsString).orElse(null); - String query = Optional.ofNullable(jsonObject).map(x -> x.get(QUERY_FIELD)).map(JsonElement::toString).orElse(null); - if (index == null || query == null) { - listener.onFailure(new IllegalArgumentException("SearchIndexTool's two parameter: index and query are required!")); - return; - } - SearchRequest searchRequest = getSearchRequest(index, query); - - ActionListener actionListener = ActionListener.wrap(r -> { - SearchHit[] hits = r.getHits().getHits(); - - if (hits != null && hits.length > 0) { - StringBuilder contextBuilder = new StringBuilder(); - for (SearchHit hit : hits) { - String doc = AccessController.doPrivileged((PrivilegedExceptionAction) () -> { - Map docContent = AbstractRetrieverTool.processResponse(hit); - return StringUtils.gson.toJson(docContent); - }); - contextBuilder.append(doc).append("\n"); - } - listener.onResponse((T) contextBuilder.toString()); - } else { - listener.onResponse((T) ""); - } - }, e -> { - log.error("Failed to search index", e); - listener.onFailure(e); - }); - - // since searching connector and model needs access control, we need - // to forward the request corresponding transport action - if (Objects.equals(index, ML_CONNECTOR_INDEX)) { - client.execute(MLConnectorSearchAction.INSTANCE, searchRequest, actionListener); - } else if (Objects.equals(index, ML_MODEL_INDEX)) { - client.execute(MLModelSearchAction.INSTANCE, searchRequest, actionListener); - } else if (Objects.equals(index, ML_MODEL_GROUP_INDEX)) { - client.execute(MLModelGroupSearchAction.INSTANCE, searchRequest, actionListener); - } else { - client.search(searchRequest, actionListener); - } - } catch (Exception e) { - log.error("Failed to search index", e); - listener.onFailure(e); - } - } - - public static class Factory implements Tool.Factory { - - private Client client; - private static Factory INSTANCE; - - private NamedXContentRegistry xContentRegistry; - - /** - * Create or return the singleton factory instance - */ - public static Factory getInstance() { - if (INSTANCE != null) { - return INSTANCE; - } - synchronized (SearchIndexTool.class) { - if (INSTANCE != null) { - return INSTANCE; - } - INSTANCE = new Factory(); - return INSTANCE; - } - } - - public void init(Client client, NamedXContentRegistry xContentRegistry) { - this.client = client; - this.xContentRegistry = xContentRegistry; - } - - @Override - public SearchIndexTool create(Map params) { - return new SearchIndexTool(client, xContentRegistry); - } - - @Override - public String getDefaultDescription() { - return DEFAULT_DESCRIPTION; - } - - @Override - public String getDefaultType() { - return TYPE; - } - - @Override - public String getDefaultVersion() { - return null; - } - } -} diff --git a/src/test/java/org/opensearch/agent/tools/SearchIndexToolTests.java b/src/test/java/org/opensearch/agent/tools/SearchIndexToolTests.java deleted file mode 100644 index d228c0cb..00000000 --- a/src/test/java/org/opensearch/agent/tools/SearchIndexToolTests.java +++ /dev/null @@ -1,183 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.agent.tools; - -import static org.junit.Assert.*; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.*; - -import java.io.InputStream; -import java.util.Collections; -import java.util.List; -import java.util.Map; -import java.util.concurrent.CompletableFuture; - -import org.junit.Before; -import org.junit.Test; -import org.mockito.Mockito; -import org.opensearch.action.search.SearchResponse; -import org.opensearch.client.Client; -import org.opensearch.common.settings.Settings; -import org.opensearch.common.xcontent.json.JsonXContent; -import org.opensearch.core.action.ActionListener; -import org.opensearch.core.common.Strings; -import org.opensearch.core.xcontent.DeprecationHandler; -import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.ml.common.transport.connector.MLConnectorSearchAction; -import org.opensearch.ml.common.transport.model.MLModelSearchAction; -import org.opensearch.ml.common.transport.model_group.MLModelGroupSearchAction; -import org.opensearch.search.SearchModule; - -import lombok.SneakyThrows; - -public class SearchIndexToolTests { - static public final NamedXContentRegistry TEST_XCONTENT_REGISTRY_FOR_QUERY = new NamedXContentRegistry( - new SearchModule(Settings.EMPTY, List.of()).getNamedXContents() - ); - - private Client client; - - private SearchIndexTool mockedSearchIndexTool; - - private String mockedSearchResponseString; - - @Before - @SneakyThrows - public void setup() { - client = mock(Client.class); - mockedSearchIndexTool = Mockito - .mock( - SearchIndexTool.class, - Mockito.withSettings().useConstructor(client, TEST_XCONTENT_REGISTRY_FOR_QUERY).defaultAnswer(Mockito.CALLS_REAL_METHODS) - ); - - try (InputStream searchResponseIns = SearchIndexTool.class.getResourceAsStream("retrieval_tool_search_response.json")) { - if (searchResponseIns != null) { - mockedSearchResponseString = new String(searchResponseIns.readAllBytes()); - } - } - } - - @Test - @SneakyThrows - public void testGetType() { - String type = mockedSearchIndexTool.getType(); - assertFalse(Strings.isNullOrEmpty(type)); - assertEquals("SearchIndexTool", type); - } - - @Test - @SneakyThrows - public void testValidate() { - Map parameters = Map.of("input", "{}"); - assertTrue(mockedSearchIndexTool.validate(parameters)); - } - - @Test - @SneakyThrows - public void testValidateWithEmptyInput() { - Map parameters = Map.of(); - assertFalse(mockedSearchIndexTool.validate(parameters)); - } - - @Test - public void testRunWithNormalIndex() { - String inputString = "{\"index\": \"test-index\", \"query\": {\"query\": {\"match_all\": {}}}}"; - Map parameters = Map.of("input", inputString); - mockedSearchIndexTool.run(parameters, null); - Mockito.verify(client, times(1)).search(any(), any()); - Mockito.verify(client, Mockito.never()).execute(any(), any(), any()); - } - - @Test - public void testRunWithConnectorIndex() { - String inputString = "{\"index\": \".plugins-ml-connector\", \"query\": {\"query\": {\"match_all\": {}}}}"; - Map parameters = Map.of("input", inputString); - mockedSearchIndexTool.run(parameters, null); - Mockito.verify(client, never()).search(any(), any()); - Mockito.verify(client, times(1)).execute(eq(MLConnectorSearchAction.INSTANCE), any(), any()); - } - - @Test - public void testRunWithModelIndex() { - String inputString = "{\"index\": \".plugins-ml-model\", \"query\": {\"query\": {\"match_all\": {}}}}"; - Map parameters = Map.of("input", inputString); - mockedSearchIndexTool.run(parameters, null); - Mockito.verify(client, never()).search(any(), any()); - Mockito.verify(client, times(1)).execute(eq(MLModelSearchAction.INSTANCE), any(), any()); - } - - @Test - public void testRunWithModelGroupIndex() { - String inputString = "{\"index\": \".plugins-ml-model-group\", \"query\": {\"query\": {\"match_all\": {}}}}"; - Map parameters = Map.of("input", inputString); - mockedSearchIndexTool.run(parameters, null); - Mockito.verify(client, never()).search(any(), any()); - Mockito.verify(client, times(1)).execute(eq(MLModelGroupSearchAction.INSTANCE), any(), any()); - } - - @Test - @SneakyThrows - public void testRunWithSearchResults() { - SearchResponse mockedSearchResponse = SearchResponse - .fromXContent( - JsonXContent.jsonXContent - .createParser(NamedXContentRegistry.EMPTY, DeprecationHandler.IGNORE_DEPRECATIONS, mockedSearchResponseString) - ); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(mockedSearchResponse); - return null; - }).when(client).search(any(), any()); - - String inputString = "{\"index\": \"test-index\", \"query\": {\"query\": {\"match_all\": {}}}}"; - final CompletableFuture future = new CompletableFuture<>(); - ActionListener listener = ActionListener.wrap(r -> { future.complete(r); }, e -> { future.completeExceptionally(e); }); - Map parameters = Map.of("input", inputString); - mockedSearchIndexTool.run(parameters, listener); - - future.join(); - - Mockito.verify(client, times(1)).search(any(), any()); - Mockito.verify(client, Mockito.never()).execute(any(), any(), any()); - } - - @Test - @SneakyThrows - public void testRunWithEmptyQuery() { - String inputString = "{\"index\": \"test_index\"}"; - Map parameters = Map.of("input", inputString); - ActionListener listener = mock(ActionListener.class); - mockedSearchIndexTool.run(parameters, listener); - Mockito.verify(client, Mockito.never()).execute(any(), any(), any()); - Mockito.verify(client, Mockito.never()).search(any(), any()); - } - - @Test - public void testRunWithInvalidQuery() { - String inputString = "{\"index\": \"test-index\", \"query\": \"invalid query\"}"; - Map parameters = Map.of("input", inputString); - ActionListener listener = mock(ActionListener.class); - mockedSearchIndexTool.run(parameters, listener); - Mockito.verify(client, Mockito.never()).execute(any(), any(), any()); - Mockito.verify(client, Mockito.never()).search(any(), any()); - } - - @Test - public void testRunWithEmptyQueryBody() { - String inputString = "{\"index\": \"test-index\", \"query\": {}}"; - Map parameters = Map.of("input", inputString); - mockedSearchIndexTool.run(parameters, null); - Mockito.verify(client, times(1)).search(any(), any()); - Mockito.verify(client, Mockito.never()).execute(any(), any(), any()); - } - - @Test - public void testFactory() { - SearchIndexTool searchIndexTool = SearchIndexTool.Factory.getInstance().create(Collections.emptyMap()); - assertEquals(SearchIndexTool.TYPE, searchIndexTool.getType()); - } -} diff --git a/src/test/java/org/opensearch/integTest/SearchIndexToolIT.java b/src/test/java/org/opensearch/integTest/SearchIndexToolIT.java deleted file mode 100644 index f989ebef..00000000 --- a/src/test/java/org/opensearch/integTest/SearchIndexToolIT.java +++ /dev/null @@ -1,136 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.integTest; - -import static org.hamcrest.Matchers.containsString; - -import java.nio.file.Files; -import java.nio.file.Path; -import java.util.List; - -import org.hamcrest.MatcherAssert; -import org.junit.After; -import org.junit.Before; -import org.opensearch.client.ResponseException; - -import lombok.SneakyThrows; - -public class SearchIndexToolIT extends BaseAgentToolsIT { - public static String TEST_INDEX_NAME = "test_index"; - private String registerAgentRequestBody; - - @SneakyThrows - private void prepareIndex() { - createIndexWithConfiguration( - TEST_INDEX_NAME, - "{\n" - + " \"mappings\": {\n" - + " \"properties\": {\n" - + " \"text\": {\n" - + " \"type\": \"text\"\n" - + " }\n" - + " }\n" - + " }\n" - + "}" - ); - addDocToIndex(TEST_INDEX_NAME, "0", List.of("text"), List.of("text doc 1")); - addDocToIndex(TEST_INDEX_NAME, "1", List.of("text"), List.of("text doc 2")); - addDocToIndex(TEST_INDEX_NAME, "2", List.of("text"), List.of("text doc 3")); - } - - @Before - @SneakyThrows - public void setUp() { - super.setUp(); - prepareIndex(); - registerAgentRequestBody = Files - .readString( - Path - .of( - this - .getClass() - .getClassLoader() - .getResource("org/opensearch/agent/tools/register_flow_agent_of_search_index_tool_request_body.json") - .toURI() - ) - ); - } - - @After - @SneakyThrows - public void tearDown() { - super.tearDown(); - deleteExternalIndices(); - } - - public void testSearchIndexToolInFlowAgent_withMatchAllQuery() { - String agentId = createAgent(registerAgentRequestBody); - String agentInput = "{\n" - + " \"parameters\": {\n" - + " \"input\": {\n" - + " \"index\": \"test_index\",\n" - + " \"query\": {\n" - + " \"query\": {\n" - + " \"match_all\": {}\n" - + " }\n" - + " }\n" - + " } \n" - + " }\n" - + "}\n"; - String result = executeAgent(agentId, agentInput); - assertEquals( - "The search index result not equal with expected.", - "{\"_index\":\"test_index\",\"_source\":{\"text\":\"text doc 1\"},\"_id\":\"0\",\"_score\":1.0}\n" - + "{\"_index\":\"test_index\",\"_source\":{\"text\":\"text doc 2\"},\"_id\":\"1\",\"_score\":1.0}\n" - + "{\"_index\":\"test_index\",\"_source\":{\"text\":\"text doc 3\"},\"_id\":\"2\",\"_score\":1.0}\n", - result - ); - } - - public void testSearchIndexToolInFlowAgent_withEmptyIndexField_thenThrowException() { - String agentId = createAgent(registerAgentRequestBody); - String agentInput = "{\n" - + " \"parameters\": {\n" - + " \"input\": {\n" - + " \"query\": {\n" - + " \"query\": {\n" - + " \"match_all\": {}\n" - + " }\n" - + " }\n" - + " } \n" - + " }\n" - + "}\n"; - Exception exception = assertThrows(ResponseException.class, () -> executeAgent(agentId, agentInput)); - MatcherAssert.assertThat(exception.getMessage(), containsString("SearchIndexTool's two parameter: index and query are required!")); - } - - public void testSearchIndexToolInFlowAgent_withEmptyQueryField_thenThrowException() { - String agentId = createAgent(registerAgentRequestBody); - String agentInput = "{\n" - + " \"parameters\": {\n" - + " \"input\": {\n" - + " \"index\": \"test_index\"\n" - + " } \n" - + " }\n" - + "}\n"; - Exception exception = assertThrows(ResponseException.class, () -> executeAgent(agentId, agentInput)); - MatcherAssert.assertThat(exception.getMessage(), containsString("SearchIndexTool's two parameter: index and query are required!")); - } - - public void testSearchIndexToolInFlowAgent_withIllegalQueryField_thenThrowException() { - String agentId = createAgent(registerAgentRequestBody); - String agentInput = "{\n" - + " \"parameters\": {\n" - + " \"input\": {\n" - + " \"index\": \"test_index\",\n" - + " \"query\": \"Invalid Query\"\n" - + " } \n" - + " }\n" - + "}\n"; - Exception exception = assertThrows(ResponseException.class, () -> executeAgent(agentId, agentInput)); - MatcherAssert.assertThat(exception.getMessage(), containsString("ParsingException")); - } -} diff --git a/src/test/resources/org/opensearch/agent/tools/register_flow_agent_of_search_index_tool_request_body.json b/src/test/resources/org/opensearch/agent/tools/register_flow_agent_of_search_index_tool_request_body.json deleted file mode 100644 index 52a67073..00000000 --- a/src/test/resources/org/opensearch/agent/tools/register_flow_agent_of_search_index_tool_request_body.json +++ /dev/null @@ -1,10 +0,0 @@ -{ - "name": "Test_Search_Index_Agent", - "type": "flow", - "tools": [ - { - "type": "SearchIndexTool", - "description": "Use this tool to search an index by providing two parameters: 'index' for the index name, and 'query' for the OpenSearch DSL formatted query. Only use this tool when a DSL query is available." - } - ] -} \ No newline at end of file