Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix interpretation of filename from model archive URL #2416

Merged
merged 1 commit into from
Jun 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import java.util.List;
import java.util.Map;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.FilenameUtils;
import org.pytorch.serve.archive.DownloadArchiveException;
import org.pytorch.serve.archive.utils.ArchiveUtils;
import org.pytorch.serve.archive.utils.InvalidArchiveURLException;
Expand Down Expand Up @@ -55,7 +54,7 @@ public static ModelArchive downloadModel(
throw new ModelNotFoundException("empty url");
}

String marFileName = FilenameUtils.getName(url);
String marFileName = ArchiveUtils.getFilenameFromUrl(url);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FilenameUtils.getName is used to extract the mar filename from the url. it is correct function. The original url is passed to function ArchiveUtils.downloadArchive.

The issue most likely happen in downloadArchive.

Copy link
Collaborator Author

@namannandan namannandan Jun 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is correct, the call path is ModelArchive.downloadModel -> ArchiveUtils.downloadArchive -> HttpUtils.copyURLToFile -> FileUtils.copyURLToFile.
org.apache.commons.io.FileUtils.copyURLToFile throws IOException because of the destination we pass to it in the case of S3 presigned URL.

For example, in the current implementation, in case of S3 pre-signed URL, we call org.apache.commons.io.FileUtils.copyURLToFile with the following arguments:

source: https://test-account.s3.us-west-2.amazonaws.com/mar_files/resnet-18.mar?response-content-disposition=inline&X-Amz-Security-Token=token&X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Date=20230614T182131Z&X-Amz-SignedHeaders=host&X-Amz-Expires=43200&X-Amz-Credential=credential&X-Amz-Signature=signature

destination: resnet-18.mar?response-content-disposition=inline&X-Amz-Security-Token=token&X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Date=20230614T182131Z&X-Amz-SignedHeaders=host&X-Amz-Expires=43200&X-Amz-Credential=credential&X-Amz-Signature=signature

and this fails with IOException.

For testing, I hardcoded the destination in the HttpUtils.copyURLToFile to resnet-18.mar and the download succeeded.

Therefore, I've fixed the the implementation to correctly identify filename from model archive URL that contains additional parameters after the filename in ModelArchive.java which is eventually passed to org.apache.commons.io.FileUtils.copyURLToFile.

With the fix in this PR, the arguments are passed correctly:
source: https://test-account.s3.us-west-2.amazonaws.com/mar_files/resnet-18.mar?response-content-disposition=inline&X-Amz-Security-Token=token&X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Date=20230614T182131Z&X-Amz-SignedHeaders=host&X-Amz-Expires=43200&X-Amz-Credential=credential&X-Amz-Signature=signature

destination: resnet-18.mar

