Skip to content

Commit

Permalink
Add training support (PaddlePaddle#26)
Browse files Browse the repository at this point in the history
  • Loading branch information
XBWGC committed Aug 5, 2021
1 parent 9b51865 commit 2d40322
Show file tree
Hide file tree
Showing 6 changed files with 170 additions and 57 deletions.
49 changes: 45 additions & 4 deletions paddle/fluid/framework/ipu/ipu_backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,28 @@ void IpuBackend::Compile(ir::Graph* graph,
}
}

std::unique_ptr<popart::Optimizer> IpuBackend::GetPopartOptimizer() {
// TODO(xiaobingw): change type_ to enum
PADDLE_ENFORCE_NE(
optimizer_.type_, "",
platform::errors::InvalidArgument("Optimizer type have not been set."));
if (optimizer_.type_ == "adam") {
auto optimizer = std::make_unique<popart::Adam>(
popart::OptimizerValue(0.01, false),
popart::OptimizerValue(0.0f, false),
popart::OptimizerValue(GetOptimizerAttr("beta1"), false),
popart::OptimizerValue(GetOptimizerAttr("beta2"), false),
popart::OptimizerValue(GetOptimizerAttr("epsilon"), false),
popart::OptimizerValue(1.0f, false), popart::AdamMode::Adam,
popart::WeightDecayMode::Decay, popart::DataType::FLOAT,
popart::DataType::FLOAT, popart::DataType::FLOAT);
return optimizer;
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Optimizer %s is not implemented now.", optimizer_.type_));
}
}

void IpuBackend::Prepare() {
VLOG(1) << "Save Model to file paddle_model.onnx ...\n";
builder_->saveModelProto("paddle_model.onnx");
Expand All @@ -111,19 +133,38 @@ void IpuBackend::Prepare() {
deviceOpts);
// or acquireAvailableDevice();

VLOG(1) << "Creating session from Onnx Model...";
session_ = popart::InferenceSession::createFromOnnxModel(proto, dataFlow,
ipuModelDevice);
if (ipu_build_strategy_ != nullptr && ipu_build_strategy_->is_training_) {
VLOG(1) << "Creating TrainingSession from Onnx Model...";
auto popart_optimizer = GetPopartOptimizer();
auto it = tensors_.find(optimizer_.loss_);
PADDLE_ENFORCE_NE(
it, tensors_.end(),
paddle::platform::errors::InvalidArgument(
"loss_id = %s doesn't exist in popart graph.", optimizer_.loss_));
session_ = popart::TrainingSession::createFromOnnxModel(
proto, dataFlow, it->second, *popart_optimizer, ipuModelDevice);
} else {
VLOG(1) << "Creating InferenceSession from Onnx Model...";
session_ = popart::InferenceSession::createFromOnnxModel(proto, dataFlow,
ipuModelDevice);
}
VLOG(1) << "Creating session from Onnx Model...done";

VLOG(1) << "Preparing session device...";
session_->prepareDevice();
VLOG(1) << "Preparing session device...done";

VLOG(1) << "Copy weights from host to device...";
session_->weightsFromHost();
VLOG(1) << "Copy weights from host to device...done";
}

