Skip to content

Commit

Permalink
Correctly clone IndexInputs to avoid race
Browse files Browse the repository at this point in the history
In PR opensearch-project#6345 I did remove a duplicate clone, however this resulted in
cloning the IndexInput in the wrong place. When requesting a file that
needs to be downloaded, we have a mechanism to ensure that concurrent
calls do not end up duplicating the download, which results in multiple
threads being given the same instance. The clone must happen _after_
this point to ensure that each thread gets its own clone.

Signed-off-by: Andrew Ross <andrross@amazon.com>
  • Loading branch information
andrross committed Feb 20, 2023
1 parent dae1566 commit 4495659
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,13 @@

package org.opensearch.index.store.remote.file;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.apache.lucene.store.FSDirectory;
import org.apache.lucene.store.IndexInput;
import org.opensearch.index.snapshots.blobstore.BlobStoreIndexShardSnapshot.FileInfo;
import org.opensearch.index.store.remote.utils.BlobFetchRequest;
import org.opensearch.index.store.remote.utils.TransferManager;

import java.io.IOException;
import java.util.concurrent.ExecutionException;

/**
* This is an implementation of {@link OnDemandBlockIndexInput} where this class provides the main IndexInput using shard snapshot files.
Expand All @@ -28,8 +24,6 @@
* @opensearch.internal
*/
public class OnDemandBlockSnapshotIndexInput extends OnDemandBlockIndexInput {
private static final Logger logger = LogManager.getLogger(OnDemandBlockSnapshotIndexInput.class);

/**
* Where this class fetches IndexInput parts from
*/
Expand Down Expand Up @@ -146,12 +140,7 @@ protected IndexInput fetchBlock(int blockId) throws IOException {
.directory(directory)
.fileName(blockFileName)
.build();
try {
return transferManager.asyncFetchBlob(blobFetchRequest).get();
} catch (InterruptedException | ExecutionException e) {
logger.error(() -> new ParameterizedMessage("unexpected failure while fetching [{}]", blobFetchRequest), e);
throw new IllegalStateException(e);
}
return transferManager.fetchBlob(blobFetchRequest);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.UncheckedIOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.function.Supplier;

/**
* This acts as entry point to fetch {@link BlobFetchRequest} and return actual {@link IndexInput}. Utilizes the BlobContainer interface to
Expand All @@ -50,26 +50,37 @@ public TransferManager(final BlobContainer blobContainer, final ExecutorService
/**
* Given a blobFetchRequest, return it's corresponding IndexInput.
* @param blobFetchRequest to fetch
* @return future of IndexInput augmented with internal caching maintenance tasks
* @return The IndexInput of the requested data
*/
public CompletableFuture<IndexInput> asyncFetchBlob(BlobFetchRequest blobFetchRequest) {
return asyncFetchBlob(blobFetchRequest.getFilePath(), () -> {
public IndexInput fetchBlob(BlobFetchRequest blobFetchRequest) {
final CompletableFuture<IndexInput> originFuture = invocationLinearizer.linearize(blobFetchRequest.getFilePath(), p -> {
try {
return fetchBlob(blobFetchRequest);
return fetchOriginBlob(blobFetchRequest);
} catch (IOException e) {
throw new IllegalStateException(e);
throw new UncheckedIOException(e);
}
});
}

private CompletableFuture<IndexInput> asyncFetchBlob(Path path, Supplier<IndexInput> indexInputSupplier) {
return invocationLinearizer.linearize(path, p -> indexInputSupplier.get());
try {
final IndexInput origin = originFuture.get();
// The origin instances stays in the cache with a ref count of zero
// and must be cloned before being returned.
return origin.clone();
} catch (InterruptedException | ExecutionException e) {
if (e instanceof InterruptedException) {
Thread.currentThread().interrupt();
}
logger.info("Unexpected error fetching blob: {}", blobFetchRequest);
throw new IllegalStateException(e);
}
}

/*
This method accessed through the ConcurrentInvocationLinearizer so read-check-write is acceptable here
/**
* Fetches the "origin" IndexInput used in the cache. This instance must
* always be cloned before being returned from this class. This method uses
* a read-check-write pattern and must be externally synchronized.
*/
private IndexInput fetchBlob(BlobFetchRequest blobFetchRequest) throws IOException {
private IndexInput fetchOriginBlob(BlobFetchRequest blobFetchRequest) throws IOException {
// check if the origin is already in block cache
IndexInput origin = fileCache.computeIfPresent(blobFetchRequest.getFilePath(), (path, cachedIndexInput) -> {
if (cachedIndexInput.isClosed()) {
Expand All @@ -88,7 +99,7 @@ private IndexInput fetchBlob(BlobFetchRequest blobFetchRequest) throws IOExcepti
return cachedIndexInput;
});

if (Objects.isNull(origin)) {
if (origin == null) {
// origin is not in file cache, download origin

// open new origin
Expand All @@ -101,8 +112,7 @@ private IndexInput fetchBlob(BlobFetchRequest blobFetchRequest) throws IOExcepti
fileCache.put(blobFetchRequest.getFilePath(), newOrigin);
origin = newOrigin;
}
// always, need to clone to do refcount += 1, and rely on GC to clean these IndexInput which will refcount -= 1
return origin.clone();
return origin;
}

private IndexInput downloadBlockLocally(BlobFetchRequest blobFetchRequest) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
import java.io.EOFException;
import java.io.IOException;
import java.nio.file.Path;
import java.util.concurrent.CompletableFuture;

import com.carrotsearch.randomizedtesting.annotations.ThreadLeakFilters;

Expand Down Expand Up @@ -142,10 +141,8 @@ private OnDemandBlockSnapshotIndexInput createOnDemandBlockSnapshotIndexInput(in

doAnswer(invocation -> {
BlobFetchRequest blobFetchRequest = invocation.getArgument(0);
return CompletableFuture.completedFuture(
blobFetchRequest.getDirectory().openInput(blobFetchRequest.getFileName(), IOContext.READ)
);
}).when(transferManager).asyncFetchBlob(any());
return blobFetchRequest.getDirectory().openInput(blobFetchRequest.getFileName(), IOContext.READ);
}).when(transferManager).fetchBlob(any());

FSDirectory directory = null;
try {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.index.store.remote.utils;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;

import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.MMapDirectory;
import org.apache.lucene.store.SimpleFSLockFactory;
import org.junit.After;
import org.junit.Before;
import org.opensearch.common.blobstore.BlobContainer;
import org.opensearch.index.store.remote.file.CleanerDaemonThreadLeakFilter;
import org.opensearch.index.store.remote.filecache.FileCache;
import org.opensearch.index.store.remote.filecache.FileCacheFactory;
import org.opensearch.test.OpenSearchTestCase;

import com.carrotsearch.randomizedtesting.annotations.ThreadLeakFilters;

import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;

@ThreadLeakFilters(filters = CleanerDaemonThreadLeakFilter.class)
public class TransferManagerTests extends OpenSearchTestCase {
private final FileCache fileCache = FileCacheFactory.createConcurrentLRUFileCache(1024 * 1024);
private final ExecutorService executor = Executors.newSingleThreadExecutor();
private MMapDirectory directory;
private BlobContainer blobContainer;
private TransferManager transferManager;

@Before
public void setUp() throws Exception {
super.setUp();
directory = new MMapDirectory(createTempDir(), SimpleFSLockFactory.INSTANCE);
blobContainer = mock(BlobContainer.class);
doAnswer(i -> new ByteArrayInputStream(new byte[] { 0, 1, 2, 3, 4, 5, 6, 7 })).when(blobContainer).readBlob("blob", 0, 8);
transferManager = new TransferManager(blobContainer, executor, fileCache);
}

@After
public void tearDown() throws Exception {
super.tearDown();
executor.shutdown();
assertTrue(executor.awaitTermination(1, TimeUnit.SECONDS));
}

public void testSingleAccess() throws IOException {
try (IndexInput i = fetchBlob()) {
i.seek(7);
}
}

public void testConcurrentAccess() throws Exception {
// Kick off multiple threads that all concurrently request the same resource
final ExecutorService testRunner = Executors.newFixedThreadPool(8);
final List<Future<IndexInput>> futures = new ArrayList<>();
for (int i = 0; i < 8; i++) {
futures.add(testRunner.submit(this::fetchBlob));
}
// Wait for all threads to complete
for (Future<IndexInput> future : futures) {
future.get(1, TimeUnit.SECONDS);
}
// Assert that all IndexInputs are independently positioned by seeking
// to the end and closing each one. If not independent, then this would
// result in EOFExceptions and/or NPEs.
for (Future<IndexInput> future : futures) {
future.get().seek(7);
future.get().close();
}
testRunner.shutdown();
assertTrue(testRunner.awaitTermination(1, TimeUnit.SECONDS));
}

private IndexInput fetchBlob() {
return transferManager.fetchBlob(
BlobFetchRequest.builder().blobName("blob").position(0).fileName("file").directory(directory).length(8).build()
);
}
}

0 comments on commit 4495659

Please sign in to comment.