diff --git a/.gitignore b/.gitignore
index 65365db2..801935a6 100644
--- a/.gitignore
+++ b/.gitignore
@@ -161,7 +161,8 @@ cython_debug/
examples/results/*
gfpgan/*
-checkpoints/
+checkpoints/*
+assets/*
results/*
Dockerfile
start_docker.sh
diff --git a/app.py b/app.py
index edde0cf4..9b5afa3a 100644
--- a/app.py
+++ b/app.py
@@ -1,14 +1,70 @@
import os, sys
-import tempfile
import gradio as gr
from src.gradio_demo import SadTalker
-def get_source_image(image):
- return image
-def sadtalker_demo():
+try:
+ import webui # in webui
+ in_webui = True
+except:
+ in_webui = False
- sad_talker = SadTalker(lazy_load=True)
+# mimetypes.init()
+# mimetypes.add_type('application/javascript', '.js')
+
+# script_path = os.path.dirname(os.path.realpath(__file__))
+
+# def webpath(fn):
+# if fn.startswith(script_path):
+# web_path = os.path.relpath(fn, script_path).replace('\\', '/')
+# else:
+# web_path = os.path.abspath(fn)
+
+# return f'file={web_path}?{os.path.getmtime(fn)}'
+
+# def javascript_html():
+# # Ensure localization is in `window` before scripts
+# # head = f'\n'
+# head = 'somehead'
+
+# script_js = os.path.join(script_path, "assets", "script.js")
+# head += f'\n'
+
+# script_js = os.path.join(script_path, "assets", "aspectRatioOverlay.js")
+# head += f'\n'
+
+# return head
+
+# def resize_from_to_html(width, height, scale_by):
+# target_width = int(width * scale_by)
+# target_height = int(height * scale_by)
+
+# if not target_width or not target_height:
+# return "no image selected"
+
+# return f"resize: from {width}x{height} to {target_width}x{target_height}"
+
+# def get_source_image(image):
+# return image
+
+# def reload_javascript():
+# js = javascript_html()
+
+# def template_response(*args, **kwargs):
+# res = shared.GradioTemplateResponseOriginal(*args, **kwargs)
+# res.body = res.body.replace(b'', f'{js}'.encode("utf8"))
+# res.init_headers()
+# return res
+
+# gradio.routes.templates.TemplateResponse = template_response
+
+# if not hasattr(shared, 'GradioTemplateResponseOriginal'):
+# shared.GradioTemplateResponseOriginal = gradio.routes.templates.TemplateResponse
+
+
+def sadtalker_demo(checkpoint_path='checkpoint', config_path='src/config', warpfn=None):
+
+ sad_talker = SadTalker(checkpoint_path, config_path, lazy_load=True)
with gr.Blocks(analytics_enabled=False) as sadtalker_interface:
gr.Markdown("
😠SadTalker: Learning Realistic 3D Motion Coefficients for Stylized Audio-Driven Single Image Talking Face Animation (CVPR 2023)
\
@@ -21,14 +77,15 @@ def sadtalker_demo():
with gr.Tabs(elem_id="sadtalker_source_image"):
with gr.TabItem('Upload image'):
with gr.Row():
- source_image = gr.Image(label="Source image", source="upload", type="filepath").style(height=256,width=256)
-
+ source_image = gr.Image(label="Source image", source="upload", type="filepath", elem_id="img2img_image").style(width=512)
+
+
with gr.Tabs(elem_id="sadtalker_driven_audio"):
with gr.TabItem('Upload OR TTS'):
with gr.Column(variant='panel'):
driven_audio = gr.Audio(label="Input audio", source="upload", type="filepath")
-
- if sys.platform != 'win32':
+
+ if sys.platform != 'win32' and not in_webui:
from src.utils.text2speech import TTSTalker
tts_talker = TTSTalker()
with gr.Column(variant='panel'):
@@ -36,105 +93,36 @@ def sadtalker_demo():
tts = gr.Button('Generate audio',elem_id="sadtalker_audio_generate", variant='primary')
tts.click(fn=tts_talker.test, inputs=[input_text], outputs=[driven_audio])
-
with gr.Column(variant='panel'):
with gr.Tabs(elem_id="sadtalker_checkbox"):
with gr.TabItem('Settings'):
+ gr.Markdown("need help? please visit our [[best practice page](https://github.com/OpenTalker/SadTalker/blob/main/docs/best_practice.md)] for more detials")
with gr.Column(variant='panel'):
- preprocess_type = gr.Radio(['crop','resize','full'], value='crop', label='preprocess', info="How to handle input image?")
- is_still_mode = gr.Checkbox(label="w/ Still Mode (fewer hand motion, works with preprocess `full`)")
- enhancer = gr.Checkbox(label="w/ GFPGAN as Face enhancer")
+ # width = gr.Slider(minimum=64, elem_id="img2img_width", maximum=2048, step=8, label="Manually Crop Width", value=512) # img2img_width
+ # height = gr.Slider(minimum=64, elem_id="img2img_height", maximum=2048, step=8, label="Manually Crop Height", value=512) # img2img_width
+ pose_style = gr.Slider(minimum=0, maximum=46, step=1, label="Pose style", value=0) #
+ size_of_image = gr.Radio([256, 512], value=256, label='face model resolution', info="use 256/512 model?") #
+ preprocess_type = gr.Radio(['crop', 'resize','full', 'extcrop', 'extfull'], value='crop', label='preprocess', info="How to handle input image?")
+ is_still_mode = gr.Checkbox(label="Still Mode (fewer hand motion, works with preprocess `full`)")
+ batch_size = gr.Slider(label="batch size in generation", step=1, maximum=10, value=2)
+ enhancer = gr.Checkbox(label="GFPGAN as Face enhancer")
submit = gr.Button('Generate', elem_id="sadtalker_generate", variant='primary')
+
with gr.Tabs(elem_id="sadtalker_genearted"):
gen_video = gr.Video(label="Generated video", format="mp4").style(width=256)
-
-
- with gr.Row():
- examples = [
- [
- 'examples/source_image/full_body_1.png',
- 'examples/driven_audio/bus_chinese.wav',
- 'crop',
- True,
- False
- ],
- [
- 'examples/source_image/full_body_2.png',
- 'examples/driven_audio/japanese.wav',
- 'crop',
- False,
- False
- ],
- [
- 'examples/source_image/full3.png',
- 'examples/driven_audio/deyu.wav',
- 'crop',
- False,
- True
- ],
- [
- 'examples/source_image/full4.jpeg',
- 'examples/driven_audio/eluosi.wav',
- 'full',
- False,
- True
- ],
- [
- 'examples/source_image/full4.jpeg',
- 'examples/driven_audio/imagine.wav',
- 'full',
- True,
- True
- ],
- [
- 'examples/source_image/full_body_1.png',
- 'examples/driven_audio/bus_chinese.wav',
- 'full',
- True,
- False
- ],
- [
- 'examples/source_image/art_13.png',
- 'examples/driven_audio/fayu.wav',
- 'resize',
- True,
- False
- ],
- [
- 'examples/source_image/art_5.png',
- 'examples/driven_audio/chinese_news.wav',
- 'resize',
- False,
- False
- ],
- [
- 'examples/source_image/art_5.png',
- 'examples/driven_audio/RD_Radio31_000.wav',
- 'resize',
- True,
- True
- ],
- ]
- gr.Examples(examples=examples,
- inputs=[
- source_image,
- driven_audio,
- preprocess_type,
- is_still_mode,
- enhancer],
- outputs=[gen_video],
- fn=sad_talker.test,
- cache_examples=os.getenv('SYSTEM') == 'spaces') #
-
submit.click(
fn=sad_talker.test,
inputs=[source_image,
driven_audio,
preprocess_type,
is_still_mode,
- enhancer],
+ enhancer,
+ batch_size,
+ size_of_image,
+ pose_style
+ ],
outputs=[gen_video]
)
@@ -144,6 +132,7 @@ def sadtalker_demo():
if __name__ == "__main__":
demo = sadtalker_demo()
- demo.launch(share=True)
+ demo.queue()
+ demo.launch()
diff --git a/inference.py b/inference.py
index 10409256..a0b00790 100644
--- a/inference.py
+++ b/inference.py
@@ -1,3 +1,5 @@
+from glob import glob
+import shutil
import torch
from time import strftime
import os, sys, time
@@ -8,6 +10,7 @@
from src.facerender.animate import AnimateFromCoeff
from src.generate_batch import get_data
from src.generate_facerender_batch import get_facerender_data
+from src.utils.init_path import init_path
def main(args):
#torch.backends.cudnn.enabled = False
@@ -25,51 +28,23 @@ def main(args):
ref_eyeblink = args.ref_eyeblink
ref_pose = args.ref_pose
- current_code_path = sys.argv[0]
- current_root_path = os.path.split(current_code_path)[0]
+ current_root_path = os.path.split(sys.argv[0])[0]
- os.environ['TORCH_HOME']=os.path.join(current_root_path, args.checkpoint_dir)
-
- path_of_lm_croper = os.path.join(current_root_path, args.checkpoint_dir, 'shape_predictor_68_face_landmarks.dat')
- path_of_net_recon_model = os.path.join(current_root_path, args.checkpoint_dir, 'epoch_20.pth')
- dir_of_BFM_fitting = os.path.join(current_root_path, args.checkpoint_dir, 'BFM_Fitting')
- wav2lip_checkpoint = os.path.join(current_root_path, args.checkpoint_dir, 'wav2lip.pth')
-
- audio2pose_checkpoint = os.path.join(current_root_path, args.checkpoint_dir, 'auido2pose_00140-model.pth')
- audio2pose_yaml_path = os.path.join(current_root_path, 'src', 'config', 'auido2pose.yaml')
-
- audio2exp_checkpoint = os.path.join(current_root_path, args.checkpoint_dir, 'auido2exp_00300-model.pth')
- audio2exp_yaml_path = os.path.join(current_root_path, 'src', 'config', 'auido2exp.yaml')
-
- free_view_checkpoint = os.path.join(current_root_path, args.checkpoint_dir, 'facevid2vid_00189-model.pth.tar')
-
- if args.preprocess == 'full':
- mapping_checkpoint = os.path.join(current_root_path, args.checkpoint_dir, 'mapping_00109-model.pth.tar')
- facerender_yaml_path = os.path.join(current_root_path, 'src', 'config', 'facerender_still.yaml')
- else:
- mapping_checkpoint = os.path.join(current_root_path, args.checkpoint_dir, 'mapping_00229-model.pth.tar')
- facerender_yaml_path = os.path.join(current_root_path, 'src', 'config', 'facerender.yaml')
+ sadtalker_paths = init_path(args.checkpoint_dir, os.path.join(current_root_path, 'src/config'), args.size, args.old_version, args.preprocess)
#init model
- print(path_of_net_recon_model)
- preprocess_model = CropAndExtract(path_of_lm_croper, path_of_net_recon_model, dir_of_BFM_fitting, device)
-
- print(audio2pose_checkpoint)
- print(audio2exp_checkpoint)
- audio_to_coeff = Audio2Coeff(audio2pose_checkpoint, audio2pose_yaml_path,
- audio2exp_checkpoint, audio2exp_yaml_path,
- wav2lip_checkpoint, device)
+ preprocess_model = CropAndExtract(sadtalker_paths, device)
+
+ audio_to_coeff = Audio2Coeff(sadtalker_paths, device)
- print(free_view_checkpoint)
- print(mapping_checkpoint)
- animate_from_coeff = AnimateFromCoeff(free_view_checkpoint, mapping_checkpoint,
- facerender_yaml_path, device)
+ animate_from_coeff = AnimateFromCoeff(sadtalker_paths, device)
#crop image and extract 3dmm from image
first_frame_dir = os.path.join(save_dir, 'first_frame_dir')
os.makedirs(first_frame_dir, exist_ok=True)
print('3DMM Extraction for source image')
- first_coeff_path, crop_pic_path, crop_info = preprocess_model.generate(pic_path, first_frame_dir, args.preprocess, source_image_flag=True)
+ first_coeff_path, crop_pic_path, crop_info = preprocess_model.generate(pic_path, first_frame_dir, args.preprocess,\
+ source_image_flag=True, pic_size=args.size)
if first_coeff_path is None:
print("Can't get the coeffs of the input")
return
@@ -79,7 +54,7 @@ def main(args):
ref_eyeblink_frame_dir = os.path.join(save_dir, ref_eyeblink_videoname)
os.makedirs(ref_eyeblink_frame_dir, exist_ok=True)
print('3DMM Extraction for the reference video providing eye blinking')
- ref_eyeblink_coeff_path, _, _ = preprocess_model.generate(ref_eyeblink, ref_eyeblink_frame_dir)
+ ref_eyeblink_coeff_path, _, _ = preprocess_model.generate(ref_eyeblink, ref_eyeblink_frame_dir, args.preprocess, source_image_flag=False)
else:
ref_eyeblink_coeff_path=None
@@ -91,7 +66,7 @@ def main(args):
ref_pose_frame_dir = os.path.join(save_dir, ref_pose_videoname)
os.makedirs(ref_pose_frame_dir, exist_ok=True)
print('3DMM Extraction for the reference video providing pose')
- ref_pose_coeff_path, _, _ = preprocess_model.generate(ref_pose, ref_pose_frame_dir)
+ ref_pose_coeff_path, _, _ = preprocess_model.generate(ref_pose, ref_pose_frame_dir, args.preprocess, source_image_flag=False)
else:
ref_pose_coeff_path=None
@@ -107,22 +82,30 @@ def main(args):
#coeff2video
data = get_facerender_data(coeff_path, crop_pic_path, first_coeff_path, audio_path,
batch_size, input_yaw_list, input_pitch_list, input_roll_list,
- expression_scale=args.expression_scale, still_mode=args.still, preprocess=args.preprocess)
+ expression_scale=args.expression_scale, still_mode=args.still, preprocess=args.preprocess, size=args.size)
- animate_from_coeff.generate(data, save_dir, pic_path, crop_info, \
- enhancer=args.enhancer, background_enhancer=args.background_enhancer, preprocess=args.preprocess)
+ result = animate_from_coeff.generate(data, save_dir, pic_path, crop_info, \
+ enhancer=args.enhancer, background_enhancer=args.background_enhancer, preprocess=args.preprocess, img_size=args.size)
+
+ shutil.move(result, save_dir+'.mp4')
+ print('The generated video is named:', save_dir+'.mp4')
+
+ if not args.verbose:
+ shutil.rmtree(save_dir)
+
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument("--driven_audio", default='./examples/driven_audio/bus_chinese.wav', help="path to driven audio")
- parser.add_argument("--source_image", default='./examples/source_image/full_body_2.png', help="path to source image")
+ parser.add_argument("--source_image", default='./examples/source_image/full_body_1.png', help="path to source image")
parser.add_argument("--ref_eyeblink", default=None, help="path to reference video providing eye blinking")
parser.add_argument("--ref_pose", default=None, help="path to reference video providing pose")
parser.add_argument("--checkpoint_dir", default='./checkpoints', help="path to output")
parser.add_argument("--result_dir", default='./results', help="path to output")
parser.add_argument("--pose_style", type=int, default=0, help="input pose style from [0, 46)")
parser.add_argument("--batch_size", type=int, default=2, help="the batch size of facerender")
+ parser.add_argument("--size", type=int, default=256, help="the image size of the facerender")
parser.add_argument("--expression_scale", type=float, default=1., help="the batch size of facerender")
parser.add_argument('--input_yaw', nargs='+', type=int, default=None, help="the input yaw degree of the user ")
parser.add_argument('--input_pitch', nargs='+', type=int, default=None, help="the input pitch degree of the user")
@@ -132,7 +115,10 @@ def main(args):
parser.add_argument("--cpu", dest="cpu", action="store_true")
parser.add_argument("--face3dvis", action="store_true", help="generate 3d face and 3d landmarks")
parser.add_argument("--still", action="store_true", help="can crop back to the original videos for the full body aniamtion")
- parser.add_argument("--preprocess", default='crop', choices=['crop', 'resize', 'full'], help="how to preprocess the images" )
+ parser.add_argument("--preprocess", default='crop', choices=['crop', 'extcrop', 'resize', 'full', 'extfull'], help="how to preprocess the images" )
+ parser.add_argument("--verbose",action="store_true", help="saving the intermedia output or not" )
+ parser.add_argument("--old_version",action="store_true", help="use the pth other than safetensor version" )
+
# net structure and parameters
parser.add_argument('--net_recon', type=str, default='resnet50', choices=['resnet18', 'resnet34', 'resnet50'], help='useless')
diff --git a/launcher.py b/launcher.py
index 565b209d..1dedb2fa 100644
--- a/launcher.py
+++ b/launcher.py
@@ -170,8 +170,8 @@ def run_extension_installer(extension_dir):
def prepare_environment():
global skip_install
- torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117")
- requirements_file = os.environ.get('REQS_FILE', "requirements.txt")
+ torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113")
+ requirements_file = os.environ.get('REQS_FILE', "req.txt")
commit = commit_hash()
@@ -181,16 +181,18 @@ def prepare_environment():
if not is_installed("torch") or not is_installed("torchvision"):
run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True)
- run_python("import torch; assert torch.cuda.is_available(), 'Torch is not able to use GPU; add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'")
-
run_pip(f"install -r \"{requirements_file}\"", "requirements for SadTalker WebUI (may take longer time in first time)")
+ if sys.platform != 'win32' and not is_installed('tts'):
+ run_pip(f"install TTS", "install TTS individually in SadTalker, which might not work on windows.")
+
def start():
print(f"Launching SadTalker Web UI")
from app import sadtalker_demo
demo = sadtalker_demo()
- demo.launch(share=True)
+ demo.queue()
+ demo.launch()
if __name__ == "__main__":
prepare_environment()
diff --git a/req.txt b/req.txt
new file mode 100644
index 00000000..3c384dd8
--- /dev/null
+++ b/req.txt
@@ -0,0 +1,23 @@
+llvmlite==0.38.1
+numpy==1.21.6
+face_alignment==1.3.5
+imageio==2.19.3
+imageio-ffmpeg==0.4.7
+librosa==0.10.0.post2
+numba==0.55.1
+resampy==0.3.1
+pydub==0.25.1
+scipy==1.10.1
+kornia==0.6.8
+tqdm
+yacs==0.1.8
+pyyaml
+joblib==1.1.0
+scikit-image==0.19.3
+basicsr==1.4.2
+facexlib==0.3.0
+gradio
+gfpgan
+dlib-bin
+av
+safetensors
diff --git a/requirements.txt b/requirements.txt
index 238c237f..2321c742 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -19,3 +19,4 @@ gradio
gfpgan
dlib-bin
av
+safetensors
diff --git a/requirements3d.txt b/requirements3d.txt
index 25b6b262..c7eeeecd 100644
--- a/requirements3d.txt
+++ b/requirements3d.txt
@@ -18,4 +18,5 @@ facexlib==0.2.5
trimesh==3.9.20
dlib-bin
gradio
-gfpgan
\ No newline at end of file
+gfpgan
+safetensors
\ No newline at end of file
diff --git a/scripts/download_models.sh b/scripts/download_models.sh
index 213eb499..425cada0 100644
--- a/scripts/download_models.sh
+++ b/scripts/download_models.sh
@@ -1,19 +1,25 @@
mkdir ./checkpoints
-wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/auido2exp_00300-model.pth -O ./checkpoints/auido2exp_00300-model.pth
-wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/auido2pose_00140-model.pth -O ./checkpoints/auido2pose_00140-model.pth
-wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/epoch_20.pth -O ./checkpoints/epoch_20.pth
-wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/facevid2vid_00189-model.pth.tar -O ./checkpoints/facevid2vid_00189-model.pth.tar
-wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/shape_predictor_68_face_landmarks.dat -O ./checkpoints/shape_predictor_68_face_landmarks.dat
-wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/wav2lip.pth -O ./checkpoints/wav2lip.pth
-wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/mapping_00229-model.pth.tar -O ./checkpoints/mapping_00229-model.pth.tar
-wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/mapping_00109-model.pth.tar -O ./checkpoints/mapping_00109-model.pth.tar
-wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/BFM_Fitting.zip -O ./checkpoints/BFM_Fitting.zip
-wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/hub.zip -O ./checkpoints/hub.zip
+
+# lagency download link
+# wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/auido2exp_00300-model.pth -O ./checkpoints/auido2exp_00300-model.pth
+# wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/auido2pose_00140-model.pth -O ./checkpoints/auido2pose_00140-model.pth
+# wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/epoch_20.pth -O ./checkpoints/epoch_20.pth
+# wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/facevid2vid_00189-model.pth.tar -O ./checkpoints/facevid2vid_00189-model.pth.tar
+# wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/shape_predictor_68_face_landmarks.dat -O ./checkpoints/shape_predictor_68_face_landmarks.dat
+# wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/wav2lip.pth -O ./checkpoints/wav2lip.pth
+# wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/mapping_00229-model.pth.tar -O ./checkpoints/mapping_00229-model.pth.tar
+# wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/mapping_00109-model.pth.tar -O ./checkpoints/mapping_00109-model.pth.tar
+# wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/hub.zip -O ./checkpoints/hub.zip
+# unzip -n ./checkpoints/hub.zip -d ./checkpoints/
-unzip -n ./checkpoints/hub.zip -d ./checkpoints/
+#### download the new links.
+wget -nc
+
+wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/BFM_Fitting.zip -O ./checkpoints/BFM_Fitting.zip
unzip -n ./checkpoints/BFM_Fitting.zip -d ./checkpoints/
+### enhancer
mkdir -p ./gfpgan/weights
wget -nc https://github.com/xinntao/facexlib/releases/download/v0.1.0/alignment_WFLW_4HG.pth -O ./gfpgan/weights/alignment_WFLW_4HG.pth
wget -nc https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_Resnet50_Final.pth -O ./gfpgan/weights/detection_Resnet50_Final.pth
diff --git a/scripts/extension.py b/scripts/extension.py
index 8cd9ca89..8fcdf3a8 100644
--- a/scripts/extension.py
+++ b/scripts/extension.py
@@ -10,6 +10,26 @@
from huggingface_hub import snapshot_download
+
+def check_all_files_safetensor(current_dir):
+ kv = {
+ "SadTalker_V0.0.2_256.safetensors": "sadtalker-256",
+ "SadTalker_V0.0.2_512.safetensors": "sadtalker-512",
+ "mapping_00109-model.pth.tar" : "mapping-109" ,
+ "mapping_00229-model.pth.tar" : "mapping-229" ,
+ }
+
+ if not os.path.isdir(current_dir):
+ return False
+
+ dirs = os.listdir(current_dir)
+
+ for f in dirs:
+ if f in kv.keys():
+ del kv[f]
+
+ return len(kv.keys()) == 0
+
def check_all_files(current_dir):
kv = {
"auido2exp_00300-model.pth": "audio2exp",
@@ -63,6 +83,14 @@ def get_default_checkpoint_path():
checkpoint_path = Path(paths.script_path) / "models"/ "SadTalker"
extension_checkpoint_path = Path(paths.script_path) / "extensions"/ "SadTalker" / "checkpoints"
+ if check_all_files_safetensor(checkpoint_path):
+ # print('founding sadtalker checkpoint in ' + str(checkpoint_path))
+ return checkpoint_path
+
+ if check_all_files_safetensor(extension_checkpoint_path):
+ # print('founding sadtalker checkpoint in ' + str(extension_checkpoint_path))
+ return extension_checkpoint_path
+
if check_all_files(checkpoint_path):
# print('founding sadtalker checkpoint in ' + str(checkpoint_path))
return checkpoint_path
@@ -91,17 +119,17 @@ def install():
"gfpgan": "gfpgan",
}
- if 'darwin' in sys.platform:
- kv['dlib'] = "dlib"
- else:
- kv['dlib'] = 'dlib-bin'
+ # dlib is not necessary currently
+ # if 'darwin' in sys.platform:
+ # kv['dlib'] = "dlib"
+ # else:
+ # kv['dlib'] = 'dlib-bin'
for k,v in kv.items():
if not launch.is_installed(k):
print(k, launch.is_installed(k))
launch.run_pip("install "+ v, "requirements for SadTalker")
-
if os.getenv('SADTALKER_CHECKPOINTS'):
print('load Sadtalker Checkpoints from '+ os.getenv('SADTALKER_CHECKPOINTS'))
@@ -138,66 +166,15 @@ def on_ui_tabs():
result_dir = opts.sadtalker_result_dir
os.makedirs(result_dir, exist_ok=True)
- from src.gradio_demo import SadTalker
+ from app import sadtalker_demo
if os.getenv('SADTALKER_CHECKPOINTS'):
checkpoint_path = os.getenv('SADTALKER_CHECKPOINTS')
else:
checkpoint_path = repo_dir+'checkpoints/'
- sad_talker = SadTalker(checkpoint_path=checkpoint_path, config_path=repo_dir+'src/config', lazy_load=True)
-
- with gr.Blocks(analytics_enabled=False) as audio_to_video:
- with gr.Row().style(equal_height=False):
- with gr.Column(variant='panel'):
- with gr.Tabs(elem_id="sadtalker_source_image"):
- with gr.TabItem('Upload image'):
- with gr.Row():
- input_image = gr.Image(label="Source image", source="upload", type="filepath").style(height=256,width=256)
-
- with gr.Row():
- submit_image2 = gr.Button('load From txt2img', variant='primary')
- submit_image2.click(fn=get_img_from_txt2img, inputs=input_image, outputs=[input_image, input_image])
-
- submit_image3 = gr.Button('load from img2img', variant='primary')
- submit_image3.click(fn=get_img_from_img2img, inputs=input_image, outputs=[input_image, input_image])
-
- with gr.Tabs(elem_id="sadtalker_driven_audio"):
- with gr.TabItem('Upload'):
- with gr.Column(variant='panel'):
-
- with gr.Row():
- driven_audio = gr.Audio(label="Input audio", source="upload", type="filepath")
-
-
- with gr.Column(variant='panel'):
- with gr.Tabs(elem_id="sadtalker_checkbox"):
- with gr.TabItem('Settings'):
- with gr.Column(variant='panel'):
- gr.Markdown("Please visit [**[here]**](https://github.com/Winfredy/SadTalker/blob/main/docs/best_practice.md) if you don't know how to choose these configurations.")
- preprocess_type = gr.Radio(['crop','resize','full'], value='crop', label='preprocess', info="How to handle input image?")
- is_still_mode = gr.Checkbox(label="Remove head motion (works better with preprocess `full`)")
- enhancer = gr.Checkbox(label="Face enhancement")
- submit = gr.Button('Generate', elem_id="sadtalker_generate", variant='primary')
- path_to_save = gr.Text(Path(paths.script_path) / "outputs/SadTalker/", visible=False)
-
- with gr.Tabs(elem_id="sadtalker_genearted"):
- gen_video = gr.Video(label="Generated video", format="mp4").style(width=256)
-
-
- ### gradio gpu call will always return the html,
- submit.click(
- fn=wrap_queued_call(sad_talker.test),
- inputs=[input_image,
- driven_audio,
- preprocess_type,
- is_still_mode,
- enhancer,
- path_to_save
- ],
- outputs=[gen_video, ]
- )
-
+ audio_to_video = sadtalker_demo(checkpoint_path=checkpoint_path, config_path=repo_dir+'src/config', warpfn = wrap_queued_call)
+
return [(audio_to_video, "SadTalker", "extension")]
def on_ui_settings():
diff --git a/scripts/test.sh b/scripts/test.sh
index 69147611..281c9c03 100644
--- a/scripts/test.sh
+++ b/scripts/test.sh
@@ -1 +1,21 @@
-### some test command before commit.
\ No newline at end of file
+### some test command before commit.
+python inference.py --preprocess crop --size 256
+python inference.py --preprocess crop --size 512
+
+python inference.py --preprocess extcrop --size 256
+python inference.py --preprocess extcrop --size 512
+
+python inference.py --preprocess resize --size 256
+python inference.py --preprocess resize --size 512
+
+python inference.py --preprocess full --size 256
+python inference.py --preprocess full --size 512
+
+python inference.py --preprocess extfull --size 256
+python inference.py --preprocess extfull --size 512
+
+python inference.py --preprocess full --size 256 --enhancer gfpgan
+python inference.py --preprocess full --size 512 --enhancer gfpgan
+
+python inference.py --preprocess full --size 256 --enhancer gfpgan --still
+python inference.py --preprocess full --size 512 --enhancer gfpgan --still
diff --git a/src/audio2pose_models/audio_encoder.py b/src/audio2pose_models/audio_encoder.py
index ea9095ad..6279d201 100644
--- a/src/audio2pose_models/audio_encoder.py
+++ b/src/audio2pose_models/audio_encoder.py
@@ -41,14 +41,14 @@ def __init__(self, wav2lip_checkpoint, device):
Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)
- #### load the pre-trained audio_encoder
- wav2lip_state_dict = torch.load(wav2lip_checkpoint, map_location=torch.device(device))['state_dict']
- state_dict = self.audio_encoder.state_dict()
+ #### load the pre-trained audio_encoder, we do not need to load wav2lip model here.
+ # wav2lip_state_dict = torch.load(wav2lip_checkpoint, map_location=torch.device(device))['state_dict']
+ # state_dict = self.audio_encoder.state_dict()
- for k,v in wav2lip_state_dict.items():
- if 'audio_encoder' in k:
- state_dict[k.replace('module.audio_encoder.', '')] = v
- self.audio_encoder.load_state_dict(state_dict)
+ # for k,v in wav2lip_state_dict.items():
+ # if 'audio_encoder' in k:
+ # state_dict[k.replace('module.audio_encoder.', '')] = v
+ # self.audio_encoder.load_state_dict(state_dict)
def forward(self, audio_sequences):
diff --git a/src/config/similarity_Lm3D_all.mat b/src/config/similarity_Lm3D_all.mat
new file mode 100644
index 00000000..a0e23588
Binary files /dev/null and b/src/config/similarity_Lm3D_all.mat differ
diff --git a/src/face3d/extract_kp_videos_safe.py b/src/face3d/extract_kp_videos_safe.py
index 5c9cff87..262439b9 100644
--- a/src/face3d/extract_kp_videos_safe.py
+++ b/src/face3d/extract_kp_videos_safe.py
@@ -55,11 +55,6 @@ def extract_keypoint(self, images, name=None, info=True):
bboxes = self.det_net.detect_faces(images, 0.97)
bboxes = bboxes[0]
-
- # bboxes[0] -= 100
- # bboxes[1] -= 100
- # bboxes[2] += 100
- # bboxes[3] += 100
img = img[int(bboxes[1]):int(bboxes[3]), int(bboxes[0]):int(bboxes[2]), :]
keypoints = landmark_98_to_68(self.detector.get_landmarks(img)) # [0]
diff --git a/src/facerender/animate.py b/src/facerender/animate.py
index 3adea961..563d87fe 100644
--- a/src/facerender/animate.py
+++ b/src/facerender/animate.py
@@ -4,7 +4,8 @@
import numpy as np
import warnings
from skimage import img_as_ubyte
-
+import safetensors
+import safetensors.torch
warnings.filterwarnings('ignore')
@@ -26,10 +27,9 @@
class AnimateFromCoeff():
- def __init__(self, free_view_checkpoint, mapping_checkpoint,
- config_path, device):
+ def __init__(self, sadtalker_path, device):
- with open(config_path) as f:
+ with open(sadtalker_path['facerender_yaml']) as f:
config = yaml.safe_load(f)
generator = OcclusionAwareSPADEGenerator(**config['model_params']['generator_params'],
@@ -40,7 +40,6 @@ def __init__(self, free_view_checkpoint, mapping_checkpoint,
**config['model_params']['common_params'])
mapping = MappingNet(**config['model_params']['mapping_params'])
-
generator.to(device)
kp_extractor.to(device)
he_estimator.to(device)
@@ -54,13 +53,16 @@ def __init__(self, free_view_checkpoint, mapping_checkpoint,
for param in mapping.parameters():
param.requires_grad = False
- if free_view_checkpoint is not None:
- self.load_cpk_facevid2vid(free_view_checkpoint, kp_detector=kp_extractor, generator=generator, he_estimator=he_estimator)
+ if sadtalker_path is not None:
+ if 'safetensors' in sadtalker_path['checkpoint']:
+ self.load_cpk_facevid2vid_safetensor(sadtalker_path['checkpoint'], kp_detector=kp_extractor, generator=generator, he_estimator=None)
+ else:
+ self.load_cpk_facevid2vid(sadtalker_path['free_view_checkpoint'], kp_detector=kp_extractor, generator=generator, he_estimator=he_estimator)
else:
raise AttributeError("Checkpoint should be specified for video head pose estimator.")
- if mapping_checkpoint is not None:
- self.load_cpk_mapping(mapping_checkpoint, mapping=mapping)
+ if sadtalker_path['mappingnet_checkpoint'] is not None:
+ self.load_cpk_mapping(sadtalker_path['mappingnet_checkpoint'], mapping=mapping)
else:
raise AttributeError("Checkpoint should be specified for video head pose estimator.")
@@ -76,6 +78,33 @@ def __init__(self, free_view_checkpoint, mapping_checkpoint,
self.device = device
+ def load_cpk_facevid2vid_safetensor(self, checkpoint_path, generator=None,
+ kp_detector=None, he_estimator=None,
+ device="cpu"):
+
+ checkpoint = safetensors.torch.load_file(checkpoint_path)
+
+ if generator is not None:
+ x_generator = {}
+ for k,v in checkpoint.items():
+ if 'generator' in k:
+ x_generator[k.replace('generator.', '')] = v
+ generator.load_state_dict(x_generator)
+ if kp_detector is not None:
+ x_generator = {}
+ for k,v in checkpoint.items():
+ if 'kp_extractor' in k:
+ x_generator[k.replace('kp_extractor.', '')] = v
+ kp_detector.load_state_dict(x_generator)
+ if he_estimator is not None:
+ x_generator = {}
+ for k,v in checkpoint.items():
+ if 'he_estimator' in k:
+ x_generator[k.replace('he_estimator.', '')] = v
+ he_estimator.load_state_dict(x_generator)
+
+ return None
+
def load_cpk_facevid2vid(self, checkpoint_path, generator=None, discriminator=None,
kp_detector=None, he_estimator=None, optimizer_generator=None,
optimizer_discriminator=None, optimizer_kp_detector=None,
@@ -120,7 +149,7 @@ def load_cpk_mapping(self, checkpoint_path, mapping=None, discriminator=None,
return checkpoint['epoch']
- def generate(self, x, video_save_dir, pic_path, crop_info, enhancer=None, background_enhancer=None, preprocess='crop'):
+ def generate(self, x, video_save_dir, pic_path, crop_info, enhancer=None, background_enhancer=None, preprocess='crop', img_size=256):
source_image=x['source_image'].type(torch.FloatTensor)
source_semantics=x['source_semantics'].type(torch.FloatTensor)
@@ -160,10 +189,10 @@ def generate(self, x, video_save_dir, pic_path, crop_info, enhancer=None, backgr
video.append(image)
result = img_as_ubyte(video)
- ### the generated video is 256x256, so we keep the aspect ratio,
+ ### the generated video is 256x256, so we keep the aspect ratio,
original_size = crop_info[0]
if original_size:
- result = [ cv2.resize(result_i,(256, int(256.0 * original_size[1]/original_size[0]) )) for result_i in result ]
+ result = [ cv2.resize(result_i,(img_size, int(img_size * original_size[1]/original_size[0]) )) for result_i in result ]
video_name = x['video_name'] + '.mp4'
path = os.path.join(video_save_dir, 'temp_'+video_name)
@@ -185,14 +214,14 @@ def generate(self, x, video_save_dir, pic_path, crop_info, enhancer=None, backgr
word.export(new_audio_path, format="wav")
save_video_with_watermark(path, new_audio_path, av_path, watermark= False)
- print(f'The generated video is named {video_name} in {video_save_dir}')
+ print(f'The generated video is named {video_save_dir}/{video_name}')
- if preprocess.lower() == 'full':
+ if 'full' in preprocess.lower():
# only add watermark to the full image.
video_name_full = x['video_name'] + '_full.mp4'
full_video_path = os.path.join(video_save_dir, video_name_full)
return_path = full_video_path
- paste_pic(path, pic_path, crop_info, new_audio_path, full_video_path)
+ paste_pic(path, pic_path, crop_info, new_audio_path, full_video_path, extended_crop= True if 'ext' in preprocess.lower() else False)
print(f'The generated video is named {video_save_dir}/{video_name_full}')
else:
full_video_path = av_path
diff --git a/src/facerender/modules/make_animation.py b/src/facerender/modules/make_animation.py
index e7887a3f..b2616648 100644
--- a/src/facerender/modules/make_animation.py
+++ b/src/facerender/modules/make_animation.py
@@ -29,7 +29,7 @@ def normalize_kp(kp_source, kp_driving, kp_driving_initial, adapt_movement_scale
def headpose_pred_to_degree(pred):
device = pred.device
idx_tensor = [idx for idx in range(66)]
- idx_tensor = torch.FloatTensor(idx_tensor).to(device)
+ idx_tensor = torch.FloatTensor(idx_tensor).type_as(pred).to(device)
pred = F.softmax(pred)
degree = torch.sum(pred*idx_tensor, 1) * 3 - 99
return degree
@@ -102,7 +102,7 @@ def keypoint_transformation(kp_canonical, he, wo_exp=False):
def make_animation(source_image, source_semantics, target_semantics,
generator, kp_detector, he_estimator, mapping,
yaw_c_seq=None, pitch_c_seq=None, roll_c_seq=None,
- use_exp=True):
+ use_exp=True, use_half=False):
with torch.no_grad():
predictions = []
@@ -122,8 +122,6 @@ def make_animation(source_image, source_semantics, target_semantics,
kp_driving = keypoint_transformation(kp_canonical, he_driving)
- #kp_norm = normalize_kp(kp_source=kp_source, kp_driving=kp_driving,
- #kp_driving_initial=kp_driving_initial)
kp_norm = kp_driving
out = generator(source_image, kp_source=kp_source, kp_driving=kp_norm)
'''
diff --git a/src/generate_facerender_batch.py b/src/generate_facerender_batch.py
index 9ec7a169..7538b475 100644
--- a/src/generate_facerender_batch.py
+++ b/src/generate_facerender_batch.py
@@ -7,7 +7,7 @@
def get_facerender_data(coeff_path, pic_path, first_coeff_path, audio_path,
batch_size, input_yaw_list=None, input_pitch_list=None, input_roll_list=None,
- expression_scale=1.0, still_mode = False, preprocess='crop'):
+ expression_scale=1.0, still_mode = False, preprocess='crop', size = 256):
semantic_radius = 13
video_name = os.path.splitext(os.path.split(coeff_path)[-1])[0]
@@ -18,7 +18,7 @@ def get_facerender_data(coeff_path, pic_path, first_coeff_path, audio_path,
img1 = Image.open(pic_path)
source_image = np.array(img1)
source_image = img_as_float32(source_image)
- source_image = transform.resize(source_image, (256, 256, 3))
+ source_image = transform.resize(source_image, (size, size, 3))
source_image = source_image.transpose((2, 0, 1))
source_image_ts = torch.FloatTensor(source_image).unsqueeze(0)
source_image_ts = source_image_ts.repeat(batch_size, 1, 1, 1)
@@ -26,7 +26,7 @@ def get_facerender_data(coeff_path, pic_path, first_coeff_path, audio_path,
source_semantics_dict = scio.loadmat(first_coeff_path)
- if preprocess.lower() != 'full':
+ if 'full' not in preprocess.lower():
source_semantics = source_semantics_dict['coeff_3dmm'][:1,:70] #1 70
else:
source_semantics = source_semantics_dict['coeff_3dmm'][:1,:73] #1 70
@@ -41,7 +41,7 @@ def get_facerender_data(coeff_path, pic_path, first_coeff_path, audio_path,
generated_3dmm = generated_dict['coeff_3dmm']
generated_3dmm[:, :64] = generated_3dmm[:, :64] * expression_scale
- if preprocess.lower() == 'full':
+ if 'full' in preprocess.lower():
generated_3dmm = np.concatenate([generated_3dmm, np.repeat(source_semantics[:,70:], generated_3dmm.shape[0], axis=0)], axis=1)
if still_mode:
diff --git a/src/gradio_demo.py b/src/gradio_demo.py
index 4bdb2177..11eaf9cd 100644
--- a/src/gradio_demo.py
+++ b/src/gradio_demo.py
@@ -6,8 +6,11 @@
from src.generate_batch import get_data
from src.generate_facerender_batch import get_facerender_data
+from src.utils.init_path import init_path
+
from pydub import AudioSegment
+
def mp3_to_wav(mp3_filename,wav_filename,frame_rate):
mp3_file = AudioSegment.from_file(file=mp3_filename)
mp3_file.set_frame_rate(frame_rate).export(wav_filename,format="wav")
@@ -28,56 +31,17 @@ def __init__(self, checkpoint_path='checkpoints', config_path='src/config', lazy
self.checkpoint_path = checkpoint_path
self.config_path = config_path
+
- self.path_of_lm_croper = os.path.join( checkpoint_path, 'shape_predictor_68_face_landmarks.dat')
- self.path_of_net_recon_model = os.path.join( checkpoint_path, 'epoch_20.pth')
- self.dir_of_BFM_fitting = os.path.join( checkpoint_path, 'BFM_Fitting')
- self.wav2lip_checkpoint = os.path.join( checkpoint_path, 'wav2lip.pth')
-
- self.audio2pose_checkpoint = os.path.join( checkpoint_path, 'auido2pose_00140-model.pth')
- self.audio2pose_yaml_path = os.path.join( config_path, 'auido2pose.yaml')
-
- self.audio2exp_checkpoint = os.path.join( checkpoint_path, 'auido2exp_00300-model.pth')
- self.audio2exp_yaml_path = os.path.join( config_path, 'auido2exp.yaml')
-
- self.free_view_checkpoint = os.path.join( checkpoint_path, 'facevid2vid_00189-model.pth.tar')
+ def test(self, source_image, driven_audio, preprocess='crop',
+ still_mode=False, use_enhancer=False, batch_size=1, size=256, pose_style = 0,result_dir='./results/'):
- self.lazy_load = lazy_load
-
- if not self.lazy_load:
- #init model
+ self.sadtalker_paths = init_path(self.checkpoint_path, self.config_path, size, False, preprocess)
+ print(self.sadtalker_paths)
- print(self.audio2pose_checkpoint)
- self.audio_to_coeff = Audio2Coeff(self.audio2pose_checkpoint, self.audio2pose_yaml_path,
- self.audio2exp_checkpoint, self.audio2exp_yaml_path, self.wav2lip_checkpoint, self.device)
-
- print(self.path_of_lm_croper)
- self.preprocess_model = CropAndExtract(self.path_of_lm_croper, self.path_of_net_recon_model, self.dir_of_BFM_fitting, self.device)
-
- def test(self, source_image, driven_audio, preprocess='crop', still_mode=False, use_enhancer=False, result_dir='./results/'):
-
- ### crop: only model,
-
- if self.lazy_load:
- #init model
-
- print(self.audio2pose_checkpoint)
- self.audio_to_coeff = Audio2Coeff(self.audio2pose_checkpoint, self.audio2pose_yaml_path,
- self.audio2exp_checkpoint, self.audio2exp_yaml_path, self.wav2lip_checkpoint, self.device)
-
- print(self.path_of_lm_croper)
- self.preprocess_model = CropAndExtract(self.path_of_lm_croper, self.path_of_net_recon_model, self.dir_of_BFM_fitting, self.device)
-
- if preprocess == 'full':
- self.mapping_checkpoint = os.path.join(self.checkpoint_path, 'mapping_00109-model.pth.tar')
- self.facerender_yaml_path = os.path.join(self.config_path, 'facerender_still.yaml')
- else:
- self.mapping_checkpoint = os.path.join(self.checkpoint_path, 'mapping_00229-model.pth.tar')
- self.facerender_yaml_path = os.path.join(self.config_path, 'facerender.yaml')
-
- print(self.free_view_checkpoint)
- self.animate_from_coeff = AnimateFromCoeff(self.free_view_checkpoint, self.mapping_checkpoint,
- self.facerender_yaml_path, self.device)
+ self.audio_to_coeff = Audio2Coeff(self.sadtalker_paths, self.device)
+ self.preprocess_model = CropAndExtract(self.sadtalker_paths, self.device)
+ self.animate_from_coeff = AnimateFromCoeff(self.sadtalker_paths, self.device)
time_tag = str(uuid.uuid4())
save_dir = os.path.join(result_dir, time_tag)
@@ -104,11 +68,11 @@ def test(self, source_image, driven_audio, preprocess='crop', still_mode=False,
os.makedirs(save_dir, exist_ok=True)
- pose_style = 0
+
#crop image and extract 3dmm from image
first_frame_dir = os.path.join(save_dir, 'first_frame_dir')
os.makedirs(first_frame_dir, exist_ok=True)
- first_coeff_path, crop_pic_path, crop_info = self.preprocess_model.generate(pic_path, first_frame_dir, preprocess)
+ first_coeff_path, crop_pic_path, crop_info = self.preprocess_model.generate(pic_path, first_frame_dir, preprocess, True, size)
if first_coeff_path is None:
raise AttributeError("No face is detected")
@@ -117,16 +81,14 @@ def test(self, source_image, driven_audio, preprocess='crop', still_mode=False,
batch = get_data(first_coeff_path, audio_path, self.device, ref_eyeblink_coeff_path=None, still=still_mode) # longer audio?
coeff_path = self.audio_to_coeff.generate(batch, save_dir, pose_style)
#coeff2video
- batch_size = 2
- data = get_facerender_data(coeff_path, crop_pic_path, first_coeff_path, audio_path, batch_size, still_mode=still_mode, preprocess=preprocess)
- return_path = self.animate_from_coeff.generate(data, save_dir, pic_path, crop_info, enhancer='gfpgan' if use_enhancer else None, preprocess=preprocess)
+ data = get_facerender_data(coeff_path, crop_pic_path, first_coeff_path, audio_path, batch_size, still_mode=still_mode, preprocess=preprocess, size=size)
+ return_path = self.animate_from_coeff.generate(data, save_dir, pic_path, crop_info, enhancer='gfpgan' if use_enhancer else None, preprocess=preprocess, img_size=size)
video_name = data['video_name']
print(f'The generated video is named {video_name} in {save_dir}')
- if self.lazy_load:
- del self.preprocess_model
- del self.audio_to_coeff
- del self.animate_from_coeff
+ del self.preprocess_model
+ del self.audio_to_coeff
+ del self.animate_from_coeff
if torch.cuda.is_available():
torch.cuda.empty_cache()
diff --git a/src/test_audio2coeff.py b/src/test_audio2coeff.py
index c3c6abcf..be66bffd 100644
--- a/src/test_audio2coeff.py
+++ b/src/test_audio2coeff.py
@@ -5,9 +5,13 @@
from yacs.config import CfgNode as CN
from scipy.signal import savgol_filter
+import safetensors
+import safetensors.torch
+
from src.audio2pose_models.audio2pose import Audio2Pose
from src.audio2exp_models.networks import SimpleWrapperV2
-from src.audio2exp_models.audio2exp import Audio2Exp
+from src.audio2exp_models.audio2exp import Audio2Exp
+from src.utils.safetensor_helper import load_x_from_safetensor
def load_cpk(checkpoint_path, model=None, optimizer=None, device="cpu"):
checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
@@ -20,25 +24,28 @@ def load_cpk(checkpoint_path, model=None, optimizer=None, device="cpu"):
class Audio2Coeff():
- def __init__(self, audio2pose_checkpoint, audio2pose_yaml_path,
- audio2exp_checkpoint, audio2exp_yaml_path,
- wav2lip_checkpoint, device):
+ def __init__(self, sadtalker_path, device):
#load config
- fcfg_pose = open(audio2pose_yaml_path)
+ fcfg_pose = open(sadtalker_path['audio2pose_yaml_path'])
cfg_pose = CN.load_cfg(fcfg_pose)
cfg_pose.freeze()
- fcfg_exp = open(audio2exp_yaml_path)
+ fcfg_exp = open(sadtalker_path['audio2exp_yaml_path'])
cfg_exp = CN.load_cfg(fcfg_exp)
cfg_exp.freeze()
# load audio2pose_model
- self.audio2pose_model = Audio2Pose(cfg_pose, wav2lip_checkpoint, device=device)
+ self.audio2pose_model = Audio2Pose(cfg_pose, None, device=device)
self.audio2pose_model = self.audio2pose_model.to(device)
self.audio2pose_model.eval()
for param in self.audio2pose_model.parameters():
param.requires_grad = False
+
try:
- load_cpk(audio2pose_checkpoint, model=self.audio2pose_model, device=device)
+ if sadtalker_path['use_safetensor']:
+ checkpoints = safetensors.torch.load_file(sadtalker_path['checkpoint'])
+ self.audio2pose_model.load_state_dict(load_x_from_safetensor(checkpoints, 'audio2pose'))
+ else:
+ load_cpk(sadtalker_path['audio2pose_checkpoint'], model=self.audio2pose_model, device=device)
except:
raise Exception("Failed in loading audio2pose_checkpoint")
@@ -49,7 +56,11 @@ def __init__(self, audio2pose_checkpoint, audio2pose_yaml_path,
netG.requires_grad = False
netG.eval()
try:
- load_cpk(audio2exp_checkpoint, model=netG, device=device)
+ if sadtalker_path['use_safetensor']:
+ checkpoints = safetensors.torch.load_file(sadtalker_path['checkpoint'])
+ netG.load_state_dict(load_x_from_safetensor(checkpoints, 'audio2exp'))
+ else:
+ load_cpk(sadtalker_path['audio2exp_checkpoint'], model=netG, device=device)
except:
raise Exception("Failed in loading audio2exp_checkpoint")
self.audio2exp_model = Audio2Exp(netG, cfg_exp, device=device, prepare_training_loss=False)
@@ -106,7 +117,8 @@ def using_refpose(self, coeffs_pred_numpy, ref_pose_coeff_path):
refpose_coeff_list.append(refpose_coeff[:re, :])
refpose_coeff = np.concatenate(refpose_coeff_list, axis=0)
- coeffs_pred_numpy[:, 64:70] = refpose_coeff[:num_frames, :]
+ #### relative head pose
+ coeffs_pred_numpy[:, 64:70] = coeffs_pred_numpy[:, 64:70] + ( refpose_coeff[:num_frames, :] - refpose_coeff[0:1, :] )
return coeffs_pred_numpy
diff --git a/src/utils/croper.py b/src/utils/croper.py
index 0ccf5dfc..3d9a0ac5 100644
--- a/src/utils/croper.py
+++ b/src/utils/croper.py
@@ -6,57 +6,38 @@
import scipy
import numpy as np
from PIL import Image
+import torch
from tqdm import tqdm
from itertools import cycle
-from torch.multiprocessing import Pool, Process, set_start_method
-
-
-"""
-brief: face alignment with FFHQ method (https://github.com/NVlabs/ffhq-dataset)
-author: lzhbrian (https://lzhbrian.me)
-date: 2020.1.5
-note: code is heavily borrowed from
- https://github.com/NVlabs/ffhq-dataset
- http://dlib.net/face_landmark_detection.py.html
-requirements:
- apt install cmake
- conda install Pillow numpy scipy
- pip install dlib
- # download face landmark model from:
- # http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2
-"""
+from src.face3d.extract_kp_videos_safe import KeypointExtractor
+from facexlib.alignment import landmark_98_to_68
import numpy as np
from PIL import Image
-import dlib
-
-class Croper:
- def __init__(self, path_of_lm):
- # download model from: http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2
- self.predictor = dlib.shape_predictor(path_of_lm)
+class Preprocesser:
+ def __init__(self, device='cuda'):
+ self.predictor = KeypointExtractor(device)
def get_landmark(self, img_np):
"""get landmark with dlib
:return: np.array shape=(68, 2)
"""
- detector = dlib.get_frontal_face_detector()
- dets = detector(img_np, 1)
- # print("Number of faces detected: {}".format(len(dets)))
- # for k, d in enumerate(dets):
+ with torch.no_grad():
+ dets = self.predictor.det_net.detect_faces(img_np, 0.97)
+
if len(dets) == 0:
return None
- d = dets[0]
- # Get the landmarks/parts for the face in box d.
- shape = self.predictor(img_np, d)
- # print("Part 0: {}, Part 1: {} ...".format(shape.part(0), shape.part(1)))
- t = list(shape.parts())
- a = []
- for tt in t:
- a.append([tt.x, tt.y])
- lm = np.array(a)
- # lm is a shape=(68,2) np.array
+ det = dets[0]
+
+ img = img_np[int(det[1]):int(det[3]), int(det[0]):int(det[2]), :]
+ lm = landmark_98_to_68(self.predictor.detector.get_landmarks(img)) # [0]
+
+ #### keypoints to the original location
+ lm[:,0] += int(det[0])
+ lm[:,1] += int(det[1])
+
return lm
def align_face(self, img, lm, output_size=1024):
@@ -138,34 +119,14 @@ def align_face(self, img, lm, output_size=1024):
ly = max(min(quad[1], quad[7]), 0)
rx = min(max(quad[4], quad[6]), img.size[0])
ry = min(max(quad[3], quad[5]), img.size[0])
- # img = img.transform((transform_size, transform_size), Image.QUAD, (quad + 0.5).flatten(),
- # Image.BILINEAR)
- # if output_size < transform_size:
- # img = img.resize((output_size, output_size), Image.ANTIALIAS)
# Save aligned image.
return rsize, crop, [lx, ly, rx, ry]
-
- # def crop(self, img_np_list):
- # for _i in range(len(img_np_list)):
- # img_np = img_np_list[_i]
- # lm = self.get_landmark(img_np)
- # if lm is None:
- # return None
- # crop, quad = self.align_face(img=Image.fromarray(img_np), lm=lm, output_size=512)
- # clx, cly, crx, cry = crop
- # lx, ly, rx, ry = quad
- # lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry)
-
- # _inp = img_np_list[_i]
- # _inp = _inp[cly:cry, clx:crx]
- # _inp = _inp[ly:ry, lx:rx]
- # img_np_list[_i] = _inp
- # return img_np_list
def crop(self, img_np_list, still=False, xsize=512): # first frame for all video
img_np = img_np_list[0]
lm = self.get_landmark(img_np)
+
if lm is None:
raise 'can not detect the landmark from source image'
rsize, crop, quad = self.align_face(img=Image.fromarray(img_np), lm=lm, output_size=xsize)
@@ -176,124 +137,8 @@ def crop(self, img_np_list, still=False, xsize=512): # first frame for all vi
_inp = img_np_list[_i]
_inp = cv2.resize(_inp, (rsize[0], rsize[1]))
_inp = _inp[cly:cry, clx:crx]
- # cv2.imwrite('test1.jpg', _inp)
if not still:
_inp = _inp[ly:ry, lx:rx]
- # cv2.imwrite('test2.jpg', _inp)
img_np_list[_i] = _inp
return img_np_list, crop, quad
-
-def read_video(filename, uplimit=100):
- frames = []
- cap = cv2.VideoCapture(filename)
- cnt = 0
- while cap.isOpened():
- ret, frame = cap.read()
- if ret:
- frame = cv2.resize(frame, (512, 512))
- frames.append(frame)
- else:
- break
- cnt += 1
- if cnt >= uplimit:
- break
- cap.release()
- assert len(frames) > 0, f'{filename}: video with no frames!'
- return frames
-
-
-def create_video(video_name, frames, fps=25, video_format='.mp4', resize_ratio=1):
- # video_name = os.path.dirname(image_folder) + video_format
- # img_list = glob.glob1(image_folder, 'frame*')
- # img_list.sort()
- # frame = cv2.imread(os.path.join(image_folder, img_list[0]))
- # frame = cv2.resize(frame, (0, 0), fx=resize_ratio, fy=resize_ratio)
- # height, width, layers = frames[0].shape
- height, width, layers = 512, 512, 3
- if video_format == '.mp4':
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
- elif video_format == '.avi':
- fourcc = cv2.VideoWriter_fourcc(*'XVID')
- video = cv2.VideoWriter(video_name, fourcc, fps, (width, height))
- for _frame in frames:
- _frame = cv2.resize(_frame, (height, width), interpolation=cv2.INTER_LINEAR)
- video.write(_frame)
-
-def create_images(video_name, frames):
- height, width, layers = 512, 512, 3
- images_dir = video_name.split('.')[0]
- os.makedirs(images_dir, exist_ok=True)
- for i, _frame in enumerate(frames):
- _frame = cv2.resize(_frame, (height, width), interpolation=cv2.INTER_LINEAR)
- _frame_path = os.path.join(images_dir, str(i)+'.jpg')
- cv2.imwrite(_frame_path, _frame)
-
-def run(data):
- filename, opt, device = data
- os.environ['CUDA_VISIBLE_DEVICES'] = device
- croper = Croper()
-
- frames = read_video(filename, uplimit=opt.uplimit)
- name = filename.split('/')[-1] # .split('.')[0]
- name = os.path.join(opt.output_dir, name)
-
- frames = croper.crop(frames)
- if frames is None:
- print(f'{name}: detect no face. should removed')
- return
- # create_video(name, frames)
- create_images(name, frames)
-
-
-def get_data_path(video_dir):
- eg_video_files = ['/apdcephfs/share_1290939/quincheng/datasets/HDTF/backup_fps25/WDA_KatieHill_000.mp4']
- # filenames = list()
- # VIDEO_EXTENSIONS_LOWERCASE = {'mp4'}
- # VIDEO_EXTENSIONS = VIDEO_EXTENSIONS_LOWERCASE.union({f.upper() for f in VIDEO_EXTENSIONS_LOWERCASE})
- # extensions = VIDEO_EXTENSIONS
- # for ext in extensions:
- # filenames = sorted(glob.glob(f'{opt.input_dir}/**/*.{ext}'))
- # print('Total number of videos:', len(filenames))
- return eg_video_files
-
-
-def get_wra_data_path(video_dir):
- if opt.option == 'video':
- videos_path = sorted(glob.glob(f'{video_dir}/*.mp4'))
- elif opt.option == 'image':
- videos_path = sorted(glob.glob(f'{video_dir}/*/'))
- else:
- raise NotImplementedError
- print('Example videos: ', videos_path[:2])
- return videos_path
-
-
-if __name__ == '__main__':
- set_start_method('spawn')
- parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
- parser.add_argument('--input_dir', type=str, help='the folder of the input files')
- parser.add_argument('--output_dir', type=str, help='the folder of the output files')
- parser.add_argument('--device_ids', type=str, default='0,1')
- parser.add_argument('--workers', type=int, default=8)
- parser.add_argument('--uplimit', type=int, default=500)
- parser.add_argument('--option', type=str, default='video')
-
- root = '/apdcephfs/share_1290939/quincheng/datasets/HDTF'
- cmd = f'--input_dir {root}/backup_fps25_first20s_sync/ ' \
- f'--output_dir {root}/crop512_stylegan_firstframe_sync/ ' \
- '--device_ids 0 ' \
- '--workers 8 ' \
- '--option video ' \
- '--uplimit 500 '
- opt = parser.parse_args(cmd.split())
- # filenames = get_data_path(opt.input_dir)
- filenames = get_wra_data_path(opt.input_dir)
- os.makedirs(opt.output_dir, exist_ok=True)
- print(f'Video numbers: {len(filenames)}')
- pool = Pool(opt.workers)
- args_list = cycle([opt])
- device_ids = opt.device_ids.split(",")
- device_ids = cycle(device_ids)
- for data in tqdm(pool.imap_unordered(run, zip(filenames, args_list, device_ids))):
- None
diff --git a/src/utils/init_path.py b/src/utils/init_path.py
new file mode 100644
index 00000000..5f38d119
--- /dev/null
+++ b/src/utils/init_path.py
@@ -0,0 +1,47 @@
+import os
+import glob
+
+def init_path(checkpoint_dir, config_dir, size=512, old_version=False, preprocess='crop'):
+
+ if old_version:
+ #### load all the checkpoint of `pth`
+ sadtalker_paths = {
+ 'wav2lip_checkpoint' : os.path.join(checkpoint_dir, 'wav2lip.pth'),
+ 'audio2pose_checkpoint' : os.path.join(checkpoint_dir, 'auido2pose_00140-model.pth'),
+ 'audio2exp_checkpoint' : os.path.join(checkpoint_dir, 'auido2exp_00300-model.pth'),
+ 'free_view_checkpoint' : os.path.join(checkpoint_dir, 'facevid2vid_00189-model.pth.tar'),
+ 'path_of_net_recon_model' : os.path.join(checkpoint_dir, 'epoch_20.pth')
+ }
+
+ use_safetensor = False
+ elif len(glob.glob(os.path.join(checkpoint_dir, '*.safetensors'))):
+ print('using safetensor as default')
+ sadtalker_paths = {
+ "checkpoint":os.path.join(checkpoint_dir, 'SadTalker_V0.0.2_'+str(size)+'.safetensors'),
+ }
+ use_safetensor = True
+ else:
+ print("WARNING: The new version of the model will be updated by safetensor, you may need to download it mannully. We run the old version of the checkpoint this time!")
+ use_safetensor = False
+
+ sadtalker_paths = {
+ 'wav2lip_checkpoint' : os.path.join(checkpoint_dir, 'wav2lip.pth'),
+ 'audio2pose_checkpoint' : os.path.join(checkpoint_dir, 'auido2pose_00140-model.pth'),
+ 'audio2exp_checkpoint' : os.path.join(checkpoint_dir, 'auido2exp_00300-model.pth'),
+ 'free_view_checkpoint' : os.path.join(checkpoint_dir, 'facevid2vid_00189-model.pth.tar'),
+ 'path_of_net_recon_model' : os.path.join(checkpoint_dir, 'epoch_20.pth')
+ }
+
+ sadtalker_paths['dir_of_BFM_fitting'] = os.path.join(config_dir) # , 'BFM_Fitting'
+ sadtalker_paths['audio2pose_yaml_path'] = os.path.join(config_dir, 'auido2pose.yaml')
+ sadtalker_paths['audio2exp_yaml_path'] = os.path.join(config_dir, 'auido2exp.yaml')
+ sadtalker_paths['use_safetensor'] = use_safetensor # os.path.join(config_dir, 'auido2exp.yaml')
+
+ if 'full' in preprocess:
+ sadtalker_paths['mappingnet_checkpoint'] = os.path.join(checkpoint_dir, 'mapping_00109-model.pth.tar')
+ sadtalker_paths['facerender_yaml'] = os.path.join(config_dir, 'facerender_still.yaml')
+ else:
+ sadtalker_paths['mappingnet_checkpoint'] = os.path.join(checkpoint_dir, 'mapping_00229-model.pth.tar')
+ sadtalker_paths['facerender_yaml'] = os.path.join(config_dir, 'facerender.yaml')
+
+ return sadtalker_paths
\ No newline at end of file
diff --git a/src/utils/model2safetensor.py b/src/utils/model2safetensor.py
new file mode 100644
index 00000000..50c48500
--- /dev/null
+++ b/src/utils/model2safetensor.py
@@ -0,0 +1,141 @@
+import torch
+import yaml
+import os
+
+import safetensors
+from safetensors.torch import save_file
+from yacs.config import CfgNode as CN
+import sys
+
+sys.path.append('/apdcephfs/private_shadowcun/SadTalker')
+
+from src.face3d.models import networks
+
+from src.facerender.modules.keypoint_detector import HEEstimator, KPDetector
+from src.facerender.modules.mapping import MappingNet
+from src.facerender.modules.generator import OcclusionAwareGenerator, OcclusionAwareSPADEGenerator
+
+from src.audio2pose_models.audio2pose import Audio2Pose
+from src.audio2exp_models.networks import SimpleWrapperV2
+from src.test_audio2coeff import load_cpk
+
+size = 256
+############ face vid2vid
+config_path = os.path.join('src', 'config', 'facerender.yaml')
+current_root_path = '.'
+
+path_of_net_recon_model = os.path.join(current_root_path, 'checkpoints', 'epoch_20.pth')
+net_recon = networks.define_net_recon(net_recon='resnet50', use_last_fc=False, init_path='')
+checkpoint = torch.load(path_of_net_recon_model, map_location='cpu')
+net_recon.load_state_dict(checkpoint['net_recon'])
+
+with open(config_path) as f:
+ config = yaml.safe_load(f)
+
+generator = OcclusionAwareSPADEGenerator(**config['model_params']['generator_params'],
+ **config['model_params']['common_params'])
+kp_extractor = KPDetector(**config['model_params']['kp_detector_params'],
+ **config['model_params']['common_params'])
+he_estimator = HEEstimator(**config['model_params']['he_estimator_params'],
+ **config['model_params']['common_params'])
+mapping = MappingNet(**config['model_params']['mapping_params'])
+
+def load_cpk_facevid2vid(checkpoint_path, generator=None, discriminator=None,
+ kp_detector=None, he_estimator=None, optimizer_generator=None,
+ optimizer_discriminator=None, optimizer_kp_detector=None,
+ optimizer_he_estimator=None, device="cpu"):
+
+ checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
+ if generator is not None:
+ generator.load_state_dict(checkpoint['generator'])
+ if kp_detector is not None:
+ kp_detector.load_state_dict(checkpoint['kp_detector'])
+ if he_estimator is not None:
+ he_estimator.load_state_dict(checkpoint['he_estimator'])
+ if discriminator is not None:
+ try:
+ discriminator.load_state_dict(checkpoint['discriminator'])
+ except:
+ print ('No discriminator in the state-dict. Dicriminator will be randomly initialized')
+ if optimizer_generator is not None:
+ optimizer_generator.load_state_dict(checkpoint['optimizer_generator'])
+ if optimizer_discriminator is not None:
+ try:
+ optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator'])
+ except RuntimeError as e:
+ print ('No discriminator optimizer in the state-dict. Optimizer will be not initialized')
+ if optimizer_kp_detector is not None:
+ optimizer_kp_detector.load_state_dict(checkpoint['optimizer_kp_detector'])
+ if optimizer_he_estimator is not None:
+ optimizer_he_estimator.load_state_dict(checkpoint['optimizer_he_estimator'])
+
+ return checkpoint['epoch']
+
+
+def load_cpk_facevid2vid_safetensor(checkpoint_path, generator=None,
+ kp_detector=None, he_estimator=None,
+ device="cpu"):
+
+ checkpoint = safetensors.torch.load_file(checkpoint_path)
+
+ if generator is not None:
+ x_generator = {}
+ for k,v in checkpoint.items():
+ if 'generator' in k:
+ x_generator[k.replace('generator.', '')] = v
+ generator.load_state_dict(x_generator)
+ if kp_detector is not None:
+ x_generator = {}
+ for k,v in checkpoint.items():
+ if 'kp_extractor' in k:
+ x_generator[k.replace('kp_extractor.', '')] = v
+ kp_detector.load_state_dict(x_generator)
+ if he_estimator is not None:
+ x_generator = {}
+ for k,v in checkpoint.items():
+ if 'he_estimator' in k:
+ x_generator[k.replace('he_estimator.', '')] = v
+ he_estimator.load_state_dict(x_generator)
+
+ return None
+
+free_view_checkpoint = '/apdcephfs/private_shadowcun/SadTalker/checkpoints/facevid2vid_'+str(size)+'-model.pth.tar'
+load_cpk_facevid2vid(free_view_checkpoint, kp_detector=kp_extractor, generator=generator, he_estimator=he_estimator)
+
+wav2lip_checkpoint = os.path.join(current_root_path, 'checkpoints', 'wav2lip.pth')
+
+audio2pose_checkpoint = os.path.join(current_root_path, 'checkpoints', 'auido2pose_00140-model.pth')
+audio2pose_yaml_path = os.path.join(current_root_path, 'src', 'config', 'auido2pose.yaml')
+
+audio2exp_checkpoint = os.path.join(current_root_path, 'checkpoints', 'auido2exp_00300-model.pth')
+audio2exp_yaml_path = os.path.join(current_root_path, 'src', 'config', 'auido2exp.yaml')
+
+fcfg_pose = open(audio2pose_yaml_path)
+cfg_pose = CN.load_cfg(fcfg_pose)
+cfg_pose.freeze()
+audio2pose_model = Audio2Pose(cfg_pose, wav2lip_checkpoint)
+audio2pose_model.eval()
+load_cpk(audio2pose_checkpoint, model=audio2pose_model, device='cpu')
+
+# load audio2exp_model
+netG = SimpleWrapperV2()
+netG.eval()
+load_cpk(audio2exp_checkpoint, model=netG, device='cpu')
+
+class SadTalker(torch.nn.Module):
+ def __init__(self, kp_extractor, generator, netG, audio2pose, face_3drecon):
+ super(SadTalker, self).__init__()
+ self.kp_extractor = kp_extractor
+ self.generator = generator
+ self.audio2exp = netG
+ self.audio2pose = audio2pose
+ self.face_3drecon = face_3drecon
+
+
+model = SadTalker(kp_extractor, generator, netG, audio2pose_model, net_recon)
+
+# here, we want to convert it to safetensor
+save_file(model.state_dict(), "checkpoints/SadTalker_V0.0.2_"+str(size)+".safetensors")
+
+### test
+load_cpk_facevid2vid_safetensor('checkpoints/SadTalker_V0.0.2_'+str(size)+'.safetensors', kp_detector=kp_extractor, generator=generator, he_estimator=None)
\ No newline at end of file
diff --git a/src/utils/paste_pic.py b/src/utils/paste_pic.py
index cf25aebc..f9989e21 100644
--- a/src/utils/paste_pic.py
+++ b/src/utils/paste_pic.py
@@ -5,7 +5,7 @@
from src.utils.videoio import save_video_with_watermark
-def paste_pic(video_path, pic_path, crop_info, new_audio_path, full_video_path):
+def paste_pic(video_path, pic_path, crop_info, new_audio_path, full_video_path, extended_crop=False):
if not os.path.isfile(pic_path):
raise ValueError('pic_path must be a valid path to video/image file')
@@ -47,13 +47,16 @@ def paste_pic(video_path, pic_path, crop_info, new_audio_path, full_video_path):
lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry)
# oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx
# oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx
- oy1, oy2, ox1, ox2 = cly, cry, clx, crx
+ if extended_crop:
+ oy1, oy2, ox1, ox2 = cly, cry, clx, crx
+ else:
+ oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx
tmp_path = str(uuid.uuid4())+'.mp4'
out_tmp = cv2.VideoWriter(tmp_path, cv2.VideoWriter_fourcc(*'MP4V'), fps, (frame_w, frame_h))
for crop_frame in tqdm(crop_frames, 'seamlessClone:'):
- p = cv2.resize(crop_frame.astype(np.uint8), (crx-clx, cry - cly))
+ p = cv2.resize(crop_frame.astype(np.uint8), (ox2-ox1, oy2 - oy1))
mask = 255*np.ones(p.shape, p.dtype)
location = ((ox1+ox2) // 2, (oy1+oy2) // 2)
diff --git a/src/utils/preprocess.py b/src/utils/preprocess.py
index db94004c..0f784e6c 100644
--- a/src/utils/preprocess.py
+++ b/src/utils/preprocess.py
@@ -4,21 +4,19 @@
from PIL import Image
# 3dmm extraction
+import safetensors
+import safetensors.torch
from src.face3d.util.preprocess import align_img
from src.face3d.util.load_mats import load_lm3d
from src.face3d.models import networks
-try:
- import webui
- from src.face3d.extract_kp_videos_safe import KeypointExtractor
- assert torch.cuda.is_available() == True
-except:
- from src.face3d.extract_kp_videos import KeypointExtractor
-
from scipy.io import loadmat, savemat
-from src.utils.croper import Croper
+from src.utils.croper import Preprocesser
+
+
+import warnings
-import warnings
+from src.utils.safetensor_helper import load_x_from_safetensor
warnings.filterwarnings("ignore")
def split_coeff(coeffs):
@@ -46,20 +44,24 @@ def split_coeff(coeffs):
class CropAndExtract():
- def __init__(self, path_of_lm_croper, path_of_net_recon_model, dir_of_BFM_fitting, device):
+ def __init__(self, sadtalker_path, device):
- self.croper = Croper(path_of_lm_croper)
- self.kp_extractor = KeypointExtractor(device)
+ self.propress = Preprocesser(device)
self.net_recon = networks.define_net_recon(net_recon='resnet50', use_last_fc=False, init_path='').to(device)
- checkpoint = torch.load(path_of_net_recon_model, map_location=torch.device(device))
- self.net_recon.load_state_dict(checkpoint['net_recon'])
+
+ if sadtalker_path['use_safetensor']:
+ checkpoint = safetensors.torch.load_file(sadtalker_path['checkpoint'])
+ self.net_recon.load_state_dict(load_x_from_safetensor(checkpoint, 'face_3drecon'))
+ else:
+ checkpoint = torch.load(sadtalker_path['path_of_net_recon_model'], map_location=torch.device(device))
+ self.net_recon.load_state_dict(checkpoint['net_recon'])
+
self.net_recon.eval()
- self.lm3d_std = load_lm3d(dir_of_BFM_fitting)
+ self.lm3d_std = load_lm3d(sadtalker_path['dir_of_BFM_fitting'])
self.device = device
- def generate(self, input_path, save_dir, crop_or_resize='crop', source_image_flag=False):
+ def generate(self, input_path, save_dir, crop_or_resize='crop', source_image_flag=False, pic_size=256):
- pic_size = 256
pic_name = os.path.splitext(os.path.split(input_path)[-1])[0]
landmarks_path = os.path.join(save_dir, pic_name+'_landmarks.txt')
@@ -90,15 +92,15 @@ def generate(self, input_path, save_dir, crop_or_resize='crop', source_image_fla
x_full_frames= [cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) for frame in full_frames]
#### crop images as the
- if crop_or_resize.lower() == 'crop': # default crop
- x_full_frames, crop, quad = self.croper.crop(x_full_frames, still=True, xsize=512)
+ if 'crop' in crop_or_resize.lower(): # default crop
+ x_full_frames, crop, quad = self.propress.crop(x_full_frames, still=True if 'ext' in crop_or_resize.lower() else False, xsize=512)
clx, cly, crx, cry = crop
lx, ly, rx, ry = quad
lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry)
oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx
crop_info = ((ox2 - ox1, oy2 - oy1), crop, quad)
- elif crop_or_resize.lower() == 'full':
- x_full_frames, crop, quad = self.croper.crop(x_full_frames, still=True, xsize=512)
+ elif 'full' in crop_or_resize.lower():
+ x_full_frames, crop, quad = self.propress.crop(x_full_frames, still=True if 'ext' in crop_or_resize.lower() else False, xsize=512)
clx, cly, crx, cry = crop
lx, ly, rx, ry = quad
lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry)
@@ -119,7 +121,7 @@ def generate(self, input_path, save_dir, crop_or_resize='crop', source_image_fla
# 2. get the landmark according to the detected face.
if not os.path.isfile(landmarks_path):
- lm = self.kp_extractor.extract_keypoint(frames_pil, landmarks_path)
+ lm = self.propress.predictor.extract_keypoint(frames_pil, landmarks_path)
else:
print(' Using saved landmarks.')
lm = np.loadtxt(landmarks_path).astype(np.float32)
diff --git a/src/utils/safetensor_helper.py b/src/utils/safetensor_helper.py
new file mode 100644
index 00000000..3cdbdd21
--- /dev/null
+++ b/src/utils/safetensor_helper.py
@@ -0,0 +1,8 @@
+
+
+def load_x_from_safetensor(checkpoint, key):
+ x_generator = {}
+ for k,v in checkpoint.items():
+ if key in k:
+ x_generator[k.replace(key+'.', '')] = v
+ return x_generator
\ No newline at end of file
diff --git a/src/utils/text2speech.py b/src/utils/text2speech.py
index 6948edf1..00d165b6 100644
--- a/src/utils/text2speech.py
+++ b/src/utils/text2speech.py
@@ -3,7 +3,6 @@
from TTS.api import TTS
-
class TTSTalker():
def __init__(self) -> None:
model_name = TTS.list_models()[0]
diff --git a/src/utils/videoio.py b/src/utils/videoio.py
index bbb5f099..08bfbdd7 100644
--- a/src/utils/videoio.py
+++ b/src/utils/videoio.py
@@ -19,7 +19,7 @@ def load_video_to_cv2(input_path):
def save_video_with_watermark(video, audio, save_path, watermark=False):
temp_file = str(uuid.uuid4())+'.mp4'
- cmd = r'ffmpeg -y -i "%s" -i "%s" -vcodec copy "%s"' % (video, audio, temp_file)
+ cmd = r'ffmpeg -y -hide_banner -loglevel error -i "%s" -i "%s" -vcodec copy "%s"' % (video, audio, temp_file)
os.system(cmd)
if watermark is False:
@@ -36,6 +36,6 @@ def save_video_with_watermark(video, audio, save_path, watermark=False):
dir_path = os.path.dirname(os.path.realpath(__file__))
watarmark_path = dir_path+"/../../docs/sadtalker_logo.png"
- cmd = r'ffmpeg -y -hide_banner -i "%s" -i "%s" -filter_complex "[1]scale=100:-1[wm];[0][wm]overlay=(main_w-overlay_w)-10:10" "%s"' % (temp_file, watarmark_path, save_path)
+ cmd = r'ffmpeg -y -hide_banner -loglevel error -i "%s" -i "%s" -filter_complex "[1]scale=100:-1[wm];[0][wm]overlay=(main_w-overlay_w)-10:10" "%s"' % (temp_file, watarmark_path, save_path)
os.system(cmd)
os.remove(temp_file)
\ No newline at end of file
diff --git a/webui.sh b/webui.sh
new file mode 100755
index 00000000..24575023
--- /dev/null
+++ b/webui.sh
@@ -0,0 +1,140 @@
+#!/usr/bin/env bash
+
+
+# If run from macOS, load defaults from webui-macos-env.sh
+if [[ "$OSTYPE" == "darwin"* ]]; then
+ export TORCH_COMMAND="pip install torch==1.12.1 torchvision==0.13.1"
+fi
+
+# python3 executable
+if [[ -z "${python_cmd}" ]]
+then
+ python_cmd="python3"
+fi
+
+# git executable
+if [[ -z "${GIT}" ]]
+then
+ export GIT="git"
+fi
+
+# python3 venv without trailing slash (defaults to ${install_dir}/${clone_dir}/venv)
+if [[ -z "${venv_dir}" ]]
+then
+ venv_dir="venv"
+fi
+
+if [[ -z "${LAUNCH_SCRIPT}" ]]
+then
+ LAUNCH_SCRIPT="launcher.py"
+fi
+
+# this script cannot be run as root by default
+can_run_as_root=1
+
+# read any command line flags to the webui.sh script
+while getopts "f" flag > /dev/null 2>&1
+do
+ case ${flag} in
+ f) can_run_as_root=1;;
+ *) break;;
+ esac
+done
+
+# Disable sentry logging
+export ERROR_REPORTING=FALSE
+
+# Do not reinstall existing pip packages on Debian/Ubuntu
+export PIP_IGNORE_INSTALLED=0
+
+# Pretty print
+delimiter="################################################################"
+
+printf "\n%s\n" "${delimiter}"
+printf "\e[1m\e[32mInstall script for SadTalker + Web UI\n"
+printf "\e[1m\e[34mTested on Debian 11 (Bullseye)\e[0m"
+printf "\n%s\n" "${delimiter}"
+
+# Do not run as root
+if [[ $(id -u) -eq 0 && can_run_as_root -eq 0 ]]
+then
+ printf "\n%s\n" "${delimiter}"
+ printf "\e[1m\e[31mERROR: This script must not be launched as root, aborting...\e[0m"
+ printf "\n%s\n" "${delimiter}"
+ exit 1
+else
+ printf "\n%s\n" "${delimiter}"
+ printf "Running on \e[1m\e[32m%s\e[0m user" "$(whoami)"
+ printf "\n%s\n" "${delimiter}"
+fi
+
+if [[ -d .git ]]
+then
+ printf "\n%s\n" "${delimiter}"
+ printf "Repo already cloned, using it as install directory"
+ printf "\n%s\n" "${delimiter}"
+ install_dir="${PWD}/../"
+ clone_dir="${PWD##*/}"
+fi
+
+# Check prerequisites
+gpu_info=$(lspci 2>/dev/null | grep VGA)
+case "$gpu_info" in
+ *"Navi 1"*|*"Navi 2"*) export HSA_OVERRIDE_GFX_VERSION=10.3.0
+ ;;
+ *"Renoir"*) export HSA_OVERRIDE_GFX_VERSION=9.0.0
+ printf "\n%s\n" "${delimiter}"
+ printf "Experimental support for Renoir: make sure to have at least 4GB of VRAM and 10GB of RAM or enable cpu mode: --use-cpu all --no-half"
+ printf "\n%s\n" "${delimiter}"
+ ;;
+ *)
+ ;;
+esac
+if echo "$gpu_info" | grep -q "AMD" && [[ -z "${TORCH_COMMAND}" ]]
+then
+ export TORCH_COMMAND="pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/rocm5.2"
+fi
+
+for preq in "${GIT}" "${python_cmd}"
+do
+ if ! hash "${preq}" &>/dev/null
+ then
+ printf "\n%s\n" "${delimiter}"
+ printf "\e[1m\e[31mERROR: %s is not installed, aborting...\e[0m" "${preq}"
+ printf "\n%s\n" "${delimiter}"
+ exit 1
+ fi
+done
+
+if ! "${python_cmd}" -c "import venv" &>/dev/null
+then
+ printf "\n%s\n" "${delimiter}"
+ printf "\e[1m\e[31mERROR: python3-venv is not installed, aborting...\e[0m"
+ printf "\n%s\n" "${delimiter}"
+ exit 1
+fi
+
+printf "\n%s\n" "${delimiter}"
+printf "Create and activate python venv"
+printf "\n%s\n" "${delimiter}"
+cd "${install_dir}"/"${clone_dir}"/ || { printf "\e[1m\e[31mERROR: Can't cd to %s/%s/, aborting...\e[0m" "${install_dir}" "${clone_dir}"; exit 1; }
+if [[ ! -d "${venv_dir}" ]]
+then
+ "${python_cmd}" -m venv "${venv_dir}"
+ first_launch=1
+fi
+# shellcheck source=/dev/null
+if [[ -f "${venv_dir}"/bin/activate ]]
+then
+ source "${venv_dir}"/bin/activate
+else
+ printf "\n%s\n" "${delimiter}"
+ printf "\e[1m\e[31mERROR: Cannot activate python venv, aborting...\e[0m"
+ printf "\n%s\n" "${delimiter}"
+ exit 1
+fi
+
+printf "\n%s\n" "${delimiter}"
+printf "Launching launcher.py..."
+printf "\n%s\n" "${delimiter}"
+exec "${python_cmd}" "${LAUNCH_SCRIPT}" "$@"
\ No newline at end of file