moondream / app.py
dmedhi's picture
add transformers moondream
25a0dc3
raw
history blame
1.42 kB
import base64
import subprocess
from tempfile import NamedTemporaryFile
from fastapi import FastAPI, HTTPException
from PIL import Image
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer
app = FastAPI()
# define request body
class RequestData(BaseModel):
prompt: str
image: str
def load_model():
model_id = "vikhyatk/moondream2"
revision = "2024-08-26"
model = AutoModelForCausalLM.from_pretrained(
model_id, trust_remote_code=True, revision=revision
)
tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision)
return model, tokenizer
MODEL, TOKENIZER = load_model()
print("INFO: Model loaded successfully!")
@app.get("/")
def greet_json():
return {"Hello": "World!"}
@app.post("/query")
def query(data: RequestData):
prompt = data.prompt
image = data.image
try:
# decode base64 to image
image = base64.b64decode(image)
with NamedTemporaryFile(delete=True, suffix=".png") as temp_image:
temp_image.write(image)
temp_image.flush()
image = Image.open(temp_image.name)
enc_image = MODEL.encode_image(image)
response = MODEL.answer_question(enc_image, str(prompt), TOKENIZER)
return {"response": str(response)}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))