diff --git a/CHANGELOG.md b/CHANGELOG.md index 595ea7dd4..c72b87e0c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.18...2.x) ### Features +- Add support for asymmetric embedding models ([#710](https://github.com/opensearch-project/neural-search/pull/710)) ### Enhancements ### Bug Fixes ### Infrastructure diff --git a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java index f9ddf73a9..a1a7a1601 100644 --- a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java +++ b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java @@ -11,16 +11,24 @@ import java.util.Arrays; import java.util.List; import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Consumer; import java.util.stream.Collectors; +import org.opensearch.common.CheckedConsumer; +import org.opensearch.common.Nullable; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.util.CollectionUtils; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.dataset.MLInputDataset; import org.opensearch.ml.common.dataset.TextDocsInputDataSet; import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet; import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.input.parameter.MLAlgoParams; +import org.opensearch.ml.common.model.MLModelConfig; +import org.opensearch.ml.common.model.TextEmbeddingModelConfig; import org.opensearch.ml.common.output.MLOutput; import org.opensearch.ml.common.output.model.ModelResultFilter; import org.opensearch.ml.common.output.model.ModelTensor; @@ -40,6 +48,7 @@ public class MLCommonsClientAccessor { private static final List TARGET_RESPONSE_FILTERS = List.of("sentence_embedding"); private final MachineLearningNodeClient mlClient; + private final Map modelAsymmetryCache = new ConcurrentHashMap<>(); /** * Wrapper around {@link #inferenceSentences} that expected a single input text and produces a single floating @@ -54,7 +63,29 @@ public void inferenceSentence( @NonNull final String inputText, @NonNull final ActionListener> listener ) { - inferenceSentences(TARGET_RESPONSE_FILTERS, modelId, List.of(inputText), ActionListener.wrap(response -> { + inferenceSentence(modelId, inputText, null, listener); + } + + /** + * Wrapper around {@link #inferenceSentences} that expected a single input text and produces a single floating + * point vector as a response. Supports passing {@link MLAlgoParams} to the inference. If the model is + * asymmetric, passing a + * {@link org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters} is + * mandatory. This method will check whether the model being used is asymmetric and correctly handle the + * parameter, so it's okay to always pass the parameter (even if the model is symmetric). + * + * @param modelId {@link String} + * @param inputText {@link List} of {@link String} on which inference needs to happen + * @param mlAlgoParams {@link MLAlgoParams} which will be used to run the inference + * @param listener {@link ActionListener} which will be called when prediction is completed or errored out + */ + public void inferenceSentence( + @NonNull final String modelId, + @NonNull final String inputText, + @Nullable final MLAlgoParams mlAlgoParams, + @NonNull final ActionListener> listener + ) { + inferenceSentences(TARGET_RESPONSE_FILTERS, modelId, List.of(inputText), mlAlgoParams, ActionListener.wrap(response -> { if (response.size() != 1) { listener.onFailure( new IllegalStateException( @@ -69,43 +100,98 @@ public void inferenceSentence( } /** - * Abstraction to call predict function of api of MLClient with default targetResponse filters. It uses the - * custom model provided as modelId and run the {@link FunctionName#TEXT_EMBEDDING}. The return will be sent - * using the actionListener which will have a {@link List} of {@link List} of {@link Float} in the order of - * inputText. We are not making this function generic enough to take any function or TaskType as currently we - * need to run only TextEmbedding tasks only. + * Abstraction to call predict function of api of MLClient with default targetResponse filters. It + * uses the custom model provided as modelId and run the {@link FunctionName#TEXT_EMBEDDING}. The + * return will be sent using the actionListener which will have a {@link List} of {@link List} of + * {@link Float} in the order of inputText. We are not making this function generic enough to take + * any function or TaskType as currently we need to run only TextEmbedding tasks only. * - * @param modelId {@link String} + * @param modelId {@link String} * @param inputText {@link List} of {@link String} on which inference needs to happen - * @param listener {@link ActionListener} which will be called when prediction is completed or errored out + * @param listener {@link ActionListener} which will be called when prediction is completed or + * errored out */ public void inferenceSentences( @NonNull final String modelId, @NonNull final List inputText, @NonNull final ActionListener>> listener ) { - inferenceSentences(TARGET_RESPONSE_FILTERS, modelId, inputText, listener); + inferenceSentences(TARGET_RESPONSE_FILTERS, modelId, inputText, null, listener); } /** - * Abstraction to call predict function of api of MLClient with provided targetResponse filters. It uses the - * custom model provided as modelId and run the {@link FunctionName#TEXT_EMBEDDING}. The return will be sent - * using the actionListener which will have a {@link List} of {@link List} of {@link Float} in the order of - * inputText. We are not making this function generic enough to take any function or TaskType as currently we - * need to run only TextEmbedding tasks only. + * Abstraction to call predict function of api of MLClient with default targetResponse filters. It + * uses the custom model provided as modelId and run the {@link FunctionName#TEXT_EMBEDDING}. The + * return will be sent using the actionListener which will have a {@link List} of {@link List} of + * {@link Float} in the order of inputText. We are not making this function generic enough to take + * any function or TaskType as currently we need to run only TextEmbedding tasks only. Supports + * passing {@link MLAlgoParams} to the inference. If the model is asymmetric, passing a + * {@link org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters} + * is mandatory. This method will check whether the model being used is asymmetric and correctly + * handle the parameter, so it's okay to always pass the parameter (even if the model is symmetric). * - * @param targetResponseFilters {@link List} of {@link String} which filters out the responses * @param modelId {@link String} * @param inputText {@link List} of {@link String} on which inference needs to happen - * @param listener {@link ActionListener} which will be called when prediction is completed or errored out. + * @param mlAlgoParams {@link MLAlgoParams} which will be used to run the inference + * @param listener {@link ActionListener} which will be called when prediction is completed or + */ + public void inferenceSentences( + @NonNull final String modelId, + @NonNull final List inputText, + @Nullable final MLAlgoParams mlAlgoParams, + @NonNull final ActionListener>> listener + ) { + inferenceSentences(TARGET_RESPONSE_FILTERS, modelId, inputText, mlAlgoParams, listener); + } + + /** + * Abstraction to call predict function of api of MLClient with provided targetResponse filters. + * It uses the custom model provided as modelId and run the {@link FunctionName#TEXT_EMBEDDING}. + * The return will be sent using the actionListener which will have a {@link List} of {@link List} + * of {@link Float} in the order of inputText. We are not making this function generic enough to + * take any function or TaskType as currently we need to run only TextEmbedding tasks only. + * + * @param targetResponseFilters {@link List} of {@link String} which filters out the responses + * @param modelId {@link String} + * @param inputText {@link List} of {@link String} on which inference needs to happen + * @param listener {@link ActionListener} which will be called when prediction is + * completed or errored out. + */ + public void inferenceSentences( + @NonNull final List targetResponseFilters, + @NonNull final String modelId, + @NonNull final List inputText, + @NonNull final ActionListener>> listener + ) { + inferenceSentences(targetResponseFilters, modelId, inputText, null, listener); + } + + /** + * Abstraction to call predict function of api of MLClient with provided targetResponse filters. + * It uses the custom model provided as modelId and run the {@link FunctionName#TEXT_EMBEDDING}. + * The return will be sent using the actionListener which will have a {@link List} of {@link List} + * of {@link Float} in the order of inputText. We are not making this function generic enough to + * take any function or TaskType as currently we need to run only TextEmbedding tasks only. Supports + * passing {@link MLAlgoParams} to the inference. If the model is asymmetric, passing a + * {@link org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters} + * is mandatory. This method will check whether the model being used is asymmetric and correctly + * handle the parameter, so it's okay to always pass the parameter (even if the model is symmetric). + * + * @param targetResponseFilters {@link List} of {@link String} which filters out the responses + * @param modelId {@link String} + * @param inputText {@link List} of {@link String} on which inference needs to happen + * @param mlAlgoParams {@link MLAlgoParams} which will be used to run the inference + * @param listener {@link ActionListener} which will be called when prediction is + * completed or errored out. */ public void inferenceSentences( @NonNull final List targetResponseFilters, @NonNull final String modelId, @NonNull final List inputText, + @Nullable final MLAlgoParams mlAlgoParams, @NonNull final ActionListener>> listener ) { - retryableInferenceSentencesWithVectorResult(targetResponseFilters, modelId, inputText, 0, listener); + retryableInferenceSentencesWithVectorResult(targetResponseFilters, modelId, inputText, mlAlgoParams, 0, listener); } public void inferenceSentencesWithMapResult( @@ -113,35 +199,65 @@ public void inferenceSentencesWithMapResult( @NonNull final List inputText, @NonNull final ActionListener>> listener ) { - retryableInferenceSentencesWithMapResult(modelId, inputText, 0, listener); + retryableInferenceSentencesWithMapResult(modelId, inputText, null, 0, listener); } /** - * Abstraction to call predict function of api of MLClient with provided targetResponse filters. It uses the - * custom model provided as modelId and run the {@link FunctionName#TEXT_EMBEDDING}. The return will be sent - * using the actionListener which will have a list of floats in the order of inputText. + * Abstraction to call predict function of api of MLClient with provided targetResponse filters. + * It uses the custom model provided as modelId and run the {@link FunctionName#TEXT_EMBEDDING}. + * The return will be sent using the actionListener which will have a list of floats in the order + * of inputText. * - * @param modelId {@link String} - * @param inputObjects {@link Map} of {@link String}, {@link String} on which inference needs to happen - * @param listener {@link ActionListener} which will be called when prediction is completed or errored out. + * @param modelId {@link String} + * @param inputObjects {@link Map} of {@link String}, {@link String} on which inference needs to + * happen + * @param listener {@link ActionListener} which will be called when prediction is completed or + * errored out. */ public void inferenceSentences( @NonNull final String modelId, @NonNull final Map inputObjects, @NonNull final ActionListener> listener ) { - retryableInferenceSentencesWithSingleVectorResult(TARGET_RESPONSE_FILTERS, modelId, inputObjects, 0, listener); + inferenceSentences(modelId, inputObjects, null, listener); } /** - * Abstraction to call predict function of api of MLClient. It uses the custom model provided as modelId and the - * {@link FunctionName#TEXT_SIMILARITY}. The return will be sent via actionListener as a list of floats representing - * the similarity scores of the texts w.r.t. the query text, in the order of the input texts. + * Abstraction to call predict function of api of MLClient with provided targetResponse filters. + * It uses the custom model provided as modelId and run the {@link FunctionName#TEXT_EMBEDDING}. + * The return will be sent using the actionListener which will have a list of floats in the order + * of inputText. Supports passing {@link MLAlgoParams} to the inference. If the model is asymmetric, + * passing a + * {@link org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters} + * is mandatory. This method will check whether the model being used is asymmetric and correctly + * handle the parameter, so it's okay to always pass the parameter (even if the model is symmetric). * - * @param modelId {@link String} ML-Commons Model Id + * @param modelId {@link String} + * @param inputObjects {@link Map} of {@link String}, {@link String} on which inference needs to + * happen + * @param mlAlgoParams {@link MLAlgoParams} which will be used to run the inference + * @param listener {@link ActionListener} which will be called when prediction is completed or + * errored out. + */ + public void inferenceSentences( + @NonNull final String modelId, + @NonNull final Map inputObjects, + @Nullable final MLAlgoParams mlAlgoParams, + @NonNull final ActionListener> listener + ) { + retryableInferenceSentencesWithSingleVectorResult(TARGET_RESPONSE_FILTERS, modelId, inputObjects, mlAlgoParams, 0, listener); + } + + /** + * Abstraction to call predict function of api of MLClient. It uses the custom model provided as + * modelId and the {@link FunctionName#TEXT_SIMILARITY}. The return will be sent via + * actionListener as a list of floats representing the similarity scores of the texts w.r.t. the + * query text, in the order of the input texts. + * + * @param modelId {@link String} ML-Commons Model Id * @param queryText {@link String} The query to compare all the inputText to * @param inputText {@link List} of {@link String} The texts to compare to the query - * @param listener {@link ActionListener} receives the result of the inference + * @param listener {@link ActionListener} receives the result of the inference */ public void inferenceSimilarity( @NonNull final String modelId, @@ -155,42 +271,95 @@ public void inferenceSimilarity( private void retryableInferenceSentencesWithMapResult( final String modelId, final List inputText, + final MLAlgoParams mlAlgoParams, final int retryTime, final ActionListener>> listener ) { - MLInput mlInput = createMLTextInput(null, inputText); - mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> { - final List> result = buildMapResultFromResponse(mlOutput); - listener.onResponse(result); - }, e -> { - if (RetryUtil.shouldRetry(e, retryTime)) { - final int retryTimeAdd = retryTime + 1; - retryableInferenceSentencesWithMapResult(modelId, inputText, retryTimeAdd, listener); - } else { - listener.onFailure(e); - } - })); + + Consumer runPrediction = isAsymmetricModel -> { + MLInput mlInput = createMLTextInput(null, inputText, isAsymmetricModel ? mlAlgoParams : null); + mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> { + final List> result = buildMapResultFromResponse(mlOutput); + listener.onResponse(result); + }, e -> { + if (RetryUtil.shouldRetry(e, retryTime)) { + final int retryTimeAdd = retryTime + 1; + retryableInferenceSentencesWithMapResult(modelId, inputText, mlAlgoParams, retryTimeAdd, listener); + } else { + listener.onFailure(e); + } + })); + }; + + checkModelAsymmetryAndThenPredict(modelId, listener::onFailure, runPrediction); } private void retryableInferenceSentencesWithVectorResult( final List targetResponseFilters, final String modelId, final List inputText, + final MLAlgoParams mlAlgoParams, final int retryTime, final ActionListener>> listener ) { - MLInput mlInput = createMLTextInput(targetResponseFilters, inputText); - mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> { - final List> vector = buildVectorFromResponse(mlOutput); - listener.onResponse(vector); - }, e -> { - if (RetryUtil.shouldRetry(e, retryTime)) { - final int retryTimeAdd = retryTime + 1; - retryableInferenceSentencesWithVectorResult(targetResponseFilters, modelId, inputText, retryTimeAdd, listener); - } else { - listener.onFailure(e); + + Consumer runPrediction = isAsymmetricModel -> { + MLInput mlInput = createMLTextInput(targetResponseFilters, inputText, isAsymmetricModel ? mlAlgoParams : null); + mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> { + final List> vector = buildVectorFromResponse(mlOutput); + listener.onResponse(vector); + }, e -> { + if (RetryUtil.shouldRetry(e, retryTime)) { + final int retryTimeAdd = retryTime + 1; + retryableInferenceSentencesWithVectorResult( + targetResponseFilters, + modelId, + inputText, + mlAlgoParams, + retryTimeAdd, + listener + ); + } else { + listener.onFailure(e); + } + })); + }; + + checkModelAsymmetryAndThenPredict(modelId, listener::onFailure, runPrediction); + } + + /** + * Check if the model is asymmetric and then run the prediction. Model asymmetry is a concept + * that is specific to TextEmbeddingModelConfig. If the model is not a TextEmbeddingModel, then + * this check is not applicable. + * + * The asymmetry of a model is static for a given model. To avoid repeated checks for the same + * model, we cache the model asymmetry status. Non-TextEmbeddingModels are cached as false. + * + * @param modelId The model id to check + * @param onFailure The action to take if the model cannot be retrieved + * @param runPrediction The action to take if the model is successfully retrieved + */ + private void checkModelAsymmetryAndThenPredict(String modelId, Consumer onFailure, Consumer runPrediction) { + CheckedConsumer checkModelAsymmetryListener = model -> { + MLModelConfig modelConfig = model.getModelConfig(); + if (!(modelConfig instanceof TextEmbeddingModelConfig)) { + modelAsymmetryCache.putIfAbsent(modelId, false); + return; } - })); + final TextEmbeddingModelConfig textEmbeddingModelConfig = (TextEmbeddingModelConfig) modelConfig; + final boolean isAsymmetricModel = textEmbeddingModelConfig.getPassagePrefix() != null + || textEmbeddingModelConfig.getQueryPrefix() != null; + modelAsymmetryCache.putIfAbsent(modelId, isAsymmetricModel); + }; + if (modelAsymmetryCache.containsKey(modelId)) { + runPrediction.accept(modelAsymmetryCache.get(modelId)); + } else { + mlClient.getModel(modelId, ActionListener.wrap(mlModel -> { + checkModelAsymmetryListener.accept(mlModel); + runPrediction.accept(modelAsymmetryCache.get(modelId)); + }, onFailure)); + } } private void retryableInferenceSimilarityWithVectorResult( @@ -213,10 +382,10 @@ private void retryableInferenceSimilarityWithVectorResult( })); } - private MLInput createMLTextInput(final List targetResponseFilters, List inputText) { + private MLInput createMLTextInput(final List targetResponseFilters, List inputText, MLAlgoParams mlAlgoParams) { final ModelResultFilter modelResultFilter = new ModelResultFilter(false, true, targetResponseFilters, null); final MLInputDataset inputDataset = new TextDocsInputDataSet(inputText, modelResultFilter); - return new MLInput(FunctionName.TEXT_EMBEDDING, null, inputDataset); + return new MLInput(FunctionName.TEXT_EMBEDDING, mlAlgoParams, inputDataset); } private MLInput createMLTextPairsInput(final String query, final List inputText) { @@ -264,25 +433,42 @@ private void retryableInferenceSentencesWithSingleVectorResult( final List targetResponseFilters, final String modelId, final Map inputObjects, + final MLAlgoParams mlAlgoParams, final int retryTime, final ActionListener> listener ) { - MLInput mlInput = createMLMultimodalInput(targetResponseFilters, inputObjects); - mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> { - final List vector = buildSingleVectorFromResponse(mlOutput); - log.debug("Inference Response for input sentence is : {} ", vector); - listener.onResponse(vector); - }, e -> { - if (RetryUtil.shouldRetry(e, retryTime)) { - final int retryTimeAdd = retryTime + 1; - retryableInferenceSentencesWithSingleVectorResult(targetResponseFilters, modelId, inputObjects, retryTimeAdd, listener); - } else { - listener.onFailure(e); - } - })); + + Consumer predictConsumer = isAsymmetricModel -> { + MLInput mlInput = createMLMultimodalInput(targetResponseFilters, inputObjects, isAsymmetricModel ? mlAlgoParams : null); + mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> { + final List vector = buildSingleVectorFromResponse(mlOutput); + log.debug("Inference Response for input sentence is : {} ", vector); + listener.onResponse(vector); + }, e -> { + if (RetryUtil.shouldRetry(e, retryTime)) { + final int retryTimeAdd = retryTime + 1; + retryableInferenceSentencesWithSingleVectorResult( + targetResponseFilters, + modelId, + inputObjects, + mlAlgoParams, + retryTimeAdd, + listener + ); + } else { + listener.onFailure(e); + } + })); + }; + + checkModelAsymmetryAndThenPredict(modelId, listener::onFailure, predictConsumer); } - private MLInput createMLMultimodalInput(final List targetResponseFilters, final Map input) { + private MLInput createMLMultimodalInput( + final List targetResponseFilters, + final Map input, + MLAlgoParams mlAlgoParams + ) { List inputText = new ArrayList<>(); inputText.add(input.get(INPUT_TEXT)); if (input.containsKey(INPUT_IMAGE)) { @@ -290,6 +476,6 @@ private MLInput createMLMultimodalInput(final List targetResponseFilters } final ModelResultFilter modelResultFilter = new ModelResultFilter(false, true, targetResponseFilters, null); final MLInputDataset inputDataset = new TextDocsInputDataSet(inputText, modelResultFilter); - return new MLInput(FunctionName.TEXT_EMBEDDING, null, inputDataset); + return new MLInput(FunctionName.TEXT_EMBEDDING, mlAlgoParams, inputDataset); } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java index c8f9f080d..23dc6af49 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java @@ -13,6 +13,8 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.env.Environment; import org.opensearch.ingest.IngestDocument; +import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters; +import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters.EmbeddingContentType; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import lombok.extern.log4j.Log4j2; @@ -47,10 +49,15 @@ public void doExecute( List inferenceList, BiConsumer handler ) { - mlCommonsClientAccessor.inferenceSentences(this.modelId, inferenceList, ActionListener.wrap(vectors -> { - setVectorFieldsToDocument(ingestDocument, ProcessMap, vectors); - handler.accept(ingestDocument, null); - }, e -> { handler.accept(null, e); })); + mlCommonsClientAccessor.inferenceSentences( + this.modelId, + inferenceList, + AsymmetricTextEmbeddingParameters.builder().embeddingContentType(EmbeddingContentType.PASSAGE).build(), + ActionListener.wrap(vectors -> { + setVectorFieldsToDocument(ingestDocument, ProcessMap, vectors); + handler.accept(ingestDocument, null); + }, e -> { handler.accept(null, e); }) + ); } @Override diff --git a/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java index 915a79117..c8b2a1d4a 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java @@ -44,6 +44,8 @@ import org.opensearch.knn.index.query.parser.RescoreParser; import org.opensearch.knn.index.query.rescore.RescoreContext; import org.opensearch.neuralsearch.common.MinClusterVersionUtil; +import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters; +import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters.EmbeddingContentType; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import com.google.common.annotations.VisibleForTesting; @@ -333,10 +335,15 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) { inferenceInput.put(INPUT_IMAGE, queryImage()); } queryRewriteContext.registerAsyncAction( - ((client, actionListener) -> ML_CLIENT.inferenceSentences(modelId(), inferenceInput, ActionListener.wrap(floatList -> { - vectorSetOnce.set(vectorAsListToArray(floatList)); - actionListener.onResponse(null); - }, actionListener::onFailure))) + ((client, actionListener) -> ML_CLIENT.inferenceSentences( + modelId(), + inferenceInput, + AsymmetricTextEmbeddingParameters.builder().embeddingContentType(EmbeddingContentType.QUERY).build(), + ActionListener.wrap(floatList -> { + vectorSetOnce.set(vectorAsListToArray(floatList)); + actionListener.onResponse(null); + }, actionListener::onFailure) + )) ); return new NeuralQueryBuilder( fieldName(), diff --git a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java index 3749e63dc..3c0376909 100644 --- a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java @@ -23,7 +23,11 @@ import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.core.action.ActionListener; import org.opensearch.ml.client.MachineLearningNodeClient; +import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters; +import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters.EmbeddingContentType; +import org.opensearch.ml.common.model.TextEmbeddingModelConfig; import org.opensearch.ml.common.output.MLOutput; import org.opensearch.ml.common.output.model.MLResultDataType; import org.opensearch.ml.common.output.model.ModelTensor; @@ -59,8 +63,14 @@ public void testInferenceSentence_whenValidInput_thenSuccess() { actionListener.onResponse(createModelTensorOutput(TestCommonConstants.PREDICT_VECTOR_ARRAY)); return null; }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + setupMocksForTextEmbeddingModelAsymmetryCheck(false); - accessor.inferenceSentence(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST.get(0), singleSentenceResultListener); + accessor.inferenceSentence( + TestCommonConstants.MODEL_ID, + TestCommonConstants.SENTENCES_LIST.get(0), + AsymmetricTextEmbeddingParameters.builder().embeddingContentType(EmbeddingContentType.PASSAGE).build(), + singleSentenceResultListener + ); Mockito.verify(client) .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); @@ -68,6 +78,19 @@ public void testInferenceSentence_whenValidInput_thenSuccess() { Mockito.verifyNoMoreInteractions(singleSentenceResultListener); } + private void setupMocksForTextEmbeddingModelAsymmetryCheck(boolean isAsymmetric) { + MLModel modelMock = mock(MLModel.class); + TextEmbeddingModelConfig textEmbeddingModelConfigMock = mock(TextEmbeddingModelConfig.class); + Mockito.when(textEmbeddingModelConfigMock.getPassagePrefix()).thenReturn(isAsymmetric ? "passage: " : null); + Mockito.when(textEmbeddingModelConfigMock.getQueryPrefix()).thenReturn(isAsymmetric ? "query: " : null); + Mockito.when(modelMock.getModelConfig()).thenReturn(textEmbeddingModelConfigMock); + Mockito.doAnswer(invocation -> { + final ActionListener actionListener = invocation.getArgument(1); + actionListener.onResponse(modelMock); + return null; + }).when(client).getModel(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(ActionListener.class)); + } + public void testInferenceSentences_whenValidInputThenSuccess() { final List> vectorList = new ArrayList<>(); vectorList.add(Arrays.asList(TestCommonConstants.PREDICT_VECTOR_ARRAY)); @@ -76,6 +99,8 @@ public void testInferenceSentences_whenValidInputThenSuccess() { actionListener.onResponse(createModelTensorOutput(TestCommonConstants.PREDICT_VECTOR_ARRAY)); return null; }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + setupMocksForTextEmbeddingModelAsymmetryCheck(false); + accessor.inferenceSentences(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener); Mockito.verify(client) @@ -92,6 +117,8 @@ public void testInferenceSentences_whenResultFromClient_thenEmptyVectorList() { actionListener.onResponse(createModelTensorOutput(new Float[] {})); return null; }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + setupMocksForTextEmbeddingModelAsymmetryCheck(false); + accessor.inferenceSentences(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener); Mockito.verify(client) @@ -107,6 +134,8 @@ public void testInferenceSentences_whenExceptionFromMLClient_thenFailure() { actionListener.onFailure(exception); return null; }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + setupMocksForTextEmbeddingModelAsymmetryCheck(false); + accessor.inferenceSentences( TestCommonConstants.TARGET_RESPONSE_FILTERS, TestCommonConstants.MODEL_ID, @@ -130,6 +159,9 @@ public void testInferenceSentences_whenNodeNotConnectedException_thenRetry_3Time actionListener.onFailure(nodeNodeConnectedException); return null; }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + + setupMocksForTextEmbeddingModelAsymmetryCheck(false); + accessor.inferenceSentences( TestCommonConstants.TARGET_RESPONSE_FILTERS, TestCommonConstants.MODEL_ID, @@ -149,6 +181,9 @@ public void testInferenceSentences_whenNotConnectionException_thenNoRetry() { actionListener.onFailure(illegalStateException); return null; }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + + setupMocksForTextEmbeddingModelAsymmetryCheck(false); + accessor.inferenceSentences( TestCommonConstants.TARGET_RESPONSE_FILTERS, TestCommonConstants.MODEL_ID, @@ -161,6 +196,62 @@ public void testInferenceSentences_whenNotConnectionException_thenNoRetry() { Mockito.verify(resultListener).onFailure(illegalStateException); } + public void testInferenceSentences_whenModelAsymmetric_thenSuccess() { + final List vector = new ArrayList<>(List.of(TestCommonConstants.PREDICT_VECTOR_ARRAY)); + Mockito.doAnswer(invocation -> { + final ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(createModelTensorOutput(TestCommonConstants.PREDICT_VECTOR_ARRAY)); + return null; + }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + setupMocksForTextEmbeddingModelAsymmetryCheck(true); + + accessor.inferenceSentence( + TestCommonConstants.MODEL_ID, + TestCommonConstants.SENTENCES_LIST.get(0), + AsymmetricTextEmbeddingParameters.builder().embeddingContentType(EmbeddingContentType.PASSAGE).build(), + singleSentenceResultListener + ); + + Mockito.verify(client) + .predict( + Mockito.eq(TestCommonConstants.MODEL_ID), + Mockito.argThat((MLInput input) -> input.getParameters() != null), + Mockito.isA(ActionListener.class) + ); + Mockito.verify(singleSentenceResultListener).onResponse(vector); + Mockito.verifyNoMoreInteractions(singleSentenceResultListener); + } + + public void testInferenceSentences_whenGetModelException_thenFailure() { + final List vector = new ArrayList<>(List.of(TestCommonConstants.PREDICT_VECTOR_ARRAY)); + Mockito.doAnswer(invocation -> { + final ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(createModelTensorOutput(TestCommonConstants.PREDICT_VECTOR_ARRAY)); + return null; + }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + RuntimeException exception = new RuntimeException("Bam!"); + setupMocksForTextEmbeddingModelAsymmetryCheck(exception); + + accessor.inferenceSentence( + TestCommonConstants.MODEL_ID, + TestCommonConstants.SENTENCES_LIST.get(0), + AsymmetricTextEmbeddingParameters.builder().embeddingContentType(EmbeddingContentType.PASSAGE).build(), + singleSentenceResultListener + ); + + Mockito.verify(client).getModel(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(ActionListener.class)); + Mockito.verify(singleSentenceResultListener).onFailure(exception); + Mockito.verifyNoMoreInteractions(singleSentenceResultListener); + } + + private void setupMocksForTextEmbeddingModelAsymmetryCheck(Exception exception) { + Mockito.doAnswer(invocation -> { + final ActionListener actionListener = invocation.getArgument(1); + actionListener.onFailure(exception); + return null; + }).when(client).getModel(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(ActionListener.class)); + } + public void testInferenceSentencesWithMapResult_whenValidInput_thenSuccess() { final Map map = Map.of("key", "value"); final ActionListener>> resultListener = mock(ActionListener.class); @@ -169,6 +260,9 @@ public void testInferenceSentencesWithMapResult_whenValidInput_thenSuccess() { actionListener.onResponse(createModelTensorOutput(map)); return null; }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + + setupMocksForTextEmbeddingModelAsymmetryCheck(false); + accessor.inferenceSentencesWithMapResult(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener); Mockito.verify(client) @@ -185,6 +279,9 @@ public void testInferenceSentencesWithMapResult_whenTensorOutputListEmpty_thenEx actionListener.onResponse(modelTensorOutput); return null; }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + + setupMocksForTextEmbeddingModelAsymmetryCheck(false); + accessor.inferenceSentencesWithMapResult(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener); Mockito.verify(client) @@ -209,6 +306,9 @@ public void testInferenceSentencesWithMapResult_whenModelTensorListEmpty_thenExc actionListener.onResponse(modelTensorOutput); return null; }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + + setupMocksForTextEmbeddingModelAsymmetryCheck(false); + accessor.inferenceSentencesWithMapResult(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener); Mockito.verify(client) @@ -236,6 +336,9 @@ public void testInferenceSentencesWithMapResult_whenModelTensorListSizeBiggerTha actionListener.onResponse(modelTensorOutput); return null; }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + + setupMocksForTextEmbeddingModelAsymmetryCheck(false); + accessor.inferenceSentencesWithMapResult(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener); Mockito.verify(client) @@ -255,6 +358,9 @@ public void testInferenceSentencesWithMapResult_whenRetryableException_retry3Tim return null; }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); final ActionListener>> resultListener = mock(ActionListener.class); + + setupMocksForTextEmbeddingModelAsymmetryCheck(false); + accessor.inferenceSentencesWithMapResult(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener); Mockito.verify(client, times(4)) @@ -270,6 +376,9 @@ public void testInferenceSentencesWithMapResult_whenNotRetryableException_thenFa return null; }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); final ActionListener>> resultListener = mock(ActionListener.class); + + setupMocksForTextEmbeddingModelAsymmetryCheck(false); + accessor.inferenceSentencesWithMapResult(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener); Mockito.verify(client, times(1)) @@ -285,6 +394,8 @@ public void testInferenceMultimodal_whenValidInput_thenSuccess() { return null; }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + setupMocksForTextEmbeddingModelAsymmetryCheck(false); + accessor.inferenceSentences(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_MAP, singleSentenceResultListener); Mockito.verify(client) @@ -300,6 +411,9 @@ public void testInferenceMultimodal_whenExceptionFromMLClient_thenFailure() { actionListener.onFailure(exception); return null; }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + + setupMocksForTextEmbeddingModelAsymmetryCheck(false); + accessor.inferenceSentences(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_MAP, singleSentenceResultListener); Mockito.verify(client) @@ -318,6 +432,9 @@ public void testInferenceSentencesMultimodal_whenNodeNotConnectedException_thenR actionListener.onFailure(nodeNodeConnectedException); return null; }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + + setupMocksForTextEmbeddingModelAsymmetryCheck(false); + accessor.inferenceSentences(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_MAP, singleSentenceResultListener); Mockito.verify(client, times(4)) @@ -333,6 +450,8 @@ public void testInferenceSimilarity_whenValidInput_thenSuccess() { return null; }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + setupMocksForTextEmbeddingModelAsymmetryCheck(false); + accessor.inferenceSimilarity( TestCommonConstants.MODEL_ID, "is it sunny", @@ -354,6 +473,8 @@ public void testInferencesSimilarity_whenExceptionFromMLClient_ThenFail() { return null; }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + setupMocksForTextEmbeddingModelAsymmetryCheck(false); + accessor.inferenceSimilarity( TestCommonConstants.MODEL_ID, "is it sunny", @@ -378,6 +499,8 @@ public void testInferenceSimilarity_whenNodeNotConnectedException_ThenTryThreeTi return null; }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + setupMocksForTextEmbeddingModelAsymmetryCheck(false); + accessor.inferenceSimilarity( TestCommonConstants.MODEL_ID, "is it sunny", diff --git a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorIT.java b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorIT.java index 4afa4031d..1611daaed 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorIT.java @@ -51,6 +51,7 @@ public class TextEmbeddingProcessorIT extends BaseNeuralSearchIT { private final String INGEST_DOC2 = Files.readString(Path.of(classLoader.getResource("processor/ingest_doc2.json").toURI())); private final String INGEST_DOC3 = Files.readString(Path.of(classLoader.getResource("processor/ingest_doc3.json").toURI())); private final String INGEST_DOC4 = Files.readString(Path.of(classLoader.getResource("processor/ingest_doc4.json").toURI())); + private final String INGEST_DOC5 = Files.readString(Path.of(classLoader.getResource("processor/ingest_doc5.json").toURI())); private final String BULK_ITEM_TEMPLATE = Files.readString( Path.of(classLoader.getResource("processor/bulk_item_template.json").toURI()) ); @@ -244,6 +245,11 @@ private String uploadTextEmbeddingModel() throws Exception { return registerModelGroupAndUploadModel(requestBody); } + private String uploadAsymmetricEmbeddingModel() throws Exception { + String requestBody = Files.readString(Path.of(classLoader.getResource("processor/UploadAsymmetricModelRequestBody.json").toURI())); + return registerModelGroupAndUploadModel(requestBody); + } + private void createTextEmbeddingIndex() throws Exception { createIndexWithConfiguration( INDEX_NAME, @@ -252,6 +258,20 @@ private void createTextEmbeddingIndex() throws Exception { ); } + public void testAsymmetricTextEmbeddingProcessor() throws Exception { + String modelId = null; + try { + modelId = uploadAsymmetricEmbeddingModel(); + loadModel(modelId); + createPipelineProcessor(modelId, PIPELINE_NAME, ProcessorType.TEXT_EMBEDDING, 2); + createTextEmbeddingIndex(); + ingestDocument(INGEST_DOC5, null); + assertEquals(1, getDocCount(INDEX_NAME)); + } finally { + wipeOfTestResources(INDEX_NAME, PIPELINE_NAME, modelId, null); + } + } + private void ingestDocument(String doc, String id) throws Exception { String endpoint; if (StringUtils.isEmpty(id)) { diff --git a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java index 97e85e46e..1d83c8c95 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java @@ -44,6 +44,7 @@ import org.opensearch.ingest.IngestDocument; import org.opensearch.ingest.IngestDocumentWrapper; import org.opensearch.ingest.Processor; +import org.opensearch.ml.common.input.parameter.MLAlgoParams; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import org.opensearch.neuralsearch.processor.factory.TextEmbeddingProcessorFactory; @@ -151,10 +152,10 @@ public void testExecute_successful() { List> modelTensorList = createMockVectorResult(); doAnswer(invocation -> { - ActionListener>> listener = invocation.getArgument(2); + ActionListener>> listener = invocation.getArgument(3); listener.onResponse(modelTensorList); return null; - }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(ActionListener.class)); + }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(MLAlgoParams.class), isA(ActionListener.class)); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); @@ -184,7 +185,8 @@ public void testExecute_whenInferenceThrowInterruptedException_throwRuntimeExcep DESCRIPTION, config ); - doThrow(new RuntimeException()).when(accessor).inferenceSentences(anyString(), anyList(), isA(ActionListener.class)); + doThrow(new RuntimeException()).when(accessor) + .inferenceSentences(anyString(), anyList(), isA(MLAlgoParams.class), isA(ActionListener.class)); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); verify(handler).accept(isNull(), any(RuntimeException.class)); @@ -230,10 +232,10 @@ public void testExecute_withListTypeInput_successful() { List> modelTensorList = createMockVectorResult(); doAnswer(invocation -> { - ActionListener>> listener = invocation.getArgument(2); + ActionListener>> listener = invocation.getArgument(3); listener.onResponse(modelTensorList); return null; - }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(ActionListener.class)); + }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(MLAlgoParams.class), isA(ActionListener.class)); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); @@ -306,10 +308,10 @@ public void testExecute_withMapTypeInput_successful() { List> modelTensorList = createMockVectorResult(); doAnswer(invocation -> { - ActionListener>> listener = invocation.getArgument(2); + ActionListener>> listener = invocation.getArgument(3); listener.onResponse(modelTensorList); return null; - }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(ActionListener.class)); + }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(MLAlgoParams.class), isA(ActionListener.class)); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); @@ -347,10 +349,10 @@ public void testNestedFieldInMapping_withMapTypeInput_successful() { List> modelTensorList = createRandomOneDimensionalMockVector(1, 100, 0.0f, 1.0f); doAnswer(invocation -> { - ActionListener>> listener = invocation.getArgument(2); + ActionListener>> listener = invocation.getArgument(3); listener.onResponse(modelTensorList); return null; - }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(ActionListener.class)); + }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(MLAlgoParams.class), isA(ActionListener.class)); processor.execute(ingestDocument, (BiConsumer) (doc, ex) -> {}); assertNotNull(ingestDocument); @@ -407,10 +409,10 @@ public void testNestedFieldInMappingForSourceAndDestination_withIngestDocumentHa List> modelTensorList = createRandomOneDimensionalMockVector(1, 100, 0.0f, 1.0f); doAnswer(invocation -> { - ActionListener>> listener = invocation.getArgument(2); + ActionListener>> listener = invocation.getArgument(3); listener.onResponse(modelTensorList); return null; - }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(ActionListener.class)); + }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(MLAlgoParams.class), isA(ActionListener.class)); processor.execute(ingestDocument, (BiConsumer) (doc, ex) -> {}); assertNotNull(ingestDocument); @@ -465,10 +467,10 @@ public void testNestedFieldInMappingForSourceAndDestination_withIngestDocumentWi List> modelTensorList = createRandomOneDimensionalMockVector(1, 100, 0.0f, 1.0f); doAnswer(invocation -> { - ActionListener>> listener = invocation.getArgument(2); + ActionListener>> listener = invocation.getArgument(3); listener.onResponse(modelTensorList); return null; - }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(ActionListener.class)); + }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(MLAlgoParams.class), isA(ActionListener.class)); processor.execute(ingestDocument, (BiConsumer) (doc, ex) -> {}); assertNotNull(ingestDocument); @@ -516,10 +518,10 @@ public void testNestedFieldInMappingMixedSyntax_withMapTypeInput_successful() { List> modelTensorList = createRandomOneDimensionalMockVector(1, 100, 0.0f, 1.0f); doAnswer(invocation -> { - ActionListener>> listener = invocation.getArgument(2); + ActionListener>> listener = invocation.getArgument(3); listener.onResponse(modelTensorList); return null; - }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(ActionListener.class)); + }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(MLAlgoParams.class), isA(ActionListener.class)); processor.execute(ingestDocument, (BiConsumer) (doc, ex) -> {}); assertNotNull(ingestDocument); @@ -585,10 +587,10 @@ public void testExecute_MLClientAccessorThrowFail_handlerFailure() { TextEmbeddingProcessor processor = createInstanceWithLevel1MapConfig(); doAnswer(invocation -> { - ActionListener>> listener = invocation.getArgument(2); + ActionListener>> listener = invocation.getArgument(3); listener.onFailure(new IllegalArgumentException("illegal argument")); return null; - }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(ActionListener.class)); + }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(MLAlgoParams.class), isA(ActionListener.class)); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); diff --git a/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java index 6d8e810f3..5efbf3869 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java @@ -649,10 +649,10 @@ public void testRewrite_whenVectorSupplierNull_thenSetVectorSupplier() { List expectedVector = Arrays.asList(1.0f, 2.0f, 3.0f, 4.0f, 5.0f); MLCommonsClientAccessor mlCommonsClientAccessor = mock(MLCommonsClientAccessor.class); doAnswer(invocation -> { - ActionListener> listener = invocation.getArgument(2); + ActionListener> listener = invocation.getArgument(3); listener.onResponse(expectedVector); return null; - }).when(mlCommonsClientAccessor).inferenceSentences(any(), anyMap(), any()); + }).when(mlCommonsClientAccessor).inferenceSentences(any(), anyMap(), any(), any()); NeuralQueryBuilder.initialize(mlCommonsClientAccessor); final CountDownLatch inProgressLatch = new CountDownLatch(1); @@ -685,10 +685,10 @@ public void testRewrite_whenVectorSupplierNullAndQueryTextAndImageTextSet_thenSe List expectedVector = Arrays.asList(1.0f, 2.0f, 3.0f, 4.0f, 5.0f); MLCommonsClientAccessor mlCommonsClientAccessor = mock(MLCommonsClientAccessor.class); doAnswer(invocation -> { - ActionListener> listener = invocation.getArgument(2); + ActionListener> listener = invocation.getArgument(3); listener.onResponse(expectedVector); return null; - }).when(mlCommonsClientAccessor).inferenceSentences(any(), anyMap(), any()); + }).when(mlCommonsClientAccessor).inferenceSentences(any(), anyMap(), any(), any()); NeuralQueryBuilder.initialize(mlCommonsClientAccessor); final CountDownLatch inProgressLatch = new CountDownLatch(1); diff --git a/src/test/resources/processor/UploadAsymmetricModelRequestBody.json b/src/test/resources/processor/UploadAsymmetricModelRequestBody.json new file mode 100644 index 000000000..8c5b6ec18 --- /dev/null +++ b/src/test/resources/processor/UploadAsymmetricModelRequestBody.json @@ -0,0 +1,17 @@ +{ + "name": "traced_small_model", + "version": "1.0.0", + "model_format": "TORCH_SCRIPT", + "model_task_type": "text_embedding", + "model_content_hash_value": "e13b74006290a9d0f58c1376f9629d4ebc05a0f9385f40db837452b167ae9021", + "model_group_id": "%s", + "model_config": { + "model_type": "bert", + "embedding_dimension": 768, + "framework_type": "sentence_transformers", + "passage_prefix" : "passage: ", + "query_prefix" : "query: ", + "all_config": "{\"architectures\":[\"BertModel\"],\"max_position_embeddings\":512,\"model_type\":\"bert\",\"num_attention_heads\":12,\"num_hidden_layers\":6}" + }, + "url": "https://github.com/opensearch-project/ml-commons/blob/2.x/ml-algorithms/src/test/resources/org/opensearch/ml/engine/algorithms/text_embedding/traced_small_model.zip?raw=true" +} diff --git a/src/test/resources/processor/ingest_doc5.json b/src/test/resources/processor/ingest_doc5.json new file mode 100644 index 000000000..e3302c75a --- /dev/null +++ b/src/test/resources/processor/ingest_doc5.json @@ -0,0 +1,21 @@ +{ + "title": "This is a good day", + "description": "daily logging", + "favor_list": [ + "test", + "hello", + "mock" + ], + "favorites": { + "game": "overwatch", + "movie": null + }, + "nested_passages": [ + { + "text": "hello" + }, + { + "text": "world" + } + ] +}