void IpuBackend::Run(const std::vector<const Tensor*>& inputs,
const std::vector<Tensor*>& outputs) {
Prepare();
if (!is_prepared_) {
Prepare();
is_prepared_ = true;
}

std::map<popart::TensorId, popart::IArray&> popart_inputs;
std::map<popart::TensorId, popart::NDArrayWrapper<float>> input_wrappers;
Expand Down
22 changes: 17 additions & 5 deletions paddle/fluid/framework/ipu/ipu_backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@ limitations under the License. */
#include <string>
#include <vector>

#include <popart/adam.hpp>
#include <popart/builder.hpp>
#include <popart/dataflow.hpp>
#include <popart/devicemanager.hpp>
#include <popart/names.hpp>
#include <popart/ndarraywrapper.hpp>
#include <popart/optimizer.hpp>
#include <popart/session.hpp>
#include <popart/sessionoptions.hpp>
#include <popart/stepio.hpp>
Expand Down Expand Up @@ -62,8 +64,11 @@ class IpuBackend {

void SetOptimizerType(const std::string &type) { optimizer_.type_ = type; }

const std::map<std::string, float> &GetOptimizerAttr() {
return optimizer_.attrs_;
float GetOptimizerAttr(const std::string &name, float default_value = 0.0f) {
if (optimizer_.attrs_.count(name) == 0) {
return default_value;
}
return optimizer_.attrs_.at(name);
}

void SetOptimizerAttr(const std::string &attr, float value) {
Expand All @@ -72,12 +77,18 @@ class IpuBackend {

void SetLoss(const std::string &loss) { optimizer_.loss_ = loss; }

std::unique_ptr<popart::Optimizer> GetPopartOptimizer();

std::vector<int64_t> GetTensorShape(const std::string &var_name) {
return builder_->getTensorShape(tensors_[var_name]);
}

// SetScope, so we can get model parameters from scope
void SetScope(Scope *scope) { scope_ = scope; }
void SetScope(const Scope &scope) { scope_.reset(&scope); }

void SetIpuBuildStrategy(const IpuBuildStrategy &strategy) {
ipu_build_strategy_.reset(&strategy);
}

static std::shared_ptr<IpuBackend> GetInstance() {
if (NULL == instance_) {
Expand All @@ -94,8 +105,9 @@ class IpuBackend {

private:
Optimizer optimizer_;
IpuBuildStrategy ipu_build_strategy_;
Scope *scope_ = nullptr;
bool is_prepared_ = false;
std::shared_ptr<const Scope> scope_ = nullptr;
std::shared_ptr<const IpuBuildStrategy> ipu_build_strategy_ = nullptr;

std::vector<popart::TensorId> inputs_;
std::vector<popart::TensorId> outputs_;
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/framework/ipu/ipu_build_strategy.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ namespace framework {
namespace ipu {

struct IpuBuildStrategy {
bool is_training_ = true;
popart::SessionOptions popart_options_;
};

Expand Down
24 changes: 19 additions & 5 deletions paddle/fluid/pybind/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,8 @@ limitations under the License. */
#endif

#ifdef PADDLE_WITH_IPU
#include "paddle/fluid/platform/ipu_info.h"
#include "paddle/fluid/framework/ipu/ipu_backend.h"
#include "paddle/fluid/platform/ipu_info.h"
#endif

#ifdef PADDLE_WITH_CRYPTO
Expand Down Expand Up @@ -1816,7 +1816,7 @@ All parameter, weight, gradient are variables in Paddle.
&IsSamePlace<platform::NPUPlace, platform::CUDAPinnedPlace>)
.def("__str__", string::to_string<const platform::NPUPlace &>);

// IPUPlace
// IPUPlace
py::class_<platform::IPUPlace>(m, "IPUPlace", R"DOC(
IPUPlace is a descriptor of a device.
It represents a IPU device on which a tensor will be allocated and a model will run.
Expand Down Expand Up @@ -3204,10 +3204,24 @@ All parameter, weight, gradient are variables in Paddle.
.def("device_count", &ParallelExecutor::DeviceCount);

#ifdef PADDLE_WITH_IPU
py::class_<framework::IpuBackend, std::shared_ptr<framework::IpuBackend>>(m,
"IpuBackend")
py::class_<framework::IpuBackend, std::shared_ptr<framework::IpuBackend>>(
m, "IpuBackend")
.def(py::init(&IpuBackend::GetInstance))
.def("set_scope", &IpuBackend::SetScope);
.def("set_scope",
[](IpuBackend &self, const Scope &scope) { self.SetScope(scope); })
.def("set_ipu_build_strategy",
[](IpuBackend &self, const IpuBuildStrategy &strategy) {
self.SetIpuBuildStrategy(strategy);
});
// TODO(xiaobingw): maybe refactor at future
py::class_<framework::ipu::IpuBuildStrategy>(m, "IpuBuildStrategy")
.def(py::init())
.def_property(
"is_training",
[](const IpuBuildStrategy &self) { return self.is_training_; },
[](IpuBuildStrategy &self, bool is_training) {
self.is_training_ = is_training;
});
#endif

BindFleetWrapper(&m);
Expand Down
31 changes: 27 additions & 4 deletions python/paddle/fluid/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,16 +504,23 @@ def __init__(self, program, scope=None, ipu_build_strategy=None):
# import here to avoiding confused
import paddle

self._program = program
self._graph = core.Graph(program.desc)
self._compiled = False

if scope is not None:
self._scope = scope
else:
self._scope = paddle.static.global_scope()
self._program = program
self._graph = core.Graph(program.desc)
self._ipu_build_strategy = ipu_build_strategy
self._compiled = False

if ipu_build_strategy is not None:
self._ipu_build_strategy = ipu_build_strategy
else:
self._ipu_build_strategy = get_ipu_build_strategy()

self._backend = core.IpuBackend()
self._backend.set_scope(self._scope)
self._backend.set_ipu_build_strategy(self._ipu_build_strategy)
self._graph_passes = [
"optimizer_extract_pass", "forward_graph_extract_pass",
"popart_canonicalization_pass"
Expand Down Expand Up @@ -541,3 +548,19 @@ def compile(self, feed_list, fetch_list, scope=None):
program = framework.Program._construct_from_desc(desc)

return program


def get_ipu_build_strategy():
"""
Create and return IpuBuildStrategy instance. We get IpuBuildStrategy from
python side, and the set by IpuBackend.set_ipu_build_strategy.
"""
if not core.is_compiled_with_ipu():
raise ValueError(
"Can't get ipu_build_strategy, since PaddlePaddle is not compiled" \
" with IPU"
)

ipu_build_strategy = core.IpuBuildStrategy()

return ipu_build_strategy
100 changes: 61 additions & 39 deletions python/paddle/fluid/tests/unittests/ipu/ipu_training_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,44 +12,66 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
import numpy as np
import paddle
import paddle.static
import paddle.fluid.compiler as compiler

paddle.seed(2021)
np.random.seed(2021)

if __name__ == "__main__":
run_on_ipu = True
fetch_loss = True

paddle.enable_static()

# model input
image = paddle.static.data(
name='image', shape=[1, 3, 10, 10], dtype='float32')
conv1 = paddle.static.nn.conv2d(
image, num_filters=3, filter_size=3, bias_attr=False)
conv2 = conv1 + conv1
loss = paddle.mean(conv2)
adam = paddle.optimizer.Adam(learning_rate=1e-3)

# apply optimizer
adam.minimize(loss)

# switch cpu/ipu place
if run_on_ipu:
place = paddle.IPUPlace(0)
else:
place = paddle.CPUPlace()
executor = paddle.static.Executor(place)

startup_prog = paddle.static.default_startup_program()
executor.run(startup_prog)

# graph
feed_list = [image.name]

# switch loss and conv1
if fetch_loss:
fetch_node = loss
else:
fetch_node = conv1
fetch_list = [fetch_node.name]

main_prog = paddle.static.default_main_program()

if run_on_ipu:
ipu_build_strategy = compiler.get_ipu_build_strategy()
ipu_build_strategy.is_training = False # default True
program = compiler.IpuCompiler(
main_prog, ipu_build_strategy=ipu_build_strategy).compile(
feed_list, fetch_list)
else:
program = main_prog

np_image = np.random.rand(1, 3, 10, 10).astype(np.float32)
res = executor.run(program,
feed={image.name: np_image},
fetch_list=[fetch_node])

# 飞桨2.X默认模式为动态图,需要开启静态图模式
paddle.enable_static()

# 编译期:调用飞桨的API编写Python程序,如下述代码中定义了一个含conv2d的网络,并使用Adam优化器优化参数。
image = paddle.static.data(
name='image', shape=[None, 3, 224, 224], dtype='float32')
conv_result = paddle.static.nn.conv2d(image, num_filters=64, filter_size=3)
loss = paddle.mean(conv_result)
adam = paddle.optimizer.Adam(learning_rate=1e-3)
adam.minimize(loss)

# 运行期:先运行一次startup program初始化网络参数,然后调用飞桨的Executor和CompiledProgram API运行网络。
place = paddle.IPUPlace(0) # 使用何种设备运行网络,IPUPlace表示使用IPU运行
executor = paddle.static.Executor(place) # 创建执行器
print("---------- startup_program --------------")
prog = paddle.static.default_startup_program()
print(prog._to_readable_code())
executor.run(prog) # 运行startup program进行参数初始化

print("---------- main_program --------------")
prog = paddle.static.default_main_program()
print(prog._to_readable_code())

# 再使用CompiledProgram编译网络,准备执行。
compiled_program = paddle.static.CompiledProgram(prog)

BATCH_NUM = 2
BATCH_SIZE = 32

for batch_id in range(BATCH_NUM):
input_image = np.random.random([BATCH_SIZE, 3, 224, 224]).astype('float32')
loss_numpy, = executor.run(compiled_program,
feed={'image': input_image},
fetch_list=[loss])
print("Batch {}, loss = {}".format(batch_id, loss_numpy))

# 关闭静态图模式
paddle.disable_static()
print(res)

0 comments on commit 2d40322

Please sign in to comment.