Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Eager Hook] Support eager hook_for_layer #39531

Merged
merged 5 commits into from
Feb 16, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/paddle/fluid/dygraph/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ def register_forward_pre_hook(self, hook):
import paddle
import numpy as np

# the forward_post_hook change the input of the layer: input = input * 2
# the forward_pre_hook change the input of the layer: input = input * 2
def forward_pre_hook(layer, input):
# user can use layer and input for information statistis tasks

Expand Down
194 changes: 168 additions & 26 deletions python/paddle/fluid/tests/unittests/test_imperative_hook_for_layer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -25,22 +25,36 @@
import paddle.fluid.dygraph.base as base

from test_imperative_lod_tensor_to_selected_rows import SimpleNet
from paddle.fluid.framework import _test_eager_guard

call_forward_hook = False
call_forward_post_hook = False
call_forward_pre_hook = False

call_forward_post_hook_eager = False
call_forward_pre_hook_eager = False

def forward_hook(layer, input, output):
global call_forward_hook
call_forward_hook = True

def forward_post_hook(layer, input, output):
global call_forward_post_hook
call_forward_post_hook = True


def forward_pre_hook(layer, input):
global call_forward_pre_hook
call_forward_pre_hook = True


def forward_hook1(layer, input, output):
def eager_forward_post_hook(layer, input, output):
global call_forward_post_hook_eager
call_forward_post_hook_eager = True


def eager_forward_pre_hook(layer, input):
global call_forward_pre_hook_eager
call_forward_pre_hook_eager = True


def forward_post_hook1(layer, input, output):
return output * 2


Expand All @@ -50,7 +64,7 @@ def forward_pre_hook1(layer, input):


class Test_Forward_Hook(unittest.TestCase):
# test forward_pre_hook and forward_hook that have return value
# test forward_pre_hook and forward_post_hook that have return value
def test_forward_hook_return_value(self):
seed = 90

Expand Down Expand Up @@ -104,22 +118,86 @@ def test_forward_hook_return_value(self):
self.assertTrue(
np.array_equal(outs_pre_hook.numpy(), outs_origin.numpy()))

# register forward_hook
forward_hook_handle1 = simplenet.register_forward_post_hook(
forward_hook1)
# register forward_post_hook
forward_post_hook_handle1 = simplenet.register_forward_post_hook(
forward_post_hook1)
outs_forward_hook = simplenet(input, y)
self.assertTrue(
np.array_equal(outs_forward_hook.numpy(),
outs_origin.numpy() * 2))

# remove forward_hook
forward_hook_handle1.remove()
# remove forward_post_hook
forward_post_hook_handle1.remove()
outs_forward_hook = simplenet(input, y)
self.assertTrue(
np.array_equal(outs_forward_hook.numpy(),
outs_origin.numpy()))

# test forward_pre_hook and forward_hook that don't have return value
for place in places:
with fluid.dygraph.guard(place):
with _test_eager_guard():
fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed
fluid.set_flags({'FLAGS_sort_sum_gradient': True})

input_word = np.array(
[0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 1, 2, 3, 4, 5, 6, 7,
8]).reshape(6, 3).astype('int64')
input_word1 = input_word * 2
input_word = input_word.reshape((-1, 3, 1))
input_word1 = input_word1.reshape((-1, 3, 1))
y_data = np.array(
[1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8,
9]).reshape(6, 3).astype('int64')
y_data = y_data.reshape((-1, 1))

input = base.to_variable(input_word)
input1 = base.to_variable(input_word1)
y = base.to_variable(y_data)

simplenet = SimpleNet(
hidden_size=20,
vocab_size=32,
num_steps=3,
init_scale=0.1,
is_sparse=False,
dtype="float32")

# origin, don't register any hook
outs_origin = simplenet(input, y)
outs_origin1 = simplenet(input1, y)

# register forward_pre_hook
forward_pre_hook_handle1 = simplenet.register_forward_pre_hook(
forward_pre_hook1)
outs_pre_hook = simplenet(input, y)
self.assertTrue(
np.array_equal(outs_pre_hook.numpy(),
outs_origin1.numpy()))

# remove forward_pre_hook
forward_pre_hook_handle1.remove()
outs_pre_hook = simplenet(input, y)
self.assertTrue(
np.array_equal(outs_pre_hook.numpy(),
outs_origin.numpy()))

# register forward_post_hook
forward_post_hook_handle1 = simplenet.register_forward_post_hook(
forward_post_hook1)
outs_forward_hook = simplenet(input, y)
self.assertTrue(
np.array_equal(outs_forward_hook.numpy(),
outs_origin.numpy() * 2))

