dolphin-vision-72b-4bit / dolphin_vision_streamlit.py
prince-canuma's picture
Upload folder using huggingface_hub
43121de verified
raw
history blame contribute delete
No virus
2.33 kB
import streamlit as st
import torch
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
from PIL import Image
import warnings
# Disable warnings and progress bars
transformers.logging.set_verbosity_error()
transformers.logging.disable_progress_bar()
warnings.filterwarnings('ignore')
# Set device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.set_default_device(device)
@st.cache_resource
def load_model():
model_name = 'cognitivecomputations/dolphin-vision-72b'
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map='auto',
trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=True
)
return model, tokenizer
def generate_response(model, tokenizer, prompt, image=None):
messages = [
{"role": "user", "content": f'<image>\n{prompt}' if image else prompt}
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
text_chunks = [tokenizer(chunk).input_ids for chunk in text.split('<image>')]
input_ids = torch.tensor(text_chunks[0] + [-200] + text_chunks[1], dtype=torch.long).unsqueeze(0)
if image:
image_tensor = model.process_images([image], model.config).to(dtype=model.dtype)
else:
image_tensor = None
output_ids = model.generate(
input_ids,
images=image_tensor,
max_new_tokens=2048,
use_cache=True
)[0]
return tokenizer.decode(output_ids[input_ids.shape[1]:], skip_special_tokens=True).strip()
st.title("Chat with DolphinVision 🐬")
model, tokenizer = load_model()
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
image = None
if uploaded_file is not None:
image = Image.open(uploaded_file)
st.image(image, caption='Uploaded Image', use_column_width=True)
user_input = st.text_input("You:", "")
if st.button("Send"):
if user_input:
with st.spinner("Generating response..."):
response = generate_response(model, tokenizer, user_input, image)
st.text_area("DolphinVision:", value=response, height=200)
else:
st.warning("Please enter a message.")