-
Notifications
You must be signed in to change notification settings - Fork 2
/
repose_single.py
111 lines (79 loc) · 4.38 KB
/
repose_single.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
from absl import flags, app
import os
import sys
sys.path.insert(0,'')
import numpy as np
import torch
import trimesh
from nnutils.train_utils import v2s_trainer
from nnutils.geom_utils import correct_bones, get_interpolated_skinning_weights, joint_transform, \
get_refined_bones_transforms, gauss_mlp_skinning, lbs
from nnutils.vis_utils import get_skeleton_vis_v1
opts = flags.FLAGS
# script specific ones
flags.DEFINE_string('canonical_mesh_path', 'vis3/mesh-rest.obj', 'path to the canonical mesh')
flags.DEFINE_string('output_dir', 'example/output_single', 'output directory')
def main(_):
trainer = v2s_trainer(opts, is_eval=True)
data_info = trainer.init_dataset()
trainer.define_model(data_info)
model = trainer.model
model.eval()
assert opts.skeleton_file != ''
bones_rst = trainer.model.bones
bones_rst, _ = correct_bones(trainer.model, bones_rst)
if opts.skeleton_bone_residual > 0:
print('residual updated')
clipped_residuals = torch.tanh(trainer.model.skel_bone_residuals) * opts.skeleton_bone_residual
trainer.model.skeleton.update_skeleton_with_residuals(clipped_residuals)
skel = get_skeleton_vis_v1(trainer.model.skeleton.joint_centers, trainer.model.skeleton.joint_connections)
skel.export('example/canonical_skel.obj')
num_skeleton_bone = trainer.model.skeleton.joint_centers.shape[0]
bone_to_skeleton_pairs = get_interpolated_skinning_weights(trainer.model.skeleton, bones_rst)
if not os.path.exists(opts.output_dir):
os.makedirs(opts.output_dir, exist_ok=True)
canonical_mesh = trimesh.load(opts.canonical_mesh_path)
pts_can = canonical_mesh.vertices
pts_can = torch.from_numpy(pts_can).float().to(trainer.model.device)
relative_transform = torch.Tensor([1,0,0,0,1,0,0,0,1,0,0,0]).view(1,12).repeat(1,num_skeleton_bone,1)
######################## Define rigid transformation of kinematic chain links here ###################################
# -30 degrees rotation about z axis for link 17
theta1 = np.radians(-30)
c1, s1 = np.cos(theta1), np.sin(theta1)
# define rotation matrix
rotation1 = torch.Tensor([c1,-s1,0,s1,c1,0,0,0,1,0,0,0])
relative_transform[0,17,:] = rotation1
# -30 degrees rotation about z axis for link 15
theta2 = np.radians(-30)
c2, s2 = np.cos(theta2), np.sin(theta2)
rotation2 = torch.Tensor([c2,-s2,0,s2,c2,0,0,0,1,0,0,0])
relative_transform[0,15,:] = rotation2
# 30 degrees rotation about z axis for link 3
theta3 = np.radians(30)
c3, s3 = np.cos(theta3), np.sin(theta3)
rotation3 = torch.Tensor([c3,-s3,0,s3,c3,0,0,0,1,0,0,0])
relative_transform[0,3,:] = rotation3
# -20 degrees rotation about z axis for link 6
theta4 = np.radians(-20)
c4, s4 = np.cos(theta4), np.sin(theta4)
rotation4 = torch.Tensor([c4,-s4,0,s4,c4,0,0,0,1,0,0,0])
relative_transform[0,6,:] = rotation4
#########################################################################################################################
relative_transform = relative_transform.view(-1,num_skeleton_bone*12)
rest_pose_code = trainer.model.rest_pose_code
rest_pose_code = rest_pose_code(torch.Tensor([0]).long().to(trainer.model.device))
nerf_skin = trainer.model.nerf_skin
skeleton_joint_transforms = trainer.model.skeleton.forward_kinematics(relative_transform)
regulated_locations = joint_transform(trainer.model.skeleton.joint_centers, skeleton_joint_transforms[0], is_vec=True)
regulated_locations = regulated_locations[0]
skel = get_skeleton_vis_v1(regulated_locations, trainer.model.skeleton.joint_connections)
skel.export(os.path.join(opts.output_dir, 'deformed_kinematic_chain.obj'))
refinement_transform = get_refined_bones_transforms(bones_rst, regulated_locations, bone_to_skeleton_pairs, canonical_mesh.vertices.shape[0], trainer.model.device)
skin_forward = gauss_mlp_skinning(pts_can[:,None], trainer.model.embedding_xyz, bones_rst,
rest_pose_code, nerf_skin, use_hs=opts.use_hs, skin_aux=trainer.model.skin_aux)
pts_dfm, _ = lbs(bones_rst, refinement_transform,
skin_forward, pts_can[:,None], backward=False)
canonical_mesh.vertices = pts_dfm.squeeze(1).detach().cpu().numpy()
canonical_mesh.export(os.path.join(opts.output_dir,'deformed_mesh.obj'))
if __name__ == '__main__':
app.run(main)