diff --git a/build.gradle b/build.gradle index bd63a17f..591aefec 100644 --- a/build.gradle +++ b/build.gradle @@ -131,6 +131,7 @@ dependencies { // Test dependencies testImplementation "org.opensearch.test:framework:${opensearch_version}" testImplementation group: 'junit', name: 'junit', version: '4.13.2' + testImplementation group: 'org.json', name: 'json', version: '20231013' testImplementation group: 'org.mockito', name: 'mockito-core', version: '5.8.0' testImplementation group: 'org.mockito', name: 'mockito-inline', version: '5.2.0' testImplementation("net.bytebuddy:byte-buddy:1.14.7") diff --git a/src/main/java/org/opensearch/agent/tools/VisualizationsTool.java b/src/main/java/org/opensearch/agent/tools/VisualizationsTool.java index 31f5cf09..958f1eda 100644 --- a/src/main/java/org/opensearch/agent/tools/VisualizationsTool.java +++ b/src/main/java/org/opensearch/agent/tools/VisualizationsTool.java @@ -16,6 +16,7 @@ import org.opensearch.client.Client; import org.opensearch.client.Requests; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.Strings; import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.QueryBuilders; @@ -24,9 +25,6 @@ import org.opensearch.search.SearchHits; import org.opensearch.search.builder.SearchSourceBuilder; -import com.google.common.annotations.VisibleForTesting; -import com.google.common.base.Strings; - import lombok.Builder; import lombok.Getter; import lombok.Setter; @@ -113,7 +111,6 @@ public void onFailure(Exception e) { }); } - @VisibleForTesting String trimIdPrefix(String id) { id = Optional.ofNullable(id).orElse(""); if (id.startsWith(SAVED_OBJECT_TYPE)) { diff --git a/src/test/java/org/opensearch/integTest/MockHttpServer.java b/src/test/java/org/opensearch/integTest/MockHttpServer.java new file mode 100644 index 00000000..f64adcd1 --- /dev/null +++ b/src/test/java/org/opensearch/integTest/MockHttpServer.java @@ -0,0 +1,52 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.integTest; + +import java.io.IOException; +import java.io.InputStream; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Map; + +import com.google.gson.Gson; +import com.sun.net.httpserver.HttpServer; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class MockHttpServer { + + private static Gson gson = new Gson(); + + public static HttpServer setupMockLLM(List promptHandlers) throws IOException { + HttpServer server = HttpServer.create(new InetSocketAddress(InetAddress.getLoopbackAddress(), 0), 0); + + server.createContext("/invoke", exchange -> { + InputStream ins = exchange.getRequestBody(); + String req = new String(ins.readAllBytes(), StandardCharsets.UTF_8); + Map map = gson.fromJson(req, Map.class); + String prompt = map.get("prompt"); + log.debug("prompt received: {}", prompt); + + String llmRes = ""; + for (PromptHandler promptHandler : promptHandlers) { + if (promptHandler.apply(prompt)) { + PromptHandler.LLMResponse llmResponse = new PromptHandler.LLMResponse(); + llmResponse.setCompletion(promptHandler.response(prompt)); + llmRes = gson.toJson(llmResponse); + break; + } + } + byte[] llmResBytes = llmRes.getBytes(StandardCharsets.UTF_8); + exchange.sendResponseHeaders(200, llmResBytes.length); + exchange.getResponseBody().write(llmResBytes); + exchange.close(); + }); + return server; + } +} diff --git a/src/test/java/org/opensearch/integTest/PromptHandler.java b/src/test/java/org/opensearch/integTest/PromptHandler.java new file mode 100644 index 00000000..a3f9314d --- /dev/null +++ b/src/test/java/org/opensearch/integTest/PromptHandler.java @@ -0,0 +1,61 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.integTest; + +import com.google.gson.annotations.SerializedName; + +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; + +public class PromptHandler { + + boolean apply(String prompt) { + return prompt.contains(llmThought().getQuestion()); + } + + LLMThought llmThought() { + return new LLMThought(); + } + + String response(String prompt) { + if (prompt.contains("TOOL RESPONSE: ")) { + return "```json{\n" + + " \"thought\": \"Thought: Now I know the final answer\",\n" + + " \"final_answer\": \"final answer\"\n" + + "}```"; + } else { + return "```json{\n" + + " \"thought\": \"Thought: Let me use tool to figure out\",\n" + + " \"action\": \"" + + this.llmThought().getAction() + + "\",\n" + + " \"action_input\": \"" + + this.llmThought().getActionInput() + + "\"\n" + + "}```"; + } + } + + @Builder + @NoArgsConstructor + @AllArgsConstructor + @Data + static class LLMThought { + String question; + String action; + String actionInput; + } + + @Data + static class LLMResponse { + String completion; + @SerializedName("stop_reason") + String stopReason = "stop_sequence"; + String stop = "\\n\\nHuman:"; + } +} diff --git a/src/test/java/org/opensearch/integTest/ToolIntegrationTest.java b/src/test/java/org/opensearch/integTest/ToolIntegrationTest.java new file mode 100644 index 00000000..aba39573 --- /dev/null +++ b/src/test/java/org/opensearch/integTest/ToolIntegrationTest.java @@ -0,0 +1,219 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.integTest; + +import java.io.IOException; +import java.io.InputStream; +import java.util.List; +import java.util.Locale; +import java.util.UUID; +import java.util.concurrent.TimeUnit; + +import org.junit.After; +import org.junit.Before; +import org.opensearch.client.Request; +import org.opensearch.client.RequestOptions; +import org.opensearch.client.Response; + +import com.google.gson.Gson; +import com.google.gson.JsonParser; +import com.sun.net.httpserver.HttpServer; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public abstract class ToolIntegrationTest extends BaseAgentToolsIT { + protected HttpServer server; + protected String modelId; + protected String agentId; + protected String modelGroupId; + protected String connectorId; + + private final Gson gson = new Gson(); + + abstract List promptHandlers(); + + abstract String toolType(); + + @Before + public void setupTestAgent() throws IOException, InterruptedException { + server = MockHttpServer.setupMockLLM(promptHandlers()); + server.start(); + clusterSettings(false); + try { + connectorId = setUpConnector(); + } catch (Exception e) { + // Wait for ML encryption master key has been initialized + TimeUnit.SECONDS.sleep(10); + connectorId = setUpConnector(); + } + modelGroupId = setupModelGroup(); + modelId = setupLLMModel(connectorId, modelGroupId); + // wait for model to get deployed + TimeUnit.SECONDS.sleep(1); + agentId = setupConversationalAgent(modelId); + log.info("model_id: {}, agent_id: {}", modelId, agentId); + } + + @After + public void cleanUpClusterSetting() throws IOException { + clusterSettings(true); + } + + @After + public void stopMockLLM() { + server.stop(1); + } + + private String setUpConnector() { + String url = String.format(Locale.ROOT, "http://127.0.0.1:%d/invoke", server.getAddress().getPort()); + return createConnector( + "{\n" + + " \"name\": \"BedRock test claude Connector\",\n" + + " \"description\": \"The connector to BedRock service for claude model\",\n" + + " \"version\": 1,\n" + + " \"protocol\": \"aws_sigv4\",\n" + + " \"parameters\": {\n" + + " \"region\": \"us-east-1\",\n" + + " \"service_name\": \"bedrock\",\n" + + " \"anthropic_version\": \"bedrock-2023-05-31\",\n" + + " \"endpoint\": \"bedrock.us-east-1.amazonaws.com\",\n" + + " \"auth\": \"Sig_V4\",\n" + + " \"content_type\": \"application/json\",\n" + + " \"max_tokens_to_sample\": 8000,\n" + + " \"temperature\": 0.0001,\n" + + " \"response_filter\": \"$.completion\"\n" + + " },\n" + + " \"credential\": {\n" + + " \"access_key\": \"\",\n" + + " \"secret_key\": \"\"\n" + + " },\n" + + " \"actions\": [\n" + + " {\n" + + " \"action_type\": \"predict\",\n" + + " \"method\": \"POST\",\n" + + " \"url\": \"" + + url + + "\",\n" + + " \"headers\": {\n" + + " \"content-type\": \"application/json\",\n" + + " \"x-amz-content-sha256\": \"required\"\n" + + " },\n" + + " \"request_body\": \"{\\\"prompt\\\":\\\"${parameters.prompt}\\\", \\\"max_tokens_to_sample\\\":${parameters.max_tokens_to_sample}, \\\"temperature\\\":${parameters.temperature}, \\\"anthropic_version\\\":\\\"${parameters.anthropic_version}\\\" }\"\n" + + " }\n" + + " ]\n" + + "}" + ); + } + + private void clusterSettings(boolean clean) throws IOException { + if (!clean) { + updateClusterSettings("plugins.ml_commons.only_run_on_ml_node", false); + updateClusterSettings("plugins.ml_commons.memory_feature_enabled", true); + updateClusterSettings("plugins.ml_commons.trusted_connector_endpoints_regex", List.of("^.*$")); + } else { + updateClusterSettings("plugins.ml_commons.only_run_on_ml_node", null); + updateClusterSettings("plugins.ml_commons.memory_feature_enabled", null); + updateClusterSettings("plugins.ml_commons.trusted_connector_endpoints_regex", null); + } + } + + private String setupModelGroup() throws IOException { + Request request = new Request("POST", "/_plugins/_ml/model_groups/_register"); + request + .setJsonEntity( + "{\n" + + " \"name\": \"test_model_group_bedrock-" + + UUID.randomUUID() + + "\",\n" + + " \"description\": \"This is a public model group\"\n" + + "}" + ); + Response response = executeRequest(request); + + String resp = readResponse(response); + + return JsonParser.parseString(resp).getAsJsonObject().get("model_group_id").getAsString(); + } + + private String setupLLMModel(String connectorId, String modelGroupId) throws IOException { + Request request = new Request("POST", "/_plugins/_ml/models/_register?deploy=true"); + request + .setJsonEntity( + "{\n" + + " \"name\": \"Bedrock Claude V2 model\",\n" + + " \"function_name\": \"remote\",\n" + + " \"model_group_id\": \"" + + modelGroupId + + "\",\n" + + " \"description\": \"test model\",\n" + + " \"connector_id\": \"" + + connectorId + + "\"\n" + + "}" + ); + Response response = executeRequest(request); + + String resp = readResponse(response); + + return JsonParser.parseString(resp).getAsJsonObject().get("model_id").getAsString(); + } + + private String setupConversationalAgent(String modelId) throws IOException { + Request request = new Request("POST", "/_plugins/_ml/agents/_register"); + request + .setJsonEntity( + "{\n" + + " \"name\": \"integTest-agent\",\n" + + " \"type\": \"conversational\",\n" + + " \"description\": \"this is a test agent\",\n" + + " \"llm\": {\n" + + " \"model_id\": \"" + + modelId + + "\",\n" + + " \"parameters\": {\n" + + " \"max_iteration\": \"5\",\n" + + " \"stop_when_no_tool_found\": \"true\",\n" + + " \"response_filter\": \"$.completion\"\n" + + " }\n" + + " },\n" + + " \"tools\": [\n" + + " {\n" + + " \"type\": \"" + + toolType() + + "\",\n" + + " \"name\": \"" + + toolType() + + "\",\n" + + " \"include_output_in_agent_response\": true,\n" + + " \"description\": \"tool description\"\n" + + " }\n" + + " ],\n" + + " \"memory\": {\n" + + " \"type\": \"conversation_index\"\n" + + " }\n" + + "}" + ); + Response response = executeRequest(request); + + String resp = readResponse(response); + + return JsonParser.parseString(resp).getAsJsonObject().get("agent_id").getAsString(); + } + + public static Response executeRequest(Request request) throws IOException { + RequestOptions.Builder builder = RequestOptions.DEFAULT.toBuilder(); + builder.addHeader("Content-Type", "application/json"); + request.setOptions(builder); + return client().performRequest(request); + } + + public static String readResponse(Response response) throws IOException { + try (InputStream ins = response.getEntity().getContent()) { + return String.join("", org.opensearch.common.io.Streams.readAllLines(ins)); + } + } +} diff --git a/src/test/java/org/opensearch/integTest/VisualizationsToolIT.java b/src/test/java/org/opensearch/integTest/VisualizationsToolIT.java new file mode 100644 index 00000000..e7f54521 --- /dev/null +++ b/src/test/java/org/opensearch/integTest/VisualizationsToolIT.java @@ -0,0 +1,107 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.integTest; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; +import java.util.UUID; + +import org.junit.Assert; +import org.opensearch.agent.tools.VisualizationsTool; +import org.opensearch.client.Request; +import org.opensearch.client.Response; +import org.opensearch.core.rest.RestStatus; + +import com.google.gson.JsonParser; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class VisualizationsToolIT extends ToolIntegrationTest { + @Override + List promptHandlers() { + return List.of(new PromptHandler() { + @Override + LLMThought llmThought() { + return LLMThought + .builder() + .action(VisualizationsTool.TYPE) + .actionInput("RAM") + .question("can you show me RAM info with visualization?") + .build(); + } + }, new PromptHandler() { + @Override + LLMThought llmThought() { + return LLMThought + .builder() + .action(VisualizationsTool.TYPE) + .actionInput("sales") + .question("how about the sales about this month?") + .build(); + } + }); + } + + String toolType() { + return VisualizationsTool.TYPE; + } + + public void testVisualizationNotFound() throws IOException { + Request request = new Request("POST", "/_plugins/_ml/agents/" + agentId + "/_execute"); + request.setJsonEntity("{\"parameters\":{\"question\":\"can you show me RAM info with visualization?\"}}"); + Response response = executeRequest(request); + String responseStr = readResponse(response); + String toolOutput = extractAdditionalInfo(responseStr); + Assert.assertEquals("No Visualization found", toolOutput); + } + + public void testVisualizationFound() throws IOException { + String title = "[eCommerce] Sales by Category"; + String id = UUID.randomUUID().toString(); + prepareVisualization(title, id); + Request request = new Request("POST", "/_plugins/_ml/agents/" + agentId + "/_execute"); + request.setJsonEntity("{\"parameters\":{\"question\":\"how about the sales about this month?\"}}"); + Response response = executeRequest(request); + String responseStr = readResponse(response); + String toolOutput = extractAdditionalInfo(responseStr); + Assert.assertEquals("Title,Id\n" + String.format(Locale.ROOT, "%s,%s\n", title, id), toolOutput); + } + + private void prepareVisualization(String title, String id) { + String body = "{\n" + + " \"visualization\": {\n" + + " \"title\": \"" + + title + + "\"\n" + + " },\n" + + " \"type\": \"visualization\"\n" + + "}"; + Response response = makeRequest(client(), "POST", String.format(Locale.ROOT, ".kibana/_doc/%s?refresh=true", id), null, body, null); + Assert.assertEquals(response.getStatusLine().getStatusCode(), RestStatus.CREATED.getStatus()); + } + + private String extractAdditionalInfo(String responseStr) { + return JsonParser + .parseString(responseStr) + .getAsJsonObject() + .get("inference_results") + .getAsJsonArray() + .get(0) + .getAsJsonObject() + .get("output") + .getAsJsonArray() + .get(0) + .getAsJsonObject() + .get("dataAsMap") + .getAsJsonObject() + .get("additional_info") + .getAsJsonObject() + .get(String.format(Locale.ROOT, "%s.output", toolType())) + .getAsString(); + } +}