diff --git a/src/main/java/org/opensearch/ad/ml/EntityColdStarter.java b/src/main/java/org/opensearch/ad/ml/EntityColdStarter.java index 521f98971..31fc8c979 100644 --- a/src/main/java/org/opensearch/ad/ml/EntityColdStarter.java +++ b/src/main/java/org/opensearch/ad/ml/EntityColdStarter.java @@ -37,6 +37,7 @@ import org.opensearch.action.ActionListener; import org.opensearch.action.support.ThreadedActionListener; import org.opensearch.ad.AnomalyDetectorPlugin; +import org.opensearch.ad.CleanState; import org.opensearch.ad.MaintenanceState; import org.opensearch.ad.NodeStateManager; import org.opensearch.ad.caching.DoorKeeper; @@ -63,7 +64,7 @@ * Training models for HCAD detectors * */ -public class EntityColdStarter implements MaintenanceState { +public class EntityColdStarter implements MaintenanceState, CleanState { private static final Logger logger = LogManager.getLogger(EntityColdStarter.class); private final Clock clock; private final ThreadPool threadPool; @@ -743,4 +744,9 @@ public void maintenance() { } }); } + + @Override + public void clear(String detectorId) { + doorKeepers.remove(detectorId); + } } diff --git a/src/main/java/org/opensearch/ad/transport/DeleteModelTransportAction.java b/src/main/java/org/opensearch/ad/transport/DeleteModelTransportAction.java index 766464b5d..839e32666 100644 --- a/src/main/java/org/opensearch/ad/transport/DeleteModelTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/DeleteModelTransportAction.java @@ -23,6 +23,7 @@ import org.opensearch.ad.NodeStateManager; import org.opensearch.ad.caching.CacheProvider; import org.opensearch.ad.feature.FeatureManager; +import org.opensearch.ad.ml.EntityColdStarter; import org.opensearch.ad.ml.ModelManager; import org.opensearch.ad.task.ADTaskCacheManager; import org.opensearch.cluster.service.ClusterService; @@ -39,6 +40,7 @@ public class DeleteModelTransportAction extends private FeatureManager featureManager; private CacheProvider cache; private ADTaskCacheManager adTaskCacheManager; + private EntityColdStarter coldStarter; @Inject public DeleteModelTransportAction( @@ -50,7 +52,8 @@ public DeleteModelTransportAction( ModelManager modelManager, FeatureManager featureManager, CacheProvider cache, - ADTaskCacheManager adTaskCacheManager + ADTaskCacheManager adTaskCacheManager, + EntityColdStarter coldStarter ) { super( DeleteModelAction.NAME, @@ -68,6 +71,7 @@ public DeleteModelTransportAction( this.featureManager = featureManager; this.cache = cache; this.adTaskCacheManager = adTaskCacheManager; + this.coldStarter = coldStarter; } @Override @@ -121,6 +125,8 @@ protected DeleteModelNodeResponse nodeOperation(DeleteModelNodeRequest request) cache.get().clear(adID); + coldStarter.clear(adID); + // delete realtime task cache adTaskCacheManager.removeRealtimeTaskCache(adID); diff --git a/src/test/java/org/opensearch/ad/e2e/DetectionResultEvalutationIT.java b/src/test/java/org/opensearch/ad/e2e/DetectionResultEvalutationIT.java index ba369dbb1..16c136fe2 100644 --- a/src/test/java/org/opensearch/ad/e2e/DetectionResultEvalutationIT.java +++ b/src/test/java/org/opensearch/ad/e2e/DetectionResultEvalutationIT.java @@ -18,17 +18,24 @@ import java.io.File; import java.io.FileReader; import java.io.IOException; +import java.io.InputStreamReader; +import java.text.SimpleDateFormat; +import java.time.Clock; import java.time.Instant; import java.time.format.DateTimeFormatter; import java.time.temporal.ChronoUnit; import java.util.AbstractMap.SimpleEntry; import java.util.ArrayList; +import java.util.Calendar; +import java.util.Date; import java.util.HashSet; import java.util.List; import java.util.Locale; import java.util.Map; import java.util.Map.Entry; import java.util.Set; +import java.util.TimeZone; +import java.util.concurrent.TimeUnit; import org.apache.http.HttpHeaders; import org.apache.http.message.BasicHeader; @@ -83,9 +90,10 @@ private void verifyAnomaly( List data = getData(dataFileName); List> anomalies = getAnomalyWindows(labelFileName); - bulkIndexTrainData(datasetName, data, trainTestSplit, client); - String detectorId = createDetector(datasetName, intervalMinutes, client); - startDetector(detectorId, data, trainTestSplit, shingleSize, intervalMinutes, client); + bulkIndexTrainData(datasetName, data, trainTestSplit, client, null); + // single-stream detector can use window delay 0 here because we give the run api the actual data time + String detectorId = createDetector(datasetName, intervalMinutes, client, null, 0); + simulateSingleStreamStartDetector(detectorId, data, trainTestSplit, shingleSize, intervalMinutes, client); bulkIndexTestData(data, datasetName, trainTestSplit, client); double[] testResults = getTestResults(detectorId, data, trainTestSplit, intervalMinutes, anomalies, client); verifyTestResults(testResults, anomalies, minPrecision, minRecall, maxError); @@ -160,7 +168,18 @@ private double[] getTestResults( return new double[] { positives, truePositives, positiveAnomalies.size(), errors }; } - private void startDetector( + /** + * Simulate starting detector without waiting for job scheduler to run. Our build process is already very slow (takes 10 mins+) + * to finish integration tests. This method triggers run API to simulate job scheduler execution in a fast-paced way. + * @param detectorId Detector Id + * @param data Data in Json format + * @param trainTestSplit Training data size + * @param shingleSize Shingle size + * @param intervalMinutes Detector Interval + * @param client OpenSearch Client + * @throws Exception when failing to query/indexing from/to OpenSearch + */ + private void simulateSingleStreamStartDetector( String detectorId, List data, int trainTestSplit, @@ -197,20 +216,100 @@ private void startDetector( } while (true); } - private String createDetector(String datasetName, int intervalMinutes, RestClient client) throws Exception { - Request request = new Request("POST", "/_opendistro/_anomaly_detection/detectors/"); - String requestBody = String - .format( - Locale.ROOT, - "{ \"name\": \"test\", \"description\": \"test\", \"time_field\": \"timestamp\"" - + ", \"indices\": [\"%s\"], \"feature_attributes\": [{ \"feature_name\": \"feature 1\", \"feature_enabled\": " - + "\"true\", \"aggregation_query\": { \"Feature1\": { \"sum\": { \"field\": \"Feature1\" } } } }, { \"feature_name\"" - + ": \"feature 2\", \"feature_enabled\": \"true\", \"aggregation_query\": { \"Feature2\": { \"sum\": { \"field\": " - + "\"Feature2\" } } } }], \"detection_interval\": { \"period\": { \"interval\": %d, \"unit\": \"Minutes\" } }, " - + "\"schema_version\": 0 }", - datasetName, - intervalMinutes - ); + /** + * Simulate starting the given HCAD detector. + * @param detectorId Detector Id + * @param data Data in Json format + * @param trainTestSplit Training data size + * @param shingleSize Shingle size + * @param intervalMinutes Detector Interval + * @param client OpenSearch Client + * @throws Exception when failing to query/indexing from/to OpenSearch + */ + private void simulateHCADStartDetector( + String detectorId, + List data, + int trainTestSplit, + int shingleSize, + int intervalMinutes, + RestClient client + ) throws Exception { + + Instant trainTime = Instant.from(DateTimeFormatter.ISO_INSTANT.parse(data.get(trainTestSplit - 1).get("timestamp").getAsString())); + + Instant begin = null; + Instant end = null; + for (int i = 0; i < shingleSize; i++) { + begin = trainTime.minus(intervalMinutes * (shingleSize - 1 - i), ChronoUnit.MINUTES); + end = begin.plus(intervalMinutes, ChronoUnit.MINUTES); + try { + getDetectionResult(detectorId, begin, end, client); + } catch (Exception e) {} + } + // It takes time to wait for model initialization + long startTime = System.currentTimeMillis(); + long duration = 0; + do { + /* + * single stream detectors will throw exception if not finding models in the + * callback, while HCAD detectors will return early, record the exception in + * node state, and throw exception in the next run. HCAD did it this way since + * it does not know when current run is gonna finish (e.g, we may have millions + * of entities to process in one run). So for single-stream detector test case, + * we can check the exception to see if models are initialized or not. So HCAD, + * we have to either wait for next runs or use profile API. Here I chose profile + * API since it is faster. Will add these explanation in the comments. + */ + Thread.sleep(5_000); + String initProgress = profileDetectorInitProgress(detectorId, client); + if (initProgress.equals("100%")) { + break; + } + try { + getDetectionResult(detectorId, begin, end, client); + } catch (Exception e) {} + duration = System.currentTimeMillis() - startTime; + } while (duration <= 60_000); + } + + private String createDetector(String datasetName, int intervalMinutes, RestClient client, String categoryField, long windowDelayInMins) + throws Exception { + Request request = new Request("POST", "/_plugins/_anomaly_detection/detectors/"); + String requestBody = null; + if (Strings.isEmpty(categoryField)) { + requestBody = String + .format( + Locale.ROOT, + "{ \"name\": \"test\", \"description\": \"test\", \"time_field\": \"timestamp\"" + + ", \"indices\": [\"%s\"], \"feature_attributes\": [{ \"feature_name\": \"feature 1\", \"feature_enabled\": " + + "\"true\", \"aggregation_query\": { \"Feature1\": { \"sum\": { \"field\": \"Feature1\" } } } }, { \"feature_name\"" + + ": \"feature 2\", \"feature_enabled\": \"true\", \"aggregation_query\": { \"Feature2\": { \"sum\": { \"field\": " + + "\"Feature2\" } } } }], \"detection_interval\": { \"period\": { \"interval\": %d, \"unit\": \"Minutes\" } }, " + + "\"window_delay\": { \"period\": {\"interval\": %d, \"unit\": \"MINUTES\"}}," + + "\"schema_version\": 0 }", + datasetName, + intervalMinutes, + windowDelayInMins + ); + } else { + requestBody = String + .format( + Locale.ROOT, + "{ \"name\": \"test\", \"description\": \"test\", \"time_field\": \"timestamp\"" + + ", \"indices\": [\"%s\"], \"feature_attributes\": [{ \"feature_name\": \"feature 1\", \"feature_enabled\": " + + "\"true\", \"aggregation_query\": { \"Feature1\": { \"sum\": { \"field\": \"Feature1\" } } } }, { \"feature_name\"" + + ": \"feature 2\", \"feature_enabled\": \"true\", \"aggregation_query\": { \"Feature2\": { \"sum\": { \"field\": " + + "\"Feature2\" } } } }], \"detection_interval\": { \"period\": { \"interval\": %d, \"unit\": \"Minutes\" } }, " + + "\"category_field\": [\"%s\"], " + + "\"window_delay\": { \"period\": {\"interval\": %d, \"unit\": \"MINUTES\"}}," + + "\"schema_version\": 0 }", + datasetName, + intervalMinutes, + categoryField, + windowDelayInMins + ); + } + request.setJsonEntity(requestBody); Map response = entityAsMap(client.performRequest(request)); String detectorId = (String) response.get("_id"); @@ -232,10 +331,24 @@ private List> getAnomalyWindows(String labalFileName) th return anomalies; } - private void bulkIndexTrainData(String datasetName, List data, int trainTestSplit, RestClient client) throws Exception { + private void bulkIndexTrainData(String datasetName, List data, int trainTestSplit, RestClient client, String categoryField) + throws Exception { Request request = new Request("PUT", datasetName); - String requestBody = "{ \"mappings\": { \"properties\": { \"timestamp\": { \"type\": \"date\"}," - + " \"Feature1\": { \"type\": \"double\" }, \"Feature2\": { \"type\": \"double\" } } } }"; + String requestBody = null; + if (Strings.isEmpty(categoryField)) { + requestBody = "{ \"mappings\": { \"properties\": { \"timestamp\": { \"type\": \"date\"}," + + " \"Feature1\": { \"type\": \"double\" }, \"Feature2\": { \"type\": \"double\" } } } }"; + } else { + requestBody = String + .format( + Locale.ROOT, + "{ \"mappings\": { \"properties\": { \"timestamp\": { \"type\": \"date\"}," + + " \"Feature1\": { \"type\": \"double\" }, \"Feature2\": { \"type\": \"double\" }," + + "\"%s\": { \"type\": \"keyword\"} } } }", + categoryField + ); + } + request.setJsonEntity(requestBody); setWarningHandler(request, false); client.performRequest(request); @@ -256,6 +369,7 @@ private void bulkIndexTrainData(String datasetName, List data, int t ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana")) ); Thread.sleep(1_000); + waitAllSyncheticDataIngested(trainTestSplit, datasetName, client); } private void bulkIndexTestData(List data, String datasetName, int trainTestSplit, RestClient client) throws Exception { @@ -274,6 +388,46 @@ private void bulkIndexTestData(List data, String datasetName, int tr ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana")) ); Thread.sleep(1_000); + waitAllSyncheticDataIngested(data.size(), datasetName, client); + } + + private void waitAllSyncheticDataIngested(int expectedSize, String datasetName, RestClient client) throws Exception { + int maxWaitCycles = 3; + do { + Request request = new Request("POST", String.format(Locale.ROOT, "/%s/_search", datasetName)); + request + .setJsonEntity( + String + .format( + Locale.ROOT, + "{\"query\": {" + + " \"match_all\": {}" + + " }," + + " \"size\": 1," + + " \"sort\": [" + + " {" + + " \"timestamp\": {" + + " \"order\": \"desc\"" + + " }" + + " }" + + " ]}" + ) + ); + // Make sure all of the test data has been ingested + // Expected response: + // "_index":"synthetic","_type":"_doc","_id":"10080","_score":null,"_source":{"timestamp":"2019-11-08T00:00:00Z","Feature1":156.30028000000001,"Feature2":100.211205,"host":"host1"},"sort":[1573171200000]} + Response response = client.performRequest(request); + JsonObject json = new JsonParser().parse(new InputStreamReader(response.getEntity().getContent())).getAsJsonObject(); + JsonArray hits = json.getAsJsonObject("hits").getAsJsonArray("hits"); + if (hits != null + && hits.size() == 1 + && expectedSize - 1 == hits.get(0).getAsJsonObject().getAsJsonPrimitive("_id").getAsLong()) { + break; + } else { + request = new Request("POST", String.format(Locale.ROOT, "/%s/_refresh", datasetName)); + client.performRequest(request); + } + } while (maxWaitCycles-- >= 0); } private void setWarningHandler(Request request, boolean strictDeprecationMode) { @@ -440,4 +594,125 @@ private void indexTrainData(String datasetName, List data, int train }); Thread.sleep(1_000); } + + public void testRestartHCADDetector() throws Exception { + // TODO: this test case will run for a much longer time and timeout with security enabled + if (!isHttps()) { + int maxRetries = 3; + int i = 0; + for (; i < maxRetries; i++) { + try { + disableResourceNotFoundFaultTolerence(); + verifyRestart("synthetic", 1, 8); + break; + } catch (Throwable throwable) { + LOG.info("Retry restart test case", throwable); + cleanUpCluster(); + wipeAllODFEIndices(); + } + } + assertTrue("failed all retries", i < maxRetries); + } + } + + private void verifyRestart(String datasetName, int intervalMinutes, int shingleSize) throws Exception { + RestClient client = client(); + + String dataFileName = String.format("data/%s.data", datasetName); + + List data = getData(dataFileName); + + String categoricalField = "host"; + String tsField = "timestamp"; + + Clock clock = Clock.systemUTC(); + long currentMilli = clock.millis(); + int trainTestSplit = 1500; + + // e.g., 2019-11-01T00:03:00Z + String pattern = "yyyy-MM-dd'T'HH:mm:ss'Z'"; + SimpleDateFormat simpleDateFormat = new SimpleDateFormat(pattern); + simpleDateFormat.setTimeZone(TimeZone.getTimeZone("UTC")); + // calculate the gap between current time and the beginning of last shingle + // the gap is used to adjust input training data's time so that the last + // few items of training data maps to current time. We need this adjustment + // because CompositeRetriever will compare expiry time with current time in hasNext + // method. The expiry time is calculated using request (one parameter of the run API) + // end time plus some fraction of interval. If the expiry time is less than + // current time, CompositeRetriever thinks this request expires and refuses to start + // querying. So this adjustment is to make the following simulateHCADStartDetector work. + String lastTrainShingleStartTime = data.get(trainTestSplit - shingleSize).getAsJsonPrimitive(tsField).getAsString(); + Date date = simpleDateFormat.parse(lastTrainShingleStartTime); + long diff = currentMilli - date.getTime(); + TimeUnit time = TimeUnit.MINUTES; + // by the time we trigger the run API, a few seconds have passed. +5 to make the adjusted time more than current time. + long gap = time.convert(diff, TimeUnit.MILLISECONDS) + 5; + + Calendar c = Calendar.getInstance(); + c.setTimeZone(TimeZone.getTimeZone("UTC")); + + // only change training data as we only need to make sure detector is fully initialized + for (int i = 0; i < trainTestSplit; i++) { + JsonObject row = data.get(i); + // add categorical field since the original data is for single-stream detectors + row.addProperty(categoricalField, "host1"); + + String dateString = row.getAsJsonPrimitive(tsField).getAsString(); + date = simpleDateFormat.parse(dateString); + c.setTime(date); + c.add(Calendar.MINUTE, (int) gap); + String adjustedDate = simpleDateFormat.format(c.getTime()); + row.addProperty(tsField, adjustedDate); + } + + bulkIndexTrainData(datasetName, data, trainTestSplit, client, categoricalField); + + String detectorId = createDetector(datasetName, intervalMinutes, client, categoricalField, 0); + // cannot stop without actually starting detector because ad complains no ad job index + startDetector(detectorId, client); + // it would be long if we wait for the job actually run the work periodically; speed it up by using simulateHCADStartDetector + simulateHCADStartDetector(detectorId, data, trainTestSplit, shingleSize, intervalMinutes, client); + String initProgress = profileDetectorInitProgress(detectorId, client); + assertEquals("init progress is " + initProgress, "100%", initProgress); + stopDetector(detectorId, client); + // restart detector + startDetector(detectorId, client); + simulateHCADStartDetector(detectorId, data, trainTestSplit, shingleSize, intervalMinutes, client); + initProgress = profileDetectorInitProgress(detectorId, client); + assertEquals("init progress is " + initProgress, "100%", initProgress); + } + + private void stopDetector(String detectorId, RestClient client) throws Exception { + Request request = new Request("POST", String.format(Locale.ROOT, "/_plugins/_anomaly_detection/detectors/%s/_stop", detectorId)); + + Map response = entityAsMap(client.performRequest(request)); + String responseDetectorId = (String) response.get("_id"); + assertEquals(detectorId, responseDetectorId); + } + + private void startDetector(String detectorId, RestClient client) throws Exception { + Request request = new Request("POST", String.format(Locale.ROOT, "/_plugins/_anomaly_detection/detectors/%s/_start", detectorId)); + + Map response = entityAsMap(client.performRequest(request)); + String responseDetectorId = (String) response.get("_id"); + assertEquals(detectorId, responseDetectorId); + } + + private String profileDetectorInitProgress(String detectorId, RestClient client) throws Exception { + Request request = new Request( + "GET", + String.format(Locale.ROOT, "/_plugins/_anomaly_detection/detectors/%s/_profile/init_progress", detectorId) + ); + + Map response = entityAsMap(client.performRequest(request)); + /* + * Example response: + * { + * "init_progress": { + * "percentage": "100%" + * } + * } + */ + return (String) ((Map) response.get("init_progress")).get("percentage"); + } } diff --git a/src/test/java/org/opensearch/ad/ml/EntityColdStarterTests.java b/src/test/java/org/opensearch/ad/ml/EntityColdStarterTests.java index 1150cf860..87ed9b95f 100644 --- a/src/test/java/org/opensearch/ad/ml/EntityColdStarterTests.java +++ b/src/test/java/org/opensearch/ad/ml/EntityColdStarterTests.java @@ -874,4 +874,97 @@ public void testAccuracyTenMinuteInterval() throws Exception { public void testAccuracyThirteenMinuteInterval() throws Exception { accuracyTemplate(13); } + + private ModelState createStateForCacheRelease() { + inProgressLatch = new CountDownLatch(1); + releaseSemaphore = () -> { + released.set(true); + inProgressLatch.countDown(); + }; + listener = ActionListener.wrap(releaseSemaphore); + Queue samples = MLUtil.createQueueSamples(1); + EntityModel model = new EntityModel(entity, samples, null); + return new ModelState<>(model, modelId, detectorId, ModelType.ENTITY.getName(), clock, priority); + } + + public void testCacheReleaseAfterMaintenance() throws IOException, InterruptedException { + ModelState modelState = createStateForCacheRelease(); + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(2); + listener.onResponse(Optional.of(1602269260000L)); + return null; + }).when(searchFeatureDao).getEntityMinDataTime(any(), any(), any()); + + List> coldStartSamples = new ArrayList<>(); + + double[] sample1 = new double[] { 57.0 }; + double[] sample2 = new double[] { 1.0 }; + double[] sample3 = new double[] { -19.0 }; + + coldStartSamples.add(Optional.of(sample1)); + coldStartSamples.add(Optional.of(sample2)); + coldStartSamples.add(Optional.of(sample3)); + doAnswer(invocation -> { + ActionListener>> listener = invocation.getArgument(4); + listener.onResponse(coldStartSamples); + return null; + }).when(searchFeatureDao).getColdStartSamplesForPeriods(any(), any(), any(), anyBoolean(), any()); + + entityColdStarter.trainModel(entity, detectorId, modelState, listener); + checkSemaphoreRelease(); + assertTrue(modelState.getModel().getTrcf().isPresent()); + + modelState = createStateForCacheRelease(); + entityColdStarter.trainModel(entity, detectorId, modelState, listener); + checkSemaphoreRelease(); + // model is not trained as the door keeper remembers it and won't retry training + assertTrue(!modelState.getModel().getTrcf().isPresent()); + + // make sure when the next maintenance coming, current door keeper gets reset + // note our detector interval is 1 minute and the door keeper will expire in 60 intervals, which are 60 minutes + when(clock.instant()).thenReturn(Instant.now().plus(AnomalyDetectorSettings.DOOR_KEEPER_MAINTENANCE_FREQ + 1, ChronoUnit.MINUTES)); + entityColdStarter.maintenance(); + + modelState = createStateForCacheRelease(); + entityColdStarter.trainModel(entity, detectorId, modelState, listener); + checkSemaphoreRelease(); + // model is trained as the door keeper gets reset + assertTrue(modelState.getModel().getTrcf().isPresent()); + } + + public void testCacheReleaseAfterClear() throws IOException, InterruptedException { + ModelState modelState = createStateForCacheRelease(); + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(2); + listener.onResponse(Optional.of(1602269260000L)); + return null; + }).when(searchFeatureDao).getEntityMinDataTime(any(), any(), any()); + + List> coldStartSamples = new ArrayList<>(); + + double[] sample1 = new double[] { 57.0 }; + double[] sample2 = new double[] { 1.0 }; + double[] sample3 = new double[] { -19.0 }; + + coldStartSamples.add(Optional.of(sample1)); + coldStartSamples.add(Optional.of(sample2)); + coldStartSamples.add(Optional.of(sample3)); + doAnswer(invocation -> { + ActionListener>> listener = invocation.getArgument(4); + listener.onResponse(coldStartSamples); + return null; + }).when(searchFeatureDao).getColdStartSamplesForPeriods(any(), any(), any(), anyBoolean(), any()); + + entityColdStarter.trainModel(entity, detectorId, modelState, listener); + checkSemaphoreRelease(); + assertTrue(modelState.getModel().getTrcf().isPresent()); + + entityColdStarter.clear(detectorId); + + modelState = createStateForCacheRelease(); + entityColdStarter.trainModel(entity, detectorId, modelState, listener); + checkSemaphoreRelease(); + // model is trained as the door keeper is regenerated after clearance + assertTrue(modelState.getModel().getTrcf().isPresent()); + } } diff --git a/src/test/java/org/opensearch/ad/transport/DeleteModelTransportActionTests.java b/src/test/java/org/opensearch/ad/transport/DeleteModelTransportActionTests.java index 4178c7391..09040669e 100644 --- a/src/test/java/org/opensearch/ad/transport/DeleteModelTransportActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/DeleteModelTransportActionTests.java @@ -34,6 +34,7 @@ import org.opensearch.ad.common.exception.JsonPathNotFoundException; import org.opensearch.ad.constant.CommonErrorMessages; import org.opensearch.ad.feature.FeatureManager; +import org.opensearch.ad.ml.EntityColdStarter; import org.opensearch.ad.ml.ModelManager; import org.opensearch.ad.task.ADTaskCacheManager; import org.opensearch.cluster.ClusterName; @@ -76,7 +77,7 @@ public void setUp() throws Exception { EntityCache entityCache = mock(EntityCache.class); when(cacheProvider.get()).thenReturn(entityCache); ADTaskCacheManager adTaskCacheManager = mock(ADTaskCacheManager.class); - NodeStateManager stateManager = mock(NodeStateManager.class); + EntityColdStarter coldStarter = mock(EntityColdStarter.class); action = new DeleteModelTransportAction( threadPool, @@ -87,7 +88,8 @@ public void setUp() throws Exception { modelManager, featureManager, cacheProvider, - adTaskCacheManager + adTaskCacheManager, + coldStarter ); }