Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FillOp #3505

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion paddle/framework/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -55,5 +55,6 @@ cc_library(paddle_pybind SHARED
recurrent_op
uniform_random_op
gaussian_random_op
fill_zeros_like_op)
fill_zeros_like_op
fill_op)
endif(WITH_PYTHON)
1 change: 1 addition & 0 deletions paddle/framework/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ USE_OP(fill_zeros_like);
USE_OP_ITSELF(recurrent_op);
USE_OP(gaussian_random);
USE_OP(uniform_random);
USE_OP(fill);

namespace paddle {
namespace framework {
Expand Down
2 changes: 2 additions & 0 deletions paddle/framework/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ class Tensor {

platform::Place place() const { return holder_->place(); }

bool IsHoldingMemory() const { return holder_ != nullptr; }

private:
template <typename T>
inline void check_memory_size() const;
Expand Down
3 changes: 1 addition & 2 deletions paddle/framework/tensor_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,11 @@ inline T* Tensor::mutable_data(platform::Place place) {
} else if (platform::is_gpu_place(place)) {
#ifdef PADDLE_ONLY_CPU
PADDLE_THROW("'GPUPlace' is not supported in CPU only device.");
}
#else
holder_.reset(new PlaceholderImpl<T, platform::GPUPlace>(
boost::get<platform::GPUPlace>(place), size));
}
#endif
}
offset_ = 0;
}
return reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(holder_->ptr()) +
Expand Down
19 changes: 19 additions & 0 deletions paddle/memory/memcpy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,16 @@ void Copy<platform::CPUPlace, platform::CPUPlace>(platform::CPUPlace, void* dst,
}

#ifndef PADDLE_ONLY_CPU

