Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for asymmetric embedding models #710

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),

## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.18...2.x)
### Features
- Add support for asymmetric embedding models ([#710](https://github.com/opensearch-project/neural-search/pull/710))
### Enhancements
### Bug Fixes
### Infrastructure
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import org.opensearch.core.action.ActionListener;
import org.opensearch.env.Environment;
import org.opensearch.ingest.IngestDocument;
import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters;
import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters.EmbeddingContentType;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;

import lombok.extern.log4j.Log4j2;
Expand Down Expand Up @@ -47,10 +49,15 @@ public void doExecute(
List<String> inferenceList,
BiConsumer<IngestDocument, Exception> handler
) {
mlCommonsClientAccessor.inferenceSentences(this.modelId, inferenceList, ActionListener.wrap(vectors -> {
setVectorFieldsToDocument(ingestDocument, ProcessMap, vectors);
handler.accept(ingestDocument, null);
}, e -> { handler.accept(null, e); }));
mlCommonsClientAccessor.inferenceSentences(
this.modelId,
inferenceList,
AsymmetricTextEmbeddingParameters.builder().embeddingContentType(EmbeddingContentType.PASSAGE).build(),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this can be pulled to a constant level

ActionListener.wrap(vectors -> {
setVectorFieldsToDocument(ingestDocument, ProcessMap, vectors);
handler.accept(ingestDocument, null);
}, e -> { handler.accept(null, e); })
);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
import org.opensearch.knn.index.query.parser.RescoreParser;
import org.opensearch.knn.index.query.rescore.RescoreContext;
import org.opensearch.neuralsearch.common.MinClusterVersionUtil;
import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters;
import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters.EmbeddingContentType;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;

import com.google.common.annotations.VisibleForTesting;
Expand Down Expand Up @@ -333,10 +335,15 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) {
inferenceInput.put(INPUT_IMAGE, queryImage());
}
queryRewriteContext.registerAsyncAction(
((client, actionListener) -> ML_CLIENT.inferenceSentences(modelId(), inferenceInput, ActionListener.wrap(floatList -> {
vectorSetOnce.set(vectorAsListToArray(floatList));
actionListener.onResponse(null);
}, actionListener::onFailure)))
((client, actionListener) -> ML_CLIENT.inferenceSentences(
modelId(),
inferenceInput,
AsymmetricTextEmbeddingParameters.builder().embeddingContentType(EmbeddingContentType.QUERY).build(),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is effectively a constant, no need to compute on every call

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: always add new parameters in the end of argument list so that it is easy to understand the changes.

ActionListener.wrap(floatList -> {
vectorSetOnce.set(vectorAsListToArray(floatList));
actionListener.onResponse(null);
}, actionListener::onFailure)
))
);
return new NeuralQueryBuilder(
fieldName(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters;
import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters.EmbeddingContentType;
import org.opensearch.ml.common.model.TextEmbeddingModelConfig;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.model.MLResultDataType;
import org.opensearch.ml.common.output.model.ModelTensor;
Expand Down Expand Up @@ -59,15 +63,34 @@ public void testInferenceSentence_whenValidInput_thenSuccess() {
actionListener.onResponse(createModelTensorOutput(TestCommonConstants.PREDICT_VECTOR_ARRAY));
return null;
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));
setupMocksForTextEmbeddingModelAsymmetryCheck(false);

accessor.inferenceSentence(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST.get(0), singleSentenceResultListener);
accessor.inferenceSentence(
TestCommonConstants.MODEL_ID,
TestCommonConstants.SENTENCES_LIST.get(0),
AsymmetricTextEmbeddingParameters.builder().embeddingContentType(EmbeddingContentType.PASSAGE).build(),
singleSentenceResultListener
);

Mockito.verify(client)
.predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));
Mockito.verify(singleSentenceResultListener).onResponse(vector);
Mockito.verifyNoMoreInteractions(singleSentenceResultListener);
}

private void setupMocksForTextEmbeddingModelAsymmetryCheck(boolean isAsymmetric) {
MLModel modelMock = mock(MLModel.class);
TextEmbeddingModelConfig textEmbeddingModelConfigMock = mock(TextEmbeddingModelConfig.class);
Mockito.when(textEmbeddingModelConfigMock.getPassagePrefix()).thenReturn(isAsymmetric ? "passage: " : null);
Mockito.when(textEmbeddingModelConfigMock.getQueryPrefix()).thenReturn(isAsymmetric ? "query: " : null);
Mockito.when(modelMock.getModelConfig()).thenReturn(textEmbeddingModelConfigMock);
Mockito.doAnswer(invocation -> {
final ActionListener<MLModel> actionListener = invocation.getArgument(1);
actionListener.onResponse(modelMock);
return null;
}).when(client).getModel(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(ActionListener.class));
}

