Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Adam FP32 JIT assembly kernel #39158

Merged
merged 2 commits into from
Feb 7, 2022
Merged

Conversation

jakpiase
Copy link
Contributor

@jakpiase jakpiase commented Jan 24, 2022

PR types

New features

PR changes

OPs

Describe

Added Adam FP32 JIT assembly kernel. This feature was requested by #39005.
All benchmarks were done on VGG training script - "test_image_classification.py". 100 batches of 128 images were tested.
Benchmark were done using: Intel(R) Xeon(R) Gold 6248 CPU @ 2.50GHz.

Performance comparison:

Threads Eigen Adam JIT assembly Adam JIT speed ratio
1 8407ms 5595ms 1.50x
20 8911ms 933ms 9.55x

Adam JIT FP32 20 threads

-------------------------     Overhead Summary      -------------------------

Total time: 255238
  Computation time       Total: 254831      Ratio: 99.8404%
  Framework overhead     Total: 407.465     Ratio: 0.159639%

-------------------------       Event Summary       -------------------------

Event                                Calls       Total       CPU Time (Ratio)        GPU Time (Ratio)        Min.        Max.        Ave.        Ratio.      
thread0::conv2d_grad                 1300        126588      126588.182723 (1.000000)0.000000 (0.000000)     6.78277     432.587     97.3755     0.49596     
thread0::conv2d                      1430        55082.9     55082.937061 (1.000000) 0.000000 (0.000000)     5.56672     182.007     38.5195     0.21581     
thread0::dropout                     1100        25468       25468.024844 (1.000000) 0.000000 (0.000000)     0.048006    134.681     23.1527     0.0997813   
thread0::elementwise_add_grad        1600        9103.74     9103.737412 (1.000000)  0.000000 (0.000000)     0.020911    28.7549     5.68984     0.0356676   
thread0::batch_norm_grad             1400        7034.51     7034.513896 (1.000000)  0.000000 (0.000000)     0.929491    28.8089     5.02465     0.0275606   
thread0::batch_norm                  1540        6037.51     6037.509643 (1.000000)  0.000000 (0.000000)     0.441304    15.7899     3.92046     0.0236544   
thread0::pool2d_grad                 500         5520.84     5520.839670 (1.000000)  0.000000 (0.000000)     0.919202    36.7582     11.0417     0.0216301   
thread0::relu_grad                   1400        5299.6      5299.604213 (1.000000)  0.000000 (0.000000)     0.38113     19.8601     3.78543     0.0207633   
thread0::elementwise_add             1760        4277.46     4277.464973 (1.000000)  0.000000 (0.000000)     0.016676    13.9591     2.43038     0.0167587   
thread0::pool2d                      550         2861.43     2861.433259 (1.000000)  0.000000 (0.000000)     0.743152    15.2987     5.20261     0.0112108   
thread0::relu                        1540        2846.08     2846.084105 (1.000000)  0.000000 (0.000000)     0.10837     10.9373     1.84811     0.0111507   
thread0::dropout_grad                1000        2452.26     2452.256979 (1.000000)  0.000000 (0.000000)     0.105874    16.9852     2.45226     0.00960771  
thread0::adam                        6000        933.149     933.148779 (1.000000)   0.000000 (0.000000)     0.015754    5.56512     0.155525    0.00365599
thread0::mul_grad                    300         710.188     710.188304 (1.000000)   0.000000 (0.000000)     0.197352    7.76205     2.36729     0.00278245  
thread0::mul                         330         458.734     458.734473 (1.000000)   0.000000 (0.000000)     0.07948     8.14803     1.3901      0.00179728  
thread0::gaussian_random             13          313.263     313.262564 (1.000000)   0.000000 (0.000000)     0.141104    51.069      24.0971     0.00122733  
thread0::uniform_random              3           163.252     163.252015 (1.000000)   0.000000 (0.000000)     0.353768    144.607     54.4173     0.000639606 
thread0::fill_constant               413         67.7006     67.700618 (1.000000)    0.000000 (0.000000)     0.005264    15.2662     0.163924    0.000265245 
thread0::top_k_v2                    110         3.25123     3.251234 (1.000000)     0.000000 (0.000000)     0.02469     0.051828    0.0295567   1.2738e-05  
thread0::softmax                     110         2.9233      2.923296 (1.000000)     0.000000 (0.000000)     0.022344    0.059341    0.0265754   1.14532e-05 
thread0::cross_entropy_grad2         100         2.85892     2.858918 (1.000000)     0.000000 (0.000000)     0.024797    0.044293    0.0285892   1.1201e-05  
thread0::cross_entropy2              110         2.37179     2.371794 (1.000000)     0.000000 (0.000000)     0.01654     0.046266    0.0215618   9.29246e-06 
thread0::accuracy                    110         2.18248     2.182484 (1.000000)     0.000000 (0.000000)     0.015547    0.052988    0.0198408   8.55076e-06 
thread0::mean_grad                   100         1.79176     1.791760 (1.000000)     0.000000 (0.000000)     0.014726    0.046336    0.0179176   7.01995e-06 
thread0::softmax_grad                100         1.66405     1.664051 (1.000000)     0.000000 (0.000000)     0.013539    0.031328    0.0166405   6.51959e-06 
thread0::mean                        110         1.42385     1.423853 (1.000000)     0.000000 (0.000000)     0.010557    0.02575     0.0129441   5.57852e-06 
thread0::feed                        220         0.956544    0.956544 (1.000000)     0.000000 (0.000000)     0.00087     0.017007    0.00434793  3.74765e-06 
thread0::fetch                       20          0.159529    0.159529 (1.000000)     0.000000 (0.000000)     0.002542    0.025782    0.00797645  6.25019e-07

