Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Extension hook executes on custom module #6

Closed
mseitzer opened this issue Jul 15, 2021 · 5 comments
Closed

[BUG] Extension hook executes on custom module #6

mseitzer opened this issue Jul 15, 2021 · 5 comments
Assignees
Labels
👷 Status: In Progress Issue is being worked on 🐛 Type: Bug Something isn't working

Comments

@mseitzer
Copy link

Cockpit crashes with non descriptive error.

Description

I am trying to run cockpit on a simple network trained with MSE loss (although implemented through custom modules, not nn.Sequential). As far as I can see, no unsupported operations are involved (Linear, Sequential, Tanh, Identity).

Cockpit crashes on computing the BatchGrad extension. Using backpack to compute the batch gradient works without crashes.

Steps to Reproduce

  1. Setup
self.loss_fn = lambda pred, target: ((target - pred) ** 2).mean(dim=-1)

self._cockpit = Cockpit(model.parameters(),
                                    quantities=configuration("economy"))
model = MyCustomModel()
model = extend(model)
optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-3)
  1. Training Step
pred = model(inp)
loss = self.loss_fn(pred, target)
mean_loss = loss.mean()

optimizer.zero_grad()
info = {
    "batch_size": len(loss),
    "individual_losses": loss,
    "loss": mean_loss,
    "optimizer": optimizer,
}
with self._cockpit(step, info=info, debug=True):
    create_graph = self._cockpit.create_graph(step)
    mean_loss.backward(create_graph=create_graph)  # CRASH HERE
optimizer.step()

Source or Possible Fix


Stacktrace

[DEBUG, step 0]
 ↪Quantities  : [<cockpit.quantities.alpha.Alpha object at 0x7fa751af5f10>, <cockpit.quantities.distance.Distance object at 0x7fa74e7c0cd0>, <cockpit.quantities.grad_hist.GradHist1d object at 0x7fa74e7c0d10>, <cockpit.quantities.grad_norm.GradNorm object at 0x7fa74e7c0d90>, <cockpit.quantities.inner_test.InnerTest object at 0x7fa74e7c0dd0>, <cockpit.quantities.loss.Loss object at 0x7fa74e7cd050>, <cockpit.quantities.norm_test.NormTest object at 0x7fa74e7cd090>, <cockpit.quantities.ortho_test.OrthoTest object at 0x7fa74e7cd0d0>, <cockpit.quantities.update_size.UpdateSize object at 0x7fa74e7cd110>]
 ↪Extensions  : [<backpack.extensions.firstorder.batch_grad.BatchGrad object at 0x7fa74e7cd750>]
 ↪Hooks       : <cockpit.quantities.utils_transforms.BatchGradTransformsHook object at 0x7fa74e7931d0>
 ↪Create graph: False
 ↪Save memory : True
[DEBUG] Running extension <backpack.extensions.firstorder.batch_grad.BatchGrad object at 0x7fa74e7cd750> on Linear(in_features=128, out_features=1, bias=True)
[DEBUG] Running extension <backpack.extensions.firstorder.batch_grad.BatchGrad object at 0x7fa74e7cd750> on LinearHead(
  (l): Linear(in_features=128, out_features=1, bias=True)
)
[DEBUG] Running extension <backpack.extensions.firstorder.batch_grad.BatchGrad object at 0x7fa74e7cd750> on Identity()
[DEBUG] Running extension <backpack.extensions.firstorder.batch_grad.BatchGrad object at 0x7fa74e7cd750> on Sequential(
  (0): Linear(in_features=1, out_features=128, bias=True)
  (1): Tanh()
  (2): Linear(in_features=128, out_features=128, bias=True)
  (3): Tanh()
  (4): LinearHead(
    (l): Linear(in_features=128, out_features=1, bias=True)
  )
  (5): Identity()
)
[DEBUG] Running extension <backpack.extensions.firstorder.batch_grad.BatchGrad object at 0x7fa74e7cd750> on MLP(
  (layers): Sequential(
    (0): Linear(in_features=1, out_features=128, bias=True)
    (1): Tanh()
    (2): Linear(in_features=128, out_features=128, bias=True)
    (3): Tanh()
    (4): LinearHead(
      (l): Linear(in_features=128, out_features=1, bias=True)
    )
    (5): Identity()
  )
)
Traceback (most recent call last):
  File "./envs/pytorch/lib/python3.7/site-packages/backpack/__init__.py", line 169, in run_extension_hook
    CTX.get_extension_hook()(module)
  File "./envs/pytorch/lib/python3.7/site-packages/cockpit/quantities/hooks/base.py", line 53, in __call__
    self.run_hook(param, module)
  File "./envs/pytorch/lib/python3.7/site-packages/cockpit/quantities/hooks/base.py", line 80, in run_hook
    value = self.module_hook(param, module)
  File "./envs/pytorch/lib/python3.7/site-packages/cockpit/quantities/hooks/base.py", line 139, in module_hook
    return self.param_hook(param)
  File "./envs/pytorch/lib/python3.7/site-packages/cockpit/quantities/utils_transforms.py", line 78, in param_hook
    param.grad_batch._param_weakref = weakref.ref(param)
AttributeError: 'Parameter' object has no attribute 'grad_batch'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "./code/trainers/base_trainer.py", line 112, in backward
    mean_loss.backward(create_graph=create_graph)
  File "./envs/pytorch/lib/python3.7/site-packages/torch/tensor.py", line 221, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "./envs/pytorch/lib/python3.7/site-packages/torch/autograd/__init__.py", line 132, in backward
    allow_unreachable=True)  # allow_unreachable flag
  File "./envs/pytorch/lib/python3.7/site-packages/backpack/__init__.py", line 151, in hook_run_extensions
    run_extension_hook(module)
  File "./envs/pytorch/lib/python3.7/site-packages/backpack/__init__.py", line 172, in run_extension_hook
    raise RuntimeError(f"Post extensions hook failed: {message}")
RuntimeError: Post extensions hook failed: AttributeError("'Parameter' object has no attribute 'grad_batch'")

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "code/train.py", line 215, in <module>
    metrics = main(opts)
  File "code/train.py", line 185, in main
    training_info = trainer.train(model, dataset, eval_data, logger)
  File "./code/trainers/mse_trainer.py", line 128, in train
    self.backward(global_step, loss, mean_loss, optimizer)
  File "./code/trainers/base_trainer.py", line 112, in backward
    mean_loss.backward(create_graph=create_graph)
  File "./envs/pytorch/lib/python3.7/site-packages/cockpit/context.py", line 137, in __exit__
    self.cp.track(self.global_step, protected_savefields=self.protected_savefields)
  File "./envs/pytorch/lib/python3.7/site-packages/cockpit/cockpit.py", line 178, in track
    q.track(global_step, self.params, batch_loss)
  File "./envs/pytorch/lib/python3.7/site-packages/cockpit/quantities/quantity.py", line 87, in track
    iteration, result = self.compute(global_step, params, batch_loss)
  File "./envs/pytorch/lib/python3.7/site-packages/cockpit/quantities/quantity.py", line 516, in compute
    save_result = self._compute(global_step, params, batch_loss)
  File "./envs/pytorch/lib/python3.7/site-packages/cockpit/quantities/quantity.py", line 538, in _compute
    self._compute_start(global_step, params, batch_loss)
  File "./envs/pytorch/lib/python3.7/site-packages/cockpit/quantities/alpha.py", line 280, in _compute_start
    self._save_1st_order_info(global_step, params, batch_loss, point, until)
  File "./envs/pytorch/lib/python3.7/site-packages/cockpit/quantities/alpha.py", line 326, in _save_1st_order_info
    grad_dict = {id(p): p.grad.data.clone().detach() for p in params}
  File "./envs/pytorch/lib/python3.7/site-packages/cockpit/quantities/alpha.py", line 326, in <dictcomp>
    grad_dict = {id(p): p.grad.data.clone().detach() for p in params}
AttributeError: 'NoneType' object has no attribute 'data'

@mseitzer mseitzer added 🆕 Status: New New Issue 🐛 Type: Bug Something isn't working labels Jul 15, 2021
@f-dangel
Copy link
Owner

f-dangel commented Jul 15, 2021

Hey Max,

thanks for taking the time and reporting such issues here. This helps us a lot to improve the library. I think you've caught the same bug as #5. TL;DR: The problem is caused by using Alpha with a non-SGD optimizer. Two hotfixes:

  1. Exclude Alpha, keep Adam
  2. Keep Alpha, swap Adam for SGD

Moving up the #5 on my priority list :)

Would be great if you confirmed my guess.

Cheers

@mseitzer
Copy link
Author

I see. I would suggest to remove alpha from the default configurations or to document this clearly, as this was surprising behavior to me; most people would probably not use plain SGD to train their networks.

Anyways. I removed alpha:

config = [
              quantities.Distance(schedules.linear(interval=1)),
              quantities.GradHist1d(schedules.linear(interval=1)),
              quantities.GradNorm(schedules.linear(interval=1)),
              quantities.InnerTest(schedules.linear(interval=1)),
              quantities.Loss(schedules.linear(interval=1)),
              quantities.NormTest(schedules.linear(interval=1)),
              quantities.OrthoTest(schedules.linear(interval=1)),
              quantities.UpdateSize(schedules.linear(interval=1))
 ]
self._cockpit = Cockpit(model.parameters(), quantities=config)

This did not do the trick:

[DEBUG, step 0]
 ↪Quantities  : [<cockpit.quantities.distance.Distance object at 0x7fe44176ae90>, <cockpit.quantities.grad_hist.GradHist1d object at 0x7fe43e2bffd0>, <cockpit.quantities.grad_norm.GradNorm object at 0x7fe43e2cd190>, <cockpit.quantities.inner_test.InnerTest object at 0x7fe43e2cd250>, <cockpit.quantities.loss.Loss object at 0x7fe43e2cd310>, <cockpit.quantities.norm_test.NormTest object at 0x7fe43e2cd3d0>, <cockpit.quantities.ortho_test.OrthoTest object at 0x7fe43e2cd490>, <cockpit.quantities.update_size.UpdateSize object at 0x7fe43e2cd550>]
 ↪Extensions  : [<backpack.extensions.firstorder.batch_grad.BatchGrad object at 0x7fe43e2cda10>]
 ↪Hooks       : <cockpit.quantities.utils_transforms.BatchGradTransformsHook object at 0x7fe43e28e890>
 ↪Create graph: False
 ↪Save memory : True
[DEBUG] Running extension <backpack.extensions.firstorder.batch_grad.BatchGrad object at 0x7fe43e2cda10> on Linear(in_features=128, out_features=1, bias=True)
[DEBUG] Running extension <backpack.extensions.firstorder.batch_grad.BatchGrad object at 0x7fe43e2cda10> on LinearHead(
  (l): Linear(in_features=128, out_features=1, bias=True)
)
[DEBUG] Running extension <backpack.extensions.firstorder.batch_grad.BatchGrad object at 0x7fe43e2cda10> on Identity()
[DEBUG] Running extension <backpack.extensions.firstorder.batch_grad.BatchGrad object at 0x7fe43e2cda10> on Sequential(
  (0): Linear(in_features=1, out_features=128, bias=True)
  (1): Tanh()
  (2): Linear(in_features=128, out_features=128, bias=True)
  (3): Tanh()
  (4): LinearHead(
    (l): Linear(in_features=128, out_features=1, bias=True)
  )
  (5): Identity()
)
[DEBUG] Running extension <backpack.extensions.firstorder.batch_grad.BatchGrad object at 0x7fe43e2cda10> on MLP(
  (layers): Sequential(
    (0): Linear(in_features=1, out_features=128, bias=True)
    (1): Tanh()
    (2): Linear(in_features=128, out_features=128, bias=True)
    (3): Tanh()
    (4): LinearHead(
      (l): Linear(in_features=128, out_features=1, bias=True)
    )
    (5): Identity()
  )
)
Traceback (most recent call last):
  File "./envs/pytorch/lib/python3.7/site-packages/backpack/__init__.py", line 169, in run_extension_hook
    CTX.get_extension_hook()(module)
  File "./envs/pytorch/lib/python3.7/site-packages/cockpit/quantities/hooks/base.py", line 53, in __call__
    self.run_hook(param, module)
  File "./envs/pytorch/lib/python3.7/site-packages/cockpit/quantities/hooks/base.py", line 80, in run_hook
    value = self.module_hook(param, module)
  File "./envs/pytorch/lib/python3.7/site-packages/cockpit/quantities/hooks/base.py", line 139, in module_hook
    return self.param_hook(param)
  File "./envs/pytorch/lib/python3.7/site-packages/cockpit/quantities/utils_transforms.py", line 78, in param_hook
    param.grad_batch._param_weakref = weakref.ref(param)
