Skip to content

Commit

Permalink
v0.0.2rc:fixed multiple annoy problems
Browse files Browse the repository at this point in the history
  • Loading branch information
shadowcun committed Jun 5, 2023
1 parent 96740ee commit 1119361
Show file tree
Hide file tree
Showing 27 changed files with 705 additions and 518 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,8 @@ cython_debug/

examples/results/*
gfpgan/*
checkpoints/
checkpoints/*
assets/*
results/*
Dockerfile
start_docker.sh
Expand Down
177 changes: 83 additions & 94 deletions app.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,70 @@
import os, sys
import tempfile
import gradio as gr
from src.gradio_demo import SadTalker

def get_source_image(image):
return image

def sadtalker_demo():
try:
import webui # in webui
in_webui = True
except:
in_webui = False

sad_talker = SadTalker(lazy_load=True)
# mimetypes.init()
# mimetypes.add_type('application/javascript', '.js')

# script_path = os.path.dirname(os.path.realpath(__file__))

# def webpath(fn):
# if fn.startswith(script_path):
# web_path = os.path.relpath(fn, script_path).replace('\\', '/')
# else:
# web_path = os.path.abspath(fn)

# return f'file={web_path}?{os.path.getmtime(fn)}'

# def javascript_html():
# # Ensure localization is in `window` before scripts
# # head = f'<script type="text/javascript">{localization.localization_js(opts.localization)}</script>\n'
# head = 'somehead'

# script_js = os.path.join(script_path, "assets", "script.js")
# head += f'<script type="text/javascript" src="{webpath(script_js)}"></script>\n'

# script_js = os.path.join(script_path, "assets", "aspectRatioOverlay.js")
# head += f'<script type="text/javascript" src="{webpath(script_js)}"></script>\n'

# return head

# def resize_from_to_html(width, height, scale_by):
# target_width = int(width * scale_by)
# target_height = int(height * scale_by)

# if not target_width or not target_height:
# return "no image selected"

# return f"resize: from <span class='resolution'>{width}x{height}</span> to <span class='resolution'>{target_width}x{target_height}</span>"

# def get_source_image(image):
# return image

# def reload_javascript():
# js = javascript_html()

# def template_response(*args, **kwargs):
# res = shared.GradioTemplateResponseOriginal(*args, **kwargs)
# res.body = res.body.replace(b'</head>', f'{js}</head>'.encode("utf8"))
# res.init_headers()
# return res

# gradio.routes.templates.TemplateResponse = template_response

# if not hasattr(shared, 'GradioTemplateResponseOriginal'):
# shared.GradioTemplateResponseOriginal = gradio.routes.templates.TemplateResponse


def sadtalker_demo(checkpoint_path='checkpoint', config_path='src/config', warpfn=None):

sad_talker = SadTalker(checkpoint_path, config_path, lazy_load=True)

with gr.Blocks(analytics_enabled=False) as sadtalker_interface:
gr.Markdown("<div align='center'> <h2> 😭 SadTalker: Learning Realistic 3D Motion Coefficients for Stylized Audio-Driven Single Image Talking Face Animation (CVPR 2023) </span> </h2> \
Expand All @@ -21,120 +77,52 @@ def sadtalker_demo():
with gr.Tabs(elem_id="sadtalker_source_image"):
with gr.TabItem('Upload image'):
with gr.Row():
source_image = gr.Image(label="Source image", source="upload", type="filepath").style(height=256,width=256)

source_image = gr.Image(label="Source image", source="upload", type="filepath", elem_id="img2img_image").style(width=512)


with gr.Tabs(elem_id="sadtalker_driven_audio"):
with gr.TabItem('Upload OR TTS'):
with gr.Column(variant='panel'):
driven_audio = gr.Audio(label="Input audio", source="upload", type="filepath")

if sys.platform != 'win32':
if sys.platform != 'win32' and not in_webui:
from src.utils.text2speech import TTSTalker
tts_talker = TTSTalker()
with gr.Column(variant='panel'):
input_text = gr.Textbox(label="Generating audio from text", lines=5, placeholder="please enter some text here, we genreate the audio from text using @Coqui.ai TTS.")
tts = gr.Button('Generate audio',elem_id="sadtalker_audio_generate", variant='primary')
tts.click(fn=tts_talker.test, inputs=[input_text], outputs=[driven_audio])


with gr.Column(variant='panel'):
with gr.Tabs(elem_id="sadtalker_checkbox"):
with gr.TabItem('Settings'):
gr.Markdown("need help? please visit our [[best practice page](https://github.com/OpenTalker/SadTalker/blob/main/docs/best_practice.md)] for more detials")
with gr.Column(variant='panel'):
preprocess_type = gr.Radio(['crop','resize','full'], value='crop', label='preprocess', info="How to handle input image?")
is_still_mode = gr.Checkbox(label="w/ Still Mode (fewer hand motion, works with preprocess `full`)")
enhancer = gr.Checkbox(label="w/ GFPGAN as Face enhancer")
# width = gr.Slider(minimum=64, elem_id="img2img_width", maximum=2048, step=8, label="Manually Crop Width", value=512) # img2img_width
# height = gr.Slider(minimum=64, elem_id="img2img_height", maximum=2048, step=8, label="Manually Crop Height", value=512) # img2img_width
pose_style = gr.Slider(minimum=0, maximum=46, step=1, label="Pose style", value=0) #
size_of_image = gr.Radio([256, 512], value=256, label='face model resolution', info="use 256/512 model?") #
preprocess_type = gr.Radio(['crop', 'resize','full', 'extcrop', 'extfull'], value='crop', label='preprocess', info="How to handle input image?")
is_still_mode = gr.Checkbox(label="Still Mode (fewer hand motion, works with preprocess `full`)")
batch_size = gr.Slider(label="batch size in generation", step=1, maximum=10, value=2)
enhancer = gr.Checkbox(label="GFPGAN as Face enhancer")
submit = gr.Button('Generate', elem_id="sadtalker_generate", variant='primary')


with gr.Tabs(elem_id="sadtalker_genearted"):
gen_video = gr.Video(label="Generated video", format="mp4").style(width=256)



with gr.Row():
examples = [
[
'examples/source_image/full_body_1.png',
'examples/driven_audio/bus_chinese.wav',
'crop',
True,
False
],
[
'examples/source_image/full_body_2.png',
'examples/driven_audio/japanese.wav',
'crop',
False,
False
],
[
'examples/source_image/full3.png',
'examples/driven_audio/deyu.wav',
'crop',
False,
True
],
[
'examples/source_image/full4.jpeg',
'examples/driven_audio/eluosi.wav',
'full',
False,
True
],
[
'examples/source_image/full4.jpeg',
'examples/driven_audio/imagine.wav',
'full',
True,
True
],
[
'examples/source_image/full_body_1.png',
'examples/driven_audio/bus_chinese.wav',
'full',
True,
False
],
[
'examples/source_image/art_13.png',
'examples/driven_audio/fayu.wav',
'resize',
True,
False
],
[
'examples/source_image/art_5.png',
'examples/driven_audio/chinese_news.wav',
'resize',
False,
False
],
[
'examples/source_image/art_5.png',
'examples/driven_audio/RD_Radio31_000.wav',
'resize',
True,
True
],
]
gr.Examples(examples=examples,
inputs=[
source_image,
driven_audio,
preprocess_type,
is_still_mode,
enhancer],
outputs=[gen_video],
fn=sad_talker.test,
cache_examples=os.getenv('SYSTEM') == 'spaces') #

submit.click(
fn=sad_talker.test,
inputs=[source_image,
driven_audio,
preprocess_type,
is_still_mode,
enhancer],
enhancer,
batch_size,
size_of_image,
pose_style
],
outputs=[gen_video]
)

Expand All @@ -144,6 +132,7 @@ def sadtalker_demo():
if __name__ == "__main__":

demo = sadtalker_demo()
demo.launch(share=True)
demo.queue()
demo.launch()


72 changes: 29 additions & 43 deletions inference.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from glob import glob
import shutil
import torch
from time import strftime
import os, sys, time
Expand All @@ -8,6 +10,7 @@
from src.facerender.animate import AnimateFromCoeff
from src.generate_batch import get_data
from src.generate_facerender_batch import get_facerender_data
from src.utils.init_path import init_path

def main(args):
#torch.backends.cudnn.enabled = False
Expand All @@ -25,51 +28,23 @@ def main(args):
ref_eyeblink = args.ref_eyeblink
ref_pose = args.ref_pose

current_code_path = sys.argv[0]
current_root_path = os.path.split(current_code_path)[0]
current_root_path = os.path.split(sys.argv[0])[0]

os.environ['TORCH_HOME']=os.path.join(current_root_path, args.checkpoint_dir)

path_of_lm_croper = os.path.join(current_root_path, args.checkpoint_dir, 'shape_predictor_68_face_landmarks.dat')
path_of_net_recon_model = os.path.join(current_root_path, args.checkpoint_dir, 'epoch_20.pth')
dir_of_BFM_fitting = os.path.join(current_root_path, args.checkpoint_dir, 'BFM_Fitting')
wav2lip_checkpoint = os.path.join(current_root_path, args.checkpoint_dir, 'wav2lip.pth')

audio2pose_checkpoint = os.path.join(current_root_path, args.checkpoint_dir, 'auido2pose_00140-model.pth')
audio2pose_yaml_path = os.path.join(current_root_path, 'src', 'config', 'auido2pose.yaml')

audio2exp_checkpoint = os.path.join(current_root_path, args.checkpoint_dir, 'auido2exp_00300-model.pth')
audio2exp_yaml_path = os.path.join(current_root_path, 'src', 'config', 'auido2exp.yaml')

free_view_checkpoint = os.path.join(current_root_path, args.checkpoint_dir, 'facevid2vid_00189-model.pth.tar')

if args.preprocess == 'full':
mapping_checkpoint = os.path.join(current_root_path, args.checkpoint_dir, 'mapping_00109-model.pth.tar')
facerender_yaml_path = os.path.join(current_root_path, 'src', 'config', 'facerender_still.yaml')
else:
mapping_checkpoint = os.path.join(current_root_path, args.checkpoint_dir, 'mapping_00229-model.pth.tar')
facerender_yaml_path = os.path.join(current_root_path, 'src', 'config', 'facerender.yaml')
sadtalker_paths = init_path(args.checkpoint_dir, os.path.join(current_root_path, 'src/config'), args.size, args.old_version, args.preprocess)

#init model
print(path_of_net_recon_model)
preprocess_model = CropAndExtract(path_of_lm_croper, path_of_net_recon_model, dir_of_BFM_fitting, device)

print(audio2pose_checkpoint)
print(audio2exp_checkpoint)
audio_to_coeff = Audio2Coeff(audio2pose_checkpoint, audio2pose_yaml_path,
audio2exp_checkpoint, audio2exp_yaml_path,
wav2lip_checkpoint, device)
preprocess_model = CropAndExtract(sadtalker_paths, device)

audio_to_coeff = Audio2Coeff(sadtalker_paths, device)

print(free_view_checkpoint)
print(mapping_checkpoint)
animate_from_coeff = AnimateFromCoeff(free_view_checkpoint, mapping_checkpoint,
facerender_yaml_path, device)
animate_from_coeff = AnimateFromCoeff(sadtalker_paths, device)

#crop image and extract 3dmm from image
first_frame_dir = os.path.join(save_dir, 'first_frame_dir')
os.makedirs(first_frame_dir, exist_ok=True)
print('3DMM Extraction for source image')
first_coeff_path, crop_pic_path, crop_info = preprocess_model.generate(pic_path, first_frame_dir, args.preprocess, source_image_flag=True)
first_coeff_path, crop_pic_path, crop_info = preprocess_model.generate(pic_path, first_frame_dir, args.preprocess,\
source_image_flag=True, pic_size=args.size)
if first_coeff_path is None:
print("Can't get the coeffs of the input")
return
Expand All @@ -79,7 +54,7 @@ def main(args):
ref_eyeblink_frame_dir = os.path.join(save_dir, ref_eyeblink_videoname)
os.makedirs(ref_eyeblink_frame_dir, exist_ok=True)
print('3DMM Extraction for the reference video providing eye blinking')
ref_eyeblink_coeff_path, _, _ = preprocess_model.generate(ref_eyeblink, ref_eyeblink_frame_dir)
ref_eyeblink_coeff_path, _, _ = preprocess_model.generate(ref_eyeblink, ref_eyeblink_frame_dir, args.preprocess, source_image_flag=False)
else:
ref_eyeblink_coeff_path=None

Expand All @@ -91,7 +66,7 @@ def main(args):
ref_pose_frame_dir = os.path.join(save_dir, ref_pose_videoname)
os.makedirs(ref_pose_frame_dir, exist_ok=True)
print('3DMM Extraction for the reference video providing pose')
ref_pose_coeff_path, _, _ = preprocess_model.generate(ref_pose, ref_pose_frame_dir)
ref_pose_coeff_path, _, _ = preprocess_model.generate(ref_pose, ref_pose_frame_dir, args.preprocess, source_image_flag=False)
else:
ref_pose_coeff_path=None

Expand All @@ -107,22 +82,30 @@ def main(args):
#coeff2video
data = get_facerender_data(coeff_path, crop_pic_path, first_coeff_path, audio_path,
batch_size, input_yaw_list, input_pitch_list, input_roll_list,
expression_scale=args.expression_scale, still_mode=args.still, preprocess=args.preprocess)
expression_scale=args.expression_scale, still_mode=args.still, preprocess=args.preprocess, size=args.size)

animate_from_coeff.generate(data, save_dir, pic_path, crop_info, \
enhancer=args.enhancer, background_enhancer=args.background_enhancer, preprocess=args.preprocess)
result = animate_from_coeff.generate(data, save_dir, pic_path, crop_info, \
enhancer=args.enhancer, background_enhancer=args.background_enhancer, preprocess=args.preprocess, img_size=args.size)

shutil.move(result, save_dir+'.mp4')
print('The generated video is named:', save_dir+'.mp4')

if not args.verbose:
shutil.rmtree(save_dir)


if __name__ == '__main__':

parser = ArgumentParser()
parser.add_argument("--driven_audio", default='./examples/driven_audio/bus_chinese.wav', help="path to driven audio")
parser.add_argument("--source_image", default='./examples/source_image/full_body_2.png', help="path to source image")
parser.add_argument("--source_image", default='./examples/source_image/full_body_1.png', help="path to source image")
parser.add_argument("--ref_eyeblink", default=None, help="path to reference video providing eye blinking")
parser.add_argument("--ref_pose", default=None, help="path to reference video providing pose")
parser.add_argument("--checkpoint_dir", default='./checkpoints', help="path to output")
parser.add_argument("--result_dir", default='./results', help="path to output")
parser.add_argument("--pose_style", type=int, default=0, help="input pose style from [0, 46)")
parser.add_argument("--batch_size", type=int, default=2, help="the batch size of facerender")
parser.add_argument("--size", type=int, default=256, help="the image size of the facerender")
parser.add_argument("--expression_scale", type=float, default=1., help="the batch size of facerender")
parser.add_argument('--input_yaw', nargs='+', type=int, default=None, help="the input yaw degree of the user ")
parser.add_argument('--input_pitch', nargs='+', type=int, default=None, help="the input pitch degree of the user")
Expand All @@ -132,7 +115,10 @@ def main(args):
parser.add_argument("--cpu", dest="cpu", action="store_true")
parser.add_argument("--face3dvis", action="store_true", help="generate 3d face and 3d landmarks")
parser.add_argument("--still", action="store_true", help="can crop back to the original videos for the full body aniamtion")
parser.add_argument("--preprocess", default='crop', choices=['crop', 'resize', 'full'], help="how to preprocess the images" )
parser.add_argument("--preprocess", default='crop', choices=['crop', 'extcrop', 'resize', 'full', 'extfull'], help="how to preprocess the images" )
parser.add_argument("--verbose",action="store_true", help="saving the intermedia output or not" )
parser.add_argument("--old_version",action="store_true", help="use the pth other than safetensor version" )


# net structure and parameters
parser.add_argument('--net_recon', type=str, default='resnet50', choices=['resnet18', 'resnet34', 'resnet50'], help='useless')
Expand Down
Loading

0 comments on commit 1119361

Please sign in to comment.