Skip to content

Commit

Permalink
tgi example styling updates
Browse files Browse the repository at this point in the history
  • Loading branch information
jlewitt1 committed Mar 15, 2024
1 parent 8a58877 commit e787099
Showing 1 changed file with 16 additions and 13 deletions.
29 changes: 16 additions & 13 deletions examples/tgi-inference-aws-ec2/tgi_mistral_ec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,13 @@
# Zephyr is a 7B fine-tuned version of [Mistral's 7B-v0.1 model](https://huggingface.co/mistralai/Mistral-7B-v0.1).
#
# ## Setup credentials and dependencies
# ```
# Install the required dependencies:
# ```shell
# $ pip install -r requirements.txt
# ```
#
# We'll be launching an AWS EC2 instance via SkyPilot, so we need to make sure our AWS credentials are set up
# with SkyPilot:
# We'll be launching an AWS EC2 instance via [SkyPilot](https://github.com/skypilot-org/skypilot), so we need to make
# sure our AWS credentials are set up with SkyPilot:
# ```shell
# $ aws configure
# $ sky check
Expand Down Expand Up @@ -127,19 +126,18 @@ def restart_container(self):
# Deploy a new container
self.deploy()

# ---------------------------------------------------


# ## Setting up Runhouse primitives
#
# Now, we define the main function that will run locally when we run this script, and set up
# our Runhouse module on a remote cluster. First, we create a cluster with the desired instance type and provider.
# Our `instance_type` here is defined as `g5.4xlarge`, which is
# an [AWS instance type on EC2](https://aws.amazon.com/ec2/instance-types/g5/) with a GPU.
# (For this model we'll need a GPU and at least 16GB of RAM)
#
# For this model we'll need a GPU and at least 16GB of RAM
# We also open port 8080, which is the port that the TGI model will be running on.
#
# Learn more in the [Runhouse docs on clusters](/docs/tutorials/api-clusters).
# Learn more about clusters in the [Runhouse docs](/docs/tutorials/api-clusters).
#
# NOTE: Make sure that your code runs within a `if __name__ == "__main__":` block, as shown below. Otherwise,
# the script code will run when Runhouse attempts to run code remotely.
Expand Down Expand Up @@ -173,10 +171,14 @@ def restart_container(self):

# ## Sharing an inference endpoint
# We can publish this module for others to use:
# ```python
# remote_tgi_model.share(visibility="public")
# ```

# Alternatively we can share with specific users:
# ```python
# remote_tgi_model.share(["user1@gmail.com", "user2@gmail.com"], access_level="read")
# ```

# Note: For more info on fine-grained access controls, see the
# [Runhouse docs on sharing](https://www.run.house/docs/tutorials/quick-start-den#sharing).
Expand All @@ -186,7 +188,7 @@ def restart_container(self):
# This will load and run the model on the remote cluster.
# We only need to do this setup step once, as further calls will use the existing docker container deployed
# on the cluster and maintain state between calls
remote_tgi_model.deploy()
remote_tgi_model.restart_container()

# ## Sending a prompt to the model
prompt_messages = [
Expand All @@ -199,12 +201,12 @@ def restart_container(self):
{"role": "user", "content": "Do you have mayonnaise recipes?"},
]

# We'll use the Messages API to send the prompt to the model
# We'll use the Messages API to send the prompt to the model.
# See [here](https://huggingface.co/docs/text-generation-inference/messages_api#streaming) for more info
# on the Messages API, and using the OpenAI python client
base_url = f"http://{cluster.address}:{port}/v1"

# Initialize the OpenAI client
base_url = f"http://{cluster.address}:{port}/v1"
client = OpenAI(base_url=base_url, api_key="-")

# Call the model with the prompt messages
Expand All @@ -214,11 +216,12 @@ def restart_container(self):
print(chat_completion)

# For streaming results, set `stream=True` and iterate over the results:
# ```python
# for message in chat_completion:
# print(message)
# print(message)
# ```

# Alternatively, we can also call the model via HTTP
print("------------")
# Alternatively, we can also call the model via HTTP:
print("To call the model via HTTP, use the following cURL command:")
print(
f"curl http://{cluster.address}:{port}/v1/chat/completions -X POST -d '"
Expand Down

0 comments on commit e787099

Please sign in to comment.