Skip to content

Commit

Permalink
Fix UT
Browse files Browse the repository at this point in the history
Signed-off-by: Heng Qian <[email protected]>
  • Loading branch information
qianheng-aws committed Nov 21, 2024
1 parent 585e6d5 commit f779fd9
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 13 deletions.
8 changes: 5 additions & 3 deletions src/main/java/org/opensearch/agent/tools/CreateAlertTool.java
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ public class CreateAlertTool implements Tool {
@Getter
private final String modelId;
@Getter
private final String modelType;
@Getter
private final String toolPrompt;

private static final String MODEL_ID = "model_id";
Expand Down Expand Up @@ -93,9 +95,9 @@ public static ModelType from(String value) {
public CreateAlertTool(Client client, String modelId, String modelType, String prompt) {
this.client = client;
this.modelId = modelId;
modelType = String.valueOf(ModelType.from(modelType));
this.modelType = String.valueOf(ModelType.from(modelType));
if (prompt.isEmpty()) {
if (!promptDict.containsKey(modelType)) {
if (!promptDict.containsKey(this.modelType)) {
throw new IllegalArgumentException(
LoggerMessageFormat
.format(
Expand All @@ -106,7 +108,7 @@ public CreateAlertTool(Client client, String modelId, String modelType, String p
)
);
}
this.toolPrompt = promptDict.get(modelType);
this.toolPrompt = promptDict.get(this.modelType);
} else {
this.toolPrompt = prompt;
}
Expand Down
22 changes: 12 additions & 10 deletions src/test/java/org/opensearch/agent/tools/CreateAlertToolTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -159,16 +159,18 @@ public void testTool_WithBlankModelId() {

@Test
public void testTool_WithNonSupportedModelType() {
Exception exception = assertThrows(
IllegalArgumentException.class,
() -> CreateAlertTool.Factory
.getInstance()
.create(ImmutableMap.of("model_id", "modelId", "model_type", "non_supported_modelType"))
);
assertEquals(
"Failed to find the right prompt for modelType: non_supported_modelType, this tool supports prompts for these models: [CLAUDE,OPENAI]",
exception.getMessage()
);
CreateAlertTool alertTool = CreateAlertTool.Factory
.getInstance()
.create(ImmutableMap.of("model_id", "modelId", "model_type", "non_supported_modelType"));
assertEquals("CLAUDE", alertTool.getModelType());
}

@Test
public void testTool_WithEmptyModelType() {
CreateAlertTool alertTool = CreateAlertTool.Factory
.getInstance()
.create(ImmutableMap.of("model_id", "modelId", "model_type", ""));
assertEquals("CLAUDE", alertTool.getModelType());
}

@Test
Expand Down

0 comments on commit f779fd9

Please sign in to comment.