Skip to content

Commit

Permalink
[TOPI] VNNI support for int8 dense (#10230)
Browse files Browse the repository at this point in the history
* wip

* revert for now

* simplify blocking

* add bench script

* update type rel

* refactor tests

* end to end compilation working

* paralleize outer loop

* add shape check

* fused schedule first cut

* restore original test

* black

* add vnni check

* add relay test

* skip on ci

* check dtype

* lint

* make it tunable

* minor cleanup
  • Loading branch information
masahi committed Feb 15, 2022
1 parent a1d8f72 commit 0009a30
Show file tree
Hide file tree
Showing 6 changed files with 169 additions and 24 deletions.
26 changes: 21 additions & 5 deletions python/tvm/relay/op/strategy/x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,11 +530,27 @@ def dense_strategy_cpu(attrs, inputs, out_type, target):
def dense_pack_strategy_cpu(attrs, inputs, out_type, target):
"""dense_pack x86 strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_dense(topi.x86.dense_pack),
wrap_topi_schedule(topi.x86.schedule_dense_pack),
name="dense_pack.x86",
)

if (
inputs[0].dtype == "uint8"
and inputs[1].dtype == "int8"
and out_type.dtype == "int32"
and attrs["weight_layout"] == "NC16n4c"
):
strategy.add_implementation(
wrap_compute_dense(topi.x86.dense_vnni),
wrap_topi_schedule(topi.x86.schedule_dense_vnni),
name="dense_vnni.x86",
plevel=12,
)
else:
strategy.add_implementation(
wrap_compute_dense(topi.x86.dense_pack),
wrap_topi_schedule(topi.x86.schedule_dense_pack),
name="dense_pack.x86",
plevel=10,
)

return strategy


Expand Down
94 changes: 94 additions & 0 deletions python/tvm/topi/x86/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from .utils import get_simd_32bit_lanes
from .. import generic, tag
from ..utils import traverse_inline, get_const_tuple
from .tensor_intrin import dot_16x1x16_uint8_int8_int32_cascadelake


def _schedule_dense_pack_template(cfg, s, C, O):
Expand Down Expand Up @@ -279,6 +280,99 @@ def _callback(op):
return s


def dense_vnni_compute(cfg, X, packed_w, bias=None):
"""Compute for uint8 x int8 -> int32 dense"""
m, k = X.shape
n_o, _, n_i, _ = packed_w.shape
ak = te.reduce_axis((0, k), name="k")

C = te.compute(
(m, n_o * n_i),
lambda i, j: te.sum(
X[i, ak].astype("int32")
* packed_w[tvm.tir.indexdiv(j, 16), tvm.tir.indexdiv(ak, 4), j % 16, ak % 4].astype(
"int32"
),
axis=ak,
),
tag="dense_vnni",
)

if bias is not None:
C = te.compute(C.shape, lambda i, j: C[i, j] + bias[j], tag=tag.BROADCAST)

a_y, _ = C.op.axis
cfg.define_split("tile_y", a_y, num_outputs=2)

return C


def dense_vnni_schedule(cfg, s, C, O):
"""Schedule dense compute using VNNI vpdpbusd instruction"""
# C: The output of GEMM
# O: The output of the fused op
def split_y(out):
default_y_split_factor = 32
a_y = out.op.axis[0]

if cfg.is_fallback:
return s[out].split(a_y, factor=default_y_split_factor)

return cfg["tile_y"].apply(s, out, a_y)

(a_k,) = C.op.reduce_axis

a_yo, a_yi = split_y(C)
a_xo, a_xi = s[C].split(C.op.axis[1], factor=16)
a_ko, a_ki = s[C].split(a_k, factor=4)

s[C].reorder(a_yo, a_xo, a_yi, a_ko, a_xi, a_ki)

pc = dot_16x1x16_uint8_int8_int32_cascadelake()
s[C].tensorize(a_xi, pc)

if C == O:
fused = s[O].fuse(a_yo, a_xo)
else:
a_yo, a_yi = split_y(O)
a_xo, a_xi = s[O].split(O.op.axis[1], factor=16)

s[O].reorder(a_yo, a_xo, a_yi, a_xi)
s[O].vectorize(a_xi)
s[C].compute_at(s[O], a_yi)

fused = s[O].fuse(a_yo, a_xo)

s[O].parallel(fused)

return s


@autotvm.register_topi_compute("dense_vnni.x86")
def dense_vnni(cfg, data, weight, bias=None, out_dtype=None):
"""Compute for uint8 x int8 -> int32 dense"""
if out_dtype is None:
out_dtype = data.dtype
assert len(weight.shape) == 4
assert data.dtype == "uint8" and weight.dtype == "int8"
_, _, n_inner, k_inner = get_const_tuple(weight.shape) # out_dim
assert n_inner == 16 and k_inner == 4
return dense_vnni_compute(cfg, data, weight, bias)


@autotvm.register_topi_schedule("dense_vnni.x86")
def schedule_dense_vnni(cfg, outs):
"""Create a schedule for dense_vnni"""
s = te.create_schedule([x.op for x in outs])

def _callback(op):
if "dense_vnni" in op.tag:
dense_vnni_schedule(cfg, s, op.output(0), outs[0])

traverse_inline(s, outs[0].op, _callback)
return s


def matmul_blas_common(cfg, tensor_a, tensor_b, bias, out_dtype, transpose_a, transpose_b, lib):
"""Compute matmul/dense using a BLAS library"""
M, K = get_const_tuple(tensor_a.shape)
Expand Down
15 changes: 14 additions & 1 deletion python/tvm/topi/x86/dense_alter_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from .dense import _default_dense_pack_config
from ..utils import get_const_tuple
from ..nn import dense_alter_layout
from .utils import target_has_vnni


@dense_alter_layout.register(["cpu", "arm_cpu"])
Expand All @@ -34,8 +35,20 @@ def _alter_dense_layout(attrs, inputs, tinfos, out_type):
out_dtype = out_type.dtype
M, K = get_const_tuple(data_tensor.shape)
N, _ = get_const_tuple(weight_tensor.shape)
mcpu = tvm.target.Target.current().mcpu

impl, outs = relay.backend.te_compiler.select_implementation(
if (
target_has_vnni(mcpu)
and data_tensor.dtype == "uint8"
and weight_tensor.dtype == "int8"
and weight_tensor.shape[0] % 16 == 0
and weight_tensor.shape[1] % 4 == 0
):
# TODO(masahi): Support int8 x int8 case
weight_layout = "NC16n4c"
return relay.nn.contrib_dense_pack(inputs[0], inputs[1], weight_layout, None, out_dtype)

_, outs = relay.backend.te_compiler.select_implementation(
relay.op.get("nn.dense"), attrs, tinfos, out_type, target
)
workload = autotvm.task.get_workload(outs)
Expand Down
2 changes: 1 addition & 1 deletion src/relay/op/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ bool DensePackRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
ICHECK(param != nullptr);

ICHECK_EQ(data->shape.size(), 2) << "Only 2D data is supported";
ICHECK_EQ(weight->shape.size(), 3) << "Weight is not packed";
ICHECK(weight->shape.size() == 3 || weight->shape.size() == 4) << "Expect weight to be 3D or 4D";

Array<tvm::PrimExpr> oshape = data->shape;
oshape.Set(1, weight->shape[0] * weight->shape[2]);
Expand Down
4 changes: 3 additions & 1 deletion tests/python/contrib/test_gemm_acc32_vnni.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ def verify(target="llvm -mcpu=cascadelake"):
(m, n),
lambda i, j: te.sum(
X[i, ak].astype("int32")
* packedW[j / 16, (ak / 4) * 16 + j % 16, ak % 4].astype("int32"),
* packedW[
tvm.tir.indexdiv(j, 16), tvm.tir.indexdiv(ak, 4) * 16 + j % 16, ak % 4
].astype("int32"),
axis=ak,
),
name="F",
Expand Down
52 changes: 36 additions & 16 deletions tests/python/relay/test_op_level1.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from tvm import te
import scipy
from tvm import relay
from tvm.relay import transform
import pytest
from tvm.relay.testing import run_infer_type
import tvm.topi.testing
from tvm.contrib.nvcc import have_fp16
Expand Down Expand Up @@ -634,19 +634,39 @@ def test_bitserial_dense():
assert yy.checked_type == relay.TensorType((m, 32), "int16")


@pytest.mark.skip("Requires cascadelake")
def test_dense_vnni():
data_shape = (32, 96)
weight_shape = (128, 96)

data = relay.var("data", shape=data_shape, dtype="uint8")
weight = relay.var("weight", shape=weight_shape, dtype="int8")
bias = relay.var("bias", shape=(weight_shape[0],), dtype="int32")
dense = relay.nn.dense(data, weight, out_dtype="int32")
out = relay.nn.bias_add(dense, bias)
mod = tvm.IRModule.from_expr(out)

target = "llvm -mcpu=cascadelake"
with tvm.transform.PassContext(opt_level=3):
lib = relay.build(mod, target=target)

dev = tvm.device(target, 0)
runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev))

a = np.random.uniform(1, 10, size=data_shape).astype("uint8")
b = np.random.uniform(1, 10, size=weight_shape).astype("int8")
c = np.random.uniform(1, 10, size=(weight_shape[0],)).astype("int32")

runtime.set_input("data", a)
runtime.set_input("weight", b)
runtime.set_input("bias", c)
runtime.run()

out = runtime.get_output(0).numpy()
ref = np.dot(a, b.transpose()) + c

np.testing.assert_equal(out, ref)


if __name__ == "__main__":
test_concatenate()
test_bias_add()
test_bias_add_type_failure()
test_unary_op()
test_binary_op()
test_expand_dims_infer_type()
test_expand_dims()
test_softmax()
test_log_softmax()
test_dropout()
test_batch_norm()
test_matmul()
test_dense()
test_bitserial_dense()
test_dense_dtype()
pytest.main([__file__])

0 comments on commit 0009a30

Please sign in to comment.