Use BLIP directly + some simplifications
This commit is contained in:
parent
a7d98f494a
commit
8c3ef58e00
3 changed files with 48 additions and 39 deletions
|
@ -1,9 +1,14 @@
|
|||
from nataili_blip.model_manager import BlipModelManager
|
||||
from nataili_blip.caption import Caption
|
||||
import requests
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import BlipForConditionalGeneration
|
||||
from transformers import BlipProcessor
|
||||
|
||||
def load_model():
|
||||
model_name = "BLIP"
|
||||
mm = BlipModelManager()
|
||||
mm.download_model(model_name)
|
||||
mm.load_blip(model_name)
|
||||
return Caption(mm.loaded_models[model_name]["model"], mm.loaded_models[model_name]["device"])
|
||||
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
|
||||
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", torch_dtype=torch.float16).to("cuda")
|
||||
|
||||
# raw_image = Image.open('/tmp/istockphoto-470604022-612x612.jpg').convert('RGB')
|
||||
def caption_image(raw_image):
|
||||
inputs = processor(raw_image, return_tensors="pt").to("cuda", torch.float16)
|
||||
out = model.generate(**inputs, max_new_tokens=100)
|
||||
return processor.decode(out[0], skip_special_tokens=True)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue