Skip to content

Commit

Permalink
WIP 2
Browse files Browse the repository at this point in the history
  • Loading branch information
jroesch committed Aug 11, 2021
1 parent ecd325d commit 4878df8
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 14 deletions.
36 changes: 34 additions & 2 deletions rust/tvm-rt/src/object/object_ptr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@
use std::convert::TryFrom;
use std::ffi::CString;
use std::fmt;
use std::os::raw::c_char;
use std::ptr::NonNull;
use std::sync::atomic::AtomicI32;

use tvm_macros::Object;
use tvm_sys::ffi::{self, TVMObjectFree, TVMObjectRetain, TVMObjectTypeKey2Index};
use tvm_sys::ffi::{self, TVMObjectFree, TVMObjectRetain, TVMObjectTypeKey2Index, TVMObjectTypeIndex2Key};
use tvm_sys::{ArgValue, RetValue};

use crate::errors::Error;
Expand Down Expand Up @@ -98,6 +99,16 @@ impl Object {
}
}

fn get_type_key(&self) -> String {
let mut cstring: * mut c_char = std::ptr::null_mut();
unsafe {
if TVMObjectTypeIndex2Key(self.type_index, &mut cstring as * mut _) != 0 {
panic!("{}", crate::get_last_error());
}
CString::from_raw(cstring).into_string().unwrap()
}
}

fn get_type_index<T: IsObject>() -> u32 {
let type_key = T::TYPE_KEY;
let cstring = CString::new(type_key).expect("type key must not contain null characters");
Expand Down Expand Up @@ -143,6 +154,7 @@ impl Object {
pub(self) fn dec_ref(&self) {
let raw_ptr = self as *const Object as *mut Object as *mut std::ffi::c_void;
unsafe {
println!("ref_count={} type_key={}", self.count(), self.get_type_key());
assert_eq!(TVMObjectFree(raw_ptr), 0);
}
}
Expand Down Expand Up @@ -308,14 +320,34 @@ impl<'a, T: IsObject> TryFrom<RetValue> for ObjectPtr<T> {
}
}

// This is the problem, if we generate a pointer with ref-count = k, then we we enter the section
// in which we will call the packed function we will leak the ref-count at k, the problem we have
// often start with k - 1 ref-count, clone the pointer to pass to the function leaving at least
// one dangling reference.
//
// The effect is that arguments to packed functions will have their ref count always grow by one
// each time they are invoked by a function. The identity case is ease to see as the inner ref-count
// seems off by one, because we have leaked a reference, then convert the arg value back into an object
// allowing us to observe a reference count which should have gone away. Putting a dec ref in the conversion
// to ret value works in a limited case as it counteracts for the idenitty function the effect of bumping
// the reference allowing us to get back to normal.
//
// The problem is that ArgValue is now an owned view on to memory and so we probably need to thread a phantom
// lifetime which is used to constrain that &'a ObjectPtr T ~ ArgValue<'a> and you can only manipulate ArgValue
// as long as the data works.
//
// This probably reintroduces the need to transmute a phantom lifetime for the incoming ArgValue which may complicate
// the presentation of invoking and auto-boxing the functions.
//
// Will fix this tomorrow.
impl<'a, T: IsObject> From<ObjectPtr<T>> for ArgValue<'a> {
fn from(object_ptr: ObjectPtr<T>) -> ArgValue<'a> {
debug_assert!(object_ptr.count() >= 1);
println!("to arg value {}", object_ptr.count());
let object_ptr = object_ptr.upcast::<Object>();
match T::TYPE_KEY {
"runtime.NDArray" => {
use crate::ndarray::NDArrayContainer;
// TODO(this is probably not optimal)
let raw_ptr = NDArrayContainer::leak(object_ptr.downcast().unwrap())
as *mut NDArrayContainer as *mut std::ffi::c_void;
assert!(!raw_ptr.is_null());
Expand Down
27 changes: 15 additions & 12 deletions rust/tvm/examples/resnet/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,24 +78,27 @@ fn main() -> anyhow::Result<()> {
"/deploy_lib.so"
)))?;

let mut graph_rt = GraphRt::create_from_parts(&graph, lib, dev)?;

// parse parameters and convert to TVMByteArray
let params: Vec<u8> = fs::read(concat!(env!("CARGO_MANIFEST_DIR"), "/deploy_param.params"))?;

println!("param bytes: {}", params.len());

graph_rt.load_params(&params)?;
graph_rt.set_input("data", input)?;
graph_rt.run()?;
let mut output: Vec<f32>;

loop {
let mut graph_rt = GraphRt::create_from_parts(&graph, lib.clone(), dev)?;

// prepare to get the output
let output_shape = &[1, 1000];
let output = NDArray::empty(output_shape, Device::cpu(0), DataType::float(32, 1));
graph_rt.get_output_into(0, output.clone())?;
graph_rt.load_params(&params)?;
graph_rt.set_input("data", input.clone())?;
graph_rt.run()?;

// flatten the output as Vec<f32>
let output = output.to_vec::<f32>()?;
// prepare to get the output
let output_shape = &[1, 1000];
let output_nd = NDArray::empty(output_shape, Device::cpu(0), DataType::float(32, 1));
graph_rt.get_output_into(0, output_nd.clone())?;

// flatten the output as Vec<f32>
output = output_nd.to_vec::<f32>()?;
}

// find the maximum entry in the output and its index
let (argmax, max_prob) = output
Expand Down

0 comments on commit 4878df8

Please sign in to comment.