Skip to content

Commit

Permalink
add parameter validate for PPL tool (#259)
Browse files Browse the repository at this point in the history
* add parameter validate for PPL tool

Signed-off-by: xinyual <[email protected]>

* apply spotless

Signed-off-by: xinyual <[email protected]>

* simplify code

Signed-off-by: xinyual <[email protected]>

* prevent NPE

Signed-off-by: xinyual <[email protected]>

* change logic

Signed-off-by: xinyual <[email protected]>

* apply spot

Signed-off-by: xinyual <[email protected]>

* fix UT

Signed-off-by: xinyual <[email protected]>

* apply spotless

Signed-off-by: xinyual <[email protected]>

---------

Signed-off-by: xinyual <[email protected]>
  • Loading branch information
xinyual authored Mar 18, 2024
1 parent 54050dd commit 571dab7
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 2 deletions.
24 changes: 24 additions & 0 deletions src/main/java/org/opensearch/agent/tools/PPLTool.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.StringJoiner;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.math.NumberUtils;
import org.apache.commons.text.StringSubstitutor;
import org.json.JSONObject;
import org.opensearch.action.ActionRequest;
Expand Down Expand Up @@ -313,6 +315,7 @@ public void init(Client client) {

@Override
public PPLTool create(Map<String, Object> map) {
validatePPLToolParameters(map);
return new PPLTool(
client,
(String) map.get("model_id"),
Expand Down Expand Up @@ -356,6 +359,27 @@ private GetMappingsRequest buildGetMappingRequest(String indexName) {
return getMappingsRequest;
}

private static void validatePPLToolParameters(Map<String, Object> map) {
if (StringUtils.isBlank((String) map.get("model_id"))) {
throw new IllegalArgumentException("PPL tool needs non blank model id.");
}
if (map.containsKey("execute") && Objects.nonNull(map.get("execute"))) {
String execute = map.get("execute").toString().toLowerCase(Locale.ROOT);
if (!execute.equals("true") && !execute.equals("false")) {
throw new IllegalArgumentException("PPL tool parameter execute must be false or true");
}

}
if (map.containsKey("head")) {
String head = map.get("head").toString();
try {
int headInt = NumberUtils.createInteger(head);
} catch (Exception e) {
throw new IllegalArgumentException("PPL tool parameter head must be integer.");
}
}
}

private String constructTableInfo(SearchHit[] searchHits, Map<String, MappingMetadata> mappings) throws PrivilegedActionException {
String firstIndexName = (String) mappings.keySet().toArray()[0];
MappingMetadata mappingMetadata = mappings.get(firstIndexName);
Expand Down
42 changes: 40 additions & 2 deletions src/test/java/org/opensearch/agent/tools/PPLToolTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,47 @@ public void setup() {
PPLTool.Factory.getInstance().init(client);
}

@Test
public void testTool_WithoutModelId() {
Exception exception = assertThrows(
IllegalArgumentException.class,
() -> PPLTool.Factory.getInstance().create(ImmutableMap.of("prompt", "contextPrompt"))
);
assertEquals("PPL tool needs non blank model id.", exception.getMessage());
}

@Test
public void testTool_WithBlankModelId() {
Exception exception = assertThrows(
IllegalArgumentException.class,
() -> PPLTool.Factory.getInstance().create(ImmutableMap.of("model_id", " "))
);
assertEquals("PPL tool needs non blank model id.", exception.getMessage());
}

@Test
public void testTool_WithNonIntegerHead() {
Exception exception = assertThrows(
IllegalArgumentException.class,
() -> PPLTool.Factory.getInstance().create(ImmutableMap.of("model_id", "demo", "head", "11.5"))
);
assertEquals("PPL tool parameter head must be integer.", exception.getMessage());
}

@Test
public void testTool_WithNonBooleanExecute() {
Exception exception = assertThrows(
IllegalArgumentException.class,
() -> PPLTool.Factory.getInstance().create(ImmutableMap.of("model_id", "demo", "execute", "hello"))
);
assertEquals("PPL tool parameter execute must be false or true", exception.getMessage());
}

@Test
public void testTool() {
PPLTool tool = PPLTool.Factory.getInstance().create(ImmutableMap.of("model_id", "modelId", "prompt", "contextPrompt"));
PPLTool tool = PPLTool.Factory
.getInstance()
.create(ImmutableMap.of("model_id", "modelId", "prompt", "contextPrompt", "head", "100"));
assertEquals(PPLTool.TYPE, tool.getName());

tool.run(ImmutableMap.of("index", "demo", "question", "demo"), ActionListener.<String>wrap(executePPLResult -> {
Expand All @@ -142,7 +180,7 @@ public void testTool() {
public void testTool_withPreviousInput() {
PPLTool tool = PPLTool.Factory
.getInstance()
.create(ImmutableMap.of("model_id", "modelId", "prompt", "contextPrompt", "previous_tool_name", "previousTool"));
.create(ImmutableMap.of("model_id", "modelId", "prompt", "contextPrompt", "previous_tool_name", "previousTool", "head", "-5"));
assertEquals(PPLTool.TYPE, tool.getName());

tool.run(ImmutableMap.of("previousTool.output", "demo", "question", "demo"), ActionListener.<String>wrap(executePPLResult -> {
Expand Down

0 comments on commit 571dab7

Please sign in to comment.