From c8b92f876b64bb2fc4990e04775336443269c6b3 Mon Sep 17 00:00:00 2001 From: zmxdream Date: Thu, 9 Dec 2021 17:28:47 +0800 Subject: [PATCH 1/4] update --- paddle/fluid/operators/filter_by_instag_op.cu | 306 ++++++++++++++++++ 1 file changed, 306 insertions(+) create mode 100644 paddle/fluid/operators/filter_by_instag_op.cu diff --git a/paddle/fluid/operators/filter_by_instag_op.cu b/paddle/fluid/operators/filter_by_instag_op.cu new file mode 100644 index 0000000000000..6c97d7e3bd21e --- /dev/null +++ b/paddle/fluid/operators/filter_by_instag_op.cu @@ -0,0 +1,306 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/mixed_vector.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/memory/memcpy.h" + +#include "paddle/fluid/operators/filter_by_instag_op.h" + +namespace cg = cooperative_groups; + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using SelectedRows = framework::SelectedRows; +using LoDTensor = framework::LoDTensor; +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +template +using Vector = framework::Vector; +#else +template +using Vector = framework::CPUVector; +#endif + +using CUDADeviceContext = paddle::platform::CUDADeviceContext; + +#define MAX_THREADS 1024 +#define THREADS 256 + +#define MAX_THREAD_STRIDE 32 +#define TILE_DIM 32 + +// Maximum sequence-length support based on the number of threads (2048) allowed +// in each block and +// this MAX is 8K For higher sequence length we need to use higher Max, like for +// 64K : 32 +#define MAX_THREAD_ITERATIONS 8 // Maximum 8K + +#define MAX_WARP_NUM 32 + +#define MAX_REGISTERS 256 + +// test real performance to decide one threads for ? ins + +template +__global__ void filter_by_instag_cuda_kernel(const int N, int64_t* x2_data, + size_t* x2_lods_data, + int64_t* x3_data, + int filter_tag_size, + int* pass_data) { + // N is instance num + // one threads for one ins + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= N) { + return; + } + int ins_tag_start = x2_loads_data[idx]; + int ins_tag_end = x2_loads_data[idx + 1]; + + // fileter logic + int i = ins_tag_start; + for (; i < ins_tag_end; i++) { + int64_t ins_tag = x2_data[i]; + int j = 0; + for (; j < filter_tag_size; j++) { + if (x3_data[j] == ins_tag) break; + } + // ins_tag in filter tag + if (j < filter_tag_size) { + pass_data[idx] = 1; + break; + } + } + // copy to output logic + // if (idx == 0) { + //} +} + +template +class FilterByInstagGPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const auto gpu_place = BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace()); + auto cpu_place = platform::CPUPlace(); + auto& dev_ctx = ctx.template device_context(); + + // X1 is global FC output + // Dim [batch size, embedding size] + auto* x1 = context.Input("Ins"); + bool is_lod = context.Attr("is_lod"); + + int is_x1_lod = -1; + if (is_lod) + is_x1_lod = 1; + else + is_x1_lod = 0; + + int64_t out_val_if_empty = context.Attr("out_val_if_empty"); + size_t x1_embed_size = x1->dims()[1]; + // X2 is ins tag list + // LoD [[0, Sum(ins1), Sum(ins1, ins2), ... ]] + auto* x2 = context.Input("Ins_tag"); + // expected auto = const int64_t + auto* x2_data = x2->data(); + + // X3 is local fc tag list + // LoD [[0, Sum(fc1), Sum(fc1, fc2) ...]] + auto* x3 = context.Input("Filter_tag"); + auto* x3_data = x3->data(); + + // Vector, in GPU + auto x2_lods = x2->lod()[0]; + const size_t* x2_lods_data = x2_loads.CUDAData(); + int N = static_cast(x2_lods.size()) - 1; + + // Vector, in GPU + Vector x1_lods(1, 0); + if (!is_x1_lod) { + for (int i = 0; i < x1->dims()[0]; i++) { + x1_lods.push_back(i + 1); + } + } else { + x1_lods = context.Input("Ins")->lod()[0]; + } + const size_t* x1_lods_data = x1_lods.CUDAData(); + + // set output value + // for those whose ins been dropout, set 0 for whole lines. + // otherwise, copy whole line + // Dim [local fc count, batch size, embedding size] + LoDTensor* out = context.Output("Out"); + LoDTensor* map = context.Output("IndexMap"); + LoDTensor* loss_weight = context.Output("LossWeight"); + + auto* out_data = out->mutable_data(context.GetPlace()); + auto* map_data = map->mutable_data(context.GetPlace()); + auto* loss_weight_data = + loss_weight->mutable_data(context.GetPlace()); + + std::unordered_map mmap_aux; + + // check configuration + // int block_size = 512; + int block_size = THREADS dim3 block_dim(block_size); + dim3 grid_dim((N + block_size - 1) / block_size); + + filter_by_instag_cuda_kernel<<>>( + N, x1_data, x1_lods_data, out_data, loss_weight_data, x2_data, + x2_loads_data, x3_data, is_x1_lod, x1_embed_size, out_val_if_empty); + + Vector out_lods(1, 0); + + if (out_lods.size() - 1 > 0) { + out->Resize(framework::make_ddim( + {(int64_t)out_lods.back(), (int64_t)x1_embed_size})); + + map->Resize(framework::make_ddim({(int64_t)out_lods.size() - 1, 3})); + loss_weight->Resize( + framework::make_ddim({(int64_t)out_lods.size() - 1, 1})); + + } else { + out->Resize(framework::make_ddim({1, (int64_t)x1_embed_size})); + map->Resize(framework::make_ddim({1, 3})); + loss_weight->Resize(framework::make_ddim({1, 1})); + } + + // auto* out_data = out->mutable_data(context.GetPlace()); + // auto* map_data = map->mutable_data(context.GetPlace()); + // auto* loss_weight_data = + // loss_weight->mutable_data(context.GetPlace()); + + if (out_lods.size() - 1 > 0) { + Vector map_lods; + for (size_t i = 0; i < out_lods.size() - 1; i++) { + map_data[i * 3] = (int64_t)out_lods[i]; + map_data[i * 3 + 1] = mmap_aux[map_data[i * 3]]; + map_data[i * 3 + 2] = out_lods[i + 1] - out_lods[i]; + map_lods.push_back(i); + } + + map_lods.push_back(out_lods.size() - 1); + std::vector> map_lod_info; + map_lod_info.push_back(map_lods); + + map->set_lod(map_lod_info); + loss_weight->set_lod(map_lod_info); + std::vector> out_lod_info; + out_lod_info.push_back(out_lods); + out->set_lod(out_lod_info); + memset(out_data, 0, out->numel() * sizeof(T)); + for (int i = 0; i < loss_weight->numel(); i++) { + loss_weight_data[i] = 1; + } + + for (size_t i = 0; i < out_lods.size() - 1; i++) { + size_t pos = out_lods[i]; + for (int k = map_data[i * 3 + 1]; + k < map_data[i * 3 + 1] + map_data[i * 3 + 2]; k++) { + memcpy(out_data + pos * x1_embed_size, x1_data + k * x1_embed_size, + x1_embed_size * sizeof(T)); + ++pos; + } + } + + } else { + Vector map_lods; + map_data[0] = 0; + map_data[1] = 1; + map_data[2] = 1; + map_lods.push_back(0); + map_lods.push_back(1); + out_lods.push_back(1); + std::vector> map_lod_info; + map_lod_info.push_back(map_lods); + map->set_lod(map_lod_info); + loss_weight->set_lod(map_lod_info); + std::vector> out_lod_info; + out_lod_info.push_back(out_lods); + out->set_lod(out_lod_info); + for (int64_t oi = 0; oi < out->numel(); ++oi) { + if (std::is_same::value) { + out_data[oi] = (int32_t)out_val_if_empty; + } else if (std::is_same::value) { + out_data[oi] = (int64_t)out_val_if_empty; + } else { + out_data[oi] = static_cast(out_val_if_empty); + } + } + loss_weight_data[0] = 0; + } + } +}; + +template +class FilterByInstagGradGPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* output_grad = context.Input(framework::GradVarName("Out")); + auto* x1_grad = context.Output(framework::GradVarName("Ins")); + auto* loss_weight = context.Input("LossWeight"); + auto* mmap = context.Input("IndexMap"); + auto* x1 = context.Input("Ins"); + + x1_grad->set_lod(context.Input("Ins")->lod()); + x1_grad->Resize(x1->dims()); + auto mmap_data = mmap->data(); + + // expected auto = T + auto* output_grad_data = output_grad->data(); + + auto* loss_weight_data = loss_weight->data(); + // expected auto = T + auto* x1_grad_data = x1_grad->mutable_data(context.GetPlace()); + memset(x1_grad_data, 0, x1->dims()[0] * x1->dims()[1] * sizeof(T)); + if (loss_weight->numel() != 1 || loss_weight_data[0] != 0) { + auto output_dims = output_grad->dims(); + for (int i = 0; i < mmap->dims()[0]; i++) { + int src_ln = mmap_data[i * 3], dst_ln = mmap_data[i * 3 + 1]; + int line_cnt = mmap_data[i * 3 + 2]; + for (int l = 0; l < line_cnt; l++) { + for (int j = 0; j < output_dims[1]; j++) { + x1_grad_data[(dst_ln + l) * output_dims[1] + j] = + output_grad_data[(src_ln + l) * output_dims[1] + j]; + } + } + } + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_CUDA_KERNEL(filter_by_instag, ops::FilterByInstagGPUKernel, + ops::FilterByInstagGPUKernel, + ops::FilterByInstagGPUKernel, + ops::FilterByInstagGPUKernel); + +REGISTER_OP_CUDA_KERNEL(filter_by_instag_grad, + ops::FilterByInstagGradGPUKernel, + ops::FilterByInstagGradGPUKernel, + ops::FilterByInstagGradGPUKernel, + ops::FilterByInstagGradGPUKernel); From ab45ab8bd5e41ecef1bcff019616f5081b678918 Mon Sep 17 00:00:00 2001 From: zmxdream Date: Wed, 29 Dec 2021 19:58:36 +0800 Subject: [PATCH 2/4] fix. test=develop --- .../distributed/fleet/base/fleet_base.py | 16 +++++-- .../tests/unittests/test_fleet_base_2.py | 46 ++----------------- 2 files changed, 16 insertions(+), 46 deletions(-) diff --git a/python/paddle/distributed/fleet/base/fleet_base.py b/python/paddle/distributed/fleet/base/fleet_base.py index c19ee1e192761..8440ac065a2fb 100755 --- a/python/paddle/distributed/fleet/base/fleet_base.py +++ b/python/paddle/distributed/fleet/base/fleet_base.py @@ -627,6 +627,8 @@ def init_server(self, *args, **kwargs): """ self._runtime_handle._init_server(*args, **kwargs) + @is_non_distributed_check + @inited_runtime_handler def load_model(self, path, mode): """ load fleet model from path @@ -699,6 +701,8 @@ def stop_worker(self): """ self._runtime_handle._stop_worker() + @is_non_distributed_check + @inited_runtime_handler def save(self, dirname, feed=[], fetch=[], **configs): inference = True @@ -742,6 +746,8 @@ def save(self, dirname, feed=[], fetch=[], **configs): self._runtime_handle._save_persistables( executor, dirname, main_program=None, mode=increment_mode) + @is_non_distributed_check + @inited_runtime_handler def save_inference_model(self, executor, dirname, @@ -777,6 +783,8 @@ def save_inference_model(self, executor, dirname, feeded_var_names, target_vars, main_program, export_for_deployment, mode) + @is_non_distributed_check + @inited_runtime_handler def save_persistables(self, executor, dirname, main_program=None, mode=0): """ @@ -1586,13 +1594,13 @@ def unscale_method(self, optimizer): ] param_grads_fp16 = [ param._grad_ivar() for param in optimizer._parameter_list - if (param._grad_ivar() is not None) and (param._grad_ivar( - ).dtype == core.VarDesc.VarType.FP16) + if (param._grad_ivar() is not None) and + (param._grad_ivar().dtype == core.VarDesc.VarType.FP16) ] param_grads_fp32 = [ param._grad_ivar() for param in optimizer._parameter_list - if (param._grad_ivar() is not None) and (param._grad_ivar( - ).dtype == core.VarDesc.VarType.FP32) + if (param._grad_ivar() is not None) and + (param._grad_ivar().dtype == core.VarDesc.VarType.FP32) ] temp_found_inf_fp16 = to_variable(np.array([0]).astype(np.bool)) temp_found_inf_fp32 = to_variable(np.array([0]).astype(np.bool)) diff --git a/python/paddle/fluid/tests/unittests/test_fleet_base_2.py b/python/paddle/fluid/tests/unittests/test_fleet_base_2.py index 3078e5b3d100e..42b46942427c0 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_base_2.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_base_2.py @@ -24,9 +24,9 @@ class TestFleetBase(unittest.TestCase): def setUp(self): os.environ["POD_IP"] = "127.0.0.1" os.environ["PADDLE_PORT"] = "36000" - os.environ["PADDLE_TRAINERS_NUM"] = "2" - os.environ["PADDLE_PSERVERS_IP_PORT_LIST"] = \ - "127.0.0.1:36001,127.0.0.2:36001" + os.environ["PADDLE_TRAINERS_NUM"] = "1" + #os.environ["PADDLE_PSERVERS_IP_PORT_LIST"] = \ + # "127.0.0.1:36001,127.0.0.2:36001" def test_ps_minimize(self): import paddle @@ -70,6 +70,7 @@ def test_ps_minimize(self): fluid.default_main_program()) fleet.init_worker() + fleet.fleet.save(dirname="/tmp", feed=['x', 'y'], fetch=[avg_cost]) fleet.fleet.save( dirname="/tmp", feed=[input_x, input_y], fetch=[avg_cost]) @@ -78,45 +79,6 @@ def test_ps_minimize(self): fleet.load_model(path="/tmp", mode=0) fleet.load_model(path="/tmp", mode=1) - self.assertRaises( - Exception, - fleet.save_inference_model, - dirname='/tmp/', - feeded_var_names=['x', 'y'], - target_vars=[avg_cost], - executor="exe") - - self.assertRaises( - Exception, - fleet.save_inference_model, - dirname='/tmp/', - feeded_var_names=['x', 'y'], - target_vars=[avg_cost], - executor=exe, - main_program=compiled_prog) - - self.assertRaises( - Exception, - fleet.save_inference_model, - dirname='afs:/tmp/', - feeded_var_names=['x', 'y'], - target_vars=[avg_cost], - executor=exe, - main_program=compiled_prog) - - self.assertRaises( - Exception, fleet.save_persistables, executor=pe, dirname='/tmp/') - - self.assertRaises( - Exception, fleet.save_persistables, executor="exe", dirname='/tmp/') - - self.assertRaises( - Exception, - fleet.save_persistables, - executor=exe, - dirname='/tmp/', - main_program=compiled_prog) - if __name__ == "__main__": unittest.main() From 1fcc88ee23c91eef7b3da8d68649b8accafbbabb Mon Sep 17 00:00:00 2001 From: zmxdream Date: Wed, 29 Dec 2021 20:10:23 +0800 Subject: [PATCH 3/4] fix. test=develop --- paddle/fluid/operators/filter_by_instag_op.cu | 306 ------------------ .../distributed/fleet/base/fleet_base.py | 8 +- .../tests/unittests/test_fleet_base_2.py | 1 - 3 files changed, 4 insertions(+), 311 deletions(-) delete mode 100644 paddle/fluid/operators/filter_by_instag_op.cu diff --git a/paddle/fluid/operators/filter_by_instag_op.cu b/paddle/fluid/operators/filter_by_instag_op.cu deleted file mode 100644 index 6c97d7e3bd21e..0000000000000 --- a/paddle/fluid/operators/filter_by_instag_op.cu +++ /dev/null @@ -1,306 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include -#include -#include -#include -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/lod_tensor.h" -#include "paddle/fluid/framework/mixed_vector.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/memory/memcpy.h" - -#include "paddle/fluid/operators/filter_by_instag_op.h" - -namespace cg = cooperative_groups; - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; -using SelectedRows = framework::SelectedRows; -using LoDTensor = framework::LoDTensor; -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -template -using Vector = framework::Vector; -#else -template -using Vector = framework::CPUVector; -#endif - -using CUDADeviceContext = paddle::platform::CUDADeviceContext; - -#define MAX_THREADS 1024 -#define THREADS 256 - -#define MAX_THREAD_STRIDE 32 -#define TILE_DIM 32 - -// Maximum sequence-length support based on the number of threads (2048) allowed -// in each block and -// this MAX is 8K For higher sequence length we need to use higher Max, like for -// 64K : 32 -#define MAX_THREAD_ITERATIONS 8 // Maximum 8K - -#define MAX_WARP_NUM 32 - -#define MAX_REGISTERS 256 - -// test real performance to decide one threads for ? ins - -template -__global__ void filter_by_instag_cuda_kernel(const int N, int64_t* x2_data, - size_t* x2_lods_data, - int64_t* x3_data, - int filter_tag_size, - int* pass_data) { - // N is instance num - // one threads for one ins - int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= N) { - return; - } - int ins_tag_start = x2_loads_data[idx]; - int ins_tag_end = x2_loads_data[idx + 1]; - - // fileter logic - int i = ins_tag_start; - for (; i < ins_tag_end; i++) { - int64_t ins_tag = x2_data[i]; - int j = 0; - for (; j < filter_tag_size; j++) { - if (x3_data[j] == ins_tag) break; - } - // ins_tag in filter tag - if (j < filter_tag_size) { - pass_data[idx] = 1; - break; - } - } - // copy to output logic - // if (idx == 0) { - //} -} - -template -class FilterByInstagGPUKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - const auto gpu_place = BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace()); - auto cpu_place = platform::CPUPlace(); - auto& dev_ctx = ctx.template device_context(); - - // X1 is global FC output - // Dim [batch size, embedding size] - auto* x1 = context.Input("Ins"); - bool is_lod = context.Attr("is_lod"); - - int is_x1_lod = -1; - if (is_lod) - is_x1_lod = 1; - else - is_x1_lod = 0; - - int64_t out_val_if_empty = context.Attr("out_val_if_empty"); - size_t x1_embed_size = x1->dims()[1]; - // X2 is ins tag list - // LoD [[0, Sum(ins1), Sum(ins1, ins2), ... ]] - auto* x2 = context.Input("Ins_tag"); - // expected auto = const int64_t - auto* x2_data = x2->data(); - - // X3 is local fc tag list - // LoD [[0, Sum(fc1), Sum(fc1, fc2) ...]] - auto* x3 = context.Input("Filter_tag"); - auto* x3_data = x3->data(); - - // Vector, in GPU - auto x2_lods = x2->lod()[0]; - const size_t* x2_lods_data = x2_loads.CUDAData(); - int N = static_cast(x2_lods.size()) - 1; - - // Vector, in GPU - Vector x1_lods(1, 0); - if (!is_x1_lod) { - for (int i = 0; i < x1->dims()[0]; i++) { - x1_lods.push_back(i + 1); - } - } else { - x1_lods = context.Input("Ins")->lod()[0]; - } - const size_t* x1_lods_data = x1_lods.CUDAData(); - - // set output value - // for those whose ins been dropout, set 0 for whole lines. - // otherwise, copy whole line - // Dim [local fc count, batch size, embedding size] - LoDTensor* out = context.Output("Out"); - LoDTensor* map = context.Output("IndexMap"); - LoDTensor* loss_weight = context.Output("LossWeight"); - - auto* out_data = out->mutable_data(context.GetPlace()); - auto* map_data = map->mutable_data(context.GetPlace()); - auto* loss_weight_data = - loss_weight->mutable_data(context.GetPlace()); - - std::unordered_map mmap_aux; - - // check configuration - // int block_size = 512; - int block_size = THREADS dim3 block_dim(block_size); - dim3 grid_dim((N + block_size - 1) / block_size); - - filter_by_instag_cuda_kernel<<>>( - N, x1_data, x1_lods_data, out_data, loss_weight_data, x2_data, - x2_loads_data, x3_data, is_x1_lod, x1_embed_size, out_val_if_empty); - - Vector out_lods(1, 0); - - if (out_lods.size() - 1 > 0) { - out->Resize(framework::make_ddim( - {(int64_t)out_lods.back(), (int64_t)x1_embed_size})); - - map->Resize(framework::make_ddim({(int64_t)out_lods.size() - 1, 3})); - loss_weight->Resize( - framework::make_ddim({(int64_t)out_lods.size() - 1, 1})); - - } else { - out->Resize(framework::make_ddim({1, (int64_t)x1_embed_size})); - map->Resize(framework::make_ddim({1, 3})); - loss_weight->Resize(framework::make_ddim({1, 1})); - } - - // auto* out_data = out->mutable_data(context.GetPlace()); - // auto* map_data = map->mutable_data(context.GetPlace()); - // auto* loss_weight_data = - // loss_weight->mutable_data(context.GetPlace()); - - if (out_lods.size() - 1 > 0) { - Vector map_lods; - for (size_t i = 0; i < out_lods.size() - 1; i++) { - map_data[i * 3] = (int64_t)out_lods[i]; - map_data[i * 3 + 1] = mmap_aux[map_data[i * 3]]; - map_data[i * 3 + 2] = out_lods[i + 1] - out_lods[i]; - map_lods.push_back(i); - } - - map_lods.push_back(out_lods.size() - 1); - std::vector> map_lod_info; - map_lod_info.push_back(map_lods); - - map->set_lod(map_lod_info); - loss_weight->set_lod(map_lod_info); - std::vector> out_lod_info; - out_lod_info.push_back(out_lods); - out->set_lod(out_lod_info); - memset(out_data, 0, out->numel() * sizeof(T)); - for (int i = 0; i < loss_weight->numel(); i++) { - loss_weight_data[i] = 1; - } - - for (size_t i = 0; i < out_lods.size() - 1; i++) { - size_t pos = out_lods[i]; - for (int k = map_data[i * 3 + 1]; - k < map_data[i * 3 + 1] + map_data[i * 3 + 2]; k++) { - memcpy(out_data + pos * x1_embed_size, x1_data + k * x1_embed_size, - x1_embed_size * sizeof(T)); - ++pos; - } - } - - } else { - Vector map_lods; - map_data[0] = 0; - map_data[1] = 1; - map_data[2] = 1; - map_lods.push_back(0); - map_lods.push_back(1); - out_lods.push_back(1); - std::vector> map_lod_info; - map_lod_info.push_back(map_lods); - map->set_lod(map_lod_info); - loss_weight->set_lod(map_lod_info); - std::vector> out_lod_info; - out_lod_info.push_back(out_lods); - out->set_lod(out_lod_info); - for (int64_t oi = 0; oi < out->numel(); ++oi) { - if (std::is_same::value) { - out_data[oi] = (int32_t)out_val_if_empty; - } else if (std::is_same::value) { - out_data[oi] = (int64_t)out_val_if_empty; - } else { - out_data[oi] = static_cast(out_val_if_empty); - } - } - loss_weight_data[0] = 0; - } - } -}; - -template -class FilterByInstagGradGPUKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* output_grad = context.Input(framework::GradVarName("Out")); - auto* x1_grad = context.Output(framework::GradVarName("Ins")); - auto* loss_weight = context.Input("LossWeight"); - auto* mmap = context.Input("IndexMap"); - auto* x1 = context.Input("Ins"); - - x1_grad->set_lod(context.Input("Ins")->lod()); - x1_grad->Resize(x1->dims()); - auto mmap_data = mmap->data(); - - // expected auto = T - auto* output_grad_data = output_grad->data(); - - auto* loss_weight_data = loss_weight->data(); - // expected auto = T - auto* x1_grad_data = x1_grad->mutable_data(context.GetPlace()); - memset(x1_grad_data, 0, x1->dims()[0] * x1->dims()[1] * sizeof(T)); - if (loss_weight->numel() != 1 || loss_weight_data[0] != 0) { - auto output_dims = output_grad->dims(); - for (int i = 0; i < mmap->dims()[0]; i++) { - int src_ln = mmap_data[i * 3], dst_ln = mmap_data[i * 3 + 1]; - int line_cnt = mmap_data[i * 3 + 2]; - for (int l = 0; l < line_cnt; l++) { - for (int j = 0; j < output_dims[1]; j++) { - x1_grad_data[(dst_ln + l) * output_dims[1] + j] = - output_grad_data[(src_ln + l) * output_dims[1] + j]; - } - } - } - } - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; - -REGISTER_OP_CUDA_KERNEL(filter_by_instag, ops::FilterByInstagGPUKernel, - ops::FilterByInstagGPUKernel, - ops::FilterByInstagGPUKernel, - ops::FilterByInstagGPUKernel); - -REGISTER_OP_CUDA_KERNEL(filter_by_instag_grad, - ops::FilterByInstagGradGPUKernel, - ops::FilterByInstagGradGPUKernel, - ops::FilterByInstagGradGPUKernel, - ops::FilterByInstagGradGPUKernel); diff --git a/python/paddle/distributed/fleet/base/fleet_base.py b/python/paddle/distributed/fleet/base/fleet_base.py index 8440ac065a2fb..bc59b87e2ffa5 100755 --- a/python/paddle/distributed/fleet/base/fleet_base.py +++ b/python/paddle/distributed/fleet/base/fleet_base.py @@ -1594,13 +1594,13 @@ def unscale_method(self, optimizer): ] param_grads_fp16 = [ param._grad_ivar() for param in optimizer._parameter_list - if (param._grad_ivar() is not None) and - (param._grad_ivar().dtype == core.VarDesc.VarType.FP16) + if (param._grad_ivar() is not None) and (param._grad_ivar( + ).dtype == core.VarDesc.VarType.FP16) ] param_grads_fp32 = [ param._grad_ivar() for param in optimizer._parameter_list - if (param._grad_ivar() is not None) and - (param._grad_ivar().dtype == core.VarDesc.VarType.FP32) + if (param._grad_ivar() is not None) and (param._grad_ivar( + ).dtype == core.VarDesc.VarType.FP32) ] temp_found_inf_fp16 = to_variable(np.array([0]).astype(np.bool)) temp_found_inf_fp32 = to_variable(np.array([0]).astype(np.bool)) diff --git a/python/paddle/fluid/tests/unittests/test_fleet_base_2.py b/python/paddle/fluid/tests/unittests/test_fleet_base_2.py index 42b46942427c0..9675a77d6766b 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_base_2.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_base_2.py @@ -70,7 +70,6 @@ def test_ps_minimize(self): fluid.default_main_program()) fleet.init_worker() - fleet.fleet.save(dirname="/tmp", feed=['x', 'y'], fetch=[avg_cost]) fleet.fleet.save( dirname="/tmp", feed=[input_x, input_y], fetch=[avg_cost]) From 5a213afeaa01de3b1af9c4e20a913456534f6118 Mon Sep 17 00:00:00 2001 From: zmxdream Date: Wed, 29 Dec 2021 20:48:06 +0800 Subject: [PATCH 4/4] fix. test=develop --- .../tests/unittests/test_dist_fleet_a_sync_optimizer_sync.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_dist_fleet_a_sync_optimizer_sync.py b/python/paddle/fluid/tests/unittests/test_dist_fleet_a_sync_optimizer_sync.py index 668b4ad872f43..4b1f0ee85d944 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_fleet_a_sync_optimizer_sync.py +++ b/python/paddle/fluid/tests/unittests/test_dist_fleet_a_sync_optimizer_sync.py @@ -62,10 +62,6 @@ def test_gradient_merge_optimizer(self): self.assertEqual(sends, 0) self.assertEqual(sgds, 0) - fleet.init_worker() - time.sleep(8) - fleet.stop_worker() - if __name__ == "__main__": unittest.main()