-
Notifications
You must be signed in to change notification settings - Fork 14
/
system_gm_ldm.py
116 lines (81 loc) · 4.93 KB
/
system_gm_ldm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers import DDIMScheduler
from gm_ldm import GaussianDrivenLDM
import tqdm
class GMLDMSystem(nn.Module):
def __init__(self, opt):
super().__init__()
self.opt = opt
self.image_size = self.opt.network.image_size
self.latent_size = self.opt.network.latent_size
self.latent_channel = self.opt.network.latent_channel
self.model = GaussianDrivenLDM(opt)
self.scheduler = DDIMScheduler(beta_schedule='scaled_linear', beta_start=0.00085, beta_end=0.012, prediction_type="sample", clip_sample=False, steps_offset=9, rescale_betas_zero_snr=True, set_alpha_to_one=True)
self.register_buffer("alphas_cumprod", self.scheduler.alphas_cumprod, persistent=False)
self.num_train_timesteps = self.scheduler.config.num_train_timesteps
self.min_step = 0
self.max_step = int(self.num_train_timesteps)
self.num_input_views = self.opt.network.num_input_views
def to(self, device):
self.device = device
return super().to(device)
def inference_one_step(self, cameras, latents_noisy, text_embeddings, uncond_text_embeddings, t, guidance_scale=7.5, use_3d_mode=True):
_latents_noisy = latents_noisy.clone()
B, N, _, _ ,_ = latents_noisy.shape
_t = t[..., None].repeat(1, N)
uncond_latents_noisy = latents_noisy.clone()
uncond_t = _t.clone()
if use_3d_mode:
latents_noisy = latents_noisy
cameras = cameras
text_embeddings = text_embeddings
tt = _t