diff --git a/qa/rolling-upgrade/src/test/java/org/opensearch/knn/bwc/WarmupIT.java b/qa/rolling-upgrade/src/test/java/org/opensearch/knn/bwc/WarmupIT.java index c760fde2a..c8f9ad6e3 100644 --- a/qa/rolling-upgrade/src/test/java/org/opensearch/knn/bwc/WarmupIT.java +++ b/qa/rolling-upgrade/src/test/java/org/opensearch/knn/bwc/WarmupIT.java @@ -9,13 +9,13 @@ import org.opensearch.knn.index.KNNSettings; import java.util.Collections; +import java.util.List; import static org.opensearch.knn.TestUtils.NODES_BWC_CLUSTER; public class WarmupIT extends AbstractRollingUpgradeTestCase { private static final String TEST_FIELD = "test-field"; private static final int DIMENSIONS = 5; - private static final int K = 5; private static final int NUM_DOCS = 10; public void testKNNWarmup() throws Exception { @@ -23,45 +23,22 @@ public void testKNNWarmup() throws Exception { switch (getClusterType()) { case OLD: createKnnIndex(testIndex, getKNNDefaultIndexSettings(), createKnnIndexMapping(TEST_FIELD, DIMENSIONS)); - int docIdOld = 0; - addKNNDocs(testIndex, TEST_FIELD, DIMENSIONS, docIdOld, NUM_DOCS); + addKNNDocs(testIndex, TEST_FIELD, DIMENSIONS, 0, NUM_DOCS); break; case MIXED: - int totalDocsCountMixed; - int docIdMixed; - if (isFirstMixedRound()) { - docIdMixed = NUM_DOCS; - totalDocsCountMixed = 2 * NUM_DOCS; - } else { - docIdMixed = 2 * NUM_DOCS; - totalDocsCountMixed = 3 * NUM_DOCS; - } - updateIndexSettings(testIndex, Settings.builder().put(KNNSettings.INDEX_KNN_ADVANCED_APPROXIMATE_THRESHOLD, 0)); - validateKNNWarmupOnUpgrade(totalDocsCountMixed, docIdMixed); + int graphCount = getTotalGraphsInCache(); + knnWarmup(Collections.singletonList(testIndex)); + assertTrue(getTotalGraphsInCache() > graphCount); + clearCache(List.of(testIndex)); break; case UPGRADED: - int docIdUpgraded = 3 * NUM_DOCS; - int totalDocsCountUpgraded = 4 * NUM_DOCS; - validateKNNWarmupOnUpgrade(totalDocsCountUpgraded, docIdUpgraded); - + updateIndexSettings(testIndex, Settings.builder().put(KNNSettings.INDEX_KNN_ADVANCED_APPROXIMATE_THRESHOLD, 0)); + addKNNDocs(testIndex, TEST_FIELD, DIMENSIONS, NUM_DOCS, NUM_DOCS); + int updatedGraphCount = getTotalGraphsInCache(); + knnWarmup(Collections.singletonList(testIndex)); + assertTrue(getTotalGraphsInCache() > updatedGraphCount); deleteKNNIndex(testIndex); } - - } - - // validation steps for KNN Warmup after upgrading each node from old version to new version - public void validateKNNWarmupOnUpgrade(int totalDocsCount, int docId) throws Exception { - int graphCount = getTotalGraphsInCache(); - knnWarmup(Collections.singletonList(testIndex)); - assertTrue(getTotalGraphsInCache() > graphCount); - - addKNNDocs(testIndex, TEST_FIELD, DIMENSIONS, docId, NUM_DOCS); - - int updatedGraphCount = getTotalGraphsInCache(); - knnWarmup(Collections.singletonList(testIndex)); - assertTrue(getTotalGraphsInCache() > updatedGraphCount); - - validateKNNSearch(testIndex, TEST_FIELD, DIMENSIONS, totalDocsCount, K); } }