Adam native FP32 20 threads


-------------------------     Overhead Summary      -------------------------

Total time: 262827
  Computation time       Total: 262393      Ratio: 99.8349%
  Framework overhead     Total: 433.863     Ratio: 0.165075%

-------------------------       Event Summary       -------------------------

Event                                Calls       Total       CPU Time (Ratio)        GPU Time (Ratio)        Min.        Max.        Ave.        Ratio.      
thread0::conv2d_grad                 1300        125850      125850.330032 (1.000000)0.000000 (0.000000)     6.90526     425.479     96.8079     0.478834    
thread0::conv2d                      1430        56455.7     56455.663905 (1.000000) 0.000000 (0.000000)     5.66905     169.554     39.4795     0.214802    
thread0::dropout                     1100        25422.8     25422.821121 (1.000000) 0.000000 (0.000000)     0.044683    134.634     23.1117     0.0967285   
thread0::elementwise_add_grad        1600        9082.94     9082.935860 (1.000000)  0.000000 (0.000000)     0.022123    28.3086     5.67683     0.0345587   
thread0::adam                        6000        8911.09     8911.090247 (1.000000)  0.000000 (0.000000)     0.013306    49.1169     1.48518     0.0339048   
thread0::batch_norm_grad             1400        6519.14     6519.135756 (1.000000)  0.000000 (0.000000)     0.870459    18.8263     4.65653     0.0248039   
thread0::pool2d_grad                 500         5656.66     5656.656907 (1.000000)  0.000000 (0.000000)     0.893712    37.4331     11.3133     0.0215224   
thread0::batch_norm                  1540        5483.95     5483.947854 (1.000000)  0.000000 (0.000000)     0.407364    15.2745     3.56101     0.0208653   
thread0::relu_grad                   1400        5200.33     5200.327987 (1.000000)  0.000000 (0.000000)     0.37653     18.0424     3.71452     0.0197862   
thread0::elementwise_add             1760        4236.93     4236.934522 (1.000000)  0.000000 (0.000000)     0.015608    13.7837     2.40735     0.0161206   
thread0::pool2d                      550         2853.04     2853.043354 (1.000000)  0.000000 (0.000000)     0.754373    14.8566     5.18735     0.0108552   
thread0::relu                        1540        2697.41     2697.414724 (1.000000)  0.000000 (0.000000)     0.120004    10.6862     1.75157     0.0102631   
thread0::dropout_grad                1000        2440.85     2440.850576 (1.000000)  0.000000 (0.000000)     0.106722    15.3806     2.44085     0.00928692  
thread0::mul_grad                    300         906.724     906.724042 (1.000000)   0.000000 (0.000000)     0.193557    11.5946     3.02241     0.00344989  
thread0::mul                         330         546.119     546.118975 (1.000000)   0.000000 (0.000000)     0.081294    6.63446     1.65491     0.00207787  
thread0::gaussian_random             13          313.92      313.920139 (1.000000)   0.000000 (0.000000)     0.156115    51.0499     24.1477     0.0011944   
thread0::uniform_random              3           159.991     159.990833 (1.000000)   0.000000 (0.000000)     0.344665    142.06      53.3303     0.000608732 
thread0::fill_constant               413         68.7027     68.702711 (1.000000)    0.000000 (0.000000)     0.008581    15.2479     0.16635     0.000261399 
thread0::top_k_v2                    110         3.38352     3.383520 (1.000000)     0.000000 (0.000000)     0.024547    0.052362    0.0307593   1.28736e-05 
thread0::softmax                     110         2.94007     2.940068 (1.000000)     0.000000 (0.000000)     0.02325     0.044121    0.0267279   1.11863e-05 
thread0::cross_entropy_grad2         100         2.73175     2.731754 (1.000000)     0.000000 (0.000000)     0.023973    0.034872    0.0273175   1.03938e-05 
thread0::cross_entropy2              110         2.39475     2.394752 (1.000000)     0.000000 (0.000000)     0.018195    0.030596    0.0217705   9.11153e-06 
thread0::accuracy                    110         2.18767     2.187670 (1.000000)     0.000000 (0.000000)     0.016385    0.030565    0.0198879   8.32363e-06 
thread0::mean_grad                   100         1.9015      1.901496 (1.000000)     0.000000 (0.000000)     0.01599     0.028001    0.019015    7.23479e-06 
thread0::softmax_grad                100         1.70133     1.701331 (1.000000)     0.000000 (0.000000)     0.013232    0.025132    0.0170133   6.47321e-06 
thread0::mean                        110         1.40548     1.405481 (1.000000)     0.000000 (0.000000)     0.010287    0.03371     0.0127771   5.34756e-06 
thread0::feed                        220         1.17216     1.172162 (1.000000)     0.000000 (0.000000)     0.00098     0.015056    0.00532801  4.45983e-06 
thread0::fetch                       20          0.144533    0.144533 (1.000000)     0.000000 (0.000000)     0.002516    0.025719    0.00722665  5.49918e-07 

