Skip to content

Commit

Permalink
Added adam kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
jakpiase committed Jan 26, 2022
1 parent 3ab9aef commit 9e033d8
Show file tree
Hide file tree
Showing 14 changed files with 422 additions and 20 deletions.
1 change: 1 addition & 0 deletions paddle/fluid/operators/jit/gen/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,6 @@ USE_JITKERNEL_GEN(kSeqPool)
USE_JITKERNEL_GEN(kHMax)
USE_JITKERNEL_GEN(kHSum)
USE_JITKERNEL_GEN(kEmbSeqPool)
USE_JITKERNEL_GEN(kAdam)
USE_JITKERNEL_GEN(kSgd)
USE_JITKERNEL_GEN(kVBroadcast)
153 changes: 153 additions & 0 deletions paddle/fluid/operators/jit/gen/adam.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */

#include "paddle/fluid/operators/jit/gen/adam.h"

#include <stddef.h> // offsetof

#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"

namespace paddle {
namespace operators {
namespace jit {
namespace gen {

void AdamJitCode::loadArgs() {
static constexpr int32_t one_as_float = 0x3f800000;
static constexpr int32_t mask_all_ones = 0xFFFFFFFF;
static constexpr int64_t mask_8_divisible = 0xFFFFFFFFFFFFFFF8;
static constexpr int64_t abi_pushes_offset = num_g_abi_regs * 8;

mov(reg_mom2_out_ptr, ptr[rsp + (abi_pushes_offset + 8)]);
mov(reg_param_out_ptr, ptr[rsp + (abi_pushes_offset + 16)]);
mov(eax, one_as_float);
movd(xmm_one, eax);

vbroadcastss(ymm_one, xmm_one); // 1
vbroadcastss(ymm_beta1, xmm_beta1); // beta1
vbroadcastss(ymm_beta2, xmm_beta2); // beta2
vbroadcastss(ymm_lr, xmm_lr); // -lr
vbroadcastss(ymm_eps, xmm_eps); // eps
vsubps(ymm_one_sub_beta1, ymm_one, ymm_beta1); // 1 - beta1
vsubps(ymm_one_sub_beta2, ymm_one, ymm_beta2); // 1 - beta2

mov(reg_numel_without_tail, reg_numel);
and_(reg_numel_without_tail, mask_8_divisible); // make it 8-divisible

shl(reg_numel_without_tail, 2); // * 4 to treat it as float offset
shl(reg_numel, 2);

mov(eax, mask_all_ones);
kmovw(k1, eax);

xor_(reg_offset, reg_offset);
}

void AdamJitCode::setTailOpmask() {
mov(r13, rcx);

mov(rcx, reg_numel);
sub(rcx, reg_offset); // get tail numel as float size
shr(rcx, 2); // as elements
mov(r14, 1);
shl(r14, cl); // 2 ^ elements
dec(r14); // 2 ^ elements - 1, so numel first bits are set to 1
kmovw(k1, r14d);

mov(rcx, r13);
}

void AdamJitCode::mainCode() {
// load grad
vmovups(ymm7 | k1, ptr[reg_grad_ptr + reg_offset]);

// beta1 * mom1 + (1 - beta1) * g
vmulps(ymm8 | k1, ymm_one_sub_beta1, ymm7);
vfmadd231ps(ymm8 | k1, ymm_beta1, ptr[reg_mom1_ptr + reg_offset]);

// beta2 * mom2 + (1 - beta2) * g * g
vmulps(ymm7 | k1, ymm7, ymm7);
vmulps(ymm7 | k1, ymm_one_sub_beta2, ymm7);
vfmadd231ps(ymm7 | k1, ymm1, ptr[reg_mom2_ptr + reg_offset]);

// store mom1 and mom2
vmovups(ptr[reg_mom1_out_ptr + reg_offset] | k1, ymm8);
vmovups(ptr[reg_mom2_out_ptr + reg_offset] | k1, ymm7);

// sqrt(mom2) + eps
vsqrtps(ymm7 | k1, ymm7);
vaddps(ymm7 | k1, ymm7, ymm3);

// p + (-lr) * (mom1 / sqrt(mom2) + eps)
vdivps(ymm7 | k1, ymm8, ymm7);
vfmadd213ps(ymm7 | k1, ymm2, ptr[reg_param_ptr + reg_offset]);

// store p
vmovups(ptr[reg_param_out_ptr + reg_offset] | k1, ymm7);
}

void AdamJitCode::genCode() {
static constexpr int64_t main_loop_elems_size =
8 * sizeof(float); // 8 floats in YMM
static constexpr int64_t offset_increment = main_loop_elems_size;
preCode();
loadArgs();

cmp(reg_numel, main_loop_elems_size);
jl("process_tail");

L("main_loop");
{
mainCode();
add(reg_offset, offset_increment);
cmp(reg_numel_without_tail, reg_offset);
jg("main_loop");
}

cmp(reg_numel, reg_offset);
je("end");

L("process_tail");
{
setTailOpmask();
mainCode();
}

L("end");
postCode();
}

class AdamCreator : public JitCodeCreator<adam_attr_t> {
public:
bool CanBeUsed(const adam_attr_t& attr) const override {
return platform::MayIUse(platform::avx512f);
}
size_t CodeSize(const adam_attr_t& attr) const override {
return 96 + 32 * 8;
}
std::unique_ptr<GenBase> CreateJitCode(
const adam_attr_t& attr) const override {
return make_unique<AdamJitCode>(attr, CodeSize(attr));
}
};

} // namespace gen
} // namespace jit
} // namespace operators
} // namespace paddle

namespace gen = paddle::operators::jit::gen;

REGISTER_JITKERNEL_GEN(kAdam, gen::AdamCreator);
75 changes: 75 additions & 0 deletions paddle/fluid/operators/jit/gen/adam.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */

#pragma once

#include <string>

#include "glog/logging.h"
#include "paddle/fluid/operators/jit/gen/jitcode.h"
#include "paddle/fluid/platform/enforce.h"

namespace paddle {
namespace operators {
namespace jit {
namespace gen {

class AdamJitCode : public JitCode {
public:
explicit AdamJitCode(const adam_attr_t& attr, size_t code_size = 256 * 1024,
void* code_ptr = nullptr)
: JitCode(code_size, code_ptr) {
this->genCode();
}

DECLARE_JIT_CODE(AdamJitCode);
void genCode() override;
void loadArgs();
void setTailOpmask();
void mainCode();

private:
reg64_t reg_numel{abi_param1};
reg64_t reg_grad_ptr{abi_param2};
reg64_t reg_mom1_ptr{abi_param3};
reg64_t reg_mom2_ptr{abi_param4};
reg64_t reg_param_ptr{abi_param5};
reg64_t reg_mom1_out_ptr{abi_param6};

xmm_t xmm_beta1 = xmm_t(0);
xmm_t xmm_beta2 = xmm_t(1);
xmm_t xmm_lr = xmm_t(2);
xmm_t xmm_eps = xmm_t(3);
xmm_t xmm_one_sub_beta1 = xmm_t(4);
xmm_t xmm_one_sub_beta2 = xmm_t(5);
xmm_t xmm_one = xmm_t(6);

ymm_t ymm_beta1 = ymm_t(0);
ymm_t ymm_beta2 = ymm_t(1);
ymm_t ymm_lr = ymm_t(2);
ymm_t ymm_eps = ymm_t(3);
ymm_t ymm_one_sub_beta1 = ymm_t(4);
ymm_t ymm_one_sub_beta2 = ymm_t(5);
ymm_t ymm_one = ymm_t(6);

reg64_t reg_mom2_out_ptr{r10};
reg64_t reg_param_out_ptr{r11};
reg64_t reg_numel_without_tail{r12};
reg64_t reg_offset{rax};
};

} // namespace gen
} // namespace jit
} // namespace operators
} // namespace paddle
1 change: 1 addition & 0 deletions paddle/fluid/operators/jit/gen/jitcode.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ using reg32_t = const Xbyak::Reg32;
using xmm_t = const Xbyak::Xmm;
using ymm_t = const Xbyak::Ymm;
using zmm_t = const Xbyak::Zmm;
using opmask_t = const Xbyak::Opmask;
using Label = Xbyak::Label;

typedef enum {
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/operators/jit/helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ const char* to_string(KernelType kt) {
ONE_CASE(kSeqPool);
ONE_CASE(kMatMul);
ONE_CASE(kHMax);
ONE_CASE(kAdam);
ONE_CASE(kHSum);
ONE_CASE(kStrideASum);
ONE_CASE(kSoftmax);
Expand Down
5 changes: 5 additions & 0 deletions paddle/fluid/operators/jit/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,11 @@ inline std::ostream& operator<<(std::ostream& os,
return os;
}

inline std::ostream& operator<<(std::ostream& os, const adam_attr_t& attr) {
os << "beta1[" << attr.beta1 << "],beta2[" << attr.beta2 << "]";
return os;
}

inline std::ostream& operator<<(std::ostream& os, const sgd_attr_t& attr) {
os << "param_height[" << attr.param_height << "],param_width["
<< attr.param_width << "],grad_height[" << attr.grad_height
Expand Down
20 changes: 18 additions & 2 deletions paddle/fluid/operators/jit/kernel_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ namespace jit {
typedef enum {
kNone = 0,
// sort by alphabet
kCRFDecoding = 1,
kEmbSeqPool = 2,
kAdam = 1,
kCRFDecoding,
kEmbSeqPool,
kGRUH1,
kGRUHtPart1,
kGRUHtPart2,
Expand Down Expand Up @@ -269,6 +270,21 @@ struct SgdTuple {
const sgd_attr_t*);
};

typedef struct adam_attr_s {
float beta1, beta2;
adam_attr_s() = default;
explicit adam_attr_s(float beta1, float beta2) : beta1(beta1), beta2(beta2) {}
} adam_attr_t;

template <typename T>
struct AdamTuple {
static constexpr KernelType kernel_type = kAdam;
typedef T data_type;
typedef adam_attr_t attr_type;
typedef void (*func_type)(T, T, T, T, int64_t, const T*, const T*, const T*,
const T*, T*, T*, T*);
};

typedef struct matmul_attr_s {
int m, n, k;
void* packed_weight{nullptr};
Expand Down
5 changes: 5 additions & 0 deletions paddle/fluid/operators/jit/kernel_key.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,11 @@ int64_t JitCodeKey<sgd_attr_t>(const sgd_attr_t& attr) {
return attr.grad_width;
}

template <>
int64_t JitCodeKey<adam_attr_t>(const adam_attr_t& attr) {
return static_cast<int64_t>(attr.beta1 + attr.beta2);
}

} // namespace jit
} // namespace operators
} // namespace paddle
1 change: 1 addition & 0 deletions paddle/fluid/operators/jit/refer/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,6 @@ USE_JITKERNEL_REFER(kHMax)
USE_JITKERNEL_REFER(kStrideASum)
USE_JITKERNEL_REFER(kSoftmax)
USE_JITKERNEL_REFER(kEmbSeqPool)
USE_JITKERNEL_REFER(kAdam)
USE_JITKERNEL_REFER(kSgd)
USE_JITKERNEL_REFER(kVBroadcast)
1 change: 1 addition & 0 deletions paddle/fluid/operators/jit/refer/refer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ REGISTER_REFER_KERNEL(HSum);
REGISTER_REFER_KERNEL(StrideASum);
REGISTER_REFER_KERNEL(Softmax);
REGISTER_REFER_KERNEL(EmbSeqPool);
REGISTER_REFER_KERNEL(Adam);
REGISTER_REFER_KERNEL(Sgd);
REGISTER_REFER_KERNEL(VBroadcast);

Expand Down
14 changes: 14 additions & 0 deletions paddle/fluid/operators/jit/refer/refer.h
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,19 @@ void Sgd(const T* lr, const T* param, const T* grad, const int64_t* rows,
}
}

template <typename T>
void Adam(T beta1, T beta2, T lr, T eps, int64_t numel, const T* grad_ptr,
const T* mom1_ptr, const T* mom2_ptr, const T* param_ptr,
T* mom1_out_ptr, T* mom2_out_ptr, T* param_out_ptr) {
for (int i = 0; i < numel; ++i) {
mom1_out_ptr[i] = beta1 * mom1_ptr[i] + (1 - beta1) * grad_ptr[i];
mom2_out_ptr[i] =
beta2 * mom2_ptr[i] + (1 - beta2) * grad_ptr[i] * grad_ptr[i];
param_out_ptr[i] =
param_ptr[i] + lr * (mom1_out_ptr[i] / (sqrt(mom2_out_ptr[i]) + eps));
}
}

#define DECLARE_REFER_KERNEL(name) \
template <typename T> \
class name##Kernel : public ReferKernel<name##Tuple<T>> { \
Expand Down Expand Up @@ -603,6 +616,7 @@ DECLARE_REFER_KERNEL(SeqPool);
DECLARE_REFER_KERNEL(MatMul);
DECLARE_REFER_KERNEL(Softmax);
DECLARE_REFER_KERNEL(EmbSeqPool);
DECLARE_REFER_KERNEL(Adam);
DECLARE_REFER_KERNEL(Sgd);
DECLARE_REFER_KERNEL(VBroadcast);

Expand Down
Loading

0 comments on commit 9e033d8

Please sign in to comment.