Skip to content

Commit

Permalink
Addressed Sagar's comments
Browse files Browse the repository at this point in the history
Signed-off-by: Peter Alfonsi <petealft@amazon.com>
  • Loading branch information
Peter Alfonsi committed Apr 29, 2024
1 parent c14a159 commit ab3a356
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.opensearch.common.cache.policy.CachedQueryResult;
import org.opensearch.common.cache.stats.ImmutableCacheStatsHolder;
import org.opensearch.common.cache.store.config.CacheConfig;
import org.opensearch.common.collect.Tuple;
import org.opensearch.common.settings.Setting;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.unit.TimeValue;
Expand Down Expand Up @@ -81,8 +82,7 @@ public class TieredSpilloverCache<K, V> implements ICache<K, V> {
/**
* Maintains caching tiers in ascending order of cache latency.
*/
private final Map<ICache<K, V>, Boolean> caches;
private final Map<ICache<K, V>, String> tierValueMap;
private final Map<ICache<K, V>, TierInfo> caches;
private final List<Predicate<V>> policies;

TieredSpilloverCache(Builder<K, V> builder) {
Expand Down Expand Up @@ -125,13 +125,12 @@ public class TieredSpilloverCache<K, V> implements ICache<K, V> {
builder.cacheFactories
);
Boolean isDiskCacheEnabled = DISK_CACHE_ENABLED_SETTING_MAP.get(builder.cacheType).get(builder.cacheConfig.getSettings());
LinkedHashMap<ICache<K, V>, Boolean> cacheListMap = new LinkedHashMap<>();
cacheListMap.put(onHeapCache, true);
cacheListMap.put(diskCache, isDiskCacheEnabled);
LinkedHashMap<ICache<K, V>, TierInfo> cacheListMap = new LinkedHashMap<>();
cacheListMap.put(onHeapCache, new TierInfo(true, TIER_DIMENSION_VALUE_ON_HEAP));
cacheListMap.put(diskCache, new TierInfo(isDiskCacheEnabled, TIER_DIMENSION_VALUE_DISK));
this.caches = Collections.synchronizedMap(cacheListMap);

this.dimensionNames = builder.cacheConfig.getDimensionNames();
this.tierValueMap = Map.of(onHeapCache, TIER_DIMENSION_VALUE_ON_HEAP, diskCache, TIER_DIMENSION_VALUE_DISK);
// Pass "tier" as the innermost dimension name, in addition to whatever dimensions are specified for the cache as a whole
this.statsHolder = new TieredSpilloverCacheStatsHolder(dimensionNames, isDiskCacheEnabled);
this.policies = builder.policies; // Will never be null; builder initializes it to an empty list
Expand All @@ -153,13 +152,17 @@ ICache<K, V> getDiskCache() {
void enableDisableDiskCache(Boolean isDiskCacheEnabled) {
// When disk cache is disabled, we are not clearing up the disk cache entries yet as that should be part of
// separate cache/clear API.
this.caches.put(diskCache, isDiskCacheEnabled);
this.caches.put(diskCache, new TierInfo(isDiskCacheEnabled, TIER_DIMENSION_VALUE_DISK));
this.statsHolder.setDiskCacheEnabled(isDiskCacheEnabled);
}

@Override
public V get(ICacheKey<K> key) {
return getValueFromTieredCache().apply(key);
Tuple<V, String> cacheValueTuple = getValueFromTieredCache(true).apply(key);
if (cacheValueTuple == null) {
return null;
}
return cacheValueTuple.v1();
}

@Override
Expand All @@ -172,22 +175,50 @@ public void put(ICacheKey<K> key, V value) {

@Override
public V computeIfAbsent(ICacheKey<K> key, LoadAwareCacheLoader<ICacheKey<K>, V> loader) throws Exception {
V cacheValue = getValueFromTieredCache().apply(key);
if (cacheValue == null) {
// Don't capture stats in the initial getValueFromTieredCache(). If we have concurrent requests for the same key,
// and it only has to be loaded one time, we should report one miss and the rest hits. But, if we do stats in
// getValueFromTieredCache(),
// we will see all misses. Instead, handle stats in computeIfAbsent().
Tuple<V, String> cacheValueTuple = getValueFromTieredCache(false).apply(key);
List<String> heapDimensionValues = statsHolder.getDimensionsWithTierValue(key.dimensions, TIER_DIMENSION_VALUE_ON_HEAP);
List<String> diskDimensionValues = statsHolder.getDimensionsWithTierValue(key.dimensions, TIER_DIMENSION_VALUE_DISK);

if (cacheValueTuple == null) {
// Add the value to the onHeap cache. We are calling computeIfAbsent which does another get inside.
// This is needed as there can be many requests for the same key at the same time and we only want to load
// the value once.
V value = null;
try (ReleasableLock ignore = writeLock.acquire()) {
value = onHeapCache.computeIfAbsent(key, loader);
if (loader.isLoaded()) {
// The value was just computed and added to the cache
updateStatsOnPut(TIER_DIMENSION_VALUE_ON_HEAP, key, value);
}
// Handle stats
if (loader.isLoaded()) {
// The value was just computed and added to the cache by this thread. Register a miss for the heap cache, and the disk cache
// if present
updateStatsOnPut(TIER_DIMENSION_VALUE_ON_HEAP, key, value);
statsHolder.incrementMisses(heapDimensionValues);
if (caches.get(diskCache).isEnabled) {
statsHolder.incrementMisses(diskDimensionValues);
}
} else {
// Another thread requesting this key already loaded the value. Register a hit for the heap cache
statsHolder.incrementHits(heapDimensionValues);
}
return value;
}
return cacheValue;

else {
// Handle stats for an initial hit from getValueFromTieredCache()
if (cacheValueTuple.v2().equals(TIER_DIMENSION_VALUE_ON_HEAP)) {
// A hit for the heap tier
statsHolder.incrementHits(heapDimensionValues);
} else {
// Miss for the heap tier, hit for the disk tier
statsHolder.incrementMisses(heapDimensionValues);
statsHolder.incrementHits(diskDimensionValues);
}
}
return cacheValueTuple.v1();
}

@Override
Expand All @@ -197,10 +228,9 @@ public void invalidate(ICacheKey<K> key) {
// also trigger a hit/miss listener event, so ignoring it for now.
// We don't update stats here, as this is handled by the removal listeners for the tiers.
try (ReleasableLock ignore = writeLock.acquire()) {
// for (Tuple<ICache<K, V>, String> pair : cacheAndTierValueList) {
for (Map.Entry<ICache<K, V>, String> cacheEntry : tierValueMap.entrySet()) {
for (Map.Entry<ICache<K, V>, TierInfo> cacheEntry : caches.entrySet()) {
if (key.getDropStatsForDimensions()) {
List<String> dimensionValues = statsHolder.getDimensionsWithTierValue(key.dimensions, cacheEntry.getValue());
List<String> dimensionValues = statsHolder.getDimensionsWithTierValue(key.dimensions, cacheEntry.getValue().tierName);
statsHolder.removeDimensions(dimensionValues);
}
if (key.key != null) {
Expand All @@ -213,7 +243,7 @@ public void invalidate(ICacheKey<K> key) {
@Override
public void invalidateAll() {
try (ReleasableLock ignore = writeLock.acquire()) {
for (Map.Entry<ICache<K, V>, Boolean> cacheEntry : caches.entrySet()) {
for (Map.Entry<ICache<K, V>, TierInfo> cacheEntry : caches.entrySet()) {
cacheEntry.getKey().invalidateAll();
}
}
Expand All @@ -228,7 +258,7 @@ public void invalidateAll() {
@Override
public Iterable<ICacheKey<K>> keys() {
List<Iterable<ICacheKey<K>>> iterableList = new ArrayList<>();
for (Map.Entry<ICache<K, V>, Boolean> cacheEntry : caches.entrySet()) {
for (Map.Entry<ICache<K, V>, TierInfo> cacheEntry : caches.entrySet()) {
iterableList.add(cacheEntry.getKey().keys());
}
Iterable<ICacheKey<K>>[] iterables = (Iterable<ICacheKey<K>>[]) iterableList.toArray(new Iterable<?>[0]);
Expand All @@ -245,15 +275,15 @@ public long count() {
@Override
public void refresh() {
try (ReleasableLock ignore = writeLock.acquire()) {
for (Map.Entry<ICache<K, V>, Boolean> cacheEntry : caches.entrySet()) {
for (Map.Entry<ICache<K, V>, TierInfo> cacheEntry : caches.entrySet()) {
cacheEntry.getKey().refresh();
}
}
}

@Override
public void close() throws IOException {
for (Map.Entry<ICache<K, V>, Boolean> cacheEntry : caches.entrySet()) {
for (Map.Entry<ICache<K, V>, TierInfo> cacheEntry : caches.entrySet()) {
// Close all the caches here irrespective of whether they are enabled or not.
cacheEntry.getKey().close();
}
Expand All @@ -264,19 +294,26 @@ public ImmutableCacheStatsHolder stats() {
return statsHolder.getImmutableCacheStatsHolder();
}

private Function<ICacheKey<K>, V> getValueFromTieredCache() {
/**
* Get a value from the tiered cache, and the name of the tier it was found in.
* @param captureStats Whether to record hits/misses for this call of the function
* @return A tuple of the value and the name of the tier it was found in.
*/
private Function<ICacheKey<K>, Tuple<V, String>> getValueFromTieredCache(boolean captureStats) {
return key -> {
try (ReleasableLock ignore = readLock.acquire()) {
for (Map.Entry<ICache<K, V>, Boolean> cacheEntry : caches.entrySet()) {
if (cacheEntry.getValue()) {
for (Map.Entry<ICache<K, V>, TierInfo> cacheEntry : caches.entrySet()) {
if (cacheEntry.getValue().isEnabled) {
V value = cacheEntry.getKey().get(key);
// Get the tier value corresponding to this cache
String tierValue = tierValueMap.get(cacheEntry.getKey());
String tierValue = cacheEntry.getValue().tierName;
List<String> dimensionValues = statsHolder.getDimensionsWithTierValue(key.dimensions, tierValue);
if (value != null) {
statsHolder.incrementHits(dimensionValues);
return value;
} else {
if (captureStats) {
statsHolder.incrementHits(dimensionValues);
}
return new Tuple<>(value, tierValue);
} else if (captureStats) {
statsHolder.incrementMisses(dimensionValues);
}
}
Expand All @@ -289,7 +326,7 @@ private Function<ICacheKey<K>, V> getValueFromTieredCache() {
void handleRemovalFromHeapTier(RemovalNotification<ICacheKey<K>, V> notification) {
ICacheKey<K> key = notification.getKey();
boolean wasEvicted = SPILLOVER_REMOVAL_REASONS.contains(notification.getRemovalReason());
if (caches.get(diskCache) && wasEvicted && evaluatePolicies(notification.getValue())) {
if (caches.get(diskCache).isEnabled && wasEvicted && evaluatePolicies(notification.getValue())) {
try (ReleasableLock ignore = writeLock.acquire()) {
diskCache.put(key, notification.getValue()); // spill over to the disk tier and increment its stats
updateStatsOnPut(TIER_DIMENSION_VALUE_DISK, key, notification.getValue());
Expand Down Expand Up @@ -426,6 +463,16 @@ public void remove() {
}
}

private class TierInfo {
boolean isEnabled;
String tierName;

TierInfo(boolean isEnabled, String tierName) {
this.isEnabled = isEnabled;
this.tierName = tierName;
}
}

/**
* Factory to create TieredSpilloverCache objects.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,8 @@ public void testComputeIfAbsentWithEvictionsFromOnHeapCache() throws Exception {
assertFalse(loadAwareCacheLoader.isLoaded());
}
}
for (int iter = 0; iter < randomIntBetween(50, 200); iter++) {
int numRandom = randomIntBetween(50, 200);
for (int iter = 0; iter < numRandom; iter++) {
// Hit cache with randomized key which is expected to miss cache always.
LoadAwareCacheLoader<ICacheKey<String>, String> tieredCacheLoader = getLoadAwareCacheLoader();
tieredSpilloverCache.computeIfAbsent(getICacheKey(UUID.randomUUID().toString()), tieredCacheLoader);
Expand Down Expand Up @@ -812,6 +813,9 @@ public String load(ICacheKey<String> key) {
}
}
assertEquals(1, numberOfTimesKeyLoaded); // It should be loaded only once.
// We should see only one heap miss, and the rest hits
assertEquals(1, getMissesForTier(tieredSpilloverCache, TIER_DIMENSION_VALUE_ON_HEAP));
assertEquals(numberOfSameKeys - 1, getHitsForTier(tieredSpilloverCache, TIER_DIMENSION_VALUE_ON_HEAP));
}

public void testConcurrencyForEvictionFlowFromOnHeapToDiskTier() throws Exception {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,9 @@ public String getCacheName() {
return NAME;
}

private boolean useNoopStats(Settings settings, boolean configUseNoopStats) {
private boolean useNoopStats(Settings settings, boolean useNoopStatsConfig) {
// Use noop stats when pluggable caching is off, or when explicitly set in the CacheConfig
return !FeatureFlags.PLUGGABLE_CACHE_SETTING.get(settings) || configUseNoopStats;
return !FeatureFlags.PLUGGABLE_CACHE_SETTING.get(settings) || useNoopStatsConfig;
}
}

Expand Down

0 comments on commit ab3a356

Please sign in to comment.