Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ignore url query parameters when storing models (.mar files) #2085

Open
MuhsinFatih opened this issue Jan 20, 2023 · 7 comments
Open

Ignore url query parameters when storing models (.mar files) #2085

MuhsinFatih opened this issue Jan 20, 2023 · 7 comments
Labels
triaged Issue has been reviewed and triaged

Comments

@MuhsinFatih
Copy link

MuhsinFatih commented Jan 20, 2023

🚀 The feature

Ignore url query parameters when registering/accessing/deleting models. Instead of saving .mar files as filename.mar&queryparams, save them as filename.mar

Motivation, pitch

I'm using S3 presigned URL's to access models stored on S3. I'm aware of the Encrypted model serving option which allows accessing encrypted models on S3, however this method requires that I allow torchserve access to the bucket containing my models. This is not useful for my use case because the software architecture I am working with uses pre-signed AWS url's supplied to torchserve by an API that governs access to models on S3. This is to follow the least privilege principle, thus torchserve instance doesn't have access to the production bucket except with presigned and short-lived url's.

The problem with using S3 presigned URL's is that torchserve registers models using a filename that includes the entirety of filename.mar&queryparams. Since query parameters change every time I want to access the model, the model-store cache always misses.

Example output of ls model-store:

╰─ ls model-store 
'model1.mar?AWSAccessKeyId=access-key&Signature=signature&Expires=expdate'
'model1.mar?AWSAccessKeyId=access-key&Signature=signature2&Expires=expdate2'
'model1.mar?AWSAccessKeyId=access-key&Signature=signature3&Expires=expdate3'
...

In simpler terms, this is what happens step-by-step:

  1. My API generates presigned S3 url: https://bucketname.s3.amazonaws.com/filename.mar?AWSAccessKeyId=****&Signature=****&Expires=****
  2. Torchserve receives url
  3. Torchserve downloads the model. I expect it to save it as filename.mar, but torchserve saves it as filename.mar?AWSAccessKeyId=****&Signature=****&Expires=****
  4. I do this again with another model. I then want to use the first model again. I expect a cache hit (I expect that torch-serve finds the model in model-store and avoids redownloading)
  5. Torchserve receives a new presigned url for the same file. It's filename is again filename.mar?AWSAccessKeyId=****&Signature=****&Expires=****, but the signature and expiration fields are different, thus cache misses
  6. Torchserve downloads the model again. Cache is cluttered because we're downloading the same model over and over again every time this happens, and we're waiting long download times

Proposed solution:
Ignore the query parameters in the url when saving the model. The filename is used here when registering a model and here when deleting a model. It's using this method to get filename from url, but that's a function that's designed to read filenames from OS paths, not urls. Using something like this should fix the filename

I can create a PR for this, but this is the first time I'm looking at the torchserve source code and I'd like to know; 1: If this solution is reasonable, 2: If I should consider other parts of the code that might need to change, 3: If anyone can suggest an alternative that will render this solution unnecessary

Thank you!

Alternatives

Alternative solutions could be as follows:

  1. Use content-disposition-filename Header
    I attempted to use the content-disposition-filename header when supplying S3 presigned url, but torchserve ignored the header. It might be better to check for the content-disposition header, and if it exists, use that to store the model

  2. Use model_name to store model file
    Instead of scraping the model name from the url, use the provided model_name in the model registration endpoint (POST /models).

  3. Take filename as function argument when registering a model
    Take optional filename argument in the model registration endpoint (POST /models). If it is provided, it would use that for model filename.

In my opinion this alternative is better than changing the existing filename convention because this option wouldn't be a breaking change

Additional context

No response

@mreso
Copy link
Collaborator

mreso commented Jan 21, 2023

Hi,

thanks for filing the issue! The behavior is indeed unfortunate in your case and the proposed solution sounds reasonable to me.
The only edge case I could come up with is if somebody uses bogus parameters to create new filenames for model with the same name when he can not control the filenames on the server. But thats very exotic.

@mreso mreso added the triaged Issue has been reviewed and triaged label Jan 21, 2023
@MuhsinFatih
Copy link
Author

@mreso awesome! I'll try to make a PR when I get the chance. I believe the bogus filename problem is best solved with the alternative number 3 (Take optional filename argument in the model registration endpoint). It would be the most flexible option for the developer as well. What do you think?

@mreso
Copy link
Collaborator

mreso commented Jan 25, 2023

@MuhsinFatih Great! Looking forward to your PR! Let me know if you need any help with it!

Agreed on the solution for users using bogus filenames. Would be great to have this option in case users leverage that "flaw" in TS to make it work in their situation. If you want to take it on as well I would suggest creating a second PR for it to facilitate the review of the changes.

@cfculhane
Copy link

Hi @MuhsinFatih - curious how you got presigned S3 URL's to work in the first place? I keep hitting the same errors as #1293 , and even if the cache was missed this would allow us to move forward with serving models from S3..

@MuhsinFatih
Copy link
Author

MuhsinFatih commented Jun 6, 2023

@cfculhane You might be having trouble with the signature version and/or with url parsing. Try the two following:

# specify signature version
s3_client = boto3.client("s3", config=Config(signature_version="s3v4"), region_name="your-region")
# replace & with %26
model_url = model_url.replace("&", "%26")

