Skip to content

Commit

Permalink
add/remove local ref on ObjectRef construction/finalize, scaffolding …
Browse files Browse the repository at this point in the history
…for testing reference counts (#126)

* add/remove local ref in Ray.jl code

* let's try this...

* use lambda because of template inference failure

* GetAllLocalReferences with wrapped return map type

* ref counting tests to exercise add/remove local reference

* add local ref false

* ==, hash, and show for ObjectID; dont' return ref from Nil

* get all reference counts in Ray.jl; specialized deepcopy

* use hex string in objectref, finalize not async

* restore async to finaliezr and yield in tests

* also exercise the task return and construction counting

* Update ray_julia_jll/src/wrappers/any.jl

Co-authored-by: Curtis Vogt <curtis.vogt@gmail.com>

* Update ray_julia_jll/src/wrappers/any.jl

Co-authored-by: Curtis Vogt <curtis.vogt@gmail.com>

* Update test/object_ref.jl

Co-authored-by: Curtis Vogt <curtis.vogt@gmail.com>

* attempt to fix segault on tests

* Update ray_julia_jll/test/reference_counting.jl

Co-authored-by: Curtis Vogt <curtis.vogt@gmail.com>

* Update src/object_ref.jl

Co-authored-by: Curtis Vogt <curtis.vogt@gmail.com>

* Apply suggestions from code review

Co-authored-by: Curtis Vogt <curtis.vogt@gmail.com>

* the ol' test shuffle

* revert un hexing

---------

Co-authored-by: Curtis Vogt <curtis.vogt@gmail.com>
  • Loading branch information
kleinschmidt and omus committed Sep 22, 2023
1 parent d275606 commit 6f0db4e
Show file tree
Hide file tree
Showing 11 changed files with 198 additions and 12 deletions.
33 changes: 31 additions & 2 deletions ray_julia_jll/deps/wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,7 @@ JLCXX_MODULE define_julia_module(jlcxx::Module& mod)
// the function. If you fail to do this you'll get a "No appropriate factory for type" upon
// attempting to use the shared library in Julia.

// Resource map type wrapper
mod.add_type<std::unordered_map<std::string, double>>("CxxMapStringDouble");
mod.method("_setindex!", [](std::unordered_map<std::string, double> &map,
double val,
Expand Down Expand Up @@ -473,7 +474,11 @@ JLCXX_MODULE define_julia_module(jlcxx::Module& mod)
mod.add_type<ObjectID>("ObjectID")
.method("ObjectIDFromHex", &ObjectID::FromHex)
.method("ObjectIDFromRandom", &ObjectID::FromRandom)
.method("ObjectIDFromNil", &ObjectID::Nil)
.method("ObjectIDFromNil", []() {
auto id = ObjectID::Nil();
ObjectID id_deref = id;
return id_deref;
})
.method("Hex", &ObjectID::Hex);

// enum Language
Expand Down Expand Up @@ -672,6 +677,24 @@ JLCXX_MODULE define_julia_module(jlcxx::Module& mod)
return std::unique_ptr<TaskArgByValue>(t);
});

// ObjectID reference count map wrapper
typedef std::unordered_map<ObjectID, std::pair<size_t, size_t>> ReferenceCountMap;
mod.add_type<ReferenceCountMap>("CxxMapObjectIDPairIntInt");
mod.method("_keys", [](ReferenceCountMap &map) {
std::vector<ObjectID> keys(map.size());
for (auto kv : map) {
keys.push_back(kv.first);
}
return keys;
});
mod.method("_getindex", [](ReferenceCountMap &map, ObjectID &key) {
auto val = map[key];
std::vector<size_t> retval;
retval.push_back(val.first);
retval.push_back(val.second);
return retval;
});

// class WorkerContext
// https://github.com/ray-project/ray/blob/ray-2.5.1/src/ray/core_worker/context.h#L30
mod.add_type<ray::core::WorkerContext>("WorkerContext")
Expand All @@ -687,7 +710,13 @@ JLCXX_MODULE define_julia_module(jlcxx::Module& mod)
.method("GetOwnerAddress", &ray::core::CoreWorker::GetOwnerAddress)
.method("GetOwnershipInfo", &ray::core::CoreWorker::GetOwnershipInfo)
.method("GetObjectRefs", &ray::core::CoreWorker::GetObjectRefs)
.method("RegisterOwnershipInfoAndResolveFuture", &ray::core::CoreWorker::RegisterOwnershipInfoAndResolveFuture);
.method("RegisterOwnershipInfoAndResolveFuture", &ray::core::CoreWorker::RegisterOwnershipInfoAndResolveFuture)
// .method("AddLocalReference", &ray::core::CoreWorker::AddLocalReference)
.method("AddLocalReference", [](ray::core::CoreWorker &worker, ObjectID &object_id) {
return worker.AddLocalReference(object_id);
})
.method("RemoveLocalReference", &ray::core::CoreWorker::RemoveLocalReference)
.method("GetAllReferenceCounts", &ray::core::CoreWorker::GetAllReferenceCounts);
mod.method("_GetCoreWorker", &_GetCoreWorker);

mod.method("_submit_task", &_submit_task);
Expand Down
8 changes: 7 additions & 1 deletion ray_julia_jll/src/wrappers/any.jl
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,13 @@ FromInt(::Type{JobID}, num::Integer) = JobIDFromInt(num)

FromHex(::Type{ObjectID}, str::AbstractString) = ObjectIDFromHex(str)
FromRandom(::Type{ObjectID}) = ObjectIDFromRandom()
Nil(::Type{ObjectID}) = ObjectIDFromNil()
FromNil(::Type{ObjectID}) = ObjectIDFromNil()

ObjectID(str::AbstractString) = FromHex(ObjectID, str)

Base.show(io::IO, x::ObjectID) = show(io, "ObjectID(\"$(Hex(x))\")")
Base.:(==)(a::ObjectID, b::ObjectID) = Hex(a) == Hex(b)
Base.hash(x::ObjectID, h::UInt) = hash(ObjectID, hash(Hex(x), h))

#####
##### TaskArg
Expand Down
31 changes: 31 additions & 0 deletions ray_julia_jll/test/reference_counting.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
@testset "Reference counting" begin
using ray_julia_jll: FromRandom, ObjectID, Hex, AddLocalReference, RemoveLocalReference,
GetAllReferenceCounts, _keys, _getindex, GetCoreWorker

worker = GetCoreWorker()
# need to convert to hex string because these are just pointers
has_count(oid) = Hex(oid) in Hex.(_keys(GetAllReferenceCounts(worker)))
local_count(oid) = first(_getindex(GetAllReferenceCounts(worker), oid))

oid = FromRandom(ObjectID)

@test !has_count(oid)

AddLocalReference(worker, oid)
@test has_count(oid)
@test local_count(oid) == 1

AddLocalReference(worker, oid)
@test local_count(oid) == 2

RemoveLocalReference(worker, oid)
@test local_count(oid) == 1

RemoveLocalReference(worker, oid)
@test !has_count(oid)
@test local_count(oid) == 0

RemoveLocalReference(worker, oid)
@test !has_count(oid)
@test local_count(oid) == 0
end
1 change: 1 addition & 0 deletions ray_julia_jll/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ end
include("gcs_client.jl")
setup_core_worker() do
include("put_get.jl")
include("reference_counting.jl")
end
end
end
60 changes: 53 additions & 7 deletions src/object_ref.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,59 @@
mutable struct ObjectRef
oid::ray_jll.ObjectIDAllocated
oid_hex::String
owner_address::Union{ray_jll.AddressAllocated,Nothing}
serialized_object_status::String

function ObjectRef(oid_hex, owner_address, serialized_object_status;
add_local_ref=true)
objref = new(oid_hex, owner_address, serialized_object_status)
if add_local_ref
worker = ray_jll.GetCoreWorker()
ray_jll.AddLocalReference(worker, objref.oid)
end
finalizer(objref) do objref
errormonitor(@async finalize_object_ref(objref))
return nothing
end
return objref
end
end

function finalize_object_ref(obj::ObjectRef)
@debug "Removing local ref for ObjectID $(obj.oid_hex)"
worker = ray_jll.GetCoreWorker()
oid = ray_jll.FromHex(ray_jll.ObjectID, obj.oid_hex)
ray_jll.RemoveLocalReference(worker, oid)
return nothing
end

function Base.getproperty(x::ObjectRef, prop::Symbol)
return if prop == :oid
ray_jll.FromHex(ray_jll.ObjectID, x.oid_hex)
else
getfield(x, prop)
end
end

# in order to actually increment the local ref count appropriately when we
# `deepcopy` an ObjectRef and setup the appropriate finalizer, this
# specialization calls the constructor after deepcopying the fields.
function Base.deepcopy_internal(x::ObjectRef, stackdict::IdDict)
fieldnames = Base.fieldnames(typeof(x))
fieldcopies = ntuple(length(fieldnames)) do i
@debug "deep copying x.$(fieldnames[i])"
fieldval = getfield(x, fieldnames[i])
return Base.deepcopy_internal(fieldval, stackdict)
end

xcp = ObjectRef(fieldcopies...; add_local_ref=true)
stackdict[x] = xcp

return xcp
end

ObjectRef(oid::ray_jll.ObjectIDAllocated) = ObjectRef(oid, nothing, "")
ObjectRef(hex_str::AbstractString) = ObjectRef(ray_jll.FromHex(ray_jll.ObjectID, hex_str))
hex_identifier(obj_ref::ObjectRef) = String(ray_jll.Hex(obj_ref.oid))
ObjectRef(oid::ray_jll.ObjectIDAllocated; kwargs...) = ObjectRef(ray_jll.Hex(oid); kwargs...)
ObjectRef(oid_hex::AbstractString; kwargs...) = ObjectRef(oid_hex, nothing, ""; kwargs...)
hex_identifier(obj_ref::ObjectRef) = obj_ref.oid_hex
Base.:(==)(a::ObjectRef, b::ObjectRef) = hex_identifier(a) == hex_identifier(b)

function Base.hash(obj_ref::ObjectRef, h::UInt)
Expand Down Expand Up @@ -47,7 +94,7 @@ function _register_ownership(obj_ref::ObjectRef, outer_obj_ref::Union{ObjectRef,
outer_object_id = if outer_obj_ref !== nothing
outer_obj_ref.oid
else
ray_jll.Nil(ray_jll.ObjectID)
ray_jll.FromNil(ray_jll.ObjectID)
end

if !isnothing(obj_ref.owner_address) && !has_owner(obj_ref)
Expand Down Expand Up @@ -87,12 +134,11 @@ function Serialization.deserialize(s::AbstractSerializer, ::Type{ObjectRef})
owner_address_str = deserialize(s)
serialized_object_status = deserialize(s)

object_id = ray_jll.FromHex(ray_jll.ObjectID, hex_str)
owner_address = nothing
if !isempty(owner_address_str)
owner_address = ray_jll.Address()
ray_jll.ParseFromString(owner_address, owner_address_str)
end

return ObjectRef(object_id, owner_address, serialized_object_status)
return ObjectRef(hex_str, owner_address, serialized_object_status)
end
29 changes: 28 additions & 1 deletion src/object_store.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ function put(data)
bytes = serialize_to_bytes(data)
buffer = ray_jll.LocalMemoryBuffer(bytes, sizeof(bytes), true)
ray_obj = ray_jll.RayObject(buffer)
return ObjectRef(ray_jll.put(ray_obj, StdVector{ray_jll.ObjectID}()))
# `CoreWorker::Put` initializes the local ref count to 1
return ObjectRef(ray_jll.put(ray_obj, StdVector{ray_jll.ObjectID}());
add_local_ref=false)
end

put(obj_ref::ObjectRef) = obj_ref
Expand Down Expand Up @@ -78,3 +80,28 @@ function Base.wait(obj_ref::ObjectRef)
end
return nothing
end

#####
##### Reference counting
#####

"""
get_all_reference_counts()
For testing/debugging purposes, returns a
`Dict{ray_jll.ObjectID,Tuple{Int,Int}}` containing the reference counts for each
object ID that the local raylet knows about. The first count is the "local
reference" count, and the second is the count of submitted tasks depending on
the object.
"""
function get_all_reference_counts()
worker = ray_jll.GetCoreWorker()
counts_raw = ray_jll.GetAllReferenceCounts(worker)

# we need to convert this to a dict we can actually work with. we use the
# hex representation of the ID so we can avoid messing with the internal
# ObjectID representation...
counts = Dict(ray_jll.Hex(k) => Tuple(Int.(ray_jll._getindex(counts_raw, k)))
for k in ray_jll._keys(counts_raw))
return counts
end
4 changes: 3 additions & 1 deletion src/runtime.jl
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,9 @@ function submit_task(f::Function, args::Tuple, kwargs::NamedTuple=NamedTuple();
serialized_runtime_env_info,
resources)
end
return ObjectRef(oid)
# CoreWorker::SubmitTask calls TaskManager::AddPendingTask which initializes
# the local ref count to 1, so we don't need to do that here.
return ObjectRef(oid; add_local_ref=false)
end

# Adapted from `prepare_args_internal`:
Expand Down
2 changes: 2 additions & 0 deletions test/object_ref.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# this runs inside setup_core_worker()

function serialize_deserialize(x)
io = IOBuffer()
serialize(io, x)
Expand Down
26 changes: 26 additions & 0 deletions test/object_store.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,30 @@
@test obj_ref1 === obj_ref2
@test Ray.get(obj_ref1) == Ray.get(obj_ref2)
end

@testset "Local ref count: put, deepcopy, and constructed object ref" begin
obj = Ray.put(nothing)
oid = obj.oid_hex

@test local_count(obj) == 1

obj2 = deepcopy(obj)
@test local_count(obj) == 2

finalize(obj2)
yield() # allows async task that makes the API call to run

@test local_count(obj) == 1

obj3 = ObjectRef(obj.oid_hex)
@test local_count(obj) == 2

finalize(obj3)
yield()
@test local_count(obj) == 1

finalize(obj)
yield()
@test local_count(oid) == 0
end
end
10 changes: 10 additions & 0 deletions test/task.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,16 @@ end
@test Ray.get_task_id() != task_id != subtask_id
end

@testset "Local ref count: Task return object" begin
obj = Ray.submit_task(getpid, ())
oid = obj.oid_hex
@test local_count(oid) == 1

finalize(obj)
yield()
@test local_count(oid) == 0
end

@testset "object ownership" begin
@testset "unknown owner" begin
invalid_ref = ObjectRef(ray_jll.FromRandom(ray_jll.ObjectID))
Expand Down
6 changes: 6 additions & 0 deletions test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,16 @@ function setup_core_worker(body)
try
body()
finally
# let object ref finalizers run...
GC.gc()
yield()
ray_jll.shutdown_driver()
end
end

local_count(o::Ray.ObjectRef) = local_count(o.oid_hex)
local_count(oid_hex) = first(get(Ray.get_all_reference_counts(), oid_hex, 0))

# Useful in running tests which require us to re-run `Ray.init` which currently can only be
# called once as it is infeasible to reset global `Ref`'s using `isassigned` conditionals.
macro process_eval(ex)
Expand Down

0 comments on commit 6f0db4e

Please sign in to comment.