Skip to content

Commit

Permalink
remove duplicate frames and keep timestamp, fix psgan docs (PaddlePad…
Browse files Browse the repository at this point in the history
…dle#74)

* remove duplicate frames and keep timestamp, fix psgan docs
  • Loading branch information
lijianshe02 committed Nov 6, 2020
1 parent 83c1b7e commit f5be3a9
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 33 deletions.
6 changes: 3 additions & 3 deletions docs/en_US/tutorials/psgan.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,14 @@ python tools/psgan_infer.py \
```
mv landmarks/makeup MT-Dataset/landmarks/makeup
mv landmarks/non-makeup MT-Dataset/landmarks/non-makeup
mv landmarks/train_makeup.txt MT-Dataset/makeup.txt
mv tlandmarks/train_non-makeup.txt MT-Dataset/non-makeup.txt
cp landmarks/train_makeup.txt MT-Dataset/train_makeup.txt
cp landmarks/train_non-makeup.txt MT-Dataset/train_non-makeup.txt
```

The final data directory should be looked like:

```
data
data/MT-Dataset
├── images
│ ├── makeup
│ └── non-makeup
Expand Down
8 changes: 4 additions & 4 deletions docs/zh_CN/tutorials/psgan.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

## 1. PSGAN原理

[PSGAN](https://arxiv.org/abs/1909.06956)模型的任务是妆容迁移, 即将任意参照图像上的妆容迁移到不带妆容的源图像上。很多人像美化应用都需要这种技术。近来的一些妆容迁移方法大都基于生成对抗网络(GAN)。它们通常采用 CycleGAN 的框架,并在两个数据集上进行训练,即无妆容图像和有妆容图像。但是,现有的方法存在一个局限性:只在正面人脸图像上表现良好,没有为处理源图像和参照图像之间的姿态和表情差异专门设计模块。PSGAN是一种全新的姿态稳健可感知空间的生生成对抗网络。PSGAN 主要分为三部分:妆容提炼网络(MDNet)、注意式妆容变形(AMM)模块和卸妆-再化妆网络(DRNet)。这三种新提出的模块能让 PSGAN 具备上述的完美妆容迁移模型所应具备的能力。
[PSGAN](https://arxiv.org/abs/1909.06956)模型的任务是妆容迁移, 即将任意参照图像上的妆容迁移到不带妆容的源图像上。很多人像美化应用都需要这种技术。近来的一些妆容迁移方法大都基于生成对抗网络(GAN)。它们通常采用 CycleGAN 的框架,并在两个数据集上进行训练,即无妆容图像和有妆容图像。但是,现有的方法存在一个局限性:只在正面人脸图像上表现良好,没有为处理源图像和参照图像之间的姿态和表情差异专门设计模块。PSGAN是一种全新的姿态稳健可感知空间的生成对抗网络。PSGAN 主要分为三部分:妆容提炼网络(MDNet)、注意式妆容变形(AMM)模块和卸妆-再化妆网络(DRNet)。这三种新提出的模块能让 PSGAN 具备上述的完美妆容迁移模型所应具备的能力。

<div align="center">
<img src="../../imgs/psgan_arc.png" width="800"/>
Expand Down Expand Up @@ -35,13 +35,13 @@ python tools/psgan_infer.py \
```
mv landmarks/makeup MT-Dataset/landmarks/makeup
mv landmarks/non-makeup MT-Dataset/landmarks/non-makeup
mv landmarks/train_makeup.txt MT-Dataset/makeup.txt
mv tlandmarks/train_non-makeup.txt MT-Dataset/non-makeup.txt
cp landmarks/train_makeup.txt MT-Dataset/train_makeup.txt
cp landmarks/train_non-makeup.txt MT-Dataset/train_non-makeup.txt
```

最后数据集目录如下所示:
```
data
data/MT-Dataset
├── images
│   ├── makeup
│   └── non-makeup
Expand Down
95 changes: 69 additions & 26 deletions ppgan/apps/dain_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,14 +82,9 @@ def run(self, video_path):
vidname = video_path.split('/')[-1].split('.')[0]

frames = sorted(glob.glob(os.path.join(out_path, '*.png')))
orig_frames = len(frames)
need_frames = orig_frames * times_interp

if self.remove_duplicates:
frames = self.remove_duplicate_frames(out_path)
left_frames = len(frames)
timestep = left_frames / need_frames
num_frames = int(1.0 / timestep) - 1

img = imread(frames[0])

Expand Down Expand Up @@ -125,9 +120,11 @@ def run(self, video_path):
if not os.path.exists(os.path.join(frame_path_combined, vidname)):
os.makedirs(os.path.join(frame_path_combined, vidname))

for i in tqdm(range(frame_num - 1)):
for i in range(frame_num - 1):
first = frames[i]
second = frames[i + 1]
first_index = int(first.split('/')[-1].split('.')[-2])
second_index = int(second.split('/')[-1].split('.')[-2])

img_first = imread(first)
img_second = imread(second)
Expand Down Expand Up @@ -173,22 +170,43 @@ def run(self, video_path):
padding_left:padding_left + int_width],
(1, 2, 0)) for item in y_
]
time_offsets = [kk * timestep for kk in range(1, 1 + num_frames, 1)]

