Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add sam_gradio #177

Merged
merged 5 commits into from
Apr 20, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 134 additions & 60 deletions gradio_app.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import os
import random
import cv2
from scipy import ndimage

import gradio as gr
import argparse
Expand All @@ -16,7 +18,7 @@
from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap

# segment anything
from segment_anything import build_sam, SamPredictor
from segment_anything import build_sam, SamPredictor, SamAutomaticMaskGenerator
import numpy as np

# diffusers
Expand All @@ -26,6 +28,30 @@
# BLIP
from transformers import BlipProcessor, BlipForConditionalGeneration

def show_anns(anns):
if len(anns) == 0:
return
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
full_img = None

# for ann in sorted_anns:
for i in range(len(sorted_anns)):
ann = anns[i]
m = ann['segmentation']
if full_img is None:
full_img = np.zeros((m.shape[0], m.shape[1], 3))
map = np.zeros((m.shape[0], m.shape[1]), dtype=np.uint16)
map[m != 0] = i + 1
color_mask = np.random.random((1, 3)).tolist()[0]
full_img[m != 0] = color_mask
full_img = full_img*255
# anno encoding from https://github.com/LUSSeg/ImageNet-S
res = np.zeros((map.shape[0], map.shape[1], 3))
res[:, :, 0] = map % 256
res[:, :, 1] = map // 256
res.astype(np.float32)
full_img = Image.fromarray(np.uint8(full_img))
return full_img, res

def generate_caption(processor, blip_model, raw_image):
# unconditional image captioning
Expand Down Expand Up @@ -137,85 +163,132 @@ def draw_box(box, draw, label):
blip_model = None
groundingdino_model = None
sam_predictor = None
sam_automask_generator = None
inpaint_pipeline = None

def run_grounded_sam(input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold, iou_threshold, inpaint_mode):
def run_grounded_sam(input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold, iou_threshold, inpaint_mode, scribble_mode):

global blip_processor, blip_model, groundingdino_model, sam_predictor, inpaint_pipeline
global blip_processor, blip_model, groundingdino_model, sam_predictor, sam_automask_generator, inpaint_pipeline

# make dir
os.makedirs(output_dir, exist_ok=True)
# load image
image_pil = input_image.convert("RGB")
transformed_image = transform_image(image_pil)
image = input_image["image"]
scribble = input_image["mask"]
size = image.size # w, h

if sam_predictor is None:
# initialize SAM
assert sam_checkpoint, 'sam_checkpoint is not found!'
sam = build_sam(checkpoint=sam_checkpoint)
sam.to(device=device)
sam_predictor = SamPredictor(sam)
sam_automask_generator = SamAutomaticMaskGenerator(sam)

if groundingdino_model is None:
groundingdino_model = load_model(config_file, ckpt_filenmae, device=device)

if task_type == 'automatic':
# generate caption and tags
# use Tag2Text can generate better captions
# https://huggingface.co/spaces/xinyu1205/Tag2Text
# but there are some bugs...
blip_processor = blip_processor or BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
blip_model = blip_model or BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large", torch_dtype=torch.float16).to("cuda")
text_prompt = generate_caption(blip_processor, blip_model, image_pil)
print(f"Caption: {text_prompt}")

# run grounding dino model
boxes_filt, scores, pred_phrases = get_grounding_output(
groundingdino_model, transformed_image, text_prompt, box_threshold, text_threshold
)

size = image_pil.size
image_pil = image.convert("RGB")
image = np.array(image_pil)

# process boxes
H, W = size[1], size[0]
for i in range(boxes_filt.size(0)):
boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
boxes_filt[i][2:] += boxes_filt[i][:2]
if task_type == 'scribble':
sam_predictor.set_image(image)
scribble = scribble.convert("RGB")
scribble = np.array(scribble)
scribble = scribble.transpose(2, 1, 0)[0]

# 将连通域进行标记
labeled_array, num_features = ndimage.label(scribble >= 255)

# 计算每个连通域的质心
centers = ndimage.center_of_mass(scribble, labeled_array, range(1, num_features+1))
centers = np.array(centers)

point_coords = torch.from_numpy(centers)
point_coords = sam_predictor.transform.apply_coords_torch(point_coords, image.shape[:2])
point_coords = point_coords.unsqueeze(0).to(device)
point_labels = torch.from_numpy(np.array([1] * len(centers))).unsqueeze(0).to(device)
if scribble_mode == 'split':
point_coords = point_coords.permute(1, 0, 2)
point_labels = point_labels.permute(1, 0)
masks, _, _ = sam_predictor.predict_torch(
point_coords=point_coords if len(point_coords) > 0 else None,
point_labels=point_labels if len(point_coords) > 0 else None,
mask_input = None,
boxes = None,
multimask_output = False,
)
elif task_type == 'automask':
masks = sam_automask_generator.generate(image)
else:
transformed_image = transform_image(image_pil)

boxes_filt = boxes_filt.cpu()
if task_type == 'automatic':
# generate caption and tags
# use Tag2Text can generate better captions
# https://huggingface.co/spaces/xinyu1205/Tag2Text
# but there are some bugs...
blip_processor = blip_processor or BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
blip_model = blip_model or BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large", torch_dtype=torch.float16).to("cuda")
text_prompt = generate_caption(blip_processor, blip_model, image_pil)
print(f"Caption: {text_prompt}")

# run grounding dino model
boxes_filt, scores, pred_phrases = get_grounding_output(
groundingdino_model, transformed_image, text_prompt, box_threshold, text_threshold
)

# process boxes
H, W = size[1], size[0]
for i in range(boxes_filt.size(0)):
boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
boxes_filt[i][2:] += boxes_filt[i][:2]

if task_type == 'seg' or task_type == 'inpainting' or task_type == 'automatic':
if sam_predictor is None:
# initialize SAM
assert sam_checkpoint, 'sam_checkpoint is not found!'
sam = build_sam(checkpoint=sam_checkpoint)
sam.to(device=device)
sam_predictor = SamPredictor(sam)
boxes_filt = boxes_filt.cpu()

image = np.array(image_pil)
sam_predictor.set_image(image)

if task_type == 'automatic':
# use NMS to handle overlapped boxes
print(f"Before NMS: {boxes_filt.shape[0]} boxes")
nms_idx = torchvision.ops.nms(boxes_filt, scores, iou_threshold).numpy().tolist()
boxes_filt = boxes_filt[nms_idx]
pred_phrases = [pred_phrases[idx] for idx in nms_idx]
print(f"After NMS: {boxes_filt.shape[0]} boxes")
print(f"Revise caption with number: {text_prompt}")
if task_type == 'seg' or task_type == 'inpainting' or task_type == 'automatic':
sam_predictor.set_image(image)

transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2]).to(device)
if task_type == 'automatic':
# use NMS to handle overlapped boxes
print(f"Before NMS: {boxes_filt.shape[0]} boxes")
nms_idx = torchvision.ops.nms(boxes_filt, scores, iou_threshold).numpy().tolist()
boxes_filt = boxes_filt[nms_idx]
pred_phrases = [pred_phrases[idx] for idx in nms_idx]
print(f"After NMS: {boxes_filt.shape[0]} boxes")
print(f"Revise caption with number: {text_prompt}")

masks, _, _ = sam_predictor.predict_torch(
point_coords = None,
point_labels = None,
boxes = transformed_boxes,
multimask_output = False,
)
transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2]).to(device)

# masks: [1, 1, 512, 512]
masks, _, _ = sam_predictor.predict_torch(
point_coords = None,
point_labels = None,
boxes = transformed_boxes,
multimask_output = False,
)

if task_type == 'det':
image_draw = ImageDraw.Draw(image_pil)
for box, label in zip(boxes_filt, pred_phrases):
draw_box(box, image_draw, label)

return [image_pil]
elif task_type == 'automask':
full_img, res = show_anns(masks)
return [full_img]
elif task_type == 'scribble':
mask_image = Image.new('RGBA', size, color=(0, 0, 0, 0))

mask_draw = ImageDraw.Draw(mask_image)

for mask in masks:
draw_mask(mask[0].cpu().numpy(), mask_draw, random_color=True)

image_pil = image_pil.convert('RGBA')
image_pil.alpha_composite(mask_image)
return [image_pil, mask_image]
elif task_type == 'seg' or task_type == 'automatic':

mask_image = Image.new('RGBA', size, color=(0, 0, 0, 0))
Expand Down Expand Up @@ -274,30 +347,31 @@ def run_grounded_sam(input_image, text_prompt, task_type, inpaint_prompt, box_th
with block:
with gr.Row():
with gr.Column():
input_image = gr.Image(source='upload', type="pil", value="assets/demo1.jpg")
task_type = gr.Dropdown(["det", "seg", "inpainting", "automatic"], value="automatic", label="task_type")
input_image = gr.Image(source='upload', type="pil", value="assets/demo1.jpg", tool="sketch")
task_type = gr.Dropdown(["scribble", "automask", "det", "seg", "inpainting", "automatic"], value="automatic", label="task_type")
text_prompt = gr.Textbox(label="Text Prompt")
inpaint_prompt = gr.Textbox(label="Inpaint Prompt")
run_button = gr.Button(label="Run")
with gr.Accordion("Advanced options", open=False):
box_threshold = gr.Slider(
label="Box Threshold", minimum=0.0, maximum=1.0, value=0.3, step=0.001
label="Box Threshold", minimum=0.0, maximum=1.0, value=0.3, step=0.05
)
text_threshold = gr.Slider(
label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.001
label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.05
)
iou_threshold = gr.Slider(
label="IOU Threshold", minimum=0.0, maximum=1.0, value=0.5, step=0.001
label="IOU Threshold", minimum=0.0, maximum=1.0, value=0.5, step=0.05
)
inpaint_mode = gr.Dropdown(["merge", "first"], value="merge", label="inpaint_mode")
scribble_mode = gr.Dropdown(["merge", "split"], value="split", label="scribble_mode")

with gr.Column():
gallery = gr.Gallery(
label="Generated images", show_label=False, elem_id="gallery"
).style(preview=True, grid=2, object_fit="scale-down")

run_button.click(fn=run_grounded_sam, inputs=[
input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold, iou_threshold, inpaint_mode], outputs=gallery)

input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold, iou_threshold, inpaint_mode, scribble_mode], outputs=gallery)

block.queue(concurrency_count=100)
block.launch(server_name='0.0.0.0', server_port=args.port, debug=args.debug, share=args.share)