Skip to content

Commit

Permalink
fix create_pit enum bug
Browse files Browse the repository at this point in the history
Signed-off-by: Peter Alfonsi <[email protected]>
  • Loading branch information
Peter Alfonsi committed Dec 4, 2024
1 parent d199096 commit 0c9110d
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,13 @@ public String getName() {
}

/**
* Returns the SearchPhase name as {@link SearchPhaseName}. Exception will come if SearchPhase name is not defined
* in {@link SearchPhaseName}
* @return {@link SearchPhaseName}
* Returns the SearchPhase name as {@link SearchPhaseName}. If unrecognized, returns the catch-all OTHER_PHASE_TYPES.
*/
public SearchPhaseName getSearchPhaseName() {
return SearchPhaseName.valueOf(name.toUpperCase(Locale.ROOT));
try {
return SearchPhaseName.valueOf(name.toUpperCase(Locale.ROOT));
} catch (IllegalArgumentException e) {
return SearchPhaseName.OTHER_PHASE_TYPES;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,29 @@
*/
@PublicApi(since = "2.9.0")
public enum SearchPhaseName {
DFS_PRE_QUERY("dfs_pre_query"),
QUERY("query"),
FETCH("fetch"),
DFS_QUERY("dfs_query"),
EXPAND("expand"),
CAN_MATCH("can_match");
DFS_PRE_QUERY("dfs_pre_query", true),
QUERY("query", true),
FETCH("fetch", true),
DFS_QUERY("dfs_query", true),
EXPAND("expand", true),
CAN_MATCH("can_match", true),

// A catch-all for other phase types which shouldn't appear in the search phase stats API.
OTHER_PHASE_TYPES("other_phase_types", false);

private final String name;
private final boolean shouldTrack;

SearchPhaseName(final String name) {
SearchPhaseName(final String name, final boolean shouldTrack) {
this.name = name;
this.shouldTrack = shouldTrack;
}

public String getName() {
return name;
}

public boolean shouldTrack() {
return shouldTrack;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -73,20 +73,29 @@ public long getTookMetric() {

@Override
protected void onPhaseStart(SearchPhaseContext context) {
phaseStatsMap.get(context.getCurrentPhase().getSearchPhaseName()).current.inc();
SearchPhaseName phaseName = context.getCurrentPhase().getSearchPhaseName();
if (phaseName.shouldTrack()) {
phaseStatsMap.get(phaseName).current.inc();
}
}

@Override
protected void onPhaseEnd(SearchPhaseContext context, SearchRequestContext searchRequestContext) {
StatsHolder phaseStats = phaseStatsMap.get(context.getCurrentPhase().getSearchPhaseName());
phaseStats.current.dec();
phaseStats.total.inc();
phaseStats.timing.inc(TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - context.getCurrentPhase().getStartTimeInNanos()));
SearchPhaseName phaseName = context.getCurrentPhase().getSearchPhaseName();
if (phaseName.shouldTrack()) {
StatsHolder phaseStats = phaseStatsMap.get(phaseName);
phaseStats.current.dec();
phaseStats.total.inc();
phaseStats.timing.inc(TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - context.getCurrentPhase().getStartTimeInNanos()));
}
}

@Override
protected void onPhaseFailure(SearchPhaseContext context, Throwable cause) {
phaseStatsMap.get(context.getCurrentPhase().getSearchPhaseName()).current.dec();
SearchPhaseName phaseName = context.getCurrentPhase().getSearchPhaseName();
if (phaseName.shouldTrack()) {
phaseStatsMap.get(phaseName).current.dec();
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws

for (SearchPhaseName searchPhaseName : SearchPhaseName.values()) {
PhaseStatsLongHolder statsLongHolder = requestStatsLongHolder.requestStatsHolder.get(searchPhaseName.getName());
if (statsLongHolder == null) {
if (statsLongHolder == null || !searchPhaseName.shouldTrack()) {
continue;
}
builder.startObject(searchPhaseName.getName());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.opensearch.common.settings.Settings;
import org.opensearch.test.OpenSearchTestCase;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand All @@ -25,6 +26,18 @@
import static org.mockito.Mockito.when;

public class SearchRequestStatsTests extends OpenSearchTestCase {

static List<SearchPhaseName> trackablePhases;

static {
trackablePhases = new ArrayList<>();
for (SearchPhaseName searchPhaseName : SearchPhaseName.values()) {
if (searchPhaseName.shouldTrack()) {
trackablePhases.add(searchPhaseName);
}
}
}

public void testSearchRequestStats_OnRequestFailure() {
ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS);
SearchRequestStats testRequestStats = new SearchRequestStats(clusterSettings);
Expand Down Expand Up @@ -67,7 +80,7 @@ public void testSearchRequestPhaseFailure() {
SearchPhase mockSearchPhase = mock(SearchPhase.class);
when(ctx.getCurrentPhase()).thenReturn(mockSearchPhase);

for (SearchPhaseName searchPhaseName : SearchPhaseName.values()) {
for (SearchPhaseName searchPhaseName : trackablePhases) {
when(mockSearchPhase.getSearchPhaseName()).thenReturn(searchPhaseName);
testRequestStats.onPhaseStart(ctx);
assertEquals(1, testRequestStats.getPhaseCurrent(searchPhaseName));
Expand All @@ -84,7 +97,7 @@ public void testSearchRequestStats() {
SearchPhase mockSearchPhase = mock(SearchPhase.class);
when(ctx.getCurrentPhase()).thenReturn(mockSearchPhase);

for (SearchPhaseName searchPhaseName : SearchPhaseName.values()) {
for (SearchPhaseName searchPhaseName : trackablePhases) {
when(mockSearchPhase.getSearchPhaseName()).thenReturn(searchPhaseName);
long tookTimeInMillis = randomIntBetween(1, 10);
testRequestStats.onPhaseStart(ctx);
Expand All @@ -109,10 +122,10 @@ public void testSearchRequestStatsOnPhaseStartConcurrently() throws InterruptedE
ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS);
SearchRequestStats testRequestStats = new SearchRequestStats(clusterSettings);
int numTasks = randomIntBetween(5, 50);
Thread[] threads = new Thread[numTasks * SearchPhaseName.values().length];
Phaser phaser = new Phaser(numTasks * SearchPhaseName.values().length + 1);
CountDownLatch countDownLatch = new CountDownLatch(numTasks * SearchPhaseName.values().length);
for (SearchPhaseName searchPhaseName : SearchPhaseName.values()) {
Thread[] threads = new Thread[numTasks * trackablePhases.size()];
Phaser phaser = new Phaser(numTasks * trackablePhases.size() + 1);
CountDownLatch countDownLatch = new CountDownLatch(numTasks * trackablePhases.size());
for (SearchPhaseName searchPhaseName : trackablePhases) {
SearchPhaseContext ctx = mock(SearchPhaseContext.class);
SearchPhase mockSearchPhase = mock(SearchPhase.class);
when(ctx.getCurrentPhase()).thenReturn(mockSearchPhase);
Expand All @@ -128,7 +141,7 @@ public void testSearchRequestStatsOnPhaseStartConcurrently() throws InterruptedE
}
phaser.arriveAndAwaitAdvance();
countDownLatch.await();
for (SearchPhaseName searchPhaseName : SearchPhaseName.values()) {
for (SearchPhaseName searchPhaseName : trackablePhases) {
assertEquals(numTasks, testRequestStats.getPhaseCurrent(searchPhaseName));
}
}
Expand All @@ -137,11 +150,11 @@ public void testSearchRequestStatsOnPhaseEndConcurrently() throws InterruptedExc
ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS);
SearchRequestStats testRequestStats = new SearchRequestStats(clusterSettings);
int numTasks = randomIntBetween(5, 50);
Thread[] threads = new Thread[numTasks * SearchPhaseName.values().length];
Phaser phaser = new Phaser(numTasks * SearchPhaseName.values().length + 1);
CountDownLatch countDownLatch = new CountDownLatch(numTasks * SearchPhaseName.values().length);
Thread[] threads = new Thread[numTasks * trackablePhases.size()];
Phaser phaser = new Phaser(numTasks * trackablePhases.size() + 1);
CountDownLatch countDownLatch = new CountDownLatch(numTasks * trackablePhases.size());
Map<SearchPhaseName, Long> searchPhaseNameLongMap = new HashMap<>();
for (SearchPhaseName searchPhaseName : SearchPhaseName.values()) {
for (SearchPhaseName searchPhaseName : trackablePhases) {
SearchPhaseContext ctx = mock(SearchPhaseContext.class);
SearchPhase mockSearchPhase = mock(SearchPhase.class);
when(ctx.getCurrentPhase()).thenReturn(mockSearchPhase);
Expand All @@ -168,7 +181,7 @@ public void testSearchRequestStatsOnPhaseEndConcurrently() throws InterruptedExc
}
phaser.arriveAndAwaitAdvance();
countDownLatch.await();
for (SearchPhaseName searchPhaseName : SearchPhaseName.values()) {
for (SearchPhaseName searchPhaseName : trackablePhases) {
assertEquals(numTasks, testRequestStats.getPhaseTotal(searchPhaseName));
assertThat(
testRequestStats.getPhaseMetric(searchPhaseName),
Expand All @@ -181,10 +194,10 @@ public void testSearchRequestStatsOnPhaseFailureConcurrently() throws Interrupte
ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS);
SearchRequestStats testRequestStats = new SearchRequestStats(clusterSettings);
int numTasks = randomIntBetween(5, 50);
Thread[] threads = new Thread[numTasks * SearchPhaseName.values().length];
Phaser phaser = new Phaser(numTasks * SearchPhaseName.values().length + 1);
CountDownLatch countDownLatch = new CountDownLatch(numTasks * SearchPhaseName.values().length);
for (SearchPhaseName searchPhaseName : SearchPhaseName.values()) {
Thread[] threads = new Thread[numTasks * trackablePhases.size()];
Phaser phaser = new Phaser(numTasks * trackablePhases.size() + 1);
CountDownLatch countDownLatch = new CountDownLatch(numTasks * trackablePhases.size());
for (SearchPhaseName searchPhaseName : trackablePhases) {
SearchPhaseContext ctx = mock(SearchPhaseContext.class);
SearchPhase mockSearchPhase = mock(SearchPhase.class);
when(ctx.getCurrentPhase()).thenReturn(mockSearchPhase);
Expand All @@ -201,8 +214,48 @@ public void testSearchRequestStatsOnPhaseFailureConcurrently() throws Interrupte
}
phaser.arriveAndAwaitAdvance();
countDownLatch.await();
for (SearchPhaseName searchPhaseName : SearchPhaseName.values()) {
for (SearchPhaseName searchPhaseName : trackablePhases) {
assertEquals(0, testRequestStats.getPhaseCurrent(searchPhaseName));
}
}

public void testOtherPhaseNamesAreIgnored() {
// Unrecognized phase names shouldn't be tracked, but should not throw any error.
ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS);
SearchRequestStats testRequestStats = new SearchRequestStats(clusterSettings);
SearchPhaseContext ctx = mock(SearchPhaseContext.class);
SearchPhase mockSearchPhase = mock(SearchPhase.class);
when(ctx.getCurrentPhase()).thenReturn(mockSearchPhase);

when(mockSearchPhase.getSearchPhaseName()).thenReturn(SearchPhaseName.OTHER_PHASE_TYPES);
testRequestStats.onPhaseStart(ctx);
long startTime = System.nanoTime() - TimeUnit.MILLISECONDS.toNanos(10);
when(mockSearchPhase.getStartTimeInNanos()).thenReturn(startTime);
// All values should return 0 for untracked phase types
assertEquals(0, testRequestStats.getPhaseCurrent(SearchPhaseName.OTHER_PHASE_TYPES));
testRequestStats.onPhaseEnd(
ctx,
new SearchRequestContext(
new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()),
new SearchRequest(),
() -> null
)
);
assertEquals(0, testRequestStats.getPhaseCurrent(SearchPhaseName.OTHER_PHASE_TYPES));
assertEquals(0, testRequestStats.getPhaseTotal(SearchPhaseName.OTHER_PHASE_TYPES));
assertEquals(0, testRequestStats.getPhaseMetric(SearchPhaseName.OTHER_PHASE_TYPES));
}

public void testSearchPhaseCatchAll() {
// Test search phases with unrecognized names return the catch-all OTHER_PHASE_TYPES when getSearchPhaseName() is called.
// These may exist, for example, "create_pit".
String unrecognizedName = "unrecognized_name";
SearchPhase dummyPhase = new SearchPhase(unrecognizedName) {
@Override
public void run() {}
};

assertEquals(unrecognizedName, dummyPhase.getName());
assertEquals(SearchPhaseName.OTHER_PHASE_TYPES, dummyPhase.getSearchPhaseName());
}
}

0 comments on commit 0c9110d

Please sign in to comment.