forked from apache/tvm
-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[RUST][FRONTEND] Add rust frontend v0.1 (apache#2292)
- Loading branch information
1 parent
1b61f2f
commit 8008fd0
Showing
80 changed files
with
5,642 additions
and
2,275 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,28 +1,11 @@ | ||
[package] | ||
name = "tvm" | ||
version = "0.1.0" | ||
license = "Apache-2.0" | ||
description = "TVM Rust runtime" | ||
repository = "https://github.com/dmlc/tvm" | ||
readme = "README.md" | ||
keywords = ["tvm", "nnvm"] | ||
categories = ["api-bindings", "science"] | ||
authors = ["TVM Contributors"] | ||
|
||
[features] | ||
default = ["nom/std"] | ||
sgx = ["nom/alloc"] | ||
|
||
[dependencies] | ||
bounded-spsc-queue = "0.4.0" | ||
error-chain = { version = "0.12.0", default-features = false } | ||
itertools = "0.7.8" | ||
lazy_static = "1.1.0" | ||
ndarray = "0.11.2" | ||
nom = {version = "4.0.0", default-features = false } | ||
serde = "1.0.59" | ||
serde_derive = "1.0.79" | ||
serde_json = "1.0.17" | ||
|
||
[target.'cfg(not(target_env = "sgx"))'.dependencies] | ||
num_cpus = "1.8.0" | ||
[workspace] | ||
members = [ | ||
"common", | ||
"runtime", | ||
"runtime/tests/test_tvm_basic", | ||
"runtime/tests/test_nnvm", | ||
"frontend", | ||
"frontend/tests/basics", | ||
"frontend/tests/callback", | ||
"frontend/examples/resnet" | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
target | ||
**/*.rs.bk | ||
Cargo.lock | ||
/tvm-sys/src/bindgen.rs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
[package] | ||
name = "tvm-common" | ||
version = "0.1.0" | ||
authors = ["TVM Contributors"] | ||
license = "Apache-2.0" | ||
|
||
[features] | ||
runtime = [] | ||
frontend = ["tvm-sys"] | ||
|
||
[dependencies] | ||
error-chain = { version = "0.12.0", default-features = false } | ||
tvm-sys = { version = "0.1.0", path = "tvm-sys", optional = true } |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
//! Error types for `TVMArgValue` and `TVMRetValue` conversions. | ||
|
||
error_chain! { | ||
errors { | ||
TryFromTVMArgValueError(expected: String, actual: String) { | ||
description("mismatched types while converting from TVMArgValue") | ||
display("expected `{}` but given `{}`", expected, actual) | ||
} | ||
|
||
TryFromTVMRetValueError(expected: String, actual: String) { | ||
description("mismatched types while downcasting TVMRetValue") | ||
display("invalid downcast: expected `{}` but given `{}`", expected, actual) | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
//! This crate contains the refactored basic components required | ||
//! for `runtime` and `frontend` TVM crates. | ||
|
||
#![crate_name = "tvm_common"] | ||
#![recursion_limit = "1024"] | ||
#![allow(non_camel_case_types, unused_imports)] | ||
#![feature(box_syntax, try_from)] | ||
|
||
#[macro_use] | ||
extern crate error_chain; | ||
|
||
/// Unified ffi module for both runtime and frontend crates. | ||
pub mod ffi { | ||
#![allow(non_camel_case_types, non_snake_case, non_upper_case_globals, unused)] | ||
|
||
#[cfg(feature = "frontend")] | ||
pub extern crate tvm_sys as ts; | ||
|
||
#[cfg(feature = "runtime")] | ||
pub mod runtime { | ||
use std::os::raw::{c_char, c_int, c_void}; | ||
|
||
include!(concat!(env!("CARGO_MANIFEST_DIR"), "/src/c_runtime_api.rs")); | ||
|
||
pub type BackendPackedCFunc = extern "C" fn( | ||
args: *const TVMValue, | ||
type_codes: *const c_int, | ||
num_args: c_int, | ||
) -> c_int; | ||
} | ||
} | ||
|
||
pub mod errors; | ||
pub mod ty; | ||
pub mod value; | ||
|
||
pub use errors::*; | ||
pub use ty::TVMTypeCode; | ||
pub use value::{TVMArgValue, TVMRetValue, TVMValue}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
//! This module containes `TVMTypeCode` and `TVMType` with some conversion methods. | ||
//! | ||
//! # Example | ||
//! | ||
//! ``` | ||
//! let dtype = TVMType::from("float"); | ||
//! println!("dtype is: {}", dtype); | ||
//! ``` | ||
|
||
use std::{ | ||
ffi::{CStr, CString}, | ||
fmt::{self, Display, Formatter}, | ||
}; | ||
|
||
/// TVM type codes. | ||
#[repr(u32)] | ||
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] | ||
pub enum TVMTypeCode { | ||
kDLInt = 0, | ||
kDLUInt = 1, | ||
kDLFloat = 2, | ||
kHandle = 3, | ||
kNull = 4, | ||
kTVMType = 5, | ||
kTVMContext = 6, | ||
kArrayHandle = 7, | ||
kNodeHandle = 8, | ||
kModuleHandle = 9, | ||
kFuncHandle = 10, | ||
kStr = 11, | ||
kBytes = 12, | ||
kNDArrayContainer = 13, | ||
} | ||
|
||
impl Default for TVMTypeCode { | ||
fn default() -> Self { | ||
TVMTypeCode::kDLInt | ||
} | ||
} | ||
|
||
impl From<TVMTypeCode> for i64 { | ||
fn from(arg: TVMTypeCode) -> i64 { | ||
match arg { | ||
TVMTypeCode::kDLInt => 0, | ||
TVMTypeCode::kDLUInt => 1, | ||
TVMTypeCode::kDLFloat => 2, | ||
TVMTypeCode::kHandle => 3, | ||
TVMTypeCode::kNull => 4, | ||
TVMTypeCode::kTVMType => 5, | ||
TVMTypeCode::kTVMContext => 6, | ||
TVMTypeCode::kArrayHandle => 7, | ||
TVMTypeCode::kNodeHandle => 8, | ||
TVMTypeCode::kModuleHandle => 9, | ||
TVMTypeCode::kFuncHandle => 10, | ||
TVMTypeCode::kStr => 11, | ||
TVMTypeCode::kBytes => 12, | ||
TVMTypeCode::kNDArrayContainer => 13, | ||
} | ||
} | ||
} | ||
|
||
impl Into<TVMTypeCode> for i64 { | ||
fn into(self) -> TVMTypeCode { | ||
match self { | ||
0 => TVMTypeCode::kDLInt, | ||
1 => TVMTypeCode::kDLUInt, | ||
2 => TVMTypeCode::kDLFloat, | ||
3 => TVMTypeCode::kHandle, | ||
4 => TVMTypeCode::kNull, | ||
5 => TVMTypeCode::kTVMType, | ||
6 => TVMTypeCode::kTVMContext, | ||
7 => TVMTypeCode::kArrayHandle, | ||
8 => TVMTypeCode::kNodeHandle, | ||
9 => TVMTypeCode::kModuleHandle, | ||
10 => TVMTypeCode::kFuncHandle, | ||
11 => TVMTypeCode::kStr, | ||
12 => TVMTypeCode::kBytes, | ||
13 => TVMTypeCode::kNDArrayContainer, | ||
_ => unreachable!(), | ||
} | ||
} | ||
} | ||
|
||
impl Display for TVMTypeCode { | ||
fn fmt(&self, f: &mut Formatter) -> fmt::Result { | ||
write!( | ||
f, | ||
"{}", | ||
match self { | ||
TVMTypeCode::kDLInt => "int", | ||
TVMTypeCode::kDLUInt => "uint", | ||
TVMTypeCode::kDLFloat => "float", | ||
TVMTypeCode::kHandle => "handle", | ||
TVMTypeCode::kNull => "null", | ||
TVMTypeCode::kTVMType => "TVM type", | ||
TVMTypeCode::kTVMContext => "TVM context", | ||
TVMTypeCode::kArrayHandle => "Array handle", | ||
TVMTypeCode::kNodeHandle => "Node handle", | ||
TVMTypeCode::kModuleHandle => "Module handle", | ||
TVMTypeCode::kFuncHandle => "Function handle", | ||
TVMTypeCode::kStr => "string", | ||
TVMTypeCode::kBytes => "bytes", | ||
TVMTypeCode::kNDArrayContainer => "ndarray container", | ||
} | ||
) | ||
} | ||
} | ||
|
||
macro_rules! impl_prim_type { | ||
($type:ty, $variant:ident) => { | ||
impl<'a> From<&'a $type> for TVMTypeCode { | ||
fn from(_arg: &$type) -> Self { | ||
TVMTypeCode::$variant | ||
} | ||
} | ||
|
||
impl<'a> From<&'a mut $type> for TVMTypeCode { | ||
fn from(_arg: &mut $type) -> Self { | ||
TVMTypeCode::$variant | ||
} | ||
} | ||
}; | ||
} | ||
|
||
impl_prim_type!(usize, kDLInt); | ||
impl_prim_type!(i64, kDLInt); | ||
impl_prim_type!(i32, kDLInt); | ||
impl_prim_type!(i16, kDLInt); | ||
impl_prim_type!(i8, kDLInt); | ||
|
||
impl_prim_type!(u64, kDLUInt); | ||
impl_prim_type!(u32, kDLUInt); | ||
impl_prim_type!(u16, kDLUInt); | ||
impl_prim_type!(u8, kDLUInt); | ||
|
||
impl_prim_type!(f64, kDLFloat); | ||
impl_prim_type!(f32, kDLFloat); | ||
|
||
impl_prim_type!(str, kStr); | ||
impl_prim_type!(CStr, kStr); | ||
impl_prim_type!(String, kStr); | ||
impl_prim_type!(CString, kStr); | ||
|
||
impl_prim_type!([u8], kBytes); |
Oops, something went wrong.