Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add clustering function - RCFSummarize #355

Merged
merged 2 commits into from
Aug 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ public enum FunctionName {
LOCAL_SAMPLE_CALCULATOR,
FIT_RCF,
BATCH_RCF,
ANOMALY_LOCALIZATION;
ANOMALY_LOCALIZATION,
RCF_SUMMARIZE;

public static FunctionName from(String value) {
try {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.input.parameter.clustering;

import lombok.Builder;
import lombok.Data;
import org.opensearch.common.ParseField;
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.io.stream.StreamOutput;
import org.opensearch.common.xcontent.NamedXContentRegistry;
import org.opensearch.common.xcontent.XContentBuilder;
import org.opensearch.common.xcontent.XContentParser;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.annotation.MLAlgoParameter;
import org.opensearch.ml.common.input.parameter.MLAlgoParams;

import java.io.IOException;

import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken;

@Data
@MLAlgoParameter(algorithms={FunctionName.RCF_SUMMARIZE})
public class RCFSummarizeParams implements MLAlgoParams {
public static final String PARSE_FIELD_NAME = FunctionName.RCF_SUMMARIZE.name();
public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry(
MLAlgoParams.class,
new ParseField(PARSE_FIELD_NAME),
it -> parse(it)
);

public static final String MAX_K_FIELD = "max_k";
public static final String INITIAL_K_FIELD = "initial_k";
public static final String DISTANCE_TYPE_FIELD = "distance_type";
public static final String PHASE1_REASSIGN_FIELD = "phase1_reassign";
public static final String PARALLEL__FIELD = "parallel";

// The max of K allowed
private Integer maxK;
// The initial K used
private Integer initialK;
// The distance function
private DistanceType distanceType;
// Whether to also use Reassign in Phase1
private Boolean phase1Reassign;
// Whether to train in parallel
private Boolean parallel;
// TODO: expose seed?

@Builder(toBuilder = true)
public RCFSummarizeParams(Integer maxK, Integer initialK, DistanceType distanceType, Boolean phase1Reassign, Boolean parallel) {

this.maxK = maxK;
this.initialK = initialK;
this.distanceType = distanceType;
this.phase1Reassign = phase1Reassign;
this.parallel = parallel;
}

public RCFSummarizeParams(StreamInput in) throws IOException {
this.maxK = in.readOptionalInt();
this.initialK = in.readOptionalInt();
this.phase1Reassign = in.readOptionalBoolean();
this.parallel = in.readOptionalBoolean();

if (in.readBoolean()) {
this.distanceType = in.readEnum(DistanceType.class);
}
}

public static MLAlgoParams parse(XContentParser parser) throws IOException {
Integer maxK = null;
Integer initialK = null;
Boolean phase1Reassign = null;
Boolean parallel = null;
DistanceType distanceType = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
String fieldName = parser.currentName();
parser.nextToken();

switch (fieldName) {
case MAX_K_FIELD:
maxK = parser.intValue(false);
break;
case INITIAL_K_FIELD:
initialK = parser.intValue(false);
break;
case PHASE1_REASSIGN_FIELD:
phase1Reassign = parser.booleanValue();
break;
case PARALLEL__FIELD:
parallel = parser.booleanValue();
break;
case DISTANCE_TYPE_FIELD:
distanceType = DistanceType.from(parser.text().toUpperCase());
break;
default:
parser.skipChildren();
break;
}
}

return new RCFSummarizeParams(maxK, initialK, distanceType, phase1Reassign, parallel);
}

@Override
public String getWriteableName() {
return PARSE_FIELD_NAME;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalInt(maxK);
out.writeOptionalInt(initialK);
out.writeOptionalBoolean(phase1Reassign);
out.writeOptionalBoolean(parallel);
if (distanceType != null) {
out.writeBoolean(true);
out.writeEnum(distanceType);
} else {
out.writeBoolean(false);
}
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (maxK != null) {
builder.field(MAX_K_FIELD, maxK);
}

if (initialK != null) {
builder.field(INITIAL_K_FIELD, initialK);
}

if (phase1Reassign != null) {
builder.field(PHASE1_REASSIGN_FIELD, phase1Reassign);
}

if (parallel != null) {
builder.field(PARALLEL__FIELD, parallel);
}

if (distanceType != null) {
builder.field(DISTANCE_TYPE_FIELD, distanceType.name());
}
builder.endObject();
return builder;
}

@Override
public int getVersion() {
return 1;
}

public enum DistanceType {
L1,
L2,
LInfinity;

public static DistanceType from(String value) {
try {
return DistanceType.valueOf(value);
} catch (Exception e) {
throw new IllegalArgumentException("Wrong distance type");
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.input.parameter.clustering;

import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.xcontent.XContentParser;
import org.opensearch.ml.common.TestHelper;

import java.io.IOException;
import java.util.function.Function;

import static org.junit.Assert.assertEquals;
import static org.opensearch.ml.common.TestHelper.contentObjectToString;
import static org.opensearch.ml.common.TestHelper.testParseFromString;

public class RCFSummarizeParamsTest {
@Rule
public ExpectedException exceptionRule = ExpectedException.none();

RCFSummarizeParams params;
private Function<XContentParser, RCFSummarizeParams> function = parser -> {
try {
return (RCFSummarizeParams)RCFSummarizeParams.parse(parser);
} catch (IOException e) {
throw new RuntimeException("failed to parse RCFSummarizeParams", e);
}
};

@Before
public void setUp() {
params = RCFSummarizeParams.builder()
.maxK(2)
.initialK(10)
.distanceType(RCFSummarizeParams.DistanceType.L1)
.build();
}

@Test
public void parseRCFSummarizeParams() throws IOException {
TestHelper.testParse(params, function);
}

@Test
public void parseRCFSummarizeParamsExceptionOnInvalidDoubleValue() throws IOException {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("2.01 cannot be converted to Integer without data loss");
String paramsStr = contentObjectToString(params);
testParseFromString(params, paramsStr.replace("\"max_k\":2,", "\"max_k\":2.01,"), function);
}

@Test
public void parseRCFSummarizeParamsExceptionOnInvalidDoubleString() throws IOException {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("Integer value passed as String");
String paramsStr = contentObjectToString(params);
testParseFromString(params, paramsStr.replace("\"max_k\":2,", "\"max_k\":\"2.01\","), function);
}

@Test
public void parseEmptyRCFSummarizeParams() throws IOException {
TestHelper.testParse(RCFSummarizeParams.builder().build(), function);
}

@Test
public void readInputStreamSuccess() throws IOException {
readInputStream(params);
}

@Test
public void readInputStream_Success_EmptyParams() throws IOException {
readInputStream(RCFSummarizeParams.builder().build());
}

private void readInputStream(RCFSummarizeParams params) throws IOException {
BytesStreamOutput bytesStreamOutput = new BytesStreamOutput();
params.writeTo(bytesStreamOutput);

StreamInput streamInput = bytesStreamOutput.bytes().streamInput();
RCFSummarizeParams parsedParams = new RCFSummarizeParams(streamInput);
assertEquals(params, parsedParams);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,12 @@ public class KMeans implements TrainAndPredictable {
private static int DEFAULT_CENTROIDS = 2;
private static int DEFAULT_ITERATIONS = 10;

//The number of threads.
// Parameters
private KMeansParams parameters;

//The number of threads.
private int numThreads = Math.max(Runtime.getRuntime().availableProcessors() / 2, 1); //Assume cpu-bound.

//The random seed.
private long seed = System.currentTimeMillis();
private KMeansTrainer.Distance distance;
Expand Down
Loading