-
Notifications
You must be signed in to change notification settings - Fork 65
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for asymmetric embedding models #710
base: main
Are you sure you want to change the base?
Conversation
@br3no can you add an entry in the changelog. |
@br3no Thanks for raising the PR. I am wondering do we require this change? In MLCommons repository a generic MLInference processor is getting launched which is supposed to do the inference of any kind of model both during ingestion and search. RFC: opensearch-project/ml-commons#2173 That capability is getting build as of now. Do you think we still need this feature then? |
@navneet1v I have been loosely following the discussions in the mentioned RFC. It's a large change that I don't expect to be stable soon – the PR is very much in flux. Also, I don't see the use-case of asymmetric embedding models being addressed. This PR here is much smaller in comparison and is not in any way in conflict with the RFC work. If once the work on the ML Inference Processors is finished and the use-case is addressed there as well, we can deprecate and eventually remove the functionality again. Until then, this PR offers users the chance to use more modern local embeddings. I'm eager to put this to spin, tbh. |
If that is the case I would recommend posting the same on the RFC to ensure that your use case is handled. On the other hand, I do agree this is an interesting feature. I would like to get some eyes on this change mainly in terms of should this be added or not given a more generic processor is around the corner. As I am of my opinion is concerned the main reason of generic processor was to avoid creating new/updating processors to support new model types which is happening in this PR. Thoughts? @jmazanec15 , @martin-gaievski , @vamshin , @vibrantvarun . Let me add some PMs too for Opensearch-project to know their thoughts. @dylan-tong-aws |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #710 +/- ##
============================================
- Coverage 85.02% 84.41% -0.61%
+ Complexity 790 785 -5
============================================
Files 60 59 -1
Lines 2430 2464 +34
Branches 410 409 -1
============================================
+ Hits 2066 2080 +14
- Misses 202 215 +13
- Partials 162 169 +7 ☔ View full report in Codecov by Sentry. |
@navneet1v I have added a comment earlier today to the RFC (cf. opensearch-project/ml-commons#2173 (comment)). Sure, let's open the discussion and get some PMs into it. I really don't mind leaving this out if the support is introduced in another PR in 2.14. I'm concerned opensearch-project/ml-commons#2173 is a much larger effort, that won't be ready that quickly... It's not about my contribution – I need the feature. 🙃 |
I can see the feature is marked for 2.14 release of Opensearch. Let me add maintainers from ML team too. @mingshl , @ylwu-amzn |
@mingshl @ylwu-amzn, I'd really like to have this feature in 2.14. Do you think this use-case will be fully supported with opensearch-project/ml-commons#2173? Cf. opensearch-project/ml-commons#2173 (comment) If not, I'd be happy to help this PR get merged as an interim solution! Let me know what you think! |
@br3no ml inference processor is targeting at first supporting remote model only. How did you usually connect this model? is it local or remote? if remote, can you please provide a SageMaker deployment code piece then I can quickly test it in 2.14 test cluster. Thanks |
@mingshl sorry for taking so long to answer! The use-case for now is to use a local, asymmetric model such as https://huggingface.co/intfloat/multilingual-e5-small. This PR here is the last puzzle piece to allow one to use these kinds of model and should in principle also work with remote models. It makes sure that the neural-search plugin uses the correct inference parameters when embedding passages and queries with asymmetric models. Regardless of whether the model is local or remote, if you are using asymmetric models, you will need to provide this information anyway. The thing is that asymmetric models need to know at inference time what exactly they are embedding. OpenSearch currently treats embedding models as symmetric, meaning that regardless of whether the text being embedded is a query or a passage, the embedding will be always the same. Asymmetric models require content "hints" to the text being embedded; the model exemplified above uses the string prefixes In opensearch-project/ml-commons#1799 we have added the concept of asymmetric models into ml-commons, introducing the I would really be happy to get this merged as an interim solution until the ml inference processor fully supports this use-case. |
I also vote for this PR in need for this functionality. |
@br3no will it possible if you can contribute back in MLInference processor for local model support? Is that even an option? |
@navneet1v you mean making sure this works there as well? Sure, I can commit to that. I'd propose then to merge this PR now and then start the work to eventually replace this once the MLInference processor supports this use case... |
I am going to pick this PR up right now, @br3no @reuschling sorry for the late reply. Let me check your code first. @br3no could you kindly rebase your code first? It will be more convenient for me to check the function diff. |
@br3no I read through the context and I believe this is a valid use case and this PR provides user the flexibility to integrate the feature, thank you very much! But current the PR needs rebase with master and in case you're not monitoring the PR, I've rebased your code with main and created this PR to your repo: br3no#1, please take time to check the PR, you can fix this PR or cherry-pick my commit to this PR, I would like to push this forward to make it released in 2.18 which is 5th Nov, thanks again. |
@zane-neo great to hear this! I have merged your PR. Is there anything else I need to do? |
@br3no The main blocker for now is the current PR has conflicts needs to be resolved, so I was saying there are two approaches to fix this:
|
Signed-off-by: br3no <[email protected]>
Signed-off-by: br3no <[email protected]>
Signed-off-by: br3no <[email protected]>
Signed-off-by: br3no <[email protected]>
6d3dba6
to
bc6caad
Compare
@zane-neo I rebased the changes and adapted new tests that were affected by my changes. I think this is now good to go. |
Signed-off-by: br3no <[email protected]>
c9d0d7f
to
b19ffe2
Compare
Signed-off-by: br3no <[email protected]>
@@ -40,6 +48,7 @@ | |||
public class MLCommonsClientAccessor { | |||
private static final List<String> TARGET_RESPONSE_FILTERS = List.of("sentence_embedding"); | |||
private final MachineLearningNodeClient mlClient; | |||
private final Map<String, Boolean> modelAsymmetryCache = new ConcurrentHashMap<>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@br3no It's fine since this not accessible from public API, to use huge amount of inexistent models, the malicious actor has to first got the permission to update the pipeline, and if that's true the malicious actor has much more straightforward means to attack the system.
src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java
Show resolved
Hide resolved
Mockito.verifyNoMoreInteractions(singleSentenceResultListener); | ||
} | ||
|
||
public void testInferenceSentences_whenGetModelException_thenFailure() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@br3no I think Martin was talking about adding more test method to increase the test coverage, e.g. add a new test method testInferenceSentences_whenGetModelNodeDisconnectedException_thenRetryToSuccess
, with this method you only need to change the
RuntimeException exception = new RuntimeException("Bam!");
setupMocksForTextEmbeddingModelAsymmetryCheck(exception);
to
NodeDisconnectedException exception = new NodeDisconnectedException("node disconnected!");
setupMocksForTextEmbeddingModelAsymmetryCheck(exception);
in this method, and by that the new method should cover the retry part.
src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorIT.java
Outdated
Show resolved
Hide resolved
Left several comments and please take a look, other part LGTM, thanks. |
Signed-off-by: br3no <[email protected]>
@zane-neo I have pushed a small refactoring of a test, as you suggested. Thanks! I have also added comments to the issues you raised in the other threads. |
Mockito.verifyNoMoreInteractions(singleSentenceResultListener); | ||
} | ||
|
||
public void testInferenceSentences_whenGetModelException_thenFailure() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed, it's fine to not retry the client.getModel
since it's essentially different with client.predict
, for predict request, ml-commons has internal routing logic to dispatch the request to different nodes and it's possible that the node disconnected from cluster right after received the request, then retry is useful as next predict request will be dispatched to a different node in round robin approach. But for client.getMode
it internally uses the client.get
operation which is more deterministic, e.g. user might specify to read the result from primary shard or read from the only one replica, retry won't help in these cases. And it's a valid argue that in OpenSearch not all the operations are wrapped with retry logic.
Thanks @br3no currently it LGTM, approved the PR, I'll push for another approval and then we can merge it. Sadly it's not possible to release it in 2.18 as the release flow has been started and code has been frozen, but I'm sure we can make it released in 2.19. |
src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java
Show resolved
Hide resolved
src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java
Show resolved
Hide resolved
src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java
Show resolved
Hide resolved
src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java
Show resolved
Hide resolved
((client, actionListener) -> ML_CLIENT.inferenceSentences( | ||
modelId(), | ||
inferenceInput, | ||
AsymmetricTextEmbeddingParameters.builder().embeddingContentType(EmbeddingContentType.QUERY).build(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: always add new parameters in the end of argument list so that it is easy to understand the changes.
@br3no How are we ensuring BWC test compatibility? |
@@ -40,6 +48,7 @@ | |||
public class MLCommonsClientAccessor { | |||
private static final List<String> TARGET_RESPONSE_FILTERS = List.of("sentence_embedding"); | |||
private final MachineLearningNodeClient mlClient; | |||
private final Map<String, Boolean> modelAsymmetryCache = new ConcurrentHashMap<>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@br3no thank you for response, it all makes sense.
I think we need to:
- use existing OpenSearch component for cache
- set a limit on either size or time in the cache, that should lower the chance of data node going to the critical level of heap memory usage.
please take a look on these classes: https://github.com/opensearch-project/OpenSearch/blob/main/server/src/main/java/org/opensearch/common/cache/Cache.java and https://github.com/opensearch-project/OpenSearch/blob/main/server/src/main/java/org/opensearch/common/cache/CacheBuilder.java
it has many nice things like multiple modes for eviction and stats
mlCommonsClientAccessor.inferenceSentences( | ||
this.modelId, | ||
inferenceList, | ||
AsymmetricTextEmbeddingParameters.builder().embeddingContentType(EmbeddingContentType.PASSAGE).build(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this can be pulled to a constant level
((client, actionListener) -> ML_CLIENT.inferenceSentences( | ||
modelId(), | ||
inferenceInput, | ||
AsymmetricTextEmbeddingParameters.builder().embeddingContentType(EmbeddingContentType.QUERY).build(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is effectively a constant, no need to compute on every call
Mockito.verifyNoMoreInteractions(singleSentenceResultListener); | ||
} | ||
|
||
public void testInferenceSentences_whenGetModelException_thenFailure() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zane-neo @br3no please help me to understand the flow for reading model config. Is there a chance that this call is remote, meaning going outside the opensearch cluster if model deployed remotely and connected via ml-commons remote connector?
If that's not the case then I'm fine with logic without retry, it's a transport request within single OS domain. In this case where the mode configuration is stored, is it in ml-commons memory or cluster metadata?
If it can be a remote call then we need a retry. Doesn't matter how simple operation is, there can be a transient network issue or anything of this sort. In this case it's different from most of OS APIs, those are primality case 1, transport requests inside cluster.
Hi @vibrantvarun ! Since this PR is targeted towards 2.19, we still have time to raise next PRs before code freeze date of 2.19 (Jan 28, 2025). I think we can leave the BWC tests for the next PR. @br3no Can you discuss with @vibrantvarun on whether a DTO object is needed for this PR? I understand DTO object is a best practice, but I don't want it takes too long before merging this PR. |
src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java
Show resolved
Hide resolved
src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java
Show resolved
Hide resolved
src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java
Show resolved
Hide resolved
src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java
Show resolved
Hide resolved
@@ -40,6 +48,7 @@ | |||
public class MLCommonsClientAccessor { | |||
private static final List<String> TARGET_RESPONSE_FILTERS = List.of("sentence_embedding"); | |||
private final MachineLearningNodeClient mlClient; | |||
private final Map<String, Boolean> modelAsymmetryCache = new ConcurrentHashMap<>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm open to make the change but I don't think it's a mandatory, the cache is to store the model info and it only consumes very little space comparing to a model, so the heap memory could reach throttle when registering a new model instead of writing the little metadata into this cache. The only case is there's a way to bypass the model registration and put huge amount of info to the cache directly which is not a case as above explanation.
Mockito.verifyNoMoreInteractions(singleSentenceResultListener); | ||
} | ||
|
||
public void testInferenceSentences_whenGetModelException_thenFailure() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No case is to fetching model metadata from remote services, cause they are all stored in OS cluster.
This is a valid concern, I took a look on the code and it's could have BWC issue during seder. The latest version serializes the AsymmetricTextEmbeddingParameters, and if a node is deployed with a legacy version OS and when deserialization, it fetches the class from the internal cache map here, and since a legacy version doesn't have this class, the IllegalArgumentException will be thrown. @br3no Can you do a test on this case to double confirm if this is true? Thanks. |
@zane-neo for some reason github ui glitching for me so I cannot put my answer under your comments
|
@br3no DTO object helps to make look code more readable and understandable. Issue 790 is currently not prioritized and is in backlog. We don't know when will that issue will be prioritized. Therefore, I would request you to add it in this PR. It will not be a design level change or something completely changes the purpose of this PR. |
Description
This PR adds support for asymmetric embedding models such as https://huggingface.co/intfloat/multilingual-e5-small to the neural-search plugin.
It builds on the work done in opensearch-project/ml-commons#1799.
Asymmetric embedding models behave differently when embedding passages and queries. For that end, the model must "know" on inference time, what kind of data it is embedding.
The changes are:
1.
src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java
The processor signals it is embedding passages, by passing the new
AsymmetricTextEmbeddingParameters
using the content typeEmbeddingContentType.PASSAGE
.2.
src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java
Analogously, the query builder uses
EmbeddingContentType.QUERY
.3.
src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java
Here is where most of the work was done. The class has been extended in a backwards-compatible way with inference methods that allow one to pass
MLAlgoParams
objects. Usage ofAsymmetricTextEmbeddingParameters
(which implementsMLAlgoParams
) is mandatory for asymmetric models. At the same time symmetric models do not accept them.The only way to know whether a model is asymmetric or symmetric is by reading its model configuration (if the models' configuration contains a
passage_prefix
and/or aquery_prefix
, they are asymmetric, otherwise they are symmetric).The
src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java
class deals with this, keeping the complexity in one place and not requiring any API change to the neural-search plugin (as proposed in #620). When calling the inference methods, clients (such as theTextEmbeddingProcessor
) may pass theAsymmetricTextEmbeddingParameters
object without caring if the model they are using is symmetric or asymmetric. The accessor class will first read the model's configuration (by calling thegetModel
API of themlClient
) and deal appropriately.To avoid adding this extra roundtrip to every inference call, the asymmetry information is kept in a cache in memory.
Issues Resolved
#620
Check List
By submitting this pull request, I confirm that my contribution is made under the terms of the Apache 2.0 license.
For more information on following Developer Certificate of Origin and signing off your commits, please check here.