Skip to content

Commit

Permalink
VM compiler. (#18)
Browse files Browse the repository at this point in the history
* VM compiler.

* Update.

* Compile IRmodule; expose Python api

* Add dtype contant serialization and type hint.

* Address comments.

* Add todos and fix lint.

* Update

* Update.
  • Loading branch information
YuchenJin authored and junrushao committed Feb 5, 2023
1 parent a392e1a commit 02c8657
Show file tree
Hide file tree
Showing 17 changed files with 602 additions and 42 deletions.
63 changes: 63 additions & 0 deletions include/tvm/relax/attrs/memory.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/relax/attrs/memory.h
* \brief Attributes for memory operators.
*/
#ifndef TVM_RELAX_ATTRS_MEMORY_H_
#define TVM_RELAX_ATTRS_MEMORY_H_

#include <tvm/ir/attrs.h>

namespace tvm {
namespace relax {
/*!
* \brief Options for allocating storage.
*/
struct AllocStorageAttrs : public tvm::AttrsNode<AllocStorageAttrs> {
DataType dtype;
int device_id;
int device_type;

TVM_DECLARE_ATTRS(AllocStorageAttrs, "relax.attrs.AllocStorageAttrs") {
TVM_ATTR_FIELD(dtype)
.describe("The dtype of the tensor to allocate.")
.set_default(DataType::Float(32, 1));
TVM_ATTR_FIELD(device_id).describe("The device id on which to allocate memory.");
TVM_ATTR_FIELD(device_type).describe("The device type on which to allocate memory.");
}
};

/*!
* \brief Options for allocating tensors.
*/
struct AllocTensorAttrs : public tvm::AttrsNode<AllocTensorAttrs> {
DataType dtype;

TVM_DECLARE_ATTRS(AllocTensorAttrs, "relax.attrs.AllocTensorAttrs") {
TVM_ATTR_FIELD(dtype)
.describe("The dtype of the tensor to allocate.")
.set_default(DataType::Float(32, 1));
}
};

} // namespace relax
} // namespace tvm
#endif // TVM_RELAX_ATTRS_MEMORY_H_
3 changes: 2 additions & 1 deletion include/tvm/relax/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ using relay::Call;
* \param diag_ctx The diagnostic context for reporting errors.
* \return The inferred output shape expression.
*/
using FInferShape = runtime::TypedPackedFunc<Optional<RelayExpr>(const Call& call, DiagnosticContext diag_ctx)>;
using FInferShape =
runtime::TypedPackedFunc<Optional<RelayExpr>(const Call& call, DiagnosticContext diag_ctx)>;

/*!
* \brief Infer the output type for operators. This function will
Expand Down
10 changes: 5 additions & 5 deletions include/tvm/relax/vm/exec_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
* \file tvm/relax/vm/exec_builder.h
* \brief
*/
#ifndef TVM_RELAX_EXEC_BUILDER_H_
#define TVM_RELAX_EXEC_BUILDER_H_
#ifndef TVM_RELAX_VM_EXEC_BUILDER_H_
#define TVM_RELAX_VM_EXEC_BUILDER_H_

#include <tvm/ir/expr.h>
#include <tvm/node/reflection.h>
Expand Down Expand Up @@ -52,7 +52,7 @@ class ExecBuilderNode : public Object {
* \param func The function name.
* \param num_inputs The number of inputs.
*/
void Function(std::string func, int64_t num_inputs);
void EmitFunction(std::string func, int64_t num_inputs);
/*!
* \brief Emit a call instruction for a packed function.
* \param func The packed function name.
Expand All @@ -69,7 +69,7 @@ class ExecBuilderNode : public Object {
* \brief Emit a constant value to the constant pool.
* \return The index that represents the constant.
*/
vm::Index EmitConstant(ObjectRef obj);
vm::Index EmitConstant(TVMRetValue obj);
/*!
* \brief Get the built executable.
* \return The built executable.
Expand Down Expand Up @@ -102,4 +102,4 @@ class ExecBuilder : public ObjectRef {
} // namespace relax
} // namespace tvm

#endif // TVM_RELAX_EXEC_BUILDER_H_
#endif // TVM_RELAX_VM_EXEC_BUILDER_H_
2 changes: 1 addition & 1 deletion include/tvm/relax/vm/executable.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ class ExecutableNode : public Object {
/*! \brief A map from globals (as strings) to their index in the function map. */
std::unordered_map<std::string, Index> global_map;
/*! \brief The global constant pool. */
std::vector<ObjectRef> constants;
std::vector<TVMRetValue> constants;
/*! \brief The name of packed functions. */
std::vector<std::string> func_names;
/*! \brief A mapping from the packed function (as string) to the index that
Expand Down
3 changes: 3 additions & 0 deletions python/tvm/relax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from . import parser
from . import analysis
from . import transform
from . import vm_compiler


# Expr
Expand Down Expand Up @@ -61,9 +62,11 @@
ExecBuilder = exec_builder.ExecBuilder
VirtualMachine = vm.VirtualMachine
load_exec_from_file = vm.load_exec_from_file
compile = vm_compiler.compile

# Operator
from .op.base import call_dps
from .op.op_attrs import AllocStorageAttrs, AllocTensorAttrs

# IRBuilder
IRBuilder = ir_builder.IRBuilder
Expand Down
4 changes: 0 additions & 4 deletions python/tvm/relax/base.py

This file was deleted.

2 changes: 1 addition & 1 deletion python/tvm/relax/exec_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def emit_call(self, name, args=[], dst=None):
dst = SpecialReg.VOID_ARG
args_ = []
for arg in args:
if isinstance(arg, tvm.nd.NDArray):
if isinstance(arg, tvm.nd.NDArray) or isinstance(arg, tvm.DataType):
new_arg = self.emit_constant(arg)
args_.append(new_arg)
else:
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relax/op/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@
# Operators
from .base import *
from .tensor import *
from .op_attrs import *
28 changes: 28 additions & 0 deletions python/tvm/relax/op/op_attrs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""The attributes node used for Relax operators"""
from tvm.ir import Attrs
import tvm._ffi

@tvm._ffi.register_object("relax.attrs.AllocStorageAttrs")
class AllocStorageAttrs(Attrs):
"""Attributes used in alloc_storage operators"""


@tvm._ffi.register_object("relax.attrs.AllocTensorAttrs")
class AllocTensorAttrs(Attrs):
"""Attributes used in alloc_tensor operators"""
17 changes: 17 additions & 0 deletions python/tvm/relax/parser.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,20 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

import inspect
Expand Down
70 changes: 70 additions & 0 deletions python/tvm/relax/vm_compiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable, invalid-name, redefined-builtin
"""
The Relax Virtual Machine compiler.
"""
from typing import List, Optional, Union, Dict
import tvm
from . import vm, _ffi_api


def compile(mod: tvm.IRModule) -> vm.Executable:
"""Compile the module to VM executable. A helper function for VMCompiler.
Parameters
----------
mod : tvm.IRModule
The Relay module to build.
Returns
-------
exec : tvm.relax.Executable
The VM executable that contains the bytecode.
"""
compiler = VMCompiler()
compiler.compile(mod)
return compiler.get_exec()


class VMCompiler(object):
"""Compiler that compiles module to VM executable."""

def __init__(self):
self.mod = _ffi_api.VMCompiler()
self._compile = self.mod["compile"]
self._get_exec = self.mod["get_executable"]

def compile(self, mod: tvm.IRModule) -> None:
"""Compile the module to VM executable.
Parameters
----------
mod : tvm.IRModule
The IRModule to build.
"""
self._compile(mod)

def get_exec(self) -> vm.Executable:
"""Get the VM executable.
Returns
-------
exec : tvm.relax.Executable
The VM executable that contains bytecode.
"""
return self._get_exec()
4 changes: 4 additions & 0 deletions src/relax/op/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
* specific language governing permissions and limitations
* under the License.
*/
#include <tvm/relax/attrs/memory.h>
#include <tvm/relax/expr.h>
#include <tvm/relay/op.h>

Expand All @@ -24,6 +25,9 @@
namespace tvm {
namespace relax {

TVM_REGISTER_NODE_TYPE(AllocStorageAttrs);
TVM_REGISTER_NODE_TYPE(AllocTensorAttrs);

bool EqualConstInt(const PrimExpr& lhs, int64_t value) {
if (const int64_t* pvalue = tir::as_const_int(lhs)) {
return pvalue[0] == value;
Expand Down
Loading

0 comments on commit 02c8657

Please sign in to comment.