public void testInferenceSentences_whenValidInputThenSuccess() {
final List<List<Float>> vectorList = new ArrayList<>();
vectorList.add(Arrays.asList(TestCommonConstants.PREDICT_VECTOR_ARRAY));
Expand All @@ -76,6 +99,8 @@ public void testInferenceSentences_whenValidInputThenSuccess() {
actionListener.onResponse(createModelTensorOutput(TestCommonConstants.PREDICT_VECTOR_ARRAY));
return null;
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));
setupMocksForTextEmbeddingModelAsymmetryCheck(false);

accessor.inferenceSentences(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener);

Mockito.verify(client)
Expand All @@ -92,6 +117,8 @@ public void testInferenceSentences_whenResultFromClient_thenEmptyVectorList() {
actionListener.onResponse(createModelTensorOutput(new Float[] {}));
return null;
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));
setupMocksForTextEmbeddingModelAsymmetryCheck(false);

accessor.inferenceSentences(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener);

Mockito.verify(client)
Expand All @@ -107,6 +134,8 @@ public void testInferenceSentences_whenExceptionFromMLClient_thenFailure() {
actionListener.onFailure(exception);
return null;
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));
setupMocksForTextEmbeddingModelAsymmetryCheck(false);

accessor.inferenceSentences(
TestCommonConstants.TARGET_RESPONSE_FILTERS,
TestCommonConstants.MODEL_ID,
Expand All @@ -130,6 +159,9 @@ public void testInferenceSentences_whenNodeNotConnectedException_thenRetry_3Time
actionListener.onFailure(nodeNodeConnectedException);
return null;
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));

setupMocksForTextEmbeddingModelAsymmetryCheck(false);

accessor.inferenceSentences(
TestCommonConstants.TARGET_RESPONSE_FILTERS,
TestCommonConstants.MODEL_ID,
Expand All @@ -149,6 +181,9 @@ public void testInferenceSentences_whenNotConnectionException_thenNoRetry() {
actionListener.onFailure(illegalStateException);
return null;
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));

setupMocksForTextEmbeddingModelAsymmetryCheck(false);

accessor.inferenceSentences(
TestCommonConstants.TARGET_RESPONSE_FILTERS,
TestCommonConstants.MODEL_ID,
Expand All @@ -161,6 +196,62 @@ public void testInferenceSentences_whenNotConnectionException_thenNoRetry() {
Mockito.verify(resultListener).onFailure(illegalStateException);
}

public void testInferenceSentences_whenModelAsymmetric_thenSuccess() {
final List<Float> vector = new ArrayList<>(List.of(TestCommonConstants.PREDICT_VECTOR_ARRAY));
Mockito.doAnswer(invocation -> {
final ActionListener<MLOutput> actionListener = invocation.getArgument(2);
actionListener.onResponse(createModelTensorOutput(TestCommonConstants.PREDICT_VECTOR_ARRAY));
return null;
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));
setupMocksForTextEmbeddingModelAsymmetryCheck(true);

accessor.inferenceSentence(
TestCommonConstants.MODEL_ID,
TestCommonConstants.SENTENCES_LIST.get(0),
AsymmetricTextEmbeddingParameters.builder().embeddingContentType(EmbeddingContentType.PASSAGE).build(),
singleSentenceResultListener
);

Mockito.verify(client)
.predict(
Mockito.eq(TestCommonConstants.MODEL_ID),
Mockito.argThat((MLInput input) -> input.getParameters() != null),
Mockito.isA(ActionListener.class)
);
Mockito.verify(singleSentenceResultListener).onResponse(vector);
Mockito.verifyNoMoreInteractions(singleSentenceResultListener);
}

public void testInferenceSentences_whenGetModelException_thenFailure() {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need to test scenario when we're retrying 1-2 times. I see scenario when first request has failed with error that isn't retryable, this isn't a full coverage

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@martin-gaievski thanks for pointing this out.

I'm wondering though, if we should make this retryable at all. Let me elaborate:

In my understanding, inference requests are retried because they tend to fail more often than regular operations in OpenSearch. I don't know the history and complete reasoning behind this, so I speculate it has to do with the fact that the inference is done natively and that many things can go wrong there.

With my change, if fetching the model information fails (mlClient.getModel(modelId, ...), there is no retry. Model information is fetched the first time inference is requested with a particular model. After that, the result is cached and the method behaves exactly as before the PR.

So my argument is: should we really add a retry logic to this relatively simple operation? If getModel fails, it is most likely to fail again, so retrying wouldn't make sense. If so, one could argue that all operations in OpenSearch should be wrapped in a retry logic.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@br3no I think Martin was talking about adding more test method to increase the test coverage, e.g. add a new test method testInferenceSentences_whenGetModelNodeDisconnectedException_thenRetryToSuccess, with this method you only need to change the

RuntimeException exception = new RuntimeException("Bam!");
setupMocksForTextEmbeddingModelAsymmetryCheck(exception);

to

NodeDisconnectedException exception = new NodeDisconnectedException("node disconnected!");
setupMocksForTextEmbeddingModelAsymmetryCheck(exception);

in this method, and by that the new method should cover the retry part.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zane-neo if I replace the exception, the behavior will not change, i.e. there will not be a successful retry.

This is because the checkModelAsymmetryAndThenPredict method in MLCommonsClientAccessor will not retry mlClient.getModel if there is a failure.

The reasoning was that if there is an exception in mlClient.getModel, which is a very simple operation (it's just fetching the model information), then most likely retrying it will lead to the same failure. If we think this operation should be retried, then, I argued above, all operations in OpenSearch would be equally good candidates to be wrapped in a retry logic.

I'm happy to implement the retry in this operation, but I don't believe it's worth it. It will introduce more complexity to an already quite intricate code path, making it less readable. At the same time I believe it will almost never really be useful, as a failing mlClient.getModel will almost always fail again if repeated.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, it's fine to not retry the client.getModel since it's essentially different with client.predict, for predict request, ml-commons has internal routing logic to dispatch the request to different nodes and it's possible that the node disconnected from cluster right after received the request, then retry is useful as next predict request will be dispatched to a different node in round robin approach. But for client.getMode it internally uses the client.get operation which is more deterministic, e.g. user might specify to read the result from primary shard or read from the only one replica, retry won't help in these cases. And it's a valid argue that in OpenSearch not all the operations are wrapped with retry logic.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zane-neo @br3no please help me to understand the flow for reading model config. Is there a chance that this call is remote, meaning going outside the opensearch cluster if model deployed remotely and connected via ml-commons remote connector?

If that's not the case then I'm fine with logic without retry, it's a transport request within single OS domain. In this case where the mode configuration is stored, is it in ml-commons memory or cluster metadata?

If it can be a remote call then we need a retry. Doesn't matter how simple operation is, there can be a transient network issue or anything of this sort. In this case it's different from most of OS APIs, those are primality case 1, transport requests inside cluster.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No case is to fetching model metadata from remote services, cause they are all stored in OS cluster.

final List<Float> vector = new ArrayList<>(List.of(TestCommonConstants.PREDICT_VECTOR_ARRAY));
Mockito.doAnswer(invocation -> {
final ActionListener<MLOutput> actionListener = invocation.getArgument(2);
actionListener.onResponse(createModelTensorOutput(TestCommonConstants.PREDICT_VECTOR_ARRAY));
return null;
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));
RuntimeException exception = new RuntimeException("Bam!");
setupMocksForTextEmbeddingModelAsymmetryCheck(exception);

accessor.inferenceSentence(
TestCommonConstants.MODEL_ID,
TestCommonConstants.SENTENCES_LIST.get(0),
AsymmetricTextEmbeddingParameters.builder().embeddingContentType(EmbeddingContentType.PASSAGE).build(),
singleSentenceResultListener
);

Mockito.verify(client).getModel(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(ActionListener.class));
Mockito.verify(singleSentenceResultListener).onFailure(exception);
Mockito.verifyNoMoreInteractions(singleSentenceResultListener);
}

private void setupMocksForTextEmbeddingModelAsymmetryCheck(Exception exception) {
Mockito.doAnswer(invocation -> {
final ActionListener<MLModel> actionListener = invocation.getArgument(1);
actionListener.onFailure(exception);
return null;
}).when(client).getModel(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(ActionListener.class));
}

public void testInferenceSentencesWithMapResult_whenValidInput_thenSuccess() {
final Map<String, String> map = Map.of("key", "value");
final ActionListener<List<Map<String, ?>>> resultListener = mock(ActionListener.class);
Expand All @@ -169,6 +260,9 @@ public void testInferenceSentencesWithMapResult_whenValidInput_thenSuccess() {
actionListener.onResponse(createModelTensorOutput(map));
return null;
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));

setupMocksForTextEmbeddingModelAsymmetryCheck(false);

accessor.inferenceSentencesWithMapResult(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener);

Mockito.verify(client)
Expand All @@ -185,6 +279,9 @@ public void testInferenceSentencesWithMapResult_whenTensorOutputListEmpty_thenEx
actionListener.onResponse(modelTensorOutput);
return null;
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));

setupMocksForTextEmbeddingModelAsymmetryCheck(false);

accessor.inferenceSentencesWithMapResult(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener);

Mockito.verify(client)
Expand All @@ -209,6 +306,9 @@ public void testInferenceSentencesWithMapResult_whenModelTensorListEmpty_thenExc
actionListener.onResponse(modelTensorOutput);
return null;
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));

