diff --git a/src/object_ref.jl b/src/object_ref.jl index bab6585a..fe8892ac 100644 --- a/src/object_ref.jl +++ b/src/object_ref.jl @@ -1,10 +1,8 @@ mutable struct ObjectRef oid_hex::String - owner_address::ray_jll.Address - serialized_object_status::String - function ObjectRef(oid_hex, owner_address, serialized_object_status; add_local_ref=true) - objref = new(oid_hex, owner_address, serialized_object_status) + function ObjectRef(oid_hex; add_local_ref=true) + objref = new(oid_hex) if add_local_ref worker = ray_jll.GetCoreWorker() ray_jll.AddLocalReference(worker, objref.oid) @@ -17,10 +15,6 @@ mutable struct ObjectRef end end -function ObjectRef(oid_hex::AbstractString; kwargs...) - return ObjectRef(oid_hex, ray_jll.Address(), ""; kwargs...) -end - ObjectRef(oid::ray_jll.ObjectID; kwargs...) = ObjectRef(ray_jll.Hex(oid); kwargs...) function finalize_object_ref(obj::ObjectRef) @@ -93,32 +87,47 @@ function has_owner(obj_ref::ObjectRef) return !isempty(ray_jll.SerializeAsString(owner_address)) end +function _get_ownership_info(obj_ref::ObjectRef) + worker = ray_jll.GetCoreWorker() + owner_address = ray_jll.Address() + serialized_object_status = StdString() + + ray_jll.GetOwnershipInfo(worker, obj_ref.oid, CxxPtr(owner_address), + CxxPtr(serialized_object_status)) + + return owner_address, serialized_object_status +end + # TODO: this is not currently used pending investigation of how to properly handle ownership # see https://github.com/beacon-biosignals/Ray.jl/issues/77#issuecomment-1717675779 # and https://github.com/beacon-biosignals/Ray.jl/pull/108 -function _register_ownership(obj_ref::ObjectRef, outer_obj_ref::Union{ObjectRef,Nothing}) - @debug """Registering ownership for $(obj_ref) - owner address: $(obj_ref.owner_address) - status: $(bytes2hex(codeunits(obj_ref.serialized_object_status))) - contained in $(outer_obj_ref)""" - - outer_object_id = if outer_obj_ref !== nothing - outer_obj_ref.oid - else - ray_jll.FromNil(ray_jll.ObjectID) - end - - worker = ray_jll.GetCoreWorker() +function _register_ownership(obj_ref::ObjectRef, outer_obj_ref::Union{ObjectRef,Nothing}, + owner_address::ray_jll.Address, + serialized_object_status::String) if !has_owner(obj_ref) - serialized_object_status = safe_convert(StdString, obj_ref.serialized_object_status) + @debug """ + Registering ownership for $(obj_ref) + owner address: $(owner_address) + status: $(bytes2hex(codeunits(serialized_object_status))) + contained in: $(outer_obj_ref)""" + + worker = ray_jll.GetCoreWorker() + + outer_object_id = if outer_obj_ref !== nothing + outer_obj_ref.oid + else + ray_jll.FromNil(ray_jll.ObjectID) + end + + serialized_object_status = safe_convert(StdString, serialized_object_status) # https://github.com/ray-project/ray/blob/ray-2.5.1/python/ray/_raylet.pyx#L3329 # https://github.com/ray-project/ray/blob/ray-2.5.1/src/ray/core_worker/core_worker.h#L543 ray_jll.RegisterOwnershipInfoAndResolveFuture(worker, obj_ref.oid, outer_object_id, - obj_ref.owner_address, + owner_address, serialized_object_status) else - @debug "attempted to register ownership but object already has known owner: $(obj_ref)" + @debug "Skipping registering ownership for $(obj_ref) as object has known owner" end return nothing @@ -126,32 +135,11 @@ end # We cannot serialize pointers between processes function Serialization.serialize(s::AbstractSerializer, obj_ref::ObjectRef) - worker = ray_jll.GetCoreWorker() - - hex_str = hex_identifier(obj_ref) - owner_address = ray_jll.Address() - serialized_object_status = StdString() - - # Prefer serializing ownership information from the core worker backend - ray_jll.GetOwnershipInfo(worker, obj_ref.oid, CxxPtr(owner_address), - CxxPtr(serialized_object_status)) - - @debug "serialize ObjectRef:\noid: $hex_str\nowner address $owner_address" - serialize_type(s, typeof(obj_ref)) - serialize(s, hex_str) - serialize(s, owner_address) - serialize(s, safe_convert(String, serialized_object_status)) - + serialize(s, hex_identifier(obj_ref)) return nothing end function Serialization.deserialize(s::AbstractSerializer, ::Type{ObjectRef}) - hex_str = deserialize(s) - owner_address = deserialize(s) - serialized_object_status = deserialize(s) - - @debug "deserialize ObjectRef:\noid: $hex_str\nowner address: $owner_address" - - return ObjectRef(hex_str, owner_address, serialized_object_status) + return ObjectRef(deserialize(s)) end diff --git a/src/ray_serializer.jl b/src/ray_serializer.jl index 74485eef..09ebbb9b 100644 --- a/src/ray_serializer.jl +++ b/src/ray_serializer.jl @@ -1,3 +1,8 @@ +struct OwnershipInfo + owner_address::ray_jll.Address + serialized_object_status::String +end + mutable struct RaySerializer{I<:IO} <: AbstractSerializer # Fields required by all AbstractSerializers io::I @@ -6,11 +11,18 @@ mutable struct RaySerializer{I<:IO} <: AbstractSerializer pending_refs::Vector{Int} version::Int - # Inlined object references encountered during serializing + # Inlined object references encountered during serializing/deserialization object_refs::Set{ObjectRef} + # Deserialized object reference metadata used for registering ownership + object_owner::Dict{ObjectRef,OwnershipInfo} + function RaySerializer{I}(io::I) where {I<:IO} - return new(io, 0, IdDict(), Int[], Serialization.ser_version, Set{ObjectRef}()) + version = Serialization.ser_version + object_refs = Set{ObjectRef}() + object_owner = Dict{ObjectRef,OwnershipInfo}() + + return new(io, 0, IdDict(), Int[], version, object_refs, object_owner) end end @@ -32,12 +44,29 @@ end function Serialization.serialize(s::RaySerializer, obj_ref::ObjectRef) push!(s.object_refs, obj_ref) - return invoke(serialize, Tuple{AbstractSerializer,ObjectRef}, s, obj_ref) + + owner_address, serialized_object_status = _get_ownership_info(obj_ref) + + invoke(serialize, Tuple{AbstractSerializer,ObjectRef}, s, obj_ref) + + # Append ownership information when serializing an `ObjectRef` with the `RaySerializer`. + # This information will be deserialized another worker process and used during object + # reference registration. + serialize(s, owner_address) + serialize(s, safe_convert(String, serialized_object_status)) + + return nothing end function Serialization.deserialize(s::RaySerializer, T::Type{ObjectRef}) obj_ref = invoke(deserialize, Tuple{AbstractSerializer,Type{ObjectRef}}, s, T) + + owner_address = deserialize(s) + serialized_object_status = deserialize(s) + s.object_owner[obj_ref] = OwnershipInfo(owner_address, serialized_object_status) + push!(s.object_refs, obj_ref) + return obj_ref end @@ -110,7 +139,9 @@ function deserialize_from_ray_object(ray_obj::SharedPtr{ray_jll.RayObject}, end for inner_object_ref in s.object_refs - _register_ownership(inner_object_ref, outer_object_ref) + (; owner_address, serialized_object_status) = s.object_owner[inner_object_ref] + _register_ownership(inner_object_ref, outer_object_ref, owner_address, + serialized_object_status) end # TODO: add an option to not rethrow diff --git a/test/object_ref.jl b/test/object_ref.jl index 7d2a691b..5bef6f9b 100644 --- a/test/object_ref.jl +++ b/test/object_ref.jl @@ -13,16 +13,10 @@ end obj_ref = ObjectRef(hex_str) @test Ray.hex_identifier(obj_ref) == hex_str @test obj_ref.oid == ray_jll.FromHex(ray_jll.ObjectID, hex_str) - @test obj_ref.owner_address == ray_jll.Address() @test obj_ref == ObjectRef(hex_str) @test hash(obj_ref) == hash(ObjectRef(hex_str)) end - @testset "no owner address constructor" begin - hex_str = "f"^(2 * 28) - @test ObjectRef(hex_str, ray_jll.Address(), "").owner_address == ray_jll.Address() - end - @testset "show" begin hex_str = "f"^(2 * 28) obj_ref = ObjectRef(hex_str) diff --git a/test/object_store.jl b/test/object_store.jl index 42532f17..16fee9af 100644 --- a/test/object_store.jl +++ b/test/object_store.jl @@ -56,28 +56,8 @@ end @testset "Object owner" begin - obj = Ray.put(1) - # ownership only embedded in ObjectRef on serialization - result = Ray.deserialize_from_ray_object(Ray.serialize_to_ray_object(obj)) - @test result.owner_address == Ray.get_owner_address(obj) - end - - @testset "deepcopy object reference owner address" begin - obj1 = Ray.put(42) - addr = Ray.get_owner_address(obj1) - obj2 = ObjectRef(Ray.hex_identifier(obj1), addr, "") - obj3 = deepcopy(obj2) - - @test obj1.owner_address != addr # Usually only populated upon deserialization - @test obj2.owner_address == addr - @test obj3.owner_address == addr - - finalize(obj2) - yield() - - # Avoid comparing against `addr` here as the finalizer could modify it in place - # allowing this test to pass. - @test obj3.owner_address == Ray.get_owner_address(obj1) + obj_ref = Ray.put(1) + @test Ray.has_owner(obj_ref) end end diff --git a/test/task.jl b/test/task.jl index 6abef060..c33634bb 100644 --- a/test/task.jl +++ b/test/task.jl @@ -97,11 +97,9 @@ end @test Ray.has_owner(return_ref) @test Ray.has_owner(remote_ref) - # Convert address to string to compare return_ref_addr = Ray.get_owner_address(return_ref) remote_ref_addr = Ray.get_owner_address(remote_ref) @test return_ref_addr != remote_ref_addr - @test remote_ref_addr == remote_ref.owner_address @test Ray.get(remote_ref) == 2 end