(source: #669 (comment))

@cfculhane
Copy link

cfculhane commented Jun 9, 2023

Ahh thank you - this worked.

I ended up coding up a fastAPI web service inside the torchserve container to handle these and download them to the model store, which allowed us to cache them better and to handle both the preseigned urls without modification, and also S3 URI's natively. We then send the register model POST request off with just the local filename, and torchserve is none the wiser :)

E.g. with some internal modules redacted :

"""
Due to torchserve not supporting S3 presigned URLs or boto3 compatible S3 Paths, we need to put a tiny python
webserver in front of the torchserve management API.

Relevant torchserve issues:
- https://github.com/pytorch/serve/issues/2085
- https://github.com/pytorch/serve/issues/1293

This only handles / forwards requests to the torchserve management API, and does not handle inference requests.

"""
import functools
import os
import shutil
from pathlib import Path
from typing import Optional
from urllib.parse import urlparse

import requests
from <internal>.s3_uploader import AmazonS3
from <internal>.utils import create_logger
from fastapi import FastAPI
from s3path import S3Path
from starlette.requests import Request
from starlette.responses import JSONResponse
from starlette.routing import Route

# noinspection HttpUrlsUsage
TS_MANAGEMENT_ADDRESS = os.environ.get("TS_MANAGEMENT_ADDRESS", "http://localhost:8001")
CONTAINER_MODEL_STORE_DIR = Path(os.environ.get("CONTAINER_MODEL_STORE_DIR", Path(__file__).parent / "model_store"))

logger = create_logger("api_proxy")
logger.info(f"Starting API proxy, forwarding management API requests to {TS_MANAGEMENT_ADDRESS}")


async def catch_all(request: Request):
    """
    This just forwards all un-routed requests to the torchserve management API
    """
    ts_response = requests.request(
        method=request.method,
        url=f"{TS_MANAGEMENT_ADDRESS}/{request.path_params['full_path']}",
        headers=request.headers,
        data=await request.body(),
        allow_redirects=False,
    )
    return JSONResponse(ts_response.json(), status_code=ts_response.status_code)


routes = [
    Route("/{full_path:path}", endpoint=catch_all, methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"]),
]
catch_all_app = FastAPI(routes=routes)

app = FastAPI()


@app.get("/ping")
async def ping(request: Request):
    return JSONResponse({"message": "PONG"})


def download_file_from_url(url: str, path: Path) -> Path:
    from tqdm.auto import tqdm

    r = requests.get(url, stream=True, allow_redirects=True)
    if r.status_code != 200:
        r.raise_for_status()  # Will only raise for 4xx codes, so...
        raise RuntimeError(f"Request to {url} returned status code {r.status_code}")
    file_size = int(r.headers.get("Content-Length", 0))

    path.parent.mkdir(parents=True, exist_ok=True)

    desc = "(Unknown total file size)" if file_size == 0 else ""
    r.raw.read = functools.partial(r.raw.read, decode_content=True)  # Decompress if needed
    with tqdm.wrapattr(r.raw, "read", total=file_size, desc=desc) as r_raw:
        with path.open("wb") as f:
            shutil.copyfileobj(r_raw, f)

    return path


@app.post("/models")
async def models_post(
    url: str, model_name: Optional[str] = None, initial_workers: Optional[int] = 1, batch_size: Optional[int] = 1
):
    if url.startswith("s3://"):  # S3 URI
        model_path_remote = S3Path.from_uri(url)
        model_name = model_name or model_path_remote.stem
        local_model_path = CONTAINER_MODEL_STORE_DIR / f"{model_name}.mar"
        logger.info(f"Detected S3 URI, downloading from {model_path_remote} to {local_model_path}")
        AmazonS3().download_file_from_bucket(
            bucket_name=model_path_remote.bucket, source_path=model_path_remote.key, dest_path=local_model_path
        )

    elif url.startswith("https://"): 
        parsed_url = urlparse(url)
        model_name = model_name or Path(parsed_url.path).stem
        local_model_path = CONTAINER_MODEL_STORE_DIR / f"{model_name}.mar"
        logger.info(f"Detected HTTPS URI, downloading from {url} to {local_model_path}")
        download_file_from_url(url, local_model_path)
    else:
        model_name = model_name or Path(url).stem
        local_model_path = CONTAINER_MODEL_STORE_DIR / f"{model_name}.mar"

    # Validate we now have the file locally
    if not local_model_path.exists():
        raise FileNotFoundError(f"Model file not found at {local_model_path}")

    ts_response = requests.post(
        url=f"{TS_MANAGEMENT_ADDRESS}/models",
        params={
            "url": local_model_path.name,
            "model_name": model_name,
            "initial_workers": initial_workers,
            "batch_size": batch_size,
        },
    )
    return JSONResponse(ts_response.json(), status_code=ts_response.status_code)


# mount last so that it doesn't override other routes defined above
app.mount("/", catch_all_app)

@MuhsinFatih
Copy link
Author

MuhsinFatih commented Oct 5, 2023

@mreso Apologies for not checking in on this, never got around to it until today. I noticed just now that the latest version of torchserve doesn't have this problem anymore. The models are saved only with the mar file name. If I'm not missing anything, this issue can be closed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
triaged Issue has been reviewed and triaged
Projects
None yet
Development

No branches or pull requests

3 participants