From 2545e9caecadd66c72fbb6734c30d100e823b0fb Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Sat, 28 Aug 2021 12:59:20 -0700 Subject: [PATCH] [Frontend][Onnx] Simplify onnx input since name accesses are not reliable. (#8867) * Simplify onnx input since name accesses are no longer supported. * move Celu importer. --- python/tvm/relay/frontend/onnx.py | 82 +++++++++---------------------- 1 file changed, 22 insertions(+), 60 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 5471f67ea106..9144d3e145c8 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -63,54 +63,16 @@ } -class onnx_input: - """Dual purpose list or dictionary access object.""" - - def __init__(self): - self.input_keys = [] - self.input_dict = {} +class onnx_input(list): + """A helper extension to list that returns None for out of bound indices.""" def __getitem__(self, item): - if isinstance(item, int): - if item > (len(self.input_keys) - 1): - return None - return self.input_dict[self.input_keys[item]] - if isinstance(item, str): - if item not in self.input_keys: - return None - return self.input_dict[item] if isinstance(item, slice): - keys = self.input_keys[item] - return [self.input_dict[key] for key in keys] - - raise ValueError("Only integer, string, and slice accesses allowed.") - - def __setitem__(self, item, value): + indices = list(range(item.stop)[item]) + return [self[i] for i in indices] if isinstance(item, int): - self.input_dict[self.input_keys[item]] = value - elif isinstance(item, str): - self.input_keys.append(item) - self.input_dict[item] = value - else: - raise ValueError("Only integer and string indexed writes allowed.") - - def keys(self): - return self.input_keys - - def __len__(self): - return len(self.input_keys) - - def __iter__(self): - self.n = 0 - return self - - def __next__(self): - if self.n < len(self.input_keys): - output = self.input_dict[self.input_keys[self.n]] - self.n += 1 - return output - - raise StopIteration + return list(self)[item] if item < len(self) else None + raise TypeError("list indices must be integers or slices, not %s" % type(item).__name__) def get_numpy(tensor_proto): @@ -2672,6 +2634,19 @@ def _impl_v10(cls, inputs, attr, params): return isinf +class Celu(OnnxOpConverter): + """Operator convereter for celu""" + + @classmethod + def _impl_v12(cls, inputs, attr, params): + x = inputs[0] + dtype = infer_type(x).checked_type.dtype + alpha = _op.const(attr.get("alpha", 1.0), dtype) + zero = _op.const(0, dtype) + one = _op.const(1, dtype) + return _op.maximum(zero, x) + _op.minimum(zero, alpha * (_op.exp(x / alpha) - one)) + + class MaxRoiPool(OnnxOpConverter): """Operator converter for MaxRoiPool.""" @@ -3822,13 +3797,13 @@ def from_onnx(self, graph, opset, get_output_expr=False): for node in graph.node: op_name = node.op_type attr = self._parse_attr(node.attribute) - # Create and populate onnx input object. + # Create and populate input list. inputs = onnx_input() for i in node.input: if i != "": - inputs[i] = self._nodes[self._renames.get(i, i)] + inputs.append(self._nodes[self._renames.get(i, i)]) else: - inputs[i] = None + inputs.append(None) i_name = self._parse_value_proto(node) node_output = self._fix_outputs(op_name, node.output) attr["tvm_custom"] = {} @@ -3981,19 +3956,6 @@ def _fix_outputs(self, op_name, outputs): return outputs -class Celu(OnnxOpConverter): - """Operator convereter for celu""" - - @classmethod - def _impl_v12(cls, inputs, attr, params): - x = inputs[0] - dtype = infer_type(x).checked_type.dtype - alpha = _op.const(attr.get("alpha", 1.0), dtype) - zero = _op.const(0, dtype) - one = _op.const(1, dtype) - return _op.maximum(zero, x) + _op.minimum(zero, alpha * (_op.exp(x / alpha) - one)) - - def from_onnx( model, shape=None, dtype="float32", opset=None, freeze_params=False, convert_config=None ):