Skip to content

Commit

Permalink
Switch parallel embedding example to use async primitives.
Browse files Browse the repository at this point in the history
  • Loading branch information
rohinb2 committed May 3, 2024
1 parent 41a8c44 commit bfe1b56
Showing 1 changed file with 29 additions and 28 deletions.
57 changes: 29 additions & 28 deletions examples/parallel-hf-embedding-ec2/parallel_hf_embedding_ec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
# Imports of libraries that are needed on the remote machine (in this case, the `langchain` dependencies)
# can happen within the functions that will be sent to the Runhouse cluster.

import concurrent.futures
import asyncio
import time
from typing import List
from urllib.parse import urljoin, urlparse
Expand Down Expand Up @@ -163,7 +163,7 @@ def embed_docs(self, urls: List[str]):
# Make sure that your code runs within a `if __name__ == "__main__":` block, as shown below. Otherwise,
# the script code will run when Runhouse attempts to import your code remotely.
# :::
if __name__ == "__main__":
async def main():
cluster = rh.cluster("rh-a10g", instance_type="A10G:4").save().up_if_not()

# We set up some parameters for our embedding task.
Expand Down Expand Up @@ -211,33 +211,34 @@ def embed_docs(self, urls: List[str]):
print(f"Time to initialize {num_replicas} replicas: {time.time() - start_time}")

# ## Calling the Runhouse modules in parallel
# We set up a loop to call each replica in parallel with the partitioned URLs.
# We'll simply use the `embed_docs` function on the remote module to embed all the URLs in parallel. Note that
# we can call this function exactly as if it were a local module. In this case, because we are using the
# `asyncio` library to make parallel calls, we need to use a special `run_async=True` argument to the
# Runhouse function. This tells Runhouse to return a coroutine that we can await on, rather than making
# a blocking network call to the server.
start_time = time.time()
results = []

# Note again that we can call the `embed_docs` function on the
# remote module exactly as if it were a local module.
# This function is simply a wrapper to use with the ThreadPoolExecutor.
def call_on_replica(replica, urls):
return replica.embed_docs(urls)

# This is standard Python code that uses the ThreadPoolExecutor to make `num_parallel_tasks` calls at once
# to the replicas. We then collect the results and print the total time taken.
with concurrent.futures.ThreadPoolExecutor(
max_workers=num_parallel_tasks
) as executor:
futs = [
executor.submit(call_on_replica, replicas[i % num_replicas], [urls[i]])
for i in range(len(urls))
]
for fut in concurrent.futures.as_completed(futs):
res = fut.result()
if res is not None:
results.extend(res)
else:
print("An embedding call failed.")

print(f"Received {len(results)} total embeddings.")
futs = [
asyncio.create_task(
replicas[i % num_replicas].embed_docs([urls[i]], run_async=True)
)
for i in range(len(urls))
]

all_embeddings = []
failures = 0
task_results = await asyncio.gather(*futs)
for res in task_results:
if res is not None:
all_embeddings.extend(res)
else:
print("An embedding call failed.")
failures += 1

print(f"Received {len(all_embeddings)} total embeddings, with {failures} failures.")
print(
f"Embedded {len(urls)} docs across {num_replicas} replicas with {num_parallel_tasks} concurrent calls: {time.time() - start_time}"
)


if __name__ == "__main__":
asyncio.run(main())

0 comments on commit bfe1b56

Please sign in to comment.