Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TOPI] VNNI support for batch matmul #10332

Merged
merged 9 commits into from
Feb 23, 2022
Merged

Conversation

masahi
Copy link
Member

@masahi masahi commented Feb 21, 2022

Following #10230, I added VNNI support for batch_matmul as well. The cool part is that I reuse the same dense schedule in #10230 to schedule the GEMM part, and parallelize over the batch dimension. See the perf result in #10332 (comment)

After this PR, I'll add int8, int8 support to VNNI dense and batch_matmul (UPDATE: Done) - that will allow us to benchmark e2e performance on QAT BERT made possible by @Icemist in #10239.

Unlike dense case, the second input to batch_matmul is typically not a constant tensor. So I don't use alter_layout and compile time layout transform. Instead, layout transform is done at runtime. So the lowered IR for batch_matmul + post ops looks like:

  parallel (ax0.ax1.outer.ax2.outer.fused.fused, 0, 128) {
    // attr [T_layout_trans] storage_alignment = 128
    let T_layout_trans = tir.TVMBackendAllocWorkspace(1, dev_id, (uint64)1536, 0, 8)
    allocate compute[int32x16 * 1], storage_scope = global
    for (ax2, 0, 24) {
      let cse_var_2 = (ax2*64)
      let cse_var_1 = ((ax0.ax1.outer.ax2.outer.fused.fused*1536) + (ax2*4))
      T_layout_trans[ramp(cse_var_2, 1, 4)] = placeholder[ramp(cse_var_1, 1, 4)]
      T_layout_trans[ramp((cse_var_2 + 4), 1, 4)] = placeholder[ramp((cse_var_1 + 96), 1, 4)]
      T_layout_trans[ramp((cse_var_2 + 8), 1, 4)] = placeholder[ramp((cse_var_1 + 192), 1, 4)]
      T_layout_trans[ramp((cse_var_2 + 12), 1, 4)] = placeholder[ramp((cse_var_1 + 288), 1, 4)]
      T_layout_trans[ramp((cse_var_2 + 16), 1, 4)] = placeholder[ramp((cse_var_1 + 384), 1, 4)]
      T_layout_trans[ramp((cse_var_2 + 20), 1, 4)] = placeholder[ramp((cse_var_1 + 480), 1, 4)]
      T_layout_trans[ramp((cse_var_2 + 24), 1, 4)] = placeholder[ramp((cse_var_1 + 576), 1, 4)]
      T_layout_trans[ramp((cse_var_2 + 28), 1, 4)] = placeholder[ramp((cse_var_1 + 672), 1, 4)]
      T_layout_trans[ramp((cse_var_2 + 32), 1, 4)] = placeholder[ramp((cse_var_1 + 768), 1, 4)]
      T_layout_trans[ramp((cse_var_2 + 36), 1, 4)] = placeholder[ramp((cse_var_1 + 864), 1, 4)]
      T_layout_trans[ramp((cse_var_2 + 40), 1, 4)] = placeholder[ramp((cse_var_1 + 960), 1, 4)]
      T_layout_trans[ramp((cse_var_2 + 44), 1, 4)] = placeholder[ramp((cse_var_1 + 1056), 1, 4)]
      T_layout_trans[ramp((cse_var_2 + 48), 1, 4)] = placeholder[ramp((cse_var_1 + 1152), 1, 4)]
      T_layout_trans[ramp((cse_var_2 + 52), 1, 4)] = placeholder[ramp((cse_var_1 + 1248), 1, 4)]
      T_layout_trans[ramp((cse_var_2 + 56), 1, 4)] = placeholder[ramp((cse_var_1 + 1344), 1, 4)]
      T_layout_trans[ramp((cse_var_2 + 60), 1, 4)] = placeholder[ramp((cse_var_1 + 1440), 1, 4)]
    }
    for (ax1.inner, 0, 32) {
      let cse_var_3 = (((tir.shift_right(ax0.ax1.outer.ax2.outer.fused.fused, 3)*4096) + (ax1.inner*128)) + (tir.bitwise_and(ax0.ax1.outer.ax2.outer.fused.fused, 7)*16))
      compute[ramp(0, 1, 16)] = x16(0)
      for (k.outer, 0, 24) {
        compute[ramp(0, 1, 16)] = (tir.call_llvm_pure_intrin((uint32)9785, (uint32)0, x16(0), x16(tir.reinterpret(placeholder[ramp((((tir.shift_right(ax0.ax1.outer.ax2.outer.fused.fused, 3)*3072) + (ax1.inner*96)) + (k.outer*4)), 1, 4)])), tir.reinterpret(T_layout_trans[ramp((k.outer*64), 1, 64)])) + compute[ramp(0, 1, 16)])
      }
      T_add[ramp(cse_var_3, 1, 16)] = (compute[ramp(0, 1, 16)] + placeholder[ramp(cse_var_3, 1, 16)])
    }

Future work can explore possibilities for eliminating runtime layout transform, or pipelining layout transform and compute to hide the overhead.

@elvin-n @mbrookhart @tkonolige @junrushao1994 @vinx13

@elvin-n
Copy link
Contributor

elvin-n commented Feb 21, 2022

Could you please share of float batch_matmul perf data vs new introducing int8 batch_matmul?

Copy link
Contributor

@elvin-n elvin-n left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@masahi
Copy link
Member Author

masahi commented Feb 21, 2022

Ok here is the comparison of GOPS between the new VNNI impl and the existing generic code. Also note that the VNNI numbers were obtained after only 1 or 2 min of tuning while the generic ones have very large tuning space and it took more than 12 hours to get these numbers under the same tuning option. The script is at https://github.com/masahi/int8_experiment/blob/main/relay_bench.py

This is on a rocket lake i5-11400 @ 2.60GHz, 6 threads.

B M N K TVM VNNI (new) TVM existing (old)
8 64 800 320 1862.9816985699251 471.93086647752153
8 64 768 512 1957.1780318372826 254.2322265717467
8 16 256 512 481.7846564891195 249.41214520865546
8 128 128 128 1940.7730023523345 372.7504095880382
8 256 512 256 2380.99163061598 496.7852808609268
8 1024 1024 1024 2275.097320545042 219.50257992579049
8 128 768 3072 1449.8759165025203 219.86756788442386
8 128 768 768 1883.3963380647226 234.35976664468328
8 128 3072 768 1595.616577196681 196.09770614852056
16 384 384 64 2487.792996038378 418.875373840064
16 384 64 384 2441.74586017639 301.37582872146345

Copy link
Contributor

@tmoreau89 tmoreau89 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @masahi the speedups you've reported are extremely impressive! LGTM

@masahi masahi merged commit 8947729 into apache:main Feb 23, 2022
pfk-beta pushed a commit to pfk-beta/tvm that referenced this pull request Apr 11, 2022
* add test

* compute added

* schedule works

* reuse dense_vnni schedule

* try an alternative approach to scheduling layout transform

* introduce a tunable knob to decide if compute_root

* check transpose condition

* support s8 + s8 input

* pylint
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants