Skip to content

Commit

Permalink
Initial code for adding the SearchPhaseInjectorProcessor interface in…
Browse files Browse the repository at this point in the history
… Search Pipeline

Signed-off-by: Navneet Verma <navneev@amazon.com>
  • Loading branch information
navneet1v committed Apr 24, 2023
1 parent 66e49a6 commit 7945f01
Show file tree
Hide file tree
Showing 23 changed files with 377 additions and 45 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
## [Unreleased 2.x]
### Added
- [Extensions] Moving Extensions APIs to protobuf serialization. ([#6960](https://github.com/opensearch-project/OpenSearch/pull/6960))
- [SearchPipeline] Initial code for adding the SearchPhaseInjectorProcessor interface in Search Pipeline.([#7283](https://github.com/opensearch-project/OpenSearch/pull/7283))

### Dependencies

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
import org.opensearch.search.internal.InternalSearchResponse;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.internal.ShardSearchRequest;
import org.opensearch.search.pipeline.SearchPipelineService;
import org.opensearch.transport.Transport;

import java.util.ArrayDeque;
Expand Down Expand Up @@ -116,6 +117,7 @@ abstract class AbstractSearchAsyncAction<Result extends SearchPhaseResult> exten
private final boolean throttleConcurrentRequests;

private final List<Releasable> releasables = new ArrayList<>();
private final SearchPipelineService searchPipelineService;

AbstractSearchAsyncAction(
String name,
Expand All @@ -134,7 +136,8 @@ abstract class AbstractSearchAsyncAction<Result extends SearchPhaseResult> exten
SearchTask task,
SearchPhaseResults<Result> resultConsumer,
int maxConcurrentRequestsPerNode,
SearchResponse.Clusters clusters
SearchResponse.Clusters clusters,
SearchPipelineService searchPipelineService
) {
super(name);
final List<SearchShardIterator> toSkipIterators = new ArrayList<>();
Expand Down Expand Up @@ -170,6 +173,7 @@ abstract class AbstractSearchAsyncAction<Result extends SearchPhaseResult> exten
this.indexRoutings = indexRoutings;
this.results = resultConsumer;
this.clusters = clusters;
this.searchPipelineService = searchPipelineService;
}

@Override
Expand Down Expand Up @@ -696,7 +700,9 @@ private void raisePhaseFailure(SearchPhaseExecutionException exception) {
* @see #onShardResult(SearchPhaseResult, SearchShardIterator)
*/
final void onPhaseDone() { // as a tribute to @kimchy aka. finishHim()
executeNextPhase(this, getNextPhase(results, this));
final SearchPhase nextPhase = getNextPhase(results, this);
searchPipelineService.transformSearchPhase(results, this, this.getSearchPhaseName(), nextPhase.getSearchPhaseName());
executeNextPhase(this, nextPhase);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ boolean hasResult(int shardIndex) {
}

@Override
AtomicArray<Result> getAtomicArray() {
public AtomicArray<Result> getAtomicArray() {
return results;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import org.opensearch.search.SearchShardTarget;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.search.internal.AliasFilter;
import org.opensearch.search.pipeline.SearchPipelineService;
import org.opensearch.search.sort.FieldSortBuilder;
import org.opensearch.search.sort.MinAndMax;
import org.opensearch.search.sort.SortOrder;
Expand Down Expand Up @@ -90,7 +91,8 @@ final class CanMatchPreFilterSearchPhase extends AbstractSearchAsyncAction<CanMa
ClusterState clusterState,
SearchTask task,
Function<GroupShardsIterator<SearchShardIterator>, SearchPhase> phaseFactory,
SearchResponse.Clusters clusters
SearchResponse.Clusters clusters,
SearchPipelineService searchPipelineService
) {
// We set max concurrent shard requests to the number of shards so no throttling happens for can_match requests
super(
Expand All @@ -110,7 +112,8 @@ final class CanMatchPreFilterSearchPhase extends AbstractSearchAsyncAction<CanMa
task,
new CanMatchSearchPhaseResults(shardsIts.size()),
shardsIts.size(),
clusters
clusters,
searchPipelineService
);
this.phaseFactory = phaseFactory;
this.shardsIts = shardsIts;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ final class DfsQueryPhase extends SearchPhase {
Function<ArraySearchPhaseResults<SearchPhaseResult>, SearchPhase> nextPhaseFactory,
SearchPhaseContext context
) {
super("dfs_query");
super(SearchPhaseName.DFS_QUERY.name());
this.progressListener = context.getTask().getProgressListener();
this.queryResult = queryResult;
this.searchResults = searchResults;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ final class ExpandSearchPhase extends SearchPhase {
private final AtomicArray<SearchPhaseResult> queryResults;

ExpandSearchPhase(SearchPhaseContext context, InternalSearchResponse searchResponse, AtomicArray<SearchPhaseResult> queryResults) {
super("expand");
super(SearchPhaseName.EXPAND.name());
this.context = context;
this.searchResponse = searchResponse;
this.queryResults = queryResults;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ final class FetchSearchPhase extends SearchPhase {
SearchPhaseContext context,
BiFunction<InternalSearchResponse, AtomicArray<SearchPhaseResult>, SearchPhase> nextPhaseFactory
) {
super("fetch");
super(SearchPhaseName.FETCH.name());
if (context.getNumShards() != resultConsumer.getNumShards()) {
throw new IllegalStateException(
"number of shards must match the length of the query results but doesn't:"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import org.opensearch.search.dfs.AggregatedDfs;
import org.opensearch.search.dfs.DfsSearchResult;
import org.opensearch.search.internal.AliasFilter;
import org.opensearch.search.pipeline.SearchPipelineService;
import org.opensearch.transport.Transport;

import java.util.List;
Expand Down Expand Up @@ -76,7 +77,8 @@ final class SearchDfsQueryThenFetchAsyncAction extends AbstractSearchAsyncAction
final TransportSearchAction.SearchTimeProvider timeProvider,
final ClusterState clusterState,
final SearchTask task,
SearchResponse.Clusters clusters
SearchResponse.Clusters clusters,
SearchPipelineService searchPipelineService
) {
super(
"dfs",
Expand All @@ -95,7 +97,8 @@ final class SearchDfsQueryThenFetchAsyncAction extends AbstractSearchAsyncAction
task,
new ArraySearchPhaseResults<>(shardsIts.size()),
request.getMaxConcurrentShardRequests(),
clusters
clusters,
searchPipelineService
);
this.queryPhaseResultConsumer = queryPhaseResultConsumer;
this.searchPhaseController = searchPhaseController;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
*
* @opensearch.internal
*/
abstract class SearchPhase implements CheckedRunnable<IOException> {
public abstract class SearchPhase implements CheckedRunnable<IOException> {
private final String name;

protected SearchPhase(String name) {
Expand All @@ -54,4 +54,25 @@ protected SearchPhase(String name) {
public String getName() {
return name;
}

public SearchPhaseName getSearchPhaseName() {
return SearchPhaseName.valueOf(name);
}

/**
* Enum for different Search Phases in OpenSearch
* @opensearch.internal
*/
public enum SearchPhaseName {
QUERY("query"),
FETCH("fetch"),
DFS_QUERY("dfs_query"),
EXPAND("expand");

private final String name;

SearchPhaseName(final String name) {
this.name = name;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
*
* @opensearch.internal
*/
interface SearchPhaseContext extends Executor {
public interface SearchPhaseContext extends Executor {
// TODO maybe we can make this concrete later - for now we just implement this in the base class for all initial phases

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
*
* @opensearch.internal
*/
abstract class SearchPhaseResults<Result extends SearchPhaseResult> {
public abstract class SearchPhaseResults<Result extends SearchPhaseResult> {
private final int numShards;

SearchPhaseResults(int numShards) {
Expand Down Expand Up @@ -75,7 +75,7 @@ final int getNumShards() {

void consumeShardFailure(int shardIndex) {}

AtomicArray<Result> getAtomicArray() {
public AtomicArray<Result> getAtomicArray() {
throw new UnsupportedOperationException();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import org.opensearch.search.internal.AliasFilter;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.internal.ShardSearchRequest;
import org.opensearch.search.pipeline.SearchPipelineService;
import org.opensearch.search.query.QuerySearchResult;
import org.opensearch.transport.Transport;

Expand Down Expand Up @@ -81,7 +82,8 @@ class SearchQueryThenFetchAsyncAction extends AbstractSearchAsyncAction<SearchPh
final TransportSearchAction.SearchTimeProvider timeProvider,
ClusterState clusterState,
SearchTask task,
SearchResponse.Clusters clusters
SearchResponse.Clusters clusters,
SearchPipelineService searchPipelineService
) {
super(
"query",
Expand All @@ -100,7 +102,8 @@ class SearchQueryThenFetchAsyncAction extends AbstractSearchAsyncAction<SearchPh
task,
resultConsumer,
request.getMaxConcurrentShardRequests(),
clusters
clusters,
searchPipelineService
);
this.topDocsSize = SearchPhaseController.getTopDocsSize(request);
this.trackTotalHitsUpTo = request.resolveTrackTotalHitsUpTo();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ protected SearchPhase sendResponsePhase(
SearchPhaseController.ReducedQueryPhase queryPhase,
final AtomicArray<? extends SearchPhaseResult> fetchResults
) {
return new SearchPhase("fetch") {
return new SearchPhase(SearchPhase.SearchPhaseName.FETCH.name()) {
@Override
public void run() throws IOException {
sendResponse(queryPhase, fetchResults);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ protected void executeInitialPhase(

@Override
protected SearchPhase moveToNextPhase(BiFunction<String, String, DiscoveryNode> clusterNodeLookup) {
return new SearchPhase("fetch") {
return new SearchPhase(SearchPhase.SearchPhaseName.FETCH.name()) {
@Override
public void run() {
final SearchPhaseController.ReducedQueryPhase reducedQueryPhase = searchPhaseController.reducedScrollQueryPhase(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,8 @@ public AbstractSearchAsyncAction<? extends SearchPhaseResult> asyncSearchAction(
task,
new ArraySearchPhaseResults<>(shardsIts.size()),
searchRequest.getMaxConcurrentShardRequests(),
clusters
clusters,
searchPipelineService
) {
@Override
protected void executePhaseOnShard(
Expand Down Expand Up @@ -1161,7 +1162,8 @@ public void run() {
}
};
},
clusters
clusters,
searchPipelineService
);
} else {
final QueryPhaseResultConsumer queryResultConsumer = searchPhaseController.newSearchPhaseResults(
Expand Down Expand Up @@ -1191,7 +1193,8 @@ public void run() {
timeProvider,
clusterState,
task,
clusters
clusters,
searchPipelineService
);
break;
case QUERY_THEN_FETCH:
Expand All @@ -1211,7 +1214,8 @@ public void run() {
timeProvider,
clusterState,
task,
clusters
clusters,
searchPipelineService
);
break;
default:
Expand Down
Loading

0 comments on commit 7945f01

Please sign in to comment.