Skip to content
This repository has been archived by the owner on Aug 2, 2022. It is now read-only.

Commit

Permalink
store threshold model training data in double array
Browse files Browse the repository at this point in the history
  • Loading branch information
ylwu-amzn committed Dec 22, 2020
1 parent ebc33dd commit 43f77d8
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,11 @@
import static com.amazon.opendistroforelasticsearch.ad.settings.AnomalyDetectorSettings.TIME_DECAY;

import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Deque;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;

import com.amazon.opendistroforelasticsearch.ad.ml.HybridThresholdingModel;
Expand All @@ -46,7 +45,8 @@ public class ADBatchTaskCache {
private ThresholdingModel thresholdModel;
private boolean thresholdModelTrained;
private Deque<Map.Entry<Long, Optional<double[]>>> shingle;
private List<Double> thresholdModelTrainingData;
private AtomicInteger thresholdModelTrainingDataSize = new AtomicInteger(0);
private double[] thresholdModelTrainingData;
private AtomicBoolean cancelled = new AtomicBoolean(false);
private AtomicLong cacheMemorySize = new AtomicLong(0);
private String cancelReason;
Expand Down Expand Up @@ -74,7 +74,7 @@ protected ADBatchTaskCache(ADTask adTask) {
AnomalyDetectorSettings.THRESHOLD_DOWNSAMPLES,
AnomalyDetectorSettings.THRESHOLD_MAX_SAMPLES
);
this.thresholdModelTrainingData = new ArrayList<>(THRESHOLD_MODEL_TRAINING_SIZE);
this.thresholdModelTrainingData = new double[THRESHOLD_MODEL_TRAINING_SIZE];
this.thresholdModelTrained = false;
this.shingle = new ArrayDeque<>(detector.getShingleSize());
}
Expand Down Expand Up @@ -103,10 +103,19 @@ protected boolean isThresholdModelTrained() {
return thresholdModelTrained;
}

protected List<Double> getThresholdModelTrainingData() {
protected double[] getThresholdModelTrainingData() {
return thresholdModelTrainingData;
}

protected void clearTrainingData() {
this.thresholdModelTrainingData = null;
this.thresholdModelTrainingDataSize.set(0);
}

public AtomicInteger getThresholdModelTrainingDataSize() {
return thresholdModelTrainingDataSize;
}

protected AtomicLong getCacheMemorySize() {
return cacheMemorySize;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,20 @@
import static com.amazon.opendistroforelasticsearch.ad.settings.AnomalyDetectorSettings.THRESHOLD_MODEL_TRAINING_SIZE;

import java.util.Deque;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;

import com.amazon.opendistroforelasticsearch.ad.model.AnomalyDetector;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.settings.Settings;

import com.amazon.opendistroforelasticsearch.ad.MemoryTracker;
import com.amazon.opendistroforelasticsearch.ad.common.exception.LimitExceededException;
import com.amazon.opendistroforelasticsearch.ad.ml.ThresholdingModel;
import com.amazon.opendistroforelasticsearch.ad.model.ADTask;
import com.amazon.opendistroforelasticsearch.ad.model.AnomalyDetector;
import com.amazon.randomcutforest.RandomCutForest;

public class ADTaskCacheManager {
Expand Down Expand Up @@ -126,10 +126,19 @@ public ThresholdingModel getThresholdModel(String taskId) {
* @param taskId AD task id
* @return threshold model training data
*/
public List<Double> getThresholdModelTrainingData(String taskId) {
public double[] getThresholdModelTrainingData(String taskId) {
return getBatchTaskCache(taskId).getThresholdModelTrainingData();
}

public int addThresholdModelTrainingData(String taskId, double... data) {
ADBatchTaskCache taskCache = getBatchTaskCache(taskId);
double[] thresholdModelTrainingData = taskCache.getThresholdModelTrainingData();
AtomicInteger size = taskCache.getThresholdModelTrainingDataSize();
int dataPointsAdded = Math.min(data.length, THRESHOLD_MODEL_TRAINING_SIZE - size.get());
System.arraycopy(data, 0, thresholdModelTrainingData, size.get(), dataPointsAdded);
return size.addAndGet(dataPointsAdded);
}

/**
* Threshold model trained or not.
* If task doesn't exist in cache, will throw {@link java.lang.IllegalArgumentException}.
Expand All @@ -151,9 +160,9 @@ protected void setThresholdModelTrained(String taskId, boolean trained) {
ADBatchTaskCache taskCache = getBatchTaskCache(taskId);
taskCache.setThresholdModelTrained(trained);
if (trained) {
int size = taskCache.getThresholdModelTrainingData().size();
int size = taskCache.getThresholdModelTrainingDataSize().get();
long cacheSize = trainingDataMemorySize(size);
taskCache.getThresholdModelTrainingData().clear();
taskCache.clearTrainingData();
taskCache.getCacheMemorySize().getAndAdd(-cacheSize);
memoryTracker.releaseMemory(cacheSize, true, HISTORICAL_SINGLE_ENTITY_DETECTOR);
}
Expand Down Expand Up @@ -287,14 +296,14 @@ public void clear() {

/**
* Estimate max memory usage of model training data.
* The training data is double and will cache in {@link java.util.ArrayList}.
* Check {@link ADBatchTaskCache#getThresholdModelTrainingData()}
* The training data is double and will cache in double array.
* One double consumes 8 bytes.
*
* @param size training data point count
* @return how many bytes will consume
*/
public long trainingDataMemorySize(int size) {
return 24 * size;
return 8 * size;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
import com.amazon.opendistroforelasticsearch.ad.model.ADTask;
import com.amazon.opendistroforelasticsearch.ad.model.ADTaskState;
import com.amazon.opendistroforelasticsearch.ad.settings.AnomalyDetectorSettings;
import com.google.common.collect.ImmutableList;

public class ADTaskCacheManagerTests extends ESTestCase {
private MemoryTracker memoryTracker;
Expand Down Expand Up @@ -126,8 +125,7 @@ public void testThresholdModelTrained() throws IOException {
ADTask adTask = TestHelpers.randomAdTask();
adTaskCacheManager.put(adTask);
assertEquals(1, adTaskCacheManager.size());
adTaskCacheManager.getThresholdModelTrainingData(adTask.getTaskId()).addAll(ImmutableList.of(randomDouble(), randomDouble()));
int size = adTaskCacheManager.getThresholdModelTrainingData(adTask.getTaskId()).size();
int size = adTaskCacheManager.addThresholdModelTrainingData(adTask.getTaskId(), randomDouble(), randomDouble());
long cacheSize = adTaskCacheManager.trainingDataMemorySize(size);
adTaskCacheManager.setThresholdModelTrained(adTask.getTaskId(), false);
verify(memoryTracker, never()).releaseMemory(anyLong(), anyBoolean(), eq(HISTORICAL_SINGLE_ENTITY_DETECTOR));
Expand Down

0 comments on commit 43f77d8

Please sign in to comment.