Skip to content

Commit

Permalink
[RUST][FRONTEND] Add rust frontend v0.1 (apache#2292)
Browse files Browse the repository at this point in the history
  • Loading branch information
ehsanmok authored and AWS Neo committed Feb 20, 2019
1 parent 92f7d68 commit adcc5c2
Show file tree
Hide file tree
Showing 80 changed files with 5,642 additions and 2,275 deletions.
8 changes: 4 additions & 4 deletions rust/.rustfmt.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
max_width = 100
hard_tabs = false
tab_spaces = 2
tab_spaces = 4
newline_style = "Auto"
use_small_heuristics = "Default"
indent_style = "Block"
Expand Down Expand Up @@ -38,7 +38,7 @@ trailing_comma = "Vertical"
match_block_trailing_comma = false
blank_lines_upper_bound = 1
blank_lines_lower_bound = 0
edition = "2015"
edition = "2018"
merge_derives = true
use_try_shorthand = true
use_field_init_shorthand = false
Expand All @@ -50,8 +50,8 @@ unstable_features = false
disable_all_formatting = false
skip_children = false
hide_parse_errors = false
error_on_line_overflow = false
error_on_unformatted = false
error_on_line_overflow = true
error_on_unformatted = true
report_todo = "Never"
report_fixme = "Never"
ignore = []
Expand Down
39 changes: 11 additions & 28 deletions rust/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,28 +1,11 @@
[package]
name = "tvm"
version = "0.1.0"
license = "Apache-2.0"
description = "TVM Rust runtime"
repository = "https://github.com/dmlc/tvm"
readme = "README.md"
keywords = ["tvm", "nnvm"]
categories = ["api-bindings", "science"]
authors = ["TVM Contributors"]

[features]
default = ["nom/std"]
sgx = ["nom/alloc"]

[dependencies]
bounded-spsc-queue = "0.4.0"
error-chain = { version = "0.12.0", default-features = false }
itertools = "0.7.8"
lazy_static = "1.1.0"
ndarray = "0.11.2"
nom = {version = "4.0.0", default-features = false }
serde = "1.0.59"
serde_derive = "1.0.79"
serde_json = "1.0.17"

[target.'cfg(not(target_env = "sgx"))'.dependencies]
num_cpus = "1.8.0"
[workspace]
members = [
"common",
"runtime",
"runtime/tests/test_tvm_basic",
"runtime/tests/test_nnvm",
"frontend",
"frontend/tests/basics",
"frontend/tests/callback",
"frontend/examples/resnet"
]
4 changes: 4 additions & 0 deletions rust/common/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
target
**/*.rs.bk
Cargo.lock
/tvm-sys/src/bindgen.rs
13 changes: 13 additions & 0 deletions rust/common/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
[package]
name = "tvm-common"
version = "0.1.0"
authors = ["TVM Contributors"]
license = "Apache-2.0"

[features]
runtime = []
frontend = ["tvm-sys"]

[dependencies]
error-chain = { version = "0.12.0", default-features = false }
tvm-sys = { version = "0.1.0", path = "tvm-sys", optional = true }
File renamed without changes.
15 changes: 15 additions & 0 deletions rust/common/src/errors.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
//! Error types for `TVMArgValue` and `TVMRetValue` conversions.

error_chain! {
errors {
TryFromTVMArgValueError(expected: String, actual: String) {
description("mismatched types while converting from TVMArgValue")
display("expected `{}` but given `{}`", expected, actual)
}

TryFromTVMRetValueError(expected: String, actual: String) {
description("mismatched types while downcasting TVMRetValue")
display("invalid downcast: expected `{}` but given `{}`", expected, actual)
}
}
}
39 changes: 39 additions & 0 deletions rust/common/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
//! This crate contains the refactored basic components required
//! for `runtime` and `frontend` TVM crates.

#![crate_name = "tvm_common"]
#![recursion_limit = "1024"]
#![allow(non_camel_case_types, unused_imports)]
#![feature(box_syntax, try_from)]

#[macro_use]
extern crate error_chain;

/// Unified ffi module for both runtime and frontend crates.
pub mod ffi {
#![allow(non_camel_case_types, non_snake_case, non_upper_case_globals, unused)]

#[cfg(feature = "frontend")]
pub extern crate tvm_sys as ts;

#[cfg(feature = "runtime")]
pub mod runtime {
use std::os::raw::{c_char, c_int, c_void};

include!(concat!(env!("CARGO_MANIFEST_DIR"), "/src/c_runtime_api.rs"));

pub type BackendPackedCFunc = extern "C" fn(
args: *const TVMValue,
type_codes: *const c_int,
num_args: c_int,
) -> c_int;
}
}

pub mod errors;
pub mod ty;
pub mod value;

pub use errors::*;
pub use ty::TVMTypeCode;
pub use value::{TVMArgValue, TVMRetValue, TVMValue};
144 changes: 144 additions & 0 deletions rust/common/src/ty.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
//! This module containes `TVMTypeCode` and `TVMType` with some conversion methods.
//!
//! # Example
//!
//! ```
//! let dtype = TVMType::from("float");
//! println!("dtype is: {}", dtype);
//! ```

use std::{
ffi::{CStr, CString},
fmt::{self, Display, Formatter},
};

/// TVM type codes.
#[repr(u32)]
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub enum TVMTypeCode {
kDLInt = 0,
kDLUInt = 1,
kDLFloat = 2,
kHandle = 3,
kNull = 4,
kTVMType = 5,
kTVMContext = 6,
kArrayHandle = 7,
kNodeHandle = 8,
kModuleHandle = 9,
kFuncHandle = 10,
kStr = 11,
kBytes = 12,
kNDArrayContainer = 13,
}

impl Default for TVMTypeCode {
fn default() -> Self {
TVMTypeCode::kDLInt
}
}

impl From<TVMTypeCode> for i64 {
fn from(arg: TVMTypeCode) -> i64 {
match arg {
TVMTypeCode::kDLInt => 0,
TVMTypeCode::kDLUInt => 1,
TVMTypeCode::kDLFloat => 2,
TVMTypeCode::kHandle => 3,
TVMTypeCode::kNull => 4,
TVMTypeCode::kTVMType => 5,
TVMTypeCode::kTVMContext => 6,
TVMTypeCode::kArrayHandle => 7,
TVMTypeCode::kNodeHandle => 8,
TVMTypeCode::kModuleHandle => 9,
TVMTypeCode::kFuncHandle => 10,
TVMTypeCode::kStr => 11,
TVMTypeCode::kBytes => 12,
TVMTypeCode::kNDArrayContainer => 13,
}
}
}

impl Into<TVMTypeCode> for i64 {
fn into(self) -> TVMTypeCode {
match self {
0 => TVMTypeCode::kDLInt,
1 => TVMTypeCode::kDLUInt,
2 => TVMTypeCode::kDLFloat,
3 => TVMTypeCode::kHandle,
4 => TVMTypeCode::kNull,
5 => TVMTypeCode::kTVMType,
6 => TVMTypeCode::kTVMContext,
7 => TVMTypeCode::kArrayHandle,
8 => TVMTypeCode::kNodeHandle,
9 => TVMTypeCode::kModuleHandle,
10 => TVMTypeCode::kFuncHandle,
11 => TVMTypeCode::kStr,
12 => TVMTypeCode::kBytes,
13 => TVMTypeCode::kNDArrayContainer,
_ => unreachable!(),
}
}
}

impl Display for TVMTypeCode {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
write!(
f,
"{}",
match self {
TVMTypeCode::kDLInt => "int",
TVMTypeCode::kDLUInt => "uint",
TVMTypeCode::kDLFloat => "float",
TVMTypeCode::kHandle => "handle",
TVMTypeCode::kNull => "null",
TVMTypeCode::kTVMType => "TVM type",
TVMTypeCode::kTVMContext => "TVM context",
TVMTypeCode::kArrayHandle => "Array handle",
TVMTypeCode::kNodeHandle => "Node handle",
TVMTypeCode::kModuleHandle => "Module handle",
TVMTypeCode::kFuncHandle => "Function handle",
TVMTypeCode::kStr => "string",
TVMTypeCode::kBytes => "bytes",
TVMTypeCode::kNDArrayContainer => "ndarray container",
}
)
}
}

macro_rules! impl_prim_type {
($type:ty, $variant:ident) => {
impl<'a> From<&'a $type> for TVMTypeCode {
fn from(_arg: &$type) -> Self {
TVMTypeCode::$variant
}
}

impl<'a> From<&'a mut $type> for TVMTypeCode {
fn from(_arg: &mut $type) -> Self {
TVMTypeCode::$variant
}
}
};
}

impl_prim_type!(usize, kDLInt);
impl_prim_type!(i64, kDLInt);
impl_prim_type!(i32, kDLInt);
impl_prim_type!(i16, kDLInt);
impl_prim_type!(i8, kDLInt);

impl_prim_type!(u64, kDLUInt);
impl_prim_type!(u32, kDLUInt);
impl_prim_type!(u16, kDLUInt);
impl_prim_type!(u8, kDLUInt);

impl_prim_type!(f64, kDLFloat);
impl_prim_type!(f32, kDLFloat);

impl_prim_type!(str, kStr);
impl_prim_type!(CStr, kStr);
impl_prim_type!(String, kStr);
impl_prim_type!(CString, kStr);

impl_prim_type!([u8], kBytes);
Loading

0 comments on commit adcc5c2

Please sign in to comment.