Skip to content

Commit

Permalink
Add async getColdStartData (opendistro-for-elasticsearch#80)
Browse files Browse the repository at this point in the history
  • Loading branch information
wnbts authored and kaituo committed Apr 13, 2020
1 parent ed40113 commit b2fd4cd
Show file tree
Hide file tree
Showing 4 changed files with 393 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,52 @@ public Optional<double[][]> getColdStartData(AnomalyDetector detector) {
.map(points -> batchShingle(points, shingleSize));
}

/**
* Returns to listener data for cold-start training.
*
* Training data starts with getting samples from (costly) search.
* Samples are increased in size via interpolation and then
* in dimension via shingling.
*
* @param detector contains data info (indices, documents, etc)
* @param listener onResponse is called with data for cold-start training, or empty if unavailable
*/
public void getColdStartData(AnomalyDetector detector, ActionListener<Optional<double[][]>> listener) {
searchFeatureDao
.getLatestDataTime(
detector,
ActionListener.wrap(latest -> getColdStartSamples(latest, detector, listener), listener::onFailure)
);
}

private void getColdStartSamples(Optional<Long> latest, AnomalyDetector detector, ActionListener<Optional<double[][]>> listener) {
if (latest.isPresent()) {
searchFeatureDao
.getFeaturesForSampledPeriods(
detector,
maxTrainSamples,
maxSampleStride,
latest.get(),
ActionListener.wrap(samples -> processColdStartSamples(samples, listener), listener::onFailure)
);
} else {
listener.onResponse(Optional.empty());
}
}

private void processColdStartSamples(Optional<Entry<double[][], Integer>> samples, ActionListener<Optional<double[][]>> listener) {
listener
.onResponse(
samples
.map(
results -> transpose(
interpolator.interpolate(transpose(results.getKey()), results.getValue() * (results.getKey().length - 1) + 1)
)
)
.map(points -> batchShingle(points, shingleSize))
);
}

/**
* Shingles a batch of data points by concatenating neighboring data points.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,12 @@ public SearchFeatureDao(Client client, NamedXContentRegistry xContent, Interpola
/**
* Returns epoch time of the latest data under the detector.
*
* @deprecated use getLatestDataTime with listener instead.
*
* @param detector info about the indices and documents
* @return epoch time of the latest data in milliseconds
*/
@Deprecated
public Optional<Long> getLatestDataTime(AnomalyDetector detector) {
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder()
.aggregation(AggregationBuilders.max(AGG_NAME_MAX).field(detector.getTimeField()))
Expand All @@ -105,6 +108,30 @@ public Optional<Long> getLatestDataTime(AnomalyDetector detector) {
.map(agg -> (long) agg.getValue());
}

/**
* Returns to listener the epoch time of the latset data under the detector.
*
* @param detector info about the data
* @param listener onResponse is called with the epoch time of the latset data under the detector
*/
public void getLatestDataTime(AnomalyDetector detector, ActionListener<Optional<Long>> listener) {
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder()
.aggregation(AggregationBuilders.max(AGG_NAME_MAX).field(detector.getTimeField()))
.size(0);
SearchRequest searchRequest = new SearchRequest().indices(detector.getIndices().toArray(new String[0])).source(searchSourceBuilder);
client
.search(searchRequest, ActionListener.wrap(response -> listener.onResponse(getLatestDataTime(response)), listener::onFailure));
}

private Optional<Long> getLatestDataTime(SearchResponse searchResponse) {
return Optional
.ofNullable(searchResponse)
.map(SearchResponse::getAggregations)
.map(aggs -> aggs.asMap())
.map(map -> (Max) map.get(AGG_NAME_MAX))
.map(agg -> (long) agg.getValue());
}

/**
* Gets features for the given time period.
* This function also adds given detector to negative cache before sending es request.
Expand Down Expand Up @@ -238,6 +265,8 @@ public void getFeatureSamplesForPeriods(
/**
* Gets features for sampled periods.
*
* @deprecated use getFeaturesForSampledPeriods with listener instead.
*
* Sampling starts with the latest period and goes backwards in time until there are up to {@code maxSamples} samples.
* If the initial stride {@code maxStride} results into a low count of samples, the implementation
* may attempt with (exponentially) reduced strides and interpolate missing points.
Expand All @@ -248,6 +277,7 @@ public void getFeatureSamplesForPeriods(
* @param endTime the end time of the latest period
* @return sampled features and stride, empty when no data found
*/
@Deprecated
public Optional<Entry<double[][], Integer>> getFeaturesForSampledPeriods(
AnomalyDetector detector,
int maxSamples,
Expand Down Expand Up @@ -316,6 +346,184 @@ private Optional<double[][]> getFeaturesForSampledPeriods(
return samples;
}

/**
* Returns to listener features for sampled periods.
*
* Sampling starts with the latest period and goes backwards in time until there are up to {@code maxSamples} samples.
* If the initial stride {@code maxStride} results into a low count of samples, the implementation
* may attempt with (exponentially) reduced strides and interpolate missing points.
*
* @param detector info about indices, documents, feature query
* @param maxSamples the maximum number of samples to return
* @param maxStride the maximum number of periods between samples
* @param endTime the end time of the latest period
* @param listener onResponse is called with sampled features and stride between points, or empty for no data
*/
public void getFeaturesForSampledPeriods(
AnomalyDetector detector,
int maxSamples,
int maxStride,
long endTime,
ActionListener<Optional<Entry<double[][], Integer>>> listener
) {
Map<Long, double[]> cache = new HashMap<>();
getFeatureSamplesWithCache(detector, maxSamples, maxStride, endTime, cache, maxStride, listener);
}

private void getFeatureSamplesWithCache(
AnomalyDetector detector,
int maxSamples,
int maxStride,
long endTime,
Map<Long, double[]> cache,
int currentStride,
ActionListener<Optional<Entry<double[][], Integer>>> listener
) {
getFeatureSamplesForStride(
detector,
maxSamples,
maxStride,
currentStride,
endTime,
cache,
ActionListener
.wrap(
features -> processFeatureSamplesForStride(
features,
detector,
maxSamples,
maxStride,
currentStride,
endTime,
cache,
listener
),
listener::onFailure
)
);
}

private void processFeatureSamplesForStride(
Optional<double[][]> features,
AnomalyDetector detector,
int maxSamples,
int maxStride,
int currentStride,
long endTime,
Map<Long, double[]> cache,
ActionListener<Optional<Entry<double[][], Integer>>> listener
) {
if (!features.isPresent()) {
listener.onResponse(Optional.empty());
} else if (features.get().length > maxSamples / 2 || currentStride == 1) {
listener.onResponse(Optional.of(new SimpleEntry<>(features.get(), currentStride)));
} else {
getFeatureSamplesWithCache(detector, maxSamples, maxStride, endTime, cache, currentStride / 2, listener);
}
}

private void getFeatureSamplesForStride(
AnomalyDetector detector,
int maxSamples,
int maxStride,
int currentStride,
long endTime,
Map<Long, double[]> cache,
ActionListener<Optional<double[][]>> listener
) {
ArrayDeque<double[]> sampledFeatures = new ArrayDeque<>(maxSamples);
boolean isInterpolatable = currentStride < maxStride;
long span = ((IntervalTimeConfiguration) detector.getDetectionInterval()).toDuration().toMillis();
sampleForIteration(detector, cache, maxSamples, endTime, span, currentStride, sampledFeatures, isInterpolatable, 0, listener);
}

private void sampleForIteration(
AnomalyDetector detector,
Map<Long, double[]> cache,
int maxSamples,
long endTime,
long span,
int stride,
ArrayDeque<double[]> sampledFeatures,
boolean isInterpolatable,
int iteration,
ActionListener<Optional<double[][]>> listener
) {
if (iteration < maxSamples) {
long end = endTime - span * stride * iteration;
if (cache.containsKey(end)) {
sampledFeatures.addFirst(cache.get(end));
sampleForIteration(
detector,
cache,
maxSamples,
endTime,
span,
stride,
sampledFeatures,
isInterpolatable,
iteration + 1,
listener
);
} else {
getFeaturesForPeriod(detector, end - span, end, ActionListener.wrap(features -> {
if (features.isPresent()) {
cache.put(end, features.get());
sampledFeatures.addFirst(features.get());
sampleForIteration(
detector,
cache,
maxSamples,
endTime,
span,
stride,
sampledFeatures,
isInterpolatable,
iteration + 1,
listener
);
} else if (isInterpolatable) {
Optional<double[]> previous = Optional.ofNullable(cache.get(end - span * stride));
Optional<double[]> next = Optional.ofNullable(cache.get(end + span * stride));
if (previous.isPresent() && next.isPresent()) {
double[] interpolants = getInterpolants(previous.get(), next.get());
cache.put(end, interpolants);
sampledFeatures.addFirst(interpolants);
sampleForIteration(
detector,
cache,
maxSamples,
endTime,
span,
stride,
sampledFeatures,
isInterpolatable,
iteration + 1,
listener
);
} else {
listener.onResponse(toMatrix(sampledFeatures));
}
} else {
listener.onResponse(toMatrix(sampledFeatures));
}
}, listener::onFailure));
}
} else {
listener.onResponse(toMatrix(sampledFeatures));
}
}

private Optional<double[][]> toMatrix(ArrayDeque<double[]> sampledFeatures) {
Optional<double[][]> samples;
if (sampledFeatures.isEmpty()) {
samples = Optional.empty();
} else {
samples = Optional.of(sampledFeatures.toArray(new double[0][0]));
}
return samples;
}

private double[] getInterpolants(double[] previous, double[] next) {
return transpose(interpolator.interpolate(transpose(new double[][] { previous, next }), 3))[1];
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import org.junit.Test;
import org.junit.runner.RunWith;

import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;

Expand Down Expand Up @@ -263,6 +264,64 @@ public void getColdStartData_returnExpected(Long latestTime, Entry<double[][], I
assertTrue(Arrays.deepEquals(expected, results.orElse(null)));
}

@Test
@SuppressWarnings("unchecked")
@Parameters(method = "getColdStartDataTestData")
public void getColdStartData_returnExpectedToListener(
Long latestTime,
Entry<double[][], Integer> data,
int interpolants,
double[][] expected
) {
doAnswer(invocation -> {
ActionListener<Optional<Long>> listener = invocation.getArgument(1);
listener.onResponse(Optional.ofNullable(latestTime));
return null;
}).when(searchFeatureDao).getLatestDataTime(eq(detector), any(ActionListener.class));
if (latestTime != null) {
doAnswer(invocation -> {
ActionListener<Optional<Entry<double[][], Integer>>> listener = invocation.getArgument(4);
listener.onResponse(ofNullable(data));
return null;
})
.when(searchFeatureDao)
.getFeaturesForSampledPeriods(
eq(detector),
eq(maxTrainSamples),
eq(maxSampleStride),
eq(latestTime),
any(ActionListener.class)
);
}
if (data != null) {
when(interpolator.interpolate(argThat(new ArrayEqMatcher<>(data.getKey())), eq(interpolants))).thenReturn(data.getKey());
doReturn(data.getKey()).when(featureManager).batchShingle(argThat(new ArrayEqMatcher<>(data.getKey())), eq(shingleSize));
}

ActionListener<Optional<double[][]>> listener = mock(ActionListener.class);
featureManager.getColdStartData(detector, listener);

ArgumentCaptor<Optional<double[][]>> captor = ArgumentCaptor.forClass(Optional.class);
verify(listener).onResponse(captor.capture());
Optional<double[][]> result = captor.getValue();
assertTrue(Arrays.deepEquals(expected, result.orElse(null)));
}

@Test
@SuppressWarnings("unchecked")
public void getColdStartData_throwToListener_whenSearchFail() {
doAnswer(invocation -> {
ActionListener<Optional<Long>> listener = invocation.getArgument(1);
listener.onFailure(new RuntimeException());
return null;
}).when(searchFeatureDao).getLatestDataTime(eq(detector), any(ActionListener.class));

ActionListener<Optional<double[][]>> listener = mock(ActionListener.class);
featureManager.getColdStartData(detector, listener);

verify(listener).onFailure(any(Exception.class));
}

private Object[] batchShingleData() {
return new Object[] {
new Object[] { new double[][] { { 1.0 } }, 1, new double[][] { { 1.0 } } },
Expand Down
Loading

0 comments on commit b2fd4cd

Please sign in to comment.