Performance improvement
Great demo!
I have one suggestion for performance improvement, you could run get_image_embeddings
only once, to retrieve the image embeddings and iteratively run the classic forward pass with the input points by making sure you popped the pixel_values
from the inputs dict.
Also, you might need to add with torch.no_grad():
context manager when calling the forward pass for faster inference
Thanks for the feedback, I added torch.no_grad() to the function definition, I also added get_image_embeddings and storing the embedding until a new image is uploaded and there is a significant speed up for the 2nd inference on the same image. Although, I might need to add some other check because i'm not sure how this will work when multiple people are using the app.