Skip to content

Commit

Permalink
make it tunable
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Feb 13, 2022
1 parent f7de3bf commit cbfe979
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 70 deletions.
26 changes: 21 additions & 5 deletions python/tvm/relay/op/strategy/x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,11 +530,27 @@ def dense_strategy_cpu(attrs, inputs, out_type, target):
def dense_pack_strategy_cpu(attrs, inputs, out_type, target):
"""dense_pack x86 strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_dense(topi.x86.dense_pack),
wrap_topi_schedule(topi.x86.schedule_dense_pack),
name="dense_pack.x86",
)

if (
inputs[0].dtype == "uint8"
and inputs[1].dtype == "int8"
and out_type.dtype == "int32"
and attrs["weight_layout"] == "NC16n4c"
):
strategy.add_implementation(
wrap_compute_dense(topi.x86.dense_vnni),
wrap_topi_schedule(topi.x86.schedule_dense_vnni),
name="dense_vnni.x86",
plevel=12,
)
else:
strategy.add_implementation(
wrap_compute_dense(topi.x86.dense_pack),
wrap_topi_schedule(topi.x86.schedule_dense_pack),
name="dense_pack.x86",
plevel=10,
)

return strategy


Expand Down
158 changes: 93 additions & 65 deletions python/tvm/topi/x86/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,65 +207,6 @@ def _callback(op):
return s


def dense_vnni_compute(X, packedW, bias=None):
"""Compute for uint8 x int8 -> int32 dense"""
m, k = X.shape
n_o, _, n_i, _ = packedW.shape
ak = te.reduce_axis((0, k), name="k")

C = te.compute(
(m, n_o * n_i),
lambda i, j: te.sum(
X[i, ak].astype("int32")
* packedW[tvm.tir.indexdiv(j, 16), tvm.tir.indexdiv(ak, 4), j % 16, ak % 4].astype(
"int32"
),
axis=ak,
),
tag="dense_vnni",
)

if bias is not None:
C = te.compute(C.shape, lambda i, j: C[i, j] + bias[j], tag=tag.BROADCAST)

return C


def dense_vnni_schedule(s, C, O):
"""Schedule dense compute using VNNI vpdpbusd instruction"""
# C: The output of GEMM
# O: The output of the fused op
if C != O:
a_y, a_x = O.op.axis
a_yo, a_yi = s[O].split(a_y, factor=32)
a_xo, a_xi = s[O].split(a_x, factor=16)

s[O].reorder(a_yo, a_xo, a_yi, a_xi)
fused = s[O].fuse(a_yo, a_xo)
s[O].vectorize(a_xi)
s[O].parallel(fused)

s[C].compute_at(s[O], a_yi)

a_y, a_x = C.op.axis
(a_k,) = C.op.reduce_axis

a_ko, a_ki = s[C].split(a_k, factor=4)
a_yo, a_yi = s[C].split(a_y, factor=32)
a_xo, a_xi = s[C].split(a_x, factor=16)

s[C].reorder(a_yo, a_xo, a_yi, a_ko, a_xi, a_ki)

pc = dot_16x1x16_uint8_int8_int32_cascadelake()
s[C].tensorize(a_xi, pc)

if C == O:
fused = s[O].fuse(a_yo, a_xo)
s[O].parallel(fused)

return s


@autotvm.register_topi_compute("dense_pack.x86")
def dense_pack(cfg, data, weight, bias=None, out_dtype=None):
"""Compute dense with transformed weight."""
Expand All @@ -275,10 +216,6 @@ def dense_pack(cfg, data, weight, bias=None, out_dtype=None):
if len(weight.shape) == 3:
N, _, packw_bn = get_const_tuple(weight.shape) # out_dim
N = N * packw_bn
elif len(weight.shape) == 4:
N, K, n_inner, k_inner = get_const_tuple(weight.shape) # out_dim
assert n_inner == 16 and k_inner == 4
return dense_vnni_compute(data, weight, bias)
else:
N, _ = get_const_tuple(weight.shape) # out_dim
# create tuning space
Expand Down Expand Up @@ -336,15 +273,106 @@ def schedule_dense_pack(cfg, outs):
s = te.create_schedule([x.op for x in outs])

def _callback(op):
if "dense_vnni" in op.tag:
dense_vnni_schedule(s, op.output(0), outs[0])
if "dense_pack" in op.tag:
_schedule_dense_pack_template(cfg, s, op.output(0), outs[0])

traverse_inline(s, outs[0].op, _callback)
return s


def dense_vnni_compute(cfg, X, packed_w, bias=None):
"""Compute for uint8 x int8 -> int32 dense"""
m, k = X.shape
n_o, _, n_i, _ = packed_w.shape
ak = te.reduce_axis((0, k), name="k")

C = te.compute(
(m, n_o * n_i),
lambda i, j: te.sum(
X[i, ak].astype("int32")
* packed_w[tvm.tir.indexdiv(j, 16), tvm.tir.indexdiv(ak, 4), j % 16, ak % 4].astype(
"int32"
),
axis=ak,
),
tag="dense_vnni",
)

if bias is not None:
C = te.compute(C.shape, lambda i, j: C[i, j] + bias[j], tag=tag.BROADCAST)

a_y, _ = C.op.axis
cfg.define_split("tile_y", a_y, num_outputs=2)

return C


def dense_vnni_schedule(cfg, s, C, O):
"""Schedule dense compute using VNNI vpdpbusd instruction"""
# C: The output of GEMM
# O: The output of the fused op
def split_y(out):
default_y_split_factor = 32
a_y = out.op.axis[0]

if cfg.is_fallback:
return s[out].split(a_y, factor=default_y_split_factor)

return cfg["tile_y"].apply(s, out, a_y)

(a_k,) = C.op.reduce_axis

a_yo, a_yi = split_y(C)
a_xo, a_xi = s[C].split(C.op.axis[1], factor=16)
a_ko, a_ki = s[C].split(a_k, factor=4)

s[C].reorder(a_yo, a_xo, a_yi, a_ko, a_xi, a_ki)

pc = dot_16x1x16_uint8_int8_int32_cascadelake()
s[C].tensorize(a_xi, pc)

if C == O:
fused = s[O].fuse(a_yo, a_xo)
s[O].parallel(fused)
else:
a_yo, a_yi = split_y(O)
a_xo, a_xi = s[O].split(O.op.axis[1], factor=16)

s[O].reorder(a_yo, a_xo, a_yi, a_xi)
fused = s[O].fuse(a_yo, a_xo)
s[O].vectorize(a_xi)
s[O].parallel(fused)

s[C].compute_at(s[O], a_yi)

return s


@autotvm.register_topi_compute("dense_vnni.x86")
def dense_vnni(cfg, data, weight, bias=None, out_dtype=None):
"""Compute for uint8 x int8 -> int32 dense"""
if out_dtype is None:
out_dtype = data.dtype
assert len(weight.shape) == 4
assert data.dtype == "uint8" and weight.dtype == "int8"
_, _, n_inner, k_inner = get_const_tuple(weight.shape) # out_dim
assert n_inner == 16 and k_inner == 4
return dense_vnni_compute(cfg, data, weight, bias)


@autotvm.register_topi_schedule("dense_vnni.x86")
def schedule_dense_vnni(cfg, outs):
"""Create a schedule for dense_vnni"""
s = te.create_schedule([x.op for x in outs])

def _callback(op):
if "dense_vnni" in op.tag:
dense_vnni_schedule(cfg, s, op.output(0), outs[0])

traverse_inline(s, outs[0].op, _callback)
return s


def matmul_blas_common(cfg, tensor_a, tensor_b, bias, out_dtype, transpose_a, transpose_b, lib):
"""Compute matmul/dense using a BLAS library"""
M, K = get_const_tuple(tensor_a.shape)
Expand Down

0 comments on commit cbfe979

Please sign in to comment.