template <>
void Copy<platform::CPUPlace, platform::GPUPlace>(platform::CPUPlace dst_place,
void* dst,
platform::GPUPlace src_place,
const void* src, size_t num) {
platform::SetDeviceId(src_place.device);
platform::GpuMemcpySync(dst, src, num, cudaMemcpyDeviceToHost);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we do not need a sync Copy here. Copy work on a specific cuda stream too. If we really want to sync the copy:

Copy(dts_place, dst, src_place, src, num, stream_);
cudaStreamSynchronize(stream_);

At now, we only have default stream(and I am fixing it in #3497 ), and you can pass 0 as cuda stream at now.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is very strange that if we invoke some copy method in memory.h, it will trigger link error while compiling.

It is hard to debug if the developer is not familiar with C++, template, and memory.{h/cc}.

So, we should implement the Copy correctly in memory.{h/cc}. It is developer's choice to add a stream or not.

}

template <>
void Copy<platform::CPUPlace, platform::GPUPlace>(platform::CPUPlace dst_place,
void* dst,
Expand All @@ -39,6 +49,15 @@ void Copy<platform::CPUPlace, platform::GPUPlace>(platform::CPUPlace dst_place,
platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToHost, stream);
}

template <>
void Copy<platform::GPUPlace, platform::CPUPlace>(platform::GPUPlace dst_place,
void* dst,
platform::CPUPlace src_place,
const void* src, size_t num) {
platform::SetDeviceId(dst_place.device);
platform::GpuMemcpySync(dst, src, num, cudaMemcpyHostToDevice);
}

template <>
void Copy<platform::GPUPlace, platform::CPUPlace>(platform::GPUPlace dst_place,
void* dst,
Expand Down
2 changes: 2 additions & 0 deletions paddle/operators/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,5 @@ op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc
cc_test(recurrent_op_test SRCS recurrent_op_test.cc DEPS recurrent_op gtest mul_op add_op)
op_library(uniform_random_op
SRCS uniform_random_op.cc uniform_random_op.cu)

op_library(fill_op SRCS fill_op.cc fill_op.cu)
67 changes: 67 additions & 0 deletions paddle/operators/fill_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/operators/fill_op.h"

namespace paddle {
namespace operators {

template <typename T>
class FillOp : public framework::OperatorWithKernel {
public:
FillOp(const std::string &type, const VarNameMap &inputs,
const VarNameMap &outputs, const framework::AttributeMap &attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {}

protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
auto &shape = GetAttr<std::vector<int>>("shape");
auto dim = framework::make_ddim(shape);
auto numel = framework::product(dim);
PADDLE_ENFORCE_EQ(numel, GetAttr<std::vector<T>>("data").size(),
"Shape's numel should be as same as data element count");
ctx.Output<framework::Tensor>("Out")->Resize(dim);
}
};

template <typename T>
class FillOpMaker : public framework::OpProtoAndCheckerMaker {
public:
FillOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddOutput("Out", "Output of Fill Op");
AddComment("Fill a variable with shape and buffer each time.");
AddAttr<int>("run_once", "Set it once or each time when run")
.SetDefault(false)
.InEnum({true, false});
AddAttr<std::vector<int>>("shape", "The shape of fill parameter");
AddAttr<std::vector<T>>("data", "The data will be filled");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please have a look at #2917
There are mainly two kinds of ways to load data. The first way is load from vector or numpy. The second way is generated by paddle itself.
Will we have another method like FeedVariable(caffe2 have FeedBlob)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fill_op is part of topology and it does not conflict with FeedVariable.

Think a situation, the minus operator's gradient, are combined operators, they are

  • An Identify or Copy operator.
  • A Fill operator to fill a scalar as -1 and a Scale operator.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not about load data. It is about designing topology.

}
};

template <typename T>
class FillOpCPUKernel : public FillOpKernelBase<T> {
public:
void Copy(const platform::Place &place, const std::vector<T> &src,
T *dst) const override {
std::copy(src.begin(), src.end(), dst);
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP(fill, ops::FillOp<float>, ops::FillOpMaker<float>);
REGISTER_OP_CPU_KERNEL(fill, ops::FillOpCPUKernel<float>);
32 changes: 32 additions & 0 deletions paddle/operators/fill_op.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/memory/memcpy.h"
#include "paddle/operators/fill_op.h"
namespace paddle {
namespace operators {
template <typename T>
class FillOpGPUKernel : public FillOpKernelBase<T> {
public:
void Copy(const platform::Place &place, const std::vector<T> &src,
T *dst) const override {
auto &gpu_place = boost::get<platform::GPUPlace>(place);
auto &cpu_place = platform::default_cpu();
memory::Copy(gpu_place, dst, cpu_place, src.data(), src.size() * sizeof(T));
}
};
} // namespace operators
} // namespace paddle

REGISTER_OP_GPU_KERNEL(fill, paddle::operators::FillOpGPUKernel<float>);
42 changes: 42 additions & 0 deletions paddle/operators/fill_op.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include <vector>

#include "paddle/framework/op_registry.h"

namespace paddle {
namespace operators {
template <typename T>
class FillOpKernelBase : public framework::OpKernel {
Copy link
Member

@QiJune QiJune Aug 16, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe the base class FillOpKernelBase is a little complex, just implementing data fill in FillOpGPUKernel and FillOpCPUKernel directly will be fine.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are common lines of code, shared between CPU/GPU kernels. Make a BaseClass will let the code shared.

public:
void Compute(const framework::ExecutionContext& context) const override {
using namespace paddle::framework;
auto* tensor = context.Output<Tensor>("Out");
auto run_once = static_cast<bool>(context.op_.GetAttr<int>("run_once"));
if (run_once && tensor->IsHoldingMemory()) {
return;
}
T* dst = tensor->mutable_data<T>(context.GetPlace());
auto& src = context.op_.GetAttr<std::vector<T>>("data");
this->Copy(context.GetPlace(), src, dst);
}

virtual void Copy(const platform::Place& place, const std::vector<T>& src,
T* dst) const = 0;
};

} // namespace operators
} // namespace paddle
1 change: 1 addition & 0 deletions python/paddle/v2/framework/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@ py_test(test_operator SRCS test_operator.py)
# py_test(test_gaussian_random_op SRCS test_gaussian_random_op.py)
py_test(test_uniform_random_op SRCS test_uniform_random_op.py)
py_test(test_recurrent_op SRCS test_recurrent_op.py)
py_test(test_fill_op SRCS test_fill_op.py)
21 changes: 21 additions & 0 deletions python/paddle/v2/framework/tests/test_fill_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import unittest
from op_test_util import OpTestMeta
import numpy


class TestFillOp(unittest.TestCase):
__metaclass__ = OpTestMeta

def setUp(self):
self.type = "fill"
data = [0.1, 0.2, 0.3, 0.4]

self.attrs = {'data': data, 'shape': [2, 2], 'run_once': True}
self.outputs = {
'Out': numpy.array(
[[0.1, 0.2], [0.3, 0.4]], dtype=numpy.float32)
}


if __name__ == '__main__':
unittest.main()