diff --git a/CHANGELOG.md b/CHANGELOG.md index 1fac9dcc4..ba68d259c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,11 @@ We appreciate your patience while we speedily work towards a stable release of t +### 0.64.0 (2024-07-29) + +- App deployment events are now atomic, reducing the risk that a failed deploy will leave the App in a bad state. + + ### 0.63.87 (2024-07-24) diff --git a/modal/cli/run.py b/modal/cli/run.py index 5f01ae490..032de2e36 100644 --- a/modal/cli/run.py +++ b/modal/cli/run.py @@ -278,10 +278,10 @@ def run(ctx, detach, quiet, interactive, env): def deploy( app_ref: str = typer.Argument(..., help="Path to a Python file with an app."), - name: str = typer.Option(None, help="Name of the deployment."), + name: str = typer.Option("", help="Name of the deployment."), env: str = ENV_OPTION, stream_logs: bool = typer.Option(False, help="Stream logs from the app upon deployment."), - tag: str = typer.Option(None, help="Tag the deployment with a version."), + tag: str = typer.Option("", help="Tag the deployment with a version."), ): # this ensures that `modal.lookup()` without environment specification uses the same env as specified env = ensure_env(env) @@ -292,7 +292,7 @@ def deploy( name = app.name with enable_output(): - res = deploy_app(app, name=name, environment_name=env, tag=tag) + res = deploy_app(app, name=name, environment_name=env or "", tag=tag) if stream_logs: stream_app_logs(res.app_id) diff --git a/modal/cls.py b/modal/cls.py index 52e025e17..b2339c97f 100644 --- a/modal/cls.py +++ b/modal/cls.py @@ -274,6 +274,16 @@ async def _load(self: "_Cls", resolver: Resolver, existing_object_id: Optional[s ) ) resp = await resolver.client.stub.ClassCreate(req) + # Even though we already have the function_handle_metadata for this method locally, + # The RPC is going to replace it with function_handle_metadata derived from the server. + # We need to overwrite the definition_id sent back from the server here with the definition_id + # previously stored in function metadata, which may have been sent back from FunctionCreate. + # The problem is that this metadata propagates back and overwrites the metadata on the Function + # object itself. This is really messy. Maybe better to exclusively populate the method metadata + # from the function metadata we already have locally? Really a lot to clean up here... + for method in resp.handle_metadata.methods: + f_metadata = self._method_functions[method.function_name]._get_metadata() + method.function_handle_metadata.definition_id = f_metadata.definition_id self._hydrate(resp.class_id, resolver.client, resp.handle_metadata) rep = f"Cls({user_cls.__name__})" diff --git a/modal/functions.py b/modal/functions.py index a01a44ada..1d446c31f 100644 --- a/modal/functions.py +++ b/modal/functions.py @@ -351,6 +351,7 @@ async def _load(method_bound_function: "_Function", resolver: Resolver, existing function=function_definition, # method_bound_function.object_id usually gets set by preload existing_function_id=existing_object_id or method_bound_function.object_id or "", + defer_updates=True, ) assert resolver.client.stub is not None # client should be connected when load is called with FunctionCreationStatus(resolver, full_name) as function_creation_status: @@ -834,6 +835,7 @@ async def _load(self: _Function, resolver: Resolver, existing_object_id: Optiona function=function_definition, schedule=schedule.proto_message if schedule is not None else None, existing_function_id=existing_object_id or "", + defer_updates=True, ) try: response: api_pb2.FunctionCreateResponse = await retry_transient_errors( @@ -1092,7 +1094,7 @@ def _initialize_from_empty(self): def _hydrate_metadata(self, metadata: Optional[Message]): # Overridden concrete implementation of base class method - assert metadata and isinstance(metadata, (api_pb2.Function, api_pb2.FunctionHandleMetadata)) + assert metadata and isinstance(metadata, api_pb2.FunctionHandleMetadata) self._is_generator = metadata.function_type == api_pb2.Function.FUNCTION_TYPE_GENERATOR self._web_url = metadata.web_url self._function_name = metadata.function_name @@ -1100,6 +1102,7 @@ def _hydrate_metadata(self, metadata: Optional[Message]): self._use_function_id = metadata.use_function_id self._use_method_name = metadata.use_method_name self._class_parameter_info = metadata.class_parameter_info + self._definition_id = metadata.definition_id def _invocation_function_id(self) -> str: return self._use_function_id or self.object_id @@ -1119,6 +1122,7 @@ def _get_metadata(self): use_function_id=self._use_function_id, is_method=self._is_method, class_parameter_info=self._class_parameter_info, + definition_id=self._definition_id, ) def _set_mute_cancellation(self, value: bool = True): diff --git a/modal/runner.py b/modal/runner.py index 39cb97da3..0663a301f 100644 --- a/modal/runner.py +++ b/modal/runner.py @@ -3,7 +3,7 @@ import dataclasses import os from multiprocessing.synchronize import Event -from typing import TYPE_CHECKING, Any, AsyncGenerator, Coroutine, Dict, List, Optional, TypeVar +from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Coroutine, Dict, List, Optional, TypeVar from grpclib import GRPCError, Status from rich.console import Console @@ -39,6 +39,9 @@ _App = TypeVar("_App") +V = TypeVar("V") + + async def _heartbeat(client: _Client, app_id: str) -> None: request = api_pb2.AppHeartbeatRequest(app_id=app_id) # TODO(erikbern): we should capture exceptions here @@ -104,7 +107,6 @@ async def _create_all_objects( client: _Client, running_app: RunningApp, indexed_objects: Dict[str, _Object], - new_app_state: int, environment_name: str, ) -> None: """Create objects that have been defined but not created on the server.""" @@ -150,21 +152,46 @@ async def _load(tag, obj): await TaskContext.gather(*(_load(tag, obj) for tag, obj in indexed_objects.items())) - # Create the app (and send a list of all tagged obs) - # TODO(erikbern): we should delete objects from a previous version that are no longer needed - # We just delete them from the app, but the actual objects will stay around - indexed_object_ids = running_app.tag_to_object_id - assert indexed_object_ids == running_app.tag_to_object_id - all_objects = resolver.objects() - unindexed_object_ids = list(set(obj.object_id for obj in all_objects) - set(running_app.tag_to_object_id.values())) - req_set = api_pb2.AppSetObjectsRequest( +async def _publish_app( + client: _Client, + running_app: RunningApp, + app_state: int, # api_pb2.AppState.value + indexed_objects: Dict[str, _Object], + name: str = "", # Only relevant for deployments + tag: str = "", # Only relevant for deployments +) -> str: + """Wrapper for AppPublish RPC.""" + + # Could simplify this function some changing the internal representation to use + # function_ids / class_ids rather than the current tag_to_object_id (i.e. "indexed_objects") + def filter_values(full_dict: Dict[str, V], condition: Callable[[V], bool]) -> Dict[str, V]: + return {k: v for k, v in full_dict.items() if condition(v)} + + # The entity prefixes are defined in the monorepo; is there any way to share them here? + function_ids = filter_values(running_app.tag_to_object_id, lambda v: v.startswith("fu-")) + class_ids = filter_values(running_app.tag_to_object_id, lambda v: v.startswith("cs-")) + + function_objs = filter_values(indexed_objects, lambda v: v.object_id in function_ids.values()) + definition_ids = {obj.object_id: obj._get_metadata().definition_id for obj in function_objs.values()} # type: ignore + + request = api_pb2.AppPublishRequest( app_id=running_app.app_id, - indexed_object_ids=indexed_object_ids, - unindexed_object_ids=unindexed_object_ids, - new_app_state=new_app_state, # type: ignore + name=name, + deployment_tag=tag, + app_state=app_state, # type: ignore : should be a api_pb2.AppState.value + function_ids=function_ids, + class_ids=class_ids, + definition_ids=definition_ids, ) - await retry_transient_errors(client.stub.AppSetObjects, req_set) + try: + response = await retry_transient_errors(client.stub.AppPublish, request) + except GRPCError as exc: + if exc.status == Status.INVALID_ARGUMENT or exc.status == Status.FAILED_PRECONDITION: + raise InvalidError(exc.message) + raise + + return response.url async def _disconnect( @@ -252,7 +279,10 @@ async def _run_app( exc_info: Optional[BaseException] = None try: # Create all members - await _create_all_objects(client, running_app, app._indexed_objects, app_state, environment_name) + await _create_all_objects(client, running_app, app._indexed_objects, environment_name) + + # Publish the app + await _publish_app(client, running_app, app_state, app._indexed_objects) # Show logs from dynamically created images. # TODO: better way to do this @@ -337,10 +367,12 @@ async def _serve_update( client, running_app, app._indexed_objects, - api_pb2.APP_STATE_UNSPECIFIED, environment_name, ) + # Publish the updated app + await _publish_app(client, running_app, api_pb2.APP_STATE_UNSPECIFIED, app._indexed_objects) + # Communicate to the parent process is_ready.set() except asyncio.exceptions.CancelledError: @@ -361,7 +393,7 @@ async def _deploy_app( namespace: Any = api_pb2.DEPLOYMENT_NAMESPACE_WORKSPACE, client: Optional[_Client] = None, environment_name: Optional[str] = None, - tag: Optional[str] = None, + tag: str = "", ) -> DeployResult: """Deploy an app and export its objects persistently. @@ -387,9 +419,8 @@ async def _deploy_app( if environment_name is None: environment_name = config.get("environment") - if name is None: - name = app.name - if name is None: + name = name or app.name + if not name: raise InvalidError( "You need to either supply an explicit deployment name to the deploy command, " "or have a name set on the app.\n" @@ -402,9 +433,9 @@ async def _deploy_app( else: check_object_name(name, "App") - if tag is not None and not is_valid_tag(tag): + if tag and not is_valid_tag(tag): raise InvalidError( - f"Tag {tag} is invalid." + f"Deployment tag {tag!r} is invalid." "\n\nTags may only contain alphanumeric characters, dashes, periods, and underscores, " "and must be 50 characters or less" ) @@ -420,38 +451,18 @@ async def _deploy_app( # Start heartbeats loop to keep the client alive tc.infinite_loop(lambda: _heartbeat(client, running_app.app_id), sleep=HEARTBEAT_INTERVAL) - # Don't change the app state - deploy state is set by AppDeploy - post_init_state = api_pb2.APP_STATE_UNSPECIFIED - try: # Create all members await _create_all_objects( client, running_app, app._indexed_objects, - post_init_state, environment_name=environment_name, ) - # Deploy app - # TODO(erikbern): not needed if the app already existed - deploy_req = api_pb2.AppDeployRequest( - app_id=running_app.app_id, - name=name, - tag=tag, - namespace=namespace, - object_entity="ap", - visibility=api_pb2.APP_DEPLOY_VISIBILITY_WORKSPACE, + app_url = await _publish_app( + client, running_app, api_pb2.APP_STATE_DEPLOYED, app._indexed_objects, name, tag ) - try: - deploy_response = await retry_transient_errors(client.stub.AppDeploy, deploy_req) - except GRPCError as exc: - if exc.status == Status.INVALID_ARGUMENT: - raise InvalidError(exc.message) - if exc.status == Status.FAILED_PRECONDITION: - raise InvalidError(exc.message) - raise - url = deploy_response.url except Exception as e: # Note that AppClientDisconnect only stops the app if it's still initializing, and is a no-op otherwise. await _disconnect(client, running_app.app_id, reason=api_pb2.APP_DISCONNECT_REASON_DEPLOYMENT_EXCEPTION) @@ -459,7 +470,7 @@ async def _deploy_app( if output_mgr := OutputManager.get(): output_mgr.print(step_completed("App deployed! 🎉")) - output_mgr.print(f"\nView Deployment: [magenta]{url}[/magenta]") + output_mgr.print(f"\nView Deployment: [magenta]{app_url}[/magenta]") return DeployResult(app_id=running_app.app_id) diff --git a/modal_version/__init__.py b/modal_version/__init__.py index 7ccd74820..1e2b1ba50 100644 --- a/modal_version/__init__.py +++ b/modal_version/__init__.py @@ -7,7 +7,7 @@ major_number = 0 # Bump this manually on breaking changes, then reset the number in _version_generated.py -minor_number = 63 +minor_number = 64 # Right now, automatically increment the patch number in CI __version__ = f"{major_number}.{minor_number}.{max(build_number, 0)}" diff --git a/modal_version/_version_generated.py b/modal_version/_version_generated.py index cc01dc140..fd8c13099 100644 --- a/modal_version/_version_generated.py +++ b/modal_version/_version_generated.py @@ -1,4 +1,4 @@ # Copyright Modal Labs 2024 # Note: Reset this value to -1 whenever you make a minor `0.X` release of the client. -build_number = 99 # git: ddee9b4 +build_number = 0 # git: d8f403b diff --git a/test/app_test.py b/test/app_test.py index 7dba92af5..7bd3a9453 100644 --- a/test/app_test.py +++ b/test/app_test.py @@ -302,8 +302,8 @@ def test_hydrated_other_app_object_gets_referenced(servicer, client): with Volume.ephemeral(client=client) as vol: app.function(volumes={"/vol": vol})(dummy) # implicitly load vol deploy_app(app, client=client) - app_set_objects_req = ctx.pop_request("AppSetObjects") - assert vol.object_id in app_set_objects_req.unindexed_object_ids + function_create_req: api_pb2.FunctionCreateRequest = ctx.pop_request("FunctionCreate") + assert vol.object_id in {obj.object_id for obj in function_create_req.function.object_dependencies} def test_hasattr(): diff --git a/test/conftest.py b/test/conftest.py index 4bc10e3e0..b94890c1c 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -171,6 +171,7 @@ def __init__(self, blob_host, blobs): self.app_client_disconnect_count = 0 self.app_get_logs_initial_count = 0 self.app_set_objects_count = 0 + self.app_publish_count = 0 self.volume_counter = 0 # Volume-id -> commit/reload count @@ -310,8 +311,17 @@ async def AppGetObjects(self, stream): object_ids = self.app_objects.get(request.app_id, {}) objects = list(object_ids.items()) if request.include_unindexed: - unindexed_object_ids = self.app_unindexed_objects.get(request.app_id, []) + unindexed_object_ids = set() + for object_id in object_ids.values(): + if object_id.startswith("fu-"): + definition = self.app_functions[object_id] + unindexed_object_ids |= {obj.object_id for obj in definition.object_dependencies} objects += [(None, object_id) for object_id in unindexed_object_ids] + # TODO(michael) This perpetuates a hack! The container_test tests rely on hardcoded unindexed_object_ids + # but we now look those up dynamically from the indexed objects in (the real) AppGetObjects. But the + # container tests never actually set indexed objects on the app. We need a total rewrite here. + if (None, "im-1") not in objects: + objects.append((None, "im-1")) items = [ api_pb2.AppGetObjectsItem(tag=tag, object=self.get_object_metadata(object_id)) for tag, object_id in objects ] @@ -334,6 +344,21 @@ async def AppDeploy(self, stream): self.app_state_history[request.app_id].append(api_pb2.APP_STATE_DEPLOYED) await stream.send_message(api_pb2.AppDeployResponse(url="http://test.modal.com/foo/bar")) + async def AppPublish(self, stream): + request: api_pb2.AppPublishRequest = await stream.recv_message() + for key, val in request.definition_ids.items(): + assert key.startswith("fu-") + assert val.startswith("de-") + # TODO(michael) add some other assertions once we make the mock server represent real RPCs more accurately + self.app_publish_count += 1 + self.app_objects[request.app_id] = {**request.function_ids, **request.class_ids} + self.app_state_history[request.app_id].append(request.app_state) + if request.app_state == api_pb2.AppState.APP_STATE_DEPLOYED: + self.deployed_apps[request.name] = request.app_id + await stream.send_message(api_pb2.AppPublishResponse(url="http://test.modal.com/foo/bar")) + else: + await stream.send_message(api_pb2.AppPublishResponse()) + async def AppGetByDeploymentName(self, stream): request: api_pb2.AppGetByDeploymentNameRequest = await stream.recv_message() await stream.send_message(api_pb2.AppGetByDeploymentNameResponse(app_id=self.deployed_apps.get(request.name))) @@ -694,6 +719,7 @@ async def FunctionCreate(self, stream): web_url=function.web_url, use_function_id=function.use_function_id or function_id, use_method_name=function.use_method_name, + definition_id=f"de-{self.n_functions}", ), ) ) @@ -842,9 +868,9 @@ async def FunctionUpdateSchedulingParams(self, stream): async def ImageGetOrCreate(self, stream): request: api_pb2.ImageGetOrCreateRequest = await stream.recv_message() - for k in self.images: - if request.image.SerializeToString() == self.images[k].SerializeToString(): - await stream.send_message(api_pb2.ImageGetOrCreateResponse(image_id=k)) + for image_id, image in self.images.items(): + if request.image.SerializeToString() == image.SerializeToString(): + await stream.send_message(api_pb2.ImageGetOrCreateResponse(image_id=image_id)) return idx = len(self.images) + 1 image_id = f"im-{idx}" diff --git a/test/container_test.py b/test/container_test.py index a6b2702ec..02d82c39b 100644 --- a/test/container_test.py +++ b/test/container_test.py @@ -132,7 +132,6 @@ def _container_args( ) else: webhook_config = None - function_def = api_pb2.Function( module_name=module_name, function_name=function_name, @@ -723,7 +722,7 @@ def test_cls_web_endpoint(servicer): @skip_github_non_linux def test_cls_web_asgi_construction(servicer): servicer.app_objects.setdefault("ap-1", {}).setdefault("square", "fu-2") - servicer.app_functions["fu-2"] = api_pb2.FunctionHandleMetadata() + servicer.app_functions["fu-2"] = api_pb2.Function() inputs = _get_web_inputs(method_name="asgi_web") ret = _run_container( diff --git a/test/live_reload_test.py b/test/live_reload_test.py index 60ef642b8..e8bf93b0f 100644 --- a/test/live_reload_test.py +++ b/test/live_reload_test.py @@ -21,7 +21,7 @@ def app_ref(test_dir): async def test_live_reload(app_ref, server_url_env, servicer): async with serve_app.aio(app, app_ref): await asyncio.sleep(3.0) - assert servicer.app_set_objects_count == 1 + assert servicer.app_publish_count == 1 assert servicer.app_client_disconnect_count == 1 assert servicer.app_get_logs_initial_count == 0 @@ -31,7 +31,7 @@ async def test_live_reload_with_logs(app_ref, server_url_env, servicer): with enable_output(): async with serve_app.aio(app, app_ref): await asyncio.sleep(3.0) - assert servicer.app_set_objects_count == 1 + assert servicer.app_publish_count == 1 assert servicer.app_client_disconnect_count == 1 assert servicer.app_get_logs_initial_count == 1 @@ -51,8 +51,8 @@ async def fake_watch(): # TODO ideally we would assert the specific expected number here, but this test # is consistently flaking in CI and I cannot reproduce locally to debug. # I'm relaxing the assertion for now to stop the test from blocking deployments. - # assert servicer.app_set_objects_count == 4 # 1 + number of file changes - assert servicer.app_set_objects_count > 1 + # assert servicer.app_publish_count == 4 # 1 + number of file changes + assert servicer.app_publish_count > 1 assert servicer.app_client_disconnect_count == 1 foo = app.indexed_objects["foo"] assert isinstance(foo, Function) @@ -69,7 +69,7 @@ async def fake_watch(): async with serve_app.aio(app, app_ref, _watcher=fake_watch()): pass - assert servicer.app_set_objects_count == 1 # Should create the initial app once + assert servicer.app_publish_count == 1 # Should create the initial app once assert servicer.app_client_disconnect_count == 1 diff --git a/test/runner_test.py b/test/runner_test.py index 8e179bc8e..7207ed4b8 100644 --- a/test/runner_test.py +++ b/test/runner_test.py @@ -18,7 +18,7 @@ def test_run_app(servicer, client): pass ctx.pop_request("AppCreate") - ctx.pop_request("AppSetObjects") + ctx.pop_request("AppPublish") ctx.pop_request("AppClientDisconnect")