Skip to content

Commit

Permalink
deprecate max_token_score
Browse files Browse the repository at this point in the history
Signed-off-by: zhichao-aws <[email protected]>
  • Loading branch information
zhichao-aws committed Oct 30, 2023
1 parent 15906e5 commit 6a0bb56
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,10 @@ public class NeuralSparseQueryBuilder extends AbstractQueryBuilder<NeuralSparseQ
static final ParseField QUERY_TEXT_FIELD = new ParseField("query_text");
@VisibleForTesting
static final ParseField MODEL_ID_FIELD = new ParseField("model_id");
// We use max_token_score field to help WAND scorer prune query clause in lucene 9.7. But in lucene 9.8 the inner
// logics change, this field is not needed any more.
@VisibleForTesting
static final ParseField MAX_TOKEN_SCORE_FIELD = new ParseField("max_token_score");
static final ParseField MAX_TOKEN_SCORE_FIELD = new ParseField("max_token_score").withAllDeprecated();

private static MLCommonsClientAccessor ML_CLIENT;

Expand Down Expand Up @@ -163,9 +165,6 @@ public static NeuralSparseQueryBuilder fromXContent(XContentParser parser) throw
sparseEncodingQueryBuilder.modelId(),
String.format(Locale.ROOT, "%s field must be provided for [%s] query", MODEL_ID_FIELD.getPreferredName(), NAME)
);
if (sparseEncodingQueryBuilder.maxTokenScore != null && sparseEncodingQueryBuilder.maxTokenScore <= 0) {
throw new IllegalArgumentException(MAX_TOKEN_SCORE_FIELD.getPreferredName() + " must be larger than 0.");
}

return sparseEncodingQueryBuilder;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,32 @@ public void testFromXContent_whenBuiltWithOptionals_thenBuildSuccessfully() {
assertEquals(QUERY_NAME, sparseEncodingQueryBuilder.queryName());
}

@SneakyThrows
public void testFromXContent_whenBuiltWithMaxTokenScore_thenThrowWarning() {
/*
{
"VECTOR_FIELD": {
"query_text": "string",
"model_id": "string",
"max_token_score": 123.0
}
}
*/
XContentBuilder xContentBuilder = XContentFactory.jsonBuilder()
.startObject()
.startObject(FIELD_NAME)
.field(QUERY_TEXT_FIELD.getPreferredName(), QUERY_TEXT)
.field(MODEL_ID_FIELD.getPreferredName(), MODEL_ID)
.field(MAX_TOKEN_SCORE_FIELD.getPreferredName(), MAX_TOKEN_SCORE)
.endObject()
.endObject();

XContentParser contentParser = createParser(xContentBuilder);
contentParser.nextToken();
NeuralSparseQueryBuilder sparseEncodingQueryBuilder = NeuralSparseQueryBuilder.fromXContent(contentParser);
assertWarnings("Deprecated field [max_token_score] used, this field is unused and will be removed entirely");
}

@SneakyThrows
public void testFromXContent_whenBuildWithMultipleRootFields_thenFail() {
/*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,16 +106,7 @@ public void testBasicQueryWithMaxTokenScore() {
Map<String, Object> firstInnerHit = getFirstInnerHit(searchResponseAsMap);

assertEquals("1", firstInnerHit.get("_id"));
Map<String, Float> queryTokens = runSparseModelInference(modelId, TEST_QUERY_TEXT);
float expectedScore = 0f;
for (Map.Entry<String, Float> entry : queryTokens.entrySet()) {
if (testRankFeaturesDoc.containsKey(entry.getKey())) {
expectedScore += entry.getValue() * Math.min(
getFeatureFieldCompressedNumber(testRankFeaturesDoc.get(entry.getKey())),
maxTokenScore
);
}
}
float expectedScore = computeExpectedScore(modelId, testRankFeaturesDoc, TEST_QUERY_TEXT);
assertEquals(expectedScore, objectToFloat(firstInnerHit.get("_score")), DELTA);
}

Expand Down

0 comments on commit 6a0bb56

Please sign in to comment.