From 5588571ee57b2aa5d259ff6f3dc4166a1b39b945 Mon Sep 17 00:00:00 2001 From: Hailong Cui Date: Sun, 28 Apr 2024 15:44:58 +0800 Subject: [PATCH] Move visualization tool to ml-commons (#296) Signed-off-by: Hailong Cui --- .../java/org/opensearch/agent/ToolPlugin.java | 3 - .../agent/tools/VisualizationsTool.java | 178 ------------------ .../agent/tools/VisualizationsToolTests.java | 161 ---------------- .../integTest/VisualizationsToolIT.java | 114 ----------- .../opensearch/agent/tools/visualization.json | 58 ------ .../agent/tools/visualization_not_found.json | 18 -- 6 files changed, 532 deletions(-) delete mode 100644 src/main/java/org/opensearch/agent/tools/VisualizationsTool.java delete mode 100644 src/test/java/org/opensearch/agent/tools/VisualizationsToolTests.java delete mode 100644 src/test/java/org/opensearch/integTest/VisualizationsToolIT.java delete mode 100644 src/test/resources/org/opensearch/agent/tools/visualization.json delete mode 100644 src/test/resources/org/opensearch/agent/tools/visualization_not_found.json diff --git a/src/main/java/org/opensearch/agent/ToolPlugin.java b/src/main/java/org/opensearch/agent/ToolPlugin.java index 0124aa7a..db07ac0b 100644 --- a/src/main/java/org/opensearch/agent/ToolPlugin.java +++ b/src/main/java/org/opensearch/agent/ToolPlugin.java @@ -18,7 +18,6 @@ import org.opensearch.agent.tools.SearchAnomalyResultsTool; import org.opensearch.agent.tools.SearchMonitorsTool; import org.opensearch.agent.tools.VectorDBTool; -import org.opensearch.agent.tools.VisualizationsTool; import org.opensearch.client.Client; import org.opensearch.cluster.metadata.IndexNameExpressionResolver; import org.opensearch.cluster.service.ClusterService; @@ -62,7 +61,6 @@ public Collection createComponents( this.xContentRegistry = xContentRegistry; PPLTool.Factory.getInstance().init(client); - VisualizationsTool.Factory.getInstance().init(client); NeuralSparseSearchTool.Factory.getInstance().init(client, xContentRegistry); VectorDBTool.Factory.getInstance().init(client, xContentRegistry); RAGTool.Factory.getInstance().init(client, xContentRegistry); @@ -80,7 +78,6 @@ public List> getToolFactories() { PPLTool.Factory.getInstance(), NeuralSparseSearchTool.Factory.getInstance(), VectorDBTool.Factory.getInstance(), - VisualizationsTool.Factory.getInstance(), RAGTool.Factory.getInstance(), SearchAlertsTool.Factory.getInstance(), SearchAnomalyDetectorsTool.Factory.getInstance(), diff --git a/src/main/java/org/opensearch/agent/tools/VisualizationsTool.java b/src/main/java/org/opensearch/agent/tools/VisualizationsTool.java deleted file mode 100644 index 2fa6b996..00000000 --- a/src/main/java/org/opensearch/agent/tools/VisualizationsTool.java +++ /dev/null @@ -1,178 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.agent.tools; - -import java.util.Arrays; -import java.util.Locale; -import java.util.Map; -import java.util.Optional; - -import org.opensearch.ExceptionsHelper; -import org.opensearch.action.search.SearchRequest; -import org.opensearch.action.search.SearchResponse; -import org.opensearch.client.Client; -import org.opensearch.client.Requests; -import org.opensearch.core.action.ActionListener; -import org.opensearch.core.common.Strings; -import org.opensearch.index.IndexNotFoundException; -import org.opensearch.index.query.BoolQueryBuilder; -import org.opensearch.index.query.QueryBuilders; -import org.opensearch.ml.common.spi.tools.Tool; -import org.opensearch.ml.common.spi.tools.ToolAnnotation; -import org.opensearch.search.SearchHits; -import org.opensearch.search.builder.SearchSourceBuilder; - -import lombok.Builder; -import lombok.Getter; -import lombok.Setter; -import lombok.extern.log4j.Log4j2; - -@Log4j2 -@ToolAnnotation(VisualizationsTool.TYPE) -public class VisualizationsTool implements Tool { - public static final String NAME = "FindVisualizations"; - public static final String TYPE = "VisualizationTool"; - public static final String VERSION = "v1.0"; - - public static final String SAVED_OBJECT_TYPE = "visualization"; - - /** - * default number of visualizations returned - */ - private static final int DEFAULT_SIZE = 3; - private static final String DEFAULT_DESCRIPTION = - "Use this tool to find user created visualizations. This tool takes the visualization name as input and returns matching visualizations"; - @Setter - @Getter - private String description = DEFAULT_DESCRIPTION; - - @Getter - @Setter - private String name = NAME; - @Getter - @Setter - private String type = TYPE; - @Getter - private final String version = VERSION; - private final Client client; - @Getter - private final String index; - @Getter - private final int size; - - @Builder - public VisualizationsTool(Client client, String index, int size) { - this.client = client; - this.index = index; - this.size = size; - } - - @Override - public void run(Map parameters, ActionListener listener) { - BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery(); - boolQueryBuilder.must().add(QueryBuilders.termQuery("type", SAVED_OBJECT_TYPE)); - boolQueryBuilder.must().add(QueryBuilders.matchQuery(SAVED_OBJECT_TYPE + ".title", parameters.get("input"))); - - SearchSourceBuilder searchSourceBuilder = SearchSourceBuilder.searchSource().query(boolQueryBuilder); - searchSourceBuilder.from(0).size(size); - SearchRequest searchRequest = Requests.searchRequest(index).source(searchSourceBuilder); - - client.search(searchRequest, new ActionListener<>() { - @Override - public void onResponse(SearchResponse searchResponse) { - SearchHits hits = searchResponse.getHits(); - StringBuilder visBuilder = new StringBuilder(); - visBuilder.append("Title,Id\n"); - if (hits.getTotalHits().value > 0) { - Arrays.stream(hits.getHits()).forEach(h -> { - String id = trimIdPrefix(h.getId()); - Map visMap = (Map) h.getSourceAsMap().get(SAVED_OBJECT_TYPE); - String title = visMap.get("title"); - visBuilder.append(String.format(Locale.ROOT, "%s,%s\n", title, id)); - }); - - listener.onResponse((T) visBuilder.toString()); - } else { - listener.onResponse((T) "No Visualization found"); - } - } - - @Override - public void onFailure(Exception e) { - if (ExceptionsHelper.unwrapCause(e) instanceof IndexNotFoundException) { - listener.onResponse((T) "No Visualization found"); - } else { - listener.onFailure(e); - } - } - }); - } - - String trimIdPrefix(String id) { - id = Optional.ofNullable(id).orElse(""); - if (id.startsWith(SAVED_OBJECT_TYPE)) { - String prefix = String.format(Locale.ROOT, "%s:", SAVED_OBJECT_TYPE); - return id.substring(prefix.length()); - } - return id; - } - - @Override - public boolean validate(Map parameters) { - return parameters.containsKey("input") && !Strings.isNullOrEmpty(parameters.get("input")); - } - - public static class Factory implements Tool.Factory { - private Client client; - - private static VisualizationsTool.Factory INSTANCE; - - public static VisualizationsTool.Factory getInstance() { - if (INSTANCE != null) { - return INSTANCE; - } - synchronized (VisualizationsTool.class) { - if (INSTANCE != null) { - return INSTANCE; - } - INSTANCE = new VisualizationsTool.Factory(); - return INSTANCE; - } - } - - public void init(Client client) { - this.client = client; - } - - @Override - public VisualizationsTool create(Map params) { - String index = params.get("index") == null ? ".kibana" : (String) params.get("index"); - String sizeStr = params.get("size") == null ? "3" : (String) params.get("size"); - int size; - try { - size = Integer.parseInt(sizeStr); - } catch (NumberFormatException ignored) { - size = DEFAULT_SIZE; - } - return VisualizationsTool.builder().client(client).index(index).size(size).build(); - } - - @Override - public String getDefaultDescription() { - return DEFAULT_DESCRIPTION; - } - - @Override - public String getDefaultType() { - return TYPE; - } - - @Override - public String getDefaultVersion() { - return null; - } - } -} diff --git a/src/test/java/org/opensearch/agent/tools/VisualizationsToolTests.java b/src/test/java/org/opensearch/agent/tools/VisualizationsToolTests.java deleted file mode 100644 index 9cd79ff9..00000000 --- a/src/test/java/org/opensearch/agent/tools/VisualizationsToolTests.java +++ /dev/null @@ -1,161 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.agent.tools; - -import static org.junit.Assert.assertEquals; - -import java.io.IOException; -import java.io.InputStream; -import java.nio.charset.StandardCharsets; -import java.util.Collections; -import java.util.Map; -import java.util.concurrent.CompletableFuture; - -import org.junit.Assert; -import org.junit.Before; -import org.junit.Test; -import org.mockito.ArgumentCaptor; -import org.mockito.ArgumentMatchers; -import org.mockito.Mock; -import org.mockito.Mockito; -import org.mockito.MockitoAnnotations; -import org.opensearch.action.search.SearchRequest; -import org.opensearch.action.search.SearchResponse; -import org.opensearch.client.Client; -import org.opensearch.common.xcontent.json.JsonXContent; -import org.opensearch.core.action.ActionListener; -import org.opensearch.core.xcontent.DeprecationHandler; -import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.index.IndexNotFoundException; -import org.opensearch.ml.common.spi.tools.Tool; - -public class VisualizationsToolTests { - @Mock - private Client client; - - private String searchResponse = "{}"; - private String searchResponseNotFound = "{}"; - - @Before - public void setup() throws IOException { - MockitoAnnotations.openMocks(this); - VisualizationsTool.Factory.getInstance().init(client); - try (InputStream searchResponseIns = VisualizationsToolTests.class.getResourceAsStream("visualization.json")) { - if (searchResponseIns != null) { - searchResponse = new String(searchResponseIns.readAllBytes(), StandardCharsets.UTF_8); - } - } - try (InputStream searchResponseIns = VisualizationsToolTests.class.getResourceAsStream("visualization_not_found.json")) { - if (searchResponseIns != null) { - searchResponseNotFound = new String(searchResponseIns.readAllBytes(), StandardCharsets.UTF_8); - } - } - } - - @Test - public void testToolIndexName() { - VisualizationsTool tool1 = VisualizationsTool.Factory.getInstance().create(Collections.emptyMap()); - assertEquals(tool1.getIndex(), ".kibana"); - - VisualizationsTool tool2 = VisualizationsTool.Factory.getInstance().create(Map.of("index", "test-index")); - assertEquals(tool2.getIndex(), "test-index"); - } - - @Test - public void testNumberOfVisualizationReturned() { - VisualizationsTool tool1 = VisualizationsTool.Factory.getInstance().create(Collections.emptyMap()); - assertEquals(tool1.getSize(), 3); - - VisualizationsTool tool2 = VisualizationsTool.Factory.getInstance().create(Map.of("size", "1")); - assertEquals(tool2.getSize(), 1); - - VisualizationsTool tool3 = VisualizationsTool.Factory.getInstance().create(Map.of("size", "badString")); - assertEquals(tool3.getSize(), 3); - } - - @Test - public void testTrimPrefix() { - VisualizationsTool tool = VisualizationsTool.Factory.getInstance().create(Collections.emptyMap()); - assertEquals(tool.trimIdPrefix(null), ""); - assertEquals(tool.trimIdPrefix("abc"), "abc"); - assertEquals(tool.trimIdPrefix("visualization:abc"), "abc"); - } - - @Test - public void testParameterValidation() { - VisualizationsTool tool = VisualizationsTool.Factory.getInstance().create(Collections.emptyMap()); - Assert.assertFalse(tool.validate(Collections.emptyMap())); - Assert.assertFalse(tool.validate(Map.of("input", ""))); - Assert.assertTrue(tool.validate(Map.of("input", "question"))); - } - - @Test - public void testRunToolWithVisualizationFound() throws Exception { - Tool tool = VisualizationsTool.Factory.getInstance().create(Collections.emptyMap()); - final CompletableFuture future = new CompletableFuture<>(); - ActionListener listener = ActionListener.wrap(future::complete, future::completeExceptionally); - - ArgumentCaptor> searchResponseListener = ArgumentCaptor.forClass(ActionListener.class); - Mockito.doNothing().when(client).search(ArgumentMatchers.any(SearchRequest.class), searchResponseListener.capture()); - - Map params = Map.of("input", "Sales by gender"); - - tool.run(params, listener); - - SearchResponse response = SearchResponse - .fromXContent( - JsonXContent.jsonXContent.createParser(NamedXContentRegistry.EMPTY, DeprecationHandler.IGNORE_DEPRECATIONS, searchResponse) - ); - searchResponseListener.getValue().onResponse(response); - - future.join(); - assertEquals("Title,Id\n[Ecommerce]Sales by gender,aeb212e0-4c84-11e8-b3d7-01146121b73d\n", future.get()); - } - - @Test - public void testRunToolWithNoVisualizationFound() throws Exception { - Tool tool = VisualizationsTool.Factory.getInstance().create(Collections.emptyMap()); - final CompletableFuture future = new CompletableFuture<>(); - ActionListener listener = ActionListener.wrap(future::complete, future::completeExceptionally); - - ArgumentCaptor> searchResponseListener = ArgumentCaptor.forClass(ActionListener.class); - Mockito.doNothing().when(client).search(ArgumentMatchers.any(SearchRequest.class), searchResponseListener.capture()); - - Map params = Map.of("input", "Sales by gender"); - - tool.run(params, listener); - - SearchResponse response = SearchResponse - .fromXContent( - JsonXContent.jsonXContent - .createParser(NamedXContentRegistry.EMPTY, DeprecationHandler.IGNORE_DEPRECATIONS, searchResponseNotFound) - ); - searchResponseListener.getValue().onResponse(response); - - future.join(); - assertEquals("No Visualization found", future.get()); - } - - @Test - public void testRunToolWithIndexNotExists() throws Exception { - Tool tool = VisualizationsTool.Factory.getInstance().create(Collections.emptyMap()); - final CompletableFuture future = new CompletableFuture<>(); - ActionListener listener = ActionListener.wrap(future::complete, future::completeExceptionally); - - ArgumentCaptor> searchResponseListener = ArgumentCaptor.forClass(ActionListener.class); - Mockito.doNothing().when(client).search(ArgumentMatchers.any(SearchRequest.class), searchResponseListener.capture()); - - Map params = Map.of("input", "Sales by gender"); - - tool.run(params, listener); - - IndexNotFoundException notFoundException = new IndexNotFoundException("test-index"); - searchResponseListener.getValue().onFailure(notFoundException); - - future.join(); - assertEquals("No Visualization found", future.get()); - } -} diff --git a/src/test/java/org/opensearch/integTest/VisualizationsToolIT.java b/src/test/java/org/opensearch/integTest/VisualizationsToolIT.java deleted file mode 100644 index 2bf0e611..00000000 --- a/src/test/java/org/opensearch/integTest/VisualizationsToolIT.java +++ /dev/null @@ -1,114 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.integTest; - -import java.io.IOException; -import java.util.List; -import java.util.Locale; -import java.util.UUID; - -import org.junit.Assert; -import org.opensearch.agent.tools.VisualizationsTool; -import org.opensearch.client.Request; -import org.opensearch.client.Response; -import org.opensearch.core.rest.RestStatus; - -import com.google.gson.JsonArray; -import com.google.gson.JsonElement; -import com.google.gson.JsonParser; - -import lombok.extern.log4j.Log4j2; - -@Log4j2 -public class VisualizationsToolIT extends ToolIntegrationTest { - @Override - List promptHandlers() { - return List.of(new PromptHandler() { - @Override - LLMThought llmThought() { - return LLMThought - .builder() - .action(VisualizationsTool.TYPE) - .actionInput("RAM") - .question("can you show me RAM info with visualization?") - .build(); - } - }, new PromptHandler() { - @Override - LLMThought llmThought() { - return LLMThought - .builder() - .action(VisualizationsTool.TYPE) - .actionInput("sales") - .question("how about the sales about this month?") - .build(); - } - }); - } - - String toolType() { - return VisualizationsTool.TYPE; - } - - public void testVisualizationNotFound() throws IOException { - Request request = new Request("POST", "/_plugins/_ml/agents/" + agentId + "/_execute"); - request.setJsonEntity("{\"parameters\":{\"question\":\"can you show me RAM info with visualization?\"}}"); - Response response = executeRequest(request); - String responseStr = readResponse(response); - String toolOutput = extractAdditionalInfo(responseStr); - Assert.assertEquals("No Visualization found", toolOutput); - } - - public void testVisualizationFound() throws IOException { - String title = "[eCommerce] Sales by Category"; - String id = UUID.randomUUID().toString(); - prepareVisualization(title, id); - Request request = new Request("POST", "/_plugins/_ml/agents/" + agentId + "/_execute"); - request.setJsonEntity("{\"parameters\":{\"question\":\"how about the sales about this month?\"}}"); - Response response = executeRequest(request); - String responseStr = readResponse(response); - String toolOutput = extractAdditionalInfo(responseStr); - Assert.assertEquals("Title,Id\n" + String.format(Locale.ROOT, "%s,%s\n", title, id), toolOutput); - } - - private void prepareVisualization(String title, String id) { - String body = "{\n" - + " \"visualization\": {\n" - + " \"title\": \"" - + title - + "\"\n" - + " },\n" - + " \"type\": \"visualization\"\n" - + "}"; - Response response = makeRequest(client(), "POST", String.format(Locale.ROOT, ".kibana/_doc/%s?refresh=true", id), null, body, null); - Assert.assertEquals(response.getStatusLine().getStatusCode(), RestStatus.CREATED.getStatus()); - } - - private String extractAdditionalInfo(String responseStr) { - JsonArray output = JsonParser - .parseString(responseStr) - .getAsJsonObject() - .get("inference_results") - .getAsJsonArray() - .get(0) - .getAsJsonObject() - .get("output") - .getAsJsonArray(); - for (JsonElement element : output) { - if ("response".equals(element.getAsJsonObject().get("name").getAsString())) { - return element - .getAsJsonObject() - .get("dataAsMap") - .getAsJsonObject() - .get("additional_info") - .getAsJsonObject() - .get(String.format(Locale.ROOT, "%s.output", toolType())) - .getAsString(); - } - } - return null; - } -} diff --git a/src/test/resources/org/opensearch/agent/tools/visualization.json b/src/test/resources/org/opensearch/agent/tools/visualization.json deleted file mode 100644 index 8901706e..00000000 --- a/src/test/resources/org/opensearch/agent/tools/visualization.json +++ /dev/null @@ -1,58 +0,0 @@ -{ - "took": 4, - "timed_out": false, - "_shards": { - "total": 1, - "successful": 1, - "skipped": 0, - "failed": 0 - }, - "hits": { - "total": { - "value": 1, - "relation": "eq" - }, - "max_score": 0.2847877, - "hits": [ - { - "_index": ".kibana_1", - "_id": "visualization:aeb212e0-4c84-11e8-b3d7-01146121b73d", - "_score": 0.2847877, - "_source": { - "visualization": { - "title": "[Ecommerce]Sales by gender", - "visState": "", - "uiStateJSON": "{}", - "description": "", - "version": 1, - "kibanaSavedObjectMeta": { - "searchSourceJSON": "{}" - } - }, - "type": "visualization", - "references": [ - { - "name": "control_0_index_pattern", - "type": "index-pattern", - "id": "d3d7af60-4c81-11e8-b3d7-01146121b73d" - }, - { - "name": "control_1_index_pattern", - "type": "index-pattern", - "id": "d3d7af60-4c81-11e8-b3d7-01146121b73d" - }, - { - "name": "control_2_index_pattern", - "type": "index-pattern", - "id": "d3d7af60-4c81-11e8-b3d7-01146121b73d" - } - ], - "migrationVersion": { - "visualization": "7.10.0" - }, - "updated_at": "2023-11-10T02:50:24.881Z" - } - } - ] - } -} diff --git a/src/test/resources/org/opensearch/agent/tools/visualization_not_found.json b/src/test/resources/org/opensearch/agent/tools/visualization_not_found.json deleted file mode 100644 index 40a0e9d3..00000000 --- a/src/test/resources/org/opensearch/agent/tools/visualization_not_found.json +++ /dev/null @@ -1,18 +0,0 @@ -{ - "took": 1, - "timed_out": false, - "_shards": { - "total": 1, - "successful": 1, - "skipped": 0, - "failed": 0 - }, - "hits": { - "total": { - "value": 0, - "relation": "eq" - }, - "max_score": null, - "hits": [] - } -}