diff --git a/adversarial/README.md b/adversarial/README.md new file mode 100644 index 0000000000000..51da21918a9d6 --- /dev/null +++ b/adversarial/README.md @@ -0,0 +1,9 @@ +# Advbox + +Advbox is a Python toolbox to create adversarial examples that fool neural networks. It requires Python and paddle. + +## How to use + +1. train a model and save it's parameters. (like fluid_mnist.py) +2. load the parameters which is trained in step1, then reconstruct the model.(like mnist_tutorial_fgsm.py) +3. use advbox to generate the adversarial sample. diff --git a/adversarial/advbox/__init__.py b/adversarial/advbox/__init__.py new file mode 100644 index 0000000000000..f56f14f18dafd --- /dev/null +++ b/adversarial/advbox/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + A set of tools for generating adversarial example on paddle platform +""" diff --git a/adversarial/advbox/attacks/base.py b/adversarial/advbox/attacks/base.py new file mode 100644 index 0000000000000..98a65f2fddff9 --- /dev/null +++ b/adversarial/advbox/attacks/base.py @@ -0,0 +1,39 @@ +""" +The base model of the model. +""" +from abc import ABCMeta, abstractmethod + + +class Attack(object): + """ + Abstract base class for adversarial attacks. `Attack` represent an adversarial attack + which search an adversarial example. subclass should implement the _apply() method. + + Args: + model(Model): an instance of the class advbox.base.Model. + + """ + __metaclass__ = ABCMeta + + def __init__(self, model): + self.model = model + + def __call__(self, image_label): + """ + Generate the adversarial sample. + + Args: + image_label(list): The image and label tuple list with one element. + """ + adv_img = self._apply(image_label) + return adv_img + + @abstractmethod + def _apply(self, image_label): + """ + Search an adversarial example. + + Args: + image_batch(list): The image and label tuple list with one element. + """ + raise NotImplementedError diff --git a/adversarial/advbox/attacks/gradientsign.py b/adversarial/advbox/attacks/gradientsign.py new file mode 100644 index 0000000000000..15b1d176cb113 --- /dev/null +++ b/adversarial/advbox/attacks/gradientsign.py @@ -0,0 +1,38 @@ +""" +This module provide the attack method for FGSM's implement. +""" +from __future__ import division +import numpy as np +from collections import Iterable +from .base import Attack + + +class GradientSignAttack(Attack): + """ + This attack was originally implemented by Goodfellow et al. (2015) with the + infinity norm (and is known as the "Fast Gradient Sign Method"). This is therefore called + the Fast Gradient Method. + Paper link: https://arxiv.org/abs/1412.6572 + """ + + def _apply(self, image_label, epsilons=1000): + assert len(image_label) == 1 + pre_label = np.argmax(self.model.predict(image_label)) + + min_, max_ = self.model.bounds() + gradient = self.model.gradient(image_label) + gradient_sign = np.sign(gradient) * (max_ - min_) + + if not isinstance(epsilons, Iterable): + epsilons = np.linspace(0, 1, num=epsilons + 1) + + for epsilon in epsilons: + adv_img = image_label[0][0].reshape( + gradient_sign.shape) + epsilon * gradient_sign + adv_img = np.clip(adv_img, min_, max_) + adv_label = np.argmax(self.model.predict([(adv_img, 0)])) + if pre_label != adv_label: + return adv_img + + +FGSM = GradientSignAttack diff --git a/adversarial/advbox/models/__init__.py b/adversarial/advbox/models/__init__.py new file mode 100644 index 0000000000000..eee0f6efd4774 --- /dev/null +++ b/adversarial/advbox/models/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Paddle model for target of attack +""" diff --git a/adversarial/advbox/models/base.py b/adversarial/advbox/models/base.py new file mode 100644 index 0000000000000..74e1045def764 --- /dev/null +++ b/adversarial/advbox/models/base.py @@ -0,0 +1,90 @@ +""" +The base model of the model. +""" +from abc import ABCMeta +import abc + +abstractmethod = abc.abstractmethod + + +class Model(object): + """ + Base class of model to provide attack. + + + Args: + bounds(tuple): The lower and upper bound for the image pixel. + channel_axis(int): The index of the axis that represents the color channel. + preprocess(tuple): Two element tuple used to preprocess the input. First + substract the first element, then divide the second element. + """ + __metaclass__ = ABCMeta + + def __init__(self, bounds, channel_axis, preprocess=None): + assert len(bounds) == 2 + assert channel_axis in [0, 1, 2, 3] + + if preprocess is None: + preprocess = (0, 1) + self._bounds = bounds + self._channel_axis = channel_axis + self._preprocess = preprocess + + def bounds(self): + """ + Return the upper and lower bounds of the model. + """ + return self._bounds + + def channel_axis(self): + """ + Return the channel axis of the model. + """ + return self._channel_axis + + def _process_input(self, input_): + res = input_ + sub, div = self._preprocess + if sub != 0: + res = input_ - sub + assert div != 0 + if div != 1: + res /= div + return res + + @abstractmethod + def predict(self, image_batch): + """ + Calculate the prediction of the image batch. + + Args: + image_batch(numpy.ndarray): image batch of shape (batch_size, height, width, channels). + + Return: + numpy.ndarray: predictions of the images with shape (batch_size, num_of_classes). + """ + raise NotImplementedError + + @abstractmethod + def num_classes(self): + """ + Determine the number of the classes + + Return: + int: the number of the classes + """ + raise NotImplementedError + + @abstractmethod + def gradient(self, image_batch): + """ + Calculate the gradient of the cross-entropy loss w.r.t the image. + + Args: + image_batch(list): The image and label tuple list. + + Return: + numpy.ndarray: gradient of the cross-entropy loss w.r.t the image with + the shape (height, width, channel). + """ + raise NotImplementedError diff --git a/adversarial/advbox/models/paddle.py b/adversarial/advbox/models/paddle.py new file mode 100644 index 0000000000000..33b2a3d5c6973 --- /dev/null +++ b/adversarial/advbox/models/paddle.py @@ -0,0 +1,101 @@ +from __future__ import absolute_import + +import numpy as np +import paddle.v2 as paddle +import paddle.v2.fluid as fluid +from paddle.v2.fluid.framework import program_guard + +from .base import Model + + +class PaddleModel(Model): + """ + Create a PaddleModel instance. + When you need to generate a adversarial sample, you should construct an instance of PaddleModel. + + Args: + program(paddle.v2.fluid.framework.Program): The program of the model which generate the adversarial sample. + input_name(string): The name of the input. + logits_name(string): The name of the logits. + predict_name(string): The name of the predict. + cost_name(string): The name of the loss in the program. + """ + + def __init__(self, + program, + input_name, + logits_name, + predict_name, + cost_name, + bounds, + channel_axis=3, + preprocess=None): + super(PaddleModel, self).__init__( + bounds=bounds, channel_axis=channel_axis, preprocess=preprocess) + + if preprocess is None: + preprocess = (0, 1) + + self._program = program + self._place = fluid.CPUPlace() + self._exe = fluid.Executor(self._place) + + self._input_name = input_name + self._logits_name = logits_name + self._predict_name = predict_name + self._cost_name = cost_name + + # gradient + loss = self._program.block(0).var(self._cost_name) + param_grads = fluid.backward.append_backward( + loss, parameter_list=[self._input_name]) + self._gradient = dict(param_grads)[self._input_name] + + def predict(self, image_batch): + """ + Predict the label of the image_batch. + + Args: + image_batch(list): The image and label tuple list. + Return: + numpy.ndarray: predictions of the images with shape (batch_size, num_of_classes). + """ + feeder = fluid.DataFeeder( + feed_list=[self._input_name, self._logits_name], + place=self._place, + program=self._program) + predict_var = self._program.block(0).var(self._predict_name) + predict = self._exe.run(self._program, + feed=feeder.feed(image_batch), + fetch_list=[predict_var]) + return predict + + def num_classes(self): + """ + Calculate the number of classes of the output label. + + Return: + int: the number of classes + """ + predict_var = self._program.block(0).var(self._predict_name) + assert len(predict_var.shape) == 2 + return predict_var.shape[1] + + def gradient(self, image_batch): + """ + Calculate the gradient of the loss w.r.t the input. + + Args: + image_batch(list): The image and label tuple list. + Return: + list: The list of the gradient of the image. + """ + feeder = fluid.DataFeeder( + feed_list=[self._input_name, self._logits_name], + place=self._place, + program=self._program) + + grad, = self._exe.run(self._program, + feed=feeder.feed(image_batch), + fetch_list=[self._gradient]) + return grad diff --git a/adversarial/fluid_mnist.py b/adversarial/fluid_mnist.py new file mode 100644 index 0000000000000..db4d4b51868ff --- /dev/null +++ b/adversarial/fluid_mnist.py @@ -0,0 +1,86 @@ +""" +CNN on mnist data using fluid api of paddlepaddle +""" +import paddle.v2 as paddle +import paddle.v2.fluid as fluid + + +def mnist_cnn_model(img): + """ + Mnist cnn model + + Args: + img(Varaible): the input image to be recognized + + Returns: + Variable: the label prediction + """ + conv_pool_1 = fluid.nets.simple_img_conv_pool( + input=img, + num_filters=20, + filter_size=5, + pool_size=2, + pool_stride=2, + act='relu') + + conv_pool_2 = fluid.nets.simple_img_conv_pool( + input=conv_pool_1, + num_filters=50, + filter_size=5, + pool_size=2, + pool_stride=2, + act='relu') + + logits = fluid.layers.fc(input=conv_pool_2, size=10, act='softmax') + return logits + + +def main(): + """ + Train the cnn model on mnist datasets + """ + img = fluid.layers.data(name='img', shape=[1, 28, 28], dtype='float32') + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + logits = mnist_cnn_model(img) + cost = fluid.layers.cross_entropy(input=logits, label=label) + avg_cost = fluid.layers.mean(x=cost) + optimizer = fluid.optimizer.Adam(learning_rate=0.01) + optimizer.minimize(avg_cost) + + accuracy = fluid.evaluator.Accuracy(input=logits, label=label) + + BATCH_SIZE = 50 + PASS_NUM = 3 + ACC_THRESHOLD = 0.98 + LOSS_THRESHOLD = 10.0 + train_reader = paddle.batch( + paddle.reader.shuffle( + paddle.dataset.mnist.train(), buf_size=500), + batch_size=BATCH_SIZE) + + place = fluid.CPUPlace() + exe = fluid.Executor(place) + feeder = fluid.DataFeeder(feed_list=[img, label], place=place) + exe.run(fluid.default_startup_program()) + + for pass_id in range(PASS_NUM): + accuracy.reset(exe) + for data in train_reader(): + loss, acc = exe.run(fluid.default_main_program(), + feed=feeder.feed(data), + fetch_list=[avg_cost] + accuracy.metrics) + pass_acc = accuracy.eval(exe) + print("pass_id=" + str(pass_id) + " acc=" + str(acc) + " pass_acc=" + + str(pass_acc)) + if loss < LOSS_THRESHOLD and pass_acc > ACC_THRESHOLD: + break + + pass_acc = accuracy.eval(exe) + print("pass_id=" + str(pass_id) + " pass_acc=" + str(pass_acc)) + fluid.io.save_params( + exe, dirname='./mnist', main_program=fluid.default_main_program()) + print('train mnist done') + + +if __name__ == '__main__': + main() diff --git a/adversarial/mnist_tutorial_fgsm.py b/adversarial/mnist_tutorial_fgsm.py new file mode 100644 index 0000000000000..8b29346b8cd7f --- /dev/null +++ b/adversarial/mnist_tutorial_fgsm.py @@ -0,0 +1,87 @@ +""" +FGSM demos on mnist using advbox tool. +""" +import paddle.v2 as paddle +import paddle.v2.fluid as fluid +import matplotlib.pyplot as plt +import numpy as np + +from advbox.models.paddle import PaddleModel +from advbox.attacks.gradientsign import GradientSignAttack + + +def cnn_model(img): + """ + Mnist cnn model + Args: + img(Varaible): the input image to be recognized + Returns: + Variable: the label prediction + """ + #conv1 = fluid.nets.conv2d() + conv_pool_1 = fluid.nets.simple_img_conv_pool( + input=img, + num_filters=20, + filter_size=5, + pool_size=2, + pool_stride=2, + act='relu') + + conv_pool_2 = fluid.nets.simple_img_conv_pool( + input=conv_pool_1, + num_filters=50, + filter_size=5, + pool_size=2, + pool_stride=2, + act='relu') + + logits = fluid.layers.fc(input=conv_pool_2, size=10, act='softmax') + return logits + + +def main(): + """ + Advbox demo which demonstrate how to use advbox. + """ + IMG_NAME = 'img' + LABEL_NAME = 'label' + + img = fluid.layers.data(name=IMG_NAME, shape=[1, 28, 28], dtype='float32') + # gradient should flow + img.stop_gradient = False + label = fluid.layers.data(name=LABEL_NAME, shape=[1], dtype='int64') + logits = cnn_model(img) + cost = fluid.layers.cross_entropy(input=logits, label=label) + avg_cost = fluid.layers.mean(x=cost) + + place = fluid.CPUPlace() + exe = fluid.Executor(place) + + BATCH_SIZE = 1 + train_reader = paddle.batch( + paddle.reader.shuffle( + paddle.dataset.mnist.train(), buf_size=500), + batch_size=BATCH_SIZE) + feeder = fluid.DataFeeder( + feed_list=[IMG_NAME, LABEL_NAME], + place=place, + program=fluid.default_main_program()) + + fluid.io.load_params( + exe, "./mnist/", main_program=fluid.default_main_program()) + + # advbox demo + m = PaddleModel(fluid.default_main_program(), IMG_NAME, LABEL_NAME, + logits.name, avg_cost.name, (-1, 1)) + att = GradientSignAttack(m) + for data in train_reader(): + # fgsm attack + adv_img = att(data) + plt.imshow(n[0][0], cmap='Greys_r') + plt.show() + #np.save('adv_img', adv_img) + break + + +if __name__ == '__main__': + main() diff --git a/doc/design/memory_optimization.md b/doc/design/memory_optimization.md index 00f514711a46b..1f68cef4cc28c 100644 --- a/doc/design/memory_optimization.md +++ b/doc/design/memory_optimization.md @@ -5,28 +5,28 @@ In a lecture from Andrew Ng, he attributes the recent sucess of AI due to a combination of these: -- availability of Big Data -- supercomputing power to process this Big Data over very large neural networks -- modern algorithms +- Availability of Big Data +- Supercomputing power to process this Big Data over very large neural networks +- Modern algorithms Following graph shows the details: ![](images/deep_learning.png) -Larger model usually brings better performance. However, GPU memory is certain limited. For example, the memory size of a GTX TITAN X is only 12GB. To train complex and large model, we have to take care of memory using. Besides, memory optimization is also necessary in both online/mobile inference. +Larger model usually bring better performance. However, GPU memory is limited. For example, the memory size of a GTX TITAN X is only 12GB. To train complex and large models, we have to take care of memory usage. Besides, memory optimization is also necessary in both online/mobile inference. ## Solution ### Basic Strategy -There are some basic strategies to make memory optimization, including in-place operation and memory sharing. +There are some basic strategies to improve memory usage, including in-place operations and memory sharing. #### In-place Operation In a relu activation operator: $y = \max(x, 0)$ -If the variable x is not used in any other operator, we can make an in-place operation. In other words, the memory block of variable y and variable x are the same. In-place operation will save 50% memory occupancy immediately. +If the variable x is not used in any other operator, we can make an in-place operation. In other words, the memory block of variable y and variable x will be the same. In-place operations will save 50% memory occupancy immediately. #### Memory Sharing @@ -40,18 +40,18 @@ d = op2(a) e = op3(d, f) ``` -In this case, variable a is no longer used, and op2 does not support in-place operation. After op2 finished, we can put the memory of variable a to a memory pool. Then, variable e can share the memory of variable a from the pool. +In this case, variable a is no longer used, and op2 does not support in-place operation. After op2 finishes, we can put the memory of variable a to a memory pool. Then, variable e can share the memory of variable a from the pool. ### Live Variable Analysis -It's not enough to only have some basic strategies. The prerequisite of memory optimization is to know if a variable is still "live" after an operation. +It's not enough to only have some basic strategies. The pre-requisite of memory optimization is to know if a variable is still "live" after an operation. In our design, the neural network topology is defined as a program. Luckily, [live variable analysis](https://en.wikipedia.org/wiki/Live_variable_analysis) is a classic problem in compilers which can be used in many stages, such as register allocation. -In compilers, the front end of the compilers translates programs into an intermediate language with an unbounded number of temporaries. This program must run on a machine with a bounded number of registers. Two temporaries a and b can fit into the same register, if a and b are never "in use" at the same time. Thus, many temporaries can fit in few registers; if they don't all fit, the excess temporaries can be kept in memory. +In compilers, the front end of the compiler translates programs into an intermediate language with an unbounded number of temporary variables. This program must run on a machine with a bounded number of registers. Two temporary variables a and b can fit into the same register, if a and b are never "in use" at the same time. Thus, many temporary variables can fit in few registers; if they don't all fit, the excess tempory variables can be kept in memory. -Therefore, the compiler needs to analyze the intermediate-representation program to determine which temporaries are in use at the same time. We say a variable is "live" if it holds a value that may be needed in the future, so this analysis is called liveness analysis. +Therefore, the compiler needs to analyze the intermediate-representation program to determine which temporary variables are in use at the same time. We say a variable is "live" if it holds a value that may be needed in the future, so this analysis is called liveness analysis. We can leran these techniques from compilers. There are mainly two stages to make live variable analysis: @@ -60,7 +60,7 @@ We can leran these techniques from compilers. There are mainly two stages to mak #### Control Flow Graph -To preform analyses on a program, it is often useful to make a control flow graph. A [control flow graph](https://en.wikipedia.org/wiki/Control_flow_graph) (CFG) in computer science is a representation, using graph notation, of all paths that might be traversed through a program during its execution. Each statement in the program is a node in the flow graph; if statemment x can be followed by statement y, there is an egde from x to y. +To perform analysis on a program, it is often useful to make a control flow graph. A [control flow graph](https://en.wikipedia.org/wiki/Control_flow_graph) (CFG) in computer science is a representation, using graph notation, of all paths that might be traversed through a program during its execution. Each statement in the program is a node in the flow graph; if statemment x can be followed by statement y, there is an egde from x to y. Following is the flow graph for a simple loop. @@ -68,18 +68,18 @@ Following is the flow graph for a simple loop. #### Dataflow Analysis -liveness of variable "flows" around the edges of the control flow graph; determining the live range of each variable is an example of a dataflow problem. [Dataflow analysis](https://en.wikipedia.org/wiki/Data-flow_analysis) is a technique for gathering information about the possible set of values calculated at various points in a computer program. +Liveness of variable "flows" around the edges of the control flow graph; determining the live range of each variable is an example of a dataflow problem. [Dataflow analysis](https://en.wikipedia.org/wiki/Data-flow_analysis) is a technique for gathering information about the possible set of values calculated at various points in a computer program. A simple way to perform data-flow analysis of programs is to set up dataflow equations for each node of the control flow graph and solve them by repeatedly calculating the output from the input locally at each node until the whole system stabilizes. - Flow Graph Terminology -A flow graph node has out-edges that lead to sucessor nodes, and in-edges that come from presucessor nodes. The set *pred[n]* is all the predecessors of node n, and *succ[n]* is the set of sucessors. +A flow graph node has out-edges that lead to sucessor nodes, and in-edges that come from predecessor nodes. The set *pred[n]* is all the predecessors of node n, and *succ[n]* is the set of sucessors. In former control flow graph, the out-edges of node 5 are 5 --> 6 and 5 --> 2, and *succ[5]* = {2, 6}. The in-edges of 2 are 5 --> 2 and 1 --> 2, and *pred[2]* = {1, 5}. - Uses and Defs -An assignmemt to a variable or temporary defines that variable. An occurence of a variable on the right-hand side of an assginment(or in other expressions) uses the variable. We can speak the *def* of a variable as the set of graph nodes that define it; or the *def* of a graph node as the set of variables that it defines; and the similarly for the *use* of a variable or graph node. In former control flow graph, *def(3)* = {c}, *use(3)* = {b, c}. +An assignmemt to a variable or temporary defines that variable. An occurence of a variable on the right-hand side of an assginment(or in other expressions) uses the variable. We can define the *def* of a variable as the set of graph nodes that define it; or the *def* of a graph node as the set of variables that it defines; and the similarly for the *use* of a variable or graph node. In former control flow graph, *def(3)* = {c}, *use(3)* = {b, c}. - Liveness @@ -168,9 +168,9 @@ class ControlFlowGraph(object): return self._program ``` -#### make dataflow analysis +#### Make dataflow analysis -We follow guide from compilers and try to solve the dataflow equation to get liveness of every variable. If the live-in of an operator node is different from the live-out, then we can make memory sharing. +We follow the guide from compilers and try to solve the dataflow equation to get liveness of every variable. If the live-in of an operator node is different from the live-out, then we can make memory sharing. For example: diff --git a/doc/design/operator_kernel_type.md b/doc/design/operator_kernel_type.md index aa82e96bf7931..f86e6b7a564ed 100644 --- a/doc/design/operator_kernel_type.md +++ b/doc/design/operator_kernel_type.md @@ -1,6 +1,6 @@ # Design Doc: The Keys of Operator Kernel Type ## Problem -An operator can have different kernel implementations, and each operator will have a map to store the related kernels. Fluid uses `OpKernelType` as a key to identify a unique Kernel. Before an operator runs, an certain kernel must be chosen by a key of `OpKernelType`. Currently, `OpKernelType` is defined as follows: +An operator can have different kernel implementations, and each operator will have a map to store the related kernels. Fluid uses `OpKernelType` as a key to identify a unique kernel. Before an operator runs, a certain type of kernel must be chosen via a key of `OpKernelType`. Currently, `OpKernelType` is defined as follows: ```cpp struct OpKernelType { @@ -10,13 +10,13 @@ struct OpKernelType { ``` For more details, please refer to [codes](https://github.com/PaddlePaddle/Paddle/blob/2d5ec16bc8a09fb8e0f62c89b116b0cd1d333907/paddle/framework/operator.h#L348-L374) in github. -It contains two keys, `Place` and `DataType`. And these two keys will be hashed to a unique key to represent a certain type of kernel. However, these two keys are not enough. We need a more complete representation of `OpKernelType`. +It contains two keys, `Place` and `DataType`. And these two keys will be hashed to a unique key to represent a certain type of kernel. However, these two keys do not provide enough information. We need a more complete representation of `OpKernelType`. -We often implement a kernel of an operator with some computing library in certain device(place). Please remind that computing library and device are not one-to-one corresponding. A device can have a lot of computing libraries and a computing library can also support several devices. +We often implement a kernel of an operator with some computing library on certain device(place). Please note that computing library and device do not have a one-to-one correspondence. A device can have a lot of computing libraries and a computing library can also support different devices. -For example, Eigen library can support Nvidia GPU/AMD GPU/CPU. And MKLDNN library can support Intel CPU/Intel FPGA. Both `Place` and `Library` should be a key of `OpKernelType`. +For example, Eigen library supports Nvidia GPU/AMD GPU/CPU and MKLDNN library supports Intel CPU/Intel FPGA. Both `Place` and `Library` should be a key of `OpKernelType`. -It's obvious that different DataTypes, like fp64/fp32/int8 will have different kernels. But the data layout of a Tensor will also lead to different implementation. Please refer to the batch norm operator [kernels](https://github.com/PaddlePaddle/Paddle/blob/a948fac4d0ad7e0412d373b8aabeb711c2899563/paddle/operators/batch_norm_op.cc#L180-L209). Data Layout should also be taken into consideration. +Different DataTypes, such as fp64/fp32/int8, will obviously have different kernels. But different data layout of a Tensor will also lead to different implementations. Please refer to the batch norm operator [kernels](https://github.com/PaddlePaddle/Paddle/blob/a948fac4d0ad7e0412d373b8aabeb711c2899563/paddle/operators/batch_norm_op.cc#L180-L209) as an example. Data layout should also be taken into consideration. ## Solution @@ -31,17 +31,17 @@ struct OpKernelType { }; ``` -Following is the details: +The details are as follows: ### Place -`Place` is defined as follows: +`Place` is defined as: ```cpp typedef boost::variant Place; ``` -`Place` is to represent the device memory where data is locating. +`Place` represents the device memory where data is located. ### Library @@ -52,10 +52,10 @@ One operator kernel is usually implemented based on one library. `Library` is de enum Library { Plain, MKLDNN, CUDNN }; ``` -We use `Plain` enumerator to represent default library. Since most operators in Fluid are implemented based on `Eigen` library, we take `Eigen` library as the `Plain` enumerator. -A library usually has a corresponding `DeviceContext` which contains some handles needed by computation. Fluid now have two default DeviceContexts in CPU and CUDA, `CPUDeviceContext` and `CUDADeviceContext`. `CPUDeviceContext` contains a Eigen library handle and `CDUADeviceContext` contains a Eigen library handle and cuBLAS handle. +We use `Plain` enumerator to represent default library. Since most operators in Fluid are implemented based on the `Eigen` library, we take `Eigen` library as the `Plain` enumerator. +A library usually has a corresponding `DeviceContext` which contains some handles needed for computation. Fluid now has two default DeviceContexts for CPU and CUDA, namely, `CPUDeviceContext` and `CUDADeviceContext`. `CPUDeviceContext` contains an Eigen library handle and `CDUADeviceContext` contains an Eigen library handle and a cuBLAS handle. -If we want to support new Library, a new enumerator need to be added to `Library` and a new corresponding `LibraryDeviceContext` will be created. +If we want to support new library, a new enumerator need to be added to `Library` and a corresponding new `LibraryDeviceContext` need to be created. ### DataType @@ -67,15 +67,15 @@ If we want to support new Library, a new enumerator need to be added to `Library Actually, a Tensor is a view of a block of memory. Besides a pointer to the memory, we also have to get some other descriptions of this block of memory, such as shape(ddim), stride, and layout. -Different layout leads to different implementation of operator kernel. There are mainly 4 principles we have to follow to support layout in our fluid framework. +Different layout leads to different implementation of the operator kernel. There are mainly 4 principles we have to follow to support layout in our Fluid framework. -- We take layout as a data member of Tensor. Layout is actually a enum variable. If fluid is built with MKLDNN, then, the memory format in MKLDNN will be added into this enum variable too. +- We take layout as a data member of Tensor. Layout is actually a enum variable. If Fluid is built with MKLDNN, then the memory format in MKLDNN will also be added into this enum variable. -- Users have to set layout for input data. And some operators like fill_constant/random, also have to set layout of generating data. Of course, we can have some default layout, like NCHW. +- Users have to set layout for input data. And some operators like fill_constant/random, also have to set layout for generating data. Of course, we can have some default layout, like NCHW. -- The inference of Layout is at run-time, not compile-time. +- The inference of Layout is at run-time, not at compile-time. -- Every operator have to implement different kernels for different layouts. Let's take MKLDNN as an example, if we want to implement a MKLDNN convolution operator, we have to realize all the kernels for different layout, list at [here](http://01org.github.io/mkl-dnn/structmkldnn_1_1memory.html). And we will have a special macro to do registering kernels for MKLDNN operators. +- Every operator has to implement different kernels for different layouts. Let's take MKLDNN as an example. If we want to implement an MKLDNN convolution operator, we have to implement all the kernels for different layouts, which are listed [here](http://01org.github.io/mkl-dnn/structmkldnn_1_1memory.html). And we will have a special macro to register kernels for MKLDNN operators. `Layout` is also defined as a enum variable: diff --git a/doc/design/python_api.md b/doc/design/python_api.md index cb5fdc765b712..73f6d7b90c7dc 100644 --- a/doc/design/python_api.md +++ b/doc/design/python_api.md @@ -279,6 +279,26 @@ class LayerHelper(object): return tmp ``` +### Return value of layer functions + +The layer will return a Variable, which is also the output of an operator. However, outputs of a layer function have more attributes than an operator. There are parameter variables, and their gradient variables need to return. To return them is useful. For example, + +1. Users can debug the network by printing parameter gradients. +2. Users can append attributes to a parameter, such as, `param.stop_gradient=True` will make a parameter stop generate the gradient. We can fix the parameter value during training by using this attribute. + +However, it is good to return a Variable for layers, since all layers and operators use Variables as their parameters. We can just append a `param` field and a `grad` field for layer function since the Python is dynamic typing. + +The sample usage is + +```python +data = fluid.layers.data(...) +hidden = fluid.layers.fc(data, ...) +... + +executor.run(fetch_list=[hidden.param, hidden.param.grad], ...) +``` + + ## Optimizer [Optimizer Design Doc](./optimizer.md) diff --git a/doc/howto/read_source.md b/doc/howto/read_source.md index e4211abb3be9c..31987920f32f2 100644 --- a/doc/howto/read_source.md +++ b/doc/howto/read_source.md @@ -26,16 +26,16 @@ sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001) sgd_optimizer.minimize(avg_cost) ``` -- Variables: `x`, `y`, `y_predict`, `cost` and `avg_cost`. [Python](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/v2/fluid/framework.py#L93) -- Layers: `fluid.layers.data`, `fluid.layers.fc` and `fluid.layers.mean` are layers. [Python](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/v2/fluid/layers.py) +- Variables: `x`, `y`, `y_predict`, `cost` and `avg_cost`. [Python](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/v2/fluid/framework.py#) +- Layers: `fluid.layers.data`, `fluid.layers.fc` and `fluid.layers.mean` are layers. [Python](https://github.com/PaddlePaddle/Paddle/tree/develop/python/paddle/v2/fluid/layers) - Every Layer has one or more operators and variables/parameters - All the operators are defined at [`paddle/operators/`](https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/operators). Other worth-looking files: - Base class: [`paddle/framework/operator.h`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/operator.h) - Operator Registration: [`paddle/framework/op_registry.h`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/op_registry.h) - Operator Lookup: [`paddle/framework/op_info.h`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/op_info.h) - Optimizer: `fluid.optimizer.SGD`. It does the following - - Add backward operators. [[Python](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/v2/fluid/backward.py), [C++](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/backward.cc)] - - Add optimizer operators. [[Python](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/v2/fluid/optimizer.py), [C++](https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/optimizer)] + - Add backward operators. [[Python](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/v2/fluid/backward.py)] + - Add optimizer operators. [[Python](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/v2/fluid/optimizer.py)] # Run Time diff --git a/paddle/framework/init.cc b/paddle/framework/init.cc index e12bac1d78e3f..4ef82a541efaa 100644 --- a/paddle/framework/init.cc +++ b/paddle/framework/init.cc @@ -11,6 +11,7 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include // for strdup #include #include @@ -60,7 +61,9 @@ void InitDevices() { } void InitGLOG(const std::string &prog_name) { - google::InitGoogleLogging(prog_name.c_str()); + // glog will not hold the ARGV[0] inside. + // Use strdup to alloc a new string. + google::InitGoogleLogging(strdup(prog_name.c_str())); google::InstallFailureSignalHandler(); } diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index ef2c55cc3799b..7756a52ca9b79 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -11,6 +11,7 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include #include #include @@ -21,6 +22,10 @@ limitations under the License. */ #include "paddle/framework/shape_inference.h" #include "paddle/framework/var_type.h" +DEFINE_bool(op_sync, false, + "Default cuda is asynchronous device, set to True will" + "force op run in synchronous mode."); + namespace paddle { namespace framework { @@ -542,8 +547,14 @@ void OperatorWithKernel::Run(const Scope& scope, auto kernel_iter = kernels.find(expected_kernel_key); - kernel_iter->second->Compute(ExecutionContext( - *this, new_scope, *pool.Get(expected_kernel_key.place_))); + auto* new_dev_ctx = pool.Get(expected_kernel_key.place_); + kernel_iter->second->Compute( + ExecutionContext(*this, new_scope, *new_dev_ctx)); + + /*For profiling/benchmark only*/ + if (FLAGS_op_sync) { + new_dev_ctx->Wait(); + } } proto::DataType OperatorWithKernel::IndicateDataType( diff --git a/paddle/operators/reorder_lod_tensor_by_rank_op.cc b/paddle/operators/reorder_lod_tensor_by_rank_op.cc index 1065441e47c51..4208c62a4f07d 100644 --- a/paddle/operators/reorder_lod_tensor_by_rank_op.cc +++ b/paddle/operators/reorder_lod_tensor_by_rank_op.cc @@ -50,8 +50,8 @@ X = [Seq0, Seq1, Seq2, Seq3]. The indices in RankTable are [3, 0, 2, 1]. Out = [Seq3, Seq0, Seq2, Seq1] with a new LoD information. If the LoD information of Input(X) is empty, this means Input(X) is not a -sequcence. This is also identical to a batch of sequences, each sequence in -which has a fixed length 1. In this case, the reorder_lod_tensor_by_rank operator +sequcence. This is also identical to a batch of sequences each sequence in which +has a fixed length 1. In this case, the reorder_lod_tensor_by_rank operator reorders each slice of Input(X) along the first axis according to Input(RankTable). @@ -62,7 +62,7 @@ Out = [Slice3, Slice0, Slice2, Slice1] with no LoD information is appended. NOTE: This operator sorts Input(X) according to a given LoDRankTable which dose not need to be calculated according to Input(X). It can be calculated according -to other different sequence, and then this operator sorts Input(X) according +to another different sequence, and then this operator sorts Input(X) according to the given LoDRankTable. )DOC"); diff --git a/paddle/operators/sequence_erase_op.cc b/paddle/operators/sequence_erase_op.cc new file mode 100644 index 0000000000000..d17b2686238b2 --- /dev/null +++ b/paddle/operators/sequence_erase_op.cc @@ -0,0 +1,89 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/sequence_erase_op.h" + +namespace paddle { +namespace operators { + +class SequenceEraseOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of SequenceEraseOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of SequenceEraseOp should not be null."); + auto x_dims = ctx->GetInputDim("X"); + PADDLE_ENFORCE(x_dims.size() == 2 && x_dims[1] == 1, + "Input(X) of SequenceEraseOp should be a 2-D LoDTensor " + "with the 2nd dimension equal to 1."); + ctx->SetOutputDim("Out", x_dims); + } +}; + +class SequenceEraseOpMaker : public framework::OpProtoAndCheckerMaker { + public: + SequenceEraseOpMaker(OpProto* proto, OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", + "(2-D LoDTensor with the 2nd dim. equal to 1) " + "Input LoDTensor of SequenceEraseOp."); + AddOutput("Out", + "(2-D LoDTensor with the 2nd dim. equal to 1) " + "Output LoDTensor of SequenceEraseOp."); + AddAttr>("tokens", + "(vector) Tokens need to be erased from " + "input sequences."); + AddComment(R"DOC( +Sequence Erase Operator. + +Sequence erase operator erases tokens specified by Attr(tokens) from the input +sequences Input(X), and outputs the remaining data and modifies the LoD +information at the same time. For example, given a 2-D LoDTensor + + X = [[2, 2, 6, 1, 3, 9, 6, 1, 0, 1]]^T + +with lod = [[0, 3, 6, 10]], there are three sequences in the input: + + X1 = [[2, 2, 6]]^T, X2 = [[1, 3, 9]]^T and X3 = [[6, 1, 0, 1]]^T. + +If the tokens to be erased are Attr(tokens) = [2, 3, 5], after the erasing +operation, the three sequences become + + X1' = [[6]]^T, X2' = [[1, 9]]^T and X3' = [[6, 1, 0, 1]]^T. + +Hence the LoDTensor Output(Out) should be + + Out = [[6, 1, 9, 6, 1, 0, 1]]^T, + +with lod = [[0, 1, 3, 7]]. + +An example usage for this operator is to remove the special tokens when +computing the edit distance between two strings, such as blank, start token, +and end token. +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_WITHOUT_GRADIENT(sequence_erase, ops::SequenceEraseOp, + ops::SequenceEraseOpMaker); +REGISTER_OP_CPU_KERNEL( + sequence_erase, + ops::SequenceEraseKernel); diff --git a/paddle/operators/sequence_erase_op.cu b/paddle/operators/sequence_erase_op.cu new file mode 100644 index 0000000000000..5da8eba3e1ac1 --- /dev/null +++ b/paddle/operators/sequence_erase_op.cu @@ -0,0 +1,133 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include "paddle/operators/sequence_erase_op.h" +#include "paddle/platform/cuda_helper.h" + +namespace paddle { +namespace operators { +using platform::PADDLE_CUDA_NUM_THREADS; +using LoDTensor = framework::LoDTensor; + +template +__global__ void LabelErasedIdx(const T* in_dat, const int in_len, + const T* tokens, const int tokens_len, + int* num_erased) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < in_len) { + int erased = 0; + for (int i = 0; i < tokens_len; ++i) { + if (in_dat[index] == tokens[i]) { + erased = 1; + } + } + num_erased[index + 1] = erased; + if (index == 0) { + num_erased[0] = 0; + } + } +} + +template +__global__ void GetOutLod(const T* num_erased, const int* in_lod, + const int lod_len, int* out_lod0) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < lod_len) { + out_lod0[index] = in_lod[index] - num_erased[in_lod[index]]; + } +} + +template +__global__ void SetOutput(const T* in_dat, const int in_len, + const int* num_erased, T* out_dat) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < in_len) { + if (in_dat[index] != in_dat[index + 1]) { + out_dat[index - num_erased[index]] = in_dat[index]; + } + } +} + +template +class SequenceEraseOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* in = ctx.Input("X"); + auto* out = ctx.Output("Out"); + + auto lod = in->lod(); + PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now."); + PADDLE_ENFORCE_EQ(lod[0].back(), (size_t)in->numel(), + "The actual size mismatches with the LoD information."); + auto tokens = ctx.Attr>("tokens"); + auto tokens_len = tokens.size(); + auto in_len = in->numel(); + auto in_dat = in->data(); + auto lod0 = lod[0]; + + thrust::host_vector host_tokens(tokens_len); + for (size_t i = 0; i < tokens.size(); ++i) { + host_tokens[i] = tokens[i]; + } + thrust::device_vector dev_tokens = host_tokens; + thrust::device_vector num_erased(in_len + 1); + + T* dev_tokens_ptr = thrust::raw_pointer_cast(dev_tokens.data()); + int* num_erased_ptr = thrust::raw_pointer_cast(num_erased.data()); + + auto stream = ctx.cuda_device_context().stream(); + LabelErasedIdx<<<(in_len - 1) / PADDLE_CUDA_NUM_THREADS + 1, + PADDLE_CUDA_NUM_THREADS, 0, stream>>>( + in_dat, in_len, dev_tokens_ptr, tokens_len, num_erased_ptr); + thrust::inclusive_scan(num_erased.begin() + 1, num_erased.end(), + num_erased.begin() + 1); + + // Calc LoD + auto lod_len = lod0.size(); + thrust::host_vector host_lod(lod_len); + for (size_t i = 0; i < lod_len; ++i) { + host_lod[i] = lod0[i]; + } + thrust::device_vector dev_in_lod = host_lod; + thrust::device_vector dev_out_lod(lod_len); + int* dev_in_lod_ptr = thrust::raw_pointer_cast(dev_in_lod.data()); + int* dev_out_lod_ptr = thrust::raw_pointer_cast(dev_out_lod.data()); + GetOutLod<<<(lod_len - 1) / PADDLE_CUDA_NUM_THREADS + 1, + PADDLE_CUDA_NUM_THREADS, 0, stream>>>( + num_erased_ptr, dev_in_lod_ptr, lod_len, dev_out_lod_ptr); + thrust::host_vector host_out_lod = dev_out_lod; + std::vector out_lod0(lod_len, 0); + for (size_t i = 0; i < lod_len; i++) { + out_lod0[i] = host_out_lod[i]; + } + framework::LoD out_lod; + out_lod.push_back(out_lod0); + out->set_lod(out_lod); + + // Set output + out->Resize({out_lod0.back(), 1}); + auto out_dat = out->mutable_data(ctx.GetPlace()); + SetOutput<<<(in_len - 1) / PADDLE_CUDA_NUM_THREADS + 1, + PADDLE_CUDA_NUM_THREADS, 0, stream>>>(in_dat, in_len, + num_erased_ptr, out_dat); + } +}; + +} // namespace operators +} // namespace paddle + +REGISTER_OP_CUDA_KERNEL(sequence_erase, + paddle::operators::SequenceEraseOpCUDAKernel); diff --git a/paddle/operators/sequence_erase_op.h b/paddle/operators/sequence_erase_op.h new file mode 100644 index 0000000000000..cb2d7be009dcb --- /dev/null +++ b/paddle/operators/sequence_erase_op.h @@ -0,0 +1,70 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +class SequenceEraseKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* in = ctx.Input("X"); + auto* out = ctx.Output("Out"); + + auto lod = in->lod(); + PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now."); + PADDLE_ENFORCE_EQ(lod[0].back(), (size_t)in->numel(), + "The actual size mismatches with the LoD information."); + auto tokens = ctx.Attr>("tokens"); + auto in_len = in->numel(); + auto in_dat = in->data(); + auto lod0 = lod[0]; + + std::vector num_erased(in_len + 1, 0); + std::vector out_lod0(1, 0); + for (size_t i = 0; i < lod0.size() - 1; ++i) { + size_t num_out = 0; + for (auto j = lod0[i] + 1; j <= lod0[i + 1]; ++j) { + num_erased[j] = num_erased[j - 1]; + if (std::find(tokens.begin(), tokens.end(), in_dat[j - 1]) != + tokens.end()) { + num_erased[j] += 1; + } else { + num_out += 1; + } + } + out_lod0.push_back(out_lod0.back() + num_out); + } + + auto out_len = in_len - num_erased[in_len]; + out->Resize({static_cast(out_len), 1}); + auto out_dat = out->mutable_data(ctx.GetPlace()); + + for (int64_t i = 0; i < in_len; ++i) { + if (num_erased[i] == num_erased[i + 1]) { + out_dat[i - num_erased[i]] = in_dat[i]; + } + } + framework::LoD out_lod; + out_lod.push_back(out_lod0); + out->set_lod(out_lod); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/shrink_rnn_memory_op.cc b/paddle/operators/shrink_rnn_memory_op.cc index 266d87e298692..ade94b40bed91 100644 --- a/paddle/operators/shrink_rnn_memory_op.cc +++ b/paddle/operators/shrink_rnn_memory_op.cc @@ -80,7 +80,7 @@ class ShrinkRNNMemoryOpProtoMaker : public framework::OpProtoAndCheckerMaker { This operator is used to shrink output batch of memory defined in dynamic RNN. Dynamic RNN is able to handle variable-length sequences, in which, sequences in -a mini-batch are sorted by its length first. After that, the longest sequence +a mini-batch are sorted by their lengths first. After that, the longest sequence becomes the first one in the sorted batch, followed by the second longest, the third longest, and so on. Dynamic RNN then slices a batch input timestep by timestep from the sorted input. Once any sequence in the input batch reaches its diff --git a/python/paddle/v2/fluid/__init__.py b/python/paddle/v2/fluid/__init__.py index ccd5998e3592a..422aa0a5ba2e4 100644 --- a/python/paddle/v2/fluid/__init__.py +++ b/python/paddle/v2/fluid/__init__.py @@ -58,7 +58,7 @@ def __bootstrap__(): read_env_flags = ['use_pinned_memory', 'check_nan_inf'] if core.is_compile_gpu(): - read_env_flags.append('fraction_of_gpu_memory_to_use') + read_env_flags += ['fraction_of_gpu_memory_to_use', 'op_sync'] core.init_gflags([sys.argv[0]] + ["--tryfromenv=" + ",".join(read_env_flags)]) core.init_glog(sys.argv[0]) diff --git a/python/paddle/v2/fluid/framework.py b/python/paddle/v2/fluid/framework.py index 2fb388acfc0a9..3ef6b33192d95 100644 --- a/python/paddle/v2/fluid/framework.py +++ b/python/paddle/v2/fluid/framework.py @@ -236,6 +236,9 @@ def to_string(self, throw_on_error): __repr__ = __str__ + def set_desc(self, input): + self.desc = input + @property def persistable(self): return self.desc.persistable() diff --git a/python/paddle/v2/fluid/layers/control_flow.py b/python/paddle/v2/fluid/layers/control_flow.py index 584027e518e60..afc5b78ba8521 100644 --- a/python/paddle/v2/fluid/layers/control_flow.py +++ b/python/paddle/v2/fluid/layers/control_flow.py @@ -690,7 +690,7 @@ def lod_tensor_to_array(x, table): """ Convert a LOD_TENSOR_ARRAY to an TensorArray. Args: - x (Variable|list): The LoD tensor to be converted to a LoD tensor array. + x (Variable|list): The lod tensor to be converted to a lod tensor array. table (ParamAttr|list): The variable that stores the level of lod which is ordered by sequence length in descending order. @@ -723,7 +723,7 @@ def array_to_lod_tensor(x, table): """Convert a LoD_Tensor_Aarry to an LoDTensor. Args: - x (Variable|list): The LoD Tensor Array to be converted to a tensor. + x (Variable|list): The lod tensor array to be converted to a tensor. table (ParamAttr|list): The variable that stores the level of lod which is ordered by sequence length in descending order. diff --git a/python/paddle/v2/fluid/layers/nn.py b/python/paddle/v2/fluid/layers/nn.py index b1534c5a886db..48a6bee558894 100644 --- a/python/paddle/v2/fluid/layers/nn.py +++ b/python/paddle/v2/fluid/layers/nn.py @@ -14,7 +14,7 @@ 'chunk_eval', 'sequence_conv', 'conv2d', 'sequence_pool', 'pool2d', 'batch_norm', 'beam_search_decode', 'conv2d_transpose', 'sequence_expand', 'lstm_unit', 'reduce_sum', 'reduce_mean', 'reduce_max', 'reduce_min', - 'sequence_first_step', 'sequence_last_step' + 'sequence_first_step', 'sequence_last_step', 'dropout' ] @@ -386,6 +386,21 @@ def cos_sim(X, Y, **kwargs): return out +def dropout(x, dropout_prob, is_test=False, seed=0, **kwargs): + helper = LayerHelper('dropout', **kwargs) + out = helper.create_tmp_variable(dtype=x.dtype) + mask = helper.create_tmp_variable(dtype=x.dtype, stop_gradient=True) + helper.append_op( + type='dropout', + inputs={'X': [x]}, + outputs={'Out': [out], + 'Mask': [mask]}, + attrs={'dropout_prob': dropout_prob, + 'is_test': is_test, + 'seed': seed}) + return out + + def cross_entropy(input, label, **kwargs): """ **Cross Entropy Layer** diff --git a/python/paddle/v2/fluid/layers/ops.py b/python/paddle/v2/fluid/layers/ops.py index 544623c4bce0c..d3a5b70785947 100644 --- a/python/paddle/v2/fluid/layers/ops.py +++ b/python/paddle/v2/fluid/layers/ops.py @@ -1,23 +1,12 @@ from ..registry import register_layer __activations__ = [ - 'abs', - 'ceil', - 'exp', - 'floor', - 'log', - 'relu', - 'round', - 'sigmoid', - 'sqrt', - 'square', - 'tanh', + 'abs', 'tanh', 'sigmoid', 'relu', 'sqrt', 'ceil', 'floor', 'log', 'round' ] __all__ = [ 'mean', 'mul', - 'dropout', 'reshape', 'scale', 'transpose', diff --git a/python/paddle/v2/fluid/memory_optimization_transpiler.py b/python/paddle/v2/fluid/memory_optimization_transpiler.py index 571fce7fac616..6800d7ddbb141 100644 --- a/python/paddle/v2/fluid/memory_optimization_transpiler.py +++ b/python/paddle/v2/fluid/memory_optimization_transpiler.py @@ -3,6 +3,17 @@ from framework import Program, default_main_program, Parameter, Variable import backward from backward import _rename_arg_ +from . import core + +dtype_to_size = { + core.DataType.FP16: 2, + core.DataType.FP32: 4, + core.DataType.FP64: 8, + core.DataType.INT16: 2, + core.DataType.INT32: 4, + core.DataType.INT64: 8, + core.DataType.BOOL: 1 +} class ControlFlowGraph(object): @@ -28,18 +39,33 @@ def _build_graph(self): block_size = program_desc.num_blocks() # TODO(qijun) handle Program with if/while operators - self.global_block = program_desc.block(0) - self.op_size = self.global_block.op_size() + self.global_block_desc = program_desc.block(0) + self.op_size = self.global_block_desc.op_size() op_node_connections = [(i, i + 1) for i in range(self.op_size - 1)] self._add_connections(op_node_connections) - self.ops = [self.global_block.op(i) for i in range(self.op_size)] + self.ops = [self.global_block_desc.op(i) for i in range(self.op_size)] for i in range(self.op_size): self._uses[i].update(self.ops[i].input_arg_names()) self._defs[i].update(self.ops[i].output_arg_names()) + def _update_graph(self, old_name, new_name, begin_idx=0): + for i in range(begin_idx, self.op_size): + if old_name in self._uses[i]: + self._uses[i].remove(old_name) + self._uses[i].add(new_name) + if old_name in self._defs[i]: + self._defs[i].remove(old_name) + self._defs[i].add(new_name) + if old_name in self._live_in[i]: + self._live_in[i].remove(old_name) + self._live_out[i].add(new_name) + if old_name in self._live_out[i]: + self._live_out[i].remove(old_name) + self._live_out[i].add(new_name) + def _reach_fixed_point(self, live_in, live_out): if len(live_in) != len(self._live_in): return False @@ -79,30 +105,45 @@ def memory_optimize(self): self.pool = [] for i in range(self.op_size): if self.pool: - out_pair = [(x, self.global_block.var(str(x)).shape()) + out_pair = [(x, self.global_block_desc.var(str(x)).shape()) for x in self._defs[i]] for x, x_shape in out_pair: - for index, cache_pair in enumerate(self.pool): - cache_var = cache_pair[0] - cache_shape = cache_pair[1] - if x_shape == cache_shape: - print( - "Hit Cache !!!! cache pool index is %d, var name is %s, cached var name is %s, var shape is %s " - % (index, x, cache_var, str(cache_shape))) - self.pool.pop(index) - _rename_arg_(self.ops, x, cache_var, begin_idx=i) - self._dataflow_analyze() - break + if not self.global_block_desc.var(str(x)).persistable(): + for index, cache_pair in enumerate(self.pool): + cache_var = cache_pair[0] + cache_shape = cache_pair[1] + if x_shape == cache_shape: + x_dtype = self.global_block_desc.var(str( + x)).dtype() + cache_dtype = self.global_block_desc.var( + str(cache_var)).dtype() + # TODO(qijun): actually, we should compare dtype_to_size[x_dtype] + # and dtype_to_size[cache_dtype] + if x_dtype == cache_dtype: + print( + "Hit Cache !!!! cache pool index is %d, var name is %s, cached var name is %s, var shape is %s " + % + (index, x, cache_var, str(cache_shape))) + self.pool.pop(index) + _rename_arg_( + self.ops, x, cache_var, begin_idx=i) + self._program.current_block().var(str( + x)).desc = self.global_block_desc.var( + str(cache_var)) + self._update_graph( + x, cache_var, begin_idx=i) + break in_diff, out_diff = self._get_diff(self._live_in[i], self._live_out[i]) can_optimize = filter( - lambda x: not self.global_block.var(str(x)).persistable(), + lambda x: not self.global_block_desc.var(str(x)).persistable(), in_diff) if can_optimize: for var_name in can_optimize: - self.pool.append(( - var_name, self.global_block.var(str(var_name)).shape())) + self.pool.append( + (var_name, + self.global_block_desc.var(str(var_name)).shape())) def get_program(self): return self._program diff --git a/python/paddle/v2/fluid/tests/test_parallel_op.py b/python/paddle/v2/fluid/tests/test_parallel_op.py index 59ed041e7fa1d..2b51a1f50473d 100644 --- a/python/paddle/v2/fluid/tests/test_parallel_op.py +++ b/python/paddle/v2/fluid/tests/test_parallel_op.py @@ -1,45 +1,156 @@ import unittest - -import paddle.v2.fluid.layers as layers import paddle.v2.fluid as fluid -from paddle.v2.fluid.framework import Program -from paddle.v2.fluid.executor import Executor -from paddle.v2.fluid.backward import append_backward -import numpy as np -import paddle.v2.fluid.core as core - - -class ParallelOpTest(unittest.TestCase): - def setUp(self): - x = layers.data( - shape=[-1, 30, 40], - dtype='float32', - name='x', - append_batch_size=False, - stop_gradient=False) - - places = layers.get_places(device_count=4) - pd = layers.ParallelDo(places=places) - - with pd.do(): - data = pd.read_input(x) - hidden = layers.fc(input=data, size=7) - pd.write_output(hidden) - data = pd() - loss = layers.mean(x=data) - sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001) - sgd_optimizer.minimize(loss) - - exe = fluid.Executor(fluid.CPUPlace()) - exe.run(fluid.default_startup_program()) - exe.run(fluid.default_main_program(), - feed={ - x.name: np.random.uniform(0.1, 0.6, - (20, 30, 40)).astype("float32") - }) - - def test_forward(self): - pass +import numpy + + +class BaseParallelForTest(unittest.TestCase): + def run_test(self, callback, feed, fetch): + """ + Run the unittest for parallel.for + Args: + callback(callable): A callable function returns a generator. There + are two yields in the generator function. The first yield + returns the data layers, and the second yield returns the loss. + The modified data variables will be sent back during the first + yield. + + feed(dict): The executor feeding dictionary. + fetch(list|basestr): The fetch name lists. + + Returns: + None + + Raises: + AssertionError when the computation of cpu, parallel.for in cpu, + gpu, parallel.for in gpu are different. + + """ + cpu = fluid.CPUPlace() + result_cpu = self._run_test_impl_( + callback=callback, + feed=feed, + fetch=fetch, + place=cpu, + use_parallel=False) + result_cpu_parallel = self._run_test_impl_( + callback=callback, + feed=feed, + fetch=fetch, + place=cpu, + use_parallel=True) + if fluid.core.is_compile_gpu(): + gpu = fluid.CUDAPlace(0) + result_gpu = self._run_test_impl_( + callback=callback, + feed=feed, + fetch=fetch, + place=gpu, + use_parallel=False) + result_gpu_parallel = self._run_test_impl_( + callback=callback, + feed=feed, + fetch=fetch, + place=gpu, + use_parallel=True) + self._assert_same_(fetch, result_cpu, result_cpu_parallel, + result_gpu, result_gpu_parallel) + else: + self._assert_same_(fetch, result_cpu, result_cpu_parallel) + + def _run_test_impl_(self, callback, feed, fetch, place, use_parallel=False): + """ + Run a single test, returns the fetch values + Args: + place(Place): the computation place. + use_parallel(bool): Whether use parallel.for or not. + + Returns: + Fetched numpy arrays. + + """ + if isinstance(fetch, basestring): + fetch = [fetch] + main = fluid.Program() + startup = fluid.Program() + # Fix seed + main.random_seed = 10 + startup.random_seed = 10 + + with fluid.program_guard(main, startup): + generator = callback() + # Automatically insert parallel do if use_parallel = True + if use_parallel: + places = fluid.layers.get_places() + pd = fluid.layers.ParallelDo(places) + data = next(generator) + + if isinstance(data, fluid.Variable): + data = [data] + + with pd.do(): + ins = map(pd.read_input, data) + if len(ins) == 1: + ins = ins[0] + loss = generator.send(ins) # patch input + pd.write_output(loss) + + loss = pd() + else: + data = next(generator) + loss = generator.send(data) + self.assertIsNotNone(loss) + avg_loss = fluid.layers.mean(x=loss) + fluid.backward.append_backward(loss=avg_loss) + + exe = fluid.Executor(place) + exe.run(startup) + return exe.run(main, feed=feed, fetch_list=fetch) + + def _assert_same_(self, fetch, *args): + """ + Assert the return values of `run_test` are same. + Args: + fetch: Fetch list. Used for print error message + *args: The fetch result lists of each situations. + + Returns: + None + + Raises: + AssertionError + + """ + + def _impl_(a, b, fetch_id, item_id): + item_str = ['CPU', 'ParallelCPU', 'GPU', 'ParallelGPU'] + flag = numpy.allclose(a, b, rtol=0.1) + self.assertTrue(flag, "The {0} are different in {1}".format( + fetch[fetch_id], item_str[item_id])) + + for i, items in enumerate(zip(*args)): + self.assertGreater(len(items), 0) + for j in range(1, len(items)): + _impl_(items[0], items[j], fetch_id=i, item_id=j) + + +class ParallelOpTest(BaseParallelForTest): + def test_simple_fc(self): + def __network__(): + x = fluid.layers.data(shape=[784], dtype='float32', name='img') + # FIXME: This is a bug of parallel.do + x.stop_gradient = False + x = yield x + hidden = fluid.layers.fc(input=x, size=200, param_attr='fc1.w') + loss = fluid.layers.mean(x=hidden) + yield loss + + self.run_test( + callback=__network__, + feed={ + 'img': + numpy.random.random(size=(128 * 3, 784)).astype('float32') + }, + fetch='fc1.w@GRAD') if __name__ == '__main__': diff --git a/python/paddle/v2/fluid/tests/test_sequence_erase_op.py b/python/paddle/v2/fluid/tests/test_sequence_erase_op.py new file mode 100644 index 0000000000000..bf257fefea0d9 --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_sequence_erase_op.py @@ -0,0 +1,35 @@ +import unittest +import numpy as np +from op_test import OpTest + + +def sequence_erase(in_seq, lod0, tokens): + new_lod0 = [0] + out_seq = [] + for i in range(0, len(lod0) - 1): + num_out = 0 + for dat in in_seq[lod0[i]:lod0[i + 1]]: + if dat not in tokens: + out_seq.append(dat) + num_out += 1 + new_lod0.append(new_lod0[-1] + num_out) + return np.array(out_seq).astype("int32"), new_lod0 + + +class TestSequenceEraseOp(OpTest): + def setUp(self): + self.op_type = "sequence_erase" + in_seq = np.random.randint(0, 10, (30, 1)).astype("int32") + lod = [[0, 9, 13, 24, 30]] + tokens = [2, 3, 5] + out_seq, new_lod0 = sequence_erase(in_seq, lod[0], tokens) + self.attrs = {'tokens': tokens} + self.inputs = {'X': (in_seq, lod)} + self.outputs = {'Out': (out_seq, [new_lod0])} + + def test_check_output(self): + self.check_output() + + +if __name__ == '__main__': + unittest.main()