Skip to content

Commit

Permalink
New optimizer (PaddlePaddle#225)
Browse files Browse the repository at this point in the history
* add pair_shuffle

* add adagrad v2

* add comment
  • Loading branch information
Yelrose authored and danleifeng committed Sep 12, 2023
1 parent 049e42d commit 67cc0ca
Show file tree
Hide file tree
Showing 13 changed files with 1,140 additions and 13 deletions.
63 changes: 63 additions & 0 deletions paddle/fluid/distributed/ps/table/sparse_sgd_rule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -334,5 +334,68 @@ void SparseSharedAdamSGDRule::InitValueWork(float *value,
*(sgd + Beta1PowIndex()) = _beta1_decay_rate;
*(sgd + Beta2PowIndex()) = _beta2_decay_rate;
}

void SparseAdaGradV2SGDRule::LoadConfig(const SparseCommonSGDRuleParameter &param,
size_t emb_dim) {
_embedding_dim = emb_dim;
auto adagrad_param = param.adagrad();
learning_rate_ = adagrad_param.learning_rate();
_initial_g2sum = adagrad_param.initial_g2sum();
_initial_range = adagrad_param.initial_range();

if (adagrad_param.weight_bounds_size() == 0) {
_min_bound = -std::numeric_limits<float>::max();
_max_bound = std::numeric_limits<float>::max();
} else {
CHECK(adagrad_param.weight_bounds_size() >= 2)
<< "invalid repeated size for weight_bounds:"
<< adagrad_param.weight_bounds_size();
_min_bound = adagrad_param.weight_bounds(0);
_max_bound = adagrad_param.weight_bounds(1);
}
}

void SparseAdaGradV2SGDRule::UpdateValueWork(float *w,
float *sgd,
const float *grad,
float scale) {
float &g2sum = sgd[G2SumIndex()];
double add_g2sum = 0;
float epsilon = 1e-8;

for (size_t i = 0; i < _embedding_dim; i++) {
double scaled_grad = grad[i] / scale;
add_g2sum += scaled_grad * scaled_grad;
}
g2sum += add_g2sum / _embedding_dim;

for (size_t i = 0; i < _embedding_dim; i++) {
double scaled_grad = grad[i] / scale;
w[i] -= learning_rate_ * scaled_grad / (sqrt(g2sum) + epsilon);
BoundValue(w[i]);
}
}

void SparseAdaGradV2SGDRule::InitValueWork(float *value,
float *sgd,
bool zero_init) {
for (size_t i = 0; i < _embedding_dim; ++i) {
if (zero_init) {
value[i] = 0.0;
BoundValue(value[i]);
} else {
value[i] =
(local_uniform_real_distribution<double>()(local_random_engine()) *
2 -
1) *
_initial_range;
BoundValue(value[i]);
}
}
sgd[G2SumIndex()] = 0;
}



} // namespace distributed
} // namespace paddle
23 changes: 23 additions & 0 deletions paddle/fluid/distributed/ps/table/sparse_sgd_rule.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,29 @@ class SparseAdaGradSGDRule : public SparseValueSGDRule {
float _initial_g2sum;
};

class SparseAdaGradV2SGDRule : public SparseValueSGDRule {
// a new SparseAdaGradV2 use standard adagrad update rules.
// g2sum = grad_x * grad_x + g2sum
// x = x + lr * grad_x / sqrt(g2sum + epsilon)
public:
virtual void LoadConfig(const SparseCommonSGDRuleParameter& param,
size_t emb_dim);
virtual void UpdateValueWork(float* w,
float* sgd,
const float* push_value,
float scale);
virtual void InitValueWork(float* value, float* sgd, bool zero_init);
virtual size_t Dim() { return 1; }
size_t G2SumIndex() { return 0; }

private:
float learning_rate_;
float _initial_g2sum;
};




class StdAdaGradSGDRule : public SparseValueSGDRule {
public:
virtual void LoadConfig(const SparseCommonSGDRuleParameter& param,
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/distributed/ps/table/table.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ REGISTER_PSCORE_CLASS(SparseValueSGDRule, StdAdaGradSGDRule);
REGISTER_PSCORE_CLASS(SparseValueSGDRule, SparseAdamSGDRule);
REGISTER_PSCORE_CLASS(SparseValueSGDRule, SparseNaiveSGDRule);
REGISTER_PSCORE_CLASS(SparseValueSGDRule, SparseAdaGradSGDRule);
REGISTER_PSCORE_CLASS(SparseValueSGDRule, SparseAdaGradV2SGDRule);
REGISTER_PSCORE_CLASS(SparseValueSGDRule, SparseSharedAdamSGDRule);

int32_t TableManager::Initialize() {
Expand Down
51 changes: 49 additions & 2 deletions paddle/fluid/framework/data_feed.cu
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,25 @@ __global__ void FillSlotValueOffsetKernel(const int ins_num,
}
}

struct RandInt
{
int low, high;

__host__ __device__
RandInt(int low, int high) : low(low), high(high) {};

__host__ __device__
int operator()(const unsigned int n) const
{
thrust::default_random_engine rng;
thrust::uniform_int_distribution<int> dist(low, high);
rng.discard(n);

return dist(rng);
}
};


void SlotRecordInMemoryDataFeed::FillSlotValueOffset(
const int ins_num,
const int used_slot_num,
Expand Down Expand Up @@ -451,6 +470,7 @@ __global__ void GraphFillIdKernel(uint64_t *id_tensor,
uint64_t *walk,
uint8_t *walk_ntype,
int *row,
int *row_col_shift,
int central_word,
int step,
int len,
Expand All @@ -471,8 +491,12 @@ __global__ void GraphFillIdKernel(uint64_t *id_tensor,
// id_tensor[dst] = walk[src];
// id_tensor[dst + 1] = walk[src + step];
if (idx < len) {
int src = row[idx] * col_num + central_word;
if (walk[src] != 0 && walk[src + step] != 0) {
int col_idx = (central_word + row_col_shift[idx]) % col_num;
int src = row[idx] * col_num + col_idx;
int last_row = row[idx] * col_num;
int next_row = last_row + col_num;

if ((src + step) >= last_row && (src + step) < next_row && walk[src] != 0 && walk[src + step] != 0) {
for (int i = 0; i < excluded_train_pair_len; i += 2) {
if (walk_ntype[src] == excluded_train_pair[i] &&
walk_ntype[src + step] == excluded_train_pair[i + 1]) {
Expand Down Expand Up @@ -735,16 +759,19 @@ int GraphDataGenerator::MakeInsPair(cudaStream_t stream) {
}
uint64_t *ins_buf = reinterpret_cast<uint64_t *>(d_ins_buf_->ptr());
int *random_row = reinterpret_cast<int *>(d_random_row_->ptr());
int *random_row_col_shift = reinterpret_cast<int *>(d_random_row_col_shift_->ptr());
int *d_pair_num = reinterpret_cast<int *>(d_pair_num_->ptr());
cudaMemsetAsync(d_pair_num, 0, sizeof(int), stream);
int len = buf_state_.len;

// make pair
GraphFillIdKernel<<<GET_BLOCKS(len), CUDA_NUM_THREADS, 0, stream>>>(
ins_buf + ins_buf_pair_len_ * 2,
d_pair_num,
walk,
walk_ntype,
random_row + buf_state_.cursor,
random_row_col_shift + buf_state_.cursor,
buf_state_.central_word,
window_step_[buf_state_.step],
len,
Expand Down Expand Up @@ -2354,6 +2381,7 @@ int GraphDataGenerator::FillWalkBuf() {
platform::CUDADeviceGuard guard2(gpuid_);
buf_state_.Reset(total_row_);
int *d_random_row = reinterpret_cast<int *>(d_random_row_->ptr());
int *d_random_row_col_shift = reinterpret_cast<int *>(d_random_row_col_shift_->ptr());

thrust::random::default_random_engine engine(shuffle_seed_);
const auto &exec_policy = thrust::cuda::par.on(sample_stream_);
Expand All @@ -2364,6 +2392,12 @@ int GraphDataGenerator::FillWalkBuf() {
thrust::device_pointer_cast(d_random_row),
engine);

thrust::transform(exec_policy,
cnt_iter,
cnt_iter + total_row_,
thrust::device_pointer_cast(d_random_row_col_shift),
RandInt(0, walk_len_));

cudaStreamSynchronize(sample_stream_);
shuffle_seed_ = engine();

Expand Down Expand Up @@ -2590,6 +2624,7 @@ int GraphDataGenerator::FillWalkBufMultiPath() {
platform::CUDADeviceGuard guard2(gpuid_);
buf_state_.Reset(total_row_);
int *d_random_row = reinterpret_cast<int *>(d_random_row_->ptr());
int *d_random_row_col_shift = reinterpret_cast<int *>(d_random_row_col_shift_->ptr());

thrust::random::default_random_engine engine(shuffle_seed_);
const auto &exec_policy = thrust::cuda::par.on(sample_stream_);
Expand All @@ -2600,6 +2635,12 @@ int GraphDataGenerator::FillWalkBufMultiPath() {
thrust::device_pointer_cast(d_random_row),
engine);

thrust::transform(exec_policy,
cnt_iter,
cnt_iter + total_row_,
thrust::device_pointer_cast(d_random_row_col_shift),
RandInt(0, walk_len_));

cudaStreamSynchronize(sample_stream_);
shuffle_seed_ = engine();

Expand Down Expand Up @@ -2778,6 +2819,12 @@ void GraphDataGenerator::AllocResource(
place_,
(once_sample_startid_len_ * walk_degree_ * repeat_time_) * sizeof(int),
phi::Stream(reinterpret_cast<phi::StreamId>(sample_stream_)));

d_random_row_col_shift_ = memory::AllocShared(
place_,
(once_sample_startid_len_ * walk_degree_ * repeat_time_) * sizeof(int),
phi::Stream(reinterpret_cast<phi::StreamId>(sample_stream_)));

shuffle_seed_ = 0;

ins_buf_pair_len_ = 0;
Expand Down
14 changes: 6 additions & 8 deletions paddle/fluid/framework/data_feed.h
Original file line number Diff line number Diff line change
Expand Up @@ -881,7 +881,8 @@ struct BufState {

int GetNextStep() {
step++;
if (step <= right && central_word + (*window)[step] < walk_len) {
// Checking out-of-bound by MakeInsPair
if (step <= right) {
return 1;
}
return 0;
Expand All @@ -905,13 +906,9 @@ struct BufState {
VLOG(2) << "random window: " << random_window << " window[" << left
<< "] = " << (*window)[left] << " window[" << right
<< "] = " << (*window)[right];

for (step = left; step <= right; step++) {
if (central_word + (*window)[step] >= 0) {
return 1;
}
}
return 0;
// Checking out-of-bound by MakeInsPair
step = left;
return 1;
}

int GetNextBatch() {
Expand Down Expand Up @@ -1054,6 +1051,7 @@ class GraphDataGenerator {
std::shared_ptr<phi::Allocation> d_feature_;
std::shared_ptr<phi::Allocation> d_len_per_row_;
std::shared_ptr<phi::Allocation> d_random_row_;
std::shared_ptr<phi::Allocation> d_random_row_col_shift_;
std::shared_ptr<phi::Allocation> d_uniq_node_num_;
std::shared_ptr<phi::Allocation> d_slot_feature_num_map_;
std::shared_ptr<phi::Allocation> d_actual_slot_id_map_;
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/framework/distributed_strategy.proto
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ message SparseNaiveSGDRuleParameter { // SparseNaiveSGDRule
}

message
SparseAdagradSGDRuleParameter { // SparseAdaGradSGDRule|StdAdaGradSGDRule
SparseAdagradSGDRuleParameter { // SparseAdaGradSGDRule|StdAdaGradSGDRule|SparseAdaGradV2SGDRule
optional double learning_rate = 1 [ default = 0.05 ];
optional double initial_g2sum = 2 [ default = 3.0 ];
optional double initial_range = 3 [ default = 0.0001 ];
Expand Down
17 changes: 15 additions & 2 deletions paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,9 @@ void HashTable<KeyType, ValType>::set_sparse_sgd(
cudaMemcpy(device_optimizer_config_,
&host_optimizer_config_,
sizeof(OptimizerConfig),
cudaMemcpyHostToDevice);
cudaMemcpyHostToDevice,
stream_);
cudaStreamSynchronize(stream_);
}

template <typename KeyType, typename ValType>
Expand All @@ -263,7 +265,9 @@ void HashTable<KeyType, ValType>::set_embedx_sgd(
cudaMemcpy(device_optimizer_config_,
&host_optimizer_config_,
sizeof(OptimizerConfig),
cudaMemcpyHostToDevice);
cudaMemcpyHostToDevice,
stream_);
cudaStreamSynchronize(stream_);
}

template <typename KeyType, typename ValType>
Expand Down Expand Up @@ -524,6 +528,15 @@ template void HashTable<uint64_t, float*>::update<
size_t len,
SparseAdagradOptimizer<CommonFeatureValueAccessor> sgd,
cudaStream_t stream);

template void HashTable<uint64_t, float*>::update<
SparseAdagradV2Optimizer<CommonFeatureValueAccessor>,
cudaStream_t>(const uint64_t* d_keys,
const char* d_grads,
size_t len,
SparseAdagradV2Optimizer<CommonFeatureValueAccessor> sgd,
cudaStream_t stream);

template void HashTable<uint64_t, float*>::update<
StdAdagradOptimizer<CommonFeatureValueAccessor>,
cudaStream_t>(const uint64_t* d_keys,
Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/framework/fleet/heter_ps/heter_ps.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ HeterPsBase* HeterPsBase::get_instance(
} else if (optimizer_type == 4) {
return new HeterPs<CommonFeatureValueAccessor, SparseAdamSharedOptimizer>(
capacity, resource, *gpu_accessor);
} else if (optimizer_type == 5) {
return new HeterPs<CommonFeatureValueAccessor, SparseAdagradV2Optimizer>(
capacity, resource, *gpu_accessor);
}
} else {
VLOG(0) << " HeterPsBase get_instance Warning: now only support "
Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/framework/fleet/heter_ps/heter_ps.cu
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ HeterPsBase* HeterPsBase::get_instance(
} else if (optimizer_type == 4) {
return new HeterPs<CommonFeatureValueAccessor, SparseAdamSharedOptimizer>(
capacity, resource, *gpu_accessor);
} else if (optimizer_type == 5) {
return new HeterPs<CommonFeatureValueAccessor, SparseAdagradV2Optimizer>(
capacity, resource, *gpu_accessor);
}
} else if (accessor_type == "DownpourCtrDymfAccessor" ||
accessor_type == "DownpourCtrDoubleDymfAccessor") {
Expand Down
Loading

0 comments on commit 67cc0ca

Please sign in to comment.