File modelLocation = new File(modelStore, marFileName);
try {
ArchiveUtils.downloadArchive(
Expand Down Expand Up @@ -165,7 +164,7 @@ public void validate() throws InvalidModelException {

public static void removeModel(String modelStore, String marURL) {
if (ArchiveUtils.isValidURL(marURL)) {
String marFileName = FilenameUtils.getName(marURL);
String marFileName = ArchiveUtils.getFilenameFromUrl(marURL);
File modelLocation = new File(modelStore, marFileName);
FileUtils.deleteQuietly(modelLocation);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.Reader;
import java.net.MalformedURLException;
import java.net.URL;
import java.nio.charset.StandardCharsets;
import java.nio.file.FileAlreadyExistsException;
Expand All @@ -16,6 +17,7 @@
import java.util.Map;
import java.util.regex.Pattern;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.FilenameUtils;
import org.pytorch.serve.archive.DownloadArchiveException;
import org.pytorch.serve.archive.model.InvalidModelException;
import org.pytorch.serve.archive.s3.HttpUtils;
Expand Down Expand Up @@ -90,6 +92,15 @@ public static boolean isValidURL(String url) {
return VALID_URL_PATTERN.matcher(url).matches();
}

public static String getFilenameFromUrl(String url) {
try {
URL archiveUrl = new URL(url);
return FilenameUtils.getName(archiveUrl.getPath());
} catch (MalformedURLException e) {
return FilenameUtils.getName(url);
}
}

public static boolean downloadArchive(
List<String> allowedUrls,
File location,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import java.nio.file.Files;
import java.util.List;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.FilenameUtils;
import org.pytorch.serve.archive.DownloadArchiveException;
import org.pytorch.serve.archive.utils.ArchiveUtils;
import org.pytorch.serve.archive.utils.InvalidArchiveURLException;
Expand Down Expand Up @@ -53,7 +52,7 @@ public static WorkflowArchive downloadWorkflow(
throw new WorkflowNotFoundException("Workflow store has not been configured.");
}

String warFileName = FilenameUtils.getName(url);
String warFileName = ArchiveUtils.getFilenameFromUrl(url);
File workflowLocation = new File(workflowStore, warFileName);

try {
Expand Down Expand Up @@ -144,7 +143,7 @@ public void validate() throws InvalidWorkflowException {

public static void removeWorkflow(String workflowStore, String warURL) {
if (ArchiveUtils.isValidURL(warURL)) {
String warFileName = FilenameUtils.getName(warURL);
String warFileName = ArchiveUtils.getFilenameFromUrl(warURL);
File workflowLocation = new File(workflowStore, warFileName);
FileUtils.deleteQuietly(workflowLocation);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package org.pytorch.serve.archive.utils;

import org.testng.Assert;
import org.testng.annotations.Test;

public class ArchiveUtilsTest {
@Test
public void testGetFilenameFromUrlWithFilename() {
String testFilename = "resnet-18.mar";
String expectedFilename = "resnet-18.mar";
Assert.assertEquals(ArchiveUtils.getFilenameFromUrl(testFilename), expectedFilename);
}

@Test
public void testGetFilenameFromUrlWithFilepath() {
String testFilepath = "/home/ubuntu/model_store/resnet-18.mar";
String expectedFilename = "resnet-18.mar";
Assert.assertEquals(ArchiveUtils.getFilenameFromUrl(testFilepath), expectedFilename);
}

@Test
public void testGetFilenameFromUrlWithUrl() {
String testFileUrl = "https://torchserve.pytorch.org/mar_files/resnet-18.mar";
String expectedFilename = "resnet-18.mar";
Assert.assertEquals(ArchiveUtils.getFilenameFromUrl(testFileUrl), expectedFilename);
}

@Test
public void testGetFilenameFromUrlWithS3PresignedUrl() {
String testFileUrl =
"https://test-account.s3.us-west-2.amazonaws.com/mar_files/resnet-18.mar?"
+ "response-content-disposition=inline&X-Amz-Security-Token=%2Ftoken%2F"
+ "&X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Date=20230614T182131Z&X-Amz-SignedHeaders=host"
+ "&X-Amz-Expires=43200&X-Amz-Credential=%2Fcredential%2F"
+ "&X-Amz-Signature=%2Fsignature%2F";
String expectedFilename = "resnet-18.mar";
Assert.assertEquals(ArchiveUtils.getFilenameFromUrl(testFileUrl), expectedFilename);
}

@Test
public void testGetFilenameFromUrlWithInvalidUrl() {
String testFileUrl = "resnet-18.mar/";
String expectedFilename = "";
Assert.assertEquals(ArchiveUtils.getFilenameFromUrl(testFileUrl), expectedFilename);
}
}
3 changes: 2 additions & 1 deletion frontend/archive/testng.xml
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
<!DOCTYPE suite SYSTEM "https://testng.org/testng-1.0.dtd" >

<suite name="ModelArchiverSuite" verbose="1" >
<test name="TorchServe">
<classes>
<class name="org.pytorch.serve.archive.CoverageTest"/>
<class name="org.pytorch.serve.archive.model.ModelArchiveTest"/>
<class name="org.pytorch.serve.archive.model.ModelConfigTest"/>
<class name="org.pytorch.serve.archive.utils.ArchiveUtilsTest"/>
<class name="org.pytorch.serve.archive.workflow.WorkFlowArchiveTest"/>
</classes>
</test>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.function.Function;
import org.apache.commons.io.FilenameUtils;
import org.pytorch.serve.archive.DownloadArchiveException;
import org.pytorch.serve.archive.model.Manifest;
import org.pytorch.serve.archive.model.ModelArchive;
import org.pytorch.serve.archive.model.ModelException;
import org.pytorch.serve.archive.model.ModelNotFoundException;
import org.pytorch.serve.archive.model.ModelVersionNotFoundException;
import org.pytorch.serve.archive.utils.ArchiveUtils;
import org.pytorch.serve.http.BadRequestException;
import org.pytorch.serve.http.InternalServerException;
import org.pytorch.serve.http.InvalidModelVersionException;
Expand Down Expand Up @@ -183,7 +183,7 @@ public static StatusResponse handleRegister(
s3SseKms);
} catch (FileAlreadyExistsException e) {
throw new InternalServerException(
"Model file already exists " + FilenameUtils.getName(modelUrl), e);
"Model file already exists " + ArchiveUtils.getFilenameFromUrl(modelUrl), e);
} catch (IOException | InterruptedException e) {
throw new InternalServerException("Failed to save model: " + modelUrl, e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.ReentrantLock;
import org.apache.commons.io.FilenameUtils;
import org.pytorch.serve.archive.model.ModelArchive;
import org.pytorch.serve.archive.model.ModelConfig;
import org.pytorch.serve.archive.utils.ArchiveUtils;
import org.pytorch.serve.job.Job;
import org.pytorch.serve.util.ConfigManager;
import org.pytorch.serve.util.messages.WorkerCommands;
Expand Down Expand Up @@ -130,7 +130,7 @@ public JsonObject getModelState(boolean isDefaultVersion) {

JsonObject modelInfo = new JsonObject();
modelInfo.addProperty(DEFAULT_VERSION, isDefaultVersion);
modelInfo.addProperty(MAR_NAME, FilenameUtils.getName(getModelUrl()));
modelInfo.addProperty(MAR_NAME, ArchiveUtils.getFilenameFromUrl(getModelUrl()));
modelInfo.addProperty(MIN_WORKERS, getMinWorkers());
modelInfo.addProperty(MAX_WORKERS, getMaxWorkers());
modelInfo.addProperty(BATCH_SIZE, getBatchSize());
Expand Down