Skip to content

Commit

Permalink
Optimize the default prompt and make prompt customizable for create a…
Browse files Browse the repository at this point in the history
…nomaly detector tool (#399) (#419)

* Optimize the prompt for create anomaly detector tool



* Remove whitespace



* Make prompt for CreateAnomalyDetectorToll customized



* format the code



* Fix test failure



* fix test failure



* Format the code



* Add more tests



---------


(cherry picked from commit 06a8537)

Signed-off-by: gaobinlong <[email protected]>
Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
  • Loading branch information
1 parent a26c924 commit 098e6d7
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -143,14 +143,18 @@ public static ModelType from(String value) {
* @param client the OpenSearch transport client
* @param modelId the model ID of LLM
*/
public CreateAnomalyDetectorTool(Client client, String modelId, String modelType) {
public CreateAnomalyDetectorTool(Client client, String modelId, String modelType, String contextPrompt) {
this.client = client;
this.modelId = modelId;
if (!ModelType.OPENAI.toString().equalsIgnoreCase(modelType) && !ModelType.CLAUDE.toString().equalsIgnoreCase(modelType)) {
throw new IllegalArgumentException("Unsupported model_type: " + modelType);
}
this.modelType = ModelType.from(modelType);
this.contextPrompt = DEFAULT_PROMPT_DICT.getOrDefault(this.modelType.toString(), "");
if (contextPrompt.isEmpty()) {
this.contextPrompt = DEFAULT_PROMPT_DICT.getOrDefault(this.modelType.toString(), "");
} else {
this.contextPrompt = contextPrompt;
}
}

/**
Expand Down Expand Up @@ -432,7 +436,8 @@ public CreateAnomalyDetectorTool create(Map<String, Object> map) {
if (!ModelType.OPENAI.toString().equalsIgnoreCase(modelType) && !ModelType.CLAUDE.toString().equalsIgnoreCase(modelType)) {
throw new IllegalArgumentException("Unsupported model_type: " + modelType);
}
return new CreateAnomalyDetectorTool(client, modelId, modelType);
String prompt = (String) map.getOrDefault("prompt", "");
return new CreateAnomalyDetectorTool(client, modelId, modelType, prompt);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
{
"CLAUDE": "Human:\" turn\": Here is an example of the create anomaly detector API: POST _plugins/_anomaly_detection/detectors, {\"time_field\":\"timestamp\",\"indices\":[\"server_log*\"],\"feature_attributes\":[{\"feature_name\":\"test\",\"feature_enabled\":true,\"aggregation_query\":{\"test\":{\"sum\":{\"field\":\"value\"}}}}],\"category_field\":[\"ip\"]}, and here are the mapping info containing all the fields in the index ${indexInfo.indexName}: ${indexInfo.indexMapping}, and the optional aggregation methods are count, avg, min, max and sum. Please give me some suggestion about creating an anomaly detector for the index ${indexInfo.indexName}, you need to give the key information: the top 3 suitable aggregation fields which are numeric types and the suitable aggregation method for each field, if there are no numeric type fields, both the aggregation field and method are empty string, and also give the category field if there exists a keyword type field like ip, address, host, city, country or region, if not exist, the category field is empty. Show me a format of keyed and pipe-delimited list wrapped in a curly bracket just like {category_field=the category field if exists|aggregation_field=comma-delimited list of all the aggregation field names|aggregation_method=comma-delimited list of all the aggregation methods}. \n\nAssistant:\" turn\"",
"OPENAI": "Here is an example of the create anomaly detector API: POST _plugins/_anomaly_detection/detectors, {\"time_field\":\"timestamp\",\"indices\":[\"server_log*\"],\"feature_attributes\":[{\"feature_name\":\"test\",\"feature_enabled\":true,\"aggregation_query\":{\"test\":{\"sum\":{\"field\":\"value\"}}}}],\"category_field\":[\"ip\"]}, and here are the mapping info containing all the fields in the index ${indexInfo.indexName}: ${indexInfo.indexMapping}, and the optional aggregation methods are count, avg, min, max and sum. Please give me some suggestion about creating an anomaly detector for the index ${indexInfo.indexName}, you need to give the key information: the top 3 suitable aggregation fields which are numeric types and the suitable aggregation method for each field, if there are no numeric type fields, both the aggregation field and method are empty string, and also give the category field if there exists a keyword type field like ip, address, host, city, country or region, if not exist, the category field is empty. Show me a format of keyed and pipe-delimited list wrapped in a curly bracket just like {category_field=the category field if exists|aggregation_field=comma-delimited list of all the aggregation field names|aggregation_method=comma-delimited list of all the aggregation methods}. "
"CLAUDE": "Human:\" turn\": Here is an example of the create anomaly detector API: POST _plugins/_anomaly_detection/detectors, {\"time_field\":\"timestamp\",\"indices\":[\"server_log*\"],\"feature_attributes\":[{\"feature_name\":\"test\",\"feature_enabled\":true,\"aggregation_query\":{\"test\":{\"sum\":{\"field\":\"value\"}}}}],\"category_field\":[\"ip\"]}, and here are the mapping info containing all the fields in the index ${indexInfo.indexName}: ${indexInfo.indexMapping}, and the optional aggregation methods are count, avg, min, max and sum. Please give me some suggestion about creating an anomaly detector for the index ${indexInfo.indexName}, you need to give the key information: the top 3 suitable aggregation fields which are numeric types(long, integer, double, float, short etc.) and the suitable aggregation method for each field, you should give at most 3 aggregation fields and corresponding aggregation methods, if there are no numeric type fields, both the aggregation field and method are empty string, and also give at most 1 category field if there exists a keyword type field like ip, address, host, city, country or region, if not exist, the category field is empty. Show me a format of keyed and pipe-delimited list wrapped in a curly bracket just like {category_field=the category field if exists|aggregation_field=comma-delimited list of all the aggregation field names|aggregation_method=comma-delimited list of all the aggregation methods}. \n\nAssistant:\" turn\"",
"OPENAI": "Here is an example of the create anomaly detector API: POST _plugins/_anomaly_detection/detectors, {\"time_field\":\"timestamp\",\"indices\":[\"server_log*\"],\"feature_attributes\":[{\"feature_name\":\"test\",\"feature_enabled\":true,\"aggregation_query\":{\"test\":{\"sum\":{\"field\":\"value\"}}}}],\"category_field\":[\"ip\"]}, and here are the mapping info containing all the fields in the index ${indexInfo.indexName}: ${indexInfo.indexMapping}, and the optional aggregation methods are count, avg, min, max and sum. Please give me some suggestion about creating an anomaly detector for the index ${indexInfo.indexName}, you need to give the key information: the top 3 suitable aggregation fields which are numeric types(long, integer, double, float, short etc.) and the suitable aggregation method for each field, you should give at most 3 aggregation fields and corresponding aggregation methods, if there are no numeric type fields, both the aggregation field and method are empty string, and also give at most 1 category field if there exists a keyword type field like ip, address, host, city, country or region, if not exist, the category field is empty. Show me a format of keyed and pipe-delimited list wrapped in a curly bracket just like {category_field=the category field if exists|aggregation_field=comma-delimited list of all the aggregation field names|aggregation_method=comma-delimited list of all the aggregation methods}."
}
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,23 @@ public void testToolWithPredictModelFailed() {
}));
}

@Test
public void testToolWithCustomPrompt() {
CreateAnomalyDetectorTool tool = CreateAnomalyDetectorTool.Factory
.getInstance()
.create(ImmutableMap.of("model_id", "modelId", "prompt", "custom prompt"));
assertEquals(CreateAnomalyDetectorTool.TYPE, tool.getName());
assertEquals("modelId", tool.getModelId());
assertEquals("CLAUDE", tool.getModelType().toString());
assertEquals("custom prompt", tool.getContextPrompt());

tool
.run(
ImmutableMap.of("index", mockedIndexName),
ActionListener.<String>wrap(response -> assertEquals(mockedResult, response), log::info)
);
}

private void createMappings() {
indexMappings = new HashMap<>();
indexMappings
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,25 @@ private String registerAgent() {
)
);
registerAgentRequestBody = registerAgentRequestBody.replace("<MODEL_ID>", modelId);
registerAgentRequestBody = registerAgentRequestBody
.replace(
"<CUSTOM_PROMPT>",
"Here is an example of the create anomaly detector API: POST _plugins/_anomaly_detection/detectors, "
+ " {\\\"time_field\\\":\\\"timestamp\\\",\\\"indices\\\":[\\\"server_log*\\\"],\\\"feature_attributes\\\":"
+ "[{\\\"feature_name\\\":\\\"test\\\",\\\"feature_enabled\\\":true,"
+ "\\\"aggregation_query\\\":{\\\"test\\\":{\\\"sum\\\":{\\\"field\\\":\\\"value\\\"}}}}],\\\"category_field\\\":[\\\"ip\\\"]},"
+ " and here are the mapping info containing all the fields in the index ${indexInfo.indexName}: ${indexInfo.indexMapping}, "
+ "and the optional aggregation methods are count, avg, min, max and sum. Please give me some suggestion about "
+ "creating an anomaly detector for the index ${indexInfo.indexName}, "
+ "you need to give the key information: the top 3 suitable aggregation fields which are numeric types and "
+ "the suitable aggregation method for each field, "
+ "if there are no numeric type fields, both the aggregation field and method are empty string, "
+ " and also give the category field if there exists a keyword type field like ip, address, host, city, country or region,"
+ " if not exist, the category field is empty. Show me a format of keyed and pipe-delimited list "
+ "wrapped in a curly bracket just like {category_field=the category field if exists|aggregation_field=comma-delimited"
+ " list of all the aggregation field names|aggregation_method=comma-delimited list of all the aggregation methods}. "
);

return createAgent(registerAgentRequestBody);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
{
"type": "CreateAnomalyDetectorTool",
"parameters": {
"model_id": "<MODEL_ID>"
"model_id": "<MODEL_ID>",
"prompt": "<CUSTOM_PROMPT>"
}
}
]
Expand Down

0 comments on commit 098e6d7

Please sign in to comment.