From 943ee01cd9b95b5d4e40194709067ad5946616f9 Mon Sep 17 00:00:00 2001 From: Hailong Cui Date: Fri, 29 Dec 2023 18:27:59 +0800 Subject: [PATCH 1/5] mock server and integTest for visualization Signed-off-by: Hailong Cui --- build.gradle | 1 + .../agent/tools/VisualizationsTool.java | 5 +- .../agent/integtest/PromptHandler.java | 29 ++ .../agent/integtest/ToolIntegrationTest.java | 303 ++++++++++++++++++ .../agent/integtest/VisualizationsToolIT.java | 109 +++++++ 5 files changed, 443 insertions(+), 4 deletions(-) create mode 100644 src/test/java/org/opensearch/agent/integtest/PromptHandler.java create mode 100644 src/test/java/org/opensearch/agent/integtest/ToolIntegrationTest.java create mode 100644 src/test/java/org/opensearch/agent/integtest/VisualizationsToolIT.java diff --git a/build.gradle b/build.gradle index bd63a17f..591aefec 100644 --- a/build.gradle +++ b/build.gradle @@ -131,6 +131,7 @@ dependencies { // Test dependencies testImplementation "org.opensearch.test:framework:${opensearch_version}" testImplementation group: 'junit', name: 'junit', version: '4.13.2' + testImplementation group: 'org.json', name: 'json', version: '20231013' testImplementation group: 'org.mockito', name: 'mockito-core', version: '5.8.0' testImplementation group: 'org.mockito', name: 'mockito-inline', version: '5.2.0' testImplementation("net.bytebuddy:byte-buddy:1.14.7") diff --git a/src/main/java/org/opensearch/agent/tools/VisualizationsTool.java b/src/main/java/org/opensearch/agent/tools/VisualizationsTool.java index 31f5cf09..958f1eda 100644 --- a/src/main/java/org/opensearch/agent/tools/VisualizationsTool.java +++ b/src/main/java/org/opensearch/agent/tools/VisualizationsTool.java @@ -16,6 +16,7 @@ import org.opensearch.client.Client; import org.opensearch.client.Requests; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.Strings; import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.QueryBuilders; @@ -24,9 +25,6 @@ import org.opensearch.search.SearchHits; import org.opensearch.search.builder.SearchSourceBuilder; -import com.google.common.annotations.VisibleForTesting; -import com.google.common.base.Strings; - import lombok.Builder; import lombok.Getter; import lombok.Setter; @@ -113,7 +111,6 @@ public void onFailure(Exception e) { }); } - @VisibleForTesting String trimIdPrefix(String id) { id = Optional.ofNullable(id).orElse(""); if (id.startsWith(SAVED_OBJECT_TYPE)) { diff --git a/src/test/java/org/opensearch/agent/integtest/PromptHandler.java b/src/test/java/org/opensearch/agent/integtest/PromptHandler.java new file mode 100644 index 00000000..28c8acab --- /dev/null +++ b/src/test/java/org/opensearch/agent/integtest/PromptHandler.java @@ -0,0 +1,29 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.integtest; + +import org.apache.commons.lang3.tuple.Pair; + +import com.google.gson.annotations.SerializedName; + +import lombok.Data; + +public abstract class PromptHandler { + + boolean apply(String prompt) { + return prompt.contains(questionAndInput().getKey()); + } + + abstract Pair questionAndInput(); + + @Data + static class LLMResponse { + String completion; + @SerializedName("stop_reason") + String stopReason = "stop_sequence"; + String stop = "\\n\\nHuman:"; + } +} diff --git a/src/test/java/org/opensearch/agent/integtest/ToolIntegrationTest.java b/src/test/java/org/opensearch/agent/integtest/ToolIntegrationTest.java new file mode 100644 index 00000000..f06f2f9b --- /dev/null +++ b/src/test/java/org/opensearch/agent/integtest/ToolIntegrationTest.java @@ -0,0 +1,303 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.integtest; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.TimeUnit; + +import org.junit.After; +import org.junit.Before; +import org.opensearch.client.Request; +import org.opensearch.client.RequestOptions; +import org.opensearch.client.Response; +import org.opensearch.test.OpenSearchIntegTestCase; + +import com.google.gson.Gson; +import com.google.gson.JsonParser; +import com.sun.net.httpserver.HttpServer; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public abstract class ToolIntegrationTest extends OpenSearchIntegTestCase { + protected HttpServer server; + protected String modelId; + protected String agentId; + protected String modelGroupId; + protected String connectorId; + + private final Gson gson = new Gson(); + + abstract List promptHandlers(); + + abstract String toolType(); + + @Before + public void setupAgent() throws IOException, InterruptedException { + setupMockLLM(); + clusterSettings(false); + try { + connectorId = setUpConnector(); + } catch (Exception e) { + // Wait for ML encryption master key has been initialized + TimeUnit.SECONDS.sleep(10); + connectorId = setUpConnector(); + } + modelGroupId = setupModelGroup(); + modelId = setupModel(connectorId, modelGroupId); + agentId = setUpAgent(modelId); + } + + @After + public void cleanUpClusterSetting() throws IOException { + clusterSettings(true); + } + + @After + public void stopMockLLM() { + server.stop(1); + } + + private void setupMockLLM() throws IOException { + server = HttpServer.create(new InetSocketAddress(InetAddress.getLoopbackAddress(), 0), 0); + + server.createContext("/invoke", exchange -> { + InputStream ins = exchange.getRequestBody(); + String req = new String(ins.readAllBytes(), StandardCharsets.UTF_8); + Map map = gson.fromJson(req, Map.class); + String prompt = map.get("prompt"); + log.debug("prompt received: {}", prompt); + + String llmRes = ""; + for (PromptHandler promptHandler : promptHandlers()) { + if (promptHandler.apply(prompt)) { + PromptHandler.LLMResponse llmResponse = new PromptHandler.LLMResponse(); + if (prompt.contains("TOOL RESPONSE: ")) { + llmResponse + .setCompletion( + "```json{\n" + + " \"thought\": \"Thought: Now I know the final answer\",\n" + + " \"final_answer\": \"final answer\"\n" + + "}```" + ); + } else { + String actionInput = promptHandler.questionAndInput().getValue(); + llmResponse + .setCompletion( + "```json{\n" + + " \"thought\": \"Thought: Let me use tool to figure out\",\n" + + " \"action\": \"" + + toolType() + + "\",\n" + + " \"action_input\": \"" + + actionInput + + "\"\n" + + "}```" + ); + } + llmRes = gson.toJson(llmResponse); + break; + } + } + byte[] llmResBytes = llmRes.getBytes(StandardCharsets.UTF_8); + exchange.sendResponseHeaders(200, llmResBytes.length); + exchange.getResponseBody().write(llmResBytes); + exchange.close(); + }); + server.start(); + } + + private String setUpConnector() throws IOException { + String url = String.format(Locale.ROOT, "http://127.0.0.1:%d/invoke", server.getAddress().getPort()); + Request request = new Request("POST", "/_plugins/_ml/connectors/_create"); + request + .setJsonEntity( + "{\n" + + " \"name\": \"BedRock test claude Connector\",\n" + + " \"description\": \"The connector to BedRock service for claude model\",\n" + + " \"version\": 1,\n" + + " \"protocol\": \"aws_sigv4\",\n" + + " \"parameters\": {\n" + + " \"region\": \"us-east-1\",\n" + + " \"service_name\": \"bedrock\",\n" + + " \"anthropic_version\": \"bedrock-2023-05-31\",\n" + + " \"endpoint\": \"bedrock.us-east-1.amazonaws.com\",\n" + + " \"auth\": \"Sig_V4\",\n" + + " \"content_type\": \"application/json\",\n" + + " \"max_tokens_to_sample\": 8000,\n" + + " \"temperature\": 0.0001,\n" + + " \"response_filter\": \"$.completion\"\n" + + " },\n" + + " \"credential\": {\n" + + " \"access_key\": \"\",\n" + + " \"secret_key\": \"\"\n" + + " },\n" + + " \"actions\": [\n" + + " {\n" + + " \"action_type\": \"predict\",\n" + + " \"method\": \"POST\",\n" + + " \"url\": \"" + + url + + "\",\n" + + " \"headers\": {\n" + + " \"content-type\": \"application/json\",\n" + + " \"x-amz-content-sha256\": \"required\"\n" + + " },\n" + + " \"request_body\": \"{\\\"prompt\\\":\\\"${parameters.prompt}\\\", \\\"max_tokens_to_sample\\\":${parameters.max_tokens_to_sample}, \\\"temperature\\\":${parameters.temperature}, \\\"anthropic_version\\\":\\\"${parameters.anthropic_version}\\\" }\"\n" + + " }\n" + + " ]\n" + + "}" + ); + Response response = executeRequest(request); + + return JsonParser.parseString(readResponse(response)).getAsJsonObject().get("connector_id").getAsString(); + } + + private void clusterSettings(boolean reset) throws IOException { + Request request = new Request("PUT", "_cluster/settings"); + if (!reset) { + request + .setJsonEntity( + "{\n" + + " \"persistent\": {\n" + + " \"plugins.ml_commons.only_run_on_ml_node\": false,\n" + + " \"plugins.ml_commons.memory_feature_enabled\": true,\n" + + " \"plugins.ml_commons.trusted_connector_endpoints_regex\": [\n" + + " \"^.*$\"\n" + + " ]\n" + + " }\n" + + "}" + ); + } else { + request + .setJsonEntity( + "{\n" + + " \"persistent\": {\n" + + " \"plugins.ml_commons.only_run_on_ml_node\": null,\n" + + " \"plugins.ml_commons.memory_feature_enabled\": null,\n" + + " \"plugins.ml_commons.trusted_connector_endpoints_regex\": null" + + " }\n" + + "}" + ); + } + RequestOptions.Builder builder = RequestOptions.DEFAULT.toBuilder(); + builder.addHeader("Content-Type", "application/json"); + request.setOptions(builder); + getRestClient().performRequest(request); + } + + private String setupModelGroup() throws IOException { + Request request = new Request("POST", "/_plugins/_ml/model_groups/_register"); + request + .setJsonEntity( + "{\n" + + " \"name\": \"test_model_group_bedrock-" + + UUID.randomUUID().toString() + + "\",\n" + + " \"description\": \"This is a public model group\"\n" + + "}" + ); + Response response = executeRequest(request); + + String resp = readResponse(response); + + return JsonParser.parseString(resp).getAsJsonObject().get("model_group_id").getAsString(); + } + + private String setupModel(String connectorId, String modelGroupId) throws IOException { + Request request = new Request("POST", "/_plugins/_ml/models/_register?deploy=true"); + request + .setJsonEntity( + "{\n" + + " \"name\": \"Bedrock Claude V2 model\",\n" + + " \"function_name\": \"remote\",\n" + + " \"model_group_id\": \"" + + modelGroupId + + "\",\n" + + " \"description\": \"test model\",\n" + + " \"connector_id\": \"" + + connectorId + + "\"\n" + + "}" + ); + Response response = executeRequest(request); + + String resp = readResponse(response); + + return JsonParser.parseString(resp).getAsJsonObject().get("model_id").getAsString(); + } + + private String setUpAgent(String modelId) throws IOException { + Request request = new Request("POST", "/_plugins/_ml/agents/_register"); + request + .setJsonEntity( + "{\n" + + " \"name\": \"integTest-agent\",\n" + + " \"type\": \"conversational\",\n" + + " \"description\": \"this is a test agent\",\n" + + " \"llm\": {\n" + + " \"model_id\": \"" + + modelId + + "\",\n" + + " \"parameters\": {\n" + + " \"max_iteration\": \"5\",\n" + + " \"stop_when_no_tool_found\": \"true\",\n" + + " \"response_filter\": \"$.completion\"\n" + + " }\n" + + " },\n" + + " \"tools\": [\n" + + " {\n" + + " \"type\": \"" + + toolType() + + "\",\n" + + " \"name\": \"" + + toolType() + + "\",\n" + + " \"include_output_in_agent_response\": true,\n" + + " \"description\": \"tool description\"\n" + + " }\n" + + " ],\n" + + " \"memory\": {\n" + + " \"type\": \"conversation_index\"\n" + + " }\n" + + "}" + ); + Response response = executeRequest(request); + + String resp = readResponse(response); + + return JsonParser.parseString(resp).getAsJsonObject().get("agent_id").getAsString(); + } + + public static Response executeRequest(Request request) throws IOException { + RequestOptions.Builder builder = RequestOptions.DEFAULT.toBuilder(); + builder.addHeader("Content-Type", "application/json"); + request.setOptions(builder); + return getRestClient().performRequest(request); + } + + public static String readResponse(Response response) throws IOException { + StringBuilder sb = new StringBuilder(); + + try (BufferedReader br = new BufferedReader(new InputStreamReader(response.getEntity().getContent()))) { + String line; + while ((line = br.readLine()) != null) { + sb.append(line); + } + } + return sb.toString(); + } +} diff --git a/src/test/java/org/opensearch/agent/integtest/VisualizationsToolIT.java b/src/test/java/org/opensearch/agent/integtest/VisualizationsToolIT.java new file mode 100644 index 00000000..221723b2 --- /dev/null +++ b/src/test/java/org/opensearch/agent/integtest/VisualizationsToolIT.java @@ -0,0 +1,109 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.integtest; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; +import java.util.UUID; + +import org.apache.commons.lang3.tuple.Pair; +import org.junit.Assert; +import org.opensearch.agent.tools.VisualizationsTool; +import org.opensearch.client.Request; +import org.opensearch.client.RequestOptions; +import org.opensearch.client.Response; +import org.opensearch.core.rest.RestStatus; + +import com.google.gson.JsonParser; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class VisualizationsToolIT extends ToolIntegrationTest { + @Override + List promptHandlers() { + return List.of(new PromptHandler() { + @Override + Pair questionAndInput() { + return Pair.of("can you show me RAM info with visualization?", "RAM"); + } + }, new PromptHandler() { + @Override + Pair questionAndInput() { + return Pair.of("how about the sales about this month?", "sales"); + } + }); + } + + String toolType() { + return VisualizationsTool.TYPE; + } + + public void testVisualizationNotFound() throws IOException { + Request request = new Request("POST", "/_plugins/_ml/agents/" + agentId + "/_execute"); + request.setJsonEntity("{\"parameters\":{\"question\":\"can you show me RAM info with visualization?\"}}"); + RequestOptions.Builder builder = RequestOptions.DEFAULT.toBuilder(); + builder.addHeader("Content-Type", "application/json"); + request.setOptions(builder); + Response response = getRestClient().performRequest(request); + String responseStr = readResponse(response); + String toolOutput = extractAdditionalInfo(responseStr); + Assert.assertEquals("No Visualization found", toolOutput); + } + + public void testVisualizationFound() throws IOException { + String title = "[eCommerce] Sales by Category"; + String id = UUID.randomUUID().toString(); + prepareVisualization(title, id); + Request request = new Request("POST", "/_plugins/_ml/agents/" + agentId + "/_execute"); + request.setJsonEntity("{\"parameters\":{\"question\":\"how about the sales about this month?\"}}"); + RequestOptions.Builder builder = RequestOptions.DEFAULT.toBuilder(); + builder.addHeader("Content-Type", "application/json"); + request.setOptions(builder); + Response response = getRestClient().performRequest(request); + String responseStr = readResponse(response); + String toolOutput = extractAdditionalInfo(responseStr); + Assert.assertEquals("Title,Id\n" + String.format(Locale.ROOT, "%s,%s\n", title, id), toolOutput); + } + + private void prepareVisualization(String title, String id) throws IOException { + Request request = new Request("POST", ".kibana/_doc/" + id); + request + .setJsonEntity( + "{\n" + + " \"visualization\": {\n" + + " \"title\": \"" + + title + + "\"\n" + + " },\n" + + " \"type\": \"visualization\"\n" + + "}" + ); + Response response = executeRequest(request); + Assert.assertEquals(response.getStatusLine().getStatusCode(), RestStatus.OK.getStatus()); + } + + private String extractAdditionalInfo(String responseStr) { + return JsonParser + .parseString(responseStr) + .getAsJsonObject() + .get("inference_results") + .getAsJsonArray() + .get(0) + .getAsJsonObject() + .get("output") + .getAsJsonArray() + .get(0) + .getAsJsonObject() + .get("dataAsMap") + .getAsJsonObject() + .get("additional_info") + .getAsJsonObject() + .get(String.format(Locale.ROOT, "%s.output", toolType())) + .getAsString(); + } +} From da5e9614e9807cffa972b3156a365782fef27133 Mon Sep 17 00:00:00 2001 From: Hailong Cui Date: Fri, 29 Dec 2023 18:35:24 +0800 Subject: [PATCH 2/5] update rest status Signed-off-by: Hailong Cui --- .../org/opensearch/agent/integtest/VisualizationsToolIT.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/test/java/org/opensearch/agent/integtest/VisualizationsToolIT.java b/src/test/java/org/opensearch/agent/integtest/VisualizationsToolIT.java index 221723b2..51bee782 100644 --- a/src/test/java/org/opensearch/agent/integtest/VisualizationsToolIT.java +++ b/src/test/java/org/opensearch/agent/integtest/VisualizationsToolIT.java @@ -84,7 +84,7 @@ private void prepareVisualization(String title, String id) throws IOException { + "}" ); Response response = executeRequest(request); - Assert.assertEquals(response.getStatusLine().getStatusCode(), RestStatus.OK.getStatus()); + Assert.assertEquals(response.getStatusLine().getStatusCode(), RestStatus.CREATED.getStatus()); } private String extractAdditionalInfo(String responseStr) { From bbe44fbd74efd731996debb4468b20a585b9daed Mon Sep 17 00:00:00 2001 From: Hailong Cui Date: Sat, 30 Dec 2023 00:08:39 +0800 Subject: [PATCH 3/5] add refresh Signed-off-by: Hailong Cui --- .../org/opensearch/agent/integtest/VisualizationsToolIT.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/test/java/org/opensearch/agent/integtest/VisualizationsToolIT.java b/src/test/java/org/opensearch/agent/integtest/VisualizationsToolIT.java index 51bee782..7f58377c 100644 --- a/src/test/java/org/opensearch/agent/integtest/VisualizationsToolIT.java +++ b/src/test/java/org/opensearch/agent/integtest/VisualizationsToolIT.java @@ -71,7 +71,7 @@ public void testVisualizationFound() throws IOException { } private void prepareVisualization(String title, String id) throws IOException { - Request request = new Request("POST", ".kibana/_doc/" + id); + Request request = new Request("POST", String.format(Locale.ROOT, ".kibana/_doc/%s?refresh=true", id)); request .setJsonEntity( "{\n" From a91e705241e0fdc2db82bf50aeddebb1e5ecdeb4 Mon Sep 17 00:00:00 2001 From: Hailong Cui Date: Tue, 2 Jan 2024 14:19:34 +0800 Subject: [PATCH 4/5] merge from main Signed-off-by: Hailong Cui --- .../agent/integtest/PromptHandler.java | 29 -- .../agent/integtest/ToolIntegrationTest.java | 303 ------------------ .../opensearch/integTest/MockHttpServer.java | 52 +++ .../opensearch/integTest/PromptHandler.java | 61 ++++ .../integTest/ToolIntegrationTest.java | 226 +++++++++++++ .../VisualizationsToolIT.java | 56 ++-- 6 files changed, 366 insertions(+), 361 deletions(-) delete mode 100644 src/test/java/org/opensearch/agent/integtest/PromptHandler.java delete mode 100644 src/test/java/org/opensearch/agent/integtest/ToolIntegrationTest.java create mode 100644 src/test/java/org/opensearch/integTest/MockHttpServer.java create mode 100644 src/test/java/org/opensearch/integTest/PromptHandler.java create mode 100644 src/test/java/org/opensearch/integTest/ToolIntegrationTest.java rename src/test/java/org/opensearch/{agent/integtest => integTest}/VisualizationsToolIT.java (65%) diff --git a/src/test/java/org/opensearch/agent/integtest/PromptHandler.java b/src/test/java/org/opensearch/agent/integtest/PromptHandler.java deleted file mode 100644 index 28c8acab..00000000 --- a/src/test/java/org/opensearch/agent/integtest/PromptHandler.java +++ /dev/null @@ -1,29 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.agent.integtest; - -import org.apache.commons.lang3.tuple.Pair; - -import com.google.gson.annotations.SerializedName; - -import lombok.Data; - -public abstract class PromptHandler { - - boolean apply(String prompt) { - return prompt.contains(questionAndInput().getKey()); - } - - abstract Pair questionAndInput(); - - @Data - static class LLMResponse { - String completion; - @SerializedName("stop_reason") - String stopReason = "stop_sequence"; - String stop = "\\n\\nHuman:"; - } -} diff --git a/src/test/java/org/opensearch/agent/integtest/ToolIntegrationTest.java b/src/test/java/org/opensearch/agent/integtest/ToolIntegrationTest.java deleted file mode 100644 index f06f2f9b..00000000 --- a/src/test/java/org/opensearch/agent/integtest/ToolIntegrationTest.java +++ /dev/null @@ -1,303 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.agent.integtest; - -import java.io.BufferedReader; -import java.io.IOException; -import java.io.InputStream; -import java.io.InputStreamReader; -import java.net.InetAddress; -import java.net.InetSocketAddress; -import java.nio.charset.StandardCharsets; -import java.util.List; -import java.util.Locale; -import java.util.Map; -import java.util.UUID; -import java.util.concurrent.TimeUnit; - -import org.junit.After; -import org.junit.Before; -import org.opensearch.client.Request; -import org.opensearch.client.RequestOptions; -import org.opensearch.client.Response; -import org.opensearch.test.OpenSearchIntegTestCase; - -import com.google.gson.Gson; -import com.google.gson.JsonParser; -import com.sun.net.httpserver.HttpServer; - -import lombok.extern.log4j.Log4j2; - -@Log4j2 -public abstract class ToolIntegrationTest extends OpenSearchIntegTestCase { - protected HttpServer server; - protected String modelId; - protected String agentId; - protected String modelGroupId; - protected String connectorId; - - private final Gson gson = new Gson(); - - abstract List promptHandlers(); - - abstract String toolType(); - - @Before - public void setupAgent() throws IOException, InterruptedException { - setupMockLLM(); - clusterSettings(false); - try { - connectorId = setUpConnector(); - } catch (Exception e) { - // Wait for ML encryption master key has been initialized - TimeUnit.SECONDS.sleep(10); - connectorId = setUpConnector(); - } - modelGroupId = setupModelGroup(); - modelId = setupModel(connectorId, modelGroupId); - agentId = setUpAgent(modelId); - } - - @After - public void cleanUpClusterSetting() throws IOException { - clusterSettings(true); - } - - @After - public void stopMockLLM() { - server.stop(1); - } - - private void setupMockLLM() throws IOException { - server = HttpServer.create(new InetSocketAddress(InetAddress.getLoopbackAddress(), 0), 0); - - server.createContext("/invoke", exchange -> { - InputStream ins = exchange.getRequestBody(); - String req = new String(ins.readAllBytes(), StandardCharsets.UTF_8); - Map map = gson.fromJson(req, Map.class); - String prompt = map.get("prompt"); - log.debug("prompt received: {}", prompt); - - String llmRes = ""; - for (PromptHandler promptHandler : promptHandlers()) { - if (promptHandler.apply(prompt)) { - PromptHandler.LLMResponse llmResponse = new PromptHandler.LLMResponse(); - if (prompt.contains("TOOL RESPONSE: ")) { - llmResponse - .setCompletion( - "```json{\n" - + " \"thought\": \"Thought: Now I know the final answer\",\n" - + " \"final_answer\": \"final answer\"\n" - + "}```" - ); - } else { - String actionInput = promptHandler.questionAndInput().getValue(); - llmResponse - .setCompletion( - "```json{\n" - + " \"thought\": \"Thought: Let me use tool to figure out\",\n" - + " \"action\": \"" - + toolType() - + "\",\n" - + " \"action_input\": \"" - + actionInput - + "\"\n" - + "}```" - ); - } - llmRes = gson.toJson(llmResponse); - break; - } - } - byte[] llmResBytes = llmRes.getBytes(StandardCharsets.UTF_8); - exchange.sendResponseHeaders(200, llmResBytes.length); - exchange.getResponseBody().write(llmResBytes); - exchange.close(); - }); - server.start(); - } - - private String setUpConnector() throws IOException { - String url = String.format(Locale.ROOT, "http://127.0.0.1:%d/invoke", server.getAddress().getPort()); - Request request = new Request("POST", "/_plugins/_ml/connectors/_create"); - request - .setJsonEntity( - "{\n" - + " \"name\": \"BedRock test claude Connector\",\n" - + " \"description\": \"The connector to BedRock service for claude model\",\n" - + " \"version\": 1,\n" - + " \"protocol\": \"aws_sigv4\",\n" - + " \"parameters\": {\n" - + " \"region\": \"us-east-1\",\n" - + " \"service_name\": \"bedrock\",\n" - + " \"anthropic_version\": \"bedrock-2023-05-31\",\n" - + " \"endpoint\": \"bedrock.us-east-1.amazonaws.com\",\n" - + " \"auth\": \"Sig_V4\",\n" - + " \"content_type\": \"application/json\",\n" - + " \"max_tokens_to_sample\": 8000,\n" - + " \"temperature\": 0.0001,\n" - + " \"response_filter\": \"$.completion\"\n" - + " },\n" - + " \"credential\": {\n" - + " \"access_key\": \"\",\n" - + " \"secret_key\": \"\"\n" - + " },\n" - + " \"actions\": [\n" - + " {\n" - + " \"action_type\": \"predict\",\n" - + " \"method\": \"POST\",\n" - + " \"url\": \"" - + url - + "\",\n" - + " \"headers\": {\n" - + " \"content-type\": \"application/json\",\n" - + " \"x-amz-content-sha256\": \"required\"\n" - + " },\n" - + " \"request_body\": \"{\\\"prompt\\\":\\\"${parameters.prompt}\\\", \\\"max_tokens_to_sample\\\":${parameters.max_tokens_to_sample}, \\\"temperature\\\":${parameters.temperature}, \\\"anthropic_version\\\":\\\"${parameters.anthropic_version}\\\" }\"\n" - + " }\n" - + " ]\n" - + "}" - ); - Response response = executeRequest(request); - - return JsonParser.parseString(readResponse(response)).getAsJsonObject().get("connector_id").getAsString(); - } - - private void clusterSettings(boolean reset) throws IOException { - Request request = new Request("PUT", "_cluster/settings"); - if (!reset) { - request - .setJsonEntity( - "{\n" - + " \"persistent\": {\n" - + " \"plugins.ml_commons.only_run_on_ml_node\": false,\n" - + " \"plugins.ml_commons.memory_feature_enabled\": true,\n" - + " \"plugins.ml_commons.trusted_connector_endpoints_regex\": [\n" - + " \"^.*$\"\n" - + " ]\n" - + " }\n" - + "}" - ); - } else { - request - .setJsonEntity( - "{\n" - + " \"persistent\": {\n" - + " \"plugins.ml_commons.only_run_on_ml_node\": null,\n" - + " \"plugins.ml_commons.memory_feature_enabled\": null,\n" - + " \"plugins.ml_commons.trusted_connector_endpoints_regex\": null" - + " }\n" - + "}" - ); - } - RequestOptions.Builder builder = RequestOptions.DEFAULT.toBuilder(); - builder.addHeader("Content-Type", "application/json"); - request.setOptions(builder); - getRestClient().performRequest(request); - } - - private String setupModelGroup() throws IOException { - Request request = new Request("POST", "/_plugins/_ml/model_groups/_register"); - request - .setJsonEntity( - "{\n" - + " \"name\": \"test_model_group_bedrock-" - + UUID.randomUUID().toString() - + "\",\n" - + " \"description\": \"This is a public model group\"\n" - + "}" - ); - Response response = executeRequest(request); - - String resp = readResponse(response); - - return JsonParser.parseString(resp).getAsJsonObject().get("model_group_id").getAsString(); - } - - private String setupModel(String connectorId, String modelGroupId) throws IOException { - Request request = new Request("POST", "/_plugins/_ml/models/_register?deploy=true"); - request - .setJsonEntity( - "{\n" - + " \"name\": \"Bedrock Claude V2 model\",\n" - + " \"function_name\": \"remote\",\n" - + " \"model_group_id\": \"" - + modelGroupId - + "\",\n" - + " \"description\": \"test model\",\n" - + " \"connector_id\": \"" - + connectorId - + "\"\n" - + "}" - ); - Response response = executeRequest(request); - - String resp = readResponse(response); - - return JsonParser.parseString(resp).getAsJsonObject().get("model_id").getAsString(); - } - - private String setUpAgent(String modelId) throws IOException { - Request request = new Request("POST", "/_plugins/_ml/agents/_register"); - request - .setJsonEntity( - "{\n" - + " \"name\": \"integTest-agent\",\n" - + " \"type\": \"conversational\",\n" - + " \"description\": \"this is a test agent\",\n" - + " \"llm\": {\n" - + " \"model_id\": \"" - + modelId - + "\",\n" - + " \"parameters\": {\n" - + " \"max_iteration\": \"5\",\n" - + " \"stop_when_no_tool_found\": \"true\",\n" - + " \"response_filter\": \"$.completion\"\n" - + " }\n" - + " },\n" - + " \"tools\": [\n" - + " {\n" - + " \"type\": \"" - + toolType() - + "\",\n" - + " \"name\": \"" - + toolType() - + "\",\n" - + " \"include_output_in_agent_response\": true,\n" - + " \"description\": \"tool description\"\n" - + " }\n" - + " ],\n" - + " \"memory\": {\n" - + " \"type\": \"conversation_index\"\n" - + " }\n" - + "}" - ); - Response response = executeRequest(request); - - String resp = readResponse(response); - - return JsonParser.parseString(resp).getAsJsonObject().get("agent_id").getAsString(); - } - - public static Response executeRequest(Request request) throws IOException { - RequestOptions.Builder builder = RequestOptions.DEFAULT.toBuilder(); - builder.addHeader("Content-Type", "application/json"); - request.setOptions(builder); - return getRestClient().performRequest(request); - } - - public static String readResponse(Response response) throws IOException { - StringBuilder sb = new StringBuilder(); - - try (BufferedReader br = new BufferedReader(new InputStreamReader(response.getEntity().getContent()))) { - String line; - while ((line = br.readLine()) != null) { - sb.append(line); - } - } - return sb.toString(); - } -} diff --git a/src/test/java/org/opensearch/integTest/MockHttpServer.java b/src/test/java/org/opensearch/integTest/MockHttpServer.java new file mode 100644 index 00000000..f64adcd1 --- /dev/null +++ b/src/test/java/org/opensearch/integTest/MockHttpServer.java @@ -0,0 +1,52 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.integTest; + +import java.io.IOException; +import java.io.InputStream; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Map; + +import com.google.gson.Gson; +import com.sun.net.httpserver.HttpServer; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class MockHttpServer { + + private static Gson gson = new Gson(); + + public static HttpServer setupMockLLM(List promptHandlers) throws IOException { + HttpServer server = HttpServer.create(new InetSocketAddress(InetAddress.getLoopbackAddress(), 0), 0); + + server.createContext("/invoke", exchange -> { + InputStream ins = exchange.getRequestBody(); + String req = new String(ins.readAllBytes(), StandardCharsets.UTF_8); + Map map = gson.fromJson(req, Map.class); + String prompt = map.get("prompt"); + log.debug("prompt received: {}", prompt); + + String llmRes = ""; + for (PromptHandler promptHandler : promptHandlers) { + if (promptHandler.apply(prompt)) { + PromptHandler.LLMResponse llmResponse = new PromptHandler.LLMResponse(); + llmResponse.setCompletion(promptHandler.response(prompt)); + llmRes = gson.toJson(llmResponse); + break; + } + } + byte[] llmResBytes = llmRes.getBytes(StandardCharsets.UTF_8); + exchange.sendResponseHeaders(200, llmResBytes.length); + exchange.getResponseBody().write(llmResBytes); + exchange.close(); + }); + return server; + } +} diff --git a/src/test/java/org/opensearch/integTest/PromptHandler.java b/src/test/java/org/opensearch/integTest/PromptHandler.java new file mode 100644 index 00000000..da6aa19d --- /dev/null +++ b/src/test/java/org/opensearch/integTest/PromptHandler.java @@ -0,0 +1,61 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.integTest; + +import com.google.gson.annotations.SerializedName; + +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; + +public class PromptHandler { + + boolean apply(String prompt) { + return prompt.contains(toolInput().getQuestion()); + } + + ToolInput toolInput() { + return new ToolInput(); + } + + String response(String prompt) { + if (prompt.contains("TOOL RESPONSE: ")) { + return "```json{\n" + + " \"thought\": \"Thought: Now I know the final answer\",\n" + + " \"final_answer\": \"final answer\"\n" + + "}```"; + } else { + return "```json{\n" + + " \"thought\": \"Thought: Let me use tool to figure out\",\n" + + " \"action\": \"" + + this.toolInput().getToolType() + + "\",\n" + + " \"action_input\": \"" + + this.toolInput().getToolInput() + + "\"\n" + + "}```"; + } + } + + @Builder + @NoArgsConstructor + @AllArgsConstructor + @Data + static class ToolInput { + String question; + String toolType; + String toolInput; + } + + @Data + static class LLMResponse { + String completion; + @SerializedName("stop_reason") + String stopReason = "stop_sequence"; + String stop = "\\n\\nHuman:"; + } +} diff --git a/src/test/java/org/opensearch/integTest/ToolIntegrationTest.java b/src/test/java/org/opensearch/integTest/ToolIntegrationTest.java new file mode 100644 index 00000000..e30b62f1 --- /dev/null +++ b/src/test/java/org/opensearch/integTest/ToolIntegrationTest.java @@ -0,0 +1,226 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.integTest; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStreamReader; +import java.util.List; +import java.util.Locale; +import java.util.UUID; +import java.util.concurrent.TimeUnit; + +import org.junit.After; +import org.junit.Before; +import org.opensearch.client.Request; +import org.opensearch.client.RequestOptions; +import org.opensearch.client.Response; + +import com.google.gson.Gson; +import com.google.gson.JsonParser; +import com.sun.net.httpserver.HttpServer; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public abstract class ToolIntegrationTest extends BaseAgentToolsIT { + protected HttpServer server; + protected String modelId; + protected String agentId; + protected String modelGroupId; + protected String connectorId; + + private final Gson gson = new Gson(); + + abstract List promptHandlers(); + + abstract String toolType(); + + @Before + public void setupTestAgent() throws IOException, InterruptedException { + server = MockHttpServer.setupMockLLM(promptHandlers()); + server.start(); + clusterSettings(false); + try { + connectorId = setUpConnector(); + } catch (Exception e) { + // Wait for ML encryption master key has been initialized + TimeUnit.SECONDS.sleep(10); + connectorId = setUpConnector(); + } + modelGroupId = setupModelGroup(); + modelId = setupLLMModel(connectorId, modelGroupId); + // wait for model to get deployed + TimeUnit.SECONDS.sleep(1); + agentId = setupConversationalAgent(modelId); + log.info("model_id: {}, agent_id: {}", modelId, agentId); + } + + @After + public void cleanUpClusterSetting() throws IOException { + clusterSettings(true); + } + + @After + public void stopMockLLM() { + server.stop(1); + } + + private String setUpConnector() { + String url = String.format(Locale.ROOT, "http://127.0.0.1:%d/invoke", server.getAddress().getPort()); + return createConnector( + "{\n" + + " \"name\": \"BedRock test claude Connector\",\n" + + " \"description\": \"The connector to BedRock service for claude model\",\n" + + " \"version\": 1,\n" + + " \"protocol\": \"aws_sigv4\",\n" + + " \"parameters\": {\n" + + " \"region\": \"us-east-1\",\n" + + " \"service_name\": \"bedrock\",\n" + + " \"anthropic_version\": \"bedrock-2023-05-31\",\n" + + " \"endpoint\": \"bedrock.us-east-1.amazonaws.com\",\n" + + " \"auth\": \"Sig_V4\",\n" + + " \"content_type\": \"application/json\",\n" + + " \"max_tokens_to_sample\": 8000,\n" + + " \"temperature\": 0.0001,\n" + + " \"response_filter\": \"$.completion\"\n" + + " },\n" + + " \"credential\": {\n" + + " \"access_key\": \"\",\n" + + " \"secret_key\": \"\"\n" + + " },\n" + + " \"actions\": [\n" + + " {\n" + + " \"action_type\": \"predict\",\n" + + " \"method\": \"POST\",\n" + + " \"url\": \"" + + url + + "\",\n" + + " \"headers\": {\n" + + " \"content-type\": \"application/json\",\n" + + " \"x-amz-content-sha256\": \"required\"\n" + + " },\n" + + " \"request_body\": \"{\\\"prompt\\\":\\\"${parameters.prompt}\\\", \\\"max_tokens_to_sample\\\":${parameters.max_tokens_to_sample}, \\\"temperature\\\":${parameters.temperature}, \\\"anthropic_version\\\":\\\"${parameters.anthropic_version}\\\" }\"\n" + + " }\n" + + " ]\n" + + "}" + ); + } + + private void clusterSettings(boolean clean) throws IOException { + if (!clean) { + updateClusterSettings("plugins.ml_commons.only_run_on_ml_node", false); + updateClusterSettings("plugins.ml_commons.memory_feature_enabled", true); + updateClusterSettings("plugins.ml_commons.trusted_connector_endpoints_regex", List.of("^.*$")); + } else { + updateClusterSettings("plugins.ml_commons.only_run_on_ml_node", null); + updateClusterSettings("plugins.ml_commons.memory_feature_enabled", null); + updateClusterSettings("plugins.ml_commons.trusted_connector_endpoints_regex", null); + } + } + + private String setupModelGroup() throws IOException { + Request request = new Request("POST", "/_plugins/_ml/model_groups/_register"); + request + .setJsonEntity( + "{\n" + + " \"name\": \"test_model_group_bedrock-" + + UUID.randomUUID() + + "\",\n" + + " \"description\": \"This is a public model group\"\n" + + "}" + ); + Response response = executeRequest(request); + + String resp = readResponse(response); + + return JsonParser.parseString(resp).getAsJsonObject().get("model_group_id").getAsString(); + } + + private String setupLLMModel(String connectorId, String modelGroupId) throws IOException { + Request request = new Request("POST", "/_plugins/_ml/models/_register?deploy=true"); + request + .setJsonEntity( + "{\n" + + " \"name\": \"Bedrock Claude V2 model\",\n" + + " \"function_name\": \"remote\",\n" + + " \"model_group_id\": \"" + + modelGroupId + + "\",\n" + + " \"description\": \"test model\",\n" + + " \"connector_id\": \"" + + connectorId + + "\"\n" + + "}" + ); + Response response = executeRequest(request); + + String resp = readResponse(response); + + return JsonParser.parseString(resp).getAsJsonObject().get("model_id").getAsString(); + } + + private String setupConversationalAgent(String modelId) throws IOException { + Request request = new Request("POST", "/_plugins/_ml/agents/_register"); + request + .setJsonEntity( + "{\n" + + " \"name\": \"integTest-agent\",\n" + + " \"type\": \"conversational\",\n" + + " \"description\": \"this is a test agent\",\n" + + " \"llm\": {\n" + + " \"model_id\": \"" + + modelId + + "\",\n" + + " \"parameters\": {\n" + + " \"max_iteration\": \"5\",\n" + + " \"stop_when_no_tool_found\": \"true\",\n" + + " \"response_filter\": \"$.completion\"\n" + + " }\n" + + " },\n" + + " \"tools\": [\n" + + " {\n" + + " \"type\": \"" + + toolType() + + "\",\n" + + " \"name\": \"" + + toolType() + + "\",\n" + + " \"include_output_in_agent_response\": true,\n" + + " \"description\": \"tool description\"\n" + + " }\n" + + " ],\n" + + " \"memory\": {\n" + + " \"type\": \"conversation_index\"\n" + + " }\n" + + "}" + ); + Response response = executeRequest(request); + + String resp = readResponse(response); + + return JsonParser.parseString(resp).getAsJsonObject().get("agent_id").getAsString(); + } + + public static Response executeRequest(Request request) throws IOException { + RequestOptions.Builder builder = RequestOptions.DEFAULT.toBuilder(); + builder.addHeader("Content-Type", "application/json"); + request.setOptions(builder); + return client().performRequest(request); + } + + public static String readResponse(Response response) throws IOException { + StringBuilder sb = new StringBuilder(); + + try (BufferedReader br = new BufferedReader(new InputStreamReader(response.getEntity().getContent()))) { + String line; + while ((line = br.readLine()) != null) { + sb.append(line); + } + } + return sb.toString(); + } +} diff --git a/src/test/java/org/opensearch/agent/integtest/VisualizationsToolIT.java b/src/test/java/org/opensearch/integTest/VisualizationsToolIT.java similarity index 65% rename from src/test/java/org/opensearch/agent/integtest/VisualizationsToolIT.java rename to src/test/java/org/opensearch/integTest/VisualizationsToolIT.java index 7f58377c..d0324ad5 100644 --- a/src/test/java/org/opensearch/agent/integtest/VisualizationsToolIT.java +++ b/src/test/java/org/opensearch/integTest/VisualizationsToolIT.java @@ -3,18 +3,16 @@ * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.agent.integtest; +package org.opensearch.integTest; import java.io.IOException; import java.util.List; import java.util.Locale; import java.util.UUID; -import org.apache.commons.lang3.tuple.Pair; import org.junit.Assert; import org.opensearch.agent.tools.VisualizationsTool; import org.opensearch.client.Request; -import org.opensearch.client.RequestOptions; import org.opensearch.client.Response; import org.opensearch.core.rest.RestStatus; @@ -28,13 +26,23 @@ public class VisualizationsToolIT extends ToolIntegrationTest { List promptHandlers() { return List.of(new PromptHandler() { @Override - Pair questionAndInput() { - return Pair.of("can you show me RAM info with visualization?", "RAM"); + ToolInput toolInput() { + return ToolInput + .builder() + .toolType(VisualizationsTool.TYPE) + .toolInput("RAM") + .question("can you show me RAM info with visualization?") + .build(); } }, new PromptHandler() { @Override - Pair questionAndInput() { - return Pair.of("how about the sales about this month?", "sales"); + ToolInput toolInput() { + return ToolInput + .builder() + .toolType(VisualizationsTool.TYPE) + .toolInput("sales") + .question("how about the sales about this month?") + .build(); } }); } @@ -46,10 +54,7 @@ String toolType() { public void testVisualizationNotFound() throws IOException { Request request = new Request("POST", "/_plugins/_ml/agents/" + agentId + "/_execute"); request.setJsonEntity("{\"parameters\":{\"question\":\"can you show me RAM info with visualization?\"}}"); - RequestOptions.Builder builder = RequestOptions.DEFAULT.toBuilder(); - builder.addHeader("Content-Type", "application/json"); - request.setOptions(builder); - Response response = getRestClient().performRequest(request); + Response response = executeRequest(request); String responseStr = readResponse(response); String toolOutput = extractAdditionalInfo(responseStr); Assert.assertEquals("No Visualization found", toolOutput); @@ -61,29 +66,22 @@ public void testVisualizationFound() throws IOException { prepareVisualization(title, id); Request request = new Request("POST", "/_plugins/_ml/agents/" + agentId + "/_execute"); request.setJsonEntity("{\"parameters\":{\"question\":\"how about the sales about this month?\"}}"); - RequestOptions.Builder builder = RequestOptions.DEFAULT.toBuilder(); - builder.addHeader("Content-Type", "application/json"); - request.setOptions(builder); - Response response = getRestClient().performRequest(request); + Response response = executeRequest(request); String responseStr = readResponse(response); String toolOutput = extractAdditionalInfo(responseStr); Assert.assertEquals("Title,Id\n" + String.format(Locale.ROOT, "%s,%s\n", title, id), toolOutput); } - private void prepareVisualization(String title, String id) throws IOException { - Request request = new Request("POST", String.format(Locale.ROOT, ".kibana/_doc/%s?refresh=true", id)); - request - .setJsonEntity( - "{\n" - + " \"visualization\": {\n" - + " \"title\": \"" - + title - + "\"\n" - + " },\n" - + " \"type\": \"visualization\"\n" - + "}" - ); - Response response = executeRequest(request); + private void prepareVisualization(String title, String id) { + String body = "{\n" + + " \"visualization\": {\n" + + " \"title\": \"" + + title + + "\"\n" + + " },\n" + + " \"type\": \"visualization\"\n" + + "}"; + Response response = makeRequest(client(), "POST", String.format(Locale.ROOT, ".kibana/_doc/%s?refresh=true", id), null, body, null); Assert.assertEquals(response.getStatusLine().getStatusCode(), RestStatus.CREATED.getStatus()); } From 0bb88ebbf35fe4cacf8a6ad0445432bb13f374bb Mon Sep 17 00:00:00 2001 From: Hailong Cui Date: Wed, 3 Jan 2024 15:44:26 +0800 Subject: [PATCH 5/5] rename variable name Signed-off-by: Hailong Cui --- .../org/opensearch/integTest/PromptHandler.java | 16 ++++++++-------- .../integTest/ToolIntegrationTest.java | 13 +++---------- .../integTest/VisualizationsToolIT.java | 16 ++++++++-------- 3 files changed, 19 insertions(+), 26 deletions(-) diff --git a/src/test/java/org/opensearch/integTest/PromptHandler.java b/src/test/java/org/opensearch/integTest/PromptHandler.java index da6aa19d..a3f9314d 100644 --- a/src/test/java/org/opensearch/integTest/PromptHandler.java +++ b/src/test/java/org/opensearch/integTest/PromptHandler.java @@ -15,11 +15,11 @@ public class PromptHandler { boolean apply(String prompt) { - return prompt.contains(toolInput().getQuestion()); + return prompt.contains(llmThought().getQuestion()); } - ToolInput toolInput() { - return new ToolInput(); + LLMThought llmThought() { + return new LLMThought(); } String response(String prompt) { @@ -32,10 +32,10 @@ String response(String prompt) { return "```json{\n" + " \"thought\": \"Thought: Let me use tool to figure out\",\n" + " \"action\": \"" - + this.toolInput().getToolType() + + this.llmThought().getAction() + "\",\n" + " \"action_input\": \"" - + this.toolInput().getToolInput() + + this.llmThought().getActionInput() + "\"\n" + "}```"; } @@ -45,10 +45,10 @@ String response(String prompt) { @NoArgsConstructor @AllArgsConstructor @Data - static class ToolInput { + static class LLMThought { String question; - String toolType; - String toolInput; + String action; + String actionInput; } @Data diff --git a/src/test/java/org/opensearch/integTest/ToolIntegrationTest.java b/src/test/java/org/opensearch/integTest/ToolIntegrationTest.java index e30b62f1..aba39573 100644 --- a/src/test/java/org/opensearch/integTest/ToolIntegrationTest.java +++ b/src/test/java/org/opensearch/integTest/ToolIntegrationTest.java @@ -5,9 +5,8 @@ package org.opensearch.integTest; -import java.io.BufferedReader; import java.io.IOException; -import java.io.InputStreamReader; +import java.io.InputStream; import java.util.List; import java.util.Locale; import java.util.UUID; @@ -213,14 +212,8 @@ public static Response executeRequest(Request request) throws IOException { } public static String readResponse(Response response) throws IOException { - StringBuilder sb = new StringBuilder(); - - try (BufferedReader br = new BufferedReader(new InputStreamReader(response.getEntity().getContent()))) { - String line; - while ((line = br.readLine()) != null) { - sb.append(line); - } + try (InputStream ins = response.getEntity().getContent()) { + return String.join("", org.opensearch.common.io.Streams.readAllLines(ins)); } - return sb.toString(); } } diff --git a/src/test/java/org/opensearch/integTest/VisualizationsToolIT.java b/src/test/java/org/opensearch/integTest/VisualizationsToolIT.java index d0324ad5..e7f54521 100644 --- a/src/test/java/org/opensearch/integTest/VisualizationsToolIT.java +++ b/src/test/java/org/opensearch/integTest/VisualizationsToolIT.java @@ -26,21 +26,21 @@ public class VisualizationsToolIT extends ToolIntegrationTest { List promptHandlers() { return List.of(new PromptHandler() { @Override - ToolInput toolInput() { - return ToolInput + LLMThought llmThought() { + return LLMThought .builder() - .toolType(VisualizationsTool.TYPE) - .toolInput("RAM") + .action(VisualizationsTool.TYPE) + .actionInput("RAM") .question("can you show me RAM info with visualization?") .build(); } }, new PromptHandler() { @Override - ToolInput toolInput() { - return ToolInput + LLMThought llmThought() { + return LLMThought .builder() - .toolType(VisualizationsTool.TYPE) - .toolInput("sales") + .action(VisualizationsTool.TYPE) + .actionInput("sales") .question("how about the sales about this month?") .build(); }