moondream / app.py
dmedhi's picture
Update app.py
a1ee7f1 verified
raw
history blame
1.49 kB
import base64
import subprocess
from tempfile import NamedTemporaryFile
from fastapi import FastAPI, HTTPException
from PIL import Image
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import uvicorn
app = FastAPI()
# define request body
class RequestData(BaseModel):
prompt: str
image: str
def load_model():
model_id = "vikhyatk/moondream2"
revision = "2024-08-26"
model = AutoModelForCausalLM.from_pretrained(
model_id, trust_remote_code=True, revision=revision
)
tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision)
return model, tokenizer
MODEL, TOKENIZER = load_model()
print("INFO: Model loaded successfully!")
@app.get("/")
def greet_json():
return {"Hello": "World!"}
@app.post("/query")
def query(data: RequestData):
prompt = data.prompt
image = data.image
try:
# decode base64 to image
image = base64.b64decode(image)
with NamedTemporaryFile(delete=True, suffix=".png") as temp_image:
temp_image.write(image)
temp_image.flush()
image = Image.open(temp_image.name)
enc_image = MODEL.encode_image(image)
response = MODEL.answer_question(enc_image, str(prompt), TOKENIZER)
return {"response": str(response)}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
uvicorn.run("app:app")