Skip to content

Commit

Permalink
Improve robustness of ObjectRef ownership info serialization (#203)
Browse files Browse the repository at this point in the history
* Handle null-terminated ObjectRef fields

* Move debug logic into `@debug` block

* Earlier serialization debug message

* Use safe_convert

* Safe convert for object status

* Drop special case for `isempty(owner_address_json)`

* Support `==` for `Address`

* Refactor `==` for `ObjectID`

* Always return an `Address` with `owner_address` property

* Update test to avoid using internal field

* Store serialized owner address instead of JSON

* Move `safe_convert` to `ray_julia_jll`

* Support Julia serialization of `Address`

* Support Julia serialization of `Message`

* Support `==` for `Message` subtypes

* Use `Address` serialization

* Remove empty line

* Add deepcopy test for owner_address
  • Loading branch information
omus committed Oct 18, 2023
1 parent a0c4d26 commit fc586a2
Show file tree
Hide file tree
Showing 11 changed files with 137 additions and 61 deletions.
2 changes: 1 addition & 1 deletion src/Ray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ export start_worker, submit_task, @ray_import, ObjectRef
export RayError, RaySystemError, RayTaskError

include(joinpath("ray_julia_jll", "ray_julia_jll.jl"))
using .ray_julia_jll: ray_julia_jll, ray_julia_jll as ray_jll
using .ray_julia_jll: ray_julia_jll, ray_julia_jll as ray_jll, safe_convert

include("exceptions.jl")
include("function_manager.jl")
Expand Down
65 changes: 21 additions & 44 deletions src/object_ref.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
mutable struct ObjectRef
oid_hex::String
owner_address_json::Union{String,Nothing}
owner_address::ray_jll.Address
serialized_object_status::String

function ObjectRef(oid_hex, owner_address_json, serialized_object_status;
add_local_ref=true)
if owner_address_json !== nothing && isempty(owner_address_json)
owner_address_json = nothing
end
objref = new(oid_hex, owner_address_json, serialized_object_status)
function ObjectRef(oid_hex, owner_address, serialized_object_status; add_local_ref=true)
objref = new(oid_hex, owner_address, serialized_object_status)
if add_local_ref
worker = ray_jll.GetCoreWorker()
ray_jll.AddLocalReference(worker, objref.oid)
Expand All @@ -21,6 +17,12 @@ mutable struct ObjectRef
end
end

function ObjectRef(oid_hex::AbstractString; kwargs...)
return ObjectRef(oid_hex, ray_jll.Address(), ""; kwargs...)
end

ObjectRef(oid::ray_jll.ObjectID; kwargs...) = ObjectRef(ray_jll.Hex(oid); kwargs...)

function finalize_object_ref(obj::ObjectRef)
@debug "Removing local ref for ObjectID $(obj.oid_hex)"
# XXX: should make sure core worker is still initialized before calling it
Expand All @@ -37,10 +39,6 @@ end
function Base.getproperty(x::ObjectRef, prop::Symbol)
return if prop == :oid
ray_jll.FromHex(ray_jll.ObjectID, getfield(x, :oid_hex))
elseif prop == :owner_address
owner_address_json = getfield(x, :owner_address_json)
isnothing(owner_address_json) && return nothing
ray_jll.JsonStringToMessage(ray_jll.Address, owner_address_json)
else
getfield(x, prop)
end
Expand All @@ -63,8 +61,6 @@ function Base.deepcopy_internal(x::ObjectRef, stackdict::IdDict)
return xcp
end

ObjectRef(oid::ray_jll.ObjectID; kwargs...) = ObjectRef(ray_jll.Hex(oid); kwargs...)
ObjectRef(oid_hex::AbstractString; kwargs...) = ObjectRef(oid_hex, nothing, ""; kwargs...)
hex_identifier(obj_ref::ObjectRef) = obj_ref.oid_hex
Base.:(==)(a::ObjectRef, b::ObjectRef) = hex_identifier(a) == hex_identifier(b)

Expand Down Expand Up @@ -112,23 +108,17 @@ function _register_ownership(obj_ref::ObjectRef, outer_obj_ref::Union{ObjectRef,
ray_jll.FromNil(ray_jll.ObjectID)
end

# we've overloaded getproperty for this one to create the actual owner ref
owner_address = obj_ref.owner_address

worker = ray_jll.GetCoreWorker()
if !isnothing(obj_ref.owner_address) && !has_owner(obj_ref)
if !has_owner(obj_ref)
serialized_object_status = safe_convert(StdString, obj_ref.serialized_object_status)

# https://github.com/ray-project/ray/blob/ray-2.5.1/python/ray/_raylet.pyx#L3329
# https://github.com/ray-project/ray/blob/ray-2.5.1/src/ray/core_worker/core_worker.h#L543
ray_jll.RegisterOwnershipInfoAndResolveFuture(worker, obj_ref.oid, outer_object_id,
owner_address,
obj_ref.serialized_object_status)
obj_ref.owner_address,
serialized_object_status)
else
if isnothing(obj_ref.owner_address)
@debug "attempted to register ownership but owner address is nothing: $(obj_ref)"
end
if has_owner(obj_ref)
@debug "attempted to register ownership but object already has known owner: $(obj_ref)"
end
@debug "attempted to register ownership but object already has known owner: $(obj_ref)"
end

return nothing
Expand All @@ -145,36 +135,23 @@ function Serialization.serialize(s::AbstractSerializer, obj_ref::ObjectRef)
# Prefer serializing ownership information from the core worker backend
ray_jll.GetOwnershipInfo(worker, obj_ref.oid, CxxPtr(owner_address),
CxxPtr(serialized_object_status))
# XXX: we use ~~codeunits~~ JSON here because when there are null bytes
# anywhere in the string, the `String` (or even `Vector{UInt8}`) conversion
# from `CxxWrap.StdString` will truncate the string at the first null byte.
#
# owner_address_bytes = collect(codeunits(ray_jll.SerializeAsString(owner_address)))
owner_address_json = String(ray_jll.MessageToJsonString(owner_address))

@debug "serialize ObjectRef:\noid: $hex_str\nowner address $owner_address"
serialized_object_status = String(serialized_object_status)

serialize_type(s, typeof(obj_ref))
serialize(s, hex_str)
serialize(s, owner_address_json)
serialize(s, serialized_object_status)
serialize(s, owner_address)
serialize(s, safe_convert(String, serialized_object_status))

return nothing
end

function Serialization.deserialize(s::AbstractSerializer, ::Type{ObjectRef})
hex_str = deserialize(s)
owner_address_json = deserialize(s)
owner_address = deserialize(s)
serialized_object_status = deserialize(s)

# this if/else block only exists for debug logging
if owner_address_json === nothing || isempty(owner_address_json)
owner_address_json = nothing
@debug "deserialize ObjectRef:\noid: $hex_str\nowner address: $owner_address_json"
else
owner_address = ray_jll.JsonStringToMessage(ray_jll.Address, owner_address_json)
@debug "deserialize ObjectRef:\noid: $hex_str\nowner address: $owner_address"
end
@debug "deserialize ObjectRef:\noid: $hex_str\nowner address: $owner_address"

return ObjectRef(hex_str, owner_address_json, serialized_object_status)
return ObjectRef(hex_str, owner_address, serialized_object_status)
end
42 changes: 38 additions & 4 deletions src/ray_julia_jll/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,41 @@ function JsonStringToMessage(::Type{T}, json::AbstractString) where {T<:Message}
return message
end

let msg_types = (Address, JobConfig, ObjectReference)
for T in msg_types
types = (Symbol(nameof(T), :Allocated), Symbol(nameof(T), :Dereferenced))
for A in types, B in types
@eval function Base.:(==)(a::$A, b::$B)
serialized_a = safe_convert(String, SerializeAsString(a))
serialized_b = safe_convert(String, SerializeAsString(b))
return serialized_a == serialized_b
end
end
end
end

function Serialization.serialize(s::AbstractSerializer, message::Message)
serialized_message = safe_convert(String, SerializeAsString(message))

serialize_type(s, Message)
serialize(s, supertype(typeof(message)))
serialize(s, serialized_message)

return nothing
end

function Serialization.deserialize(s::AbstractSerializer, ::Type{Message})
T = deserialize(s)
serialized_message = deserialize(s)

message = T()
ParseFromString(message, safe_convert(StdString, serialized_message))

return message
end

#####
##### Address
##### Address <: Message
#####

# there's annoying conversion from protobuf binary blobs for the "fields" so we
Expand Down Expand Up @@ -198,9 +231,10 @@ Base.show(io::IO, x::ObjectID) = write(io, "ObjectID(\"", Hex(x), "\")")
# Base.:(==)(a::ObjectID, b::ObjectID) = Hex(a) == Hex(b)
#
# is shadowed by more specific fallbacks defined by CxxWrap.
const ObjectIDTypes = (ObjectIDAllocated, ObjectIDDereferenced)
for Ta in ObjectIDTypes, Tb in ObjectIDTypes
@eval Base.:(==)(a::$Ta, b::$Tb) = Hex(a) == Hex(b)
let types = (ObjectIDAllocated, ObjectIDDereferenced)
for A in types, B in types
@eval Base.:(==)(a::$A, b::$B) = Hex(a) == Hex(b)
end
end

Base.hash(x::ObjectID, h::UInt) = hash(ObjectID, hash(Hex(x), h))
Expand Down
4 changes: 3 additions & 1 deletion src/ray_julia_jll/ray_julia_jll.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ module ray_julia_jll
using Artifacts: @artifact_str
using CxxWrap
using CxxWrap.StdLib: StdVector, SharedPtr
using Serialization
using Serialization: Serialization, AbstractSerializer, deserialize, serialize,
serialize_type
using libcxxwrap_julia_jll

@wrapmodule(joinpath(artifact"ray_julia", "julia_core_worker_lib.so"))
Expand All @@ -12,6 +13,7 @@ function __init__()
@initcxx
end # __init__()

include("upstream_fixes.jl")
include("expr.jl")
include("common.jl")

Expand Down
4 changes: 4 additions & 0 deletions src/ray_julia_jll/upstream_fixes.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Using `collect` and `ncodeunits` to ensure that the entire string is captured and not just
# up to the first null character: https://github.com/JuliaInterop/CxxWrap.jl/pull/378
safe_convert(::Type{String}, str::StdString) = String(Vector{UInt8}(collect(str)))
safe_convert(::Type{StdString}, str::String) = StdString(str, ncodeunits(str))
9 changes: 5 additions & 4 deletions test/object_ref.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@ end
obj_ref = ObjectRef(hex_str)
@test Ray.hex_identifier(obj_ref) == hex_str
@test obj_ref.oid == ray_jll.FromHex(ray_jll.ObjectID, hex_str)
@test obj_ref.owner_address === nothing
@test obj_ref.owner_address == ray_jll.Address()
@test obj_ref == ObjectRef(hex_str)
@test hash(obj_ref) == hash(ObjectRef(hex_str))
end

# test various "no owner address" constructors
@test ObjectRef(hex_str, nothing, "") == obj_ref
@test ObjectRef(hex_str, "", "") == obj_ref
@testset "no owner address constructor" begin
hex_str = "f"^(2 * 28)
@test ObjectRef(hex_str, ray_jll.Address(), "").owner_address == ray_jll.Address()
end

@testset "show" begin
Expand Down
24 changes: 20 additions & 4 deletions test/object_store.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,26 @@
@testset "Object owner" begin
obj = Ray.put(1)
# ownership only embedded in ObjectRef on serialization
obj_rt = Ray.deserialize_from_ray_object(Ray.serialize_to_ray_object(obj))
addr_json = ray_jll.MessageToJsonString(Ray.get_owner_address(obj))
addr_rt_json = ray_jll.MessageToJsonString(obj_rt.owner_address)
@test addr_json == addr_rt_json == obj_rt.owner_address_json
result = Ray.deserialize_from_ray_object(Ray.serialize_to_ray_object(obj))
@test result.owner_address == Ray.get_owner_address(obj)
end

@testset "deepcopy object reference owner address" begin
obj1 = Ray.put(42)
addr = Ray.get_owner_address(obj1)
obj2 = ObjectRef(Ray.hex_identifier(obj1), addr, "")
obj3 = deepcopy(obj2)

@test obj1.owner_address != addr # Usually only populated upon deserialization
@test obj2.owner_address == addr
@test obj3.owner_address == addr

finalize(obj2)
yield()

# Avoid comparing against `addr` here as the finalizer could modify it in place
# allowing this test to pass.
@test obj3.owner_address == Ray.get_owner_address(obj1)
end
end

Expand Down
25 changes: 25 additions & 0 deletions test/ray_julia_jll/address.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
using .ray_julia_jll: Address

@testset "Address" begin
@testset "equality" begin
addr = Address()
@test addr == Address()
@test addr == CxxPtr(Address())[]
end

@testset "json round-trip" begin
json = Dict(:rayletId => base64encode("raylet"), :ipAddress => "127.0.0.1",
:port => 10000, :workerId => base64encode("worker"))
Expand All @@ -16,4 +22,23 @@ using .ray_julia_jll: Address
result = ray_julia_jll.SerializeAsString(address)
@test result == serialized_str
end

@testset "julia serialization round-trip" begin
addr_alloc = Address()
@test addr_alloc isa ray_julia_jll.AddressAllocated
serialized_addr_alloc = sprint(serialize, addr_alloc)
result = deserialize(IOBuffer(serialized_addr_alloc))
@test result isa ray_julia_jll.AddressAllocated
@test result == addr_alloc

addr_ptr = CxxPtr(addr_alloc)
addr_deref = addr_ptr[]
@test addr_deref isa ray_julia_jll.AddressDereferenced
serialized_addr_deref = sprint(serialize, addr_deref)
result = deserialize(IOBuffer(serialized_addr_deref))
@test result isa ray_julia_jll.AddressAllocated
@test result == addr_deref

@test serialized_addr_deref == serialized_addr_alloc
end
end
1 change: 1 addition & 0 deletions test/ray_julia_jll/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ f(x) = x + 1
end

@testset "ray_julia_jll.jl" begin
include("upstream_fixes.jl")
include("expr.jl")
include("buffer.jl")
include("function_descriptor.jl")
Expand Down
16 changes: 16 additions & 0 deletions test/ray_julia_jll/upstream_fixes.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
@testset "safe_convert" begin
expected = "¿\0ÿ"
cpp_expected = "¿\0ÿ"

std_str = Ray.safe_convert(StdString, expected)
@test length(std_str) == length(cpp_expected)
@test collect(std_str) == collect(cpp_expected)
@test ncodeunits(std_str) == ncodeunits(expected)
@test codeunits(std_str) == codeunits(expected)

str = Ray.safe_convert(String, std_str)
@test length(str) == length(expected)
@test collect(str) == collect(expected)
@test ncodeunits(str) == ncodeunits(expected)
@test codeunits(str) == codeunits(expected)
end
6 changes: 3 additions & 3 deletions test/task.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,10 @@ end
@test Ray.has_owner(remote_ref)

# Convert address to string to compare
return_ref_addr = ray_jll.MessageToJsonString(Ray.get_owner_address(return_ref))
remote_ref_addr = ray_jll.MessageToJsonString(Ray.get_owner_address(remote_ref))
return_ref_addr = Ray.get_owner_address(return_ref)
remote_ref_addr = Ray.get_owner_address(remote_ref)
@test return_ref_addr != remote_ref_addr
@test remote_ref_addr == remote_ref.owner_address_json
@test remote_ref_addr == remote_ref.owner_address

@test Ray.get(remote_ref) == 2
end
Expand Down

0 comments on commit fc586a2

Please sign in to comment.