Skip to content

Commit

Permalink
revert accidental removal of tests from
Browse files Browse the repository at this point in the history
  • Loading branch information
br3no committed Oct 22, 2024
1 parent bc6caad commit c9d0d7f
Showing 1 changed file with 175 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,32 @@

import java.io.IOException;
import java.net.URISyntaxException;
import java.net.URL;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;

import org.apache.commons.lang3.StringUtils;
import org.apache.hc.core5.http.HttpHeaders;
import org.apache.hc.core5.http.io.entity.EntityUtils;
import org.apache.hc.core5.http.message.BasicHeader;
import org.apache.lucene.search.join.ScoreMode;
import org.junit.Before;
import org.opensearch.client.Response;
import org.opensearch.common.xcontent.XContentHelper;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.neuralsearch.BaseNeuralSearchIT;

import com.google.common.collect.ImmutableList;
import org.opensearch.neuralsearch.query.NeuralQueryBuilder;

public class TextEmbeddingProcessorIT extends BaseNeuralSearchIT {

Expand Down Expand Up @@ -58,9 +66,7 @@ public void setUp() throws Exception {
public void testTextEmbeddingProcessor() throws Exception {
String modelId = null;
try {
modelId = uploadTextEmbeddingModel(
Files.readString(Path.of(classLoader.getResource("processor/UploadModelRequestBody.json").toURI()))
);
modelId = uploadTextEmbeddingModel();
loadModel(modelId);
createPipelineProcessor(modelId, PIPELINE_NAME, ProcessorType.TEXT_EMBEDDING);
createTextEmbeddingIndex();
Expand All @@ -71,7 +77,170 @@ public void testTextEmbeddingProcessor() throws Exception {
}
}

private String uploadTextEmbeddingModel(String requestBody) throws Exception {
public void testTextEmbeddingProcessor_batch() throws Exception {
String modelId = null;
try {
modelId = uploadTextEmbeddingModel();
loadModel(modelId);
createPipelineProcessor(modelId, PIPELINE_NAME, ProcessorType.TEXT_EMBEDDING, 2);
createTextEmbeddingIndex();
ingestBatchDocumentWithBulk("batch_", 2, Collections.emptySet(), Collections.emptySet());
assertEquals(2, getDocCount(INDEX_NAME));

ingestDocument(String.format(LOCALE, INGEST_DOC1, "success"), "1");
ingestDocument(String.format(LOCALE, INGEST_DOC2, "success"), "2");

assertEquals(getDocById(INDEX_NAME, "1").get("_source"), getDocById(INDEX_NAME, "batch_1").get("_source"));
assertEquals(getDocById(INDEX_NAME, "2").get("_source"), getDocById(INDEX_NAME, "batch_2").get("_source"));
} finally {
wipeOfTestResources(INDEX_NAME, PIPELINE_NAME, modelId, null);
}
}

public void testNestedFieldMapping_whenDocumentsIngested_thenSuccessful() throws Exception {
String modelId = null;
try {
modelId = uploadTextEmbeddingModel();
loadModel(modelId);
createPipelineProcessor(modelId, PIPELINE_NAME, ProcessorType.TEXT_EMBEDDING_WITH_NESTED_FIELDS_MAPPING);
createTextEmbeddingIndex();
ingestDocument(INGEST_DOC3, "3");
ingestDocument(INGEST_DOC4, "4");

assertDoc(
(Map<String, Object>) getDocById(INDEX_NAME, "3").get("_source"),
TEXT_FIELD_VALUE_1,
Optional.of(TEXT_FIELD_VALUE_3)
);
assertDoc((Map<String, Object>) getDocById(INDEX_NAME, "4").get("_source"), TEXT_FIELD_VALUE_2, Optional.empty());

NeuralQueryBuilder neuralQueryBuilderQuery = new NeuralQueryBuilder(
LEVEL_1_FIELD + "." + LEVEL_2_FIELD + "." + LEVEL_3_FIELD_CONTAINER + "." + LEVEL_3_FIELD_EMBEDDING,
QUERY_TEXT,
"",
modelId,
10,
null,
null,
null,
null,
null,
null
);
QueryBuilder queryNestedLowerLevel = QueryBuilders.nestedQuery(
LEVEL_1_FIELD + "." + LEVEL_2_FIELD,
neuralQueryBuilderQuery,
ScoreMode.Total
);
QueryBuilder queryNestedHighLevel = QueryBuilders.nestedQuery(LEVEL_1_FIELD, queryNestedLowerLevel, ScoreMode.Total);

Map<String, Object> searchResponseAsMap = search(INDEX_NAME, queryNestedHighLevel, 2);
assertNotNull(searchResponseAsMap);

Map<String, Object> hits = (Map<String, Object>) searchResponseAsMap.get("hits");
assertNotNull(hits);

assertEquals(1.0, hits.get("max_score"));
List<Map<String, Object>> listOfHits = (List<Map<String, Object>>) hits.get("hits");
assertNotNull(listOfHits);
assertEquals(2, listOfHits.size());

Map<String, Object> innerHitDetails = listOfHits.get(0);
assertEquals("3", innerHitDetails.get("_id"));
assertEquals(1.0, innerHitDetails.get("_score"));

innerHitDetails = listOfHits.get(1);
assertEquals("4", innerHitDetails.get("_id"));
assertTrue((double) innerHitDetails.get("_score") <= 1.0);
} finally {
wipeOfTestResources(INDEX_NAME, PIPELINE_NAME, modelId, null);
}
}

private void assertDoc(Map<String, Object> sourceMap, String textFieldValue, Optional<String> level3ExpectedValue) {
assertNotNull(sourceMap);
assertTrue(sourceMap.containsKey(LEVEL_1_FIELD));
Map<String, Object> nestedPassages = (Map<String, Object>) sourceMap.get(LEVEL_1_FIELD);
assertTrue(nestedPassages.containsKey(LEVEL_2_FIELD));
Map<String, Object> level2 = (Map<String, Object>) nestedPassages.get(LEVEL_2_FIELD);
assertEquals(textFieldValue, level2.get(LEVEL_3_FIELD_TEXT));
Map<String, Object> level3 = (Map<String, Object>) level2.get(LEVEL_3_FIELD_CONTAINER);
List<Double> embeddings = (List<Double>) level3.get(LEVEL_3_FIELD_EMBEDDING);
assertEquals(768, embeddings.size());
for (Double embedding : embeddings) {
assertTrue(embedding >= 0.0 && embedding <= 1.0);
}
if (level3ExpectedValue.isPresent()) {
assertEquals(level3ExpectedValue.get(), level3.get("level_4_text_field"));
}
}

public void testTextEmbeddingProcessor_withBatchSizeInProcessor() throws Exception {
String modelId = null;
try {
modelId = uploadTextEmbeddingModel();
loadModel(modelId);
URL pipelineURLPath = classLoader.getResource("processor/PipelineConfigurationWithBatchSize.json");
Objects.requireNonNull(pipelineURLPath);
String requestBody = Files.readString(Path.of(pipelineURLPath.toURI()));
createPipelineProcessor(requestBody, PIPELINE_NAME, modelId, null);
createTextEmbeddingIndex();
int docCount = 5;
ingestBatchDocumentWithBulk("batch_", docCount, Collections.emptySet(), Collections.emptySet());
assertEquals(5, getDocCount(INDEX_NAME));

for (int i = 0; i < docCount; ++i) {
String template = List.of(INGEST_DOC1, INGEST_DOC2).get(i % 2);
String payload = String.format(LOCALE, template, "success");
ingestDocument(payload, String.valueOf(i + 1));
}

for (int i = 0; i < docCount; ++i) {
assertEquals(
getDocById(INDEX_NAME, String.valueOf(i + 1)).get("_source"),
getDocById(INDEX_NAME, "batch_" + (i + 1)).get("_source")
);

}
} finally {
wipeOfTestResources(INDEX_NAME, PIPELINE_NAME, modelId, null);
}
}

public void testTextEmbeddingProcessor_withFailureAndSkip() throws Exception {
String modelId = null;
try {
modelId = uploadTextEmbeddingModel();
loadModel(modelId);
URL pipelineURLPath = classLoader.getResource("processor/PipelineConfigurationWithBatchSize.json");
Objects.requireNonNull(pipelineURLPath);
String requestBody = Files.readString(Path.of(pipelineURLPath.toURI()));
createPipelineProcessor(requestBody, PIPELINE_NAME, modelId, null);
createTextEmbeddingIndex();
int docCount = 5;
ingestBatchDocumentWithBulk("batch_", docCount, Set.of(0), Set.of(1));
assertEquals(3, getDocCount(INDEX_NAME));

for (int i = 2; i < docCount; ++i) {
String template = List.of(INGEST_DOC1, INGEST_DOC2).get(i % 2);
String payload = String.format(LOCALE, template, "success");
ingestDocument(payload, String.valueOf(i + 1));
}

for (int i = 2; i < docCount; ++i) {
assertEquals(
getDocById(INDEX_NAME, String.valueOf(i + 1)).get("_source"),
getDocById(INDEX_NAME, "batch_" + (i + 1)).get("_source")
);

}
} finally {
wipeOfTestResources(INDEX_NAME, PIPELINE_NAME, modelId, null);
}
}

private String uploadTextEmbeddingModel() throws Exception {
String requestBody = Files.readString(Path.of(classLoader.getResource("processor/UploadModelRequestBody.json").toURI()));
return registerModelGroupAndUploadModel(requestBody);
}

Expand All @@ -86,11 +255,9 @@ private void createTextEmbeddingIndex() throws Exception {
public void testAsymmetricTextEmbeddingProcessor() throws Exception {
String modelId = null;
try {
modelId = uploadTextEmbeddingModel(
Files.readString(Path.of(classLoader.getResource("processor/UploadAsymmetricModelRequestBody.json").toURI()))
);
modelId = uploadTextEmbeddingModel();
loadModel(modelId);
createPipelineProcessor(modelId, PIPELINE_NAME, ProcessorType.TEXT_EMBEDDING);
createPipelineProcessor(modelId, PIPELINE_NAME, ProcessorType.TEXT_EMBEDDING, 2);
createTextEmbeddingIndex();
ingestDocument();
assertEquals(1, getDocCount(INDEX_NAME));
Expand Down

0 comments on commit c9d0d7f

Please sign in to comment.