Skip to content

Commit

Permalink
Merge branch 'develop' into update_comments
Browse files Browse the repository at this point in the history
  • Loading branch information
lcy-seso committed Jan 11, 2018
2 parents 49c20a0 + a9dbdab commit 473fd63
Show file tree
Hide file tree
Showing 28 changed files with 1,119 additions and 117 deletions.
9 changes: 9 additions & 0 deletions adversarial/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Advbox

Advbox is a Python toolbox to create adversarial examples that fool neural networks. It requires Python and paddle.

## How to use

1. train a model and save it's parameters. (like fluid_mnist.py)
2. load the parameters which is trained in step1, then reconstruct the model.(like mnist_tutorial_fgsm.py)
3. use advbox to generate the adversarial sample.
16 changes: 16 additions & 0 deletions adversarial/advbox/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
A set of tools for generating adversarial example on paddle platform
"""
39 changes: 39 additions & 0 deletions adversarial/advbox/attacks/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""
The base model of the model.
"""
from abc import ABCMeta, abstractmethod


class Attack(object):
"""
Abstract base class for adversarial attacks. `Attack` represent an adversarial attack
which search an adversarial example. subclass should implement the _apply() method.
Args:
model(Model): an instance of the class advbox.base.Model.
"""
__metaclass__ = ABCMeta

def __init__(self, model):
self.model = model

def __call__(self, image_label):
"""
Generate the adversarial sample.
Args:
image_label(list): The image and label tuple list with one element.
"""
adv_img = self._apply(image_label)
return adv_img

@abstractmethod
def _apply(self, image_label):
"""
Search an adversarial example.
Args:
image_batch(list): The image and label tuple list with one element.
"""
raise NotImplementedError
38 changes: 38 additions & 0 deletions adversarial/advbox/attacks/gradientsign.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
"""
This module provide the attack method for FGSM's implement.
"""
from __future__ import division
import numpy as np
from collections import Iterable
from .base import Attack


class GradientSignAttack(Attack):
"""
This attack was originally implemented by Goodfellow et al. (2015) with the
infinity norm (and is known as the "Fast Gradient Sign Method"). This is therefore called
the Fast Gradient Method.
Paper link: https://arxiv.org/abs/1412.6572
"""

def _apply(self, image_label, epsilons=1000):
assert len(image_label) == 1
pre_label = np.argmax(self.model.predict(image_label))

min_, max_ = self.model.bounds()
gradient = self.model.gradient(image_label)
gradient_sign = np.sign(gradient) * (max_ - min_)

if not isinstance(epsilons, Iterable):
epsilons = np.linspace(0, 1, num=epsilons + 1)

for epsilon in epsilons:
adv_img = image_label[0][0].reshape(
gradient_sign.shape) + epsilon * gradient_sign
adv_img = np.clip(adv_img, min_, max_)
adv_label = np.argmax(self.model.predict([(adv_img, 0)]))
if pre_label != adv_label:
return adv_img


