diff --git a/paddle/fluid/framework/new_executor/stream_analyzer.cc b/paddle/fluid/framework/new_executor/stream_analyzer.cc index 3ac30456a2d98..fdcd19b03098c 100644 --- a/paddle/fluid/framework/new_executor/stream_analyzer.cc +++ b/paddle/fluid/framework/new_executor/stream_analyzer.cc @@ -134,10 +134,10 @@ platform::DeviceContext* StreamAnalyzer::ParseDeviceContext( const OpFuncNode& op_func_node) { auto& op_type = op_func_node.operator_base_->Type(); auto* dev_ctx = op_func_node.dev_ctx_; - if (op_type == interpreter::kMemcpyH2D) { + if (op_type == interpreter::kMemcpyD2H) { VLOG(3) << "Get dev_ctx from d2h_context_pool_"; dev_ctx = d2h_ctx_pool_.Get(place_); - } else if (op_type == interpreter::kMemcpyD2H) { + } else if (op_type == interpreter::kMemcpyH2D) { VLOG(3) << "Get dev_ctx from h2d_context_pool_"; dev_ctx = h2d_ctx_pool_.Get(place_); } diff --git a/paddle/fluid/operators/controlflow/fetch_v2_op.cc b/paddle/fluid/operators/controlflow/fetch_v2_op.cc index 93035dddefee7..29132f2930acb 100644 --- a/paddle/fluid/operators/controlflow/fetch_v2_op.cc +++ b/paddle/fluid/operators/controlflow/fetch_v2_op.cc @@ -157,6 +157,7 @@ class FetchV2Kernel { DeepCopy(src_item, fetch_var_name, dst_item); } else { dst_item->ShareDataWith(src_item); + dst_item->set_lod(src_item.lod()); } } else { auto &src_item = fetch_var->Get(); @@ -172,6 +173,7 @@ class FetchV2Kernel { DeepCopy(src_item[i], fetch_var_name, &dst_item[i]); } else { dst_item[i].ShareDataWith(src_item[i]); + dst_item[i].set_lod(src_item[i].lod()); } } }