diff --git a/.gitignore b/.gitignore index 65365db2..801935a6 100644 --- a/.gitignore +++ b/.gitignore @@ -161,7 +161,8 @@ cython_debug/ examples/results/* gfpgan/* -checkpoints/ +checkpoints/* +assets/* results/* Dockerfile start_docker.sh diff --git a/app.py b/app.py index edde0cf4..9b5afa3a 100644 --- a/app.py +++ b/app.py @@ -1,14 +1,70 @@ import os, sys -import tempfile import gradio as gr from src.gradio_demo import SadTalker -def get_source_image(image): - return image -def sadtalker_demo(): +try: + import webui # in webui + in_webui = True +except: + in_webui = False - sad_talker = SadTalker(lazy_load=True) +# mimetypes.init() +# mimetypes.add_type('application/javascript', '.js') + +# script_path = os.path.dirname(os.path.realpath(__file__)) + +# def webpath(fn): +# if fn.startswith(script_path): +# web_path = os.path.relpath(fn, script_path).replace('\\', '/') +# else: +# web_path = os.path.abspath(fn) + +# return f'file={web_path}?{os.path.getmtime(fn)}' + +# def javascript_html(): +# # Ensure localization is in `window` before scripts +# # head = f'\n' +# head = 'somehead' + +# script_js = os.path.join(script_path, "assets", "script.js") +# head += f'\n' + +# script_js = os.path.join(script_path, "assets", "aspectRatioOverlay.js") +# head += f'\n' + +# return head + +# def resize_from_to_html(width, height, scale_by): +# target_width = int(width * scale_by) +# target_height = int(height * scale_by) + +# if not target_width or not target_height: +# return "no image selected" + +# return f"resize: from {width}x{height} to {target_width}x{target_height}" + +# def get_source_image(image): +# return image + +# def reload_javascript(): +# js = javascript_html() + +# def template_response(*args, **kwargs): +# res = shared.GradioTemplateResponseOriginal(*args, **kwargs) +# res.body = res.body.replace(b'', f'{js}'.encode("utf8")) +# res.init_headers() +# return res + +# gradio.routes.templates.TemplateResponse = template_response + +# if not hasattr(shared, 'GradioTemplateResponseOriginal'): +# shared.GradioTemplateResponseOriginal = gradio.routes.templates.TemplateResponse + + +def sadtalker_demo(checkpoint_path='checkpoint', config_path='src/config', warpfn=None): + + sad_talker = SadTalker(checkpoint_path, config_path, lazy_load=True) with gr.Blocks(analytics_enabled=False) as sadtalker_interface: gr.Markdown("

😭 SadTalker: Learning Realistic 3D Motion Coefficients for Stylized Audio-Driven Single Image Talking Face Animation (CVPR 2023)

\ @@ -21,14 +77,15 @@ def sadtalker_demo(): with gr.Tabs(elem_id="sadtalker_source_image"): with gr.TabItem('Upload image'): with gr.Row(): - source_image = gr.Image(label="Source image", source="upload", type="filepath").style(height=256,width=256) - + source_image = gr.Image(label="Source image", source="upload", type="filepath", elem_id="img2img_image").style(width=512) + + with gr.Tabs(elem_id="sadtalker_driven_audio"): with gr.TabItem('Upload OR TTS'): with gr.Column(variant='panel'): driven_audio = gr.Audio(label="Input audio", source="upload", type="filepath") - - if sys.platform != 'win32': + + if sys.platform != 'win32' and not in_webui: from src.utils.text2speech import TTSTalker tts_talker = TTSTalker() with gr.Column(variant='panel'): @@ -36,105 +93,36 @@ def sadtalker_demo(): tts = gr.Button('Generate audio',elem_id="sadtalker_audio_generate", variant='primary') tts.click(fn=tts_talker.test, inputs=[input_text], outputs=[driven_audio]) - with gr.Column(variant='panel'): with gr.Tabs(elem_id="sadtalker_checkbox"): with gr.TabItem('Settings'): + gr.Markdown("need help? please visit our [[best practice page](https://github.com/OpenTalker/SadTalker/blob/main/docs/best_practice.md)] for more detials") with gr.Column(variant='panel'): - preprocess_type = gr.Radio(['crop','resize','full'], value='crop', label='preprocess', info="How to handle input image?") - is_still_mode = gr.Checkbox(label="w/ Still Mode (fewer hand motion, works with preprocess `full`)") - enhancer = gr.Checkbox(label="w/ GFPGAN as Face enhancer") + # width = gr.Slider(minimum=64, elem_id="img2img_width", maximum=2048, step=8, label="Manually Crop Width", value=512) # img2img_width + # height = gr.Slider(minimum=64, elem_id="img2img_height", maximum=2048, step=8, label="Manually Crop Height", value=512) # img2img_width + pose_style = gr.Slider(minimum=0, maximum=46, step=1, label="Pose style", value=0) # + size_of_image = gr.Radio([256, 512], value=256, label='face model resolution', info="use 256/512 model?") # + preprocess_type = gr.Radio(['crop', 'resize','full', 'extcrop', 'extfull'], value='crop', label='preprocess', info="How to handle input image?") + is_still_mode = gr.Checkbox(label="Still Mode (fewer hand motion, works with preprocess `full`)") + batch_size = gr.Slider(label="batch size in generation", step=1, maximum=10, value=2) + enhancer = gr.Checkbox(label="GFPGAN as Face enhancer") submit = gr.Button('Generate', elem_id="sadtalker_generate", variant='primary') + with gr.Tabs(elem_id="sadtalker_genearted"): gen_video = gr.Video(label="Generated video", format="mp4").style(width=256) - - - with gr.Row(): - examples = [ - [ - 'examples/source_image/full_body_1.png', - 'examples/driven_audio/bus_chinese.wav', - 'crop', - True, - False - ], - [ - 'examples/source_image/full_body_2.png', - 'examples/driven_audio/japanese.wav', - 'crop', - False, - False - ], - [ - 'examples/source_image/full3.png', - 'examples/driven_audio/deyu.wav', - 'crop', - False, - True - ], - [ - 'examples/source_image/full4.jpeg', - 'examples/driven_audio/eluosi.wav', - 'full', - False, - True - ], - [ - 'examples/source_image/full4.jpeg', - 'examples/driven_audio/imagine.wav', - 'full', - True, - True - ], - [ - 'examples/source_image/full_body_1.png', - 'examples/driven_audio/bus_chinese.wav', - 'full', - True, - False - ], - [ - 'examples/source_image/art_13.png', - 'examples/driven_audio/fayu.wav', - 'resize', - True, - False - ], - [ - 'examples/source_image/art_5.png', - 'examples/driven_audio/chinese_news.wav', - 'resize', - False, - False - ], - [ - 'examples/source_image/art_5.png', - 'examples/driven_audio/RD_Radio31_000.wav', - 'resize', - True, - True - ], - ] - gr.Examples(examples=examples, - inputs=[ - source_image, - driven_audio, - preprocess_type, - is_still_mode, - enhancer], - outputs=[gen_video], - fn=sad_talker.test, - cache_examples=os.getenv('SYSTEM') == 'spaces') # - submit.click( fn=sad_talker.test, inputs=[source_image, driven_audio, preprocess_type, is_still_mode, - enhancer], + enhancer, + batch_size, + size_of_image, + pose_style + ], outputs=[gen_video] ) @@ -144,6 +132,7 @@ def sadtalker_demo(): if __name__ == "__main__": demo = sadtalker_demo() - demo.launch(share=True) + demo.queue() + demo.launch() diff --git a/inference.py b/inference.py index 10409256..a0b00790 100644 --- a/inference.py +++ b/inference.py @@ -1,3 +1,5 @@ +from glob import glob +import shutil import torch from time import strftime import os, sys, time @@ -8,6 +10,7 @@ from src.facerender.animate import AnimateFromCoeff from src.generate_batch import get_data from src.generate_facerender_batch import get_facerender_data +from src.utils.init_path import init_path def main(args): #torch.backends.cudnn.enabled = False @@ -25,51 +28,23 @@ def main(args): ref_eyeblink = args.ref_eyeblink ref_pose = args.ref_pose - current_code_path = sys.argv[0] - current_root_path = os.path.split(current_code_path)[0] + current_root_path = os.path.split(sys.argv[0])[0] - os.environ['TORCH_HOME']=os.path.join(current_root_path, args.checkpoint_dir) - - path_of_lm_croper = os.path.join(current_root_path, args.checkpoint_dir, 'shape_predictor_68_face_landmarks.dat') - path_of_net_recon_model = os.path.join(current_root_path, args.checkpoint_dir, 'epoch_20.pth') - dir_of_BFM_fitting = os.path.join(current_root_path, args.checkpoint_dir, 'BFM_Fitting') - wav2lip_checkpoint = os.path.join(current_root_path, args.checkpoint_dir, 'wav2lip.pth') - - audio2pose_checkpoint = os.path.join(current_root_path, args.checkpoint_dir, 'auido2pose_00140-model.pth') - audio2pose_yaml_path = os.path.join(current_root_path, 'src', 'config', 'auido2pose.yaml') - - audio2exp_checkpoint = os.path.join(current_root_path, args.checkpoint_dir, 'auido2exp_00300-model.pth') - audio2exp_yaml_path = os.path.join(current_root_path, 'src', 'config', 'auido2exp.yaml') - - free_view_checkpoint = os.path.join(current_root_path, args.checkpoint_dir, 'facevid2vid_00189-model.pth.tar') - - if args.preprocess == 'full': - mapping_checkpoint = os.path.join(current_root_path, args.checkpoint_dir, 'mapping_00109-model.pth.tar') - facerender_yaml_path = os.path.join(current_root_path, 'src', 'config', 'facerender_still.yaml') - else: - mapping_checkpoint = os.path.join(current_root_path, args.checkpoint_dir, 'mapping_00229-model.pth.tar') - facerender_yaml_path = os.path.join(current_root_path, 'src', 'config', 'facerender.yaml') + sadtalker_paths = init_path(args.checkpoint_dir, os.path.join(current_root_path, 'src/config'), args.size, args.old_version, args.preprocess) #init model - print(path_of_net_recon_model) - preprocess_model = CropAndExtract(path_of_lm_croper, path_of_net_recon_model, dir_of_BFM_fitting, device) - - print(audio2pose_checkpoint) - print(audio2exp_checkpoint) - audio_to_coeff = Audio2Coeff(audio2pose_checkpoint, audio2pose_yaml_path, - audio2exp_checkpoint, audio2exp_yaml_path, - wav2lip_checkpoint, device) + preprocess_model = CropAndExtract(sadtalker_paths, device) + + audio_to_coeff = Audio2Coeff(sadtalker_paths, device) - print(free_view_checkpoint) - print(mapping_checkpoint) - animate_from_coeff = AnimateFromCoeff(free_view_checkpoint, mapping_checkpoint, - facerender_yaml_path, device) + animate_from_coeff = AnimateFromCoeff(sadtalker_paths, device) #crop image and extract 3dmm from image first_frame_dir = os.path.join(save_dir, 'first_frame_dir') os.makedirs(first_frame_dir, exist_ok=True) print('3DMM Extraction for source image') - first_coeff_path, crop_pic_path, crop_info = preprocess_model.generate(pic_path, first_frame_dir, args.preprocess, source_image_flag=True) + first_coeff_path, crop_pic_path, crop_info = preprocess_model.generate(pic_path, first_frame_dir, args.preprocess,\ + source_image_flag=True, pic_size=args.size) if first_coeff_path is None: print("Can't get the coeffs of the input") return @@ -79,7 +54,7 @@ def main(args): ref_eyeblink_frame_dir = os.path.join(save_dir, ref_eyeblink_videoname) os.makedirs(ref_eyeblink_frame_dir, exist_ok=True) print('3DMM Extraction for the reference video providing eye blinking') - ref_eyeblink_coeff_path, _, _ = preprocess_model.generate(ref_eyeblink, ref_eyeblink_frame_dir) + ref_eyeblink_coeff_path, _, _ = preprocess_model.generate(ref_eyeblink, ref_eyeblink_frame_dir, args.preprocess, source_image_flag=False) else: ref_eyeblink_coeff_path=None @@ -91,7 +66,7 @@ def main(args): ref_pose_frame_dir = os.path.join(save_dir, ref_pose_videoname) os.makedirs(ref_pose_frame_dir, exist_ok=True) print('3DMM Extraction for the reference video providing pose') - ref_pose_coeff_path, _, _ = preprocess_model.generate(ref_pose, ref_pose_frame_dir) + ref_pose_coeff_path, _, _ = preprocess_model.generate(ref_pose, ref_pose_frame_dir, args.preprocess, source_image_flag=False) else: ref_pose_coeff_path=None @@ -107,22 +82,30 @@ def main(args): #coeff2video data = get_facerender_data(coeff_path, crop_pic_path, first_coeff_path, audio_path, batch_size, input_yaw_list, input_pitch_list, input_roll_list, - expression_scale=args.expression_scale, still_mode=args.still, preprocess=args.preprocess) + expression_scale=args.expression_scale, still_mode=args.still, preprocess=args.preprocess, size=args.size) - animate_from_coeff.generate(data, save_dir, pic_path, crop_info, \ - enhancer=args.enhancer, background_enhancer=args.background_enhancer, preprocess=args.preprocess) + result = animate_from_coeff.generate(data, save_dir, pic_path, crop_info, \ + enhancer=args.enhancer, background_enhancer=args.background_enhancer, preprocess=args.preprocess, img_size=args.size) + + shutil.move(result, save_dir+'.mp4') + print('The generated video is named:', save_dir+'.mp4') + + if not args.verbose: + shutil.rmtree(save_dir) + if __name__ == '__main__': parser = ArgumentParser() parser.add_argument("--driven_audio", default='./examples/driven_audio/bus_chinese.wav', help="path to driven audio") - parser.add_argument("--source_image", default='./examples/source_image/full_body_2.png', help="path to source image") + parser.add_argument("--source_image", default='./examples/source_image/full_body_1.png', help="path to source image") parser.add_argument("--ref_eyeblink", default=None, help="path to reference video providing eye blinking") parser.add_argument("--ref_pose", default=None, help="path to reference video providing pose") parser.add_argument("--checkpoint_dir", default='./checkpoints', help="path to output") parser.add_argument("--result_dir", default='./results', help="path to output") parser.add_argument("--pose_style", type=int, default=0, help="input pose style from [0, 46)") parser.add_argument("--batch_size", type=int, default=2, help="the batch size of facerender") + parser.add_argument("--size", type=int, default=256, help="the image size of the facerender") parser.add_argument("--expression_scale", type=float, default=1., help="the batch size of facerender") parser.add_argument('--input_yaw', nargs='+', type=int, default=None, help="the input yaw degree of the user ") parser.add_argument('--input_pitch', nargs='+', type=int, default=None, help="the input pitch degree of the user") @@ -132,7 +115,10 @@ def main(args): parser.add_argument("--cpu", dest="cpu", action="store_true") parser.add_argument("--face3dvis", action="store_true", help="generate 3d face and 3d landmarks") parser.add_argument("--still", action="store_true", help="can crop back to the original videos for the full body aniamtion") - parser.add_argument("--preprocess", default='crop', choices=['crop', 'resize', 'full'], help="how to preprocess the images" ) + parser.add_argument("--preprocess", default='crop', choices=['crop', 'extcrop', 'resize', 'full', 'extfull'], help="how to preprocess the images" ) + parser.add_argument("--verbose",action="store_true", help="saving the intermedia output or not" ) + parser.add_argument("--old_version",action="store_true", help="use the pth other than safetensor version" ) + # net structure and parameters parser.add_argument('--net_recon', type=str, default='resnet50', choices=['resnet18', 'resnet34', 'resnet50'], help='useless') diff --git a/launcher.py b/launcher.py index 565b209d..1dedb2fa 100644 --- a/launcher.py +++ b/launcher.py @@ -170,8 +170,8 @@ def run_extension_installer(extension_dir): def prepare_environment(): global skip_install - torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117") - requirements_file = os.environ.get('REQS_FILE', "requirements.txt") + torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113") + requirements_file = os.environ.get('REQS_FILE', "req.txt") commit = commit_hash() @@ -181,16 +181,18 @@ def prepare_environment(): if not is_installed("torch") or not is_installed("torchvision"): run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True) - run_python("import torch; assert torch.cuda.is_available(), 'Torch is not able to use GPU; add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'") - run_pip(f"install -r \"{requirements_file}\"", "requirements for SadTalker WebUI (may take longer time in first time)") + if sys.platform != 'win32' and not is_installed('tts'): + run_pip(f"install TTS", "install TTS individually in SadTalker, which might not work on windows.") + def start(): print(f"Launching SadTalker Web UI") from app import sadtalker_demo demo = sadtalker_demo() - demo.launch(share=True) + demo.queue() + demo.launch() if __name__ == "__main__": prepare_environment() diff --git a/req.txt b/req.txt new file mode 100644 index 00000000..3c384dd8 --- /dev/null +++ b/req.txt @@ -0,0 +1,23 @@ +llvmlite==0.38.1 +numpy==1.21.6 +face_alignment==1.3.5 +imageio==2.19.3 +imageio-ffmpeg==0.4.7 +librosa==0.10.0.post2 +numba==0.55.1 +resampy==0.3.1 +pydub==0.25.1 +scipy==1.10.1 +kornia==0.6.8 +tqdm +yacs==0.1.8 +pyyaml +joblib==1.1.0 +scikit-image==0.19.3 +basicsr==1.4.2 +facexlib==0.3.0 +gradio +gfpgan +dlib-bin +av +safetensors diff --git a/requirements.txt b/requirements.txt index 238c237f..2321c742 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,3 +19,4 @@ gradio gfpgan dlib-bin av +safetensors diff --git a/requirements3d.txt b/requirements3d.txt index 25b6b262..c7eeeecd 100644 --- a/requirements3d.txt +++ b/requirements3d.txt @@ -18,4 +18,5 @@ facexlib==0.2.5 trimesh==3.9.20 dlib-bin gradio -gfpgan \ No newline at end of file +gfpgan +safetensors \ No newline at end of file diff --git a/scripts/download_models.sh b/scripts/download_models.sh index 213eb499..425cada0 100644 --- a/scripts/download_models.sh +++ b/scripts/download_models.sh @@ -1,19 +1,25 @@ mkdir ./checkpoints -wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/auido2exp_00300-model.pth -O ./checkpoints/auido2exp_00300-model.pth -wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/auido2pose_00140-model.pth -O ./checkpoints/auido2pose_00140-model.pth -wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/epoch_20.pth -O ./checkpoints/epoch_20.pth -wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/facevid2vid_00189-model.pth.tar -O ./checkpoints/facevid2vid_00189-model.pth.tar -wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/shape_predictor_68_face_landmarks.dat -O ./checkpoints/shape_predictor_68_face_landmarks.dat -wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/wav2lip.pth -O ./checkpoints/wav2lip.pth -wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/mapping_00229-model.pth.tar -O ./checkpoints/mapping_00229-model.pth.tar -wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/mapping_00109-model.pth.tar -O ./checkpoints/mapping_00109-model.pth.tar -wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/BFM_Fitting.zip -O ./checkpoints/BFM_Fitting.zip -wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/hub.zip -O ./checkpoints/hub.zip + +# lagency download link +# wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/auido2exp_00300-model.pth -O ./checkpoints/auido2exp_00300-model.pth +# wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/auido2pose_00140-model.pth -O ./checkpoints/auido2pose_00140-model.pth +# wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/epoch_20.pth -O ./checkpoints/epoch_20.pth +# wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/facevid2vid_00189-model.pth.tar -O ./checkpoints/facevid2vid_00189-model.pth.tar +# wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/shape_predictor_68_face_landmarks.dat -O ./checkpoints/shape_predictor_68_face_landmarks.dat +# wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/wav2lip.pth -O ./checkpoints/wav2lip.pth +# wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/mapping_00229-model.pth.tar -O ./checkpoints/mapping_00229-model.pth.tar +# wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/mapping_00109-model.pth.tar -O ./checkpoints/mapping_00109-model.pth.tar +# wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/hub.zip -O ./checkpoints/hub.zip +# unzip -n ./checkpoints/hub.zip -d ./checkpoints/ -unzip -n ./checkpoints/hub.zip -d ./checkpoints/ +#### download the new links. +wget -nc + +wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/BFM_Fitting.zip -O ./checkpoints/BFM_Fitting.zip unzip -n ./checkpoints/BFM_Fitting.zip -d ./checkpoints/ +### enhancer mkdir -p ./gfpgan/weights wget -nc https://github.com/xinntao/facexlib/releases/download/v0.1.0/alignment_WFLW_4HG.pth -O ./gfpgan/weights/alignment_WFLW_4HG.pth wget -nc https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_Resnet50_Final.pth -O ./gfpgan/weights/detection_Resnet50_Final.pth diff --git a/scripts/extension.py b/scripts/extension.py index 8cd9ca89..8fcdf3a8 100644 --- a/scripts/extension.py +++ b/scripts/extension.py @@ -10,6 +10,26 @@ from huggingface_hub import snapshot_download + +def check_all_files_safetensor(current_dir): + kv = { + "SadTalker_V0.0.2_256.safetensors": "sadtalker-256", + "SadTalker_V0.0.2_512.safetensors": "sadtalker-512", + "mapping_00109-model.pth.tar" : "mapping-109" , + "mapping_00229-model.pth.tar" : "mapping-229" , + } + + if not os.path.isdir(current_dir): + return False + + dirs = os.listdir(current_dir) + + for f in dirs: + if f in kv.keys(): + del kv[f] + + return len(kv.keys()) == 0 + def check_all_files(current_dir): kv = { "auido2exp_00300-model.pth": "audio2exp", @@ -63,6 +83,14 @@ def get_default_checkpoint_path(): checkpoint_path = Path(paths.script_path) / "models"/ "SadTalker" extension_checkpoint_path = Path(paths.script_path) / "extensions"/ "SadTalker" / "checkpoints" + if check_all_files_safetensor(checkpoint_path): + # print('founding sadtalker checkpoint in ' + str(checkpoint_path)) + return checkpoint_path + + if check_all_files_safetensor(extension_checkpoint_path): + # print('founding sadtalker checkpoint in ' + str(extension_checkpoint_path)) + return extension_checkpoint_path + if check_all_files(checkpoint_path): # print('founding sadtalker checkpoint in ' + str(checkpoint_path)) return checkpoint_path @@ -91,17 +119,17 @@ def install(): "gfpgan": "gfpgan", } - if 'darwin' in sys.platform: - kv['dlib'] = "dlib" - else: - kv['dlib'] = 'dlib-bin' + # dlib is not necessary currently + # if 'darwin' in sys.platform: + # kv['dlib'] = "dlib" + # else: + # kv['dlib'] = 'dlib-bin' for k,v in kv.items(): if not launch.is_installed(k): print(k, launch.is_installed(k)) launch.run_pip("install "+ v, "requirements for SadTalker") - if os.getenv('SADTALKER_CHECKPOINTS'): print('load Sadtalker Checkpoints from '+ os.getenv('SADTALKER_CHECKPOINTS')) @@ -138,66 +166,15 @@ def on_ui_tabs(): result_dir = opts.sadtalker_result_dir os.makedirs(result_dir, exist_ok=True) - from src.gradio_demo import SadTalker + from app import sadtalker_demo if os.getenv('SADTALKER_CHECKPOINTS'): checkpoint_path = os.getenv('SADTALKER_CHECKPOINTS') else: checkpoint_path = repo_dir+'checkpoints/' - sad_talker = SadTalker(checkpoint_path=checkpoint_path, config_path=repo_dir+'src/config', lazy_load=True) - - with gr.Blocks(analytics_enabled=False) as audio_to_video: - with gr.Row().style(equal_height=False): - with gr.Column(variant='panel'): - with gr.Tabs(elem_id="sadtalker_source_image"): - with gr.TabItem('Upload image'): - with gr.Row(): - input_image = gr.Image(label="Source image", source="upload", type="filepath").style(height=256,width=256) - - with gr.Row(): - submit_image2 = gr.Button('load From txt2img', variant='primary') - submit_image2.click(fn=get_img_from_txt2img, inputs=input_image, outputs=[input_image, input_image]) - - submit_image3 = gr.Button('load from img2img', variant='primary') - submit_image3.click(fn=get_img_from_img2img, inputs=input_image, outputs=[input_image, input_image]) - - with gr.Tabs(elem_id="sadtalker_driven_audio"): - with gr.TabItem('Upload'): - with gr.Column(variant='panel'): - - with gr.Row(): - driven_audio = gr.Audio(label="Input audio", source="upload", type="filepath") - - - with gr.Column(variant='panel'): - with gr.Tabs(elem_id="sadtalker_checkbox"): - with gr.TabItem('Settings'): - with gr.Column(variant='panel'): - gr.Markdown("Please visit [**[here]**](https://github.com/Winfredy/SadTalker/blob/main/docs/best_practice.md) if you don't know how to choose these configurations.") - preprocess_type = gr.Radio(['crop','resize','full'], value='crop', label='preprocess', info="How to handle input image?") - is_still_mode = gr.Checkbox(label="Remove head motion (works better with preprocess `full`)") - enhancer = gr.Checkbox(label="Face enhancement") - submit = gr.Button('Generate', elem_id="sadtalker_generate", variant='primary') - path_to_save = gr.Text(Path(paths.script_path) / "outputs/SadTalker/", visible=False) - - with gr.Tabs(elem_id="sadtalker_genearted"): - gen_video = gr.Video(label="Generated video", format="mp4").style(width=256) - - - ### gradio gpu call will always return the html, - submit.click( - fn=wrap_queued_call(sad_talker.test), - inputs=[input_image, - driven_audio, - preprocess_type, - is_still_mode, - enhancer, - path_to_save - ], - outputs=[gen_video, ] - ) - + audio_to_video = sadtalker_demo(checkpoint_path=checkpoint_path, config_path=repo_dir+'src/config', warpfn = wrap_queued_call) + return [(audio_to_video, "SadTalker", "extension")] def on_ui_settings(): diff --git a/scripts/test.sh b/scripts/test.sh index 69147611..281c9c03 100644 --- a/scripts/test.sh +++ b/scripts/test.sh @@ -1 +1,21 @@ -### some test command before commit. \ No newline at end of file +### some test command before commit. +python inference.py --preprocess crop --size 256 +python inference.py --preprocess crop --size 512 + +python inference.py --preprocess extcrop --size 256 +python inference.py --preprocess extcrop --size 512 + +python inference.py --preprocess resize --size 256 +python inference.py --preprocess resize --size 512 + +python inference.py --preprocess full --size 256 +python inference.py --preprocess full --size 512 + +python inference.py --preprocess extfull --size 256 +python inference.py --preprocess extfull --size 512 + +python inference.py --preprocess full --size 256 --enhancer gfpgan +python inference.py --preprocess full --size 512 --enhancer gfpgan + +python inference.py --preprocess full --size 256 --enhancer gfpgan --still +python inference.py --preprocess full --size 512 --enhancer gfpgan --still diff --git a/src/audio2pose_models/audio_encoder.py b/src/audio2pose_models/audio_encoder.py index ea9095ad..6279d201 100644 --- a/src/audio2pose_models/audio_encoder.py +++ b/src/audio2pose_models/audio_encoder.py @@ -41,14 +41,14 @@ def __init__(self, wav2lip_checkpoint, device): Conv2d(256, 512, kernel_size=3, stride=1, padding=0), Conv2d(512, 512, kernel_size=1, stride=1, padding=0),) - #### load the pre-trained audio_encoder - wav2lip_state_dict = torch.load(wav2lip_checkpoint, map_location=torch.device(device))['state_dict'] - state_dict = self.audio_encoder.state_dict() + #### load the pre-trained audio_encoder, we do not need to load wav2lip model here. + # wav2lip_state_dict = torch.load(wav2lip_checkpoint, map_location=torch.device(device))['state_dict'] + # state_dict = self.audio_encoder.state_dict() - for k,v in wav2lip_state_dict.items(): - if 'audio_encoder' in k: - state_dict[k.replace('module.audio_encoder.', '')] = v - self.audio_encoder.load_state_dict(state_dict) + # for k,v in wav2lip_state_dict.items(): + # if 'audio_encoder' in k: + # state_dict[k.replace('module.audio_encoder.', '')] = v + # self.audio_encoder.load_state_dict(state_dict) def forward(self, audio_sequences): diff --git a/src/config/similarity_Lm3D_all.mat b/src/config/similarity_Lm3D_all.mat new file mode 100644 index 00000000..a0e23588 Binary files /dev/null and b/src/config/similarity_Lm3D_all.mat differ diff --git a/src/face3d/extract_kp_videos_safe.py b/src/face3d/extract_kp_videos_safe.py index 5c9cff87..262439b9 100644 --- a/src/face3d/extract_kp_videos_safe.py +++ b/src/face3d/extract_kp_videos_safe.py @@ -55,11 +55,6 @@ def extract_keypoint(self, images, name=None, info=True): bboxes = self.det_net.detect_faces(images, 0.97) bboxes = bboxes[0] - - # bboxes[0] -= 100 - # bboxes[1] -= 100 - # bboxes[2] += 100 - # bboxes[3] += 100 img = img[int(bboxes[1]):int(bboxes[3]), int(bboxes[0]):int(bboxes[2]), :] keypoints = landmark_98_to_68(self.detector.get_landmarks(img)) # [0] diff --git a/src/facerender/animate.py b/src/facerender/animate.py index 3adea961..563d87fe 100644 --- a/src/facerender/animate.py +++ b/src/facerender/animate.py @@ -4,7 +4,8 @@ import numpy as np import warnings from skimage import img_as_ubyte - +import safetensors +import safetensors.torch warnings.filterwarnings('ignore') @@ -26,10 +27,9 @@ class AnimateFromCoeff(): - def __init__(self, free_view_checkpoint, mapping_checkpoint, - config_path, device): + def __init__(self, sadtalker_path, device): - with open(config_path) as f: + with open(sadtalker_path['facerender_yaml']) as f: config = yaml.safe_load(f) generator = OcclusionAwareSPADEGenerator(**config['model_params']['generator_params'], @@ -40,7 +40,6 @@ def __init__(self, free_view_checkpoint, mapping_checkpoint, **config['model_params']['common_params']) mapping = MappingNet(**config['model_params']['mapping_params']) - generator.to(device) kp_extractor.to(device) he_estimator.to(device) @@ -54,13 +53,16 @@ def __init__(self, free_view_checkpoint, mapping_checkpoint, for param in mapping.parameters(): param.requires_grad = False - if free_view_checkpoint is not None: - self.load_cpk_facevid2vid(free_view_checkpoint, kp_detector=kp_extractor, generator=generator, he_estimator=he_estimator) + if sadtalker_path is not None: + if 'safetensors' in sadtalker_path['checkpoint']: + self.load_cpk_facevid2vid_safetensor(sadtalker_path['checkpoint'], kp_detector=kp_extractor, generator=generator, he_estimator=None) + else: + self.load_cpk_facevid2vid(sadtalker_path['free_view_checkpoint'], kp_detector=kp_extractor, generator=generator, he_estimator=he_estimator) else: raise AttributeError("Checkpoint should be specified for video head pose estimator.") - if mapping_checkpoint is not None: - self.load_cpk_mapping(mapping_checkpoint, mapping=mapping) + if sadtalker_path['mappingnet_checkpoint'] is not None: + self.load_cpk_mapping(sadtalker_path['mappingnet_checkpoint'], mapping=mapping) else: raise AttributeError("Checkpoint should be specified for video head pose estimator.") @@ -76,6 +78,33 @@ def __init__(self, free_view_checkpoint, mapping_checkpoint, self.device = device + def load_cpk_facevid2vid_safetensor(self, checkpoint_path, generator=None, + kp_detector=None, he_estimator=None, + device="cpu"): + + checkpoint = safetensors.torch.load_file(checkpoint_path) + + if generator is not None: + x_generator = {} + for k,v in checkpoint.items(): + if 'generator' in k: + x_generator[k.replace('generator.', '')] = v + generator.load_state_dict(x_generator) + if kp_detector is not None: + x_generator = {} + for k,v in checkpoint.items(): + if 'kp_extractor' in k: + x_generator[k.replace('kp_extractor.', '')] = v + kp_detector.load_state_dict(x_generator) + if he_estimator is not None: + x_generator = {} + for k,v in checkpoint.items(): + if 'he_estimator' in k: + x_generator[k.replace('he_estimator.', '')] = v + he_estimator.load_state_dict(x_generator) + + return None + def load_cpk_facevid2vid(self, checkpoint_path, generator=None, discriminator=None, kp_detector=None, he_estimator=None, optimizer_generator=None, optimizer_discriminator=None, optimizer_kp_detector=None, @@ -120,7 +149,7 @@ def load_cpk_mapping(self, checkpoint_path, mapping=None, discriminator=None, return checkpoint['epoch'] - def generate(self, x, video_save_dir, pic_path, crop_info, enhancer=None, background_enhancer=None, preprocess='crop'): + def generate(self, x, video_save_dir, pic_path, crop_info, enhancer=None, background_enhancer=None, preprocess='crop', img_size=256): source_image=x['source_image'].type(torch.FloatTensor) source_semantics=x['source_semantics'].type(torch.FloatTensor) @@ -160,10 +189,10 @@ def generate(self, x, video_save_dir, pic_path, crop_info, enhancer=None, backgr video.append(image) result = img_as_ubyte(video) - ### the generated video is 256x256, so we keep the aspect ratio, + ### the generated video is 256x256, so we keep the aspect ratio, original_size = crop_info[0] if original_size: - result = [ cv2.resize(result_i,(256, int(256.0 * original_size[1]/original_size[0]) )) for result_i in result ] + result = [ cv2.resize(result_i,(img_size, int(img_size * original_size[1]/original_size[0]) )) for result_i in result ] video_name = x['video_name'] + '.mp4' path = os.path.join(video_save_dir, 'temp_'+video_name) @@ -185,14 +214,14 @@ def generate(self, x, video_save_dir, pic_path, crop_info, enhancer=None, backgr word.export(new_audio_path, format="wav") save_video_with_watermark(path, new_audio_path, av_path, watermark= False) - print(f'The generated video is named {video_name} in {video_save_dir}') + print(f'The generated video is named {video_save_dir}/{video_name}') - if preprocess.lower() == 'full': + if 'full' in preprocess.lower(): # only add watermark to the full image. video_name_full = x['video_name'] + '_full.mp4' full_video_path = os.path.join(video_save_dir, video_name_full) return_path = full_video_path - paste_pic(path, pic_path, crop_info, new_audio_path, full_video_path) + paste_pic(path, pic_path, crop_info, new_audio_path, full_video_path, extended_crop= True if 'ext' in preprocess.lower() else False) print(f'The generated video is named {video_save_dir}/{video_name_full}') else: full_video_path = av_path diff --git a/src/facerender/modules/make_animation.py b/src/facerender/modules/make_animation.py index e7887a3f..b2616648 100644 --- a/src/facerender/modules/make_animation.py +++ b/src/facerender/modules/make_animation.py @@ -29,7 +29,7 @@ def normalize_kp(kp_source, kp_driving, kp_driving_initial, adapt_movement_scale def headpose_pred_to_degree(pred): device = pred.device idx_tensor = [idx for idx in range(66)] - idx_tensor = torch.FloatTensor(idx_tensor).to(device) + idx_tensor = torch.FloatTensor(idx_tensor).type_as(pred).to(device) pred = F.softmax(pred) degree = torch.sum(pred*idx_tensor, 1) * 3 - 99 return degree @@ -102,7 +102,7 @@ def keypoint_transformation(kp_canonical, he, wo_exp=False): def make_animation(source_image, source_semantics, target_semantics, generator, kp_detector, he_estimator, mapping, yaw_c_seq=None, pitch_c_seq=None, roll_c_seq=None, - use_exp=True): + use_exp=True, use_half=False): with torch.no_grad(): predictions = [] @@ -122,8 +122,6 @@ def make_animation(source_image, source_semantics, target_semantics, kp_driving = keypoint_transformation(kp_canonical, he_driving) - #kp_norm = normalize_kp(kp_source=kp_source, kp_driving=kp_driving, - #kp_driving_initial=kp_driving_initial) kp_norm = kp_driving out = generator(source_image, kp_source=kp_source, kp_driving=kp_norm) ''' diff --git a/src/generate_facerender_batch.py b/src/generate_facerender_batch.py index 9ec7a169..7538b475 100644 --- a/src/generate_facerender_batch.py +++ b/src/generate_facerender_batch.py @@ -7,7 +7,7 @@ def get_facerender_data(coeff_path, pic_path, first_coeff_path, audio_path, batch_size, input_yaw_list=None, input_pitch_list=None, input_roll_list=None, - expression_scale=1.0, still_mode = False, preprocess='crop'): + expression_scale=1.0, still_mode = False, preprocess='crop', size = 256): semantic_radius = 13 video_name = os.path.splitext(os.path.split(coeff_path)[-1])[0] @@ -18,7 +18,7 @@ def get_facerender_data(coeff_path, pic_path, first_coeff_path, audio_path, img1 = Image.open(pic_path) source_image = np.array(img1) source_image = img_as_float32(source_image) - source_image = transform.resize(source_image, (256, 256, 3)) + source_image = transform.resize(source_image, (size, size, 3)) source_image = source_image.transpose((2, 0, 1)) source_image_ts = torch.FloatTensor(source_image).unsqueeze(0) source_image_ts = source_image_ts.repeat(batch_size, 1, 1, 1) @@ -26,7 +26,7 @@ def get_facerender_data(coeff_path, pic_path, first_coeff_path, audio_path, source_semantics_dict = scio.loadmat(first_coeff_path) - if preprocess.lower() != 'full': + if 'full' not in preprocess.lower(): source_semantics = source_semantics_dict['coeff_3dmm'][:1,:70] #1 70 else: source_semantics = source_semantics_dict['coeff_3dmm'][:1,:73] #1 70 @@ -41,7 +41,7 @@ def get_facerender_data(coeff_path, pic_path, first_coeff_path, audio_path, generated_3dmm = generated_dict['coeff_3dmm'] generated_3dmm[:, :64] = generated_3dmm[:, :64] * expression_scale - if preprocess.lower() == 'full': + if 'full' in preprocess.lower(): generated_3dmm = np.concatenate([generated_3dmm, np.repeat(source_semantics[:,70:], generated_3dmm.shape[0], axis=0)], axis=1) if still_mode: diff --git a/src/gradio_demo.py b/src/gradio_demo.py index 4bdb2177..11eaf9cd 100644 --- a/src/gradio_demo.py +++ b/src/gradio_demo.py @@ -6,8 +6,11 @@ from src.generate_batch import get_data from src.generate_facerender_batch import get_facerender_data +from src.utils.init_path import init_path + from pydub import AudioSegment + def mp3_to_wav(mp3_filename,wav_filename,frame_rate): mp3_file = AudioSegment.from_file(file=mp3_filename) mp3_file.set_frame_rate(frame_rate).export(wav_filename,format="wav") @@ -28,56 +31,17 @@ def __init__(self, checkpoint_path='checkpoints', config_path='src/config', lazy self.checkpoint_path = checkpoint_path self.config_path = config_path + - self.path_of_lm_croper = os.path.join( checkpoint_path, 'shape_predictor_68_face_landmarks.dat') - self.path_of_net_recon_model = os.path.join( checkpoint_path, 'epoch_20.pth') - self.dir_of_BFM_fitting = os.path.join( checkpoint_path, 'BFM_Fitting') - self.wav2lip_checkpoint = os.path.join( checkpoint_path, 'wav2lip.pth') - - self.audio2pose_checkpoint = os.path.join( checkpoint_path, 'auido2pose_00140-model.pth') - self.audio2pose_yaml_path = os.path.join( config_path, 'auido2pose.yaml') - - self.audio2exp_checkpoint = os.path.join( checkpoint_path, 'auido2exp_00300-model.pth') - self.audio2exp_yaml_path = os.path.join( config_path, 'auido2exp.yaml') - - self.free_view_checkpoint = os.path.join( checkpoint_path, 'facevid2vid_00189-model.pth.tar') + def test(self, source_image, driven_audio, preprocess='crop', + still_mode=False, use_enhancer=False, batch_size=1, size=256, pose_style = 0,result_dir='./results/'): - self.lazy_load = lazy_load - - if not self.lazy_load: - #init model + self.sadtalker_paths = init_path(self.checkpoint_path, self.config_path, size, False, preprocess) + print(self.sadtalker_paths) - print(self.audio2pose_checkpoint) - self.audio_to_coeff = Audio2Coeff(self.audio2pose_checkpoint, self.audio2pose_yaml_path, - self.audio2exp_checkpoint, self.audio2exp_yaml_path, self.wav2lip_checkpoint, self.device) - - print(self.path_of_lm_croper) - self.preprocess_model = CropAndExtract(self.path_of_lm_croper, self.path_of_net_recon_model, self.dir_of_BFM_fitting, self.device) - - def test(self, source_image, driven_audio, preprocess='crop', still_mode=False, use_enhancer=False, result_dir='./results/'): - - ### crop: only model, - - if self.lazy_load: - #init model - - print(self.audio2pose_checkpoint) - self.audio_to_coeff = Audio2Coeff(self.audio2pose_checkpoint, self.audio2pose_yaml_path, - self.audio2exp_checkpoint, self.audio2exp_yaml_path, self.wav2lip_checkpoint, self.device) - - print(self.path_of_lm_croper) - self.preprocess_model = CropAndExtract(self.path_of_lm_croper, self.path_of_net_recon_model, self.dir_of_BFM_fitting, self.device) - - if preprocess == 'full': - self.mapping_checkpoint = os.path.join(self.checkpoint_path, 'mapping_00109-model.pth.tar') - self.facerender_yaml_path = os.path.join(self.config_path, 'facerender_still.yaml') - else: - self.mapping_checkpoint = os.path.join(self.checkpoint_path, 'mapping_00229-model.pth.tar') - self.facerender_yaml_path = os.path.join(self.config_path, 'facerender.yaml') - - print(self.free_view_checkpoint) - self.animate_from_coeff = AnimateFromCoeff(self.free_view_checkpoint, self.mapping_checkpoint, - self.facerender_yaml_path, self.device) + self.audio_to_coeff = Audio2Coeff(self.sadtalker_paths, self.device) + self.preprocess_model = CropAndExtract(self.sadtalker_paths, self.device) + self.animate_from_coeff = AnimateFromCoeff(self.sadtalker_paths, self.device) time_tag = str(uuid.uuid4()) save_dir = os.path.join(result_dir, time_tag) @@ -104,11 +68,11 @@ def test(self, source_image, driven_audio, preprocess='crop', still_mode=False, os.makedirs(save_dir, exist_ok=True) - pose_style = 0 + #crop image and extract 3dmm from image first_frame_dir = os.path.join(save_dir, 'first_frame_dir') os.makedirs(first_frame_dir, exist_ok=True) - first_coeff_path, crop_pic_path, crop_info = self.preprocess_model.generate(pic_path, first_frame_dir, preprocess) + first_coeff_path, crop_pic_path, crop_info = self.preprocess_model.generate(pic_path, first_frame_dir, preprocess, True, size) if first_coeff_path is None: raise AttributeError("No face is detected") @@ -117,16 +81,14 @@ def test(self, source_image, driven_audio, preprocess='crop', still_mode=False, batch = get_data(first_coeff_path, audio_path, self.device, ref_eyeblink_coeff_path=None, still=still_mode) # longer audio? coeff_path = self.audio_to_coeff.generate(batch, save_dir, pose_style) #coeff2video - batch_size = 2 - data = get_facerender_data(coeff_path, crop_pic_path, first_coeff_path, audio_path, batch_size, still_mode=still_mode, preprocess=preprocess) - return_path = self.animate_from_coeff.generate(data, save_dir, pic_path, crop_info, enhancer='gfpgan' if use_enhancer else None, preprocess=preprocess) + data = get_facerender_data(coeff_path, crop_pic_path, first_coeff_path, audio_path, batch_size, still_mode=still_mode, preprocess=preprocess, size=size) + return_path = self.animate_from_coeff.generate(data, save_dir, pic_path, crop_info, enhancer='gfpgan' if use_enhancer else None, preprocess=preprocess, img_size=size) video_name = data['video_name'] print(f'The generated video is named {video_name} in {save_dir}') - if self.lazy_load: - del self.preprocess_model - del self.audio_to_coeff - del self.animate_from_coeff + del self.preprocess_model + del self.audio_to_coeff + del self.animate_from_coeff if torch.cuda.is_available(): torch.cuda.empty_cache() diff --git a/src/test_audio2coeff.py b/src/test_audio2coeff.py index c3c6abcf..be66bffd 100644 --- a/src/test_audio2coeff.py +++ b/src/test_audio2coeff.py @@ -5,9 +5,13 @@ from yacs.config import CfgNode as CN from scipy.signal import savgol_filter +import safetensors +import safetensors.torch + from src.audio2pose_models.audio2pose import Audio2Pose from src.audio2exp_models.networks import SimpleWrapperV2 -from src.audio2exp_models.audio2exp import Audio2Exp +from src.audio2exp_models.audio2exp import Audio2Exp +from src.utils.safetensor_helper import load_x_from_safetensor def load_cpk(checkpoint_path, model=None, optimizer=None, device="cpu"): checkpoint = torch.load(checkpoint_path, map_location=torch.device(device)) @@ -20,25 +24,28 @@ def load_cpk(checkpoint_path, model=None, optimizer=None, device="cpu"): class Audio2Coeff(): - def __init__(self, audio2pose_checkpoint, audio2pose_yaml_path, - audio2exp_checkpoint, audio2exp_yaml_path, - wav2lip_checkpoint, device): + def __init__(self, sadtalker_path, device): #load config - fcfg_pose = open(audio2pose_yaml_path) + fcfg_pose = open(sadtalker_path['audio2pose_yaml_path']) cfg_pose = CN.load_cfg(fcfg_pose) cfg_pose.freeze() - fcfg_exp = open(audio2exp_yaml_path) + fcfg_exp = open(sadtalker_path['audio2exp_yaml_path']) cfg_exp = CN.load_cfg(fcfg_exp) cfg_exp.freeze() # load audio2pose_model - self.audio2pose_model = Audio2Pose(cfg_pose, wav2lip_checkpoint, device=device) + self.audio2pose_model = Audio2Pose(cfg_pose, None, device=device) self.audio2pose_model = self.audio2pose_model.to(device) self.audio2pose_model.eval() for param in self.audio2pose_model.parameters(): param.requires_grad = False + try: - load_cpk(audio2pose_checkpoint, model=self.audio2pose_model, device=device) + if sadtalker_path['use_safetensor']: + checkpoints = safetensors.torch.load_file(sadtalker_path['checkpoint']) + self.audio2pose_model.load_state_dict(load_x_from_safetensor(checkpoints, 'audio2pose')) + else: + load_cpk(sadtalker_path['audio2pose_checkpoint'], model=self.audio2pose_model, device=device) except: raise Exception("Failed in loading audio2pose_checkpoint") @@ -49,7 +56,11 @@ def __init__(self, audio2pose_checkpoint, audio2pose_yaml_path, netG.requires_grad = False netG.eval() try: - load_cpk(audio2exp_checkpoint, model=netG, device=device) + if sadtalker_path['use_safetensor']: + checkpoints = safetensors.torch.load_file(sadtalker_path['checkpoint']) + netG.load_state_dict(load_x_from_safetensor(checkpoints, 'audio2exp')) + else: + load_cpk(sadtalker_path['audio2exp_checkpoint'], model=netG, device=device) except: raise Exception("Failed in loading audio2exp_checkpoint") self.audio2exp_model = Audio2Exp(netG, cfg_exp, device=device, prepare_training_loss=False) @@ -106,7 +117,8 @@ def using_refpose(self, coeffs_pred_numpy, ref_pose_coeff_path): refpose_coeff_list.append(refpose_coeff[:re, :]) refpose_coeff = np.concatenate(refpose_coeff_list, axis=0) - coeffs_pred_numpy[:, 64:70] = refpose_coeff[:num_frames, :] + #### relative head pose + coeffs_pred_numpy[:, 64:70] = coeffs_pred_numpy[:, 64:70] + ( refpose_coeff[:num_frames, :] - refpose_coeff[0:1, :] ) return coeffs_pred_numpy diff --git a/src/utils/croper.py b/src/utils/croper.py index 0ccf5dfc..3d9a0ac5 100644 --- a/src/utils/croper.py +++ b/src/utils/croper.py @@ -6,57 +6,38 @@ import scipy import numpy as np from PIL import Image +import torch from tqdm import tqdm from itertools import cycle -from torch.multiprocessing import Pool, Process, set_start_method - - -""" -brief: face alignment with FFHQ method (https://github.com/NVlabs/ffhq-dataset) -author: lzhbrian (https://lzhbrian.me) -date: 2020.1.5 -note: code is heavily borrowed from - https://github.com/NVlabs/ffhq-dataset - http://dlib.net/face_landmark_detection.py.html -requirements: - apt install cmake - conda install Pillow numpy scipy - pip install dlib - # download face landmark model from: - # http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2 -""" +from src.face3d.extract_kp_videos_safe import KeypointExtractor +from facexlib.alignment import landmark_98_to_68 import numpy as np from PIL import Image -import dlib - -class Croper: - def __init__(self, path_of_lm): - # download model from: http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2 - self.predictor = dlib.shape_predictor(path_of_lm) +class Preprocesser: + def __init__(self, device='cuda'): + self.predictor = KeypointExtractor(device) def get_landmark(self, img_np): """get landmark with dlib :return: np.array shape=(68, 2) """ - detector = dlib.get_frontal_face_detector() - dets = detector(img_np, 1) - # print("Number of faces detected: {}".format(len(dets))) - # for k, d in enumerate(dets): + with torch.no_grad(): + dets = self.predictor.det_net.detect_faces(img_np, 0.97) + if len(dets) == 0: return None - d = dets[0] - # Get the landmarks/parts for the face in box d. - shape = self.predictor(img_np, d) - # print("Part 0: {}, Part 1: {} ...".format(shape.part(0), shape.part(1))) - t = list(shape.parts()) - a = [] - for tt in t: - a.append([tt.x, tt.y]) - lm = np.array(a) - # lm is a shape=(68,2) np.array + det = dets[0] + + img = img_np[int(det[1]):int(det[3]), int(det[0]):int(det[2]), :] + lm = landmark_98_to_68(self.predictor.detector.get_landmarks(img)) # [0] + + #### keypoints to the original location + lm[:,0] += int(det[0]) + lm[:,1] += int(det[1]) + return lm def align_face(self, img, lm, output_size=1024): @@ -138,34 +119,14 @@ def align_face(self, img, lm, output_size=1024): ly = max(min(quad[1], quad[7]), 0) rx = min(max(quad[4], quad[6]), img.size[0]) ry = min(max(quad[3], quad[5]), img.size[0]) - # img = img.transform((transform_size, transform_size), Image.QUAD, (quad + 0.5).flatten(), - # Image.BILINEAR) - # if output_size < transform_size: - # img = img.resize((output_size, output_size), Image.ANTIALIAS) # Save aligned image. return rsize, crop, [lx, ly, rx, ry] - - # def crop(self, img_np_list): - # for _i in range(len(img_np_list)): - # img_np = img_np_list[_i] - # lm = self.get_landmark(img_np) - # if lm is None: - # return None - # crop, quad = self.align_face(img=Image.fromarray(img_np), lm=lm, output_size=512) - # clx, cly, crx, cry = crop - # lx, ly, rx, ry = quad - # lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry) - - # _inp = img_np_list[_i] - # _inp = _inp[cly:cry, clx:crx] - # _inp = _inp[ly:ry, lx:rx] - # img_np_list[_i] = _inp - # return img_np_list def crop(self, img_np_list, still=False, xsize=512): # first frame for all video img_np = img_np_list[0] lm = self.get_landmark(img_np) + if lm is None: raise 'can not detect the landmark from source image' rsize, crop, quad = self.align_face(img=Image.fromarray(img_np), lm=lm, output_size=xsize) @@ -176,124 +137,8 @@ def crop(self, img_np_list, still=False, xsize=512): # first frame for all vi _inp = img_np_list[_i] _inp = cv2.resize(_inp, (rsize[0], rsize[1])) _inp = _inp[cly:cry, clx:crx] - # cv2.imwrite('test1.jpg', _inp) if not still: _inp = _inp[ly:ry, lx:rx] - # cv2.imwrite('test2.jpg', _inp) img_np_list[_i] = _inp return img_np_list, crop, quad - -def read_video(filename, uplimit=100): - frames = [] - cap = cv2.VideoCapture(filename) - cnt = 0 - while cap.isOpened(): - ret, frame = cap.read() - if ret: - frame = cv2.resize(frame, (512, 512)) - frames.append(frame) - else: - break - cnt += 1 - if cnt >= uplimit: - break - cap.release() - assert len(frames) > 0, f'{filename}: video with no frames!' - return frames - - -def create_video(video_name, frames, fps=25, video_format='.mp4', resize_ratio=1): - # video_name = os.path.dirname(image_folder) + video_format - # img_list = glob.glob1(image_folder, 'frame*') - # img_list.sort() - # frame = cv2.imread(os.path.join(image_folder, img_list[0])) - # frame = cv2.resize(frame, (0, 0), fx=resize_ratio, fy=resize_ratio) - # height, width, layers = frames[0].shape - height, width, layers = 512, 512, 3 - if video_format == '.mp4': - fourcc = cv2.VideoWriter_fourcc(*'mp4v') - elif video_format == '.avi': - fourcc = cv2.VideoWriter_fourcc(*'XVID') - video = cv2.VideoWriter(video_name, fourcc, fps, (width, height)) - for _frame in frames: - _frame = cv2.resize(_frame, (height, width), interpolation=cv2.INTER_LINEAR) - video.write(_frame) - -def create_images(video_name, frames): - height, width, layers = 512, 512, 3 - images_dir = video_name.split('.')[0] - os.makedirs(images_dir, exist_ok=True) - for i, _frame in enumerate(frames): - _frame = cv2.resize(_frame, (height, width), interpolation=cv2.INTER_LINEAR) - _frame_path = os.path.join(images_dir, str(i)+'.jpg') - cv2.imwrite(_frame_path, _frame) - -def run(data): - filename, opt, device = data - os.environ['CUDA_VISIBLE_DEVICES'] = device - croper = Croper() - - frames = read_video(filename, uplimit=opt.uplimit) - name = filename.split('/')[-1] # .split('.')[0] - name = os.path.join(opt.output_dir, name) - - frames = croper.crop(frames) - if frames is None: - print(f'{name}: detect no face. should removed') - return - # create_video(name, frames) - create_images(name, frames) - - -def get_data_path(video_dir): - eg_video_files = ['/apdcephfs/share_1290939/quincheng/datasets/HDTF/backup_fps25/WDA_KatieHill_000.mp4'] - # filenames = list() - # VIDEO_EXTENSIONS_LOWERCASE = {'mp4'} - # VIDEO_EXTENSIONS = VIDEO_EXTENSIONS_LOWERCASE.union({f.upper() for f in VIDEO_EXTENSIONS_LOWERCASE}) - # extensions = VIDEO_EXTENSIONS - # for ext in extensions: - # filenames = sorted(glob.glob(f'{opt.input_dir}/**/*.{ext}')) - # print('Total number of videos:', len(filenames)) - return eg_video_files - - -def get_wra_data_path(video_dir): - if opt.option == 'video': - videos_path = sorted(glob.glob(f'{video_dir}/*.mp4')) - elif opt.option == 'image': - videos_path = sorted(glob.glob(f'{video_dir}/*/')) - else: - raise NotImplementedError - print('Example videos: ', videos_path[:2]) - return videos_path - - -if __name__ == '__main__': - set_start_method('spawn') - parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument('--input_dir', type=str, help='the folder of the input files') - parser.add_argument('--output_dir', type=str, help='the folder of the output files') - parser.add_argument('--device_ids', type=str, default='0,1') - parser.add_argument('--workers', type=int, default=8) - parser.add_argument('--uplimit', type=int, default=500) - parser.add_argument('--option', type=str, default='video') - - root = '/apdcephfs/share_1290939/quincheng/datasets/HDTF' - cmd = f'--input_dir {root}/backup_fps25_first20s_sync/ ' \ - f'--output_dir {root}/crop512_stylegan_firstframe_sync/ ' \ - '--device_ids 0 ' \ - '--workers 8 ' \ - '--option video ' \ - '--uplimit 500 ' - opt = parser.parse_args(cmd.split()) - # filenames = get_data_path(opt.input_dir) - filenames = get_wra_data_path(opt.input_dir) - os.makedirs(opt.output_dir, exist_ok=True) - print(f'Video numbers: {len(filenames)}') - pool = Pool(opt.workers) - args_list = cycle([opt]) - device_ids = opt.device_ids.split(",") - device_ids = cycle(device_ids) - for data in tqdm(pool.imap_unordered(run, zip(filenames, args_list, device_ids))): - None diff --git a/src/utils/init_path.py b/src/utils/init_path.py new file mode 100644 index 00000000..5f38d119 --- /dev/null +++ b/src/utils/init_path.py @@ -0,0 +1,47 @@ +import os +import glob + +def init_path(checkpoint_dir, config_dir, size=512, old_version=False, preprocess='crop'): + + if old_version: + #### load all the checkpoint of `pth` + sadtalker_paths = { + 'wav2lip_checkpoint' : os.path.join(checkpoint_dir, 'wav2lip.pth'), + 'audio2pose_checkpoint' : os.path.join(checkpoint_dir, 'auido2pose_00140-model.pth'), + 'audio2exp_checkpoint' : os.path.join(checkpoint_dir, 'auido2exp_00300-model.pth'), + 'free_view_checkpoint' : os.path.join(checkpoint_dir, 'facevid2vid_00189-model.pth.tar'), + 'path_of_net_recon_model' : os.path.join(checkpoint_dir, 'epoch_20.pth') + } + + use_safetensor = False + elif len(glob.glob(os.path.join(checkpoint_dir, '*.safetensors'))): + print('using safetensor as default') + sadtalker_paths = { + "checkpoint":os.path.join(checkpoint_dir, 'SadTalker_V0.0.2_'+str(size)+'.safetensors'), + } + use_safetensor = True + else: + print("WARNING: The new version of the model will be updated by safetensor, you may need to download it mannully. We run the old version of the checkpoint this time!") + use_safetensor = False + + sadtalker_paths = { + 'wav2lip_checkpoint' : os.path.join(checkpoint_dir, 'wav2lip.pth'), + 'audio2pose_checkpoint' : os.path.join(checkpoint_dir, 'auido2pose_00140-model.pth'), + 'audio2exp_checkpoint' : os.path.join(checkpoint_dir, 'auido2exp_00300-model.pth'), + 'free_view_checkpoint' : os.path.join(checkpoint_dir, 'facevid2vid_00189-model.pth.tar'), + 'path_of_net_recon_model' : os.path.join(checkpoint_dir, 'epoch_20.pth') + } + + sadtalker_paths['dir_of_BFM_fitting'] = os.path.join(config_dir) # , 'BFM_Fitting' + sadtalker_paths['audio2pose_yaml_path'] = os.path.join(config_dir, 'auido2pose.yaml') + sadtalker_paths['audio2exp_yaml_path'] = os.path.join(config_dir, 'auido2exp.yaml') + sadtalker_paths['use_safetensor'] = use_safetensor # os.path.join(config_dir, 'auido2exp.yaml') + + if 'full' in preprocess: + sadtalker_paths['mappingnet_checkpoint'] = os.path.join(checkpoint_dir, 'mapping_00109-model.pth.tar') + sadtalker_paths['facerender_yaml'] = os.path.join(config_dir, 'facerender_still.yaml') + else: + sadtalker_paths['mappingnet_checkpoint'] = os.path.join(checkpoint_dir, 'mapping_00229-model.pth.tar') + sadtalker_paths['facerender_yaml'] = os.path.join(config_dir, 'facerender.yaml') + + return sadtalker_paths \ No newline at end of file diff --git a/src/utils/model2safetensor.py b/src/utils/model2safetensor.py new file mode 100644 index 00000000..50c48500 --- /dev/null +++ b/src/utils/model2safetensor.py @@ -0,0 +1,141 @@ +import torch +import yaml +import os + +import safetensors +from safetensors.torch import save_file +from yacs.config import CfgNode as CN +import sys + +sys.path.append('/apdcephfs/private_shadowcun/SadTalker') + +from src.face3d.models import networks + +from src.facerender.modules.keypoint_detector import HEEstimator, KPDetector +from src.facerender.modules.mapping import MappingNet +from src.facerender.modules.generator import OcclusionAwareGenerator, OcclusionAwareSPADEGenerator + +from src.audio2pose_models.audio2pose import Audio2Pose +from src.audio2exp_models.networks import SimpleWrapperV2 +from src.test_audio2coeff import load_cpk + +size = 256 +############ face vid2vid +config_path = os.path.join('src', 'config', 'facerender.yaml') +current_root_path = '.' + +path_of_net_recon_model = os.path.join(current_root_path, 'checkpoints', 'epoch_20.pth') +net_recon = networks.define_net_recon(net_recon='resnet50', use_last_fc=False, init_path='') +checkpoint = torch.load(path_of_net_recon_model, map_location='cpu') +net_recon.load_state_dict(checkpoint['net_recon']) + +with open(config_path) as f: + config = yaml.safe_load(f) + +generator = OcclusionAwareSPADEGenerator(**config['model_params']['generator_params'], + **config['model_params']['common_params']) +kp_extractor = KPDetector(**config['model_params']['kp_detector_params'], + **config['model_params']['common_params']) +he_estimator = HEEstimator(**config['model_params']['he_estimator_params'], + **config['model_params']['common_params']) +mapping = MappingNet(**config['model_params']['mapping_params']) + +def load_cpk_facevid2vid(checkpoint_path, generator=None, discriminator=None, + kp_detector=None, he_estimator=None, optimizer_generator=None, + optimizer_discriminator=None, optimizer_kp_detector=None, + optimizer_he_estimator=None, device="cpu"): + + checkpoint = torch.load(checkpoint_path, map_location=torch.device(device)) + if generator is not None: + generator.load_state_dict(checkpoint['generator']) + if kp_detector is not None: + kp_detector.load_state_dict(checkpoint['kp_detector']) + if he_estimator is not None: + he_estimator.load_state_dict(checkpoint['he_estimator']) + if discriminator is not None: + try: + discriminator.load_state_dict(checkpoint['discriminator']) + except: + print ('No discriminator in the state-dict. Dicriminator will be randomly initialized') + if optimizer_generator is not None: + optimizer_generator.load_state_dict(checkpoint['optimizer_generator']) + if optimizer_discriminator is not None: + try: + optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator']) + except RuntimeError as e: + print ('No discriminator optimizer in the state-dict. Optimizer will be not initialized') + if optimizer_kp_detector is not None: + optimizer_kp_detector.load_state_dict(checkpoint['optimizer_kp_detector']) + if optimizer_he_estimator is not None: + optimizer_he_estimator.load_state_dict(checkpoint['optimizer_he_estimator']) + + return checkpoint['epoch'] + + +def load_cpk_facevid2vid_safetensor(checkpoint_path, generator=None, + kp_detector=None, he_estimator=None, + device="cpu"): + + checkpoint = safetensors.torch.load_file(checkpoint_path) + + if generator is not None: + x_generator = {} + for k,v in checkpoint.items(): + if 'generator' in k: + x_generator[k.replace('generator.', '')] = v + generator.load_state_dict(x_generator) + if kp_detector is not None: + x_generator = {} + for k,v in checkpoint.items(): + if 'kp_extractor' in k: + x_generator[k.replace('kp_extractor.', '')] = v + kp_detector.load_state_dict(x_generator) + if he_estimator is not None: + x_generator = {} + for k,v in checkpoint.items(): + if 'he_estimator' in k: + x_generator[k.replace('he_estimator.', '')] = v + he_estimator.load_state_dict(x_generator) + + return None + +free_view_checkpoint = '/apdcephfs/private_shadowcun/SadTalker/checkpoints/facevid2vid_'+str(size)+'-model.pth.tar' +load_cpk_facevid2vid(free_view_checkpoint, kp_detector=kp_extractor, generator=generator, he_estimator=he_estimator) + +wav2lip_checkpoint = os.path.join(current_root_path, 'checkpoints', 'wav2lip.pth') + +audio2pose_checkpoint = os.path.join(current_root_path, 'checkpoints', 'auido2pose_00140-model.pth') +audio2pose_yaml_path = os.path.join(current_root_path, 'src', 'config', 'auido2pose.yaml') + +audio2exp_checkpoint = os.path.join(current_root_path, 'checkpoints', 'auido2exp_00300-model.pth') +audio2exp_yaml_path = os.path.join(current_root_path, 'src', 'config', 'auido2exp.yaml') + +fcfg_pose = open(audio2pose_yaml_path) +cfg_pose = CN.load_cfg(fcfg_pose) +cfg_pose.freeze() +audio2pose_model = Audio2Pose(cfg_pose, wav2lip_checkpoint) +audio2pose_model.eval() +load_cpk(audio2pose_checkpoint, model=audio2pose_model, device='cpu') + +# load audio2exp_model +netG = SimpleWrapperV2() +netG.eval() +load_cpk(audio2exp_checkpoint, model=netG, device='cpu') + +class SadTalker(torch.nn.Module): + def __init__(self, kp_extractor, generator, netG, audio2pose, face_3drecon): + super(SadTalker, self).__init__() + self.kp_extractor = kp_extractor + self.generator = generator + self.audio2exp = netG + self.audio2pose = audio2pose + self.face_3drecon = face_3drecon + + +model = SadTalker(kp_extractor, generator, netG, audio2pose_model, net_recon) + +# here, we want to convert it to safetensor +save_file(model.state_dict(), "checkpoints/SadTalker_V0.0.2_"+str(size)+".safetensors") + +### test +load_cpk_facevid2vid_safetensor('checkpoints/SadTalker_V0.0.2_'+str(size)+'.safetensors', kp_detector=kp_extractor, generator=generator, he_estimator=None) \ No newline at end of file diff --git a/src/utils/paste_pic.py b/src/utils/paste_pic.py index cf25aebc..f9989e21 100644 --- a/src/utils/paste_pic.py +++ b/src/utils/paste_pic.py @@ -5,7 +5,7 @@ from src.utils.videoio import save_video_with_watermark -def paste_pic(video_path, pic_path, crop_info, new_audio_path, full_video_path): +def paste_pic(video_path, pic_path, crop_info, new_audio_path, full_video_path, extended_crop=False): if not os.path.isfile(pic_path): raise ValueError('pic_path must be a valid path to video/image file') @@ -47,13 +47,16 @@ def paste_pic(video_path, pic_path, crop_info, new_audio_path, full_video_path): lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry) # oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx # oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx - oy1, oy2, ox1, ox2 = cly, cry, clx, crx + if extended_crop: + oy1, oy2, ox1, ox2 = cly, cry, clx, crx + else: + oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx tmp_path = str(uuid.uuid4())+'.mp4' out_tmp = cv2.VideoWriter(tmp_path, cv2.VideoWriter_fourcc(*'MP4V'), fps, (frame_w, frame_h)) for crop_frame in tqdm(crop_frames, 'seamlessClone:'): - p = cv2.resize(crop_frame.astype(np.uint8), (crx-clx, cry - cly)) + p = cv2.resize(crop_frame.astype(np.uint8), (ox2-ox1, oy2 - oy1)) mask = 255*np.ones(p.shape, p.dtype) location = ((ox1+ox2) // 2, (oy1+oy2) // 2) diff --git a/src/utils/preprocess.py b/src/utils/preprocess.py index db94004c..0f784e6c 100644 --- a/src/utils/preprocess.py +++ b/src/utils/preprocess.py @@ -4,21 +4,19 @@ from PIL import Image # 3dmm extraction +import safetensors +import safetensors.torch from src.face3d.util.preprocess import align_img from src.face3d.util.load_mats import load_lm3d from src.face3d.models import networks -try: - import webui - from src.face3d.extract_kp_videos_safe import KeypointExtractor - assert torch.cuda.is_available() == True -except: - from src.face3d.extract_kp_videos import KeypointExtractor - from scipy.io import loadmat, savemat -from src.utils.croper import Croper +from src.utils.croper import Preprocesser + + +import warnings -import warnings +from src.utils.safetensor_helper import load_x_from_safetensor warnings.filterwarnings("ignore") def split_coeff(coeffs): @@ -46,20 +44,24 @@ def split_coeff(coeffs): class CropAndExtract(): - def __init__(self, path_of_lm_croper, path_of_net_recon_model, dir_of_BFM_fitting, device): + def __init__(self, sadtalker_path, device): - self.croper = Croper(path_of_lm_croper) - self.kp_extractor = KeypointExtractor(device) + self.propress = Preprocesser(device) self.net_recon = networks.define_net_recon(net_recon='resnet50', use_last_fc=False, init_path='').to(device) - checkpoint = torch.load(path_of_net_recon_model, map_location=torch.device(device)) - self.net_recon.load_state_dict(checkpoint['net_recon']) + + if sadtalker_path['use_safetensor']: + checkpoint = safetensors.torch.load_file(sadtalker_path['checkpoint']) + self.net_recon.load_state_dict(load_x_from_safetensor(checkpoint, 'face_3drecon')) + else: + checkpoint = torch.load(sadtalker_path['path_of_net_recon_model'], map_location=torch.device(device)) + self.net_recon.load_state_dict(checkpoint['net_recon']) + self.net_recon.eval() - self.lm3d_std = load_lm3d(dir_of_BFM_fitting) + self.lm3d_std = load_lm3d(sadtalker_path['dir_of_BFM_fitting']) self.device = device - def generate(self, input_path, save_dir, crop_or_resize='crop', source_image_flag=False): + def generate(self, input_path, save_dir, crop_or_resize='crop', source_image_flag=False, pic_size=256): - pic_size = 256 pic_name = os.path.splitext(os.path.split(input_path)[-1])[0] landmarks_path = os.path.join(save_dir, pic_name+'_landmarks.txt') @@ -90,15 +92,15 @@ def generate(self, input_path, save_dir, crop_or_resize='crop', source_image_fla x_full_frames= [cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) for frame in full_frames] #### crop images as the - if crop_or_resize.lower() == 'crop': # default crop - x_full_frames, crop, quad = self.croper.crop(x_full_frames, still=True, xsize=512) + if 'crop' in crop_or_resize.lower(): # default crop + x_full_frames, crop, quad = self.propress.crop(x_full_frames, still=True if 'ext' in crop_or_resize.lower() else False, xsize=512) clx, cly, crx, cry = crop lx, ly, rx, ry = quad lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry) oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx crop_info = ((ox2 - ox1, oy2 - oy1), crop, quad) - elif crop_or_resize.lower() == 'full': - x_full_frames, crop, quad = self.croper.crop(x_full_frames, still=True, xsize=512) + elif 'full' in crop_or_resize.lower(): + x_full_frames, crop, quad = self.propress.crop(x_full_frames, still=True if 'ext' in crop_or_resize.lower() else False, xsize=512) clx, cly, crx, cry = crop lx, ly, rx, ry = quad lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry) @@ -119,7 +121,7 @@ def generate(self, input_path, save_dir, crop_or_resize='crop', source_image_fla # 2. get the landmark according to the detected face. if not os.path.isfile(landmarks_path): - lm = self.kp_extractor.extract_keypoint(frames_pil, landmarks_path) + lm = self.propress.predictor.extract_keypoint(frames_pil, landmarks_path) else: print(' Using saved landmarks.') lm = np.loadtxt(landmarks_path).astype(np.float32) diff --git a/src/utils/safetensor_helper.py b/src/utils/safetensor_helper.py new file mode 100644 index 00000000..3cdbdd21 --- /dev/null +++ b/src/utils/safetensor_helper.py @@ -0,0 +1,8 @@ + + +def load_x_from_safetensor(checkpoint, key): + x_generator = {} + for k,v in checkpoint.items(): + if key in k: + x_generator[k.replace(key+'.', '')] = v + return x_generator \ No newline at end of file diff --git a/src/utils/text2speech.py b/src/utils/text2speech.py index 6948edf1..00d165b6 100644 --- a/src/utils/text2speech.py +++ b/src/utils/text2speech.py @@ -3,7 +3,6 @@ from TTS.api import TTS - class TTSTalker(): def __init__(self) -> None: model_name = TTS.list_models()[0] diff --git a/src/utils/videoio.py b/src/utils/videoio.py index bbb5f099..08bfbdd7 100644 --- a/src/utils/videoio.py +++ b/src/utils/videoio.py @@ -19,7 +19,7 @@ def load_video_to_cv2(input_path): def save_video_with_watermark(video, audio, save_path, watermark=False): temp_file = str(uuid.uuid4())+'.mp4' - cmd = r'ffmpeg -y -i "%s" -i "%s" -vcodec copy "%s"' % (video, audio, temp_file) + cmd = r'ffmpeg -y -hide_banner -loglevel error -i "%s" -i "%s" -vcodec copy "%s"' % (video, audio, temp_file) os.system(cmd) if watermark is False: @@ -36,6 +36,6 @@ def save_video_with_watermark(video, audio, save_path, watermark=False): dir_path = os.path.dirname(os.path.realpath(__file__)) watarmark_path = dir_path+"/../../docs/sadtalker_logo.png" - cmd = r'ffmpeg -y -hide_banner -i "%s" -i "%s" -filter_complex "[1]scale=100:-1[wm];[0][wm]overlay=(main_w-overlay_w)-10:10" "%s"' % (temp_file, watarmark_path, save_path) + cmd = r'ffmpeg -y -hide_banner -loglevel error -i "%s" -i "%s" -filter_complex "[1]scale=100:-1[wm];[0][wm]overlay=(main_w-overlay_w)-10:10" "%s"' % (temp_file, watarmark_path, save_path) os.system(cmd) os.remove(temp_file) \ No newline at end of file diff --git a/webui.sh b/webui.sh new file mode 100755 index 00000000..24575023 --- /dev/null +++ b/webui.sh @@ -0,0 +1,140 @@ +#!/usr/bin/env bash + + +# If run from macOS, load defaults from webui-macos-env.sh +if [[ "$OSTYPE" == "darwin"* ]]; then + export TORCH_COMMAND="pip install torch==1.12.1 torchvision==0.13.1" +fi + +# python3 executable +if [[ -z "${python_cmd}" ]] +then + python_cmd="python3" +fi + +# git executable +if [[ -z "${GIT}" ]] +then + export GIT="git" +fi + +# python3 venv without trailing slash (defaults to ${install_dir}/${clone_dir}/venv) +if [[ -z "${venv_dir}" ]] +then + venv_dir="venv" +fi + +if [[ -z "${LAUNCH_SCRIPT}" ]] +then + LAUNCH_SCRIPT="launcher.py" +fi + +# this script cannot be run as root by default +can_run_as_root=1 + +# read any command line flags to the webui.sh script +while getopts "f" flag > /dev/null 2>&1 +do + case ${flag} in + f) can_run_as_root=1;; + *) break;; + esac +done + +# Disable sentry logging +export ERROR_REPORTING=FALSE + +# Do not reinstall existing pip packages on Debian/Ubuntu +export PIP_IGNORE_INSTALLED=0 + +# Pretty print +delimiter="################################################################" + +printf "\n%s\n" "${delimiter}" +printf "\e[1m\e[32mInstall script for SadTalker + Web UI\n" +printf "\e[1m\e[34mTested on Debian 11 (Bullseye)\e[0m" +printf "\n%s\n" "${delimiter}" + +# Do not run as root +if [[ $(id -u) -eq 0 && can_run_as_root -eq 0 ]] +then + printf "\n%s\n" "${delimiter}" + printf "\e[1m\e[31mERROR: This script must not be launched as root, aborting...\e[0m" + printf "\n%s\n" "${delimiter}" + exit 1 +else + printf "\n%s\n" "${delimiter}" + printf "Running on \e[1m\e[32m%s\e[0m user" "$(whoami)" + printf "\n%s\n" "${delimiter}" +fi + +if [[ -d .git ]] +then + printf "\n%s\n" "${delimiter}" + printf "Repo already cloned, using it as install directory" + printf "\n%s\n" "${delimiter}" + install_dir="${PWD}/../" + clone_dir="${PWD##*/}" +fi + +# Check prerequisites +gpu_info=$(lspci 2>/dev/null | grep VGA) +case "$gpu_info" in + *"Navi 1"*|*"Navi 2"*) export HSA_OVERRIDE_GFX_VERSION=10.3.0 + ;; + *"Renoir"*) export HSA_OVERRIDE_GFX_VERSION=9.0.0 + printf "\n%s\n" "${delimiter}" + printf "Experimental support for Renoir: make sure to have at least 4GB of VRAM and 10GB of RAM or enable cpu mode: --use-cpu all --no-half" + printf "\n%s\n" "${delimiter}" + ;; + *) + ;; +esac +if echo "$gpu_info" | grep -q "AMD" && [[ -z "${TORCH_COMMAND}" ]] +then + export TORCH_COMMAND="pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/rocm5.2" +fi + +for preq in "${GIT}" "${python_cmd}" +do + if ! hash "${preq}" &>/dev/null + then + printf "\n%s\n" "${delimiter}" + printf "\e[1m\e[31mERROR: %s is not installed, aborting...\e[0m" "${preq}" + printf "\n%s\n" "${delimiter}" + exit 1 + fi +done + +if ! "${python_cmd}" -c "import venv" &>/dev/null +then + printf "\n%s\n" "${delimiter}" + printf "\e[1m\e[31mERROR: python3-venv is not installed, aborting...\e[0m" + printf "\n%s\n" "${delimiter}" + exit 1 +fi + +printf "\n%s\n" "${delimiter}" +printf "Create and activate python venv" +printf "\n%s\n" "${delimiter}" +cd "${install_dir}"/"${clone_dir}"/ || { printf "\e[1m\e[31mERROR: Can't cd to %s/%s/, aborting...\e[0m" "${install_dir}" "${clone_dir}"; exit 1; } +if [[ ! -d "${venv_dir}" ]] +then + "${python_cmd}" -m venv "${venv_dir}" + first_launch=1 +fi +# shellcheck source=/dev/null +if [[ -f "${venv_dir}"/bin/activate ]] +then + source "${venv_dir}"/bin/activate +else + printf "\n%s\n" "${delimiter}" + printf "\e[1m\e[31mERROR: Cannot activate python venv, aborting...\e[0m" + printf "\n%s\n" "${delimiter}" + exit 1 +fi + +printf "\n%s\n" "${delimiter}" +printf "Launching launcher.py..." +printf "\n%s\n" "${delimiter}" +exec "${python_cmd}" "${LAUNCH_SCRIPT}" "$@" \ No newline at end of file