diff --git a/build.gradle b/build.gradle index 9c92f7c62..e507afa77 100644 --- a/build.gradle +++ b/build.gradle @@ -35,7 +35,7 @@ buildscript { js_resource_folder = "src/test/resources/job-scheduler" common_utils_version = System.getProperty("common_utils.version", opensearch_build) job_scheduler_version = System.getProperty("job_scheduler.version", opensearch_build) - bwcVersionShort = "2.17.0" + bwcVersionShort = "2.18.0" bwcVersion = bwcVersionShort + ".0" bwcOpenSearchADDownload = 'https://ci.opensearch.org/ci/dbc/distribution-build-opensearch/' + bwcVersionShort + '/latest/linux/x64/tar/builds/' + 'opensearch/plugins/opensearch-anomaly-detection-' + bwcVersion + '.zip' @@ -696,9 +696,8 @@ List jacocoExclusions = [ // TODO: add test coverage (kaituo) 'org.opensearch.forecast.*', + 'org.opensearch.ad.transport.ADHCImputeNodeResponse', 'org.opensearch.ad.transport.GetAnomalyDetectorTransportAction', - 'org.opensearch.ad.ml.ADColdStart', - 'org.opensearch.ad.transport.ADHCImputeNodesResponse', 'org.opensearch.timeseries.transport.BooleanNodeResponse', 'org.opensearch.timeseries.ml.TimeSeriesSingleStreamCheckpointDao', 'org.opensearch.timeseries.transport.JobRequest', @@ -713,7 +712,6 @@ List jacocoExclusions = [ 'org.opensearch.timeseries.transport.ResultBulkTransportAction', 'org.opensearch.timeseries.transport.handler.IndexMemoryPressureAwareResultHandler', 'org.opensearch.timeseries.transport.handler.ResultIndexingHandler', - 'org.opensearch.ad.transport.ADHCImputeNodeResponse', 'org.opensearch.timeseries.ml.Sample', 'org.opensearch.timeseries.ratelimit.FeatureRequest', 'org.opensearch.ad.transport.ADHCImputeNodeRequest', diff --git a/src/main/java/org/opensearch/ad/model/AnomalyResult.java b/src/main/java/org/opensearch/ad/model/AnomalyResult.java index 868317bab..f52fe7439 100644 --- a/src/main/java/org/opensearch/ad/model/AnomalyResult.java +++ b/src/main/java/org/opensearch/ad/model/AnomalyResult.java @@ -446,7 +446,7 @@ public static AnomalyResult fromRawTRCFResult( taskId, rcfScore, Math.max(0, grade), - confidence, + Math.min(1, confidence), featureData, dataStartTime, dataEndTime, diff --git a/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorTransportAction.java b/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorTransportAction.java index 0bae4a5ce..46e246191 100644 --- a/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorTransportAction.java @@ -86,10 +86,11 @@ public GetAnomalyDetectorTransportAction( } @Override - protected void fillInHistoricalTaskforBwc(Map tasks, Optional historicalAdTask) { + protected Optional fillInHistoricalTaskforBwc(Map tasks) { if (tasks.containsKey(ADTaskType.HISTORICAL.name())) { - historicalAdTask = Optional.ofNullable(tasks.get(ADTaskType.HISTORICAL.name())); + return Optional.ofNullable(tasks.get(ADTaskType.HISTORICAL.name())); } + return Optional.empty(); } @Override diff --git a/src/main/java/org/opensearch/forecast/model/ForecastResult.java b/src/main/java/org/opensearch/forecast/model/ForecastResult.java index 69dd4d6ea..49f51e2e9 100644 --- a/src/main/java/org/opensearch/forecast/model/ForecastResult.java +++ b/src/main/java/org/opensearch/forecast/model/ForecastResult.java @@ -188,7 +188,7 @@ public static List fromRawRCFCasterResult( new ForecastResult( forecasterId, taskId, - dataQuality, + Math.min(1, dataQuality), featureData, dataStartTime, dataEndTime, @@ -218,7 +218,7 @@ public static List fromRawRCFCasterResult( new ForecastResult( forecasterId, taskId, - dataQuality, + Math.min(1, dataQuality), null, dataStartTime, dataEndTime, diff --git a/src/main/java/org/opensearch/timeseries/caching/PriorityCache.java b/src/main/java/org/opensearch/timeseries/caching/PriorityCache.java index 043c197cf..1d984be46 100644 --- a/src/main/java/org/opensearch/timeseries/caching/PriorityCache.java +++ b/src/main/java/org/opensearch/timeseries/caching/PriorityCache.java @@ -463,15 +463,16 @@ public Pair, List> selectUpdateCandidate(Collection return Pair.of(hotEntities, coldEntities); } - private CacheBufferType computeBufferIfAbsent(Config config, String configId) { + public CacheBufferType computeBufferIfAbsent(Config config, String configId) { CacheBufferType buffer = activeEnities.get(configId); if (buffer == null) { - long requiredBytes = getRequiredMemory(config, config.isHighCardinality() ? hcDedicatedCacheSize : 1); + long bytesPerEntityModel = getRequiredMemoryPerEntity(config, memoryTracker, numberOfTrees); + long requiredBytes = bytesPerEntityModel * (config.isHighCardinality() ? hcDedicatedCacheSize : 1); if (memoryTracker.canAllocateReserved(requiredBytes)) { memoryTracker.consumeMemory(requiredBytes, true, origin); buffer = createEmptyCacheBuffer( config, - requiredBytes, + bytesPerEntityModel, priorityTrackerMap .getOrDefault( configId, @@ -496,16 +497,6 @@ private CacheBufferType computeBufferIfAbsent(Config config, String configId) { return buffer; } - /** - * - * @param config Detector config accessor - * @param numberOfEntity number of entities - * @return Memory in bytes required for hosting numberOfEntity entities - */ - private long getRequiredMemory(Config config, int numberOfEntity) { - return numberOfEntity * getRequiredMemoryPerEntity(config, memoryTracker, numberOfTrees); - } - /** * Whether the candidate entity can replace any entity in the shared cache. * We can have race conditions when multiple threads try to evaluate this diff --git a/src/main/java/org/opensearch/timeseries/transport/BaseGetConfigTransportAction.java b/src/main/java/org/opensearch/timeseries/transport/BaseGetConfigTransportAction.java index 3b6ad29d9..b803a4851 100644 --- a/src/main/java/org/opensearch/timeseries/transport/BaseGetConfigTransportAction.java +++ b/src/main/java/org/opensearch/timeseries/transport/BaseGetConfigTransportAction.java @@ -180,7 +180,7 @@ public void doExecute(Task task, ActionRequest request, ActionListener tasks, Optional historicalAdTask) {} + protected Optional fillInHistoricalTaskforBwc(Map tasks) { + return Optional.empty(); + } protected void getExecuteProfile( GetConfigRequest request, diff --git a/src/main/java/org/opensearch/timeseries/transport/ResultProcessor.java b/src/main/java/org/opensearch/timeseries/transport/ResultProcessor.java index b244ee4ac..f412ce84e 100644 --- a/src/main/java/org/opensearch/timeseries/transport/ResultProcessor.java +++ b/src/main/java/org/opensearch/timeseries/transport/ResultProcessor.java @@ -229,8 +229,6 @@ public void onResponse(CompositeRetriever.Page entityFeatures) { pageIterator.next(this); } if (entityFeatures != null && false == entityFeatures.isEmpty()) { - sentOutPages.incrementAndGet(); - LOG .info( "Sending an HC request to process data from timestamp {} to {} for config {}", @@ -285,6 +283,7 @@ public void onResponse(CompositeRetriever.Page entityFeatures) { final AtomicReference failure = new AtomicReference<>(); node2Entities.stream().forEach(nodeEntity -> { + sentOutPages.incrementAndGet(); DiscoveryNode node = nodeEntity.getKey(); transportService .sendRequest( @@ -370,7 +369,15 @@ public void run() { cancellable.get().cancel(); } } else if (Instant.now().toEpochMilli() >= timeoutMillis) { - LOG.warn("Scheduled impute HC task is cancelled due to timeout"); + LOG + .warn( + "Scheduled impute HC task is cancelled due to timeout, current epoch {}, timeout epoch {}, dataEndTime {}, sent out {}, receive {}", + Instant.now().toEpochMilli(), + timeoutMillis, + dataEndTime, + sentOutPages.get(), + receivedPages.get() + ); if (cancellable != null) { cancellable.get().cancel(); } diff --git a/src/test/java/org/opensearch/ad/HistoricalAnalysisRestTestCase.java b/src/test/java/org/opensearch/ad/HistoricalAnalysisRestTestCase.java index d71b3370b..a7391f894 100644 --- a/src/test/java/org/opensearch/ad/HistoricalAnalysisRestTestCase.java +++ b/src/test/java/org/opensearch/ad/HistoricalAnalysisRestTestCase.java @@ -269,6 +269,29 @@ protected List waitUntilTaskReachState(String detectorId, Set ta return results; } + protected List waitUntilTaskReachNumberOfEntities(String detectorId, int categoricalValuesCount) throws InterruptedException { + List results = new ArrayList<>(); + int i = 0; + ADTaskProfile adTaskProfile = null; + // Increase retryTimes if some task can't reach done state + while ((adTaskProfile == null + || adTaskProfile.getTotalEntitiesCount() == null + || adTaskProfile.getTotalEntitiesCount().intValue() != categoricalValuesCount) && i < MAX_RETRY_TIMES) { + try { + adTaskProfile = getADTaskProfile(detectorId); + } catch (Exception e) { + logger.error("failed to get ADTaskProfile", e); + } finally { + Thread.sleep(1000); + } + i++; + } + assertNotNull(adTaskProfile); + results.add(adTaskProfile); + results.add(i); + return results; + } + protected List waitUntilEntityCountAvailable(String detectorId) throws InterruptedException { List results = new ArrayList<>(); int i = 0; diff --git a/src/test/java/org/opensearch/ad/caching/PriorityCacheTests.java b/src/test/java/org/opensearch/ad/caching/PriorityCacheTests.java index 7c09709f0..3da08575d 100644 --- a/src/test/java/org/opensearch/ad/caching/PriorityCacheTests.java +++ b/src/test/java/org/opensearch/ad/caching/PriorityCacheTests.java @@ -26,6 +26,7 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import java.io.IOException; import java.time.Duration; import java.time.Instant; import java.util.ArrayDeque; @@ -62,6 +63,7 @@ import org.opensearch.threadpool.Scheduler.ScheduledCancellable; import org.opensearch.threadpool.ThreadPool; import org.opensearch.timeseries.MemoryTracker; +import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.breaker.CircuitBreakerService; import org.opensearch.timeseries.common.exception.LimitExceededException; import org.opensearch.timeseries.common.exception.TimeSeriesException; @@ -788,4 +790,48 @@ public void testGetTotalUpdates_orElseGetBranchWithNullSamples() { // Assert that the result is 0L assertEquals(0L, result); } + + public void testAllocation() throws IOException { + JvmService jvmService = mock(JvmService.class); + JvmInfo info = mock(JvmInfo.class); + + when(jvmService.info()).thenReturn(info); + + Mem mem = mock(Mem.class); + when(mem.getHeapMax()).thenReturn(new ByteSizeValue(800_000_000L)); + when(info.getMem()).thenReturn(mem); + + CircuitBreakerService circuitBreaker = mock(CircuitBreakerService.class); + when(circuitBreaker.isOpen()).thenReturn(false); + MemoryTracker tracker = new MemoryTracker(jvmService, 0.1, clusterService, circuitBreaker); + + dedicatedCacheSize = 10; + ADPriorityCache cache = new ADPriorityCache( + checkpoint, + dedicatedCacheSize, + AnomalyDetectorSettings.AD_CHECKPOINT_TTL, + AnomalyDetectorSettings.MAX_INACTIVE_ENTITIES, + tracker, + TimeSeriesSettings.NUM_TREES, + clock, + clusterService, + TimeSeriesSettings.HOURLY_MAINTENANCE, + threadPool, + TimeSeriesSettings.MAINTENANCE_FREQ_CONSTANT, + Settings.EMPTY, + AnomalyDetectorSettings.AD_CHECKPOINT_SAVING_FREQ, + checkpointWriteQueue, + checkpointMaintainQueue + ); + + List categoryFields = Arrays.asList("category_field_1", "category_field_2"); + AnomalyDetector anomalyDetector = TestHelpers.AnomalyDetectorBuilder + .newInstance(5) + .setShingleSize(8) + .setCategoryFields(categoryFields) + .build(); + ADCacheBuffer buffer = cache.computeBufferIfAbsent(anomalyDetector, anomalyDetector.getId()); + assertEquals(698336, buffer.getMemoryConsumptionPerModel()); + assertEquals(698336 * dedicatedCacheSize, tracker.getTotalMemoryBytes()); + } } diff --git a/src/test/java/org/opensearch/ad/e2e/AbstractRuleTestCase.java b/src/test/java/org/opensearch/ad/e2e/AbstractRuleTestCase.java index 030f98d20..d1e7a100a 100644 --- a/src/test/java/org/opensearch/ad/e2e/AbstractRuleTestCase.java +++ b/src/test/java/org/opensearch/ad/e2e/AbstractRuleTestCase.java @@ -47,7 +47,7 @@ protected TrainResult ingestTrainDataAndCreateDetector( int trainTestSplit, boolean useDateNanos ) throws Exception { - return ingestTrainDataAndCreateDetector(datasetName, intervalMinutes, numberOfEntities, trainTestSplit, useDateNanos, -1); + return ingestTrainDataAndCreateDetector(datasetName, intervalMinutes, numberOfEntities, trainTestSplit, useDateNanos, -1, true); } protected TrainResult ingestTrainDataAndCreateDetector( @@ -56,7 +56,8 @@ protected TrainResult ingestTrainDataAndCreateDetector( int numberOfEntities, int trainTestSplit, boolean useDateNanos, - int ingestDataSize + int ingestDataSize, + boolean relative ) throws Exception { TrainResult trainResult = ingestTrainData( datasetName, @@ -67,7 +68,7 @@ protected TrainResult ingestTrainDataAndCreateDetector( ingestDataSize ); - String detector = genDetector(datasetName, intervalMinutes, trainTestSplit, trainResult); + String detector = genDetector(datasetName, intervalMinutes, trainTestSplit, trainResult, relative); String detectorId = createDetector(client(), detector); LOG.info("Created detector {}", detectorId); trainResult.detectorId = detectorId; @@ -75,7 +76,22 @@ protected TrainResult ingestTrainDataAndCreateDetector( return trainResult; } - protected String genDetector(String datasetName, int intervalMinutes, int trainTestSplit, TrainResult trainResult) { + protected String genDetector(String datasetName, int intervalMinutes, int trainTestSplit, TrainResult trainResult, boolean relative) { + // Determine threshold types and values based on the 'relative' parameter + String thresholdType1; + String thresholdType2; + double value; + if (relative) { + thresholdType1 = "actual_over_expected_ratio"; + thresholdType2 = "expected_over_actual_ratio"; + value = 0.3; + } else { + thresholdType1 = "actual_over_expected_margin"; + thresholdType2 = "expected_over_actual_margin"; + value = 3000.0; + } + + // Generate the detector JSON string with the appropriate threshold types and values String detector = String .format( Locale.ROOT, @@ -87,15 +103,20 @@ protected String genDetector(String datasetName, int intervalMinutes, int trainT + "\"window_delay\": { \"period\": {\"interval\": %d, \"unit\": \"MINUTES\"}}," + "\"history\": %d," + "\"schema_version\": 0," - + "\"rules\": [{\"action\": \"ignore_anomaly\", \"conditions\": [{\"feature_name\": \"feature 1\", \"threshold_type\": \"actual_over_expected_ratio\", \"operator\": \"lte\", \"value\": 0.3}, " - + "{\"feature_name\": \"feature 1\", \"threshold_type\": \"expected_over_actual_ratio\", \"operator\": \"lte\", \"value\": 0.3}" + + "\"rules\": [{\"action\": \"ignore_anomaly\", \"conditions\": [" + + "{ \"feature_name\": \"feature 1\", \"threshold_type\": \"%s\", \"operator\": \"lte\", \"value\": %f }, " + + "{ \"feature_name\": \"feature 1\", \"threshold_type\": \"%s\", \"operator\": \"lte\", \"value\": %f }" + "]}]" + "}", datasetName, intervalMinutes, categoricalField, trainResult.windowDelay.toMinutes(), - trainTestSplit - 1 + trainTestSplit - 1, + thresholdType1, + value, + thresholdType2, + value ); return detector; } diff --git a/src/test/java/org/opensearch/ad/e2e/PreviewRuleIT.java b/src/test/java/org/opensearch/ad/e2e/PreviewRuleIT.java index 8b481e3c9..e5194bb63 100644 --- a/src/test/java/org/opensearch/ad/e2e/PreviewRuleIT.java +++ b/src/test/java/org/opensearch/ad/e2e/PreviewRuleIT.java @@ -32,7 +32,7 @@ public void testRule() throws Exception { (trainTestSplit + 1) * numberOfEntities ); - String detector = genDetector(datasetName, intervalMinutes, trainTestSplit, trainResult); + String detector = genDetector(datasetName, intervalMinutes, trainTestSplit, trainResult, true); Map result = preview(detector, trainResult.firstDataTime, trainResult.finalDataTime, client()); List results = (List) XContentMapValues.extractValue(result, "anomaly_result"); assertTrue(results.size() > 100); diff --git a/src/test/java/org/opensearch/ad/e2e/RealTimeRuleIT.java b/src/test/java/org/opensearch/ad/e2e/RealTimeRuleIT.java index 650fec64d..a4d4a855b 100644 --- a/src/test/java/org/opensearch/ad/e2e/RealTimeRuleIT.java +++ b/src/test/java/org/opensearch/ad/e2e/RealTimeRuleIT.java @@ -15,7 +15,7 @@ import com.google.gson.JsonObject; public class RealTimeRuleIT extends AbstractRuleTestCase { - public void testRuleWithDateNanos() throws Exception { + private void template(boolean reltive) throws Exception { // TODO: this test case will run for a much longer time and timeout with security enabled if (!isHttps()) { disableResourceNotFoundFaultTolerence(); @@ -32,7 +32,8 @@ public void testRuleWithDateNanos() throws Exception { trainTestSplit, true, // ingest just enough for finish the test - (trainTestSplit + 1) * numberOfEntities + (trainTestSplit + 1) * numberOfEntities, + reltive ); startRealTimeDetector(trainResult, numberOfEntities, intervalMinutes, false); @@ -90,4 +91,12 @@ public void testRuleWithDateNanos() throws Exception { } } } + + public void testRelativeRule() throws Exception { + template(true); + } + + public void testAbsoluateRule() throws Exception { + template(false); + } } diff --git a/src/test/java/org/opensearch/ad/ml/EntityColdStarterTests.java b/src/test/java/org/opensearch/ad/ml/EntityColdStarterTests.java index 3feccd298..d89e03128 100644 --- a/src/test/java/org/opensearch/ad/ml/EntityColdStarterTests.java +++ b/src/test/java/org/opensearch/ad/ml/EntityColdStarterTests.java @@ -1275,4 +1275,25 @@ public void testNotEnoughTrainingData() throws IOException, InterruptedException checkSemaphoreRelease(); assertTrue(modelState.getModel().isEmpty()); } + + public void testTrainModelFromInvalidSamplesNotEnoughSamples() { + Deque samples = new ArrayDeque<>(); + // we have at least numMinSamples samples before executing the null check of trainModelFromDataSegments + for (int i = 0; i < numMinSamples; i++) { + samples.add(new Sample()); + } + + modelState = new ModelState( + null, + modelId, + detectorId, + ModelManager.ModelType.TRCF.getName(), + clock, + priority, + Optional.of(entity), + samples + ); + entityColdStarter.trainModelFromExistingSamples(modelState, detector, "123"); + assertTrue(modelState.getModel().isEmpty()); + } } diff --git a/src/test/java/org/opensearch/ad/model/AnomalyResultTests.java b/src/test/java/org/opensearch/ad/model/AnomalyResultTests.java index 28245aa31..cac5e8a4d 100644 --- a/src/test/java/org/opensearch/ad/model/AnomalyResultTests.java +++ b/src/test/java/org/opensearch/ad/model/AnomalyResultTests.java @@ -12,10 +12,15 @@ package org.opensearch.ad.model; import java.io.IOException; +import java.time.Instant; import java.util.Collection; +import java.util.Collections; +import java.util.List; import java.util.Locale; +import java.util.Optional; import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.commons.authuser.User; import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput; import org.opensearch.core.common.io.stream.NamedWriteableRegistry; import org.opensearch.core.xcontent.ToXContent; @@ -24,6 +29,8 @@ import org.opensearch.test.OpenSearchSingleNodeTestCase; import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.FeatureData; import com.google.common.base.Objects; @@ -152,4 +159,64 @@ public void testSerializeAnomalyResultWithEntity() throws IOException { AnomalyResult parsedDetectResult = new AnomalyResult(input); assertTrue(parsedDetectResult.equals(detectResult)); } + + public void testFromRawTRCFResultWithHighConfidence() { + // Set up test parameters + String detectorId = "test-detector-id"; + long intervalMillis = 60000; // Example interval + String taskId = "test-task-id"; + Double rcfScore = 0.5; + Double grade = 0.0; // Non-anomalous + Double confidence = 1.03; // Confidence greater than 1 + List featureData = Collections.emptyList(); // Assuming empty for simplicity + Instant dataStartTime = Instant.now(); + Instant dataEndTime = dataStartTime.plusMillis(intervalMillis); + Instant executionStartTime = Instant.now(); + Instant executionEndTime = executionStartTime.plusMillis(500); + String error = null; + Optional entity = Optional.empty(); + User user = null; // Replace with actual user if needed + Integer schemaVersion = 1; + String modelId = "test-model-id"; + double[] relevantAttribution = null; + Integer relativeIndex = null; + double[] pastValues = null; + double[][] expectedValuesList = null; + double[] likelihoodOfValues = null; + Double threshold = null; + double[] currentData = null; + boolean[] featureImputed = null; + + // Invoke the method under test + AnomalyResult result = AnomalyResult + .fromRawTRCFResult( + detectorId, + intervalMillis, + taskId, + rcfScore, + grade, + confidence, + featureData, + dataStartTime, + dataEndTime, + executionStartTime, + executionEndTime, + error, + entity, + user, + schemaVersion, + modelId, + relevantAttribution, + relativeIndex, + pastValues, + expectedValuesList, + likelihoodOfValues, + threshold, + currentData, + featureImputed + ); + + // Assert that the confidence is capped at 1.0 + assertEquals("Confidence should be capped at 1.0", 1.0, result.getConfidence(), 0.00001); + } } diff --git a/src/test/java/org/opensearch/ad/model/GetAnomalyDetectorTransportActionTests.java b/src/test/java/org/opensearch/ad/model/GetAnomalyDetectorTransportActionTests.java new file mode 100644 index 000000000..64295e4e2 --- /dev/null +++ b/src/test/java/org/opensearch/ad/model/GetAnomalyDetectorTransportActionTests.java @@ -0,0 +1,95 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ad.model; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.function.Consumer; + +import org.opensearch.action.support.ActionFilters; +import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.ad.task.ADTaskManager; +import org.opensearch.ad.transport.GetAnomalyDetectorResponse; +import org.opensearch.ad.transport.GetAnomalyDetectorTransportAction; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.action.ActionListener; +import org.opensearch.timeseries.AbstractTimeSeriesTest; +import org.opensearch.timeseries.transport.GetConfigRequest; +import org.opensearch.transport.TransportService; + +public class GetAnomalyDetectorTransportActionTests extends AbstractTimeSeriesTest { + @SuppressWarnings("unchecked") + public void testRealtimeTaskAssignedWithSingleStreamRealTimeTaskName() throws Exception { + // Arrange + String configID = "test-config-id"; + + // Create a task with singleStreamRealTimeTaskName + Map tasks = new HashMap<>(); + ADTask adTask = ADTask.builder().taskType(ADTaskType.HISTORICAL.name()).build(); + tasks.put(ADTaskType.HISTORICAL.name(), adTask); + + // Mock taskManager to return the tasks + ADTaskManager taskManager = mock(ADTaskManager.class); + doAnswer(invocation -> { + List taskList = new ArrayList<>(tasks.values()); + ((Consumer>) invocation.getArguments()[4]).accept(taskList); + return null; + }).when(taskManager).getAndExecuteOnLatestTasks(anyString(), any(), any(), any(), any(), any(), anyBoolean(), anyInt(), any()); + + // Mock listener + ActionListener listener = mock(ActionListener.class); + + ClusterService clusterService = mock(ClusterService.class); + ClusterSettings settings = new ClusterSettings( + Settings.EMPTY, + Collections.unmodifiableSet(new HashSet<>(Arrays.asList(AnomalyDetectorSettings.AD_FILTER_BY_BACKEND_ROLES))) + ); + when(clusterService.getClusterSettings()).thenReturn(settings); + GetAnomalyDetectorTransportAction getForecaster = spy( + new GetAnomalyDetectorTransportAction( + mock(TransportService.class), + null, + mock(ActionFilters.class), + clusterService, + null, + null, + Settings.EMPTY, + null, + taskManager, + null + ) + ); + + // Act + GetConfigRequest request = new GetConfigRequest(configID, 0L, true, true, "", "", true, null); + getForecaster.getExecute(request, listener); + + // Assert + // Verify that realtimeTask is assigned using singleStreamRealTimeTaskName + // This can be checked by verifying interactions or internal state + // For this example, we'll verify that the correct task is passed to getConfigAndJob + verify(getForecaster).getConfigAndJob(eq(configID), anyBoolean(), anyBoolean(), any(), eq(Optional.of(adTask)), eq(listener)); + } +} diff --git a/src/test/java/org/opensearch/ad/rest/HistoricalAnalysisRestApiIT.java b/src/test/java/org/opensearch/ad/rest/HistoricalAnalysisRestApiIT.java index 8661259c0..80d8f7509 100644 --- a/src/test/java/org/opensearch/ad/rest/HistoricalAnalysisRestApiIT.java +++ b/src/test/java/org/opensearch/ad/rest/HistoricalAnalysisRestApiIT.java @@ -105,14 +105,17 @@ public void testHistoricalAnalysisForMultiCategoryHC() throws Exception { } private void checkIfTaskCanFinishCorrectly(String detectorId, String taskId, Set states) throws InterruptedException { - List results = waitUntilTaskDone(detectorId); + List results = waitUntilTaskReachState(detectorId, states); TaskProfile endTaskProfile = (TaskProfile) results.get(0); Integer retryCount = (Integer) results.get(1); ADTask stoppedAdTask = endTaskProfile.getTask(); assertEquals(taskId, stoppedAdTask.getTaskId()); if (retryCount < MAX_RETRY_TIMES) { // It's possible that historical analysis still running after max retry times - assertTrue(states.contains(stoppedAdTask.getState())); + assertTrue( + "expect: " + stoppedAdTask.getState() + ", but got " + stoppedAdTask.getState(), + states.contains(stoppedAdTask.getState()) + ); } } @@ -134,6 +137,10 @@ private List startHistoricalAnalysis(int categoryFieldSize, String resul if (!TaskState.RUNNING.name().equals(adTaskProfile.getTask().getState())) { adTaskProfile = (ADTaskProfile) waitUntilTaskReachState(detectorId, ImmutableSet.of(TaskState.RUNNING.name())).get(0); } + if (adTaskProfile == null + || (int) Math.pow(categoryFieldDocCount, categoryFieldSize) != adTaskProfile.getTotalEntitiesCount().intValue()) { + adTaskProfile = (ADTaskProfile) waitUntilTaskReachNumberOfEntities(detectorId, categoryFieldDocCount).get(0); + } assertEquals((int) Math.pow(categoryFieldDocCount, categoryFieldSize), adTaskProfile.getTotalEntitiesCount().intValue()); assertTrue(adTaskProfile.getPendingEntitiesCount() > 0); assertTrue(adTaskProfile.getRunningEntitiesCount() > 0); diff --git a/src/test/java/org/opensearch/ad/transport/ADHCImputeNodesResponseTests.java b/src/test/java/org/opensearch/ad/transport/ADHCImputeNodesResponseTests.java new file mode 100644 index 000000000..f2657f21d --- /dev/null +++ b/src/test/java/org/opensearch/ad/transport/ADHCImputeNodesResponseTests.java @@ -0,0 +1,118 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ad.transport; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; + +import org.opensearch.Version; +import org.opensearch.action.FailedNodeException; +import org.opensearch.cluster.ClusterName; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.test.OpenSearchTestCase; + +public class ADHCImputeNodesResponseTests extends OpenSearchTestCase { + + public void testADHCImputeNodesResponseSerialization() throws IOException { + // Arrange + DiscoveryNode node = new DiscoveryNode( + "nodeId", + buildNewFakeTransportAddress(), + Collections.emptyMap(), + Collections.emptySet(), + Version.CURRENT + ); + Exception previousException = new Exception("Test exception message"); + + ADHCImputeNodeResponse nodeResponse = new ADHCImputeNodeResponse(node, previousException); + List nodes = Collections.singletonList(nodeResponse); + List failures = Collections.emptyList(); + ClusterName clusterName = new ClusterName("test-cluster"); + + ADHCImputeNodesResponse response = new ADHCImputeNodesResponse(clusterName, nodes, failures); + + // Act: Serialize the response + BytesStreamOutput output = new BytesStreamOutput(); + response.writeTo(output); + + // Deserialize the response + StreamInput input = output.bytes().streamInput(); + ADHCImputeNodesResponse deserializedResponse = new ADHCImputeNodesResponse(input); + + // Assert + assertEquals(clusterName, deserializedResponse.getClusterName()); + assertEquals(response.getNodes().size(), deserializedResponse.getNodes().size()); + assertEquals(response.failures().size(), deserializedResponse.failures().size()); + + // Check the node response + ADHCImputeNodeResponse deserializedNodeResponse = deserializedResponse.getNodes().get(0); + assertEquals(node, deserializedNodeResponse.getNode()); + assertNotNull(deserializedNodeResponse.getPreviousException()); + assertEquals("exception: " + previousException.getMessage(), deserializedNodeResponse.getPreviousException().getMessage()); + } + + public void testReadNodesFromAndWriteNodesTo() throws IOException { + // Arrange + DiscoveryNode node = new DiscoveryNode( + "nodeId", + buildNewFakeTransportAddress(), + Collections.emptyMap(), + Collections.emptySet(), + Version.CURRENT + ); + Exception previousException = new Exception("Test exception message"); + + ADHCImputeNodeResponse nodeResponse = new ADHCImputeNodeResponse(node, previousException); + List nodes = Collections.singletonList(nodeResponse); + ClusterName clusterName = new ClusterName("test-cluster"); + ADHCImputeNodesResponse response = new ADHCImputeNodesResponse(clusterName, nodes, Collections.emptyList()); + + // Act: Write nodes to output + BytesStreamOutput output = new BytesStreamOutput(); + response.writeNodesTo(output, nodes); + + // Read nodes from input + StreamInput input = output.bytes().streamInput(); + List readNodes = response.readNodesFrom(input); + + // Assert + assertEquals(nodes.size(), readNodes.size()); + ADHCImputeNodeResponse readNodeResponse = readNodes.get(0); + assertEquals(node, readNodeResponse.getNode()); + assertNotNull(readNodeResponse.getPreviousException()); + assertEquals("exception: " + previousException.getMessage(), readNodeResponse.getPreviousException().getMessage()); + } + + public void testADHCImputeNodeResponseSerialization() throws IOException { + // Arrange + DiscoveryNode node = new DiscoveryNode( + "nodeId", + buildNewFakeTransportAddress(), + Collections.emptyMap(), + Collections.emptySet(), + Version.CURRENT + ); + Exception previousException = new Exception("Test exception message"); + + ADHCImputeNodeResponse nodeResponse = new ADHCImputeNodeResponse(node, previousException); + + // Act: Serialize the node response + BytesStreamOutput output = new BytesStreamOutput(); + nodeResponse.writeTo(output); + + // Deserialize the node response + StreamInput input = output.bytes().streamInput(); + ADHCImputeNodeResponse deserializedNodeResponse = new ADHCImputeNodeResponse(input); + + // Assert + assertEquals(node, deserializedNodeResponse.getNode()); + assertNotNull(deserializedNodeResponse.getPreviousException()); + assertEquals("exception: " + previousException.getMessage(), deserializedNodeResponse.getPreviousException().getMessage()); + } +} diff --git a/src/test/java/org/opensearch/forecast/transport/GetForecasterTransportActionTests.java b/src/test/java/org/opensearch/forecast/transport/GetForecasterTransportActionTests.java new file mode 100644 index 000000000..c71470231 --- /dev/null +++ b/src/test/java/org/opensearch/forecast/transport/GetForecasterTransportActionTests.java @@ -0,0 +1,95 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.transport; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.function.Consumer; + +import org.opensearch.action.support.ActionFilters; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.action.ActionListener; +import org.opensearch.forecast.model.ForecastTask; +import org.opensearch.forecast.model.ForecastTaskType; +import org.opensearch.forecast.settings.ForecastSettings; +import org.opensearch.forecast.task.ForecastTaskManager; +import org.opensearch.timeseries.AbstractTimeSeriesTest; +import org.opensearch.timeseries.transport.GetConfigRequest; +import org.opensearch.transport.TransportService; + +public class GetForecasterTransportActionTests extends AbstractTimeSeriesTest { + @SuppressWarnings("unchecked") + public void testRealtimeTaskAssignedWithSingleStreamRealTimeTaskName() throws Exception { + // Arrange + String configID = "test-config-id"; + + // Create a task with singleStreamRealTimeTaskName + Map tasks = new HashMap<>(); + ForecastTask forecastTask = ForecastTask.builder().taskType(ForecastTaskType.REALTIME_FORECAST_SINGLE_STREAM.name()).build(); + tasks.put(ForecastTaskType.REALTIME_FORECAST_SINGLE_STREAM.name(), forecastTask); + + // Mock taskManager to return the tasks + ForecastTaskManager taskManager = mock(ForecastTaskManager.class); + doAnswer(invocation -> { + List taskList = new ArrayList<>(tasks.values()); + ((Consumer>) invocation.getArguments()[4]).accept(taskList); + return null; + }).when(taskManager).getAndExecuteOnLatestTasks(anyString(), any(), any(), any(), any(), any(), anyBoolean(), anyInt(), any()); + + // Mock listener + ActionListener listener = mock(ActionListener.class); + + ClusterService clusterService = mock(ClusterService.class); + ClusterSettings settings = new ClusterSettings( + Settings.EMPTY, + Collections.unmodifiableSet(new HashSet<>(Arrays.asList(ForecastSettings.FORECAST_FILTER_BY_BACKEND_ROLES))) + ); + when(clusterService.getClusterSettings()).thenReturn(settings); + GetForecasterTransportAction getForecaster = spy( + new GetForecasterTransportAction( + mock(TransportService.class), + null, + mock(ActionFilters.class), + clusterService, + null, + null, + Settings.EMPTY, + null, + taskManager, + null + ) + ); + + // Act + GetConfigRequest request = new GetConfigRequest(configID, 0L, true, true, "", "", true, null); + getForecaster.getExecute(request, listener); + + // Assert + // Verify that realtimeTask is assigned using singleStreamRealTimeTaskName + // This can be checked by verifying interactions or internal state + // For this example, we'll verify that the correct task is passed to getConfigAndJob + verify(getForecaster).getConfigAndJob(eq(configID), anyBoolean(), anyBoolean(), eq(Optional.of(forecastTask)), any(), eq(listener)); + } +}