# remove forward_post_hook
forward_post_hook_handle1.remove()
outs_forward_hook = simplenet(input, y)
self.assertTrue(
np.array_equal(outs_forward_hook.numpy(),
outs_origin.numpy()))

# test forward_pre_hook and forward_post_hook that don't have return value
def test_forward_hook(self):
seed = 90

Expand All @@ -133,7 +211,7 @@ def test_forward_hook(self):
fluid.default_main_program().random_seed = seed
fluid.set_flags({'FLAGS_sort_sum_gradient': True})

global call_forward_hook
global call_forward_post_hook
global call_forward_pre_hook

input_word = np.array(
Expand All @@ -158,38 +236,102 @@ def test_forward_hook(self):

# origin, don't register any hook
outs_origin = simplenet(input, y)
self.assertFalse(call_forward_hook)
self.assertFalse(call_forward_post_hook)
self.assertFalse(call_forward_pre_hook)

# register forward_hook and forward_pre_hook
forward_hook_handle = simplenet.register_forward_post_hook(
forward_hook)
# register forward_post_hook and forward_pre_hook
forward_post_hook_handle = simplenet.register_forward_post_hook(
forward_post_hook)
forward_pre_hook_handle = simplenet.register_forward_pre_hook(
forward_pre_hook)
outs_hook = simplenet(input, y)
self.assertTrue(call_forward_hook)
self.assertTrue(call_forward_post_hook)
self.assertTrue(call_forward_pre_hook)

outs_hook = simplenet(input, y)
self.assertTrue(call_forward_hook)
self.assertTrue(call_forward_post_hook)
self.assertTrue(call_forward_pre_hook)

# remove forward_hook
forward_hook_handle.remove()
call_forward_hook = False
# remove forward_post_hook
forward_post_hook_handle.remove()
call_forward_post_hook = False
call_forward_pre_hook = False
outs_remove_forward_hook = simplenet(input, y)
self.assertFalse(call_forward_hook)
self.assertFalse(call_forward_post_hook)
self.assertTrue(call_forward_pre_hook)

# remove forward_pre_hook
forward_pre_hook_handle.remove()
call_forward_hook = False
call_forward_post_hook = False
call_forward_pre_hook = False
outs_remove_hook = simplenet(input, y)
self.assertFalse(call_forward_hook)
self.assertFalse(call_forward_post_hook)
self.assertFalse(call_forward_pre_hook)

for place in places:
with fluid.dygraph.guard(place):
with _test_eager_guard():
fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed
fluid.set_flags({'FLAGS_sort_sum_gradient': True})

global call_forward_post_hook_eager
global call_forward_pre_hook_eager

input_word = np.array(
[0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 1, 2, 3, 4, 5, 6, 7,
8]).reshape(6, 3).astype('int64')
input_word = input_word.reshape((-1, 3, 1))
y_data = np.array(
[1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8,
9]).reshape(6, 3).astype('int64')
y_data = y_data.reshape((-1, 1))

input = base.to_variable(input_word)
y = base.to_variable(y_data)

simplenet = SimpleNet(
hidden_size=20,
vocab_size=32,
num_steps=3,
init_scale=0.1,
is_sparse=False,
dtype="float32")

# origin, don't register any hook
outs_origin = simplenet(input, y)
self.assertFalse(call_forward_post_hook_eager)
self.assertFalse(call_forward_pre_hook_eager)

# register forward_post_hook and forward_pre_hook
forward_post_hook_handle = simplenet.register_forward_post_hook(
eager_forward_post_hook)
forward_pre_hook_handle = simplenet.register_forward_pre_hook(
eager_forward_pre_hook)
outs_hook = simplenet(input, y)
self.assertTrue(call_forward_post_hook_eager)
self.assertTrue(call_forward_pre_hook_eager)

outs_hook = simplenet(input, y)
self.assertTrue(call_forward_post_hook_eager)
self.assertTrue(call_forward_pre_hook_eager)

# remove forward_post_hook
forward_post_hook_handle.remove()
call_forward_post_hook_eager = False
call_forward_pre_hook_eager = False
outs_remove_forward_hook = simplenet(input, y)
self.assertFalse(call_forward_post_hook_eager)
self.assertTrue(call_forward_pre_hook_eager)

# remove forward_pre_hook
forward_pre_hook_handle.remove()
call_forward_post_hook_eager = False
call_forward_pre_hook_eager = False
outs_remove_hook = simplenet(input, y)
self.assertFalse(call_forward_post_hook_eager)
self.assertFalse(call_forward_pre_hook_eager)


if __name__ == '__main__':
unittest.main()