FGSM = GradientSignAttack
16 changes: 16 additions & 0 deletions adversarial/advbox/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Paddle model for target of attack
"""
90 changes: 90 additions & 0 deletions adversarial/advbox/models/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
"""
The base model of the model.
"""
from abc import ABCMeta
import abc

abstractmethod = abc.abstractmethod


class Model(object):
"""
Base class of model to provide attack.
Args:
bounds(tuple): The lower and upper bound for the image pixel.
channel_axis(int): The index of the axis that represents the color channel.
preprocess(tuple): Two element tuple used to preprocess the input. First
substract the first element, then divide the second element.
"""
__metaclass__ = ABCMeta

def __init__(self, bounds, channel_axis, preprocess=None):
assert len(bounds) == 2
assert channel_axis in [0, 1, 2, 3]

if preprocess is None:
preprocess = (0, 1)
self._bounds = bounds
self._channel_axis = channel_axis
self._preprocess = preprocess

def bounds(self):
"""
Return the upper and lower bounds of the model.
"""
return self._bounds

def channel_axis(self):
"""
Return the channel axis of the model.
"""
return self._channel_axis

def _process_input(self, input_):
res = input_
sub, div = self._preprocess
if sub != 0:
res = input_ - sub
assert div != 0
if div != 1:
res /= div
return res

@abstractmethod
def predict(self, image_batch):
"""
Calculate the prediction of the image batch.
Args:
image_batch(numpy.ndarray): image batch of shape (batch_size, height, width, channels).
Return:
numpy.ndarray: predictions of the images with shape (batch_size, num_of_classes).
"""
raise NotImplementedError

@abstractmethod
def num_classes(self):
"""
Determine the number of the classes
Return:
int: the number of the classes
"""
raise NotImplementedError

@abstractmethod
def gradient(self, image_batch):
"""
Calculate the gradient of the cross-entropy loss w.r.t the image.
Args:
image_batch(list): The image and label tuple list.
Return:
numpy.ndarray: gradient of the cross-entropy loss w.r.t the image with
the shape (height, width, channel).
"""
raise NotImplementedError
101 changes: 101 additions & 0 deletions adversarial/advbox/models/paddle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
from __future__ import absolute_import

import numpy as np
import paddle.v2 as paddle
import paddle.v2.fluid as fluid
from paddle.v2.fluid.framework import program_guard

from .base import Model


class PaddleModel(Model):
"""
Create a PaddleModel instance.
When you need to generate a adversarial sample, you should construct an instance of PaddleModel.
Args:
program(paddle.v2.fluid.framework.Program): The program of the model which generate the adversarial sample.
input_name(string): The name of the input.
logits_name(string): The name of the logits.
predict_name(string): The name of the predict.
cost_name(string): The name of the loss in the program.
"""

def __init__(self,
program,
input_name,
logits_name,
predict_name,
cost_name,
bounds,
channel_axis=3,
preprocess=None):
super(PaddleModel, self).__init__(
bounds=bounds, channel_axis=channel_axis, preprocess=preprocess)

if preprocess is None:
preprocess = (0, 1)

self._program = program
self._place = fluid.CPUPlace()
self._exe = fluid.Executor(self._place)

self._input_name = input_name
self._logits_name = logits_name
self._predict_name = predict_name
self._cost_name = cost_name

# gradient
loss = self._program.block(0).var(self._cost_name)
param_grads = fluid.backward.append_backward(
loss, parameter_list=[self._input_name])
self._gradient = dict(param_grads)[self._input_name]

def predict(self, image_batch):
"""
Predict the label of the image_batch.
Args:
image_batch(list): The image and label tuple list.
Return:
numpy.ndarray: predictions of the images with shape (batch_size, num_of_classes).
"""
feeder = fluid.DataFeeder(
feed_list=[self._input_name, self._logits_name],
place=self._place,
program=self._program)
predict_var = self._program.block(0).var(self._predict_name)
predict = self._exe.run(self._program,
feed=feeder.feed(image_batch),
fetch_list=[predict_var])
return predict

def num_classes(self):
"""
Calculate the number of classes of the output label.
Return:
int: the number of classes
"""
predict_var = self._program.block(0).var(self._predict_name)
assert len(predict_var.shape) == 2
return predict_var.shape[1]

def gradient(self, image_batch):
"""
Calculate the gradient of the loss w.r.t the input.
Args:
image_batch(list): The image and label tuple list.
Return:
list: The list of the gradient of the image.
"""
feeder = fluid.DataFeeder(
feed_list=[self._input_name, self._logits_name],
place=self._place,
program=self._program)

grad, = self._exe.run(self._program,
feed=feeder.feed(image_batch),
fetch_list=[self._gradient])
return grad
Loading

0 comments on commit 473fd63

Please sign in to comment.