Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support multi-output feature for elementwise #38410

Conversation

JamesLim-sy
Copy link
Contributor

@JamesLim-sy JamesLim-sy commented Dec 23, 2021

PR types

New features

PR changes

OPs

Describe

  1. 特性

    • 修改Elementwise_no_broadcast 内的核心实现,保证了elmenetwise可与PR38329broadcast一致,支持多输出功能
    • 统一修改了LaunchBroadcastElementwiseCudaKernel, LaunchSameDimsElementwiseCudaKernel, 和二者的上层封装LaunchElementwiseCudaKernel 的模板参数列表
    • 配合 PR38044 验证了对于多输出的正确性,性能提升效果
    case pytorch(kernel) 优化前 优化前相比pytorch 优化后 优化后相比pytorch
    [50, 128, 1000], [128, 1000] 0.46865 0.24259 优于 (48.24%) 0.2379 优于(49.24%)
  2. 使用方法

    • 引入paddle::framework::Array 作为输入的数据类型,并设计如下所示的functor
    template <typename InT, typename OutT>
    struct Functor {
      HOSTDEVICE paddle::framework::Array<OutT, 2>  operator()(InT x, InT y, InT z) {
        paddle::framework::Array<OutT, 2> outs;
        outs[0] = x + y * z;
        outs[1] = x / z;
        return outs;
      }
    };
  • 调用LaunchBroadcastElementwiseCudaKernel时,模板参数由于<InT, OutT, functor> 转变为<InT, OutT, functor, NumOuts>,其中NumOuts用于表达functor的输出元素数量(本例为2),默认值为1故兼容现有的单个functor计算的写法
  1. 遗留项目
    • 目前引入了模板参数NumOuts 用于实现多输出情形,但面对多输出情形Function_traits中的ReturnType 模板参数本身也是paddle::framework::Array<OutT, NumOuts> 的类型,一直想利用using Traits = paddle::platform::FunctionTraits 尝试获取NumOuts,暂时未能有效实现

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

int num,
int data_offset,
Functor func) {
InT args[Arity][VecSize];
OutT result[VecSize];
OutType<OutT, NumOuts> result[VecSize];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

看看是否有更合适的名字吧,OutType不是很贴切,或者叫PackedOutT,要么体现下ConditionalOutT

Copy link
Contributor Author

@JamesLim-sy JamesLim-sy Dec 28, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

修改成了ConditionalT<OutT, NumOuts>

@@ -174,19 +207,39 @@ void LaunchSameDimsElementwiseCudaKernel(
"is %d, the arity of functor is %d.",
ins.size(),
kArity));
PADDLE_ENFORCE_EQ(outs->size(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里不需要像broadcast一样判断下ET的值吗

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

由于function_traits中这段的存在:

using Traits = paddle::platform::FunctionTraits<Functor>;
const int kArity =
Traits::has_pointer_args ? static_cast<int>(ET) : Traits::arity;

所以ET的值就被 kArity 取代了,后面就用kArity 做判断了

Copy link
Contributor

@Xreki Xreki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@@ -162,7 +162,8 @@ struct DimensionsTransform {
}
};

template <ElementwiseType ET, typename InT, typename OutT, typename Functor>
template <ElementwiseType ET, typename InT, typename OutT, typename Functor,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个文件里面的DimensionsTransform是不是可以删掉了?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以删除掉,我提个PR把这里删除吧,本来以为pten那边会删除掉

PADDLE_ENFORCE_EQ(kArity,
2,
PADDLE_ENFORCE_LE(kArity,
ElementwiseType::kTernary,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

其实这里就应该用3,因为ElementwiseType::kTernary是个枚举类型,可能设置成别的值。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,和上一个删除DimensionsTransform的PR一同把这里修改掉

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

其实这里就应该用3,因为ElementwiseType::kTernary是个枚举类型,可能设置成别的值。

已经在PR38550中修改

@JamesLim-sy JamesLim-sy merged commit 48f061f into PaddlePaddle:develop Dec 28, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants