Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[HeterPs] add delta score, scale show #33492

Merged
merged 1 commit into from
Jun 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 0 additions & 19 deletions paddle/fluid/framework/fleet/heter_ps/feature_value.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,25 +52,6 @@ struct FeaturePushValue {
float lr_g;
float mf_g[MF_DIM];
};
// class DownpourFixedFeatureValue {
// public:
// DownpourFixedFeatureValue() {}
// ~DownpourFixedFeatureValue() {}
// float* data() {
// return _data.data();
// }
// size_t size() {
// return _data.size();
// }
// void resize(size_t size) {
// _data.resize(size);
// }
// void shrink_to_fit() {
// _data.shrink_to_fit();
// }
// private:
// std::vector<float> _data;
// };

} // end namespace framework
} // end namespace paddle
Expand Down
31 changes: 4 additions & 27 deletions paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,30 +23,6 @@ limitations under the License. */
namespace paddle {
namespace framework {

__device__ double cuda_double_random(unsigned long long seed) {
// copy from MurmurHash3
seed ^= seed >> 33;
seed *= 0xff51afd7ed558ccd;
seed ^= seed >> 33;
seed *= 0xc4ceb9fe1a85ec53;
seed ^= seed >> 33;
return ((double)seed / 18446744073709551615.0);
}

__device__ float cuda_normal_random(unsigned long long idx) {
static double pi = 3.1415926897932384;
unsigned long long x = clock64() + idx;
double x1, x2, res;
while (1) {
x1 = cuda_double_random(x);
x2 = cuda_double_random(x + 33);
res = sqrt(-2.0 * log(x1)) * cos(2.0 * pi * x2);
if (-10 < res && res < 10) break;
x += 207;
}
return res;
}

template <typename ValType, typename GradType>
class Optimizer {
public:
Expand Down Expand Up @@ -95,11 +71,12 @@ class Optimizer {
}
__device__ void update_value(ValType& val, const GradType& grad) {
val.slot = grad.slot;
;
val.show += grad.show;
val.clk += grad.clk;
val.delta_score += optimizer_config::nonclk_coeff * (grad.show - grad.clk) +
optimizer_config::clk_coeff * grad.clk;

update_lr(val.lr, val.lr_g2sum, grad.lr_g, 1.0);
update_lr(val.lr, val.lr_g2sum, grad.lr_g, grad.show);

if (val.mf_size == 0) {
if (optimizer_config::mf_create_thresholds <=
Expand All @@ -116,7 +93,7 @@ class Optimizer {
}
}
} else {
update_mf(MF_DIM, &val.mf[1], val.mf[0], grad.mf_g, 1.0);
update_mf(MF_DIM, &val.mf[1], val.mf[0], grad.mf_g, grad.show);
}
}
};
Expand Down
5 changes: 3 additions & 2 deletions paddle/fluid/framework/fleet/heter_ps/optimizer_conf.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,16 @@ limitations under the License. */

namespace optimizer_config {

__constant__ float mf_create_thresholds = 0;
__constant__ float nonclk_coeff = 0.1;
__constant__ float clk_coeff = 1;

__constant__ float min_bound = -10;
__constant__ float max_bound = 10;
__constant__ float learning_rate = 0.05;
__constant__ float initial_g2sum = 3.0;
__constant__ float initial_range = 1e-4;
__constant__ float initial_range = 0;

__constant__ float mf_create_thresholds = 10;
__constant__ float mf_learning_rate = 0.05;
__constant__ float mf_initial_g2sum = 3.0;
__constant__ float mf_initial_range = 1e-4;
Expand Down
16 changes: 8 additions & 8 deletions paddle/fluid/framework/io/fs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -240,16 +240,16 @@ void set_download_command(const std::string& x) {

std::shared_ptr<FILE> hdfs_open_read(std::string path, int* err_no,
const std::string& converter) {
if (fs_end_with_internal(path, ".gz")) {
path = string::format_string("%s -text \"%s\"", hdfs_command().c_str(),
if (download_cmd() != "") { // use customized download command
path = string::format_string("%s \"%s\"", download_cmd().c_str(),
path.c_str());
} else {
const std::string file_path = path;
path = string::format_string("%s -cat \"%s\"", hdfs_command().c_str(),
file_path.c_str());
if (download_cmd() != "") { // use customized download command
path = string::format_string("%s \"%s\"", download_cmd().c_str(),
file_path.c_str());
if (fs_end_with_internal(path, ".gz")) {
path = string::format_string("%s -text \"%s\"", hdfs_command().c_str(),
path.c_str());
} else {
path = string::format_string("%s -cat \"%s\"", hdfs_command().c_str(),
path.c_str());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import copy
from .node import DownpourWorker, DownpourServer
from . import ps_pb2 as pslib
import os

OpRole = core.op_proto_and_checker_maker.OpRole
# this dict is for store info about pull/push sparse ops.
Expand Down Expand Up @@ -765,7 +766,8 @@ def _minimize(self,
"user_define_dump_filename", "")
opt_info["dump_fields_path"] = strategy.get("dump_fields_path", "")
opt_info["dump_param"] = strategy.get("dump_param", [])
opt_info["worker_places"] = strategy.get("worker_places", [])
gpus_env = os.getenv("FLAGS_selected_gpus")
opt_info["worker_places"] = [int(s) for s in gpus_env.split(",")]
opt_info["use_ps_gpu"] = strategy.get("use_ps_gpu", False)
if server._server.downpour_server_param.downpour_table_param[
0].accessor.accessor_class in [
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/fluid/incubate/fleet/utils/fleet_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
__all__ = ["FleetUtil"]

_logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s')
__name__, logging.INFO, fmt='%(asctime)s %(levelname)s: %(message)s')

fleet = None

Expand Down
2 changes: 1 addition & 1 deletion python/paddle/fluid/log_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def get_logger(name, level, fmt=None):
handler = logging.StreamHandler()

if fmt:
formatter = logging.Formatter(fmt=fmt)
formatter = logging.Formatter(fmt=fmt, datefmt='%a %b %d %H:%M:%S')
handler.setFormatter(formatter)

logger.addHandler(handler)
Expand Down