Skip to content

Commit

Permalink
Enhance profile API to add model centric result controlled by view pa…
Browse files Browse the repository at this point in the history
…rameter (opensearch-project#714)

* Enhance profile API to add model centric result controled by view paramter

Signed-off-by: Zan Niu <zaniu@amazon.com>

* Enhance profile API to add model centric result controled by view parameter

Signed-off-by: Zan Niu <zaniu@amazon.com>

* Enhance profile API to add model centric result controled by view parameter

Signed-off-by: Zan Niu <zaniu@amazon.com>

---------

Signed-off-by: Zan Niu <zaniu@amazon.com>
  • Loading branch information
zane-neo authored and ylwu-amzn committed Feb 17, 2023
1 parent eaa8652 commit 2c8f2cf
Show file tree
Hide file tree
Showing 6 changed files with 316 additions and 9 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
/*
*
* * Copyright OpenSearch Contributors
* * SPDX-License-Identifier: Apache-2.0
*
*/

package org.opensearch.ml.action.profile;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;

import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.Setter;

import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.io.stream.StreamOutput;
import org.opensearch.common.io.stream.Writeable;
import org.opensearch.common.xcontent.ToXContentFragment;
import org.opensearch.common.xcontent.XContentBuilder;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.profile.MLModelProfile;

@Getter
@NoArgsConstructor
public class MLProfileModelResponse implements ToXContentFragment, Writeable {
@Setter
private String[] targetWorkerNodes;

@Setter
private String[] workerNodes;

private Map<String, MLModelProfile> mlModelProfileMap = new HashMap<>();

private Map<String, MLTask> mlTaskMap = new HashMap<>();

public MLProfileModelResponse(String[] targetWorkerNodes, String[] workerNodes) {
this.targetWorkerNodes = targetWorkerNodes;
this.workerNodes = workerNodes;
}

public MLProfileModelResponse(StreamInput in) throws IOException {
this.workerNodes = in.readOptionalStringArray();
this.targetWorkerNodes = in.readOptionalStringArray();
if (in.readBoolean()) {
this.mlModelProfileMap = in.readMap(StreamInput::readString, MLModelProfile::new);
}
if (in.readBoolean()) {
this.mlTaskMap = in.readMap(StreamInput::readString, MLTask::new);
}
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (targetWorkerNodes != null) {
builder.field("target_worker_nodes", targetWorkerNodes);
}
if (workerNodes != null) {
builder.field("worker_nodes", workerNodes);
}
if (mlModelProfileMap.size() > 0) {
builder.startObject("nodes");
for (Map.Entry<String, MLModelProfile> entry : mlModelProfileMap.entrySet()) {
builder.field(entry.getKey(), entry.getValue());
}
builder.endObject();
}
if (mlTaskMap.size() > 0) {
builder.startObject("tasks");
for (Map.Entry<String, MLTask> entry : mlTaskMap.entrySet()) {
builder.field(entry.getKey(), entry.getValue());
}
builder.endObject();
}
builder.endObject();
return builder;
}

@Override
public void writeTo(StreamOutput streamOutput) throws IOException {
streamOutput.writeOptionalStringArray(workerNodes);
streamOutput.writeOptionalStringArray(targetWorkerNodes);
if (mlModelProfileMap.size() > 0) {
streamOutput.writeBoolean(true);
streamOutput.writeMap(mlModelProfileMap, StreamOutput::writeString, (o, r) -> r.writeTo(o));
} else {
streamOutput.writeBoolean(false);
}
if (mlTaskMap.size() > 0) {
streamOutput.writeBoolean(true);
streamOutput.writeMap(mlTaskMap, StreamOutput::writeString, (o, r) -> r.writeTo(o));
} else {
streamOutput.writeBoolean(false);
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@

import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;

Expand All @@ -28,20 +30,29 @@
import org.opensearch.common.xcontent.XContentBuilder;
import org.opensearch.common.xcontent.XContentParser;
import org.opensearch.ml.action.profile.MLProfileAction;
import org.opensearch.ml.action.profile.MLProfileModelResponse;
import org.opensearch.ml.action.profile.MLProfileNodeResponse;
import org.opensearch.ml.action.profile.MLProfileRequest;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.profile.MLModelProfile;
import org.opensearch.ml.profile.MLProfileInput;
import org.opensearch.ml.utils.RestActionUtils;
import org.opensearch.rest.BaseRestHandler;
import org.opensearch.rest.BytesRestResponse;
import org.opensearch.rest.RestRequest;
import org.opensearch.rest.RestStatus;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;

@Log4j2
public class RestMLProfileAction extends BaseRestHandler {
private static final String PROFILE_ML_ACTION = "profile_ml";

private static final String VIEW = "view";
private static final String MODEL_VIEW = "model";
private static final String NODE_VIEW = "node";

private ClusterService clusterService;

/**
Expand Down Expand Up @@ -80,6 +91,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
} else {
mlProfileInput = createMLProfileInputFromRequestParams(request);
}
String view = RestActionUtils.getStringParam(request, VIEW).orElse(NODE_VIEW);
String[] nodeIds = mlProfileInput.retrieveProfileOnAllNodes()
? getAllNodes(clusterService)
: mlProfileInput.getNodeIds().toArray(new String[0]);
Expand All @@ -93,7 +105,16 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
List<MLProfileNodeResponse> nodeProfiles = r.getNodes().stream().filter(s -> !s.isEmpty()).collect(Collectors.toList());
log.debug("Build MLProfileNodeResponse for size of {}", nodeProfiles.size());
if (nodeProfiles.size() > 0) {
r.toXContent(builder, ToXContent.EMPTY_PARAMS);
if (NODE_VIEW.equals(view)) {
r.toXContent(builder, ToXContent.EMPTY_PARAMS);
} else if (MODEL_VIEW.equals(view)) {
Map<String, MLProfileModelResponse> modelCentricProfileMap = buildModelCentricResult(nodeProfiles);
builder.startObject("models");
for (Map.Entry<String, MLProfileModelResponse> entry : modelCentricProfileMap.entrySet()) {
builder.field(entry.getKey(), entry.getValue());
}
builder.endObject();
}
}
builder.endObject();
channel.sendResponse(new BytesRestResponse(RestStatus.OK, builder));
Expand All @@ -105,6 +126,59 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
};
}

/**
* The data structure for node centric is:
* MLProfileNodeResponse:
* taskMap: Map<String, MLTask>
* modelMap: Map<String, MLModelProfile> model_id, MLModelProfile
* And we need to convert to format like this:
* modelMap: Map<String, Map<String, MLModelProfile>>
*/
private Map<String, MLProfileModelResponse> buildModelCentricResult(List<MLProfileNodeResponse> nodeResponses) {
// aggregate model information into one final map.
Map<String, MLProfileModelResponse> modelCentricMap = new HashMap<>();
for (MLProfileNodeResponse mlProfileNodeResponse : nodeResponses) {
String nodeId = mlProfileNodeResponse.getNode().getId();
Map<String, MLModelProfile> modelProfileMap = mlProfileNodeResponse.getMlNodeModels();
Map<String, MLTask> taskProfileMap = mlProfileNodeResponse.getMlNodeTasks();
for (Map.Entry<String, MLModelProfile> entry : modelProfileMap.entrySet()) {
MLProfileModelResponse mlProfileModelResponse = modelCentricMap.get(entry.getKey());
if (mlProfileModelResponse == null) {
mlProfileModelResponse = new MLProfileModelResponse(
entry.getValue().getTargetWorkerNodes(),
entry.getValue().getWorkerNodes()
);
modelCentricMap.put(entry.getKey(), mlProfileModelResponse);
}
if (mlProfileModelResponse.getTargetWorkerNodes() == null || mlProfileModelResponse.getWorkerNodes() == null) {
mlProfileModelResponse.setTargetWorkerNodes(entry.getValue().getTargetWorkerNodes());
mlProfileModelResponse.setWorkerNodes(entry.getValue().getWorkerNodes());
}
// Create a new object and remove targetWorkerNodes and workerNodes.
MLModelProfile modelProfile = new MLModelProfile(
entry.getValue().getModelState(),
entry.getValue().getPredictor(),
null,
null,
entry.getValue().getModelInferenceStats(),
entry.getValue().getPredictRequestStats()
);
mlProfileModelResponse.getMlModelProfileMap().putAll(ImmutableMap.of(nodeId, modelProfile));
}

for (Map.Entry<String, MLTask> entry : taskProfileMap.entrySet()) {
String modelId = entry.getValue().getModelId();
MLProfileModelResponse mlProfileModelResponse = modelCentricMap.get(modelId);
if (mlProfileModelResponse == null) {
mlProfileModelResponse = new MLProfileModelResponse();
modelCentricMap.put(modelId, mlProfileModelResponse);
}
mlProfileModelResponse.getMlTaskMap().putAll(ImmutableMap.of(entry.getKey(), entry.getValue()));
}
}
return modelCentricMap;
}

MLProfileInput createMLProfileInputFromRequestParams(RestRequest request) {
MLProfileInput mlProfileInput = new MLProfileInput();
Optional<String[]> modelIds = splitCommaSeparatedParam(request, PARAMETER_MODEL_ID);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,4 +169,9 @@ private static String coalesceToEmpty(@Nullable String s) {
private static boolean isNullOrEmpty(@Nullable String s) {
return s == null || s.isEmpty();
}

public static Optional<String> getStringParam(RestRequest request, String paramName) {
return Optional.ofNullable(request.param(paramName));
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
/*
*
* * Copyright OpenSearch Contributors
* * SPDX-License-Identifier: Apache-2.0
*
*/

package org.opensearch.ml.action.profile;

import java.io.IOException;
import java.time.Instant;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;

import org.junit.Before;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.xcontent.ToXContent;
import org.opensearch.common.xcontent.XContentBuilder;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.commons.authuser.User;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.MLTaskState;
import org.opensearch.ml.common.MLTaskType;
import org.opensearch.ml.common.dataset.MLInputDataType;
import org.opensearch.ml.common.model.MLModelState;
import org.opensearch.ml.profile.MLModelProfile;
import org.opensearch.ml.profile.MLPredictRequestStats;
import org.opensearch.ml.utils.TestHelper;
import org.opensearch.test.OpenSearchTestCase;

public class MLProfileModelResponseTests extends OpenSearchTestCase {

MLTask mlTask;
MLModelProfile mlModelProfile;

@Before
public void setup() {
mlTask = MLTask
.builder()
.taskId("test_id")
.modelId("model_id")
.taskType(MLTaskType.TRAINING)
.functionName(FunctionName.AD_LIBSVM)
.state(MLTaskState.CREATED)
.inputType(MLInputDataType.DATA_FRAME)
.progress(0.4f)
.outputIndex("test_index")
.workerNodes(Arrays.asList("test_node"))
.createTime(Instant.ofEpochMilli(123))
.lastUpdateTime(Instant.ofEpochMilli(123))
.error("error")
.user(new User())
.async(false)
.build();
mlModelProfile = MLModelProfile
.builder()
.predictor("test_predictor")
.workerNodes(new String[] { "node1", "node2" })
.modelState(MLModelState.LOADED)
.modelInferenceStats(MLPredictRequestStats.builder().count(10L).average(11.0).max(20.0).min(5.0).build())
.build();
}

public void test_create_MLProfileModelResponse_withArgs() throws IOException {
String[] targetWorkerNodes = new String[] { "node1", "node2" };
String[] workerNodes = new String[] { "node1" };
Map<String, MLModelProfile> profileMap = new HashMap<>();
Map<String, MLTask> taskMap = new HashMap<>();
profileMap.put("node1", mlModelProfile);
taskMap.put("node1", mlTask);
MLProfileModelResponse response = new MLProfileModelResponse(targetWorkerNodes, workerNodes);
response.getMlModelProfileMap().putAll(profileMap);
response.getMlTaskMap().putAll(taskMap);
BytesStreamOutput output = new BytesStreamOutput();
response.writeTo(output);
MLProfileModelResponse newResponse = new MLProfileModelResponse(output.bytes().streamInput());
assertNotNull(newResponse.getTargetWorkerNodes());
assertNotNull(response.getTargetWorkerNodes());
assertEquals(newResponse.getTargetWorkerNodes().length, response.getTargetWorkerNodes().length);
assertEquals(newResponse.getMlModelProfileMap().size(), response.getMlModelProfileMap().size());
assertEquals(newResponse.getMlTaskMap().size(), response.getMlTaskMap().size());
}

public void test_create_MLProfileModelResponse_NoArgs() throws IOException {
MLProfileModelResponse response = new MLProfileModelResponse();
BytesStreamOutput output = new BytesStreamOutput();
response.writeTo(output);
MLProfileModelResponse newResponse = new MLProfileModelResponse(output.bytes().streamInput());
assertNull(response.getWorkerNodes());
assertNull(newResponse.getWorkerNodes());
}

public void test_toXContent() throws IOException {
String[] targetWorkerNodes = new String[] { "node1", "node2" };
String[] workerNodes = new String[] { "node1" };
Map<String, MLModelProfile> profileMap = new HashMap<>();
Map<String, MLTask> taskMap = new HashMap<>();
profileMap.put("node1", mlModelProfile);
taskMap.put("node1", mlTask);
MLProfileModelResponse response = new MLProfileModelResponse(targetWorkerNodes, workerNodes);
response.getMlModelProfileMap().putAll(profileMap);
response.getMlTaskMap().putAll(taskMap);

XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
response.toXContent(builder, ToXContent.EMPTY_PARAMS);
String xContentString = TestHelper.xContentBuilderToString(builder);
System.out.println(xContentString);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.opensearch.ml.utils.TestHelper.getProfileRestRequest;
import static org.opensearch.ml.utils.TestHelper.setupTestClusterState;
import static org.opensearch.ml.utils.TestHelper.*;

import java.io.IOException;
import java.time.Instant;
Expand Down Expand Up @@ -68,6 +67,7 @@
import org.opensearch.threadpool.ThreadPool;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;

public class RestMLProfileActionTests extends OpenSearchTestCase {
@Rule
Expand Down Expand Up @@ -286,6 +286,14 @@ public void test_PrepareRequest_Failure() throws Exception {
verify(client, times(1)).execute(eq(MLProfileAction.INSTANCE), argumentCaptor.capture(), any());
}

public void test_WhenViewIsModel_ReturnModelViewResult() throws Exception {
MLProfileInput mlProfileInput = new MLProfileInput();
RestRequest request = getProfileRestRequestWithQueryParams(mlProfileInput, ImmutableMap.of("view", "model"));
profileAction.handleRequest(request, channel, client);
ArgumentCaptor<MLProfileRequest> argumentCaptor = ArgumentCaptor.forClass(MLProfileRequest.class);
verify(client, times(1)).execute(eq(MLProfileAction.INSTANCE), argumentCaptor.capture(), any());
}

private RestRequest getRestRequest() {
Map<String, String> params = new HashMap<>();
params.put("task_id", "test_id");
Expand Down
Loading

0 comments on commit 2c8f2cf

Please sign in to comment.