Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TOPI] Winograd #899

Closed
3 of 4 tasks
tqchen opened this issue Feb 13, 2018 · 10 comments
Closed
3 of 4 tasks

[TOPI] Winograd #899

tqchen opened this issue Feb 13, 2018 · 10 comments

Comments

@tqchen
Copy link
Member

tqchen commented Feb 13, 2018

So far we didn't have winograd, and with #898 brings the first implementation, we want to push it for other backends, so this is an issue to track the progress. Ideally, let us make the implementation also works for bigger batches.

  • Mali
  • CUDA
  • AMDGPU
  • arm
@tqchen
Copy link
Member Author

tqchen commented Feb 13, 2018

@ZihengJiang @masahi @adityaatluri @Laurawly

@aditya4d1
Copy link
Contributor

I'm getting into implementing winograd kernels, will let you know the progress.

@masahi
Copy link
Member

masahi commented Jun 24, 2018

I have a very basic winograd working for CUDA and AMDGPU here. My code is modified one from Mali winograd which is a very good reference. Will try optimize batched gemm which is taking 96% of compute.

For AOT compiler like TVM, filter transform can be pre-computed during compile time. TVM Mali implementation doesn't do this, but it should.

@merrymercy
Copy link
Member

merrymercy commented Jun 24, 2018

I implemented some experimental cuda winograd with filter transform precomputed. The code is not very clean so I only keep it in my local branch.

ref (only op/compute definition, schedule is in a private repo):
https://github.com/merrymercy/nnvm/blob/winograd/python/nnvm/top/contrib.py
https://github.com/merrymercy/tvm/blob/winograd/topi/python/topi/contrib.py

To support pre-computing filter transform, we need

  • implement two ops in NNVM : conv2d_winograd_filter_transform, conv2d_winograd_without_filter_transform
  • register alter op in NNVM: it can replace the original conv2d op with two ops: one for filter transform and one for other parts. Then the filter transform op can be pre-computed by optimization pass PrecomputePrune
    The alter op registration (implemented by General Layout Support dmlc/nnvm#447) looks like
@reg.register_alter_op_layout("conv2d")
def alter_conv2d_layout(attrs, inputs, tinfos):
    ....
    if groups == 1 and kernel_size == (3, 3) and strides == (1, 1):
        copy_inputs[1] = sym.contrib.conv2d_winograd_6x6_3x3_weight_transform(copy_inputs[1])
        return sym.contrib.conv2d_winograd_6x6_3x3_without_weight_transform(*copy_inputs, **new_attrs)
    else:
        return sym.conv2d(*copy_inputs, **new_attrs)

@masahi
Copy link
Member

masahi commented Jun 24, 2018

nice, I am hoping to push my winograd code to AMDGPU backend with necessary changes in NNVM soon. Later when you add cuda winograd, I can update AMDGPU winograd to be in sync with cuda one.

For AMDGPU backend, my basic winograd is already faster than existing direct conv, which uses cuda schedules as is.

@masahi
Copy link
Member

masahi commented Jun 25, 2018

@merrymercy for input transform, my IR dump is something like this

 produce V.local {
    V.local[0] = 0.000000f
    V.local[0] = (V.local[0] + d[0])
    V.local[0] = (V.local[0] - d[2])
    V.local[0] = (V.local[0] - d[8])
    V.local[0] = (V.local[0] + d[10])
    V.local[1] = 0.000000f
    V.local[1] = (V.local[1] + d[1])
    V.local[1] = (V.local[1] + d[2])
    V.local[1] = (V.local[1] - d[9])
    V.local[1] = (V.local[1] - d[10])
    V.local[2] = 0.000000f
    V.local[2] = (V.local[2] - d[1])
    V.local[2] = (V.local[2] + d[2])
    V.local[2] = (V.local[2] + d[9])
    V.local[2] = (V.local[2] - d[10])
    V.local[3] = 0.000000f
    V.local[3] = (V.local[3] + d[1])
    V.local[3] = (V.local[3] - d[3])
    V.local[3] = (V.local[3] - d[9])
    V.local[3] = (V.local[3] + d[11])
    V.local[4] = 0.000000f
    V.local[4] = (V.local[4] + d[4])
    V.local[4] = (V.local[4] - d[6])
    V.local[4] = (V.local[4] + d[8])
    V.local[4] = (V.local[4] - d[10])
    V.local[5] = 0.000000f
    V.local[5] = (V.local[5] + d[5])
    V.local[5] = (V.local[5] + d[6])
    V.local[5] = (V.local[5] + d[9])
    V.local[5] = (V.local[5] + d[10])
    V.local[6] = 0.000000f
    V.local[6] = (V.local[6] - d[5])
    V.local[6] = (V.local[6] + d[6])
    V.local[6] = (V.local[6] - d[9])
    V.local[6] = (V.local[6] + d[10])
    V.local[7] = 0.000000f
    V.local[7] = (V.local[7] + d[5])
    V.local[7] = (V.local[7] - d[7])
    V.local[7] = (V.local[7] + d[9])
    V.local[7] = (V.local[7] - d[11])
    ...

But ideally I want minimum amount of add and sub, and remove add by 0.0f. The desired code is something like this

Is achieving minimal math possible with tvm? Same goes for inverse transform.

@aditya4d1
Copy link
Contributor

Great work @masahi . Can you share your performance numbers?

@masahi
Copy link
Member

masahi commented Jun 25, 2018

It is still preliminary, but I put some numbers here

I'm sure there are many opportunities for improvement.

@tqchen
Copy link
Member Author

tqchen commented Jul 26, 2018

#1487 provides winograd for cpu

@tqchen
Copy link
Member Author

tqchen commented Aug 8, 2018

Close as most winograd are checked in, let us open new threads for specific working items

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants