Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Problems of Parameter registration #7

Open
Kylin9511 opened this issue May 8, 2019 · 10 comments
Open

Problems of Parameter registration #7

Kylin9511 opened this issue May 8, 2019 · 10 comments

Comments

@Kylin9511
Copy link

Kylin9511 commented May 8, 2019

key_rel_w = nn.Parameter(torch.randn((2 * W - 1, dk), requires_grad=True)).to(device)
rel_logits_w = self.relative_logits_1d(q, key_rel_w, H, W, Nh, "w")
key_rel_h = nn.Parameter(torch.randn((2 * H - 1, dk), requires_grad=True)).to(device)
rel_logits_h = self.relative_logits_1d(torch.transpose(q, 2, 3), key_rel_h, W, H, Nh, "h")

  • I think if you register your Parameters here, it can not be correctly optimized.
  • Generally your optimizer takes model.named_parameters() as input. And the optimizer.step() and optimizer.zero_grad() will ignore your key_rel_w and key_rel_h because they are not in the model.named_parameters(). [Through the gradients will be calculated normally when loss.backward() is called.]
  • use self.key_rel_w and self.key_rel_h instead.
@leaderj1001
Copy link
Owner

leaderj1001 commented May 8, 2019

I'll check it as soon as possible. Thank you.

@Kylin9511
Copy link
Author

Kylin9511 commented May 10, 2019

Emmm, I realized that register your parameter is not enough. (register_parameter recommended)

Your need to call optim.add_param_group({"params": my_new_param}) as well to enable the training.

Maybe there exist easier ways. Tell me if you have found one, Thx~

@leaderj1001
Copy link
Owner

Thank you for your opinion. I'll solve this problem.
I have solved this problem by declaring nn.Parameter () when doing init.
Soon, I will fix the code. Thank you so much.

@Kylin9511
Copy link
Author

Kylin9511 commented May 11, 2019

Well, that is the simple but ugly way (laughter)

Since the size of key_rel_w and key_rel_h relies on certain feature map. Setting them in init before the first forward will be OK, but still like a hack. If slight difference occur in input image size, your model will break. That means a limitation for domain transfer and finetune.

Anyway, fix h and w in init will solve this problem.

By the way, may you provide your reimplement result on ImageNet?

@leaderj1001
Copy link
Owner

I know.. If we use a “relative=True”, we have to fix a output shape. If you know how to solve this problem, could I get some your idea?
I don’t have enough gpus. Therefore, the results for CIFAR-100 are being tested first.
I’ll try as soon as possible.Thank you.

@Kylin9511
Copy link
Author

Sorry, I don't have an elegant way either.

You can use register_parameter and optim.add_param_group like I said. This will enable the dynamic network definition.

But it is still troublesome.

@leaderj1001
Copy link
Owner

Thank you for advise. I’ll try later.

Thanks for the issue. I'll make changes to keep it going :)

@nhatsmrt
Copy link

nhatsmrt commented Jun 19, 2019

This is not mentioned in the paper, but I think we can get around the problem of variable input size by combining this with an adaptive pooling layer right before it. This way, even though we still have to fix H and W for this particular layer, we can feed into the network what ever input size we want.

@xarryon
Copy link

xarryon commented Oct 15, 2020

so... why the dim0 of key_rel_w/h should be 2*self.shape-1 instead of others?

@GingerCohle
Copy link

so... why the dim0 of key_rel_w/h should be 2*self.shape-1 instead of others?

have you figure out this problem

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants