Skip to content

Commit

Permalink
feat: implement mvi search filters (#1119)
Browse files Browse the repository at this point in the history
Adds filter expression data types and implementation for search and
searchAndFetchVectors. Includes search filters for equality checks and logical expressions. A followup PR will include the exhaustive set of comparison operators.
  • Loading branch information
malandis authored Feb 2, 2024
1 parent 7d5e0e7 commit b2d6314
Show file tree
Hide file tree
Showing 6 changed files with 562 additions and 34 deletions.
109 changes: 97 additions & 12 deletions packages/client-sdk-nodejs/src/internal/vector-index-data-client.ts
Original file line number Diff line number Diff line change
@@ -1,23 +1,24 @@
import {version} from '../../package.json';
import {IVectorIndexDataClient} from '@gomomento/sdk-core/dist/src/internal/clients/vector/IVectorIndexDataClient';
import {
ALL_VECTOR_METADATA,
CredentialProvider,
InvalidArgumentError,
MomentoLogger,
MomentoLoggerFactory,
SearchOptions,
VECTOR_DEFAULT_TOPK,
VectorCountItems,
VectorDeleteItemBatch,
vectorFilters as F,
VectorGetItemBatch,
VectorGetItemMetadataBatch,
VectorIndexItem,
VectorIndexMetadata,
VectorIndexStoredItem,
VectorSearch,
VectorSearchAndFetchVectors,
VectorIndexMetadata,
VectorIndexItem,
VectorUpsertItemBatch,
VectorIndexStoredItem,
VectorGetItemBatch,
VectorGetItemMetadataBatch,
ALL_VECTOR_METADATA,
VECTOR_DEFAULT_TOPK,
} from '@gomomento/sdk-core';
import {VectorIndexConfiguration} from '../config/vector-index-configuration';
import {ChannelCredentials, Interceptor} from '@grpc/grpc-js';
Expand Down Expand Up @@ -297,7 +298,7 @@ export class VectorIndexDataClient implements IVectorIndexDataClient {
return await this.sendSearch(indexName, queryVector, options);
}

private static prepareMetadataRequest(
private static buildMetadataRequest(
options?: SearchOptions
): vectorindex._MetadataRequest {
const metadataRequest = new vectorindex._MetadataRequest();
Expand Down Expand Up @@ -325,6 +326,84 @@ export class VectorIndexDataClient implements IVectorIndexDataClient {
}
}

private static buildFilterExpression(
filterExpression?: F.VectorFilterExpression
): vectorindex._FilterExpression | undefined {
if (filterExpression === undefined) {
return undefined;
} else if (filterExpression instanceof F.VectorFilterAndExpression) {
return new vectorindex._FilterExpression({
and_expression: new vectorindex._AndExpression({
first_expression: VectorIndexDataClient.buildFilterExpression(
filterExpression.FirstExpression
),
second_expression: VectorIndexDataClient.buildFilterExpression(
filterExpression.SecondExpression
),
}),
});
} else if (filterExpression instanceof F.VectorFilterOrExpression) {
return new vectorindex._FilterExpression({
or_expression: new vectorindex._OrExpression({
first_expression: VectorIndexDataClient.buildFilterExpression(
filterExpression.FirstExpression
),
second_expression: VectorIndexDataClient.buildFilterExpression(
filterExpression.SecondExpression
),
}),
});
} else if (filterExpression instanceof F.VectorFilterNotExpression) {
return new vectorindex._FilterExpression({
not_expression: new vectorindex._NotExpression({
expression_to_negate: VectorIndexDataClient.buildFilterExpression(
filterExpression.Expression
),
}),
});
} else if (filterExpression instanceof F.VectorFilterEqualsExpression) {
if (typeof filterExpression.Value === 'string') {
return new vectorindex._FilterExpression({
equals_expression: new vectorindex._EqualsExpression({
field: filterExpression.Field,
string_value: filterExpression.Value,
}),
});
} else if (typeof filterExpression.Value === 'number') {
if (Number.isInteger(filterExpression.Value)) {
return new vectorindex._FilterExpression({
equals_expression: new vectorindex._EqualsExpression({
field: filterExpression.Field,
integer_value: filterExpression.Value,
}),
});
} else {
return new vectorindex._FilterExpression({
equals_expression: new vectorindex._EqualsExpression({
field: filterExpression.Field,
float_value: filterExpression.Value,
}),
});
}
} else if (typeof filterExpression.Value === 'boolean') {
return new vectorindex._FilterExpression({
equals_expression: new vectorindex._EqualsExpression({
field: filterExpression.Field,
boolean_value: filterExpression.Value,
}),
});
} else {
throw new InvalidArgumentError(
`Filter value for field '${
filterExpression.Field
}' is not a valid type. Value is of type '${typeof filterExpression.Value} and is not a string, number, or boolean.'`
);
}
}

throw new InvalidArgumentError('Filter expression is not a valid type.');
}

private static deserializeMetadata(
metadata: vectorindex._Metadata[],
errorCallback: () => void
Expand Down Expand Up @@ -364,7 +443,10 @@ export class VectorIndexDataClient implements IVectorIndexDataClient {
index_name: indexName,
query_vector: new vectorindex._Vector({elements: queryVector}),
top_k: options?.topK ?? VECTOR_DEFAULT_TOPK,
metadata_fields: VectorIndexDataClient.prepareMetadataRequest(options),
metadata_fields: VectorIndexDataClient.buildMetadataRequest(options),
filter_expression: VectorIndexDataClient.buildFilterExpression(
options?.filterExpression
),
});
VectorIndexDataClient.applyScoreThreshold(request, options);

Expand Down Expand Up @@ -438,7 +520,10 @@ export class VectorIndexDataClient implements IVectorIndexDataClient {
index_name: indexName,
query_vector: new vectorindex._Vector({elements: queryVector}),
top_k: options?.topK ?? VECTOR_DEFAULT_TOPK,
metadata_fields: VectorIndexDataClient.prepareMetadataRequest(options),
metadata_fields: VectorIndexDataClient.buildMetadataRequest(options),
filter_expression: VectorIndexDataClient.buildFilterExpression(
options?.filterExpression
),
});
VectorIndexDataClient.applyScoreThreshold(request, options);

Expand Down Expand Up @@ -504,7 +589,7 @@ export class VectorIndexDataClient implements IVectorIndexDataClient {
const request = new vectorindex._GetItemBatchRequest({
index_name: indexName,
ids: ids,
metadata_fields: VectorIndexDataClient.prepareMetadataRequest({
metadata_fields: VectorIndexDataClient.buildMetadataRequest({
metadataFields: ALL_VECTOR_METADATA,
}),
});
Expand Down Expand Up @@ -586,7 +671,7 @@ export class VectorIndexDataClient implements IVectorIndexDataClient {
const request = new vectorindex._GetItemMetadataBatchRequest({
index_name: indexName,
ids: ids,
metadata_fields: VectorIndexDataClient.prepareMetadataRequest({
metadata_fields: VectorIndexDataClient.buildMetadataRequest({
metadataFields: ALL_VECTOR_METADATA,
}),
});
Expand Down
110 changes: 96 additions & 14 deletions packages/client-sdk-web/src/internal/vector-index-data-client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,23 @@ import {VectorIndexClient} from '@gomomento/generated-types-webtext/dist/Vectori
import * as vectorindex from '@gomomento/generated-types-webtext/dist/vectorindex_pb';
import {IVectorIndexDataClient} from '@gomomento/sdk-core/dist/src/internal/clients/vector/IVectorIndexDataClient';
import {
ALL_VECTOR_METADATA,
InvalidArgumentError,
MomentoLogger,
SearchOptions,
UnknownError,
VECTOR_DEFAULT_TOPK,
VectorCountItems,
VectorDeleteItemBatch,
vectorFilters as F,
VectorGetItemBatch,
VectorGetItemMetadataBatch,
VectorIndexItem,
VectorIndexMetadata,
VectorIndexStoredItem,
VectorSearch,
VectorSearchAndFetchVectors,
VectorIndexMetadata,
VectorIndexItem,
VectorUpsertItemBatch,
VectorIndexStoredItem,
VectorGetItemBatch,
VectorGetItemMetadataBatch,
InvalidArgumentError,
UnknownError,
ALL_VECTOR_METADATA,
VECTOR_DEFAULT_TOPK,
} from '@gomomento/sdk-core';
import {CacheServiceErrorMapper} from '../errors/cache-service-error-mapper';
import {
Expand Down Expand Up @@ -273,7 +274,7 @@ export class VectorIndexDataClient implements IVectorIndexDataClient {
return await this.sendSearch(indexName, queryVector, options);
}

private static prepareMetadataRequest(
private static buildMetadataRequest(
options?: SearchOptions
): vectorindex._MetadataRequest {
const metadataRequest = new vectorindex._MetadataRequest();
Expand Down Expand Up @@ -303,6 +304,81 @@ export class VectorIndexDataClient implements IVectorIndexDataClient {
}
}

private static buildFilterExpression(
filterExpression?: F.VectorFilterExpression
): vectorindex._FilterExpression | undefined {
if (filterExpression === undefined) {
return undefined;
}

if (filterExpression instanceof F.VectorFilterAndExpression) {
const and = new vectorindex._AndExpression();
and.setFirstExpression(
VectorIndexDataClient.buildFilterExpression(
filterExpression.FirstExpression
)
);
and.setSecondExpression(
VectorIndexDataClient.buildFilterExpression(
filterExpression.SecondExpression
)
);
const expression = new vectorindex._FilterExpression();
expression.setAndExpression(and);
return expression;
} else if (filterExpression instanceof F.VectorFilterOrExpression) {
const or = new vectorindex._OrExpression();
or.setFirstExpression(
VectorIndexDataClient.buildFilterExpression(
filterExpression.FirstExpression
)
);
or.setSecondExpression(
VectorIndexDataClient.buildFilterExpression(
filterExpression.SecondExpression
)
);
const expression = new vectorindex._FilterExpression();
expression.setOrExpression(or);
return expression;
} else if (filterExpression instanceof F.VectorFilterNotExpression) {
const not = new vectorindex._NotExpression();
not.setExpressionToNegate(
VectorIndexDataClient.buildFilterExpression(filterExpression.Expression)
);
const expression = new vectorindex._FilterExpression();
expression.setNotExpression(not);
return expression;
} else if (filterExpression instanceof F.VectorFilterEqualsExpression) {
const equals = new vectorindex._EqualsExpression();
equals.setField(filterExpression.Field);

if (typeof filterExpression.Value === 'string') {
equals.setStringValue(filterExpression.Value);
} else if (typeof filterExpression.Value === 'number') {
if (Number.isInteger(filterExpression.Value)) {
equals.setIntegerValue(filterExpression.Value);
} else {
equals.setFloatValue(filterExpression.Value);
}
} else if (typeof filterExpression.Value === 'boolean') {
equals.setBooleanValue(filterExpression.Value);
} else {
throw new InvalidArgumentError(
`Filter value for field '${
filterExpression.Field
}' is not a valid type. Value is of type '${typeof filterExpression.Value} and is not a string, number, or boolean.'`
);
}

const expression = new vectorindex._FilterExpression();
expression.setEqualsExpression(equals);
return expression;
}

throw new InvalidArgumentError('Filter expression is not a valid type.');
}

private static deserializeMetadata(
metadata: vectorindex._Metadata[],
errorCallback: () => void
Expand Down Expand Up @@ -345,9 +421,12 @@ export class VectorIndexDataClient implements IVectorIndexDataClient {
request.setQueryVector(vector);
request.setTopK(options?.topK ?? VECTOR_DEFAULT_TOPK);
request.setMetadataFields(
VectorIndexDataClient.prepareMetadataRequest(options)
VectorIndexDataClient.buildMetadataRequest(options)
);
VectorIndexDataClient.applyScoreThreshold(request, options);
request.setFilterExpression(
VectorIndexDataClient.buildFilterExpression(options?.filterExpression)
);

return await new Promise((resolve, reject) => {
this.client.search(
Expand Down Expand Up @@ -425,9 +504,12 @@ export class VectorIndexDataClient implements IVectorIndexDataClient {
request.setQueryVector(vector);
request.setTopK(options?.topK ?? VECTOR_DEFAULT_TOPK);
request.setMetadataFields(
VectorIndexDataClient.prepareMetadataRequest(options)
VectorIndexDataClient.buildMetadataRequest(options)
);
VectorIndexDataClient.applyScoreThreshold(request, options);
request.setFilterExpression(
VectorIndexDataClient.buildFilterExpression(options?.filterExpression)
);

return await new Promise((resolve, reject) => {
this.client.searchAndFetchVectors(
Expand Down Expand Up @@ -495,7 +577,7 @@ export class VectorIndexDataClient implements IVectorIndexDataClient {
request.setIndexName(indexName);
request.setIdsList(ids);
request.setMetadataFields(
VectorIndexDataClient.prepareMetadataRequest({
VectorIndexDataClient.buildMetadataRequest({
metadataFields: ALL_VECTOR_METADATA,
})
);
Expand Down Expand Up @@ -584,7 +666,7 @@ export class VectorIndexDataClient implements IVectorIndexDataClient {
request.setIndexName(indexName);
request.setIdsList(ids);
request.setMetadataFields(
VectorIndexDataClient.prepareMetadataRequest({
VectorIndexDataClient.buildMetadataRequest({
metadataFields: ALL_VECTOR_METADATA,
})
);
Expand Down
Loading

0 comments on commit b2d6314

Please sign in to comment.