count = 1
for item, time_offset in zip(y_, time_offsets):
out_dir = os.path.join(frame_path_interpolated, vidname,
"{:0>6d}_{:0>4d}.png".format(i, count))
count = count + 1
imsave(out_dir, np.round(item).astype(np.uint8))

num_frames = int(1.0 / timestep) - 1
if self.remove_duplicates:
num_frames = times_interp * (second_index - first_index) - 1
time_offsets = [
kk * timestep for kk in range(1, 1 + num_frames, 1)
]
start = times_interp * first_index + 1
for item, time_offset in zip(y_, time_offsets):
out_dir = os.path.join(frame_path_interpolated, vidname,
"{:08d}.png".format(start))
imsave(out_dir, np.round(item).astype(np.uint8))
start = start + 1

else:
time_offsets = [
kk * timestep for kk in range(1, 1 + num_frames, 1)
]

count = 1
for item, time_offset in zip(y_, time_offsets):
out_dir = os.path.join(
frame_path_interpolated, vidname,
"{:0>6d}_{:0>4d}.png".format(i, count))
count = count + 1
imsave(out_dir, np.round(item).astype(np.uint8))

input_dir = os.path.join(frame_path_input, vidname)
interpolated_dir = os.path.join(frame_path_interpolated, vidname)
combined_dir = os.path.join(frame_path_combined, vidname)
self.combine_frames(input_dir, interpolated_dir, combined_dir,
num_frames)

if self.remove_duplicates:
self.combine_frames_with_rm(input_dir, interpolated_dir,
combined_dir, times_interp)

else:
num_frames = int(1.0 / timestep) - 1
self.combine_frames(input_dir, interpolated_dir, combined_dir,
num_frames)

frame_pattern_combined = os.path.join(frame_path_combined, vidname,
'%08d.png')
Expand Down Expand Up @@ -223,6 +241,26 @@ def combine_frames(self, input, interpolated, combined, num_frames):
except Exception as e:
print(e)

def combine_frames_with_rm(self, input, interpolated, combined,
times_interp):
frames1 = sorted(glob.glob(os.path.join(input, '*.png')))
frames2 = sorted(glob.glob(os.path.join(interpolated, '*.png')))
num1 = len(frames1)
num2 = len(frames2)

for i in range(num1):
src = frames1[i]
index = int(src.split('/')[-1].split('.')[-2])
dst = os.path.join(combined,
'{:08d}.png'.format(times_interp * index))
shutil.copy2(src, dst)

for i in range(num2):
src = frames2[i]
imgname = src.split('/')[-1]
dst = os.path.join(combined, imgname)
shutil.copy2(src, dst)

def remove_duplicate_frames(self, paths):
def dhash(image, hash_size=8):
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
Expand All @@ -241,14 +279,19 @@ def dhash(image, hash_size=8):

for (h, hashed_paths) in hashes.items():
if len(hashed_paths) > 1:
for p in hashed_paths[1:]:
os.remove(p)

frames = sorted(glob.glob(os.path.join(paths, '*.png')))
for fid, frame in enumerate(frames):
new_name = '{:08d}'.format(fid) + '.png'
new_name = os.path.join(paths, new_name)
os.rename(frame, new_name)
first_index = int(hashed_paths[0].split('/')[-1].split('.')[-2])
last_index = int(
hashed_paths[-1].split('/')[-1].split('.')[-2]) + 1
gap = 2 * (last_index - first_index) - 1
if gap > 9:
mid = len(hashed_paths) // 2
for p in hashed_paths[1:mid - 1]:
os.remove(p)
for p in hashed_paths[mid + 1:]:
os.remove(p)
else:
for p in hashed_paths[1:]:
os.remove(p)

frames = sorted(glob.glob(os.path.join(paths, '*.png')))
return frames

0 comments on commit f5be3a9

Please sign in to comment.