Skip to content

Commit

Permalink
Merge pull request #4598 from jacquesqiao/fix-sgd-learning-rate
Browse files Browse the repository at this point in the history
use EigenVector to get learning_rate for GPU device in SGD operator
  • Loading branch information
jacquesqiao committed Oct 5, 2017
2 parents ffd092d + 8ebc31d commit c0511c8
Showing 1 changed file with 10 additions and 13 deletions.
23 changes: 10 additions & 13 deletions paddle/operators/sgd_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,28 +19,25 @@ limitations under the License. */
namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;

template <typename Place, typename T>
class SGDOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto param = ctx.Input<Tensor>("Param");
auto grad = ctx.Input<Tensor>("Grad");
auto param_out = ctx.Output<Tensor>("ParamOut");
float lr = ctx.Input<Tensor>("LearningRate")->data<float>()[0];
auto param = ctx.Input<framework::Tensor>("Param");
auto grad = ctx.Input<framework::Tensor>("Grad");
auto param_out = ctx.Output<framework::Tensor>("ParamOut");
auto learning_rate = ctx.Input<framework::Tensor>("LearningRate");

param_out->mutable_data<T>(ctx.GetPlace());

auto p = EigenVector<T>::Flatten(*param);
auto g = EigenVector<T>::Flatten(*grad);
auto o = EigenVector<T>::Flatten(*param_out);
auto p = framework::EigenVector<T>::Flatten(*param);
auto g = framework::EigenVector<T>::Flatten(*grad);
auto o = framework::EigenVector<T>::Flatten(*param_out);
auto lr = framework::EigenVector<T>::Flatten(*learning_rate);
auto place = ctx.GetEigenDevice<Place>();

o.device(place) = p - lr * g;
Eigen::DSizes<int, 1> grad_dsize(grad->numel());
o.device(place) = p - lr.broadcast(grad_dsize) * g;
}
};

Expand Down

0 comments on commit c0511c8

Please sign in to comment.