forked from EQTPartners/GenCeption
-
Notifications
You must be signed in to change notification settings - Fork 0
/
exp_mplug.py
138 lines (119 loc) · 4.24 KB
/
exp_mplug.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import os
import torch
import argparse
from functools import partial
import logging
from tqdm import tqdm
from mplug_owl2.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
from mplug_owl2.conversation import conv_templates
from mplug_owl2.model.builder import load_pretrained_model
from mplug_owl2.mm_utils import (
process_images,
tokenizer_image_token,
get_model_name_from_path,
KeywordsStoppingCriteria,
)
from genception.utils import test_sample, encode_image_os, prompt
from genception.file_utils import find_image_files
logging.basicConfig(level=logging.INFO)
torch.backends.cudnn.enabled = False
def get_desc_mPLUG(image, image_processor, lmm_model, tokenizer, prompt, device):
"""
Given an image, generate a description using the mPLUG model
Args:
image: CLIPImageProcessor: The image to describe
image_processor: callable: The image processor
lmm_model: The language model
tokenizer: The tokenizer
prompt: str: The prompt for the model
Returns:
str: The description of the image
"""
conv = conv_templates["mplug_owl2"].copy()
max_edge = max(image.size)
resized_image = image.resize((max_edge, max_edge))
image_tensor = process_images([resized_image], image_processor)
if device == "cuda":
image_tensor = image_tensor.to(lmm_model.device, dtype=torch.float16)
inp = DEFAULT_IMAGE_TOKEN + prompt
conv.append_message(conv.roles[0], inp)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
input_ids = (
tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
.unsqueeze(0)
.to(lmm_model.device)
)
stop_str = conv.sep2
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
attention_mask = torch.ones_like(input_ids, dtype=torch.long)