-
Notifications
You must be signed in to change notification settings - Fork 843
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Ignore url query parameters when storing models (.mar files) #2085
Comments
Hi, thanks for filing the issue! The behavior is indeed unfortunate in your case and the proposed solution sounds reasonable to me. |
@mreso awesome! I'll try to make a PR when I get the chance. I believe the bogus filename problem is best solved with the alternative number 3 (Take optional filename argument in the model registration endpoint). It would be the most flexible option for the developer as well. What do you think? |
@MuhsinFatih Great! Looking forward to your PR! Let me know if you need any help with it! Agreed on the solution for users using bogus filenames. Would be great to have this option in case users leverage that "flaw" in TS to make it work in their situation. If you want to take it on as well I would suggest creating a second PR for it to facilitate the review of the changes. |
Hi @MuhsinFatih - curious how you got presigned S3 URL's to work in the first place? I keep hitting the same errors as #1293 , and even if the cache was missed this would allow us to move forward with serving models from S3.. |
@cfculhane You might be having trouble with the signature version and/or with url parsing. Try the two following:
(source: #669 (comment)) |
Ahh thank you - this worked. I ended up coding up a fastAPI web service inside the torchserve container to handle these and download them to the model store, which allowed us to cache them better and to handle both the preseigned urls without modification, and also S3 URI's natively. We then send the register model POST request off with just the local filename, and torchserve is none the wiser :) E.g. with some internal modules redacted : """
Due to torchserve not supporting S3 presigned URLs or boto3 compatible S3 Paths, we need to put a tiny python
webserver in front of the torchserve management API.
Relevant torchserve issues:
- https://github.com/pytorch/serve/issues/2085
- https://github.com/pytorch/serve/issues/1293
This only handles / forwards requests to the torchserve management API, and does not handle inference requests.
"""
import functools
import os
import shutil
from pathlib import Path
from typing import Optional
from urllib.parse import urlparse
import requests
from <internal>.s3_uploader import AmazonS3
from <internal>.utils import create_logger
from fastapi import FastAPI
from s3path import S3Path
from starlette.requests import Request
from starlette.responses import JSONResponse
from starlette.routing import Route
# noinspection HttpUrlsUsage
TS_MANAGEMENT_ADDRESS = os.environ.get("TS_MANAGEMENT_ADDRESS", "http://localhost:8001")
CONTAINER_MODEL_STORE_DIR = Path(os.environ.get("CONTAINER_MODEL_STORE_DIR", Path(__file__).parent / "model_store"))
logger = create_logger("api_proxy")
logger.info(f"Starting API proxy, forwarding management API requests to {TS_MANAGEMENT_ADDRESS}")
async def catch_all(request: Request):
"""
This just forwards all un-routed requests to the torchserve management API
"""
ts_response = requests.request(
method=request.method,
url=f"{TS_MANAGEMENT_ADDRESS}/{request.path_params['full_path']}",
headers=request.headers,
data=await request.body(),
allow_redirects=False,
)
return JSONResponse(ts_response.json(), status_code=ts_response.status_code)
routes = [
Route("/{full_path:path}", endpoint=catch_all, methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"]),
]
catch_all_app = FastAPI(routes=routes)
app = FastAPI()
@app.get("/ping")
async def ping(request: Request):
return JSONResponse({"message": "PONG"})
def download_file_from_url(url: str, path: Path) -> Path:
from tqdm.auto import tqdm
r = requests.get(url, stream=True, allow_redirects=True)
if r.status_code != 200:
r.raise_for_status() # Will only raise for 4xx codes, so...
raise RuntimeError(f"Request to {url} returned status code {r.status_code}")
file_size = int(r.headers.get("Content-Length", 0))
path.parent.mkdir(parents=True, exist_ok=True)
desc = "(Unknown total file size)" if file_size == 0 else ""
r.raw.read = functools.partial(r.raw.read, decode_content=True) # Decompress if needed
with tqdm.wrapattr(r.raw, "read", total=file_size, desc=desc) as r_raw:
with path.open("wb") as f:
shutil.copyfileobj(r_raw, f)
return path
@app.post("/models")
async def models_post(
url: str, model_name: Optional[str] = None, initial_workers: Optional[int] = 1, batch_size: Optional[int] = 1
):
if url.startswith("s3://"): # S3 URI
model_path_remote = S3Path.from_uri(url)
model_name = model_name or model_path_remote.stem
local_model_path = CONTAINER_MODEL_STORE_DIR / f"{model_name}.mar"
logger.info(f"Detected S3 URI, downloading from {model_path_remote} to {local_model_path}")
AmazonS3().download_file_from_bucket(
bucket_name=model_path_remote.bucket, source_path=model_path_remote.key, dest_path=local_model_path
)
elif url.startswith("https://"):
parsed_url = urlparse(url)
model_name = model_name or Path(parsed_url.path).stem
local_model_path = CONTAINER_MODEL_STORE_DIR / f"{model_name}.mar"
logger.info(f"Detected HTTPS URI, downloading from {url} to {local_model_path}")
download_file_from_url(url, local_model_path)
else:
model_name = model_name or Path(url).stem
local_model_path = CONTAINER_MODEL_STORE_DIR / f"{model_name}.mar"
# Validate we now have the file locally
if not local_model_path.exists():
raise FileNotFoundError(f"Model file not found at {local_model_path}")
ts_response = requests.post(
url=f"{TS_MANAGEMENT_ADDRESS}/models",
params={
"url": local_model_path.name,
"model_name": model_name,
"initial_workers": initial_workers,
"batch_size": batch_size,
},
)
return JSONResponse(ts_response.json(), status_code=ts_response.status_code)
# mount last so that it doesn't override other routes defined above
app.mount("/", catch_all_app) |
@mreso Apologies for not checking in on this, never got around to it until today. I noticed just now that the latest version of torchserve doesn't have this problem anymore. The models are saved only with the mar file name. If I'm not missing anything, this issue can be closed |
🚀 The feature
Ignore url query parameters when registering/accessing/deleting models. Instead of saving
.mar
files asfilename.mar&queryparams
, save them asfilename.mar
Motivation, pitch
I'm using S3 presigned URL's to access models stored on S3. I'm aware of the Encrypted model serving option which allows accessing encrypted models on S3, however this method requires that I allow torchserve access to the bucket containing my models. This is not useful for my use case because the software architecture I am working with uses pre-signed AWS url's supplied to torchserve by an API that governs access to models on S3. This is to follow the least privilege principle, thus torchserve instance doesn't have access to the production bucket except with presigned and short-lived url's.
The problem with using S3 presigned URL's is that torchserve registers models using a filename that includes the entirety of
filename.mar&queryparams
. Since query parameters change every time I want to access the model, the model-store cache always misses.Example output of
ls model-store
:In simpler terms, this is what happens step-by-step:
https://bucketname.s3.amazonaws.com/filename.mar?AWSAccessKeyId=****&Signature=****&Expires=****
filename.mar
, but torchserve saves it asfilename.mar?AWSAccessKeyId=****&Signature=****&Expires=****
filename.mar?AWSAccessKeyId=****&Signature=****&Expires=****
, but the signature and expiration fields are different, thus cache missesProposed solution:
Ignore the query parameters in the url when saving the model. The filename is used here when registering a model and here when deleting a model. It's using this method to get filename from url, but that's a function that's designed to read filenames from OS paths, not urls. Using something like this should fix the filename
I can create a PR for this, but this is the first time I'm looking at the torchserve source code and I'd like to know; 1: If this solution is reasonable, 2: If I should consider other parts of the code that might need to change, 3: If anyone can suggest an alternative that will render this solution unnecessary
Thank you!
Alternatives
Alternative solutions could be as follows:
Use content-disposition-filename Header
I attempted to use the content-disposition-filename header when supplying S3 presigned url, but torchserve ignored the header. It might be better to check for the content-disposition header, and if it exists, use that to store the model
Use model_name to store model file
Instead of scraping the model name from the url, use the provided model_name in the model registration endpoint (
POST /models
).Take filename as function argument when registering a model
Take optional filename argument in the model registration endpoint (
POST /models
). If it is provided, it would use that for model filename.In my opinion this alternative is better than changing the existing filename convention because this option wouldn't be a breaking change
Additional context
No response
The text was updated successfully, but these errors were encountered: