Skip to content

Commit

Permalink
Re-release of atomic deployments through AppPublish (#2066)
Browse files Browse the repository at this point in the history
* Reapply "Run / deploy apps via new AppPublish RPC (#2043)" (#2056)

This reverts commit f80ddb8.

* Add defer_updates=True to FunctionCreateRequest
  • Loading branch information
mwaskom committed Aug 2, 2024
1 parent 1d58709 commit 662d2a9
Show file tree
Hide file tree
Showing 12 changed files with 120 additions and 65 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ We appreciate your patience while we speedily work towards a stable release of t

<!-- NEW CONTENT GENERATED BELOW. PLEASE PRESERVE THIS COMMENT. -->

### 0.64.0 (2024-07-29)

- App deployment events are now atomic, reducing the risk that a failed deploy will leave the App in a bad state.



### 0.63.87 (2024-07-24)

Expand Down
6 changes: 3 additions & 3 deletions modal/cli/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,10 +278,10 @@ def run(ctx, detach, quiet, interactive, env):

def deploy(
app_ref: str = typer.Argument(..., help="Path to a Python file with an app."),
name: str = typer.Option(None, help="Name of the deployment."),
name: str = typer.Option("", help="Name of the deployment."),
env: str = ENV_OPTION,
stream_logs: bool = typer.Option(False, help="Stream logs from the app upon deployment."),
tag: str = typer.Option(None, help="Tag the deployment with a version."),
tag: str = typer.Option("", help="Tag the deployment with a version."),
):
# this ensures that `modal.lookup()` without environment specification uses the same env as specified
env = ensure_env(env)
Expand All @@ -292,7 +292,7 @@ def deploy(
name = app.name

with enable_output():
res = deploy_app(app, name=name, environment_name=env, tag=tag)
res = deploy_app(app, name=name, environment_name=env or "", tag=tag)

if stream_logs:
stream_app_logs(res.app_id)
Expand Down
10 changes: 10 additions & 0 deletions modal/cls.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,16 @@ async def _load(self: "_Cls", resolver: Resolver, existing_object_id: Optional[s
)
)
resp = await resolver.client.stub.ClassCreate(req)
# Even though we already have the function_handle_metadata for this method locally,
# The RPC is going to replace it with function_handle_metadata derived from the server.
# We need to overwrite the definition_id sent back from the server here with the definition_id
# previously stored in function metadata, which may have been sent back from FunctionCreate.
# The problem is that this metadata propagates back and overwrites the metadata on the Function
# object itself. This is really messy. Maybe better to exclusively populate the method metadata
# from the function metadata we already have locally? Really a lot to clean up here...
for method in resp.handle_metadata.methods:
f_metadata = self._method_functions[method.function_name]._get_metadata()
method.function_handle_metadata.definition_id = f_metadata.definition_id
self._hydrate(resp.class_id, resolver.client, resp.handle_metadata)

rep = f"Cls({user_cls.__name__})"
Expand Down
6 changes: 5 additions & 1 deletion modal/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,7 @@ async def _load(method_bound_function: "_Function", resolver: Resolver, existing
function=function_definition,
# method_bound_function.object_id usually gets set by preload
existing_function_id=existing_object_id or method_bound_function.object_id or "",
defer_updates=True,
)
assert resolver.client.stub is not None # client should be connected when load is called
with FunctionCreationStatus(resolver, full_name) as function_creation_status:
Expand Down Expand Up @@ -834,6 +835,7 @@ async def _load(self: _Function, resolver: Resolver, existing_object_id: Optiona
function=function_definition,
schedule=schedule.proto_message if schedule is not None else None,
existing_function_id=existing_object_id or "",
defer_updates=True,
)
try:
response: api_pb2.FunctionCreateResponse = await retry_transient_errors(
Expand Down Expand Up @@ -1092,14 +1094,15 @@ def _initialize_from_empty(self):

def _hydrate_metadata(self, metadata: Optional[Message]):
# Overridden concrete implementation of base class method
assert metadata and isinstance(metadata, (api_pb2.Function, api_pb2.FunctionHandleMetadata))
assert metadata and isinstance(metadata, api_pb2.FunctionHandleMetadata)
self._is_generator = metadata.function_type == api_pb2.Function.FUNCTION_TYPE_GENERATOR
self._web_url = metadata.web_url
self._function_name = metadata.function_name
self._is_method = metadata.is_method
self._use_function_id = metadata.use_function_id
self._use_method_name = metadata.use_method_name
self._class_parameter_info = metadata.class_parameter_info
self._definition_id = metadata.definition_id

def _invocation_function_id(self) -> str:
return self._use_function_id or self.object_id
Expand All @@ -1119,6 +1122,7 @@ def _get_metadata(self):
use_function_id=self._use_function_id,
is_method=self._is_method,
class_parameter_info=self._class_parameter_info,
definition_id=self._definition_id,
)

def _set_mute_cancellation(self, value: bool = True):
Expand Down
101 changes: 56 additions & 45 deletions modal/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import dataclasses
import os
from multiprocessing.synchronize import Event
from typing import TYPE_CHECKING, Any, AsyncGenerator, Coroutine, Dict, List, Optional, TypeVar
from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Coroutine, Dict, List, Optional, TypeVar

from grpclib import GRPCError, Status
from rich.console import Console
Expand Down Expand Up @@ -39,6 +39,9 @@
_App = TypeVar("_App")


V = TypeVar("V")


async def _heartbeat(client: _Client, app_id: str) -> None:
request = api_pb2.AppHeartbeatRequest(app_id=app_id)
# TODO(erikbern): we should capture exceptions here
Expand Down Expand Up @@ -104,7 +107,6 @@ async def _create_all_objects(
client: _Client,
running_app: RunningApp,
indexed_objects: Dict[str, _Object],
new_app_state: int,
environment_name: str,
) -> None:
"""Create objects that have been defined but not created on the server."""
Expand Down Expand Up @@ -150,21 +152,46 @@ async def _load(tag, obj):

await TaskContext.gather(*(_load(tag, obj) for tag, obj in indexed_objects.items()))

# Create the app (and send a list of all tagged obs)
# TODO(erikbern): we should delete objects from a previous version that are no longer needed
# We just delete them from the app, but the actual objects will stay around
indexed_object_ids = running_app.tag_to_object_id
assert indexed_object_ids == running_app.tag_to_object_id
all_objects = resolver.objects()

unindexed_object_ids = list(set(obj.object_id for obj in all_objects) - set(running_app.tag_to_object_id.values()))
req_set = api_pb2.AppSetObjectsRequest(
async def _publish_app(
client: _Client,
running_app: RunningApp,
app_state: int, # api_pb2.AppState.value
indexed_objects: Dict[str, _Object],
name: str = "", # Only relevant for deployments
tag: str = "", # Only relevant for deployments
) -> str:
"""Wrapper for AppPublish RPC."""

# Could simplify this function some changing the internal representation to use
# function_ids / class_ids rather than the current tag_to_object_id (i.e. "indexed_objects")
def filter_values(full_dict: Dict[str, V], condition: Callable[[V], bool]) -> Dict[str, V]:
return {k: v for k, v in full_dict.items() if condition(v)}

# The entity prefixes are defined in the monorepo; is there any way to share them here?
function_ids = filter_values(running_app.tag_to_object_id, lambda v: v.startswith("fu-"))
class_ids = filter_values(running_app.tag_to_object_id, lambda v: v.startswith("cs-"))

function_objs = filter_values(indexed_objects, lambda v: v.object_id in function_ids.values())
definition_ids = {obj.object_id: obj._get_metadata().definition_id for obj in function_objs.values()} # type: ignore

request = api_pb2.AppPublishRequest(
app_id=running_app.app_id,
indexed_object_ids=indexed_object_ids,
unindexed_object_ids=unindexed_object_ids,
new_app_state=new_app_state, # type: ignore
name=name,
deployment_tag=tag,
app_state=app_state, # type: ignore : should be a api_pb2.AppState.value
function_ids=function_ids,
class_ids=class_ids,
definition_ids=definition_ids,
)
await retry_transient_errors(client.stub.AppSetObjects, req_set)
try:
response = await retry_transient_errors(client.stub.AppPublish, request)
except GRPCError as exc:
if exc.status == Status.INVALID_ARGUMENT or exc.status == Status.FAILED_PRECONDITION:
raise InvalidError(exc.message)
raise

return response.url


async def _disconnect(
Expand Down Expand Up @@ -252,7 +279,10 @@ async def _run_app(
exc_info: Optional[BaseException] = None
try:
# Create all members
await _create_all_objects(client, running_app, app._indexed_objects, app_state, environment_name)
await _create_all_objects(client, running_app, app._indexed_objects, environment_name)

# Publish the app
await _publish_app(client, running_app, app_state, app._indexed_objects)

# Show logs from dynamically created images.
# TODO: better way to do this
Expand Down Expand Up @@ -337,10 +367,12 @@ async def _serve_update(
client,
running_app,
app._indexed_objects,
api_pb2.APP_STATE_UNSPECIFIED,
environment_name,
)

# Publish the updated app
await _publish_app(client, running_app, api_pb2.APP_STATE_UNSPECIFIED, app._indexed_objects)

# Communicate to the parent process
is_ready.set()
except asyncio.exceptions.CancelledError:
Expand All @@ -361,7 +393,7 @@ async def _deploy_app(
namespace: Any = api_pb2.DEPLOYMENT_NAMESPACE_WORKSPACE,
client: Optional[_Client] = None,
environment_name: Optional[str] = None,
tag: Optional[str] = None,
tag: str = "",
) -> DeployResult:
"""Deploy an app and export its objects persistently.
Expand All @@ -387,9 +419,8 @@ async def _deploy_app(
if environment_name is None:
environment_name = config.get("environment")

if name is None:
name = app.name
if name is None:
name = name or app.name
if not name:
raise InvalidError(
"You need to either supply an explicit deployment name to the deploy command, "
"or have a name set on the app.\n"
Expand All @@ -402,9 +433,9 @@ async def _deploy_app(
else:
check_object_name(name, "App")

if tag is not None and not is_valid_tag(tag):
if tag and not is_valid_tag(tag):
raise InvalidError(
f"Tag {tag} is invalid."
f"Deployment tag {tag!r} is invalid."
"\n\nTags may only contain alphanumeric characters, dashes, periods, and underscores, "
"and must be 50 characters or less"
)
Expand All @@ -420,46 +451,26 @@ async def _deploy_app(
# Start heartbeats loop to keep the client alive
tc.infinite_loop(lambda: _heartbeat(client, running_app.app_id), sleep=HEARTBEAT_INTERVAL)

# Don't change the app state - deploy state is set by AppDeploy
post_init_state = api_pb2.APP_STATE_UNSPECIFIED

try:
# Create all members
await _create_all_objects(
client,
running_app,
app._indexed_objects,
post_init_state,
environment_name=environment_name,
)

# Deploy app
# TODO(erikbern): not needed if the app already existed
deploy_req = api_pb2.AppDeployRequest(
app_id=running_app.app_id,
name=name,
tag=tag,
namespace=namespace,
object_entity="ap",
visibility=api_pb2.APP_DEPLOY_VISIBILITY_WORKSPACE,
app_url = await _publish_app(
client, running_app, api_pb2.APP_STATE_DEPLOYED, app._indexed_objects, name, tag
)
try:
deploy_response = await retry_transient_errors(client.stub.AppDeploy, deploy_req)
except GRPCError as exc:
if exc.status == Status.INVALID_ARGUMENT:
raise InvalidError(exc.message)
if exc.status == Status.FAILED_PRECONDITION:
raise InvalidError(exc.message)
raise
url = deploy_response.url
except Exception as e:
# Note that AppClientDisconnect only stops the app if it's still initializing, and is a no-op otherwise.
await _disconnect(client, running_app.app_id, reason=api_pb2.APP_DISCONNECT_REASON_DEPLOYMENT_EXCEPTION)
raise e

if output_mgr := OutputManager.get():
output_mgr.print(step_completed("App deployed! 🎉"))
output_mgr.print(f"\nView Deployment: [magenta]{url}[/magenta]")
output_mgr.print(f"\nView Deployment: [magenta]{app_url}[/magenta]")
return DeployResult(app_id=running_app.app_id)


Expand Down
2 changes: 1 addition & 1 deletion modal_version/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
major_number = 0

# Bump this manually on breaking changes, then reset the number in _version_generated.py
minor_number = 63
minor_number = 64

# Right now, automatically increment the patch number in CI
__version__ = f"{major_number}.{minor_number}.{max(build_number, 0)}"
2 changes: 1 addition & 1 deletion modal_version/_version_generated.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright Modal Labs 2024

# Note: Reset this value to -1 whenever you make a minor `0.X` release of the client.
build_number = 99 # git: ddee9b4
build_number = 0 # git: d8f403b
4 changes: 2 additions & 2 deletions test/app_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,8 +302,8 @@ def test_hydrated_other_app_object_gets_referenced(servicer, client):
with Volume.ephemeral(client=client) as vol:
app.function(volumes={"/vol": vol})(dummy) # implicitly load vol
deploy_app(app, client=client)
app_set_objects_req = ctx.pop_request("AppSetObjects")
assert vol.object_id in app_set_objects_req.unindexed_object_ids
function_create_req: api_pb2.FunctionCreateRequest = ctx.pop_request("FunctionCreate")
assert vol.object_id in {obj.object_id for obj in function_create_req.function.object_dependencies}


def test_hasattr():
Expand Down
34 changes: 30 additions & 4 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ def __init__(self, blob_host, blobs):
self.app_client_disconnect_count = 0
self.app_get_logs_initial_count = 0
self.app_set_objects_count = 0
self.app_publish_count = 0

self.volume_counter = 0
# Volume-id -> commit/reload count
Expand Down Expand Up @@ -310,8 +311,17 @@ async def AppGetObjects(self, stream):
object_ids = self.app_objects.get(request.app_id, {})
objects = list(object_ids.items())
if request.include_unindexed:
unindexed_object_ids = self.app_unindexed_objects.get(request.app_id, [])
unindexed_object_ids = set()
for object_id in object_ids.values():
if object_id.startswith("fu-"):
definition = self.app_functions[object_id]
unindexed_object_ids |= {obj.object_id for obj in definition.object_dependencies}
objects += [(None, object_id) for object_id in unindexed_object_ids]
# TODO(michael) This perpetuates a hack! The container_test tests rely on hardcoded unindexed_object_ids
# but we now look those up dynamically from the indexed objects in (the real) AppGetObjects. But the
# container tests never actually set indexed objects on the app. We need a total rewrite here.
if (None, "im-1") not in objects:
objects.append((None, "im-1"))
items = [
api_pb2.AppGetObjectsItem(tag=tag, object=self.get_object_metadata(object_id)) for tag, object_id in objects
]
Expand All @@ -334,6 +344,21 @@ async def AppDeploy(self, stream):
self.app_state_history[request.app_id].append(api_pb2.APP_STATE_DEPLOYED)
await stream.send_message(api_pb2.AppDeployResponse(url="http://test.modal.com/foo/bar"))

async def AppPublish(self, stream):
request: api_pb2.AppPublishRequest = await stream.recv_message()
for key, val in request.definition_ids.items():
assert key.startswith("fu-")
assert val.startswith("de-")
# TODO(michael) add some other assertions once we make the mock server represent real RPCs more accurately
self.app_publish_count += 1
self.app_objects[request.app_id] = {**request.function_ids, **request.class_ids}
self.app_state_history[request.app_id].append(request.app_state)
if request.app_state == api_pb2.AppState.APP_STATE_DEPLOYED:
self.deployed_apps[request.name] = request.app_id
await stream.send_message(api_pb2.AppPublishResponse(url="http://test.modal.com/foo/bar"))
else:
await stream.send_message(api_pb2.AppPublishResponse())

async def AppGetByDeploymentName(self, stream):
request: api_pb2.AppGetByDeploymentNameRequest = await stream.recv_message()
await stream.send_message(api_pb2.AppGetByDeploymentNameResponse(app_id=self.deployed_apps.get(request.name)))
Expand Down Expand Up @@ -694,6 +719,7 @@ async def FunctionCreate(self, stream):
web_url=function.web_url,
use_function_id=function.use_function_id or function_id,
use_method_name=function.use_method_name,
definition_id=f"de-{self.n_functions}",
),
)
)
Expand Down Expand Up @@ -842,9 +868,9 @@ async def FunctionUpdateSchedulingParams(self, stream):

async def ImageGetOrCreate(self, stream):
request: api_pb2.ImageGetOrCreateRequest = await stream.recv_message()
for k in self.images:
if request.image.SerializeToString() == self.images[k].SerializeToString():
await stream.send_message(api_pb2.ImageGetOrCreateResponse(image_id=k))
for image_id, image in self.images.items():
if request.image.SerializeToString() == image.SerializeToString():
await stream.send_message(api_pb2.ImageGetOrCreateResponse(image_id=image_id))
return
idx = len(self.images) + 1
image_id = f"im-{idx}"
Expand Down
3 changes: 1 addition & 2 deletions test/container_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,6 @@ def _container_args(
)
else:
webhook_config = None

function_def = api_pb2.Function(
module_name=module_name,
function_name=function_name,
Expand Down Expand Up @@ -723,7 +722,7 @@ def test_cls_web_endpoint(servicer):
@skip_github_non_linux
def test_cls_web_asgi_construction(servicer):
servicer.app_objects.setdefault("ap-1", {}).setdefault("square", "fu-2")
servicer.app_functions["fu-2"] = api_pb2.FunctionHandleMetadata()
servicer.app_functions["fu-2"] = api_pb2.Function()

inputs = _get_web_inputs(method_name="asgi_web")
ret = _run_container(
Expand Down
Loading

0 comments on commit 662d2a9

Please sign in to comment.