From 18b1d82073557d1ebe9f2c6761dfcf5d9ed9081e Mon Sep 17 00:00:00 2001 From: Xun Zhang Date: Thu, 7 Jul 2022 17:10:51 -0700 Subject: [PATCH 1/4] Add Logistic Regression algorithm Signed-off-by: Xun Zhang --- .../opensearch/ml/common/FunctionName.java | 3 +- .../ml/common/dataframe/DataFrame.java | 6 + .../ml/common/dataframe/DefaultDataFrame.java | 18 ++ .../regression/LogisticRegressionParams.java | 231 ++++++++++++++++++ ml-algorithms/build.gradle | 1 + .../regression/LogisticRegression.java | 145 +++++++++++ .../ml/engine/contants/TribuoOutputType.java | 4 +- .../ml/engine/utils/TribuoUtil.java | 113 +++++++-- .../regression/LogisticRegressionTest.java | 109 +++++++++ .../helper/LogisticRegressionHelper.java | 55 +++++ .../ml/action/MLCommonsIntegTestCase.java | 33 +++ .../action/prediction/PredictionITTests.java | 10 + .../ml/action/training/TrainingITTests.java | 24 +- 13 files changed, 726 insertions(+), 26 deletions(-) create mode 100644 common/src/main/java/org/opensearch/ml/common/input/parameter/regression/LogisticRegressionParams.java create mode 100644 ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/regression/LogisticRegression.java create mode 100644 ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/regression/LogisticRegressionTest.java create mode 100644 ml-algorithms/src/test/java/org/opensearch/ml/engine/helper/LogisticRegressionHelper.java diff --git a/common/src/main/java/org/opensearch/ml/common/FunctionName.java b/common/src/main/java/org/opensearch/ml/common/FunctionName.java index 44f9eff7e1..697c9b68f8 100644 --- a/common/src/main/java/org/opensearch/ml/common/FunctionName.java +++ b/common/src/main/java/org/opensearch/ml/common/FunctionName.java @@ -14,7 +14,8 @@ public enum FunctionName { FIT_RCF, BATCH_RCF, ANOMALY_LOCALIZATION, - RCF_SUMMARIZE; + RCF_SUMMARIZE, + LOGISTIC_REGRESSION; public static FunctionName from(String value) { try { diff --git a/common/src/main/java/org/opensearch/ml/common/dataframe/DataFrame.java b/common/src/main/java/org/opensearch/ml/common/dataframe/DataFrame.java index 7f0daa5d4e..f7675a07e4 100644 --- a/common/src/main/java/org/opensearch/ml/common/dataframe/DataFrame.java +++ b/common/src/main/java/org/opensearch/ml/common/dataframe/DataFrame.java @@ -58,4 +58,10 @@ public interface DataFrame extends Iterable, Writeable, ToXContentObject { */ DataFrame select(int[] columns); + /** + * Find the index of the target in columnMetas + * @param target the string value of the target + * @return column index of the target in the list of columnMetas + */ + int getColumnIndex(String target); } diff --git a/common/src/main/java/org/opensearch/ml/common/dataframe/DefaultDataFrame.java b/common/src/main/java/org/opensearch/ml/common/dataframe/DefaultDataFrame.java index ccb7aa6097..228dd40157 100644 --- a/common/src/main/java/org/opensearch/ml/common/dataframe/DefaultDataFrame.java +++ b/common/src/main/java/org/opensearch/ml/common/dataframe/DefaultDataFrame.java @@ -131,6 +131,24 @@ public DataFrame select(int[] columns) { return new DefaultDataFrame(newColumnMetas, rows.stream().map(row-> row.select(columns)).collect(Collectors.toList())); } + @Override + public int getColumnIndex(String target) { + List featureNames = Arrays.stream(this.columnMetas()).map(ColumnMeta::getName).collect(Collectors.toList()); + + int targetIndex = -1; + for (int i = 0; i < featureNames.size(); ++i) { + if (featureNames.get(i).equals(target)) { + targetIndex = i; + break; + } + } + if (targetIndex == -1) { + throw new IllegalArgumentException("No matched target when generating dataset from data frame."); + } + + return targetIndex; + } + @Override public Iterator iterator() { return rows.iterator(); diff --git a/common/src/main/java/org/opensearch/ml/common/input/parameter/regression/LogisticRegressionParams.java b/common/src/main/java/org/opensearch/ml/common/input/parameter/regression/LogisticRegressionParams.java new file mode 100644 index 0000000000..506a5289ea --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/input/parameter/regression/LogisticRegressionParams.java @@ -0,0 +1,231 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.input.parameter.regression; + +import lombok.Builder; +import lombok.Data; +import org.opensearch.common.ParseField; +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.common.xcontent.NamedXContentRegistry; +import org.opensearch.common.xcontent.XContentBuilder; +import org.opensearch.common.xcontent.XContentParser; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.annotation.MLAlgoParameter; +import org.opensearch.ml.common.input.parameter.MLAlgoParams; + +import java.io.IOException; +import java.util.Locale; + +import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken; + +@Data +@MLAlgoParameter(algorithms={FunctionName.LOGISTIC_REGRESSION}) +public class LogisticRegressionParams implements MLAlgoParams { + + public static final String PARSE_FIELD_NAME = FunctionName.LOGISTIC_REGRESSION.name(); + public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry( + MLAlgoParams.class, + new ParseField(PARSE_FIELD_NAME), + it -> parse(it) + ); + + public static final String OBJECTIVE_FIELD = "objective"; + public static final String OPTIMISER_FIELD = "optimiser"; + public static final String LEARNING_RATE_FIELD = "learning_rate"; + public static final String EPSILON_FIELD = "epsilon"; + public static final String EPOCHS_FIELD = "epochs"; + public static final String BATCH_SIZE_FIELD = "batch_size"; + public static final String SEED_FIELD = "seed"; + public static final String TARGET_FIELD = "target"; + + private LogisticRegressionParams.ObjectiveType objectiveType; + private LogisticRegressionParams.OptimizerType optimizerType; + private Double learningRate; + private Double epsilon; + private Integer epochs; + private Integer batchSize; + private Long seed; + private String target; + + @Builder(toBuilder = true) + public LogisticRegressionParams( + ObjectiveType objectiveType, + OptimizerType optimizerType, + Double learningRate, + Double epsilon, + Integer epochs, + Integer batchSize, + Long seed, + String target + ) { + this.objectiveType = objectiveType; + this.optimizerType = optimizerType; + this.learningRate = learningRate; + this.epsilon = epsilon; + this.epochs = epochs; + this.batchSize = batchSize; + this.seed = seed; + this.target = target; + } + + public LogisticRegressionParams(StreamInput in) throws IOException { + if (in.readBoolean()) { + this.objectiveType = in.readEnum(ObjectiveType.class); + } + if (in.readBoolean()) { + this.optimizerType = in.readEnum(OptimizerType.class); + } + this.learningRate = in.readOptionalDouble(); + + this.epsilon = in.readOptionalDouble(); + this.epochs = in.readOptionalInt(); + this.batchSize = in.readOptionalInt(); + this.seed = in.readOptionalLong(); + this.target = in.readOptionalString(); + } + + public static MLAlgoParams parse(XContentParser parser) throws IOException { + ObjectiveType objective = null; + OptimizerType optimizerType = null; + Double learningRate = null; + Double epsilon = null; + Integer epochs = null; + Integer batchSize = null; + Long seed = null; + String target = null; + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case OBJECTIVE_FIELD: + objective = ObjectiveType.valueOf(parser.text().toUpperCase(Locale.ROOT)); + break; + case OPTIMISER_FIELD: + optimizerType = OptimizerType.valueOf(parser.text().toUpperCase(Locale.ROOT)); + break; + case LEARNING_RATE_FIELD: + learningRate = parser.doubleValue(false); + break; + case EPSILON_FIELD: + epsilon = parser.doubleValue(false); + break; + case EPOCHS_FIELD: + epochs = parser.intValue(false); + break; + case BATCH_SIZE_FIELD: + batchSize = parser.intValue(false); + break; + case SEED_FIELD: + seed = parser.longValue(false); + break; + case TARGET_FIELD: + target = parser.text(); + break; + default: + parser.skipChildren(); + break; + } + } + return new LogisticRegressionParams(objective, optimizerType, learningRate, epsilon, epochs, batchSize, seed, target); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + if (objectiveType != null) { + out.writeBoolean(true); + out.writeEnum(objectiveType); + } else { + out.writeBoolean(false); + } + if (optimizerType != null) { + out.writeBoolean(true); + out.writeEnum(optimizerType); + } else { + out.writeBoolean(false); + } + out.writeOptionalDouble(learningRate); + out.writeOptionalDouble(epsilon); + out.writeOptionalInt(epochs); + out.writeOptionalInt(batchSize); + out.writeOptionalLong(seed); + out.writeOptionalString(target); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (objectiveType != null) { + builder.field(OBJECTIVE_FIELD, objectiveType); + } + if (optimizerType != null) { + builder.field(OPTIMISER_FIELD, optimizerType); + } + if (learningRate != null) { + builder.field(LEARNING_RATE_FIELD, learningRate); + } + if (epsilon != null) { + builder.field(EPSILON_FIELD, epsilon); + } + if (epochs != null) { + builder.field(EPOCHS_FIELD, epochs); + } + if (batchSize != null) { + builder.field(BATCH_SIZE_FIELD, batchSize); + } + if (seed != null) { + builder.field(SEED_FIELD, seed); + } + if (target != null) { + builder.field(TARGET_FIELD, target); + } + builder.endObject(); + return builder; + } + + @Override + public String getWriteableName() { + return PARSE_FIELD_NAME; + } + + @Override + public int getVersion() { + return 1; + } + + public enum ObjectiveType { + HINGE, + LOGMULTICLASS; + public static ObjectiveType from(String value) { + try{ + return ObjectiveType.valueOf(value); + } catch (Exception e) { + throw new IllegalArgumentException("Wrong objective type"); + } + } + } + + public enum OptimizerType { + SIMPLE_SGD, + LINEAR_DECAY_SGD, + SQRT_DECAY_SGD, + ADA_GRAD, + ADA_DELTA, + ADAM, + RMS_PROP; + + public static OptimizerType from(String value) { + try{ + return OptimizerType.valueOf(value); + } catch (Exception e) { + throw new IllegalArgumentException("Wrong optimizer type"); + } + } + } +} diff --git a/ml-algorithms/build.gradle b/ml-algorithms/build.gradle index 4e766e9279..889cfa76e2 100644 --- a/ml-algorithms/build.gradle +++ b/ml-algorithms/build.gradle @@ -20,6 +20,7 @@ dependencies { implementation group: 'org.tribuo', name: 'tribuo-clustering-kmeans', version: '4.2.1' implementation group: 'org.tribuo', name: 'tribuo-regression-sgd', version: '4.2.1' implementation group: 'org.tribuo', name: 'tribuo-anomaly-libsvm', version: '4.2.1' + implementation group: 'org.tribuo', name: 'tribuo-classification-sgd', version: '4.2.1' implementation group: 'commons-io', name: 'commons-io', version: '2.11.0' implementation 'software.amazon.randomcutforest:randomcutforest-parkservices:3.0-rc3' implementation 'software.amazon.randomcutforest:randomcutforest-core:3.0-rc3' diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/regression/LogisticRegression.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/regression/LogisticRegression.java new file mode 100644 index 0000000000..54086db8e4 --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/regression/LogisticRegression.java @@ -0,0 +1,145 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.algorithms.regression; + +import org.opensearch.ml.common.Model; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.dataframe.DataFrame; +import org.opensearch.ml.common.dataframe.DataFrameBuilder; +import org.opensearch.ml.common.input.parameter.MLAlgoParams; +import org.opensearch.ml.common.input.parameter.regression.LogisticRegressionParams; +import org.opensearch.ml.common.output.MLOutput; +import org.opensearch.ml.common.output.MLPredictionOutput; +import org.opensearch.ml.engine.Predictable; +import org.opensearch.ml.engine.Trainable; +import org.opensearch.ml.engine.annotation.Function; +import org.opensearch.ml.engine.contants.TribuoOutputType; +import org.opensearch.ml.engine.utils.ModelSerDeSer; +import org.opensearch.ml.engine.utils.TribuoUtil; +import org.tribuo.MutableDataset; +import org.tribuo.Prediction; +import org.tribuo.Trainer; +import org.tribuo.classification.Label; +import org.tribuo.classification.LabelFactory; +import org.tribuo.classification.sgd.LabelObjective; +import org.tribuo.classification.sgd.linear.LogisticRegressionTrainer; +import org.tribuo.classification.sgd.objectives.Hinge; +import org.tribuo.classification.sgd.objectives.LogMulticlass; +import org.tribuo.math.StochasticGradientOptimiser; +import org.tribuo.math.optimisers.AdaGrad; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +@Function(FunctionName.LOGISTIC_REGRESSION) +public class LogisticRegression implements Trainable, Predictable { + + private static final LogisticRegressionParams.ObjectiveType DEFAULT_OBJECTIVE_TYPE = LogisticRegressionParams.ObjectiveType.LOGMULTICLASS; + private static final LogisticRegressionParams.OptimizerType DEFAULT_OPTIMIZER_TYPE = LogisticRegressionParams.OptimizerType.ADA_GRAD; + private static final double DEFAULT_LEARNING_RATE = 0.01; + + //AdaGrad, AdaDelta, AdaGradRDA, Adam, RMSProp + private static final double DEFAULT_EPSILON = 1e-6; + + private static final int DEFAULT_EPOCHS = 10; + + private LogisticRegressionParams parameters; + private StochasticGradientOptimiser optimiser; + private LabelObjective objective; + /** + * Initialize a linear regression algorithm. + * @param parameters the parameters for linear regression algorithm + */ + public LogisticRegression(MLAlgoParams parameters) { + this.parameters = parameters == null ? LogisticRegressionParams.builder().build() : (LogisticRegressionParams)parameters; + validateParameters(); + createObjective(); + createOptimiser(); + } + + private void validateParameters() { + if (parameters.getLearningRate() != null && parameters.getLearningRate() < 0) { + throw new IllegalArgumentException("Learning rate should not be negative."); + } + + if (parameters.getEpsilon() != null && parameters.getEpsilon() < 0) { + throw new IllegalArgumentException("Epsilon should not be negative."); + } + + if (parameters.getEpochs() != null && parameters.getEpochs() < 0) { + throw new IllegalArgumentException("Epochs should not be negative."); + } + + if (parameters.getBatchSize() != null && parameters.getBatchSize() < 0) { + throw new IllegalArgumentException("MiniBatchSize should not be negative."); + } + } + + private void createObjective() { + LogisticRegressionParams.ObjectiveType objectiveType = Optional.ofNullable(parameters.getObjectiveType()).orElse(DEFAULT_OBJECTIVE_TYPE); + switch (objectiveType) { + case HINGE: + objective = new Hinge(); + break; + default: + objective = new LogMulticlass(); + break; + } + } + + private void createOptimiser() { + LogisticRegressionParams.OptimizerType optimizerType = Optional.ofNullable(parameters.getOptimizerType()).orElse(DEFAULT_OPTIMIZER_TYPE); + Double learningRate = Optional.ofNullable(parameters.getLearningRate()).orElse(DEFAULT_LEARNING_RATE); + Double epsilon = Optional.ofNullable(parameters.getEpsilon()).orElse(DEFAULT_EPSILON); + + switch (optimizerType) { + // ToDo: Add more possible optimizer. Tribuo only provides AdaGrad for logistic regression. + case ADA_GRAD: + optimiser = new AdaGrad(learningRate, epsilon); + break; + default: + //Use default SGD with a constant learning rate. + optimiser = new AdaGrad(learningRate, epsilon); + break; + } + } + + @Override + public Model train(DataFrame dataFrame) { + MutableDataset