Skip to content

Commit

Permalink
[Hybrid Parallel]add op_device in seed op for recompute
Browse files Browse the repository at this point in the history
  • Loading branch information
ForFishes committed Jul 14, 2021
1 parent cb6510f commit 52c1a95
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion python/paddle/fluid/backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,12 +189,20 @@ def modify_forward_desc_for_recompute(self):
persistable=False,
stop_gradient=False)
seed = 0 if op.attr("fix_seed") is False else int(op.attr("seed"))

op_device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName(
)
op_device = ""
if op.desc.has_attr(op_device_attr_name):
op_device = op.desc.attr(op_device_attr_name)

added_op = self.block._insert_op(
index=op.idx,
type='seed',
inputs={},
outputs={'Out': [added_var]},
attrs={'seed': seed})
attrs={'seed': seed,
'op_device': op_device})
self.ops.insert(op_idx, added_op)
# modify dropout op desc so that it accept a seed var as input
op.desc.set_input("Seed", [var_unique_name])
Expand Down

0 comments on commit 52c1a95

Please sign in to comment.