-
Notifications
You must be signed in to change notification settings - Fork 65
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for asymmetric embedding models #710
base: main
Are you sure you want to change the base?
Changes from 6 commits
ac4f7b3
4c36060
1b84577
bc6caad
b19ffe2
dfa03fe
a5977b2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -44,6 +44,8 @@ | |
import org.opensearch.knn.index.query.parser.RescoreParser; | ||
import org.opensearch.knn.index.query.rescore.RescoreContext; | ||
import org.opensearch.neuralsearch.common.MinClusterVersionUtil; | ||
import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters; | ||
import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters.EmbeddingContentType; | ||
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; | ||
|
||
import com.google.common.annotations.VisibleForTesting; | ||
|
@@ -333,10 +335,15 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) { | |
inferenceInput.put(INPUT_IMAGE, queryImage()); | ||
} | ||
queryRewriteContext.registerAsyncAction( | ||
((client, actionListener) -> ML_CLIENT.inferenceSentences(modelId(), inferenceInput, ActionListener.wrap(floatList -> { | ||
vectorSetOnce.set(vectorAsListToArray(floatList)); | ||
actionListener.onResponse(null); | ||
}, actionListener::onFailure))) | ||
((client, actionListener) -> ML_CLIENT.inferenceSentences( | ||
modelId(), | ||
inferenceInput, | ||
AsymmetricTextEmbeddingParameters.builder().embeddingContentType(EmbeddingContentType.QUERY).build(), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is effectively a constant, no need to compute on every call There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: always add new parameters in the end of argument list so that it is easy to understand the changes. |
||
ActionListener.wrap(floatList -> { | ||
vectorSetOnce.set(vectorAsListToArray(floatList)); | ||
actionListener.onResponse(null); | ||
}, actionListener::onFailure) | ||
)) | ||
); | ||
return new NeuralQueryBuilder( | ||
fieldName(), | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,7 +23,11 @@ | |
import org.opensearch.cluster.node.DiscoveryNode; | ||
import org.opensearch.core.action.ActionListener; | ||
import org.opensearch.ml.client.MachineLearningNodeClient; | ||
import org.opensearch.ml.common.MLModel; | ||
import org.opensearch.ml.common.input.MLInput; | ||
import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters; | ||
import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters.EmbeddingContentType; | ||
import org.opensearch.ml.common.model.TextEmbeddingModelConfig; | ||
import org.opensearch.ml.common.output.MLOutput; | ||
import org.opensearch.ml.common.output.model.MLResultDataType; | ||
import org.opensearch.ml.common.output.model.ModelTensor; | ||
|
@@ -59,15 +63,34 @@ public void testInferenceSentence_whenValidInput_thenSuccess() { | |
actionListener.onResponse(createModelTensorOutput(TestCommonConstants.PREDICT_VECTOR_ARRAY)); | ||
return null; | ||
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); | ||
setupMocksForTextEmbeddingModelAsymmetryCheck(false); | ||
|
||
accessor.inferenceSentence(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST.get(0), singleSentenceResultListener); | ||
accessor.inferenceSentence( | ||
TestCommonConstants.MODEL_ID, | ||
TestCommonConstants.SENTENCES_LIST.get(0), | ||
AsymmetricTextEmbeddingParameters.builder().embeddingContentType(EmbeddingContentType.PASSAGE).build(), | ||
singleSentenceResultListener | ||
); | ||
|
||
Mockito.verify(client) | ||
.predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); | ||
Mockito.verify(singleSentenceResultListener).onResponse(vector); | ||
Mockito.verifyNoMoreInteractions(singleSentenceResultListener); | ||
} | ||
|
||
private void setupMocksForTextEmbeddingModelAsymmetryCheck(boolean isAsymmetric) { | ||
MLModel modelMock = mock(MLModel.class); | ||
TextEmbeddingModelConfig textEmbeddingModelConfigMock = mock(TextEmbeddingModelConfig.class); | ||
Mockito.when(textEmbeddingModelConfigMock.getPassagePrefix()).thenReturn(isAsymmetric ? "passage: " : null); | ||
Mockito.when(textEmbeddingModelConfigMock.getQueryPrefix()).thenReturn(isAsymmetric ? "query: " : null); | ||
Mockito.when(modelMock.getModelConfig()).thenReturn(textEmbeddingModelConfigMock); | ||
Mockito.doAnswer(invocation -> { | ||
final ActionListener<MLModel> actionListener = invocation.getArgument(1); | ||
actionListener.onResponse(modelMock); | ||
return null; | ||
}).when(client).getModel(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(ActionListener.class)); | ||
} | ||
|
||
public void testInferenceSentences_whenValidInputThenSuccess() { | ||
final List<List<Float>> vectorList = new ArrayList<>(); | ||
vectorList.add(Arrays.asList(TestCommonConstants.PREDICT_VECTOR_ARRAY)); | ||
|
@@ -76,6 +99,8 @@ public void testInferenceSentences_whenValidInputThenSuccess() { | |
actionListener.onResponse(createModelTensorOutput(TestCommonConstants.PREDICT_VECTOR_ARRAY)); | ||
return null; | ||
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); | ||
setupMocksForTextEmbeddingModelAsymmetryCheck(false); | ||
|
||
accessor.inferenceSentences(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener); | ||
|
||
Mockito.verify(client) | ||
|
@@ -92,6 +117,8 @@ public void testInferenceSentences_whenResultFromClient_thenEmptyVectorList() { | |
actionListener.onResponse(createModelTensorOutput(new Float[] {})); | ||
return null; | ||
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); | ||
setupMocksForTextEmbeddingModelAsymmetryCheck(false); | ||
|
||
accessor.inferenceSentences(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener); | ||
|
||
Mockito.verify(client) | ||
|
@@ -107,6 +134,8 @@ public void testInferenceSentences_whenExceptionFromMLClient_thenFailure() { | |
actionListener.onFailure(exception); | ||
return null; | ||
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); | ||
setupMocksForTextEmbeddingModelAsymmetryCheck(false); | ||
|
||
accessor.inferenceSentences( | ||
TestCommonConstants.TARGET_RESPONSE_FILTERS, | ||
TestCommonConstants.MODEL_ID, | ||
|
@@ -130,6 +159,9 @@ public void testInferenceSentences_whenNodeNotConnectedException_thenRetry_3Time | |
actionListener.onFailure(nodeNodeConnectedException); | ||
return null; | ||
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); | ||
|
||
setupMocksForTextEmbeddingModelAsymmetryCheck(false); | ||
|
||
accessor.inferenceSentences( | ||
TestCommonConstants.TARGET_RESPONSE_FILTERS, | ||
TestCommonConstants.MODEL_ID, | ||
|
@@ -149,6 +181,9 @@ public void testInferenceSentences_whenNotConnectionException_thenNoRetry() { | |
actionListener.onFailure(illegalStateException); | ||
return null; | ||
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); | ||
|
||
setupMocksForTextEmbeddingModelAsymmetryCheck(false); | ||
|
||
accessor.inferenceSentences( | ||
TestCommonConstants.TARGET_RESPONSE_FILTERS, | ||
TestCommonConstants.MODEL_ID, | ||
|
@@ -161,6 +196,62 @@ public void testInferenceSentences_whenNotConnectionException_thenNoRetry() { | |
Mockito.verify(resultListener).onFailure(illegalStateException); | ||
} | ||
|
||
public void testInferenceSentences_whenModelAsymmetric_thenSuccess() { | ||
final List<Float> vector = new ArrayList<>(List.of(TestCommonConstants.PREDICT_VECTOR_ARRAY)); | ||
Mockito.doAnswer(invocation -> { | ||
final ActionListener<MLOutput> actionListener = invocation.getArgument(2); | ||
actionListener.onResponse(createModelTensorOutput(TestCommonConstants.PREDICT_VECTOR_ARRAY)); | ||
return null; | ||
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); | ||
setupMocksForTextEmbeddingModelAsymmetryCheck(true); | ||
|
||
accessor.inferenceSentence( | ||
TestCommonConstants.MODEL_ID, | ||
TestCommonConstants.SENTENCES_LIST.get(0), | ||
AsymmetricTextEmbeddingParameters.builder().embeddingContentType(EmbeddingContentType.PASSAGE).build(), | ||
singleSentenceResultListener | ||
); | ||
|
||
Mockito.verify(client) | ||
.predict( | ||
Mockito.eq(TestCommonConstants.MODEL_ID), | ||
Mockito.argThat((MLInput input) -> input.getParameters() != null), | ||
Mockito.isA(ActionListener.class) | ||
); | ||
Mockito.verify(singleSentenceResultListener).onResponse(vector); | ||
Mockito.verifyNoMoreInteractions(singleSentenceResultListener); | ||
} | ||
|
||
public void testInferenceSentences_whenGetModelException_thenFailure() { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we need to test scenario when we're retrying 1-2 times. I see scenario when first request has failed with error that isn't retryable, this isn't a full coverage There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @martin-gaievski thanks for pointing this out. I'm wondering though, if we should make this retryable at all. Let me elaborate: In my understanding, inference requests are retried because they tend to fail more often than regular operations in OpenSearch. I don't know the history and complete reasoning behind this, so I speculate it has to do with the fact that the inference is done natively and that many things can go wrong there. With my change, if fetching the model information fails ( So my argument is: should we really add a retry logic to this relatively simple operation? If There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @br3no I think Martin was talking about adding more test method to increase the test coverage, e.g. add a new test method
to
in this method, and by that the new method should cover the retry part. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @zane-neo if I replace the exception, the behavior will not change, i.e. there will not be a successful retry. This is because the The reasoning was that if there is an exception in I'm happy to implement the retry in this operation, but I don't believe it's worth it. It will introduce more complexity to an already quite intricate code path, making it less readable. At the same time I believe it will almost never really be useful, as a failing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed, it's fine to not retry the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @zane-neo @br3no please help me to understand the flow for reading model config. Is there a chance that this call is remote, meaning going outside the opensearch cluster if model deployed remotely and connected via ml-commons remote connector? If that's not the case then I'm fine with logic without retry, it's a transport request within single OS domain. In this case where the mode configuration is stored, is it in ml-commons memory or cluster metadata? If it can be a remote call then we need a retry. Doesn't matter how simple operation is, there can be a transient network issue or anything of this sort. In this case it's different from most of OS APIs, those are primality case 1, transport requests inside cluster. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No case is to fetching model metadata from remote services, cause they are all stored in OS cluster. |
||
final List<Float> vector = new ArrayList<>(List.of(TestCommonConstants.PREDICT_VECTOR_ARRAY)); | ||
Mockito.doAnswer(invocation -> { | ||
final ActionListener<MLOutput> actionListener = invocation.getArgument(2); | ||
actionListener.onResponse(createModelTensorOutput(TestCommonConstants.PREDICT_VECTOR_ARRAY)); | ||
return null; | ||
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); | ||
RuntimeException exception = new RuntimeException("Bam!"); | ||
setupMocksForTextEmbeddingModelAsymmetryCheck(exception); | ||
|
||
accessor.inferenceSentence( | ||
TestCommonConstants.MODEL_ID, | ||
TestCommonConstants.SENTENCES_LIST.get(0), | ||
AsymmetricTextEmbeddingParameters.builder().embeddingContentType(EmbeddingContentType.PASSAGE).build(), | ||
singleSentenceResultListener | ||
); | ||
|
||
Mockito.verify(client).getModel(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(ActionListener.class)); | ||
Mockito.verify(singleSentenceResultListener).onFailure(exception); | ||
Mockito.verifyNoMoreInteractions(singleSentenceResultListener); | ||
} | ||
|
||
private void setupMocksForTextEmbeddingModelAsymmetryCheck(Exception exception) { | ||
Mockito.doAnswer(invocation -> { | ||
final ActionListener<MLModel> actionListener = invocation.getArgument(1); | ||
actionListener.onFailure(exception); | ||
return null; | ||
}).when(client).getModel(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(ActionListener.class)); | ||
} | ||
|
||
public void testInferenceSentencesWithMapResult_whenValidInput_thenSuccess() { | ||
final Map<String, String> map = Map.of("key", "value"); | ||
final ActionListener<List<Map<String, ?>>> resultListener = mock(ActionListener.class); | ||
|
@@ -169,6 +260,9 @@ public void testInferenceSentencesWithMapResult_whenValidInput_thenSuccess() { | |
actionListener.onResponse(createModelTensorOutput(map)); | ||
return null; | ||
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); | ||
|
||
setupMocksForTextEmbeddingModelAsymmetryCheck(false); | ||
|
||
accessor.inferenceSentencesWithMapResult(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener); | ||
|
||
Mockito.verify(client) | ||
|
@@ -185,6 +279,9 @@ public void testInferenceSentencesWithMapResult_whenTensorOutputListEmpty_thenEx | |
actionListener.onResponse(modelTensorOutput); | ||
return null; | ||
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); | ||
|
||
setupMocksForTextEmbeddingModelAsymmetryCheck(false); | ||
|
||
accessor.inferenceSentencesWithMapResult(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener); | ||
|
||
Mockito.verify(client) | ||
|
@@ -209,6 +306,9 @@ public void testInferenceSentencesWithMapResult_whenModelTensorListEmpty_thenExc | |
actionListener.onResponse(modelTensorOutput); | ||
return null; | ||
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); | ||
|
||
setupMocksForTextEmbeddingModelAsymmetryCheck(false); | ||
|
||
accessor.inferenceSentencesWithMapResult(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener); | ||
|
||
Mockito.verify(client) | ||
|
@@ -236,6 +336,9 @@ public void testInferenceSentencesWithMapResult_whenModelTensorListSizeBiggerTha | |
actionListener.onResponse(modelTensorOutput); | ||
return null; | ||
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); | ||
|
||
setupMocksForTextEmbeddingModelAsymmetryCheck(false); | ||
|
||
accessor.inferenceSentencesWithMapResult(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener); | ||
|
||
Mockito.verify(client) | ||
|
@@ -255,6 +358,9 @@ public void testInferenceSentencesWithMapResult_whenRetryableException_retry3Tim | |
return null; | ||
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); | ||
final ActionListener<List<Map<String, ?>>> resultListener = mock(ActionListener.class); | ||
|
||
setupMocksForTextEmbeddingModelAsymmetryCheck(false); | ||
|
||
accessor.inferenceSentencesWithMapResult(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener); | ||
|
||
Mockito.verify(client, times(4)) | ||
|
@@ -270,6 +376,9 @@ public void testInferenceSentencesWithMapResult_whenNotRetryableException_thenFa | |
return null; | ||
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); | ||
final ActionListener<List<Map<String, ?>>> resultListener = mock(ActionListener.class); | ||
|
||
setupMocksForTextEmbeddingModelAsymmetryCheck(false); | ||
|
||
accessor.inferenceSentencesWithMapResult(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener); | ||
|
||
Mockito.verify(client, times(1)) | ||
|
@@ -285,6 +394,8 @@ public void testInferenceMultimodal_whenValidInput_thenSuccess() { | |
return null; | ||
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); | ||
|
||
setupMocksForTextEmbeddingModelAsymmetryCheck(false); | ||
|
||
accessor.inferenceSentences(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_MAP, singleSentenceResultListener); | ||
|
||
Mockito.verify(client) | ||
|
@@ -300,6 +411,9 @@ public void testInferenceMultimodal_whenExceptionFromMLClient_thenFailure() { | |
actionListener.onFailure(exception); | ||
return null; | ||
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); | ||
|
||
setupMocksForTextEmbeddingModelAsymmetryCheck(false); | ||
|
||
accessor.inferenceSentences(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_MAP, singleSentenceResultListener); | ||
|
||
Mockito.verify(client) | ||
|
@@ -318,6 +432,9 @@ public void testInferenceSentencesMultimodal_whenNodeNotConnectedException_thenR | |
actionListener.onFailure(nodeNodeConnectedException); | ||
return null; | ||
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); | ||
|
||
setupMocksForTextEmbeddingModelAsymmetryCheck(false); | ||
|
||
accessor.inferenceSentences(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_MAP, singleSentenceResultListener); | ||
|
||
Mockito.verify(client, times(4)) | ||
|
@@ -333,6 +450,8 @@ public void testInferenceSimilarity_whenValidInput_thenSuccess() { | |
return null; | ||
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); | ||
|
||
setupMocksForTextEmbeddingModelAsymmetryCheck(false); | ||
|
||
accessor.inferenceSimilarity( | ||
TestCommonConstants.MODEL_ID, | ||
"is it sunny", | ||
|
@@ -354,6 +473,8 @@ public void testInferencesSimilarity_whenExceptionFromMLClient_ThenFail() { | |
return null; | ||
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); | ||
|
||
setupMocksForTextEmbeddingModelAsymmetryCheck(false); | ||
|
||
accessor.inferenceSimilarity( | ||
TestCommonConstants.MODEL_ID, | ||
"is it sunny", | ||
|
@@ -378,6 +499,8 @@ public void testInferenceSimilarity_whenNodeNotConnectedException_ThenTryThreeTi | |
return null; | ||
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); | ||
|
||
setupMocksForTextEmbeddingModelAsymmetryCheck(false); | ||
|
||
accessor.inferenceSimilarity( | ||
TestCommonConstants.MODEL_ID, | ||
"is it sunny", | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this can be pulled to a constant level