From b2b58228bfb2235cabe4673ce894d5c8b712afbc Mon Sep 17 00:00:00 2001 From: jlewitt1 Date: Fri, 15 Mar 2024 15:34:33 +0200 Subject: [PATCH] tgi example styling updates --- .../tgi-inference-aws-ec2/tgi_mistral_ec2.py | 35 ++++++++++--------- 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/examples/tgi-inference-aws-ec2/tgi_mistral_ec2.py b/examples/tgi-inference-aws-ec2/tgi_mistral_ec2.py index 32bee039f..dfd27e50f 100644 --- a/examples/tgi-inference-aws-ec2/tgi_mistral_ec2.py +++ b/examples/tgi-inference-aws-ec2/tgi_mistral_ec2.py @@ -1,4 +1,4 @@ -# # Deploy Mistral's 7B TGI Model with AWS EC2 +# # Deploy Mistral's 7B Model with TGI on AWS EC2 # This example demonstrates how to deploy a # [TGI model](https://huggingface.co/docs/text-generation-inference/messages_api) on AWS EC2 using Runhouse. @@ -7,14 +7,13 @@ # Zephyr is a 7B fine-tuned version of [Mistral's 7B-v0.1 model](https://huggingface.co/mistralai/Mistral-7B-v0.1). # # ## Setup credentials and dependencies -# ``` # Install the required dependencies: # ```shell # $ pip install -r requirements.txt # ``` # -# We'll be launching an AWS EC2 instance via SkyPilot, so we need to make sure our AWS credentials are set up -# with SkyPilot: +# We'll be launching an AWS EC2 instance via [SkyPilot](https://github.com/skypilot-org/skypilot), so we need to make +# sure our AWS credentials are set up with SkyPilot: # ```shell # $ aws configure # $ sky check @@ -127,8 +126,6 @@ def restart_container(self): # Deploy a new container self.deploy() - # --------------------------------------------------- - # ## Setting up Runhouse primitives # @@ -136,10 +133,11 @@ def restart_container(self): # our Runhouse module on a remote cluster. First, we create a cluster with the desired instance type and provider. # Our `instance_type` here is defined as `g5.4xlarge`, which is # an [AWS instance type on EC2](https://aws.amazon.com/ec2/instance-types/g5/) with a GPU. -# (For this model we'll need a GPU and at least 16GB of RAM) +# +# For this model we'll need a GPU and at least 16GB of RAM # We also open port 8080, which is the port that the TGI model will be running on. # -# Learn more in the [Runhouse docs on clusters](/docs/tutorials/api-clusters). +# Learn more about clusters in the [Runhouse docs](/docs/tutorials/api-clusters). # # NOTE: Make sure that your code runs within a `if __name__ == "__main__":` block, as shown below. Otherwise, # the script code will run when Runhouse attempts to run code remotely. @@ -173,10 +171,14 @@ def restart_container(self): # ## Sharing an inference endpoint # We can publish this module for others to use: + # ```python # remote_tgi_model.share(visibility="public") + # ``` # Alternatively we can share with specific users: + # ```python # remote_tgi_model.share(["user1@gmail.com", "user2@gmail.com"], access_level="read") + # ``` # Note: For more info on fine-grained access controls, see the # [Runhouse docs on sharing](https://www.run.house/docs/tutorials/quick-start-den#sharing). @@ -185,7 +187,7 @@ def restart_container(self): # We can call the `deploy` method on the model class instance if it were running locally. # This will load and run the model on the remote cluster. # We only need to do this setup step once, as further calls will use the existing docker container deployed - # on the cluster and maintain state between calls + # on the cluster and maintain state between calls: remote_tgi_model.deploy() # ## Sending a prompt to the model @@ -199,26 +201,27 @@ def restart_container(self): {"role": "user", "content": "Do you have mayonnaise recipes?"}, ] - # We'll use the Messages API to send the prompt to the model + # We'll use the Messages API to send the prompt to the model. # See [here](https://huggingface.co/docs/text-generation-inference/messages_api#streaming) for more info # on the Messages API, and using the OpenAI python client - base_url = f"http://{cluster.address}:{port}/v1" - # Initialize the OpenAI client + # Initialize the OpenAI client: + base_url = f"http://{cluster.address}:{port}/v1" client = OpenAI(base_url=base_url, api_key="-") - # Call the model with the prompt messages + # Call the model with the prompt messages: chat_completion = client.chat.completions.create( model="tgi", messages=prompt_messages, stream=False ) print(chat_completion) # For streaming results, set `stream=True` and iterate over the results: + # ```python # for message in chat_completion: - # print(message) + # print(message) + # ``` - # Alternatively, we can also call the model via HTTP - print("------------") + # Alternatively, we can also call the model via HTTP: print("To call the model via HTTP, use the following cURL command:") print( f"curl http://{cluster.address}:{port}/v1/chat/completions -X POST -d '"