Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support different dtypes of inputs for elementwise ops #38859

Merged
merged 12 commits into from
Feb 11, 2022

Conversation

zhangting2020
Copy link
Contributor

@zhangting2020 zhangting2020 commented Jan 11, 2022

PR types

Function optimization

PR changes

OPs

Describe

Support different dtypes of inputs for elementwise ops

背景

  • dropout反向算子为elementwise类型运算,但输入类型不一致。
  • 原始的elementwise模版仅支持相同dtype输入,因此dropout反向算子无法使用LaunchSameDimsElementwiseCudaKernel应用向量化优化。

修改内容

使用std::tuple支持不同类型输入,对原始实现的主要修改点如下:

  • 因使用tuple支持,实现中有 ArgsTuple a;这种写法,计算函数参数不能使用引用传递,因此改为了值传递。
  • 对需要遍历输入并且需要使用输入类型的接口,例如GetVectorizedSizeForTensors、kps::Init、kps::ReadData等,实现static unroller
  • LaunchSameDimsElementwiseCudaKernel的模版参数简化,仅保留了如下2个;在实际调用时,若不是多输出,则写法为LaunchSameDimsElementwiseCudaKernel(dev_ctx, ins, &outs, functor);
template <typename Functor, int NumOuts = 1>
void LaunchSameDimsElementwiseCudaKernel
  • 原始实现中,最后一层ElementwisePrimitiveCaller需要分别实现一元、二元、三元、任意输入个数的caller接口,在目前的实现下,可统一为一个。但因为broadcast情况下用到,因此暂时无法清理代码。
  • 因kps::Init、kps::ReadData需要适配tuple输入,但因其他类型运算的调用暂时不能统一;本PR对GPU、XPU2分别新增了kps::Init、kps::ReadData接口

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

Copy link
Contributor

@JamesLim-sy JamesLim-sy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great job!FunctorAny 作为一种过渡时期的产物,退出历史舞台了

Func<Begin, VecSize>::Apply(std::forward<Args>(args)...);
Unroller<Func, VecSize, End, Begin + 1>::step(args...);
}
};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果reduce也需要支持多类型输入或者输出的话,感觉迭代器这部分的代码感觉可以放到 function_traits.h内,扩大适用范围,不过是否有必要需要 @niuliling123 @ZzSean @Xreki 大佬们判断

Copy link
Contributor

@Xreki Xreki Jan 12, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • reduce现在就只有1个输入,暂时还没涉及到多类型问题,当前reduce也完全没有用到function_traits
  • 这部分功能是否要放到function_traits.h内,主要还是看该功能是否通用,是否是function traits。我觉得也确实有一部分属于相关的基础功能,可以考虑放进去

@paddle-bot-old
Copy link

Sorry to inform you that 0ed5439's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.

const ArgsT &args,
int *vec_size) {
using Type = std::tuple_element_t<Index, ArgsT>;
*vec_size = std::min<int>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vec_size不能作为返回值?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是因为这里在访问每一个输入Tensor时就需要知道dtype,只能套用Unroller去处理。在这个功能设计里,unroller模版是一个统一的接口,InputSetter、Loader这2个过程都需要用unroller,而这2个接口没有返回值。所以为了统一,只能将vecsize作为函数参数

}
ArgsT arg;
// The Arg VecSize=1 is to match the Unroller template.
Unroller<VecSizeGetter, 1, Arity>::step(ins, arg, &vec_size);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vec_size不是只与输入、输出的长度相关,需要加这些逻辑吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里没有修改原始逻辑,只是将原来遍历所有输入计算vecsize的for循环过程,换成了unroller

namespace detail {
template <class F, class Tuple, std::size_t... INDEX>
// GCC/Clang need the decltype() return type
HOSTDEVICE constexpr decltype(auto) apply_impl(F &&f,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

apply_impl -> ApplyImpl

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

std::forward<Tuple>(t),
std::make_index_sequence<
std::tuple_size<std::remove_reference_t<Tuple>>::value>{});
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

UnrollApply这些基础功能,是否适合挪到FunctionTraits里面?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

考虑到本身unroller和apply不属于FunctionTraits,只是为了实现提取多个输入类型的功能而配合使用的,所以本质上来说不太适合。

另外,在这个功能中,Unroller在多处使用,为了配合多种场景,不得不添加Vecsize模版参数。因此通用性上会有一些问题,类似于本PR中的VecsizeGetter功能,虽然并没有用到VecSize这个模版参数,但在接口调用时不得不设置它,否则就得再增加一个Unroller。

@@ -505,6 +579,39 @@ struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, 3, false> {
}
};

namespace detail {
template <class F, class Tuple, std::size_t... INDEX>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

INDEX -> Index

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Functor,
Arity,
kCallElementwiseAny>()(func, args, result);
SameDimsElementwisePrimitiveCaller<ConditionalT<OutT, NumOuts>,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

有个疑问,Any这种计算模式是不是不支持了?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any这种计算模式依然是保留的。只是在SameDims的情况下不再需要了,但没有删除,因为原始实现最后是在ElementwisePrimitiveCaller接口中决定是否调用Any这种模式,目前Broadcast在用ElementwisePrimitiveCaller,所以依然保留

AnnaTrainingG
AnnaTrainingG previously approved these changes Feb 10, 2022
Xreki
Xreki previously approved these changes Feb 10, 2022
Copy link
Contributor

@Xreki Xreki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM and good work~

@zhangting2020 zhangting2020 merged commit bf30503 into PaddlePaddle:develop Feb 11, 2022
Shixiaowei02 added a commit that referenced this pull request Feb 16, 2022
* 【Pten】Adjust the Empyt dev_api (#39143)

* adjust the Empyt dev_api

* fix merge conflict

* fix sparse_utils_kernel

* Fix code conflict of empty dev_api (#39430)

* fix code conflict

* clear cache

* just try

* [PluggableDevice] custom kernel supports multi cpp_dtype registering (#39385)

* [PTen] Add standard kernel suffix set (#39404)

* add standard_suffix_set_and_remove_reshape_with_xshape

* revert reshape change

* polish reduce name

* [pten] update isnan registration (#39419)

* update isnan registration

* fix compile

* [bf16] add bf16 kernel: dropout & reshape & slice (#39395)

* add dropout

* add reshape

* add slice

* refien slice unittest

* refine slice unittest

* add cpu bf16 kernel

* [bf16] add bf16 kernel: squeeze & unsqueeze & stack (#39402)

* add squeeze unsqueeze stack

* add unittest

* add cpu kernel

* Modify the unsqueeze dimension of input data in conv1d NCL And NLC format (#38425)

* optimize conv1d forward

* add conv opt

* Optimize memory copy

* delete share data with

* set num_filters=512

* add nlc optimize

* Optimize num_filter=512 data on A100 and V100

* Fix the workspace_size size setting of filter

* 【Pten】Refactor C++ API code-gen (#39408)

* refactor C++ API code-gen

* fix windows problem of C++ API

* Refactored Python-C Attributes Parsing Functions (#39328)

* Add _get_parameter method to Lamb optimizer (#39416)

* add _get_parameter func to lamb

* remove duplicate code

* mkldnn layout issue fix (#39422)

* mkldnn conv fix

* definetion

* fix compile error on jetson (#39441)

* move Masked select to pten (#39193)

* move masked select cpu kernel

* add masked selected gpu kernel; test=develop

* fix bugs; test=develop

* bug fix; test=develop

* bug fix; test=develop

* add namespace to set mask array; test=develop

* fix bug; test=develop

* fix bugs; test=develop

* fix ddim bug; test=develop

* fix npu op bug; test=develop

* fix xpu dependecy bug; test=develop

* move kernel args to sig.cc; test=develop

* 【PaddlePaddle Hackathon】31. Add Java frontend for Paddle Inference  (#37162)

* fix check error of ResetHolder (#39439)

* Added python-c code generation for final state Eager Dygraph (#39233)

* Removed debug info

* Added automatic code generation for final state Eager Dygraph

* Modified backward yaml

* Added EagerUtils helper functions for final state CodeGen

* Adjusted CMakeFiles to support compilation for final state auto generated codes

* Added python-c code generation for final state Eager Dygraph

* Fixed minor issue

* Fixed yaml.load() method failure

* Fixed minor issues

* Refactored Python-C Attributes Parsing Functions

* Fixed minor issue with Python-C AddFunctions

* Fixed issues from merge

* Fixed merge issues

* change dtype of pooling mask to 'int32' for Paddle2ONNX (#39314)

* change dtype of pooling mask to 'int32' for Paddle2ONNX

* empty commit to rerun ci

* fix format

* share MemOptVarInfos of external variables into cinn_launch subgraph (#39209)

* add a graph pass to share MemOptVarInfos of external variables into subgraph

* update pass name

* fix compile failed

* add share_mem_opt_info_to_subgraph_pass test

* share_mem_opt_info_to_subgraph_pass_test pass

* modify some codes for better style and more robust

* update cmake

* [NPU] add reduce_min (#39019)

[NPU] add reduce_min

* [MLU] add mlu kernel for accuracy op (#39337)

* [MLU] add mlu kernel for accuracy op

* fix license format

* fix error message

* [Dy2St]Handle `a, b = paddle.shape(x)` in Static Analysis (#39245)

* refine Assign

* add UT

* 【Pten】Auto-Generate InterMeta register (#39436)

* fix code conflict

* generate inter_meta register

* clear cache

* just try

* add sign c++ api

* polish some code

* Support different dtypes of inputs for elementwise ops (#38859)

* improve backward performance

* support different dtypes for elementwise ops

* Add profiler node tree implementation (#39316)

* add event node implementation

* modify profiler.stop interface

* fix according to review

* fix file mode

* modify class method name in event_node.cc

* modify LLONG_MAX to ULLONG_MAX

* fix ci error

* fix ci error

* add print pten kernel tool (#39371)

* test=document_fix;add print pten kernel tool

* test=document_fix

* test=document_fix

* test=document_fix

* test=document_fix

* add print_pten_kernels tool

* add print_pten_kernels tool

* fix windows complie

* notest,test=rocm_ci

* add merge tool

* add comments

* [new-exec] set type of op-kernel op by place (#39458)

* Add log for executor (#39459)

* add align for WorkQueue

* add spinlock

* merge develop

* merge

* Add EventsWaiter

* Revert "Add EventsWaiter"

This reverts commit e206173.

* add log for Executor

Co-authored-by: liutiexing <liutiexing@google.com>

* [Paddle Inference] support ernie quant model with interleaved (#39424)

* support ernie quant model with interleaved

* support ernie quant model with interleaved

* support ernie quant model with interleaved

* support ernie quant model with interleaved

* support ernie quant model with interleaved

* support ernie quant model with interleaved

* support ernie quant model with interleaved

* 统一 ps 开发 - python (#39431)

* delete gloo connect retry

* the_one_ps dirs reconstruct

* .

* .

* create the_one_ps dirs

* create the_one_ps dirs

* create the_one_ps dirs

* create the_one_ps dirs

* create the_one_ps dirs

* create the_one_ps dirs

* the one ps dirs modify

* the one ps dirs modify

* the one ps dirs modify

* the one ps dirs modify

* refactor ps optimize

* refactor ps optimize

* refactor ps optimize

* .

* .

* .

* .

* .

* .

* refactor theoneps

* the_one_ps

* add ps pass unittest

* add ps pass unittest

* ps unitest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* add cpu_async_ps_mode test

* add cpu_async_ps_mode test

* add cpu_async_ps_mode test

* ps unittest ready

* ps unittest ready

* solve dist_pass init conflict

* solve import CommContext error

* unittest ok

* implement AllocateFrom

* solve setup.py.in conflict

* solve conflict

* solve conflict

* solve conflict

* .

* .

* cpu-async-ps minimize test ok & gpu minimize test ok

Co-authored-by: zkh2016 <zhangkaihuo@baidu.com>

* [PTen] Move grad GetExpectedPtenKernelArgs into pten (#39418)

* move grad get expected pten kernel args

* fix reduce sum error

* fix element_sub_grad failed

* revert kernel judge change

* fix compilation warning on mac (#39438)

* get build time (#39368)

* fix prelu trt convert (#39389)

* Optimize bilinear interpolation foward (#39243)

* bilinear_fw init

* optimize code

* pre-compute linear_interp input index

* Optimize performance of softmax_bwd when axis!=-1 (#38609)

* Optimize performance of softmax_bwd when axis!=-1

* fix

* fix

* fix

* fix

* [PTen] Remove pten core's dependency on fluid xxx_info.h (#39401)

* ermove xxx_info include

* fix namespace error

* resolve conflict

* skip xpu context in registry

* fix macro error

* resolve conflict

* resolve conflict

* revert xpu convert

* remove trans to fluid place

* remove useless headers

* [Pten] move operators/math/math_function_* to pten/kernels/func (#39300)

* move operators/math/math_function_* to pten/kernels/func
* namespace from `paddle::operators::math` to `pten::funcs`

* [MLU] add pool2d and pool2d_grad mlu kernel (#39453)

* [MLU]support c_gen_cncl_id_op run on MLU device (#39336)

Co-authored-by: zhangna <zhangna@cambricon.com>

* [bf16] add bf16 kernel: transpose & unbind (#39457)

* add transpose unbind

* add unittest

* refine transpose unittest

* uniform_random op for mlu (#39450)

* [MLU] add pool2d pytest (#39454)

* Added shape (U)INT8/BF16/FP32 oneDNN kernel (#36033)

* added shape oneDNN kernel

* removed unnecessary import from test

* added skipping tests for GPU

* refactoring

* refactored shape kernel

* added tests in new framework

* removed one line

* minor change

* added newline at EOF

* added formatting

* added attributes as extra

* move memcpy.h into cc file (#39469)

* Add TensorRT inspector into Paddle-TRT (#38362)

* Fix add profiler node tree implementation cmake error (#39474)

* add event node implementation

* modify profiler.stop interface

* fix according to review

* fix file mode

* modify class method name in event_node.cc

* modify LLONG_MAX to ULLONG_MAX

* fix ci error

* fix ci error

* fix dependency error

* unify naming style (#39481)

* [Pten] Generate Wrapped InferMeta by Yaml (#39482)

* generate wrapped_infer_meta

* add test for wrapped_infer_meta

* Update test_meta_fn_utils.cc

* change the dir of generated file

Co-authored-by: Chen Weihang <chenweihang@baidu.com>
Co-authored-by: Chen Weihang <chenwhpro@163.com>

* Adjusted python-level trace_op to accomodate final state Eager Dygraph (#39319)

* Removed debug info

* Added automatic code generation for final state Eager Dygraph

* Modified backward yaml

* Added EagerUtils helper functions for final state CodeGen

* Adjusted CMakeFiles to support compilation for final state auto generated codes

* Added python-c code generation for final state Eager Dygraph

* Fixed minor issue

* Fixed yaml.load() method failure

* Fixed minor issues

* Refactored Python-C Attributes Parsing Functions

* Fixed minor issue with Python-C AddFunctions

* Adjusted python-level trace_op to accomodate final state Eager Dygraph

* Added Logs for final state Eager Dygraph

* Fixed merge issues

* Fixed minor issue

* Fixed get_tensor method for EagerTensor (#39414)

* Enabled Eager OpTest #1

* Enabled Eager OpTest #1

* Fixed get_tensor method for EagerTensor

* [Approver Update] update check approver of qili93, test=document_fix (#39483)

* [MLU] add mlu kernel for c_broadcast op (#39470)

* update xpu test build script and fix get_test_cover_info, *test=kunlun (#39235)

* fix gather_nd, *test=kunlun (#39283)

* [pten] add split kernel (#39060)

* add split kernel

* add split kernel signature

* fix split bug

* modify MakePtenScalarArrayFromVarList

* modify MakePtenScalarArrayFromVarList

* fix split windows register error

* add test case for split kernel

* replace raw split kernel with pten kernel

* fix makeScalar/ScalarArray bug

* remove debug log

* remove int64_t type in buildPtcontext

* update by code review

* fix split dev test failed

* change DenseTensorMeta to MetaTensor

* change split api code from auto gen to manual

* split cuda kernel support bfloat16 type

* fix conflict

* rm raw split kernel

* merge develop branch

* change to pten::errors

* new may of test cases, *test=kunlun (#39444)

* new may of test cases, *test=kunlun

* new may of test cases, *test=kunlun

* new may of test cases, *test=kunlun

* [PTen] Add HasAttr for ArgumentMappingContext (#39464)

* add has_attr for arg map context

* skip useless attr now

* skip attr if not exists

* fix typo

* [ROCm] fix missing dcu kernel in operator.cmake, test=develop (#39480)

Co-authored-by: zyfncg <zhangyunfei07@baidu.com>
Co-authored-by: Aganlengzi <aganlengzi@gmail.com>
Co-authored-by: Chen Weihang <chenweihang@baidu.com>
Co-authored-by: Leo Chen <chenqiuliang@baidu.com>
Co-authored-by: zhangbo9674 <82555433+zhangbo9674@users.noreply.github.com>
Co-authored-by: crystal <62974595+Zjq9409@users.noreply.github.com>
Co-authored-by: Zhanlue Yang <jim19930609@gmail.com>
Co-authored-by: sneaxiy <32832641+sneaxiy@users.noreply.github.com>
Co-authored-by: wenbin <wang3323032@qq.com>
Co-authored-by: Wilber <jiweibo@baidu.com>
Co-authored-by: hong <43953930+phlrain@users.noreply.github.com>
Co-authored-by: chenyanlann <62465397+chenyanlann@users.noreply.github.com>
Co-authored-by: Wei Shengyu <weisy11@163.com>
Co-authored-by: TeFeng Chen <ctfeng66@163.com>
Co-authored-by: furnace <34057289+windstamp@users.noreply.github.com>
Co-authored-by: fwenguang <95677191+fwenguang@users.noreply.github.com>
Co-authored-by: 0x45f <23097963+0x45f@users.noreply.github.com>
Co-authored-by: Zhang Ting <zhangting_2017@163.com>
Co-authored-by: chenjian <chenjian26@baidu.com>
Co-authored-by: Shang Zhizhou <shangzhizhou@baidu.com>
Co-authored-by: liutiexing <74819124+liutiexing@users.noreply.github.com>
Co-authored-by: liutiexing <liutiexing@google.com>
Co-authored-by: Wangzheee <634486483@qq.com>
Co-authored-by: ziyoujiyi <73728031+ziyoujiyi@users.noreply.github.com>
Co-authored-by: zkh2016 <zhangkaihuo@baidu.com>
Co-authored-by: zhangchunle <clzhang_cauc@163.com>
Co-authored-by: JingZhuangzhuang <75348594+JZZ-NOTE@users.noreply.github.com>
Co-authored-by: Lijunhui <1578034415@qq.com>
Co-authored-by: Zhang Zheng <32410583+ZzSean@users.noreply.github.com>
Co-authored-by: Feiyu Chan <chenfeiyu@baidu.com>
Co-authored-by: zn <96479180+kangna-qi@users.noreply.github.com>
Co-authored-by: zhangna <zhangna@cambricon.com>
Co-authored-by: joeqiao12 <45232181+joeqiao12@users.noreply.github.com>
Co-authored-by: jakpiase <jakpia21@gmail.com>
Co-authored-by: Leo Chen <39020268+leo0519@users.noreply.github.com>
Co-authored-by: Chen Weihang <chenwhpro@163.com>
Co-authored-by: Qi Li <qili93@qq.com>
Co-authored-by: maxhuiy <1508399706@qq.com>
Co-authored-by: TTerror <tangzhiyi11@users.noreply.github.com>
Co-authored-by: chentianyu03 <chentianyu03@baidu.com>
Co-authored-by: helen88 <z8hanghuan@126.com>
winter-wang pushed a commit to winter-wang/Paddle that referenced this pull request Feb 16, 2022
* 【Pten】Adjust the Empyt dev_api (PaddlePaddle#39143)

* adjust the Empyt dev_api

* fix merge conflict

* fix sparse_utils_kernel

* Fix code conflict of empty dev_api (PaddlePaddle#39430)

* fix code conflict

* clear cache

* just try

* [PluggableDevice] custom kernel supports multi cpp_dtype registering (PaddlePaddle#39385)

* [PTen] Add standard kernel suffix set (PaddlePaddle#39404)

* add standard_suffix_set_and_remove_reshape_with_xshape

* revert reshape change

* polish reduce name

* [pten] update isnan registration (PaddlePaddle#39419)

* update isnan registration

* fix compile

* [bf16] add bf16 kernel: dropout & reshape & slice (PaddlePaddle#39395)

* add dropout

* add reshape

* add slice

* refien slice unittest

* refine slice unittest

* add cpu bf16 kernel

* [bf16] add bf16 kernel: squeeze & unsqueeze & stack (PaddlePaddle#39402)

* add squeeze unsqueeze stack

* add unittest

* add cpu kernel

* Modify the unsqueeze dimension of input data in conv1d NCL And NLC format (PaddlePaddle#38425)

* optimize conv1d forward

* add conv opt

* Optimize memory copy

* delete share data with

* set num_filters=512

* add nlc optimize

* Optimize num_filter=512 data on A100 and V100

* Fix the workspace_size size setting of filter

* 【Pten】Refactor C++ API code-gen (PaddlePaddle#39408)

* refactor C++ API code-gen

* fix windows problem of C++ API

* Refactored Python-C Attributes Parsing Functions (PaddlePaddle#39328)

* Add _get_parameter method to Lamb optimizer (PaddlePaddle#39416)

* add _get_parameter func to lamb

* remove duplicate code

* mkldnn layout issue fix (PaddlePaddle#39422)

* mkldnn conv fix

* definetion

* fix compile error on jetson (PaddlePaddle#39441)

* move Masked select to pten (PaddlePaddle#39193)

* move masked select cpu kernel

* add masked selected gpu kernel; test=develop

* fix bugs; test=develop

* bug fix; test=develop

* bug fix; test=develop

* add namespace to set mask array; test=develop

* fix bug; test=develop

* fix bugs; test=develop

* fix ddim bug; test=develop

* fix npu op bug; test=develop

* fix xpu dependecy bug; test=develop

* move kernel args to sig.cc; test=develop

* 【PaddlePaddle Hackathon】31. Add Java frontend for Paddle Inference  (PaddlePaddle#37162)

* fix check error of ResetHolder (PaddlePaddle#39439)

* Added python-c code generation for final state Eager Dygraph (PaddlePaddle#39233)

* Removed debug info

* Added automatic code generation for final state Eager Dygraph

* Modified backward yaml

* Added EagerUtils helper functions for final state CodeGen

* Adjusted CMakeFiles to support compilation for final state auto generated codes

* Added python-c code generation for final state Eager Dygraph

* Fixed minor issue

* Fixed yaml.load() method failure

* Fixed minor issues

* Refactored Python-C Attributes Parsing Functions

* Fixed minor issue with Python-C AddFunctions

* Fixed issues from merge

* Fixed merge issues

* change dtype of pooling mask to 'int32' for Paddle2ONNX (PaddlePaddle#39314)

* change dtype of pooling mask to 'int32' for Paddle2ONNX

* empty commit to rerun ci

* fix format

* share MemOptVarInfos of external variables into cinn_launch subgraph (PaddlePaddle#39209)

* add a graph pass to share MemOptVarInfos of external variables into subgraph

* update pass name

* fix compile failed

* add share_mem_opt_info_to_subgraph_pass test

* share_mem_opt_info_to_subgraph_pass_test pass

* modify some codes for better style and more robust

* update cmake

* [NPU] add reduce_min (PaddlePaddle#39019)

[NPU] add reduce_min

* [MLU] add mlu kernel for accuracy op (PaddlePaddle#39337)

* [MLU] add mlu kernel for accuracy op

* fix license format

* fix error message

* [Dy2St]Handle `a, b = paddle.shape(x)` in Static Analysis (PaddlePaddle#39245)

* refine Assign

* add UT

* 【Pten】Auto-Generate InterMeta register (PaddlePaddle#39436)

* fix code conflict

* generate inter_meta register

* clear cache

* just try

* add sign c++ api

* polish some code

* Support different dtypes of inputs for elementwise ops (PaddlePaddle#38859)

* improve backward performance

* support different dtypes for elementwise ops

* Add profiler node tree implementation (PaddlePaddle#39316)

* add event node implementation

* modify profiler.stop interface

* fix according to review

* fix file mode

* modify class method name in event_node.cc

* modify LLONG_MAX to ULLONG_MAX

* fix ci error

* fix ci error

* add print pten kernel tool (PaddlePaddle#39371)

* test=document_fix;add print pten kernel tool

* test=document_fix

* test=document_fix

* test=document_fix

* test=document_fix

* add print_pten_kernels tool

* add print_pten_kernels tool

* fix windows complie

* notest,test=rocm_ci

* add merge tool

* add comments

* [new-exec] set type of op-kernel op by place (PaddlePaddle#39458)

* Add log for executor (PaddlePaddle#39459)

* add align for WorkQueue

* add spinlock

* merge develop

* merge

* Add EventsWaiter

* Revert "Add EventsWaiter"

This reverts commit e206173.

* add log for Executor

Co-authored-by: liutiexing <liutiexing@google.com>

* [Paddle Inference] support ernie quant model with interleaved (PaddlePaddle#39424)

* support ernie quant model with interleaved

* support ernie quant model with interleaved

* support ernie quant model with interleaved

* support ernie quant model with interleaved

* support ernie quant model with interleaved

* support ernie quant model with interleaved

* support ernie quant model with interleaved

* 统一 ps 开发 - python (PaddlePaddle#39431)

* delete gloo connect retry

* the_one_ps dirs reconstruct

* .

* .

* create the_one_ps dirs

* create the_one_ps dirs

* create the_one_ps dirs

* create the_one_ps dirs

* create the_one_ps dirs

* create the_one_ps dirs

* the one ps dirs modify

* the one ps dirs modify

* the one ps dirs modify

* the one ps dirs modify

* refactor ps optimize

* refactor ps optimize

* refactor ps optimize

* .

* .

* .

* .

* .

* .

* refactor theoneps

* the_one_ps

* add ps pass unittest

* add ps pass unittest

* ps unitest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* ps unittest frame

* add cpu_async_ps_mode test

* add cpu_async_ps_mode test

* add cpu_async_ps_mode test

* ps unittest ready

* ps unittest ready

* solve dist_pass init conflict

* solve import CommContext error

* unittest ok

* implement AllocateFrom

* solve setup.py.in conflict

* solve conflict

* solve conflict

* solve conflict

* .

* .

* cpu-async-ps minimize test ok & gpu minimize test ok

Co-authored-by: zkh2016 <zhangkaihuo@baidu.com>

* [PTen] Move grad GetExpectedPtenKernelArgs into pten (PaddlePaddle#39418)

* move grad get expected pten kernel args

* fix reduce sum error

* fix element_sub_grad failed

* revert kernel judge change

* fix compilation warning on mac (PaddlePaddle#39438)

* get build time (PaddlePaddle#39368)

* fix prelu trt convert (PaddlePaddle#39389)

* Optimize bilinear interpolation foward (PaddlePaddle#39243)

* bilinear_fw init

* optimize code

* pre-compute linear_interp input index

* Optimize performance of softmax_bwd when axis!=-1 (PaddlePaddle#38609)

* Optimize performance of softmax_bwd when axis!=-1

* fix

* fix

* fix

* fix

* [PTen] Remove pten core's dependency on fluid xxx_info.h (PaddlePaddle#39401)

* ermove xxx_info include

* fix namespace error

* resolve conflict

* skip xpu context in registry

* fix macro error

* resolve conflict

* resolve conflict

* revert xpu convert

* remove trans to fluid place

* remove useless headers

* [Pten] move operators/math/math_function_* to pten/kernels/func (PaddlePaddle#39300)

* move operators/math/math_function_* to pten/kernels/func
* namespace from `paddle::operators::math` to `pten::funcs`

* [MLU] add pool2d and pool2d_grad mlu kernel (PaddlePaddle#39453)

* [MLU]support c_gen_cncl_id_op run on MLU device (PaddlePaddle#39336)

Co-authored-by: zhangna <zhangna@cambricon.com>

* [bf16] add bf16 kernel: transpose & unbind (PaddlePaddle#39457)

* add transpose unbind

* add unittest

* refine transpose unittest

* uniform_random op for mlu (PaddlePaddle#39450)

* [MLU] add pool2d pytest (PaddlePaddle#39454)

* Added shape (U)INT8/BF16/FP32 oneDNN kernel (PaddlePaddle#36033)

* added shape oneDNN kernel

* removed unnecessary import from test

* added skipping tests for GPU

* refactoring

* refactored shape kernel

* added tests in new framework

* removed one line

* minor change

* added newline at EOF

* added formatting

* added attributes as extra

* move memcpy.h into cc file (PaddlePaddle#39469)

* Add TensorRT inspector into Paddle-TRT (PaddlePaddle#38362)

* Fix add profiler node tree implementation cmake error (PaddlePaddle#39474)

* add event node implementation

* modify profiler.stop interface

* fix according to review

* fix file mode

* modify class method name in event_node.cc

* modify LLONG_MAX to ULLONG_MAX

* fix ci error

* fix ci error

* fix dependency error

* unify naming style (PaddlePaddle#39481)

* [Pten] Generate Wrapped InferMeta by Yaml (PaddlePaddle#39482)

* generate wrapped_infer_meta

* add test for wrapped_infer_meta

* Update test_meta_fn_utils.cc

* change the dir of generated file

Co-authored-by: Chen Weihang <chenweihang@baidu.com>
Co-authored-by: Chen Weihang <chenwhpro@163.com>

* Adjusted python-level trace_op to accomodate final state Eager Dygraph (PaddlePaddle#39319)

* Removed debug info

* Added automatic code generation for final state Eager Dygraph

* Modified backward yaml

* Added EagerUtils helper functions for final state CodeGen

* Adjusted CMakeFiles to support compilation for final state auto generated codes

* Added python-c code generation for final state Eager Dygraph

* Fixed minor issue

* Fixed yaml.load() method failure

* Fixed minor issues

* Refactored Python-C Attributes Parsing Functions

* Fixed minor issue with Python-C AddFunctions

* Adjusted python-level trace_op to accomodate final state Eager Dygraph

* Added Logs for final state Eager Dygraph

* Fixed merge issues

* Fixed minor issue

* Fixed get_tensor method for EagerTensor (PaddlePaddle#39414)

* Enabled Eager OpTest PaddlePaddle#1

* Enabled Eager OpTest PaddlePaddle#1

* Fixed get_tensor method for EagerTensor

* [Approver Update] update check approver of qili93, test=document_fix (PaddlePaddle#39483)

* [MLU] add mlu kernel for c_broadcast op (PaddlePaddle#39470)

* update xpu test build script and fix get_test_cover_info, *test=kunlun (PaddlePaddle#39235)

* fix gather_nd, *test=kunlun (PaddlePaddle#39283)

* [pten] add split kernel (PaddlePaddle#39060)

* add split kernel

* add split kernel signature

* fix split bug

* modify MakePtenScalarArrayFromVarList

* modify MakePtenScalarArrayFromVarList

* fix split windows register error

* add test case for split kernel

* replace raw split kernel with pten kernel

* fix makeScalar/ScalarArray bug

* remove debug log

* remove int64_t type in buildPtcontext

* update by code review

* fix split dev test failed

* change DenseTensorMeta to MetaTensor

* change split api code from auto gen to manual

* split cuda kernel support bfloat16 type

* fix conflict

* rm raw split kernel

* merge develop branch

* change to pten::errors

* new may of test cases, *test=kunlun (PaddlePaddle#39444)

* new may of test cases, *test=kunlun

* new may of test cases, *test=kunlun

* new may of test cases, *test=kunlun

* [PTen] Add HasAttr for ArgumentMappingContext (PaddlePaddle#39464)

* add has_attr for arg map context

* skip useless attr now

* skip attr if not exists

* fix typo

* [ROCm] fix missing dcu kernel in operator.cmake, test=develop (PaddlePaddle#39480)

Co-authored-by: zyfncg <zhangyunfei07@baidu.com>
Co-authored-by: Aganlengzi <aganlengzi@gmail.com>
Co-authored-by: Chen Weihang <chenweihang@baidu.com>
Co-authored-by: Leo Chen <chenqiuliang@baidu.com>
Co-authored-by: zhangbo9674 <82555433+zhangbo9674@users.noreply.github.com>
Co-authored-by: crystal <62974595+Zjq9409@users.noreply.github.com>
Co-authored-by: Zhanlue Yang <jim19930609@gmail.com>
Co-authored-by: sneaxiy <32832641+sneaxiy@users.noreply.github.com>
Co-authored-by: wenbin <wang3323032@qq.com>
Co-authored-by: Wilber <jiweibo@baidu.com>
Co-authored-by: hong <43953930+phlrain@users.noreply.github.com>
Co-authored-by: chenyanlann <62465397+chenyanlann@users.noreply.github.com>
Co-authored-by: Wei Shengyu <weisy11@163.com>
Co-authored-by: TeFeng Chen <ctfeng66@163.com>
Co-authored-by: furnace <34057289+windstamp@users.noreply.github.com>
Co-authored-by: fwenguang <95677191+fwenguang@users.noreply.github.com>
Co-authored-by: 0x45f <23097963+0x45f@users.noreply.github.com>
Co-authored-by: Zhang Ting <zhangting_2017@163.com>
Co-authored-by: chenjian <chenjian26@baidu.com>
Co-authored-by: Shang Zhizhou <shangzhizhou@baidu.com>
Co-authored-by: liutiexing <74819124+liutiexing@users.noreply.github.com>
Co-authored-by: liutiexing <liutiexing@google.com>
Co-authored-by: Wangzheee <634486483@qq.com>
Co-authored-by: ziyoujiyi <73728031+ziyoujiyi@users.noreply.github.com>
Co-authored-by: zkh2016 <zhangkaihuo@baidu.com>
Co-authored-by: zhangchunle <clzhang_cauc@163.com>
Co-authored-by: JingZhuangzhuang <75348594+JZZ-NOTE@users.noreply.github.com>
Co-authored-by: Lijunhui <1578034415@qq.com>
Co-authored-by: Zhang Zheng <32410583+ZzSean@users.noreply.github.com>
Co-authored-by: Feiyu Chan <chenfeiyu@baidu.com>
Co-authored-by: zn <96479180+kangna-qi@users.noreply.github.com>
Co-authored-by: zhangna <zhangna@cambricon.com>
Co-authored-by: joeqiao12 <45232181+joeqiao12@users.noreply.github.com>
Co-authored-by: jakpiase <jakpia21@gmail.com>
Co-authored-by: Leo Chen <39020268+leo0519@users.noreply.github.com>
Co-authored-by: Chen Weihang <chenwhpro@163.com>
Co-authored-by: Qi Li <qili93@qq.com>
Co-authored-by: maxhuiy <1508399706@qq.com>
Co-authored-by: TTerror <tangzhiyi11@users.noreply.github.com>
Co-authored-by: chentianyu03 <chentianyu03@baidu.com>
Co-authored-by: helen88 <z8hanghuan@126.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants