Skip to content

Commit

Permalink
bugfix
Browse files Browse the repository at this point in the history
  • Loading branch information
stnoah1 committed Nov 22, 2021
1 parent 68d13e9 commit 8ff68e1
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 18 deletions.
2 changes: 1 addition & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def train(self, epoch, save_model=False):
k: '{:02d}%'.format(int(round(v * 100 / sum(timer.values()))))
for k, v in timer.items()
}
self.print_log(f'\tTraining loss: {np.mean(loss_value):.4f}. Training acc: {np.mean(acc_value)*100):.2f}%.')
self.print_log(f'\tTraining loss: {np.mean(loss_value):.4f}. Training acc: {np.mean(acc_value)*100:.2f}%.')
self.print_log(f'\tTime consumption: [Data]{proportion["dataloader"]}, [Network]{proportion["model"]}')

if save_model:
Expand Down
23 changes: 10 additions & 13 deletions model/infogcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,6 @@
from torch import linalg as LA

from model.ms_tcn import MultiScale_TemporalConv as MS_TCN
from model.ms_gcn import MultiScale_GraphConv as MS_GCN
from model.ms_gcn import MultiHead_GraphConv as MH_GCN
from model.port import MORT
from einops import rearrange, repeat

from utils import set_parameter_requires_grad, get_vector_property
Expand All @@ -37,15 +34,15 @@ def __init__(self, num_class=60, num_point=25, num_person=2, graph=None, in_chan
self.to_joint_embedding = nn.Linear(in_channels, base_channel)
self.pos_embedding = nn.Parameter(torch.randn(1, self.num_point, base_channel))

self.l2 = EncodingBlock(base_channel, base_channel,A, adaptive=adaptive)
self.l3 = EncodingBlock(base_channel, base_channel,A, adaptive=adaptive)
self.l4 = EncodingBlock(base_channel, base_channel,A, adaptive=adaptive)
self.l5 = EncodingBlock(base_channel, base_channel*2, A, stride=2, adaptive=adaptive)
self.l6 = EncodingBlock(base_channel*2, base_channel*2, A, adaptive=adaptive)
self.l7 = EncodingBlock(base_channel*2, base_channel*2, A, adaptive=adaptive)
self.l8 = EncodingBlock(base_channel*2, base_channel*4, A, stride=2, adaptive=adaptive)
self.l9 = EncodingBlock(base_channel*4, base_channel*4, A, adaptive=adaptive)
self.l10 = EncodingBlock(base_channel*4, base_channel*4, A, adaptive=adaptive)
self.l1 = EncodingBlock(base_channel, base_channel,A)
self.l2 = EncodingBlock(base_channel, base_channel,A)
self.l3 = EncodingBlock(base_channel, base_channel,A)
self.l4 = EncodingBlock(base_channel, base_channel*2, A, stride=2)
self.l5 = EncodingBlock(base_channel*2, base_channel*2, A)
self.l6 = EncodingBlock(base_channel*2, base_channel*2, A)
self.l7 = EncodingBlock(base_channel*2, base_channel*4, A, stride=2)
self.l8 = EncodingBlock(base_channel*4, base_channel*4, A)
self.l9 = EncodingBlock(base_channel*4, base_channel*4, A)
self.fc = nn.Linear(base_channel*4, base_channel*4)
self.fc_mu = nn.Linear(base_channel*4, base_channel*4)
self.fc_logvar = nn.Linear(base_channel*4, base_channel*4)
Expand Down Expand Up @@ -89,6 +86,7 @@ def forward(self, x):

x = self.data_bn(x)
x = rearrange(x, 'n (m v c) t -> (n m) c t v', m=M, v=V).contiguous()
x = self.l1(x)
x = self.l2(x)
x = self.l3(x)
x = self.l4(x)
Expand All @@ -97,7 +95,6 @@ def forward(self, x):
x = self.l7(x)
x = self.l8(x)
x = self.l9(x)
x = self.l10(x)

# N*M,C,T,V
c_new = x.size(1)
Expand Down
8 changes: 4 additions & 4 deletions model/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,8 @@ def forward(self, x, attn=None):
A = attn * self.shared_topology.unsqueeze(0)
for h in range(self.num_head):
A_h = A[:, h, :, :] # (nt)vv
x = rearrange(x, 'n c t v -> (n t) v c')
z = A_h@x
feature = rearrange(x, 'n c t v -> (n t) v c')
z = A_h@feature
z = rearrange(z, '(n t) v c-> n c t v', t=T).contiguous()
z = self.conv_d[h](z)
out = z + out if out is not None else z
Expand All @@ -132,9 +132,9 @@ def forward(self, x, attn=None):

return out

class EncodingBlcok(nn.Module):
class EncodingBlock(nn.Module):
def __init__(self, in_channels, out_channels, A, stride=1, residual=True):
super(EncodingBlcok, self).__init__()
super(EncodingBlock, self).__init__()
self.agcn = SA_GC(in_channels, out_channels, A)
self.tcn = MS_TCN(out_channels, out_channels, kernel_size=5, stride=stride,
dilations=[1, 2], residual=False)
Expand Down

0 comments on commit 8ff68e1

Please sign in to comment.