Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

20230623 blip diffusion documentation #441

Merged
merged 45 commits into from
Jul 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
4f41867
first cut on blip-diffusion.
dxli94 Jun 13, 2023
981cb3b
add blip-diffusion controlnet.
dxli94 Jun 14, 2023
e211495
remove controlnet sub-model.
dxli94 Jun 14, 2023
4edb577
add script to download annotator weights.
dxli94 Jun 14, 2023
1a8677b
add editing capability and zero-shot editing notebook.
dxli94 Jun 15, 2023
e776cb0
further trim the code. update notebook.
dxli94 Jun 15, 2023
c87cc9d
minor formatting.
dxli94 Jun 15, 2023
b20dd88
refactor prompt reps. fix cache embedding loading.
dxli94 Jun 15, 2023
15bd3a2
renaming.
dxli94 Jun 16, 2023
e48b9e1
wip edict.
dxli94 Jun 17, 2023
b4222ed
use _init_latents
dxli94 Jun 19, 2023
2c70f8e
fix init_latent.
dxli94 Jun 19, 2023
87462b7
wip editing real image.s
dxli94 Jun 19, 2023
4f5e14e
fix ddim inversion.
dxli94 Jun 19, 2023
8ba27b2
fix editing issue.
dxli94 Jun 19, 2023
01bd2c2
working version of editing real.
dxli94 Jun 19, 2023
db12364
minor.
dxli94 Jun 19, 2023
e31e061
add controller reset to editing.
dxli94 Jun 19, 2023
7214f77
all zero-shot capabilities done.
dxli94 Jun 19, 2023
91d53b4
minor renaming.
dxli94 Jun 20, 2023
72c48db
minor.
dxli94 Jun 20, 2023
74ae950
add constant lr scheduler.
dxli94 Jun 21, 2023
9dd7d3d
finetuning code.
dxli94 Jun 22, 2023
791b406
change sample keywards for inference; add example notebooks.
dxli94 Jun 23, 2023
f484804
update README.
dxli94 Jun 23, 2023
a4a6b84
upload checkpoint to gcloud.
dxli94 Jun 23, 2023
adc4e87
BLIP-Diffusion examples (#1)
dxli94 Jun 23, 2023
03b4e19
Update train_db.sh to accept batch_size_train
LiJunnan1992 Jun 26, 2023
1fa9463
remove unused print.
dxli94 Jul 10, 2023
37789c9
20230623 blip diffusion documentation (#2)
dxli94 Jul 10, 2023
2bf79be
add embed youtube video.
dxli94 Jul 10, 2023
a3bdbcc
remove embedded video.
dxli94 Jul 10, 2023
08e4f0e
remove ununsed script.
dxli94 Jul 10, 2023
330a444
final touch
dxli94 Jul 10, 2023
86150e3
remove unused notebooks.
dxli94 Jul 10, 2023
7304080
remove unused script.
dxli94 Jul 10, 2023
849fdd8
Merge branch 'main' into 20230623-blip-diffusion-documentation
dxli94 Jul 10, 2023
59c865a
remove unused scripts.
dxli94 Jul 10, 2023
ecf1d03
20230623 blip diffusion documentation (#3)
dxli94 Jul 10, 2023
009477f
rerun notebook.
dxli94 Jul 10, 2023
ade8110
re-merge the notebooks.
dxli94 Jul 10, 2023
87ad99c
re-merge notebooks.
dxli94 Jul 10, 2023
dbe7d05
re-merge notebooks.
dxli94 Jul 10, 2023
5d7d6c2
remove editing and finetune.
dxli94 Jul 21, 2023
6ea0b17
Merge remote-tracking branch 'salesforce/main' into 20230623-blip-dif…
dxli94 Jul 21, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
6 changes: 6 additions & 0 deletions lavis/common/annotator/canny/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import cv2


class CannyDetector:
def __call__(self, img, low_threshold, high_threshold):
return cv2.Canny(img, low_threshold, high_threshold)
5 changes: 5 additions & 0 deletions lavis/common/annotator/ckpts/download.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#! /bin/bash

wget https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/dpt_hybrid-midas-501f0c75.pt
wget https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/network-bsds500.pth

132 changes: 132 additions & 0 deletions lavis/common/annotator/hed/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import numpy as np
import cv2
import os
import torch
from einops import rearrange
from annotator.util import annotator_ckpts_path


class Network(torch.nn.Module):
def __init__(self, model_path):
super().__init__()

self.netVggOne = torch.nn.Sequential(
torch.nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1),
torch.nn.ReLU(inplace=False),
torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
torch.nn.ReLU(inplace=False)
)

self.netVggTwo = torch.nn.Sequential(
torch.nn.MaxPool2d(kernel_size=2, stride=2),
torch.nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
torch.nn.ReLU(inplace=False),
torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
torch.nn.ReLU(inplace=False)
)

self.netVggThr = torch.nn.Sequential(
torch.nn.MaxPool2d(kernel_size=2, stride=2),
torch.nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1),
torch.nn.ReLU(inplace=False),
torch.nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
torch.nn.ReLU(inplace=False),
torch.nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
torch.nn.ReLU(inplace=False)
)

self.netVggFou = torch.nn.Sequential(
torch.nn.MaxPool2d(kernel_size=2, stride=2),
torch.nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1),
torch.nn.ReLU(inplace=False),
torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
torch.nn.ReLU(inplace=False),
torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
torch.nn.ReLU(inplace=False)
)

self.netVggFiv = torch.nn.Sequential(
torch.nn.MaxPool2d(kernel_size=2, stride=2),
torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
torch.nn.ReLU(inplace=False),
torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
torch.nn.ReLU(inplace=False),
torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
torch.nn.ReLU(inplace=False)
)

self.netScoreOne = torch.nn.Conv2d(in_channels=64, out_channels=1, kernel_size=1, stride=1, padding=0)
self.netScoreTwo = torch.nn.Conv2d(in_channels=128, out_channels=1, kernel_size=1, stride=1, padding=0)
self.netScoreThr = torch.nn.Conv2d(in_channels=256, out_channels=1, kernel_size=1, stride=1, padding=0)
self.netScoreFou = torch.nn.Conv2d(in_channels=512, out_channels=1, kernel_size=1, stride=1, padding=0)
self.netScoreFiv = torch.nn.Conv2d(in_channels=512, out_channels=1, kernel_size=1, stride=1, padding=0)

self.netCombine = torch.nn.Sequential(
torch.nn.Conv2d(in_channels=5, out_channels=1, kernel_size=1, stride=1, padding=0),
torch.nn.Sigmoid()
)

self.load_state_dict({strKey.replace('module', 'net'): tenWeight for strKey, tenWeight in torch.load(model_path).items()})

def forward(self, tenInput):
tenInput = tenInput * 255.0
tenInput = tenInput - torch.tensor(data=[104.00698793, 116.66876762, 122.67891434], dtype=tenInput.dtype, device=tenInput.device).view(1, 3, 1, 1)

tenVggOne = self.netVggOne(tenInput)
tenVggTwo = self.netVggTwo(tenVggOne)
tenVggThr = self.netVggThr(tenVggTwo)
tenVggFou = self.netVggFou(tenVggThr)
tenVggFiv = self.netVggFiv(tenVggFou)

tenScoreOne = self.netScoreOne(tenVggOne)
tenScoreTwo = self.netScoreTwo(tenVggTwo)
tenScoreThr = self.netScoreThr(tenVggThr)
tenScoreFou = self.netScoreFou(tenVggFou)
tenScoreFiv = self.netScoreFiv(tenVggFiv)

tenScoreOne = torch.nn.functional.interpolate(input=tenScoreOne, size=(tenInput.shape[2], tenInput.shape[3]), mode='bilinear', align_corners=False)
tenScoreTwo = torch.nn.functional.interpolate(input=tenScoreTwo, size=(tenInput.shape[2], tenInput.shape[3]), mode='bilinear', align_corners=False)
tenScoreThr = torch.nn.functional.interpolate(input=tenScoreThr, size=(tenInput.shape[2], tenInput.shape[3]), mode='bilinear', align_corners=False)
tenScoreFou = torch.nn.functional.interpolate(input=tenScoreFou, size=(tenInput.shape[2], tenInput.shape[3]), mode='bilinear', align_corners=False)
tenScoreFiv = torch.nn.functional.interpolate(input=tenScoreFiv, size=(tenInput.shape[2], tenInput.shape[3]), mode='bilinear', align_corners=False)

return self.netCombine(torch.cat([ tenScoreOne, tenScoreTwo, tenScoreThr, tenScoreFou, tenScoreFiv ], 1))


class HEDdetector:
def __init__(self):
remote_model_path = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/network-bsds500.pth"
modelpath = os.path.join(annotator_ckpts_path, "network-bsds500.pth")
if not os.path.exists(modelpath):
from basicsr.utils.download_util import load_file_from_url
load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path)
self.netNetwork = Network(modelpath).cuda().eval()

def __call__(self, input_image):
assert input_image.ndim == 3
input_image = input_image[:, :, ::-1].copy()
with torch.no_grad():
image_hed = torch.from_numpy(input_image).float().cuda()
image_hed = image_hed / 255.0
image_hed = rearrange(image_hed, 'h w c -> 1 c h w')
edge = self.netNetwork(image_hed)[0]
edge = (edge.cpu().numpy() * 255.0).clip(0, 255).astype(np.uint8)
return edge[0]


def nms(x, t, s):
x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s)

f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8)
f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8)
f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8)
f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8)

y = np.zeros_like(x)

for f in [f1, f2, f3, f4]:
np.putmask(y, cv2.dilate(x, kernel=f) == x, x)

z = np.zeros_like(y, dtype=np.uint8)
z[y > t] = 255
return z
38 changes: 38 additions & 0 deletions lavis/common/annotator/midas/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import cv2
import numpy as np
import torch

from einops import rearrange
from .api import MiDaSInference


class MidasDetector:
def __init__(self):
self.model = MiDaSInference(model_type="dpt_hybrid").cuda()

def __call__(self, input_image, a=np.pi * 2.0, bg_th=0.1):
assert input_image.ndim == 3
image_depth = input_image
with torch.no_grad():
image_depth = torch.from_numpy(image_depth).float().cuda()
image_depth = image_depth / 127.5 - 1.0
image_depth = rearrange(image_depth, 'h w c -> 1 c h w')
depth = self.model(image_depth)[0]

depth_pt = depth.clone()
depth_pt -= torch.min(depth_pt)
depth_pt /= torch.max(depth_pt)
depth_pt = depth_pt.cpu().numpy()
depth_image = (depth_pt * 255.0).clip(0, 255).astype(np.uint8)

depth_np = depth.cpu().numpy()
x = cv2.Sobel(depth_np, cv2.CV_32F, 1, 0, ksize=3)
y = cv2.Sobel(depth_np, cv2.CV_32F, 0, 1, ksize=3)
z = np.ones_like(x) * a
x[depth_pt < bg_th] = 0
y[depth_pt < bg_th] = 0
normal = np.stack([x, y, z], axis=2)
normal /= np.sum(normal ** 2.0, axis=2, keepdims=True) ** 0.5
normal_image = (normal * 127.5 + 127.5).clip(0, 255).astype(np.uint8)

return depth_image, normal_image
169 changes: 169 additions & 0 deletions lavis/common/annotator/midas/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
# based on https://github.com/isl-org/MiDaS

import cv2
import os
import torch
import torch.nn as nn
from torchvision.transforms import Compose

from .midas.dpt_depth import DPTDepthModel
from .midas.midas_net import MidasNet
from .midas.midas_net_custom import MidasNet_small
from .midas.transforms import Resize, NormalizeImage, PrepareForNet
from annotator.util import annotator_ckpts_path


ISL_PATHS = {
"dpt_large": os.path.join(annotator_ckpts_path, "dpt_large-midas-2f21e586.pt"),
"dpt_hybrid": os.path.join(annotator_ckpts_path, "dpt_hybrid-midas-501f0c75.pt"),
"midas_v21": "",
"midas_v21_small": "",
}

remote_model_path = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/dpt_hybrid-midas-501f0c75.pt"


def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return self


def load_midas_transform(model_type):
# https://github.com/isl-org/MiDaS/blob/master/run.py
# load transform only
if model_type == "dpt_large": # DPT-Large
net_w, net_h = 384, 384
resize_mode = "minimal"
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])

elif model_type == "dpt_hybrid": # DPT-Hybrid
net_w, net_h = 384, 384
resize_mode = "minimal"
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])

elif model_type == "midas_v21":
net_w, net_h = 384, 384
resize_mode = "upper_bound"
normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

elif model_type == "midas_v21_small":
net_w, net_h = 256, 256
resize_mode = "upper_bound"
normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

else:
assert False, f"model_type '{model_type}' not implemented, use: --model_type large"

transform = Compose(
[
Resize(
net_w,
net_h,
resize_target=None,
keep_aspect_ratio=True,
ensure_multiple_of=32,
resize_method=resize_mode,
image_interpolation_method=cv2.INTER_CUBIC,
),
normalization,
PrepareForNet(),
]
)

return transform


def load_model(model_type):
# https://github.com/isl-org/MiDaS/blob/master/run.py
# load network
model_path = ISL_PATHS[model_type]
if model_type == "dpt_large": # DPT-Large
model = DPTDepthModel(
path=model_path,
backbone="vitl16_384",
non_negative=True,
)
net_w, net_h = 384, 384
resize_mode = "minimal"
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])

elif model_type == "dpt_hybrid": # DPT-Hybrid
if not os.path.exists(model_path):
from basicsr.utils.download_util import load_file_from_url
load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path)

model = DPTDepthModel(
path=model_path,
backbone="vitb_rn50_384",
non_negative=True,
)
net_w, net_h = 384, 384
resize_mode = "minimal"
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])

elif model_type == "midas_v21":
model = MidasNet(model_path, non_negative=True)
net_w, net_h = 384, 384
resize_mode = "upper_bound"
normalization = NormalizeImage(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
)

elif model_type == "midas_v21_small":
model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True,
non_negative=True, blocks={'expand': True})
net_w, net_h = 256, 256
resize_mode = "upper_bound"
normalization = NormalizeImage(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
)

else:
print(f"model_type '{model_type}' not implemented, use: --model_type large")
assert False

transform = Compose(
[
Resize(
net_w,
net_h,
resize_target=None,
keep_aspect_ratio=True,
ensure_multiple_of=32,
resize_method=resize_mode,
image_interpolation_method=cv2.INTER_CUBIC,
),
normalization,
PrepareForNet(),
]
)

return model.eval(), transform


class MiDaSInference(nn.Module):
MODEL_TYPES_TORCH_HUB = [
"DPT_Large",
"DPT_Hybrid",
"MiDaS_small"
]
MODEL_TYPES_ISL = [
"dpt_large",
"dpt_hybrid",
"midas_v21",
"midas_v21_small",
]

def __init__(self, model_type):
super().__init__()
assert (model_type in self.MODEL_TYPES_ISL)
model, _ = load_model(model_type)
self.model = model
self.model.train = disabled_train

def forward(self, x):
with torch.no_grad():
prediction = self.model(x)
return prediction

Empty file.
16 changes: 16 additions & 0 deletions lavis/common/annotator/midas/midas/base_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import torch


class BaseModel(torch.nn.Module):
def load(self, path):
"""Load model from file.
Args:
path (str): file path
"""
parameters = torch.load(path, map_location=torch.device('cpu'))

if "optimizer" in parameters:
parameters = parameters["model"]

self.load_state_dict(parameters)
Loading
Loading