Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEATURE] Add z-score for the normalization processor #376 #468

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,12 @@ integTest {
if (System.getProperty("test.debug") != null) {
jvmArgs '-agentlib:jdwp=transport=dt_socket,server=y,suspend=y,address=*:5005'
}

systemProperty 'log4j2.configurationFile', "${projectDir}/src/test/resources/log4j2-test.xml"

// Set this to true this if you want to see the logs in the terminal test output.
// note: if left false the log output will still show in your IDE
testLogging.showStandardStreams = true
}

testClusters.integTest {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@
@Log4j2
public class ScoreCombiner {

private static final Float ZERO_SCORE = 0.0f;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any reason why we are removing this?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not in use anywhere in the code


/**
* Performs score combination based on input combination technique. Mutates input object by updating combined scores
* Main steps we're doing for combination:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ public class ScoreNormalizationFactory {
MinMaxScoreNormalizationTechnique.TECHNIQUE_NAME,
new MinMaxScoreNormalizationTechnique(),
L2ScoreNormalizationTechnique.TECHNIQUE_NAME,
new L2ScoreNormalizationTechnique()
new L2ScoreNormalizationTechnique(),
ZScoreNormalizationTechnique.TECHNIQUE_NAME,
new ZScoreNormalizationTechnique()
);

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.neuralsearch.processor.normalization;

import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.Optional;

import lombok.ToString;

import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.opensearch.neuralsearch.processor.CompoundTopDocs;

import com.google.common.primitives.Floats;

/**
* Implementation of z-score normalization technique for hybrid query
* This is currently modeled based on the existing normalization techniques {@link L2ScoreNormalizationTechnique} and {@link MinMaxScoreNormalizationTechnique}
* However, this class as well as the original ones require a significant work to improve style and ease of use, see TODO items below
*/
/*
TODO: Some todo items that apply here but also on the original normalization techniques on which it is modeled {@link L2ScoreNormalizationTechnique} and {@link MinMaxScoreNormalizationTechnique}
1. Random access to abstract list object is a bad practice both stylistically and from performance perspective and should be removed
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please provide an alternative what should be used?

As per my understanding, random access on the List is bad if List concrete implementation is LinkedList. But what I have seen generally is we use ArrayList which is backed by arrays, hence random access is done in constant time.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be fine if we know the exact implementation of List, as Navneet mentioned. But with list we can use functional style easier, without expensive conversion array -> stream, that was a reason why we switched to a List.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Usually it is highly discouraged to do List.get() for an abstract List object because it could be an implementation that doesn't support random access efficiently (e.g. LinkedList). Suggested alternative is to enforce that this is explicitly declared as an ArrayList object throughout the hot path that require random access.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm ok to switch from using general List to ArrayList, that still works with stream API and keep our requirements to a caller code cleaner. I expect that change will affect a lot of classes, thus I prefer to see it as a separate refactoring PR.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@martin-gaievski same here, I added the comment out of intention to propose as a separate refactoring PR.

2. Identical sub queries and their distribution between shards is currently completely implicit based on ordering and should be explicit based on identifier
Copy link
Collaborator

@navneet1v navneet1v Oct 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is really a good thought, but problem is none of the query clauses in Opensearch supports identifiers. During the implementation this was discussed. The problem is the way after QueryPhase the results are returned. They are returned in a ScoreDocs array which doesn't support identifiers.

We can go around that but it will require changes in interface of OpenSearch Core. Hence we decided against it to make sure that we are compatible with OpenSearch core.

If there is an alternative supported in opensearch please let us know, may be we are missing something

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good @navneet1v I will give it some thought and will come up with suggestion. In any case not planning to do as part of this change. Can keep it for now and can suggest refactor or just remove if not achievable.

3. Implicit calculation of numOfSubQueries instead of having a more explicit upstream indicator/metadata regarding it
*/
@ToString(onlyExplicitlyIncluded = true)
public class ZScoreNormalizationTechnique implements ScoreNormalizationTechnique {
@ToString.Include
public static final String TECHNIQUE_NAME = "z_score";
private static final float SINGLE_RESULT_SCORE = 1.0f;

@Override
public void normalize(final List<CompoundTopDocs> queryTopDocs) {
/*
TODO: There is an implicit assumption in this calculation that probably need to be made clearer by passing some metadata with the results.
Currently assuming that finding a single non empty shard result will contain all sub query results with 0 hits.
*/
final Optional<CompoundTopDocs> maybeCompoundTopDocs = queryTopDocs.stream()
.filter(Objects::nonNull)
.filter(topDocs -> topDocs.getTopDocs().size() > 0)
.findAny();

final int numOfSubQueries = maybeCompoundTopDocs.map(compoundTopDocs -> compoundTopDocs.getTopDocs().size()).orElse(0);

// to be done for each subquery
float[] sumPerSubquery = findScoreSumPerSubQuery(queryTopDocs, numOfSubQueries);
long[] elementsPerSubquery = findNumberOfElementsPerSubQuery(queryTopDocs, numOfSubQueries);
float[] meanPerSubQuery = findMeanPerSubquery(sumPerSubquery, elementsPerSubquery);
float[] stdPerSubquery = findStdPerSubquery(queryTopDocs, meanPerSubQuery, elementsPerSubquery, numOfSubQueries);

// do normalization using actual score and z-scores for corresponding sub query
for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) {
if (Objects.isNull(compoundQueryTopDocs)) {
continue;
}
List<TopDocs> topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs();
for (int j = 0; j < topDocsPerSubQuery.size(); j++) {
TopDocs subQueryTopDoc = topDocsPerSubQuery.get(j);
for (ScoreDoc scoreDoc : subQueryTopDoc.scoreDocs) {
scoreDoc.score = normalizeSingleScore(scoreDoc.score, stdPerSubquery[j], meanPerSubQuery[j]);
}
}
}
}

static private float[] findScoreSumPerSubQuery(final List<CompoundTopDocs> queryTopDocs, final int numOfScores) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit. private would be better unless you have specific reason this to be static. Better way would be moving all these methods to another class to make it easier to write unit test.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

convention I was following is that if method is not dependent on any instance object it should be static.
Regarding refactoring method out to utility class, are there any other classes that can use it right now or in the future? Ideally I would like to avoid creating unnecessary abstraction.

final float[] sumOfScorePerSubQuery = new float[numOfScores];
Arrays.fill(sumOfScorePerSubQuery, 0);
// TODO: make this syntactically clearer regarding performance by avoiding List.get(j) with an abstract List type
for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) {
if (Objects.isNull(compoundQueryTopDocs)) {
continue;
}
List<TopDocs> topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs();
for (int j = 0; j < topDocsPerSubQuery.size(); j++) {
sumOfScorePerSubQuery[j] += sumScoreDocsArray(topDocsPerSubQuery.get(j).scoreDocs);
}
}

return sumOfScorePerSubQuery;
}

static private long[] findNumberOfElementsPerSubQuery(final List<CompoundTopDocs> queryTopDocs, final int numOfScores) {
final long[] numberOfElementsPerSubQuery = new long[numOfScores];
Arrays.fill(numberOfElementsPerSubQuery, 0);
// TODO: make this syntactically clearer regarding performance by avoiding List.get(j) with an abstract List type
for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) {
if (Objects.isNull(compoundQueryTopDocs)) {
continue;
}
List<TopDocs> topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs();
for (int j = 0; j < topDocsPerSubQuery.size(); j++) {
numberOfElementsPerSubQuery[j] += topDocsPerSubQuery.get(j).totalHits.value;
}
}

return numberOfElementsPerSubQuery;
}

static private float[] findMeanPerSubquery(final float[] sumPerSubquery, final long[] elementsPerSubquery) {
final float[] meanPerSubQuery = new float[elementsPerSubquery.length];
for (int i = 0; i < elementsPerSubquery.length; i++) {
if (elementsPerSubquery[i] == 0) {
meanPerSubQuery[i] = 0;
} else {
meanPerSubQuery[i] = sumPerSubquery[i] / elementsPerSubquery[i];
}
}

return meanPerSubQuery;
}

static private float[] findStdPerSubquery(
final List<CompoundTopDocs> queryTopDocs,
final float[] meanPerSubQuery,
final long[] elementsPerSubquery,
final int numOfScores
) {
final double[] deltaSumPerSubquery = new double[numOfScores];
Arrays.fill(deltaSumPerSubquery, 0);
// TODO: make this syntactically clearer regarding performance by avoiding List.get(j) with an abstract List type
for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) {
if (Objects.isNull(compoundQueryTopDocs)) {
continue;
}
List<TopDocs> topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs();
for (int j = 0; j < topDocsPerSubQuery.size(); j++) {
for (ScoreDoc scoreDoc : topDocsPerSubQuery.get(j).scoreDocs) {
deltaSumPerSubquery[j] += Math.pow(scoreDoc.score - meanPerSubQuery[j], 2);
}
}
}

final float[] stdPerSubQuery = new float[numOfScores];
for (int i = 0; i < deltaSumPerSubquery.length; i++) {
if (elementsPerSubquery[i] == 0) {
stdPerSubQuery[i] = 0;
} else {
stdPerSubQuery[i] = (float) Math.sqrt(deltaSumPerSubquery[i] / elementsPerSubquery[i]);
}
}

return stdPerSubQuery;
}

static private float sumScoreDocsArray(final ScoreDoc[] scoreDocs) {
float sum = 0;
for (ScoreDoc scoreDoc : scoreDocs) {
sum += scoreDoc.score;
}

return sum;
}

private static float normalizeSingleScore(final float score, final float standardDeviation, final float mean) {
// edge case when there is only one score and min and max scores are same
if (Floats.compare(mean, score) == 0) {
return SINGLE_RESULT_SCORE;
}
return (score - mean) / standardDeviation;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -760,4 +760,19 @@ private String registerModelGroup() {
assertNotNull(modelGroupId);
return modelGroupId;
}

protected List<Map<String, Object>> getNestedHits(Map<String, Object> searchResponseAsMap) {
Map<String, Object> hitsMap = (Map<String, Object>) searchResponseAsMap.get("hits");
return (List<Map<String, Object>>) hitsMap.get("hits");
}

protected Map<String, Object> getTotalHits(Map<String, Object> searchResponseAsMap) {
Map<String, Object> hitsMap = (Map<String, Object>) searchResponseAsMap.get("hits");
return (Map<String, Object>) hitsMap.get("total");
}

protected Optional<Float> getMaxScore(Map<String, Object> searchResponseAsMap) {
Map<String, Object> hitsMap = (Map<String, Object>) searchResponseAsMap.get("hits");
return hitsMap.get("max_score") == null ? Optional.empty() : Optional.of(((Double) hitsMap.get("max_score")).floatValue());
}
}
Loading
Loading