Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tgi example styling updates #604

Merged
merged 1 commit into from
Mar 15, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 19 additions & 17 deletions examples/tgi-inference-aws-ec2/tgi_mistral_ec2.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# # Deploy Mistral's 7B TGI Model with AWS EC2
# # Deploy Mistral's 7B Model with TGI on AWS EC2

# This example demonstrates how to deploy a
# [TGI model](https://huggingface.co/docs/text-generation-inference/messages_api) on AWS EC2 using Runhouse.
Expand All @@ -7,14 +7,13 @@
# Zephyr is a 7B fine-tuned version of [Mistral's 7B-v0.1 model](https://huggingface.co/mistralai/Mistral-7B-v0.1).
#
# ## Setup credentials and dependencies
# ```
# Install the required dependencies:
# ```shell
# $ pip install -r requirements.txt
# ```
#
# We'll be launching an AWS EC2 instance via SkyPilot, so we need to make sure our AWS credentials are set up
# with SkyPilot:
# We'll be launching an AWS EC2 instance via [SkyPilot](https://github.com/skypilot-org/skypilot), so we need to make
# sure our AWS credentials are set up with SkyPilot:
# ```shell
# $ aws configure
# $ sky check
Expand Down Expand Up @@ -127,19 +126,18 @@ def restart_container(self):
# Deploy a new container
self.deploy()

# ---------------------------------------------------


# ## Setting up Runhouse primitives
#
# Now, we define the main function that will run locally when we run this script, and set up
# our Runhouse module on a remote cluster. First, we create a cluster with the desired instance type and provider.
# Our `instance_type` here is defined as `g5.4xlarge`, which is
# an [AWS instance type on EC2](https://aws.amazon.com/ec2/instance-types/g5/) with a GPU.
# (For this model we'll need a GPU and at least 16GB of RAM)
#
# For this model we'll need a GPU and at least 16GB of RAM
# We also open port 8080, which is the port that the TGI model will be running on.
#
# Learn more in the [Runhouse docs on clusters](/docs/tutorials/api-clusters).
# Learn more about clusters in the [Runhouse docs](/docs/tutorials/api-clusters).
#
# NOTE: Make sure that your code runs within a `if __name__ == "__main__":` block, as shown below. Otherwise,
# the script code will run when Runhouse attempts to run code remotely.
Expand Down Expand Up @@ -173,10 +171,14 @@ def restart_container(self):

# ## Sharing an inference endpoint
# We can publish this module for others to use:
# ```python
# remote_tgi_model.share(visibility="public")
# ```

# Alternatively we can share with specific users:
# ```python
# remote_tgi_model.share(["user1@gmail.com", "user2@gmail.com"], access_level="read")
# ```

# Note: For more info on fine-grained access controls, see the
# [Runhouse docs on sharing](https://www.run.house/docs/tutorials/quick-start-den#sharing).
Expand All @@ -185,7 +187,7 @@ def restart_container(self):
# We can call the `deploy` method on the model class instance if it were running locally.
# This will load and run the model on the remote cluster.
# We only need to do this setup step once, as further calls will use the existing docker container deployed
# on the cluster and maintain state between calls
# on the cluster and maintain state between calls:
remote_tgi_model.deploy()

# ## Sending a prompt to the model
Expand All @@ -199,27 +201,27 @@ def restart_container(self):
{"role": "user", "content": "Do you have mayonnaise recipes?"},
]

# We'll use the Messages API to send the prompt to the model
# We'll use the Messages API to send the prompt to the model.
# See [here](https://huggingface.co/docs/text-generation-inference/messages_api#streaming) for more info
# on the Messages API, and using the OpenAI python client
base_url = f"http://{cluster.address}:{port}/v1"

# Initialize the OpenAI client
# Initialize the OpenAI client, with the URL set to the cluster's address:
base_url = f"http://{cluster.address}:{port}/v1"
client = OpenAI(base_url=base_url, api_key="-")

# Call the model with the prompt messages
# Call the model with the prompt messages:
chat_completion = client.chat.completions.create(
model="tgi", messages=prompt_messages, stream=False
)
print(chat_completion)

# For streaming results, set `stream=True` and iterate over the results:
# ```python
# for message in chat_completion:
# print(message)
# print(message)
# ```

# Alternatively, we can also call the model via HTTP
print("------------")
print("To call the model via HTTP, use the following cURL command:")
# Alternatively, we can also call the model via HTTP:
print(
f"curl http://{cluster.address}:{port}/v1/chat/completions -X POST -d '"
'{"model": "tgi", "stream": false, "messages": [{"role": "system", "content": "You are a helpful assistant."},'
Expand Down
Loading