Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Rust] Fix memory leak #2 #8725

Merged
merged 13 commits into from
Aug 24, 2021
8 changes: 8 additions & 0 deletions include/tvm/runtime/c_runtime_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,14 @@ TVM_DLL int TVMObjectGetTypeIndex(TVMObjectHandle obj, unsigned* out_tindex);
*/
TVM_DLL int TVMObjectTypeKey2Index(const char* type_key, unsigned* out_tindex);

/*!
* \brief Convert type index to type key.
* \param tindex The type index.
* \param out_type_key The output type key.
* \return 0 when success, nonzero when failure happens
*/
TVM_DLL int TVMObjectTypeIndex2Key(unsigned tindex, char** out_type_key);

/*!
* \brief Increase the reference count of an object.
*
Expand Down
13 changes: 3 additions & 10 deletions rust/tvm-macros/src/object.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,27 +147,20 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream {
}
}

impl<'a> From<#ref_id> for #tvm_rt_crate::ArgValue<'a> {
fn from(object_ref: #ref_id) -> #tvm_rt_crate::ArgValue<'a> {
impl<'a> From<&'a #ref_id> for #tvm_rt_crate::ArgValue<'a> {
fn from(object_ref: &'a #ref_id) -> #tvm_rt_crate::ArgValue<'a> {
use std::ffi::c_void;
let object_ptr = &object_ref.0;
match object_ptr {
None => {
#tvm_rt_crate::ArgValue::
ObjectHandle(std::ptr::null::<c_void>() as *mut c_void)
}
Some(value) => value.clone().into()
Some(value) => value.into()
}
}
}

impl<'a> From<&#ref_id> for #tvm_rt_crate::ArgValue<'a> {
fn from(object_ref: &#ref_id) -> #tvm_rt_crate::ArgValue<'a> {
let oref: #ref_id = object_ref.clone();
#tvm_rt_crate::ArgValue::<'a>::from(oref)
}
}

impl<'a> std::convert::TryFrom<#tvm_rt_crate::ArgValue<'a>> for #ref_id {
type Error = #error;

Expand Down
13 changes: 8 additions & 5 deletions rust/tvm-rt/src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,19 +45,22 @@ external! {
fn array_size(array: ObjectRef) -> i64;
}

impl<T: IsObjectRef> IsObjectRef for Array<T> {
impl<T: IsObjectRef + 'static> IsObjectRef for Array<T> {
type Object = Object;
fn as_ptr(&self) -> Option<&ObjectPtr<Self::Object>> {
self.object.as_ptr()
}

fn into_ptr(self) -> Option<ObjectPtr<Self::Object>> {
self.object.into_ptr()
}

fn from_ptr(object_ptr: Option<ObjectPtr<Self::Object>>) -> Self {
let object_ref = match object_ptr {
Some(o) => o.into(),
_ => panic!(),
};

Array {
object: object_ref,
_data: PhantomData,
Expand All @@ -67,7 +70,7 @@ impl<T: IsObjectRef> IsObjectRef for Array<T> {

impl<T: IsObjectRef> Array<T> {
pub fn from_vec(data: Vec<T>) -> Result<Array<T>> {
let iter = data.into_iter().map(T::into_arg_value).collect();
let iter = data.iter().map(T::into_arg_value).collect();

let func = Function::get("runtime.Array").expect(
"runtime.Array function is not registered, this is most likely a build or linking error",
Expand Down Expand Up @@ -151,9 +154,9 @@ impl<T: IsObjectRef> FromIterator<T> for Array<T> {
}
}

impl<'a, T: IsObjectRef> From<Array<T>> for ArgValue<'a> {
fn from(array: Array<T>) -> ArgValue<'a> {
array.object.into()
impl<'a, T: IsObjectRef> From<&'a Array<T>> for ArgValue<'a> {
fn from(array: &'a Array<T>) -> ArgValue<'a> {
(&array.object).into()
}
}

Expand Down
17 changes: 9 additions & 8 deletions rust/tvm-rt/src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ use std::{

use crate::errors::Error;

pub use super::to_function::{ToFunction, Typed};
pub use super::to_function::{RawArgs, ToFunction, Typed};
use crate::object::AsArgValue;
pub use tvm_sys::{ffi, ArgValue, RetValue};

pub type Result<T> = std::result::Result<T, Error>;
Expand Down Expand Up @@ -153,12 +154,12 @@ macro_rules! impl_to_fn {
where
Error: From<Err>,
Out: TryFrom<RetValue, Error = Err>,
$($t: Into<ArgValue<'static>>),*
$($t: for<'a> AsArgValue<'a>),*
{
fn from(func: Function) -> Self {
#[allow(non_snake_case)]
Box::new(move |$($t : $t),*| {
let args = vec![ $($t.into()),* ];
let args = vec![ $((&$t).as_arg_value()),* ];
Ok(func.invoke(args)?.try_into()?)
})
}
Expand Down Expand Up @@ -196,8 +197,8 @@ impl TryFrom<RetValue> for Function {
}
}

impl<'a> From<Function> for ArgValue<'a> {
fn from(func: Function) -> ArgValue<'a> {
impl<'a> From<&'a Function> for ArgValue<'a> {
fn from(func: &'a Function) -> ArgValue<'a> {
if func.handle().is_null() {
ArgValue::Null
} else {
Expand Down Expand Up @@ -291,12 +292,12 @@ where
}

pub fn register_untyped<S: Into<String>>(
f: fn(Vec<ArgValue<'static>>) -> Result<RetValue>,
f: for<'a> fn(Vec<ArgValue<'a>>) -> Result<RetValue>,
name: S,
override_: bool,
) -> Result<()> {
// TODO(@jroesch): can we unify all the code.
let func = f.to_function();
//TODO(@jroesch): can we unify the untpyed and typed registration functions.
let func = ToFunction::<RawArgs, RetValue>::to_function(f);
let name = name.into();
// Not sure about this code
let handle = func.handle();
Expand Down
7 changes: 4 additions & 3 deletions rust/tvm-rt/src/graph_rt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,12 @@ impl GraphRt {

let runtime_create_fn_ret = runtime_create_fn.invoke(vec![
graph.into(),
lib.into(),
(&lib).into(),
(&dev.device_type).into(),
// NOTE you must pass the device id in as i32 because that's what TVM expects
(dev.device_id as i32).into(),
]);

let graph_executor_module: Module = runtime_create_fn_ret?.try_into()?;
Ok(Self {
module: graph_executor_module,
Expand All @@ -79,7 +80,7 @@ impl GraphRt {
pub fn set_input(&mut self, name: &str, input: NDArray) -> Result<()> {
let ref set_input_fn = self.module.get_function("set_input", false)?;

set_input_fn.invoke(vec![name.into(), input.into()])?;
set_input_fn.invoke(vec![name.into(), (&input).into()])?;
Ok(())
}

Expand All @@ -101,7 +102,7 @@ impl GraphRt {
/// Extract the ith output from the graph executor and write the results into output.
pub fn get_output_into(&mut self, i: i64, output: NDArray) -> Result<()> {
let get_output_fn = self.module.get_function("get_output", false)?;
get_output_fn.invoke(vec![i.into(), output.into()])?;
get_output_fn.invoke(vec![i.into(), (&output).into()])?;
Ok(())
}
}
21 changes: 11 additions & 10 deletions rust/tvm-rt/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,16 +130,17 @@ mod tests {
);
}

#[test]
fn bytearray() {
let w = vec![1u8, 2, 3, 4, 5];
let v = ByteArray::from(w.as_slice());
let tvm: ByteArray = RetValue::from(v).try_into().unwrap();
assert_eq!(
tvm.data(),
w.iter().copied().collect::<Vec<u8>>().as_slice()
);
}
// todo(@jroesch): #8800 Follow up with ByteArray RetValue ownership.
// #[test]
// fn bytearray() {
// let w = vec![1u8, 2, 3, 4, 5];
// let v = ByteArray::from(w.as_slice());
// let tvm: ByteArray = RetValue::from(v).try_into().unwrap();
// assert_eq!(
// tvm.data(),
// w.iter().copied().collect::<Vec<u8>>().as_slice()
// );
// }

#[test]
fn ty() {
Expand Down
16 changes: 8 additions & 8 deletions rust/tvm-rt/src/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,18 +58,18 @@ external! {
fn map_items(map: ObjectRef) -> Array<ObjectRef>;
}

impl<K, V> FromIterator<(K, V)> for Map<K, V>
impl<'a, K: 'a, V: 'a> FromIterator<(&'a K, &'a V)> for Map<K, V>
where
K: IsObjectRef,
V: IsObjectRef,
{
fn from_iter<T: IntoIterator<Item = (K, V)>>(iter: T) -> Self {
fn from_iter<T: IntoIterator<Item = (&'a K, &'a V)>>(iter: T) -> Self {
let iter = iter.into_iter();
let (lower_bound, upper_bound) = iter.size_hint();
let mut buffer: Vec<ArgValue> = Vec::with_capacity(upper_bound.unwrap_or(lower_bound) * 2);
for (k, v) in iter {
buffer.push(k.into());
buffer.push(v.into())
buffer.push(k.into_arg_value());
buffer.push(v.into_arg_value());
}
Self::from_data(buffer).expect("failed to convert from data")
}
Expand Down Expand Up @@ -202,13 +202,13 @@ where
}
}

impl<'a, K, V> From<Map<K, V>> for ArgValue<'a>
impl<'a, K, V> From<&'a Map<K, V>> for ArgValue<'a>
where
K: IsObjectRef,
V: IsObjectRef,
{
fn from(map: Map<K, V>) -> ArgValue<'a> {
map.object.into()
fn from(map: &'a Map<K, V>) -> ArgValue<'a> {
(&map.object).into()
}
}

Expand Down Expand Up @@ -268,7 +268,7 @@ mod test {
let mut std_map: HashMap<TString, TString> = HashMap::new();
std_map.insert("key1".into(), "value1".into());
std_map.insert("key2".into(), "value2".into());
let tvm_map = Map::from_iter(std_map.clone().into_iter());
let tvm_map = Map::from_iter(std_map.iter());
let back_map = tvm_map.into();
assert_eq!(std_map, back_map);
}
Expand Down
15 changes: 15 additions & 0 deletions rust/tvm-rt/src/ndarray.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,21 @@ impl NDArrayContainer {
.cast::<NDArrayContainer>()
}
}

pub fn as_mut_ptr<'a>(object_ptr: &ObjectPtr<NDArrayContainer>) -> *mut NDArrayContainer
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems legit though some of the other functions on NDArrayContainer seem concerning from a rust memory model perspective.

pub fn leak<'a>(object_ptr: ObjectPtr<NDArrayContainer>) -> &'a mut NDArrayContainer seems like it can be used to hand out multiple &mut NDArrayContainer and should be considered unsafe.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@robo-corg I think this is modeled after the other leak family of functions in Rust which take ownership https://doc.rust-lang.org/std/boxed/struct.Box.html#method.leak.

I admit this code is very tricky and scary it might be worth opening up follow up issue.

where
NDArrayContainer: 'a,
{
let base_offset = memoffset::offset_of!(NDArrayContainer, dl_tensor) as isize;
unsafe {
object_ptr
.ptr
.as_ptr()
.cast::<u8>()
.offset(base_offset)
.cast::<NDArrayContainer>()
}
}
}

fn cow_usize<'a>(slice: &[i64]) -> Cow<'a, [usize]> {
Expand Down
19 changes: 16 additions & 3 deletions rust/tvm-rt/src/object/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,19 @@ mod object_ptr;

pub use object_ptr::{IsObject, Object, ObjectPtr, ObjectRef};

pub trait AsArgValue<'a> {
fn as_arg_value(&'a self) -> ArgValue<'a>;
}

impl<'a, T: 'static> AsArgValue<'a> for T
where
&'a T: Into<ArgValue<'a>>,
{
fn as_arg_value(&'a self) -> ArgValue<'a> {
self.into()
}
}

// TODO we would prefer to blanket impl From/TryFrom ArgValue/RetValue, but we
// can't because of coherence rules. Instead, we generate them in the macro, and
// add what we can (including Into instead of From) as subtraits.
Expand All @@ -37,8 +50,8 @@ pub trait IsObjectRef:
Sized
+ Clone
+ Into<RetValue>
+ for<'a> AsArgValue<'a>
+ TryFrom<RetValue, Error = Error>
+ for<'a> Into<ArgValue<'a>>
+ for<'a> TryFrom<ArgValue<'a>, Error = Error>
+ std::fmt::Debug
{
Expand All @@ -51,8 +64,8 @@ pub trait IsObjectRef:
Self::from_ptr(None)
}

fn into_arg_value<'a>(self) -> ArgValue<'a> {
self.into()
fn into_arg_value<'a>(&'a self) -> ArgValue<'a> {
self.as_arg_value()
}

fn from_arg_value<'a>(arg_value: ArgValue<'a>) -> Result<Self, Error> {
Expand Down
Loading