diff --git a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelArchive.java b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelArchive.java index 0d71b93a60..e950c3ccfc 100644 --- a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelArchive.java +++ b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelArchive.java @@ -8,7 +8,6 @@ import java.util.List; import java.util.Map; import org.apache.commons.io.FileUtils; -import org.apache.commons.io.FilenameUtils; import org.pytorch.serve.archive.DownloadArchiveException; import org.pytorch.serve.archive.utils.ArchiveUtils; import org.pytorch.serve.archive.utils.InvalidArchiveURLException; @@ -55,7 +54,7 @@ public static ModelArchive downloadModel( throw new ModelNotFoundException("empty url"); } - String marFileName = FilenameUtils.getName(url); + String marFileName = ArchiveUtils.getFilenameFromUrl(url); File modelLocation = new File(modelStore, marFileName); try { ArchiveUtils.downloadArchive( @@ -165,7 +164,7 @@ public void validate() throws InvalidModelException { public static void removeModel(String modelStore, String marURL) { if (ArchiveUtils.isValidURL(marURL)) { - String marFileName = FilenameUtils.getName(marURL); + String marFileName = ArchiveUtils.getFilenameFromUrl(marURL); File modelLocation = new File(modelStore, marFileName); FileUtils.deleteQuietly(modelLocation); } diff --git a/frontend/archive/src/main/java/org/pytorch/serve/archive/utils/ArchiveUtils.java b/frontend/archive/src/main/java/org/pytorch/serve/archive/utils/ArchiveUtils.java index 82c4681dd6..ff24e483e3 100644 --- a/frontend/archive/src/main/java/org/pytorch/serve/archive/utils/ArchiveUtils.java +++ b/frontend/archive/src/main/java/org/pytorch/serve/archive/utils/ArchiveUtils.java @@ -8,6 +8,7 @@ import java.io.IOException; import java.io.InputStreamReader; import java.io.Reader; +import java.net.MalformedURLException; import java.net.URL; import java.nio.charset.StandardCharsets; import java.nio.file.FileAlreadyExistsException; @@ -16,6 +17,7 @@ import java.util.Map; import java.util.regex.Pattern; import org.apache.commons.io.FileUtils; +import org.apache.commons.io.FilenameUtils; import org.pytorch.serve.archive.DownloadArchiveException; import org.pytorch.serve.archive.model.InvalidModelException; import org.pytorch.serve.archive.s3.HttpUtils; @@ -90,6 +92,15 @@ public static boolean isValidURL(String url) { return VALID_URL_PATTERN.matcher(url).matches(); } + public static String getFilenameFromUrl(String url) { + try { + URL archiveUrl = new URL(url); + return FilenameUtils.getName(archiveUrl.getPath()); + } catch (MalformedURLException e) { + return FilenameUtils.getName(url); + } + } + public static boolean downloadArchive( List allowedUrls, File location, diff --git a/frontend/archive/src/main/java/org/pytorch/serve/archive/workflow/WorkflowArchive.java b/frontend/archive/src/main/java/org/pytorch/serve/archive/workflow/WorkflowArchive.java index aaf185528e..556ef33fa5 100644 --- a/frontend/archive/src/main/java/org/pytorch/serve/archive/workflow/WorkflowArchive.java +++ b/frontend/archive/src/main/java/org/pytorch/serve/archive/workflow/WorkflowArchive.java @@ -13,7 +13,6 @@ import java.nio.file.Files; import java.util.List; import org.apache.commons.io.FileUtils; -import org.apache.commons.io.FilenameUtils; import org.pytorch.serve.archive.DownloadArchiveException; import org.pytorch.serve.archive.utils.ArchiveUtils; import org.pytorch.serve.archive.utils.InvalidArchiveURLException; @@ -53,7 +52,7 @@ public static WorkflowArchive downloadWorkflow( throw new WorkflowNotFoundException("Workflow store has not been configured."); } - String warFileName = FilenameUtils.getName(url); + String warFileName = ArchiveUtils.getFilenameFromUrl(url); File workflowLocation = new File(workflowStore, warFileName); try { @@ -144,7 +143,7 @@ public void validate() throws InvalidWorkflowException { public static void removeWorkflow(String workflowStore, String warURL) { if (ArchiveUtils.isValidURL(warURL)) { - String warFileName = FilenameUtils.getName(warURL); + String warFileName = ArchiveUtils.getFilenameFromUrl(warURL); File workflowLocation = new File(workflowStore, warFileName); FileUtils.deleteQuietly(workflowLocation); } diff --git a/frontend/archive/src/test/java/org/pytorch/serve/archive/utils/ArchiveUtilsTest.java b/frontend/archive/src/test/java/org/pytorch/serve/archive/utils/ArchiveUtilsTest.java new file mode 100644 index 0000000000..aa560de547 --- /dev/null +++ b/frontend/archive/src/test/java/org/pytorch/serve/archive/utils/ArchiveUtilsTest.java @@ -0,0 +1,46 @@ +package org.pytorch.serve.archive.utils; + +import org.testng.Assert; +import org.testng.annotations.Test; + +public class ArchiveUtilsTest { + @Test + public void testGetFilenameFromUrlWithFilename() { + String testFilename = "resnet-18.mar"; + String expectedFilename = "resnet-18.mar"; + Assert.assertEquals(ArchiveUtils.getFilenameFromUrl(testFilename), expectedFilename); + } + + @Test + public void testGetFilenameFromUrlWithFilepath() { + String testFilepath = "/home/ubuntu/model_store/resnet-18.mar"; + String expectedFilename = "resnet-18.mar"; + Assert.assertEquals(ArchiveUtils.getFilenameFromUrl(testFilepath), expectedFilename); + } + + @Test + public void testGetFilenameFromUrlWithUrl() { + String testFileUrl = "https://torchserve.pytorch.org/mar_files/resnet-18.mar"; + String expectedFilename = "resnet-18.mar"; + Assert.assertEquals(ArchiveUtils.getFilenameFromUrl(testFileUrl), expectedFilename); + } + + @Test + public void testGetFilenameFromUrlWithS3PresignedUrl() { + String testFileUrl = + "https://test-account.s3.us-west-2.amazonaws.com/mar_files/resnet-18.mar?" + + "response-content-disposition=inline&X-Amz-Security-Token=%2Ftoken%2F" + + "&X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Date=20230614T182131Z&X-Amz-SignedHeaders=host" + + "&X-Amz-Expires=43200&X-Amz-Credential=%2Fcredential%2F" + + "&X-Amz-Signature=%2Fsignature%2F"; + String expectedFilename = "resnet-18.mar"; + Assert.assertEquals(ArchiveUtils.getFilenameFromUrl(testFileUrl), expectedFilename); + } + + @Test + public void testGetFilenameFromUrlWithInvalidUrl() { + String testFileUrl = "resnet-18.mar/"; + String expectedFilename = ""; + Assert.assertEquals(ArchiveUtils.getFilenameFromUrl(testFileUrl), expectedFilename); + } +} diff --git a/frontend/archive/testng.xml b/frontend/archive/testng.xml index 0d050dfbcd..f6bb44cd49 100644 --- a/frontend/archive/testng.xml +++ b/frontend/archive/testng.xml @@ -1,11 +1,12 @@ - + + diff --git a/frontend/server/src/main/java/org/pytorch/serve/util/ApiUtils.java b/frontend/server/src/main/java/org/pytorch/serve/util/ApiUtils.java index 43e2954272..80c31b5bbc 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/util/ApiUtils.java +++ b/frontend/server/src/main/java/org/pytorch/serve/util/ApiUtils.java @@ -12,13 +12,13 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; import java.util.function.Function; -import org.apache.commons.io.FilenameUtils; import org.pytorch.serve.archive.DownloadArchiveException; import org.pytorch.serve.archive.model.Manifest; import org.pytorch.serve.archive.model.ModelArchive; import org.pytorch.serve.archive.model.ModelException; import org.pytorch.serve.archive.model.ModelNotFoundException; import org.pytorch.serve.archive.model.ModelVersionNotFoundException; +import org.pytorch.serve.archive.utils.ArchiveUtils; import org.pytorch.serve.http.BadRequestException; import org.pytorch.serve.http.InternalServerException; import org.pytorch.serve.http.InvalidModelVersionException; @@ -183,7 +183,7 @@ public static StatusResponse handleRegister( s3SseKms); } catch (FileAlreadyExistsException e) { throw new InternalServerException( - "Model file already exists " + FilenameUtils.getName(modelUrl), e); + "Model file already exists " + ArchiveUtils.getFilenameFromUrl(modelUrl), e); } catch (IOException | InterruptedException e) { throw new InternalServerException("Failed to save model: " + modelUrl, e); } diff --git a/frontend/server/src/main/java/org/pytorch/serve/wlm/Model.java b/frontend/server/src/main/java/org/pytorch/serve/wlm/Model.java index 846f80a7c2..0e54dfd72b 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/wlm/Model.java +++ b/frontend/server/src/main/java/org/pytorch/serve/wlm/Model.java @@ -12,9 +12,9 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.locks.ReentrantLock; -import org.apache.commons.io.FilenameUtils; import org.pytorch.serve.archive.model.ModelArchive; import org.pytorch.serve.archive.model.ModelConfig; +import org.pytorch.serve.archive.utils.ArchiveUtils; import org.pytorch.serve.job.Job; import org.pytorch.serve.util.ConfigManager; import org.pytorch.serve.util.messages.WorkerCommands; @@ -130,7 +130,7 @@ public JsonObject getModelState(boolean isDefaultVersion) { JsonObject modelInfo = new JsonObject(); modelInfo.addProperty(DEFAULT_VERSION, isDefaultVersion); - modelInfo.addProperty(MAR_NAME, FilenameUtils.getName(getModelUrl())); + modelInfo.addProperty(MAR_NAME, ArchiveUtils.getFilenameFromUrl(getModelUrl())); modelInfo.addProperty(MIN_WORKERS, getMinWorkers()); modelInfo.addProperty(MAX_WORKERS, getMaxWorkers()); modelInfo.addProperty(BATCH_SIZE, getBatchSize());