setupMocksForTextEmbeddingModelAsymmetryCheck(false);

accessor.inferenceSentencesWithMapResult(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener);

Mockito.verify(client)
Expand Down Expand Up @@ -236,6 +336,9 @@ public void testInferenceSentencesWithMapResult_whenModelTensorListSizeBiggerTha
actionListener.onResponse(modelTensorOutput);
return null;
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));

setupMocksForTextEmbeddingModelAsymmetryCheck(false);

accessor.inferenceSentencesWithMapResult(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener);

Mockito.verify(client)
Expand All @@ -255,6 +358,9 @@ public void testInferenceSentencesWithMapResult_whenRetryableException_retry3Tim
return null;
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));
final ActionListener<List<Map<String, ?>>> resultListener = mock(ActionListener.class);

setupMocksForTextEmbeddingModelAsymmetryCheck(false);

accessor.inferenceSentencesWithMapResult(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener);

Mockito.verify(client, times(4))
Expand All @@ -270,6 +376,9 @@ public void testInferenceSentencesWithMapResult_whenNotRetryableException_thenFa
return null;
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));
final ActionListener<List<Map<String, ?>>> resultListener = mock(ActionListener.class);

setupMocksForTextEmbeddingModelAsymmetryCheck(false);

accessor.inferenceSentencesWithMapResult(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener);

Mockito.verify(client, times(1))
Expand All @@ -285,6 +394,8 @@ public void testInferenceMultimodal_whenValidInput_thenSuccess() {
return null;
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));

setupMocksForTextEmbeddingModelAsymmetryCheck(false);

accessor.inferenceSentences(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_MAP, singleSentenceResultListener);

Mockito.verify(client)
Expand All @@ -300,6 +411,9 @@ public void testInferenceMultimodal_whenExceptionFromMLClient_thenFailure() {
actionListener.onFailure(exception);
return null;
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));

setupMocksForTextEmbeddingModelAsymmetryCheck(false);

accessor.inferenceSentences(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_MAP, singleSentenceResultListener);

Mockito.verify(client)
Expand All @@ -318,6 +432,9 @@ public void testInferenceSentencesMultimodal_whenNodeNotConnectedException_thenR
actionListener.onFailure(nodeNodeConnectedException);
return null;
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));

setupMocksForTextEmbeddingModelAsymmetryCheck(false);

accessor.inferenceSentences(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_MAP, singleSentenceResultListener);

Mockito.verify(client, times(4))
Expand All @@ -333,6 +450,8 @@ public void testInferenceSimilarity_whenValidInput_thenSuccess() {
return null;
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));

setupMocksForTextEmbeddingModelAsymmetryCheck(false);

accessor.inferenceSimilarity(
TestCommonConstants.MODEL_ID,
"is it sunny",
Expand All @@ -354,6 +473,8 @@ public void testInferencesSimilarity_whenExceptionFromMLClient_ThenFail() {
return null;
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));

setupMocksForTextEmbeddingModelAsymmetryCheck(false);

accessor.inferenceSimilarity(
TestCommonConstants.MODEL_ID,
"is it sunny",
Expand All @@ -378,6 +499,8 @@ public void testInferenceSimilarity_whenNodeNotConnectedException_ThenTryThreeTi
return null;
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));

setupMocksForTextEmbeddingModelAsymmetryCheck(false);

accessor.inferenceSimilarity(
TestCommonConstants.MODEL_ID,
"is it sunny",
Expand Down
Loading
Loading