AttributeError: 'Parameter' object has no attribute 'grad_batch'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "./code/trainers/base_trainer.py", line 122, in backward
    mean_loss.backward(create_graph=create_graph)
  File "./envs/pytorch/lib/python3.7/site-packages/torch/tensor.py", line 221, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "./envs/pytorch/lib/python3.7/site-packages/torch/autograd/__init__.py", line 132, in backward
    allow_unreachable=True)  # allow_unreachable flag
  File "./envs/pytorch/lib/python3.7/site-packages/backpack/__init__.py", line 151, in hook_run_extensions
    run_extension_hook(module)
  File "./envs/pytorch/lib/python3.7/site-packages/backpack/__init__.py", line 172, in run_extension_hook
    raise RuntimeError(f"Post extensions hook failed: {message}")
RuntimeError: Post extensions hook failed: AttributeError("'Parameter' object has no attribute 'grad_batch'")

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "code/train.py", line 215, in <module>
    metrics = main(opts)
  File "code/train.py", line 185, in main
    training_info = trainer.train(model, dataset, eval_data, logger)
  File "./code/trainers/mse_trainer.py", line 128, in train
    self.backward(global_step, loss, mean_loss, optimizer)
  File "./code/trainers/base_trainer.py", line 122, in backward
    mean_loss.backward(create_graph=create_graph)
  File "./envs/pytorch/lib/python3.7/site-packages/cockpit/context.py", line 137, in __exit__
    self.cp.track(self.global_step, protected_savefields=self.protected_savefields)
  File "./envs/pytorch/lib/python3.7/site-packages/cockpit/cockpit.py", line 178, in track
    q.track(global_step, self.params, batch_loss)
  File "./envs/pytorch/lib/python3.7/site-packages/cockpit/quantities/grad_hist.py", line 93, in track
    super().track(global_step, params, batch_loss)
  File "./envs/pytorch/lib/python3.7/site-packages/cockpit/quantities/quantity.py", line 87, in track
    iteration, result = self.compute(global_step, params, batch_loss)
  File "./envs/pytorch/lib/python3.7/site-packages/cockpit/quantities/quantity.py", line 444, in compute
    return (global_step, self._compute(global_step, params, batch_loss))
  File "./envs/pytorch/lib/python3.7/site-packages/cockpit/quantities/grad_hist.py", line 114, in _compute
    hist = sum(p.grad_batch_transforms["hist_1d"][0] for p in params).detach()
  File "./envs/pytorch/lib/python3.7/site-packages/cockpit/quantities/grad_hist.py", line 114, in <genexpr>
    hist = sum(p.grad_batch_transforms["hist_1d"][0] for p in params).detach()
AttributeError: 'Parameter' object has no attribute 'grad_batch_transforms'

f-dangel added a commit that referenced this issue Jul 19, 2021
@f-dangel f-dangel changed the title Cockpit crash on computing BatchGrad extension [BUG] Extension hook executes on custom module Jul 19, 2021
@f-dangel
Copy link
Owner

Thanks for the report,

Based on the traceback, I think the problem is introduced by your custom MLP module which is not recognized as container module, for which the extension hook execution should be skipped. I pushed a fix on bug-extension-hook-executes-on-custom-module that will skip execution for any module that contains other modules.

Can you install from that branch and check if it did the trick?

@mseitzer
Copy link
Author

Yes, now it works, very nice!

@f-dangel f-dangel added 👷 Status: In Progress Issue is being worked on and removed 🆕 Status: New New Issue labels Jul 20, 2021
f-dangel added a commit that referenced this issue Jul 20, 2021
* [BUG] Reproduce bug described in #6

* [FIX] Skip extension hook if module has children

* [FIX] flake8
@f-dangel
Copy link
Owner

Merged to development.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
👷 Status: In Progress Issue is being worked on 🐛 Type: Bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants