Skip to content

Commit

Permalink
use variance in gradient guide instead of std
Browse files Browse the repository at this point in the history
  • Loading branch information
jannerm committed Oct 17, 2022
1 parent da8b87c commit 3d7361c
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
3 changes: 2 additions & 1 deletion diffuser/sampling/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@ def n_step_guided_p_sample(
):
model_log_variance = extract(model.posterior_log_variance_clipped, t, x.shape)
model_std = torch.exp(0.5 * model_log_variance)
model_var = torch.exp(model_log_variance)

for _ in range(n_guide_steps):
with torch.enable_grad():
y, grad = guide.gradients(x, cond, t)

if scale_grad_by_std:
grad = model_std * grad
grad = model_var * grad

grad[t < t_stopgrad] = 0

Expand Down
2 changes: 1 addition & 1 deletion slurm/plan.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ do
python -u scripts/plan_guided.py \
--logbase logs/pretrained \
--dataset $env-$buffer-v2 \
--prefix plans/reference \
--prefix plans/reference_var \
--vis_freq 500 \
--verbose False \
--suffix {1} \
Expand Down

0 comments on commit 3d7361c

Please sign in to comment.