Adam JIT FP32 1 thread

-------------------------     Overhead Summary      -------------------------

Total time: 397260
  Computation time       Total: 396849      Ratio: 99.8963%
  Framework overhead     Total: 411.861     Ratio: 0.103676%

-------------------------       Event Summary       -------------------------
Event                                Calls       Total       CPU Time (Ratio)        GPU Time (Ratio)        Min.        Max.        Ave.        Ratio.      
thread0::conv2d_grad                 1300        215251      215251.345686 (1.000000)0.000000 (0.000000)     6.75745     374.658     165.578     0.541839    
thread0::conv2d                      1430        95355.9     95355.893177 (1.000000) 0.000000 (0.000000)     9.10349     150.765     66.6824     0.240034    
thread0::dropout                     1100        23543.7     23543.724907 (1.000000) 0.000000 (0.000000)     0.051755    115.039     21.4034     0.0592652   
thread0::elementwise_add_grad        1600        8004.05     8004.047934 (1.000000)  0.000000 (0.000000)     0.020271    25.6009     5.00253     0.0201481   
thread0::mul                         330         7230.75     7230.750323 (1.000000)  0.000000 (0.000000)     0.248336    60.1106     21.9114     0.0182015   
thread0::batch_norm_grad             1400        6995.54     6995.542689 (1.000000)  0.000000 (0.000000)     0.88085     27.0562     4.99682     0.0176095   
thread0::mul_grad                    300         6786.89     6786.886619 (1.000000)  0.000000 (0.000000)     0.58572     69.4208     22.623      0.0170842   
thread0::batch_norm                  1540        5934.5      5934.499931 (1.000000)  0.000000 (0.000000)     0.4169      14.9896     3.85357     0.0149386   
thread0::adam                        6000        5595.49     5595.491115 (1.000000)  0.000000 (0.000000)     0.011879    48.9321     0.932582    0.0140852   
thread0::pool2d_grad                 500         5461.87     5461.865221 (1.000000)  0.000000 (0.000000)     0.844246    40.6597     10.9237     0.0137488   
thread0::relu_grad                   1400        4991.32     4991.315824 (1.000000)  0.000000 (0.000000)     0.354873    24.2478     3.56523     0.0125643   
thread0::elementwise_add             1760        3721.69     3721.694840 (1.000000)  0.000000 (0.000000)     0.016622    13.2252     2.1146      0.0093684   
thread0::relu                        1540        2866.11     2866.111641 (1.000000)  0.000000 (0.000000)     0.104892    10.6162     1.86111     0.00721469  
thread0::pool2d                      550         2645.78     2645.783084 (1.000000)  0.000000 (0.000000)     0.681516    12.7018     4.81051     0.00666007  
thread0::dropout_grad                1000        2315.56     2315.558503 (1.000000)  0.000000 (0.000000)     0.089981    16.5313     2.31556     0.00582882  
thread0::gaussian_random             13          314.031     314.030642 (1.000000)   0.000000 (0.000000)     0.142971    50.7919     24.1562     0.00079049  
thread0::uniform_random              3           158.698     158.697587 (1.000000)   0.000000 (0.000000)     0.336269    140.795     52.8992     0.00039948  
thread0::fill_constant               413         67.201      67.200967 (1.000000)    0.000000 (0.000000)     0.005053    15.1799     0.162714    0.000169161 
thread0::top_k_v2                    110         5.0654      5.065401 (1.000000)     0.000000 (0.000000)     0.035961    0.087723    0.0460491   1.27508e-05 
thread0::softmax                     110         2.77214     2.772139 (1.000000)     0.000000 (0.000000)     0.022062    0.058787    0.0252013   6.97814e-06 
thread0::cross_entropy_grad2         100         2.46541     2.465410 (1.000000)     0.000000 (0.000000)     0.021999    0.036969    0.0246541   6.20603e-06 
thread0::cross_entropy2              110         2.23243     2.232426 (1.000000)     0.000000 (0.000000)     0.01678     0.035619    0.0202948   5.61955e-06 
thread0::accuracy                    110         1.93446     1.934458 (1.000000)     0.000000 (0.000000)     0.015228    0.026053    0.017586    4.8695e-06  
thread0::mean_grad                   100         1.63768     1.637677 (1.000000)     0.000000 (0.000000)     0.014113    0.037605    0.0163768   4.12243e-06 
thread0::softmax_grad                100         1.40975     1.409747 (1.000000)     0.000000 (0.000000)     0.011879    0.025825    0.0140975   3.54867e-06 
thread0::mean                        110         1.38516     1.385164 (1.000000)     0.000000 (0.000000)     0.010333    0.022291    0.0125924   3.48679e-06 
thread0::feed                        220         1.01252     1.012522 (1.000000)     0.000000 (0.000000)     0.00092     0.017252    0.00460237  2.54876e-06 
thread0::fetch                       20          0.140892    0.140892 (1.000000)     0.000000 (0.000000)     0.002419    0.0329      0.0070446   3.54659e-07 

Adam native FP32 1 thread


-------------------------     Overhead Summary      -------------------------

Total time: 394698
  Computation time       Total: 394300      Ratio: 99.8992%
  Framework overhead     Total: 397.736     Ratio: 0.100768%

-------------------------       Event Summary       -------------------------

Event                                Calls       Total       CPU Time (Ratio)        GPU Time (Ratio)        Min.        Max.        Ave.        Ratio.      
thread0::conv2d_grad                 1300        212333      212332.646480 (1.000000)0.000000 (0.000000)     6.68817     375.206     163.333     0.537962    
thread0::conv2d                      1430        94353.3     94353.326454 (1.000000) 0.000000 (0.000000)     9.09199     150.12      65.9813     0.239052    
thread0::dropout                     1100        23547.6     23547.582396 (1.000000) 0.000000 (0.000000)     0.055426    112.812     21.4069     0.0596597   
thread0::adam                        6000        8407.12     8407.115654 (1.000000)  0.000000 (0.000000)     0.011201    43.8112     1.40119     0.0213001   
thread0::elementwise_add_grad        1600        7995.02     7995.015518 (1.000000)  0.000000 (0.000000)     0.020061    25.8042     4.99688     0.020256    
thread0::mul                         330         7179.66     7179.658238 (1.000000)  0.000000 (0.000000)     0.232449    60.5752     21.7565     0.0181903   
thread0::mul_grad                    300         6732.35     6732.346253 (1.000000)  0.000000 (0.000000)     0.59177     68.4031     22.4412     0.017057    
thread0::batch_norm_grad             1400        6507.14     6507.141797 (1.000000)  0.000000 (0.000000)     0.862287    26.3169     4.64796     0.0164864   
thread0::batch_norm                  1540        5507.27     5507.270106 (1.000000)  0.000000 (0.000000)     0.409352    14.7317     3.57615     0.0139531   
thread0::pool2d_grad                 500         5322.61     5322.610129 (1.000000)  0.000000 (0.000000)     0.760214    40.4335     10.6452     0.0134853   
thread0::relu_grad                   1400        4973.64     4973.639114 (1.000000)  0.000000 (0.000000)     0.348559    23.9322     3.5526      0.0126011   
thread0::elementwise_add             1760        3671.73     3671.733153 (1.000000)  0.000000 (0.000000)     0.017316    14.6762     2.08621     0.00930264  
thread0::relu                        1540        2688.45     2688.447007 (1.000000)  0.000000 (0.000000)     0.148117    10.2019     1.74574     0.0068114   
thread0::pool2d                      550         2641.52     2641.524530 (1.000000)  0.000000 (0.000000)     0.697208    13.136      4.80277     0.00669252  
thread0::dropout_grad                1000        2280.26     2280.257859 (1.000000)  0.000000 (0.000000)     0.089909    16.325      2.28026     0.00577722  
thread0::gaussian_random             13          312.219     312.218866 (1.000000)   0.000000 (0.000000)     0.152403    50.1217     24.0168     0.000791032 
thread0::uniform_random              3           158.901     158.900562 (1.000000)   0.000000 (0.000000)     0.334984    140.981     52.9669     0.000402588 
thread0::fill_constant               413         66.9596     66.959555 (1.000000)    0.000000 (0.000000)     0.005256    15.2514     0.16213     0.000169648 
thread0::top_k_v2                    110         4.99694     4.996938 (1.000000)     0.000000 (0.000000)     0.034938    0.087521    0.0454267   1.26602e-05 
thread0::softmax                     110         2.74408     2.744077 (1.000000)     0.000000 (0.000000)     0.022586    0.060981    0.0249462   6.95235e-06 
thread0::cross_entropy_grad2         100         2.41326     2.413263 (1.000000)     0.000000 (0.000000)     0.021924    0.035475    0.0241326   6.1142e-06  
thread0::cross_entropy2              110         2.14528     2.145282 (1.000000)     0.000000 (0.000000)     0.017247    0.032467    0.0195026   5.43525e-06 
thread0::accuracy                    110         1.87733     1.877333 (1.000000)     0.000000 (0.000000)     0.013957    0.030046    0.0170667   4.75638e-06 
thread0::mean_grad                   100         1.61598     1.615976 (1.000000)     0.000000 (0.000000)     0.014196    0.040255    0.0161598   4.09421e-06 
thread0::softmax_grad                100         1.38025     1.380251 (1.000000)     0.000000 (0.000000)     0.011881    0.026854    0.0138025   3.49698e-06 
thread0::mean                        110         1.2885      1.288499 (1.000000)     0.000000 (0.000000)     0.009255    0.020004    0.0117136   3.26452e-06 
thread0::feed                        220         1.02354     1.023536 (1.000000)     0.000000 (0.000000)     0.000984    0.015596    0.00465244  2.59321e-06 
thread0::fetch                       20          0.150682    0.150682 (1.000000)     0.000000 (0.000000)     0.002294    0.042685    0.0075341   3.81765e-07 

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@jakpiase jakpiase added the Intel label Jan 24, 2022
@jakpiase jakpiase force-pushed the adam_jit branch 2 times, most recently from 1d5bdf8 to 9e033d8 Compare January 26, 2022 22:54
@jakpiase jakpiase requested a review from jczaja January 27, 2022 02:20
@jakpiase
Copy link
Contributor Author

