Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CustomOp] Support output as input argument of kernel func #39353

Merged

Conversation

chenwhql
Copy link
Contributor

@chenwhql chenwhql commented Feb 6, 2022

PR types

New features

PR changes

Others

Describe

[CustomOp] Support output as input argument of kernel func

扩展支持自定义算子将输出Tensor以Tensor*形式作为Kernel函数输入参数的写法。

之前自定义算子限定返回值必须是vector<Tensor>,这在一些情况下是无法满足需求的,包括但不限于:

  1. inplace运算,输入和输出是同一个Tensor
  2. 输出Tensor的数据地址不能在运算前后发生变化

因此,为了确保功能完备,扩展支持以下写法的自定义算子开发,即返回值作为函数输入参数的写法:

void ReluForwardOut(const paddle::Tensor& x, paddle::Tensor* out) {
  if (x.place() == paddle::PlaceType::kCPU) {
    return relu_cpu_forward_out(x, out);
  } else if (x.place() == paddle::PlaceType::kGPU) {
    return relu_cuda_forward_out(x, out);
  } else {
    PD_THROW("Not implemented.");
  }
}

当输出以paddle::Tensor*的形式指定,并且自定义算子kernel内没有对内存进行重新设定的话,Tensor*对应的Holder在自定义算子运算时操作的是框架内传入Tensor的Holder,从而确保能够适配由框架指定的inplace和fuse的场景。

但由于paddle::Tensorframework::Variable有些差别,实质上在准备自定义算子的输出Tensor时,其内部的DenseTensor也是拷贝创建的(meta拷贝,holder共享),因此在自定义算子kernel执行结束后,需要对meta信息进行同步,holder如果相同的话,不需要再Reset。

同时,原先写法仍然是完全兼容的。

@paddle-bot-old
Copy link

paddle-bot-old bot commented Feb 6, 2022

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

Copy link
Collaborator

@sneaxiy sneaxiy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@sneaxiy sneaxiy merged commit f1f74e9 into PaddlePaddle:develop Feb 7, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants