Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit0ba3c76

Browse files
committed
updated support for gpt-4, pixtral, gemini and momlo
1 parent51647db commit0ba3c76

File tree

3 files changed

+138
-68
lines changed

3 files changed

+138
-68
lines changed

‎models/model_loader.py

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,15 @@
44
importtorch
55
fromtransformersimportQwen2VLForConditionalGeneration,AutoProcessor
66
fromtransformersimportMllamaForConditionalGeneration
7+
fromvllm.sampling_paramsimportSamplingParams
78
fromtransformersimportAutoModelForCausalLM
9+
importgoogle.generativeaiasgenai
810
fromvllmimportLLM
9-
fromvllm.sampling_paramsimportSamplingParams
11+
12+
fromdotenvimportload_dotenv
13+
14+
# Load environment variables from .env file
15+
load_dotenv()
1016

1117
fromloggerimportget_logger
1218

@@ -21,8 +27,8 @@ def detect_device():
2127
"""
2228
iftorch.cuda.is_available():
2329
return'cuda'
24-
#elif torch.backends.mps.is_available():
25-
# return 'mps'
30+
eliftorch.backends.mps.is_available():
31+
return'mps'
2632
else:
2733
return'cpu'
2834

@@ -51,21 +57,13 @@ def load_model(model_choice):
5157

5258
elifmodel_choice=='gemini':
5359
# Load Gemini model
54-
importgenai
55-
genai.api_key=os.environ.get('GENAI_API_KEY')
56-
model=genai.GenerativeModel(model_name="gemini-1.5-pro")
57-
processor=None
58-
_model_cache[model_choice]= (model,processor)
59-
logger.info("Gemini model loaded and cached.")
60-
return_model_cache[model_choice]
60+
api_key=os.getenv("GOOGLE_API_KEY")
61+
ifnotapi_key:
62+
raiseValueError("GOOGLE_API_KEY not found in .env file")
63+
genai.configure(api_key=api_key)
64+
model=genai.GenerativeModel('gemini-1.5-flash-002')# Use the appropriate model name
65+
returnmodel,None
6166

62-
elifmodel_choice=='gpt4':
63-
# Load OpenAI GPT-4 model
64-
importopenai
65-
openai.api_key=os.environ.get('OPENAI_API_KEY')
66-
_model_cache[model_choice]= (None,None)
67-
logger.info("GPT-4 model ready and cached.")
68-
return_model_cache[model_choice]
6967

7068
elifmodel_choice=='llama-vision':
7169
# Load Llama-Vision model
@@ -85,21 +83,26 @@ def load_model(model_choice):
8583

8684
elifmodel_choice=="pixtral":
8785
device=detect_device()
88-
model=LLM(model="mistralai/Pixtral-12B-2409",tokenizer_mode="mistral")
86+
model=LLM(model="mistralai/Pixtral-12B-2409",
87+
tokenizer_mode="mistral",
88+
gpu_memory_utilization=0.8,# Increase GPU memory utilization
89+
max_model_len=8192,# Decrease max model length
90+
dtype="float16",# Use half precision to save memory
91+
trust_remote_code=True)
8992
sampling_params=SamplingParams(max_tokens=1024)
9093
_model_cache[model_choice]= (model,sampling_params,device)
9194
return_model_cache[model_choice]
9295

9396
elifmodel_choice=="molmo":
9497
device=detect_device()
9598
processor=AutoProcessor.from_pretrained(
96-
'allenai/Molmo-7B-D-0924',
99+
'allenai/MolmoE-1B-0924',
97100
trust_remote_code=True,
98101
torch_dtype='auto',
99102
device_map='auto'
100103
)
101104
model=AutoModelForCausalLM.from_pretrained(
102-
'allenai/Molmo-7B-D-0924',
105+
'allenai/MolmoE-1B-0924',
103106
trust_remote_code=True,
104107
torch_dtype='auto',
105108
device_map='auto'

‎models/responder.py

Lines changed: 112 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,24 @@
22

33
frommodels.model_loaderimportload_model
44
fromtransformersimportGenerationConfig
5+
importgoogle.generativeaiasgenai
6+
fromdotenvimportload_dotenv
57
fromloggerimportget_logger
8+
fromopenaiimportOpenAI
69
fromPILimportImage
710
importtorch
11+
importbase64
812
importos
13+
importio
14+
915

1016
logger=get_logger(__name__)
1117

18+
# Function to encode the image
19+
defencode_image(image_path):
20+
withopen(image_path,"rb")asimage_file:
21+
returnbase64.b64encode(image_file.read()).decode('utf-8')
22+
1223
defgenerate_response(images,query,session_id,resized_height=280,resized_width=280,model_choice='qwen'):
1324
"""
1425
Generates a response using the selected model based on the query and images.
@@ -56,18 +67,83 @@ def generate_response(images, query, session_id, resized_height=280, resized_wid
5667
)
5768
logger.info("Response generated using Qwen model.")
5869
returnoutput_text[0]
70+
5971
elifmodel_choice=='gemini':
60-
frommodels.gemini_responderimportgenerate_gemini_response
61-
model,processor=load_model('gemini')
62-
response=generate_gemini_response(images,query,model,processor)
63-
logger.info("Response generated using Gemini model.")
64-
returnresponse
72+
73+
model,_=load_model('gemini')
74+
75+
try:
76+
content= []
77+
content.append(query)# Add the text query first
78+
79+
forimg_pathinimages:
80+
full_path=os.path.join('static',img_path)
81+
ifos.path.exists(full_path):
82+
try:
83+
img=Image.open(full_path)
84+
content.append(img)
85+
exceptExceptionase:
86+
logger.error(f"Error opening image{full_path}:{e}")
87+
else:
88+
logger.warning(f"Image file not found:{full_path}")
89+
90+
iflen(content)==1:# Only text, no images
91+
return"No images could be loaded for analysis."
92+
93+
response=model.generate_content(content)
94+
95+
ifresponse.text:
96+
generated_text=response.text
97+
logger.info("Response generated using Gemini model.")
98+
returngenerated_text
99+
else:
100+
return"The Gemini model did not generate any text response."
101+
102+
exceptExceptionase:
103+
logger.error(f"Error in Gemini processing:{str(e)}",exc_info=True)
104+
returnf"An error occurred while processing the images:{str(e)}"
105+
65106
elifmodel_choice=='gpt4':
66-
frommodels.gpt4_responderimportgenerate_gpt4_response
67-
model,_=load_model('gpt4')
68-
response=generate_gpt4_response(images,query,model)
69-
logger.info("Response generated using GPT-4 model.")
70-
returnresponse
107+
api_key=os.getenv("OPENAI_API_KEY")
108+
client=OpenAI(api_key=api_key)
109+
110+
try:
111+
content= [{"type":"text","text":query}]
112+
113+
forimg_pathinimages:
114+
full_path=os.path.join('static',img_path)
115+
ifos.path.exists(full_path):
116+
base64_image=encode_image(full_path)
117+
content.append({
118+
"type":"image_url",
119+
"image_url": {
120+
"url":f"data:image/jpeg;base64,{base64_image}"
121+
}
122+
})
123+
else:
124+
logger.warning(f"Image file not found:{full_path}")
125+
126+
iflen(content)==1:# Only text, no images
127+
return"No images could be loaded for analysis."
128+
129+
response=client.chat.completions.create(
130+
model="gpt-4o",# Make sure to use the correct model name
131+
messages=[
132+
{
133+
"role":"user",
134+
"content":content
135+
}
136+
],
137+
max_tokens=1024
138+
)
139+
140+
generated_text=response.choices[0].message.content
141+
logger.info("Response generated using GPT-4 model.")
142+
returngenerated_text
143+
144+
exceptExceptionase:
145+
logger.error(f"Error in GPT-4 processing:{str(e)}",exc_info=True)
146+
returnf"An error occurred while processing the images:{str(e)}"
71147

72148
elifmodel_choice=='llama-vision':
73149
# Load model, processor, and device
@@ -98,20 +174,22 @@ def generate_response(images, query, session_id, resized_height=280, resized_wid
98174

99175
model,sampling_params,device=load_model('pixtral')
100176

101-
image_urls= []
102-
forimginimages:
103-
# Convert PIL Image to base64
104-
buffered=io.BytesIO()
105-
img.save(buffered,format="PNG")
106-
img_str=base64.b64encode(buffered.getvalue()).decode()
107-
image_urls.append(f"data:image/png;base64,{img_str}")
108177

178+
defimage_to_data_url(image_path):
179+
180+
image_path=os.path.join('static',image_path)
181+
182+
withopen(image_path,"rb")asimage_file:
183+
encoded_string=base64.b64encode(image_file.read()).decode('utf-8')
184+
ext=os.path.splitext(image_path)[1][1:]# Get the file extension
185+
returnf"data:image/{ext};base64,{encoded_string}"
186+
109187
messages= [
110188
{
111189
"role":"user",
112190
"content": [
113191
{"type":"text","text":query},
114-
*[{"type":"image_url","image_url": {"url":url}}forurlinimage_urls]
192+
*[{"type":"image_url","image_url": {"url":image_to_data_url(img_path)}}fori,img_pathinenumerate(images)ifi<1]
115193
]
116194
},
117195
]
@@ -120,10 +198,10 @@ def generate_response(images, query, session_id, resized_height=280, resized_wid
120198
returnoutputs[0].outputs[0].text
121199

122200
elifmodel_choice=="molmo":
123-
124201
model,processor,device=load_model('molmo')
202+
model=model.half()# Convert model to half precision
125203
pil_images= []
126-
forimg_pathinimages:
204+
forimg_pathinimages[:1]:# Process only the first image for now
127205
full_path=os.path.join('static',img_path)
128206
ifos.path.exists(full_path):
129207
try:
@@ -138,53 +216,40 @@ def generate_response(images, query, session_id, resized_height=280, resized_wid
138216
return"No images could be loaded for analysis."
139217

140218
try:
141-
# Log the types and shapes of the images
142-
logger.info(f"Number of images:{len(pil_images)}")
143-
logger.info(f"Image types:{[type(img)forimginpil_images]}")
144-
logger.info(f"Image sizes:{[img.sizeforimginpil_images]}")
145-
146219
# Process the images and text
147220
inputs=processor.process(
148221
images=pil_images,
149222
text=query
150223
)
151224

152-
# Log the keys and shapes of the inputs
153-
logger.info(f"Input keys:{inputs.keys()}")
154-
fork,vininputs.items():
155-
ifisinstance(v,torch.Tensor):
156-
logger.info(f"Input '{k}' shape:{v.shape}, dtype:{v.dtype}, device:{v.device}")
157-
else:
158-
logger.info(f"Input '{k}' type:{type(v)}")
159-
160225
# Move inputs to the correct device and make a batch of size 1
161-
inputs= {k:v.to(model.device).unsqueeze(0)ifisinstance(v,torch.Tensor)elsevfork,vininputs.items()}
162-
163-
# Log the updated shapes after moving to device
164-
fork,vininputs.items():
165-
ifisinstance(v,torch.Tensor):
166-
logger.info(f"Updated input '{k}' shape:{v.shape}, dtype:{v.dtype}, device:{v.device}")
226+
# Convert float tensors to half precision, but keep integer tensors as they are
227+
inputs= {k: (v.to(device).unsqueeze(0).half()ifv.dtypein [torch.float32,torch.float64]else
228+
v.to(device).unsqueeze(0))
229+
ifisinstance(v,torch.Tensor)elsev
230+
fork,vininputs.items()}
167231

168232
# Generate output
169-
output=model.generate_from_batch(
170-
inputs,
171-
GenerationConfig(max_new_tokens=200,stop_strings="<|endoftext|>"),
172-
tokenizer=processor.tokenizer
173-
)
233+
withtorch.no_grad():# Disable gradient calculation
234+
output=model.generate_from_batch(
235+
inputs,
236+
GenerationConfig(max_new_tokens=200,stop_strings="<|endoftext|>"),
237+
tokenizer=processor.tokenizer
238+
)
174239

175240
# Only get generated tokens; decode them to text
176241
generated_tokens=output[0,inputs['input_ids'].size(1):]
177242
generated_text=processor.tokenizer.decode(generated_tokens,skip_special_tokens=True)
178243

244+
returngenerated_text
245+
179246
exceptExceptionase:
180247
logger.error(f"Error in Molmo processing:{str(e)}",exc_info=True)
181248
returnf"An error occurred while processing the images:{str(e)}"
182249
finally:
183250
# Close the opened images to free up resources
184251
forimginpil_images:
185-
img.close()
186-
187-
returngenerated_text
252+
img.close()
188253
else:
189254
logger.error(f"Invalid model choice:{model_choice}")
190255
return"Invalid model selected."

‎requirements.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,6 @@ docx2pdf
88
qwen-vl-utils
99
vllm>=0.6.1.post1
1010
mistral_common>=1.4.1
11-
einops
11+
einops
12+
mistral_common[opencv]
13+
mistral_common

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp