Skip to content

Commit

Permalink
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Browse files Browse the repository at this point in the history
… fix_fetchop
  • Loading branch information
chengduoZH committed May 8, 2018
2 parents 6f306f0 + 22ab14c commit a459764
Show file tree
Hide file tree
Showing 56 changed files with 1,310 additions and 276 deletions.
28 changes: 22 additions & 6 deletions benchmark/cluster/vgg16/vgg16_fluid.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ def str2bool(v):
type=str,
default="",
help="Comma-separated list of hostname:port pairs")
parser.add_argument(
"--profile", action='store_true', help="If set, profile a few steps.")

# Flags for defining the tf.train.Server
parser.add_argument(
Expand Down Expand Up @@ -183,8 +185,8 @@ def train_loop(exe, trainer_prog):
start_time = time.time()
num_samples = 0
train_pass_acc.reset()
for batch_id, data in enumerate(train_reader()):
ts = time.time()

def run_step(batch_id, data):
img_data = np.array(
map(lambda x: x[0].reshape(data_shape), data)).astype(
"float32")
Expand All @@ -196,14 +198,28 @@ def train_loop(exe, trainer_prog):
feed={"pixel": img_data,
"label": y_data},
fetch_list=[avg_cost, batch_acc, batch_size])
return loss, acc, b_size

if args.profile and args.task_index == 0:
# warmup.
for batch_id, data in enumerate(train_reader()):
if batch_id > 5: break
run_step(batch_id, data)
with profiler.profiler('All', 'total', '/tmp/profile_vgg'):
for batch_id, data in enumerate(train_reader()):
if batch_id > 5: break
run_step(batch_id, data)

for batch_id, data in enumerate(train_reader()):
ts = time.time()
loss, acc, b_size = run_step(batch_id, data)
iters += 1
num_samples += len(data)
train_pass_acc.add(value=acc, weight=b_size)
print(
"Task:%d Pass = %d, Iters = %d, Loss = %f, Accuracy = %f, "
"Speed = %.2f img/s " % (args.task_index, pass_id, iters,
loss, acc,
len(data) / (time.time() - ts))
"Pass = %d, Iters = %d, Loss = %f, Accuracy = %f, "
"Speed = %.2f img/s" % (pass_id, iters, loss, acc,
len(data) / (time.time() - ts))
) # The accuracy is the accumulation of batches, but not the current batch.

pass_elapsed = time.time() - start_time
Expand Down

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions doc/fluid/build_and_install/paddleci.png
6 changes: 3 additions & 3 deletions doc/fluid/design/motivation/refactorization.md
Original file line number Diff line number Diff line change
Expand Up @@ -125,12 +125,12 @@ Compile Time -> IR -> Runtime

## Operator/OpWithKernel/OpKernel

![class_diagram](http://api.paddlepaddle.org/graphviz?dot=https://gist.githubusercontent.com/reyoung/53df507f6749762675dff3e7ce53372f/raw/49caf1fb70820fb4a6c217634317c9306f361f36/op_op_with_kern_class_diagram.dot)
![class_diagram](https://raw.githubusercontent.com/PaddlePaddle/Paddle/develop/doc/fluid/images/op_op_with_kern_class_diagram.dot)

---

## Operator
![class_diagram](http://api.paddlepaddle.org/graphviz?dot=https://gist.githubusercontent.com/reyoung/53df507f6749762675dff3e7ce53372f/raw/dd598e8f1976f5759f58af5e5ef94738a6b2e661/op.dot)
![class_diagram](https://raw.githubusercontent.com/PaddlePaddle/Paddle/develop/doc/fluid/images/op.dot)

* `Operator` is the fundamental building block of the user interface.
* Operator stores input/output variable names and attributes.
Expand All @@ -141,7 +141,7 @@ Compile Time -> IR -> Runtime

## OpWithKernel/Kernel

![class_diagram](http://api.paddlepaddle.org/graphviz?dot=https://gist.githubusercontent.com/reyoung/53df507f6749762675dff3e7ce53372f/raw/9d7f4eba185cf41c8e2fbfb40ae21890dbddcd39/op_with_kernel.dot)
![class_diagram](https://raw.githubusercontent.com/PaddlePaddle/Paddle/develop/doc/fluid/images/op_with_kernel.dot)

* `OpWithKernel` inherits `Operator`.
* `OpWithKernel` contains a Kernel map.
Expand Down
4 changes: 4 additions & 0 deletions doc/fluid/images/op.dot
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
digraph sample {
graph [rankdir=TD]; node [shape=record];
op [label="{Operator| InferShape()=0\lRun()=0\l | map<string, string[]> inputs_\lmap<string, string[]> outputs_ \l AttributeMap attrs_\l}"];
}
38 changes: 38 additions & 0 deletions doc/fluid/images/op_op_with_kern_class_diagram.dot
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
digraph sample {
graph [rankdir=TD]; node [shape=record];
op [label="{Operator| InferShape()=0\lRun()=0\l | map<string, string[]> inputs_\lmap<string, string[]> outputs_ \l AttributeMap attrs_\l}"];
op_with_kern [label="{OpWithKernel | InferShape()=0\lRun()\l | map<OpKernelKey,OpKernel>kernels_ }"]
op_kernel [label="{OpKernel | Compute()=0}"]
op_kernel_key [label="{OpKernelKey| Place place\n...}"]

op -> op_with_kern [dir=back, arrowtail=onormal]
op_with_kern -> op_kernel [arrowhead=vee, label="contains many"]

{
rank=same;
op_with_kern
op_kernel
}

op_kernel -> op_kernel_key [style=invis]

{
rank=same;
op_kernel
op_kernel_key
}

op_with_kern -> op_kernel_key [arrowhead=vee, label ="\nas map key"]

mul_op [label="MulOp"]
op_with_kern -> mul_op [dir=back, arrowtail=onormal]
mul_kernel [label="template <typename Place>\lclass MulOpKernel\l"]
op_kernel -> mul_kernel [dir=back, arrowtail=onormal]
mul_op -> mul_kernel [arrowhead=vee, label="register many"]

{
rank=same;
mul_op;
mul_kernel;
}
}
26 changes: 26 additions & 0 deletions doc/fluid/images/op_with_kernel.dot
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
digraph sample {
graph [rankdir=TD]; node [shape=record];
op [label="{Operator}"];
op_with_kern [label="{OpWithKernel | InferShape()=0\lRun()\l | map<OpKernelKey,OpKernel>kernels_ }"]
op_kernel [label="{OpKernel | Compute()=0}"]
op_kernel_key [label="{OpKernelKey| Place place\n...}"]

op -> op_with_kern [dir=back, arrowtail=onormal]
op_with_kern -> op_kernel [arrowhead=vee, label="contains many"]

{
rank=same;
op_with_kern
op_kernel
}

op_kernel -> op_kernel_key [style=invis]

{
rank=same;
op_kernel
op_kernel_key
}

op_with_kern -> op_kernel_key [arrowhead=vee, label ="\nas map key"]
}
11 changes: 8 additions & 3 deletions doc/v2/api/config/layer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ gated_unit
-----------
.. autoclass:: paddle.v2.layer.gated_unit
:noindex:

Recurrent Layer Group
=====================

Expand Down Expand Up @@ -354,7 +354,7 @@ dropout
--------
.. autoclass:: paddle.v2.layer.dropout
:noindex:

dot_prod
---------
.. autoclass:: paddle.v2.layer.dot_prod
Expand Down Expand Up @@ -460,6 +460,11 @@ multi_binary_label_cross_entropy_cost
.. autoclass:: paddle.v2.layer.multi_binary_label_cross_entropy_cost
:noindex:

classification_cost
-------------------
.. autoclass:: paddle.v2.layer.classification_cost
:noindex:

huber_regression_cost
-------------------------
.. autoclass:: paddle.v2.layer.huber_regression_cost
Expand Down Expand Up @@ -534,7 +539,7 @@ detection_output
----------------
.. autoclass:: paddle.v2.layer.detection_output
:noindex:

Check Layer
============

Expand Down
4 changes: 2 additions & 2 deletions doc/v2/howto/cluster/multi_cluster/k8s_distributed_en.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ Training docker image needs to package the paddle pserver and paddle trainer run
- Generating the initialization arguments for `Paddle PServer` and `Paddle Training` processes.

Since the paddlepaddle official docker image already has the runtimes we need, we'll take it as the base image and pack some additional scripts for the processes mentioned above to build our training image. for more detail, please find from the following link:
- https://github.com/PaddlePaddle/Paddle/blob/develop/doc/howto/usage/cluster/src/k8s_train/Dockerfile
- https://github.com/PaddlePaddle/Paddle/tree/develop/doc/v2/howto/cluster/multi_cluster/src/k8s_train/Dockerfile


```bash
Expand All @@ -62,7 +62,7 @@ represent the Docker Image which built in this step.
### Prepare Training Data

We can download and split the training job by creating a Kubernetes Job, or custom your image
by editing [k8s_train](./src/k8s_train/).
by editing [k8s_train](https://github.com/PaddlePaddle/Paddle/tree/develop/doc/v2/howto/cluster/multi_cluster/src/k8s_train).

Before creating a Job, we need to bind a [persistenVolumeClaim](https://kubernetes.io/docs/user-guide/persistent-volumes) by the different type of
the different file system, the generated dataset would be saved on this volume.
Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/framework/executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -348,8 +348,12 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
}
}
}
platform::DeviceContextPool::Instance().Get(place_)->Wait();
if (create_vars && create_local_scope) {
scope->DeleteScope(local_scope);
} else {
// Delete the local scopes created in operators.
scope->DropKids();
}
if (FLAGS_benchmark) {
VLOG(2) << "-------------------------------------------------------";
Expand Down
5 changes: 3 additions & 2 deletions paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
nv_test(test_tensorrt_op_converter SRCS test_op_converter.cc mul_op.cc conv2d_op.cc DEPS ${FLUID_CORE_MODULES})
nv_test(test_tensorrt_activation_op SRCS test_activation_op.cc ${ENGINE_FILE} activation_op.cc
nv_test(test_op_converter SRCS test_op_converter.cc mul_op.cc conv2d_op.cc DEPS ${FLUID_CORE_MODULES})
nv_test(test_trt_activation_op SRCS test_activation_op.cc ${ENGINE_FILE} activation_op.cc
DEPS ${FLUID_CORE_MODULES} activation_op)
nv_test(test_io_converter SRCS test_io_converter.cc io_converter.cc DEPS dynload_cuda dynamic_loader lod_tensor)
57 changes: 57 additions & 0 deletions paddle/fluid/inference/tensorrt/convert/io_converter.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/inference/tensorrt/convert/io_converter.h"
#include <cuda.h>
#include "paddle/fluid/platform/enforce.h"

namespace paddle {
namespace inference {
namespace tensorrt {

using platform::is_gpu_place;
using platform::is_cpu_place;

class DefaultInputConverter : public EngineInputConverter {
public:
DefaultInputConverter() {}
// NOTE out is GPU memory.
virtual void operator()(const LoDTensor& in, void* out,
size_t max_size) override {
PADDLE_ENFORCE(out != nullptr);
PADDLE_ENFORCE_LE(in.memory_size(), max_size);
const auto& place = in.place();
if (is_cpu_place(place)) {
PADDLE_ENFORCE(stream_ != nullptr);
PADDLE_ENFORCE_EQ(0,
cudaMemcpyAsync(out, in.data<float>(), in.memory_size(),
cudaMemcpyHostToDevice, *stream_));

} else if (is_gpu_place(place)) {
PADDLE_ENFORCE_EQ(0,
cudaMemcpyAsync(out, in.data<float>(), in.memory_size(),
cudaMemcpyHostToHost, *stream_));

} else {
PADDLE_THROW("Unknown device for converter");
}
cudaStreamSynchronize(*stream_);
}
};

REGISTER_TENSORRT_INPUT_CONVERTER(default, DefaultInputConverter);

} // namespace tensorrt
} // namespace inference
} // namespace paddle
67 changes: 67 additions & 0 deletions paddle/fluid/inference/tensorrt/convert/io_converter.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <unordered_map>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/inference/utils/singleton.h"

namespace paddle {
namespace inference {
namespace tensorrt {

using framework::LoDTensor;

/*
* Convert Input from Fluid to an Engine.
* TensorRT's ITensor follows row major, NCHW. Fluid is also row major, so in
* most cases just need to copy the data.
*/
class EngineInputConverter {
public:
EngineInputConverter() {}

virtual void operator()(const LoDTensor& in, void* out, size_t max_size) {}

void SetStream(cudaStream_t* stream) { stream_ = stream; }

static void Run(const std::string& in_op_type, const LoDTensor& in, void* out,
size_t max_size, cudaStream_t* stream) {
PADDLE_ENFORCE(stream != nullptr);
auto* converter = Registry<EngineInputConverter>::Lookup(
in_op_type, "default" /* default_type */);
PADDLE_ENFORCE_NOT_NULL(converter);
converter->SetStream(stream);
(*converter)(in, out, max_size);
}

virtual ~EngineInputConverter() {}

protected:
cudaStream_t* stream_{nullptr};
};

} // namespace tensorrt
} // namespace inference
} // namespace paddle

#define REGISTER_TENSORRT_INPUT_CONVERTER(in_op_type__, Converter__) \
struct trt_input_##in_op_type__##_converter { \
trt_input_##in_op_type__##_converter() { \
::paddle::inference::Registry<EngineInputConverter>::Register< \
Converter__>(#in_op_type__); \
} \
}; \
trt_input_##in_op_type__##_converter trt_input_##in_op_type__##_converter__;
Loading

0 comments on commit a459764

Please sign in to comment.