Skip to content

Commit

Permalink
【auto parallel】剔除切分推导相关的头文件对proto 的依赖 (#60543)
Browse files Browse the repository at this point in the history
* decouple proto

* format

* format

* strcuct pre def
  • Loading branch information
liuzhenhai93 committed Jan 4, 2024
1 parent 2ad9e24 commit 353cb27
Show file tree
Hide file tree
Showing 16 changed files with 202 additions and 81 deletions.
10 changes: 7 additions & 3 deletions paddle/fluid/distributed/auto_parallel/dist_attr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/var_desc.h"
#include "paddle/phi/core/distributed/auto_parallel/proto_helper.h"

namespace paddle {
namespace distributed {
Expand Down Expand Up @@ -406,14 +407,17 @@ OperatorDistAttrProto OperatorDistAttr::to_proto() const {
for (const auto& item : input_dist_attrs_) {
auto proto_item = proto.mutable_input_dist_attrs()->Add();
proto_item->set_name(item.first);
proto_item->mutable_tensor_dist_attr()->CopyFrom(item.second.to_proto());
proto_item->mutable_tensor_dist_attr()->CopyFrom(
phi::distributed::to_proto(item.second));
}
for (const auto& item : output_dist_attrs_) {
auto proto_item = proto.mutable_output_dist_attrs()->Add();
proto_item->set_name(item.first);
proto_item->mutable_tensor_dist_attr()->CopyFrom(item.second.to_proto());
proto_item->mutable_tensor_dist_attr()->CopyFrom(
phi::distributed::to_proto(item.second));
}
proto.mutable_process_mesh()->CopyFrom(process_mesh_.to_proto());
proto.mutable_process_mesh()->CopyFrom(
phi::distributed::to_proto(process_mesh_));
proto.set_impl_type(impl_type_);
proto.set_impl_idx(impl_idx_);
proto.set_chunk_id(chunk_id_);
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/core/distributed/auto_parallel/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ collect_srcs(
dist_mapper.cc
dist_tensor.cc
dist_meta_tensor.cc
proto_helper.cc
placement_types.cc
inferspmd_utils.cc)

Expand Down
70 changes: 31 additions & 39 deletions paddle/phi/core/distributed/auto_parallel/device_mesh.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ limitations under the License. */
#include <iterator>

#include "paddle/phi/core/distributed/auto_parallel/device_mesh.h"
#include "paddle/phi/core/distributed/auto_parallel/proto_helper.h"
#include "paddle/phi/core/distributed/auto_parallel/utils.h"

namespace phi {
namespace distributed {
namespace auto_parallel {
Expand All @@ -41,13 +41,11 @@ DeviceCapability DeviceCapability::from_proto(
return capability;
}

DeviceCapabilityProto DeviceCapability::to_proto() const {
DeviceCapabilityProto proto;
proto.set_single_precision_flops(single_precision_flops);
proto.set_double_precision_flops(double_precision_flops);
proto.set_memory_size_in_bytes(memory_size_in_bytes);
proto.set_clock_rate_in_ghz(clock_rate_in_ghz);
return proto;
void DeviceCapability::to_proto(DeviceCapabilityProto *proto) const {
proto->set_single_precision_flops(single_precision_flops);
proto->set_double_precision_flops(double_precision_flops);
proto->set_memory_size_in_bytes(memory_size_in_bytes);
proto->set_clock_rate_in_ghz(clock_rate_in_ghz);
}

std::string Device::to_string() const {
Expand All @@ -69,14 +67,13 @@ Device Device::from_proto(const DeviceProto &proto) {
return device;
}

DeviceProto Device::to_proto() const {
DeviceProto proto;
proto.set_global_id(global_id_);
proto.set_local_id(local_id_);
proto.set_machine_id(machine_id_);
proto.set_type(type_);
proto.mutable_capability()->CopyFrom(capability_.to_proto());
return proto;
void Device::to_proto(DeviceProto *proto) const {
proto->set_global_id(global_id_);
proto->set_local_id(local_id_);
proto->set_machine_id(machine_id_);
proto->set_type(type_);
proto->mutable_capability()->CopyFrom(
phi::distributed::to_proto(capability_));
}

bool operator==(const Device &lhs, const Device &rhs) {
Expand Down Expand Up @@ -109,11 +106,9 @@ LinkCapability LinkCapability::from_proto(const LinkCapabilityProto &proto) {
return capability;
}

LinkCapabilityProto LinkCapability::to_proto() const {
LinkCapabilityProto proto;
proto.set_bandwidth(bandwidth);
proto.set_latency(latency);
return proto;
void LinkCapability::to_proto(LinkCapabilityProto *proto) const {
proto->set_bandwidth(bandwidth);
proto->set_latency(latency);
}

std::string Link::to_string() const {
Expand All @@ -133,13 +128,12 @@ Link Link::from_proto(const LinkProto &proto) {
return link;
}

LinkProto Link::to_proto() const {
LinkProto proto;
proto.set_source_id(source_id_);
proto.set_target_id(target_id_);
proto.set_type(type_);
proto.mutable_capability()->CopyFrom(capability_.to_proto());
return proto;
void Link::to_proto(LinkProto *proto) const {
proto->set_source_id(source_id_);
proto->set_target_id(target_id_);
proto->set_type(type_);
proto->mutable_capability()->CopyFrom(
phi::distributed::to_proto(capability_));
}

bool operator==(const Link &lhs, const Link &rhs) {
Expand Down Expand Up @@ -355,34 +349,32 @@ DeviceMesh DeviceMesh::from_proto(const DeviceMeshProto &proto) {
return mesh;
}

DeviceMeshProto DeviceMesh::to_proto() const {
DeviceMeshProto proto;

proto.set_name(name_);
void DeviceMesh::to_proto(DeviceMeshProto *proto) const {
proto->set_name(name_);

for (const auto &i : shape_) {
proto.add_shape(i);
proto->add_shape(i);
}

for (const auto &i : device_ids_) {
proto.add_device_ids(i);
proto->add_device_ids(i);
}

for (const auto &i : dim_names_) {
proto.add_dim_names(i);
proto->add_dim_names(i);
}

for (const auto &device : devices_) {
proto.mutable_devices()->Add()->CopyFrom(device.second.to_proto());
proto->mutable_devices()->Add()->CopyFrom(
phi::distributed::to_proto(device.second));
}

for (const auto &neighbors : links_) {
for (const auto &link : neighbors.second) {
proto.mutable_links()->Add()->CopyFrom(link.second.to_proto());
proto->mutable_links()->Add()->CopyFrom(
phi::distributed::to_proto(link.second));
}
}

return proto;
}

bool operator==(const DeviceMesh &lhs, const DeviceMesh &rhs) {
Expand Down
17 changes: 12 additions & 5 deletions paddle/phi/core/distributed/auto_parallel/device_mesh.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,13 @@ limitations under the License. */
namespace phi {
namespace distributed {
namespace auto_parallel {

class DeviceCapabilityProto;
class DeviceProto;
class LinkCapabilityProto;
class LinkProto;
class DeviceMeshProto;

struct DeviceCapability {
double single_precision_flops = 0.0;
double double_precision_flops = 0.0;
Expand All @@ -40,7 +47,7 @@ struct DeviceCapability {
std::string to_string() const;

static DeviceCapability from_proto(const DeviceCapabilityProto& proto);
DeviceCapabilityProto to_proto() const;
void to_proto(DeviceCapabilityProto* proto) const;
};

inline std::ostream& operator<<(std::ostream& os, const DeviceCapability& obj) {
Expand Down Expand Up @@ -74,7 +81,7 @@ class Device {
std::string to_string() const;

static Device from_proto(const DeviceProto& proto);
DeviceProto to_proto() const;
void to_proto(DeviceProto* proto) const;

private:
int64_t global_id_;
Expand Down Expand Up @@ -103,7 +110,7 @@ struct LinkCapability {
std::string to_string() const;

static LinkCapability from_proto(const LinkCapabilityProto& proto);
LinkCapabilityProto to_proto() const;
void to_proto(LinkCapabilityProto* proto) const;
};

inline std::ostream& operator<<(std::ostream& os, const LinkCapability& obj) {
Expand Down Expand Up @@ -131,7 +138,7 @@ class Link {
std::string to_string() const;

static Link from_proto(const LinkProto& proto);
LinkProto to_proto() const;
void to_proto(LinkProto* proto) const;

private:
int64_t source_id_;
Expand Down Expand Up @@ -273,7 +280,7 @@ class DeviceMesh {
std::string to_string() const;

static DeviceMesh from_proto(const DeviceMeshProto& proto);
DeviceMeshProto to_proto() const;
void to_proto(DeviceMeshProto* proto) const;

private:
std::string name_;
Expand Down
20 changes: 10 additions & 10 deletions paddle/phi/core/distributed/auto_parallel/dist_attr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License. */
#include <iterator>

#include "glog/logging.h"
#include "paddle/phi/core/distributed/auto_parallel/proto_helper.h"

namespace phi {
namespace distributed {
Expand Down Expand Up @@ -308,25 +309,24 @@ void TensorDistAttr::from_proto(const TensorDistAttrProto& proto) {
}
}

TensorDistAttrProto TensorDistAttr::to_proto() const {
TensorDistAttrProto proto;
proto.mutable_process_mesh()->CopyFrom(process_mesh_.to_proto());
void TensorDistAttr::to_proto(TensorDistAttrProto* proto) const {
proto->mutable_process_mesh()->CopyFrom(
phi::distributed::to_proto(process_mesh_));
for (const auto& i : dims_mapping_) {
proto.add_dims_mapping(i);
proto->add_dims_mapping(i);
}
proto.set_batch_dim(batch_dim_);
proto.set_chunk_id(chunk_id_);
proto->set_batch_dim(batch_dim_);
proto->set_chunk_id(chunk_id_);
for (const auto& i : dynamic_dims_) {
proto.add_dynamic_dims(i);
proto->add_dynamic_dims(i);
}
return proto;
}

std::string TensorDistAttr::serialize_to_string() {
std::string data;
auto proto = to_proto();
auto proto = phi::distributed::to_proto(*this);
proto.SerializeToString(&data);
PADDLE_ENFORCE_EQ(to_proto().SerializeToString(&data),
PADDLE_ENFORCE_EQ(phi::distributed::to_proto(*this).SerializeToString(&data),
true,
errors::InvalidArgument(
"Failed to serialize tensor dist attr to string."));
Expand Down
7 changes: 5 additions & 2 deletions paddle/phi/core/distributed/auto_parallel/dist_attr.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ limitations under the License. */
#include <vector>

#include "paddle/phi/common/reduce_type.h"
#include "paddle/phi/core/distributed/auto_parallel/auto_parallel.pb.h"
#include "paddle/phi/core/distributed/auto_parallel/process_mesh.h"
#include "paddle/phi/core/distributed/auto_parallel/utils.h"
#include "paddle/phi/core/enforce.h"
Expand All @@ -32,6 +31,10 @@ limitations under the License. */
namespace phi {
namespace distributed {

namespace auto_parallel {
class TensorDistAttrProto;
}

constexpr int kReplicateDim = -1;

class PlacementStatus {
Expand Down Expand Up @@ -169,7 +172,7 @@ class TEST_API TensorDistAttr {
// future partial-support-stage-II.
void from_proto(const auto_parallel::TensorDistAttrProto& proto);

auto_parallel::TensorDistAttrProto to_proto() const;
void to_proto(auto_parallel::TensorDistAttrProto* proto) const;

std::string serialize_to_string();

Expand Down
10 changes: 5 additions & 5 deletions paddle/phi/core/distributed/auto_parallel/dist_mapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License. */
#include <algorithm>

#include "paddle/phi/core/distributed/auto_parallel/dist_mapper.h"
#include "paddle/phi/core/distributed/auto_parallel/proto_helper.h"
#include "paddle/phi/core/distributed/auto_parallel/utils.h"

namespace phi {
Expand Down Expand Up @@ -91,20 +92,19 @@ DistributedMapper DistributedMapper::from_proto(
return dist_mapper;
}

DistributedMapperProto DistributedMapper::to_proto() const {
DistributedMapperProto proto;
void DistributedMapper::to_proto(DistributedMapperProto* proto) const {
for (const auto& item : device_meshes_) {
proto.mutable_device_meshes()->Add()->CopyFrom(item.second.to_proto());
proto->mutable_device_meshes()->Add()->CopyFrom(
phi::distributed::to_proto(item.second));
}
for (const auto& outer : process_id_to_device_ids_) {
auto proto_item = proto.mutable_process_id_to_device_ids()->Add();
auto proto_item = proto->mutable_process_id_to_device_ids()->Add();
proto_item->set_process_id(outer.first);
proto_item->set_device_mesh_name(outer.second.first);
for (const auto& inner : outer.second.second) {
proto_item->add_device_ids(inner);
}
}
return proto;
}

std::string DistributedMapper::to_string() const {
Expand Down
5 changes: 3 additions & 2 deletions paddle/phi/core/distributed/auto_parallel/dist_mapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@ limitations under the License. */

#include <utility>

#include "paddle/phi/core/distributed/auto_parallel/auto_parallel.pb.h"
#include "paddle/phi/core/distributed/auto_parallel/device_mesh.h"
#include "paddle/phi/core/distributed/auto_parallel/process_mesh.h"

namespace phi {
namespace distributed {
namespace auto_parallel {

class DistributedMapperProto;

class DistributedMapper {
public:
DistributedMapper() = default;
Expand Down Expand Up @@ -52,7 +53,7 @@ class DistributedMapper {
std::string to_string() const;

static DistributedMapper from_proto(const DistributedMapperProto& proto);
DistributedMapperProto to_proto() const;
void to_proto(DistributedMapperProto* proto) const;

private:
std::map<std::string, DeviceMesh> device_meshes_;
Expand Down
13 changes: 4 additions & 9 deletions paddle/phi/core/distributed/auto_parallel/process_mesh.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ limitations under the License. */

#include <algorithm>
#include <iterator>

#include "paddle/phi/core/distributed/auto_parallel/utils.h"

namespace phi {
Expand Down Expand Up @@ -105,22 +104,18 @@ ProcessMesh ProcessMesh::from_proto(const ProcessMeshProto &proto) {
return mesh;
}

ProcessMeshProto ProcessMesh::to_proto() const {
ProcessMeshProto proto;

void ProcessMesh::to_proto(ProcessMeshProto *proto) const {
for (const auto &i : shape_) {
proto.add_shape(i);
proto->add_shape(i);
}

for (const auto &i : process_ids_) {
proto.add_process_ids(i);
proto->add_process_ids(i);
}

for (const auto &i : dim_names_) {
proto.add_dim_names(i);
proto->add_dim_names(i);
}

return proto;
}

bool operator==(const ProcessMesh &lhs, const ProcessMesh &rhs) {
Expand Down
Loading

0 comments on commit 353cb27

Please sign in to comment.