Skip to content

Commit

Permalink
[Frontend] Onnx (apache#40)
Browse files Browse the repository at this point in the history
* init onnx

finish onnx frontend

add onnx tests

fix various

backup

use transformer

[Frontend] graph passed

add test forward

test forward

fix doc and lint

fix test graph tuple

from_onnx now take 2 args, output (sym, params)

fix rename

fix input names

fix multiple

fix lint

fix lint check

* better doc
  • Loading branch information
zhreshold authored and tqchen committed May 29, 2018
1 parent dddd8d1 commit 4f664f5
Show file tree
Hide file tree
Showing 20 changed files with 1,105 additions and 223 deletions.
1 change: 1 addition & 0 deletions nnvm/python/nnvm/frontend/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
"""NNVM frontends."""
from __future__ import absolute_import
from .mxnet import from_mxnet
from .onnx import from_onnx
130 changes: 130 additions & 0 deletions nnvm/python/nnvm/frontend/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
"""Shared functions and classes for frontends."""
from __future__ import absolute_import as _abs
import warnings
from .._base import string_types

class Renamer(object):
"""A simply renamer for operators.
Parameters
----------
new_name : str
The new name for the operator
"""
def __init__(self, new_name):
self._new_name = new_name

def __call__(self, attrs):
return self._new_name, attrs


class AttrConverter(object):
"""Common attribute conveter. An AttrConverter instance is a callable:
```
attr_converter = AttrConverter(op_name, transforms={'a':'b', 'c':('d', 1)})
new_op_name, new_attr = attr_converter(attrs)
```
Parameters
----------
op_name : str or callable
If set as str, returned operator name is the str.
If set as callable, returned operator is the str returned by calling:
`op_name = func(attr)`
transforms : dict of `new_name, or (new_name, default_value, transform function)`
If only a new_name is provided, it's like renaming the attribute name.
If default_value if provded, then the attribute is considered as optional.
If transform function is provided, the original attribute value is handled
by transform function.
excludes : list
A list of excluded attributes that should `NOT` appear.
Raise NotImplementedError if occured.
disables : list
A list of attributes that is disabled in nnvm. Raise warnings.
ignores : list
A list of attributes that is ignored in nnvm. Silent.
extras : dict
A series of additional attributes should be added anyway to the returned
attribute dict.
custom_check : callable
A custom function takes attribute, and return True/False.
Raise RuntimeError if not bool(True) returned.
"""
def __init__(self, op_name, transforms=None,
excludes=None, disables=None, ignores=None,
extras=None, custom_check=None):
self._op_name = op_name
self._transforms = transforms if transforms else {}
self._excludes = excludes if excludes else []
self._disables = disables if disables else []
self._ignores = ignores if ignores else []
self._extras = extras if extras else {}
self._custom_check = custom_check

def __call__(self, attrs):
# apply custom check
if self._custom_check:
func, msg = self._custom_check
if not func(attrs):
raise RuntimeError("Check failed: {}".format(msg))
# get new op_name
if isinstance(self._op_name, string_types):
op_name = self._op_name
else:
assert callable(self._op_name), "op_name can either be string or callable"
op_name = self._op_name(attrs)
# convert attributes
new_attrs = {}
for k in attrs.keys():
if k in self._excludes:
raise NotImplementedError("Attribute {} not supported yet.".format(k))
elif k in self._disables:
warnings.warn("Attribute {} is disabled in nnvm.sym.{}".format(k, op_name))
elif k in self._ignores:
pass
elif k in self._transforms:
new_name, defaults, transform = self._parse_default(self._transforms[k])
if defaults is None:
new_attr = self._required_attr(attrs, k)
else:
new_attr = attrs.get(k, None)
if new_attr is None:
new_attrs[new_name] = defaults
else:
new_attrs[new_name] = transform(new_attr)
else:
# copy
new_attrs[k] = attrs[k]
# add extras
new_attrs.update(self._extras)
return op_name, new_attrs

def _parse_default(self, target):
"""Helper function to parse default values."""
if not isinstance(target, (list, tuple)):
k, v, t = target, None, lambda x: x
elif len(target) == 1:
k, v, t = target[0], None, lambda x: x
elif len(target) == 2:
k, v, t = target[0], target[1], lambda x: x
elif len(target) > 2:
k, v, t = target[0], target[1], target[2]
else:
k = None # should raise
if not isinstance(k, string_types):
msg = "{} is not a valid target, (name, default) expected.".format(target)
raise ValueError(msg)
return k, v, t

def _parse_bool(self, value):
"""Helper function to parse default boolean values."""
if isinstance(value, string_types):
return value.strip().lower() in ['true', '1', 't', 'y', 'yes']
return bool(value)

def _required_attr(self, attr, key):
"""Wrapper for getting required attributes."""
assert isinstance(attr, dict)
if key not in attr:
raise AttributeError("Required attribute {} not found.".format(key))
return attr[key]
8 changes: 4 additions & 4 deletions nnvm/python/nnvm/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,12 @@ def _pooling(attrs):
def _batch_norm(attrs):
if _parse_bool_str(attrs, 'output_mean_var'):
_raise_not_supported('output_mean_var', 'batch_norm')
if _parse_bool_str(attrs, 'fix_gamma'):
_warn_not_used('fix_gamma', 'batch_norm')
# if _parse_bool_str(attrs, 'fix_gamma'):
# _warn_not_used('fix_gamma', 'batch_norm')
if _parse_bool_str(attrs, 'use_global_stats'):
_warn_not_used('use_global_stats', 'batch_norm')
if _parse_bool_str(attrs, 'momentum'):
_warn_not_used('momentum', 'batch_norm')
# if _parse_bool_str(attrs, 'momentum'):
# _warn_not_used('momentum', 'batch_norm')
op_name, new_attrs = 'batch_norm', {}
new_attrs['axis'] = attrs.get('axis', 1)
new_attrs['epsilon'] = attrs.get('eps', 0.001)
Expand Down
Loading

0 comments on commit 4f664f5

Please sign in to comment.