@Silv3S please review this PR

@haohongxiang
Copy link
Contributor

I see that you've implemented Adam-CPU, but it can't cover cases using Adamw. So I hope that Adamw-CPU can be also implemented to provide higher performance. Thx.

@jakpiase
Copy link
Contributor Author

Hi @haohongxiang, Adamw can also be implemented, but that would be probably done in future in another PR.

@jczaja jczaja requested a review from sfraczek January 27, 2022 14:48
@jczaja
Copy link
Contributor

jczaja commented Jan 27, 2022

@pawelpiotrowicz , @tsocha please help with review

@jczaja
Copy link
Contributor

jczaja commented Jan 27, 2022

Excellet contribution! LGTM

@haohongxiang
Copy link
Contributor

Hi @haohongxiang, Adamw can also be implemented, but that would be probably done in future in another PR.

OK, thx.

Copy link
Contributor

@pawelpiotrowicz pawelpiotrowicz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@jakpiase
Copy link
Contributor Author

@Aganlengzi could you please start your review?

Copy link
Contributor

@sfraczek sfraczek left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@Silv3S
Copy link
Member

Silv3S commented Feb 1, 2022

LGTM

@paddle-bot-old
Copy link

paddle-bot-old bot commented Feb 3, 2022

Sorry to inform you that 9e033d8's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.

Copy link
Contributor

@Aganlengzi Aganlengzi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants