Skip to content

Commit

Permalink
[CustomDevice] add custom ccl 1/2 (PaddlePaddle#44294)
Browse files Browse the repository at this point in the history
* [CustomDevice] add custom ccl api

* add ut
  • Loading branch information
ronny1996 committed Jul 14, 2022
1 parent c446ab7 commit d88e77a
Show file tree
Hide file tree
Showing 15 changed files with 994 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,15 @@ def FindParsingFunctionFromAttributeType(atype):
#else
PADDLE_THROW(paddle::platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with GPU if use CUDAPlace."));
#endif
}}
if (paddle::platform::is_custom_place(place)) {{
#if defined(PADDLE_WITH_CUSTOM_DEVICE)
phi::DeviceManager::SetDevice(place);
VLOG(1) <<"CurrentDeviceId: " << phi::DeviceManager::GetDevice(place.GetDeviceType()) << " from " << (int)place.device;
#else
PADDLE_THROW(paddle::platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with CUSTOM_DEVICE if use CustomPlace."));
#endif
}}
"""
Expand All @@ -200,6 +209,7 @@ def FindParsingFunctionFromAttributeType(atype):
#include <Python.h>
#include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/api/include/strings_api.h"
#include "paddle/phi/backends/device_manager.h"
#include "paddle/fluid/pybind/eager_utils.h"
#include "paddle/fluid/pybind/exception.h"
#include "paddle/fluid/platform/profiler/event_tracing.h"
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/memory/allocation/naive_best_fit_allocator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -764,7 +764,7 @@ class BuddyAllocatorList {
private:
explicit BuddyAllocatorList(const std::string &device_type)
: device_type_(device_type) {
auto devices = phi::DeviceManager::GetDeviceList(device_type);
auto devices = phi::DeviceManager::GetSelectedDeviceList(device_type);
for (auto dev_id : devices) {
init_flags_[dev_id].reset(new std::once_flag());
}
Expand Down
8 changes: 4 additions & 4 deletions paddle/fluid/platform/init.cc
Original file line number Diff line number Diff line change
Expand Up @@ -264,11 +264,11 @@ void InitDevices(const std::vector<int> devices) {

auto device_types = phi::DeviceManager::GetAllCustomDeviceTypes();
for (auto &dev_type : device_types) {
auto device_count = phi::DeviceManager::GetDeviceCount(dev_type);
auto device_list = phi::DeviceManager::GetSelectedDeviceList(dev_type);
LOG(INFO) << "CustomDevice: " << dev_type
<< ", visible devices count: " << device_count;
for (size_t i = 0; i < device_count; i++) {
places.push_back(platform::CustomPlace(dev_type, i));
<< ", visible devices count: " << device_list.size();
for (auto &dev_id : device_list) {
places.push_back(platform::CustomPlace(dev_type, dev_id));
}
}
} else {
Expand Down
20 changes: 20 additions & 0 deletions paddle/phi/backends/c_comm_lib.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/backends/c_comm_lib.h"

namespace phi {
// Even this source file does not contains any code, it is better to keep this
// source file for cmake dependency.
} // namespace phi
60 changes: 60 additions & 0 deletions paddle/phi/backends/c_comm_lib.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include <vector>

#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/errors.h"
#include "paddle/phi/core/macros.h"

namespace phi {
namespace ccl {
using CCLComm = void*;
using CCLRootId = std::vector<uint8_t>;

enum CCLReduceOp { SUM = 0, AVG, MAX, MIN, PRODUCT };
enum CCLDataType {
CCL_DATA_TYPE_FP64 = 0,
CCL_DATA_TYPE_FP32,
CCL_DATA_TYPE_FP16,
CCL_DATA_TYPE_INT64,
CCL_DATA_TYPE_INT32,
CCL_DATA_TYPE_INT16,
CCL_DATA_TYPE_INT8
};

inline CCLDataType ToCCLDataType(paddle::experimental::DataType type) {
if (type == paddle::experimental::DataType::FLOAT64) {
return CCL_DATA_TYPE_FP64;
} else if (type == paddle::experimental::DataType::FLOAT32) {
return CCL_DATA_TYPE_FP32;
} else if (type == paddle::experimental::DataType::FLOAT16) {
return CCL_DATA_TYPE_FP16;
} else if (type == paddle::experimental::DataType::INT64) {
return CCL_DATA_TYPE_INT64;
} else if (type == paddle::experimental::DataType::INT32) {
return CCL_DATA_TYPE_INT32;
} else if (type == paddle::experimental::DataType::INT8) {
return CCL_DATA_TYPE_INT8;
} else {
PADDLE_THROW(
phi::errors::Unimplemented("This datatype in CCL is not supported."));
}
}

} // namespace ccl
} // namespace phi
4 changes: 3 additions & 1 deletion paddle/phi/backends/callback_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include "paddle/fluid/platform/device/device_wrapper.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/backends/device_guard.h"

namespace phi {

Expand All @@ -33,12 +34,13 @@ void CallbackManager::AddCallback(std::function<void()> callback) const {
(*callback_func)();
});
});

phi::DeviceGuard guard(stream_->GetPlace());
phi::DeviceManager::GetDeviceWithPlace(stream_->GetPlace())
->AddCallback(stream_, func);
}

void CallbackManager::Wait() const {
phi::DeviceGuard guard(stream_->GetPlace());
phi::DeviceManager::GetDeviceWithPlace(stream_->GetPlace())
->SynchronizeStream(stream_);

Expand Down
Loading

0 comments on commit d88e77a

Please sign in to comment.