From 6f0db4ec2a54932df1ca319f14d23e4b5dc1cb39 Mon Sep 17 00:00:00 2001 From: Dave Kleinschmidt Date: Fri, 22 Sep 2023 16:30:20 -0400 Subject: [PATCH] add/remove local ref on ObjectRef construction/finalize, scaffolding for testing reference counts (#126) * add/remove local ref in Ray.jl code * let's try this... * use lambda because of template inference failure * GetAllLocalReferences with wrapped return map type * ref counting tests to exercise add/remove local reference * add local ref false * ==, hash, and show for ObjectID; dont' return ref from Nil * get all reference counts in Ray.jl; specialized deepcopy * use hex string in objectref, finalize not async * restore async to finaliezr and yield in tests * also exercise the task return and construction counting * Update ray_julia_jll/src/wrappers/any.jl Co-authored-by: Curtis Vogt * Update ray_julia_jll/src/wrappers/any.jl Co-authored-by: Curtis Vogt * Update test/object_ref.jl Co-authored-by: Curtis Vogt * attempt to fix segault on tests * Update ray_julia_jll/test/reference_counting.jl Co-authored-by: Curtis Vogt * Update src/object_ref.jl Co-authored-by: Curtis Vogt * Apply suggestions from code review Co-authored-by: Curtis Vogt * the ol' test shuffle * revert un hexing --------- Co-authored-by: Curtis Vogt --- ray_julia_jll/deps/wrapper.cc | 33 ++++++++++++- ray_julia_jll/src/wrappers/any.jl | 8 +++- ray_julia_jll/test/reference_counting.jl | 31 ++++++++++++ ray_julia_jll/test/runtests.jl | 1 + src/object_ref.jl | 60 +++++++++++++++++++++--- src/object_store.jl | 29 +++++++++++- src/runtime.jl | 4 +- test/object_ref.jl | 2 + test/object_store.jl | 26 ++++++++++ test/task.jl | 10 ++++ test/utils.jl | 6 +++ 11 files changed, 198 insertions(+), 12 deletions(-) create mode 100644 ray_julia_jll/test/reference_counting.jl diff --git a/ray_julia_jll/deps/wrapper.cc b/ray_julia_jll/deps/wrapper.cc index a97d8236..81d304ff 100644 --- a/ray_julia_jll/deps/wrapper.cc +++ b/ray_julia_jll/deps/wrapper.cc @@ -393,6 +393,7 @@ JLCXX_MODULE define_julia_module(jlcxx::Module& mod) // the function. If you fail to do this you'll get a "No appropriate factory for type" upon // attempting to use the shared library in Julia. + // Resource map type wrapper mod.add_type>("CxxMapStringDouble"); mod.method("_setindex!", [](std::unordered_map &map, double val, @@ -473,7 +474,11 @@ JLCXX_MODULE define_julia_module(jlcxx::Module& mod) mod.add_type("ObjectID") .method("ObjectIDFromHex", &ObjectID::FromHex) .method("ObjectIDFromRandom", &ObjectID::FromRandom) - .method("ObjectIDFromNil", &ObjectID::Nil) + .method("ObjectIDFromNil", []() { + auto id = ObjectID::Nil(); + ObjectID id_deref = id; + return id_deref; + }) .method("Hex", &ObjectID::Hex); // enum Language @@ -672,6 +677,24 @@ JLCXX_MODULE define_julia_module(jlcxx::Module& mod) return std::unique_ptr(t); }); + // ObjectID reference count map wrapper + typedef std::unordered_map> ReferenceCountMap; + mod.add_type("CxxMapObjectIDPairIntInt"); + mod.method("_keys", [](ReferenceCountMap &map) { + std::vector keys(map.size()); + for (auto kv : map) { + keys.push_back(kv.first); + } + return keys; + }); + mod.method("_getindex", [](ReferenceCountMap &map, ObjectID &key) { + auto val = map[key]; + std::vector retval; + retval.push_back(val.first); + retval.push_back(val.second); + return retval; + }); + // class WorkerContext // https://github.com/ray-project/ray/blob/ray-2.5.1/src/ray/core_worker/context.h#L30 mod.add_type("WorkerContext") @@ -687,7 +710,13 @@ JLCXX_MODULE define_julia_module(jlcxx::Module& mod) .method("GetOwnerAddress", &ray::core::CoreWorker::GetOwnerAddress) .method("GetOwnershipInfo", &ray::core::CoreWorker::GetOwnershipInfo) .method("GetObjectRefs", &ray::core::CoreWorker::GetObjectRefs) - .method("RegisterOwnershipInfoAndResolveFuture", &ray::core::CoreWorker::RegisterOwnershipInfoAndResolveFuture); + .method("RegisterOwnershipInfoAndResolveFuture", &ray::core::CoreWorker::RegisterOwnershipInfoAndResolveFuture) + // .method("AddLocalReference", &ray::core::CoreWorker::AddLocalReference) + .method("AddLocalReference", [](ray::core::CoreWorker &worker, ObjectID &object_id) { + return worker.AddLocalReference(object_id); + }) + .method("RemoveLocalReference", &ray::core::CoreWorker::RemoveLocalReference) + .method("GetAllReferenceCounts", &ray::core::CoreWorker::GetAllReferenceCounts); mod.method("_GetCoreWorker", &_GetCoreWorker); mod.method("_submit_task", &_submit_task); diff --git a/ray_julia_jll/src/wrappers/any.jl b/ray_julia_jll/src/wrappers/any.jl index 95f7cdf7..4bc0f238 100644 --- a/ray_julia_jll/src/wrappers/any.jl +++ b/ray_julia_jll/src/wrappers/any.jl @@ -163,7 +163,13 @@ FromInt(::Type{JobID}, num::Integer) = JobIDFromInt(num) FromHex(::Type{ObjectID}, str::AbstractString) = ObjectIDFromHex(str) FromRandom(::Type{ObjectID}) = ObjectIDFromRandom() -Nil(::Type{ObjectID}) = ObjectIDFromNil() +FromNil(::Type{ObjectID}) = ObjectIDFromNil() + +ObjectID(str::AbstractString) = FromHex(ObjectID, str) + +Base.show(io::IO, x::ObjectID) = show(io, "ObjectID(\"$(Hex(x))\")") +Base.:(==)(a::ObjectID, b::ObjectID) = Hex(a) == Hex(b) +Base.hash(x::ObjectID, h::UInt) = hash(ObjectID, hash(Hex(x), h)) ##### ##### TaskArg diff --git a/ray_julia_jll/test/reference_counting.jl b/ray_julia_jll/test/reference_counting.jl new file mode 100644 index 00000000..d0096c8b --- /dev/null +++ b/ray_julia_jll/test/reference_counting.jl @@ -0,0 +1,31 @@ +@testset "Reference counting" begin + using ray_julia_jll: FromRandom, ObjectID, Hex, AddLocalReference, RemoveLocalReference, + GetAllReferenceCounts, _keys, _getindex, GetCoreWorker + + worker = GetCoreWorker() + # need to convert to hex string because these are just pointers + has_count(oid) = Hex(oid) in Hex.(_keys(GetAllReferenceCounts(worker))) + local_count(oid) = first(_getindex(GetAllReferenceCounts(worker), oid)) + + oid = FromRandom(ObjectID) + + @test !has_count(oid) + + AddLocalReference(worker, oid) + @test has_count(oid) + @test local_count(oid) == 1 + + AddLocalReference(worker, oid) + @test local_count(oid) == 2 + + RemoveLocalReference(worker, oid) + @test local_count(oid) == 1 + + RemoveLocalReference(worker, oid) + @test !has_count(oid) + @test local_count(oid) == 0 + + RemoveLocalReference(worker, oid) + @test !has_count(oid) + @test local_count(oid) == 0 +end diff --git a/ray_julia_jll/test/runtests.jl b/ray_julia_jll/test/runtests.jl index 451cf5f9..8a18e196 100644 --- a/ray_julia_jll/test/runtests.jl +++ b/ray_julia_jll/test/runtests.jl @@ -20,6 +20,7 @@ end include("gcs_client.jl") setup_core_worker() do include("put_get.jl") + include("reference_counting.jl") end end end diff --git a/src/object_ref.jl b/src/object_ref.jl index 85d5bee7..bf8cb966 100644 --- a/src/object_ref.jl +++ b/src/object_ref.jl @@ -1,12 +1,59 @@ mutable struct ObjectRef - oid::ray_jll.ObjectIDAllocated + oid_hex::String owner_address::Union{ray_jll.AddressAllocated,Nothing} serialized_object_status::String + + function ObjectRef(oid_hex, owner_address, serialized_object_status; + add_local_ref=true) + objref = new(oid_hex, owner_address, serialized_object_status) + if add_local_ref + worker = ray_jll.GetCoreWorker() + ray_jll.AddLocalReference(worker, objref.oid) + end + finalizer(objref) do objref + errormonitor(@async finalize_object_ref(objref)) + return nothing + end + return objref + end +end + +function finalize_object_ref(obj::ObjectRef) + @debug "Removing local ref for ObjectID $(obj.oid_hex)" + worker = ray_jll.GetCoreWorker() + oid = ray_jll.FromHex(ray_jll.ObjectID, obj.oid_hex) + ray_jll.RemoveLocalReference(worker, oid) + return nothing +end + +function Base.getproperty(x::ObjectRef, prop::Symbol) + return if prop == :oid + ray_jll.FromHex(ray_jll.ObjectID, x.oid_hex) + else + getfield(x, prop) + end +end + +# in order to actually increment the local ref count appropriately when we +# `deepcopy` an ObjectRef and setup the appropriate finalizer, this +# specialization calls the constructor after deepcopying the fields. +function Base.deepcopy_internal(x::ObjectRef, stackdict::IdDict) + fieldnames = Base.fieldnames(typeof(x)) + fieldcopies = ntuple(length(fieldnames)) do i + @debug "deep copying x.$(fieldnames[i])" + fieldval = getfield(x, fieldnames[i]) + return Base.deepcopy_internal(fieldval, stackdict) + end + + xcp = ObjectRef(fieldcopies...; add_local_ref=true) + stackdict[x] = xcp + + return xcp end -ObjectRef(oid::ray_jll.ObjectIDAllocated) = ObjectRef(oid, nothing, "") -ObjectRef(hex_str::AbstractString) = ObjectRef(ray_jll.FromHex(ray_jll.ObjectID, hex_str)) -hex_identifier(obj_ref::ObjectRef) = String(ray_jll.Hex(obj_ref.oid)) +ObjectRef(oid::ray_jll.ObjectIDAllocated; kwargs...) = ObjectRef(ray_jll.Hex(oid); kwargs...) +ObjectRef(oid_hex::AbstractString; kwargs...) = ObjectRef(oid_hex, nothing, ""; kwargs...) +hex_identifier(obj_ref::ObjectRef) = obj_ref.oid_hex Base.:(==)(a::ObjectRef, b::ObjectRef) = hex_identifier(a) == hex_identifier(b) function Base.hash(obj_ref::ObjectRef, h::UInt) @@ -47,7 +94,7 @@ function _register_ownership(obj_ref::ObjectRef, outer_obj_ref::Union{ObjectRef, outer_object_id = if outer_obj_ref !== nothing outer_obj_ref.oid else - ray_jll.Nil(ray_jll.ObjectID) + ray_jll.FromNil(ray_jll.ObjectID) end if !isnothing(obj_ref.owner_address) && !has_owner(obj_ref) @@ -87,12 +134,11 @@ function Serialization.deserialize(s::AbstractSerializer, ::Type{ObjectRef}) owner_address_str = deserialize(s) serialized_object_status = deserialize(s) - object_id = ray_jll.FromHex(ray_jll.ObjectID, hex_str) owner_address = nothing if !isempty(owner_address_str) owner_address = ray_jll.Address() ray_jll.ParseFromString(owner_address, owner_address_str) end - return ObjectRef(object_id, owner_address, serialized_object_status) + return ObjectRef(hex_str, owner_address, serialized_object_status) end diff --git a/src/object_store.jl b/src/object_store.jl index f97a46b6..bcdb4af0 100644 --- a/src/object_store.jl +++ b/src/object_store.jl @@ -8,7 +8,9 @@ function put(data) bytes = serialize_to_bytes(data) buffer = ray_jll.LocalMemoryBuffer(bytes, sizeof(bytes), true) ray_obj = ray_jll.RayObject(buffer) - return ObjectRef(ray_jll.put(ray_obj, StdVector{ray_jll.ObjectID}())) + # `CoreWorker::Put` initializes the local ref count to 1 + return ObjectRef(ray_jll.put(ray_obj, StdVector{ray_jll.ObjectID}()); + add_local_ref=false) end put(obj_ref::ObjectRef) = obj_ref @@ -78,3 +80,28 @@ function Base.wait(obj_ref::ObjectRef) end return nothing end + +##### +##### Reference counting +##### + +""" + get_all_reference_counts() + +For testing/debugging purposes, returns a +`Dict{ray_jll.ObjectID,Tuple{Int,Int}}` containing the reference counts for each +object ID that the local raylet knows about. The first count is the "local +reference" count, and the second is the count of submitted tasks depending on +the object. +""" +function get_all_reference_counts() + worker = ray_jll.GetCoreWorker() + counts_raw = ray_jll.GetAllReferenceCounts(worker) + + # we need to convert this to a dict we can actually work with. we use the + # hex representation of the ID so we can avoid messing with the internal + # ObjectID representation... + counts = Dict(ray_jll.Hex(k) => Tuple(Int.(ray_jll._getindex(counts_raw, k))) + for k in ray_jll._keys(counts_raw)) + return counts +end diff --git a/src/runtime.jl b/src/runtime.jl index b39f50a2..28bda19f 100644 --- a/src/runtime.jl +++ b/src/runtime.jl @@ -199,7 +199,9 @@ function submit_task(f::Function, args::Tuple, kwargs::NamedTuple=NamedTuple(); serialized_runtime_env_info, resources) end - return ObjectRef(oid) + # CoreWorker::SubmitTask calls TaskManager::AddPendingTask which initializes + # the local ref count to 1, so we don't need to do that here. + return ObjectRef(oid; add_local_ref=false) end # Adapted from `prepare_args_internal`: diff --git a/test/object_ref.jl b/test/object_ref.jl index 15b3984a..019596eb 100644 --- a/test/object_ref.jl +++ b/test/object_ref.jl @@ -1,3 +1,5 @@ +# this runs inside setup_core_worker() + function serialize_deserialize(x) io = IOBuffer() serialize(io, x) diff --git a/test/object_store.jl b/test/object_store.jl index 1033c25d..8ce8ab91 100644 --- a/test/object_store.jl +++ b/test/object_store.jl @@ -29,4 +29,30 @@ @test obj_ref1 === obj_ref2 @test Ray.get(obj_ref1) == Ray.get(obj_ref2) end + + @testset "Local ref count: put, deepcopy, and constructed object ref" begin + obj = Ray.put(nothing) + oid = obj.oid_hex + + @test local_count(obj) == 1 + + obj2 = deepcopy(obj) + @test local_count(obj) == 2 + + finalize(obj2) + yield() # allows async task that makes the API call to run + + @test local_count(obj) == 1 + + obj3 = ObjectRef(obj.oid_hex) + @test local_count(obj) == 2 + + finalize(obj3) + yield() + @test local_count(obj) == 1 + + finalize(obj) + yield() + @test local_count(oid) == 0 + end end diff --git a/test/task.jl b/test/task.jl index ca9d6a18..c4e315ca 100644 --- a/test/task.jl +++ b/test/task.jl @@ -54,6 +54,16 @@ end @test Ray.get_task_id() != task_id != subtask_id end +@testset "Local ref count: Task return object" begin + obj = Ray.submit_task(getpid, ()) + oid = obj.oid_hex + @test local_count(oid) == 1 + + finalize(obj) + yield() + @test local_count(oid) == 0 +end + @testset "object ownership" begin @testset "unknown owner" begin invalid_ref = ObjectRef(ray_jll.FromRandom(ray_jll.ObjectID)) diff --git a/test/utils.jl b/test/utils.jl index 2b82bccf..6d549038 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -18,10 +18,16 @@ function setup_core_worker(body) try body() finally + # let object ref finalizers run... + GC.gc() + yield() ray_jll.shutdown_driver() end end +local_count(o::Ray.ObjectRef) = local_count(o.oid_hex) +local_count(oid_hex) = first(get(Ray.get_all_reference_counts(), oid_hex, 0)) + # Useful in running tests which require us to re-run `Ray.init` which currently can only be # called once as it is infeasible to reset global `Ref`'s using `isassigned` conditionals. macro process_eval(ex)