How to get individual labels and seperate segmented masks for each category

#14
by omerjadoon1 - opened

How to get individual labels and seperate segmented masks for each category? e.g the code you provided return an image with all the segmented images on top of each other. how can i get it seperately and how do i know which segment belongs to which category.

Reference Code:
from transformers import SegformerImageProcessor, AutoModelForSemanticSegmentation
from PIL import Image
import requests
import matplotlib.pyplot as plt
import torch.nn as nn

processor = SegformerImageProcessor.from_pretrained("mattmdjaga/segformer_b2_clothes")
model = AutoModelForSemanticSegmentation.from_pretrained("mattmdjaga/segformer_b2_clothes")

url = "https://img.ltwebstatic.com/images3_pi/2022/12/30/16723672772d7cd1f75a62454976de69df7d4b2bd2_thumbnail_600x.webp"

image = Image.open(requests.get(url, stream=True).raw)
inputs = processor(images=image, return_tensors="pt")

outputs = model(**inputs)
logits = outputs.logits.cpu()

upsampled_logits = nn.functional.interpolate(
logits,
size=image.size[::-1],
mode="bilinear",
align_corners=True,
)

pred_seg = upsampled_logits.argmax(dim=1)[0]
plt.imshow(pred_seg)

Hi, the code blow loops over the masks and sets the name of that label as the title, does that help?

import numpy as np 
import torch

segments = torch.unique(pred_seg) # Get a list of all the predicted items
for i in segments: 
    mask = pred_seg == i # Filter out anything that isn't the current item
    img = Image.fromarray((mask * 255).numpy().astype(np.uint8))
    name = model.config.id2label[i.item()] # get the item name
    plt.imshow(img)
    plt.title(name)
    plt.show()

Thanks alot. :)

omerjadoon1 changed discussion status to closed

Sign up or log in to comment