Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How can an image be input into a model to output its scene graph information and bounding box information for visualization? #7

Open
PiPiSang opened this issue Aug 29, 2024 · 8 comments

Comments

@PiPiSang
Copy link

First of all, the work is very instructive, thank you! As a novice to scene graph generation, my current research requires methods to extract information from images using scene graph generation. So I'm curious about how to use the output of the model to generate a scene graph, and I'm confused about how to visualize the output. Any guidance would be greatly appreciated.

@JeongSooHwan
Copy link

Hi :). I was also deeply impressed by the EGTR paper. I am facing the same issue. As I am not very knowledgeable in the SGG domain, I understand that it involves predicting relationships between objects. I need to perform the SGG as a preliminary step to use SG embedding for downstream tasks. I would like to refer to EGTR as the SGG module, but could you please let me know how the output scene graph is generated?

@jinbae
Copy link
Collaborator

jinbae commented Aug 30, 2024

First of all, I have not attempted to extract the scene graph well other than measuring measures for evaluation.
So, although I share the code snippet for extracting the scene graph, I recommend that you improve it to suit your own needs.
Following code snippet is based on evaluate_egtr.py.

from glob import glob

import torch
from PIL import Image

from model.deformable_detr import DeformableDetrConfig, DeformableDetrFeatureExtractor
from model.egtr import DetrForSceneGraphGeneration

# config
architecture = "SenseTime/deformable-detr"
min_size = 800
max_size = 1333
artifact_path = YOUR_ARTIFACT_PATH

# feature extractor
feature_extractor = DeformableDetrFeatureExtractor.from_pretrained(
    architecture, size=min_size, max_size=max_size
)

# inference image
image = Image.open(YOUR_IMAGE_PATH)
image = feature_extractor(image, return_tensors="pt")

# model
config = DeformableDetrConfig.from_pretrained(artifact_path)
model = DetrForSceneGraphGeneration.from_pretrained(
    architecture, config=config, ignore_mismatched_sizes=True
)
ckpt_path = sorted(
    glob(f"{artifact_path}/checkpoints/epoch=*.ckpt"),
    key=lambda x: int(x.split("epoch=")[1].split("-")[0]),
)[-1]
state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]
for k in list(state_dict.keys()):
    state_dict[k[6:]] = state_dict.pop(k)  # "model."

model.load_state_dict(state_dict)
model.cuda()
model.eval()

# output
outputs = model(
    pixel_values=image['pixel_values'].cuda(), 
    pixel_mask=image['pixel_mask'].cuda(), 
    output_attention_states=True
)

pred_logits = outputs['logits'][0]
obj_scores, pred_classes = torch.max(pred_logits.softmax(-1), -1)
pred_boxes = outputs['pred_boxes'][0]

pred_connectivity = outputs['pred_connectivity'][0]
pred_rel = outputs['pred_rel'][0]
pred_rel = torch.mul(pred_rel, pred_connectivity)

# get valid objects and triplets
obj_threshold = YOUR_OBJ_THRESHOLD
valid_obj_indices = (obj_scores >= obj_threshold).nonzero()[:, 0]

valid_obj_classes = pred_classes[valid_obj_indices] # [num_valid_objects]
valid_obj_boxes = pred_boxes[valid_obj_indices] # [num_valid_objects, 4]

rel_threshold = YOUR_REL_THRESHOLD
valid_triplets = (pred_rel[valid_obj_indices][:, valid_obj_indices] >= rel_threshold).nonzero() # [num_valid_triplets, 3]

You can generate a scene graph based on valid_obj_classes, valid_obj_boxes, and valid_triplets.

  • valid_obj_classes: object classes
  • valid_obj_boxes: object bounding boxes (cxcywh format)
  • valid_triplets: relation triplets (subject entity index, object entity index, relation class)
    • Please note that subject entity index and object entity index indicate the indices of valid objects.

I built a scene graph using thresholds in this example, but it can also be implemented by selecting the top k objects or triplets.
Since the thresholds have never been explored, it may be important to set the threshold well.

@jinbae
Copy link
Collaborator

jinbae commented Sep 2, 2024

@PiPiSang

(1) obj boxes

valid_obj_boxes: object bounding boxes (cxcywh format)

As I mentioned before, pred_boxes are cxcywh format.
Please make sure that pred_boxes have been converted to xyxy format before bbox visualization.

(2) obj scores

We used Deformable DETR rather than DETR, and in Deformable DETR, focal loss is used for object detection instead of cross entropy loss.

obj_scores, pred_classes = torch.max(pred_logits.softmax(-1), -1)

Therefore, It is more natural to use sigmoid instead of softmax (https://github.com/huggingface/transformers/blob/409fcfdfccde77a14b7cc36972b774cabc371ae1/src/transformers/models/deformable_detr/image_processing_deformable_detr.py#L1555), but we used softmax to get obj_scores.
Obj_scores may be low compared to models trained with the cross entropy loss.

@Aoihigashi
Copy link

@jinbae
Thank you very much for your response regarding scene graph visualization. I followed the code you provided for inference on a single image and used the pretrained weights for the oi dataset as mentioned in the README for inference. However, I encountered the following error.
image
I then tried using the weights I trained myself on the oi dataset, but the same error occurred. I would like to know why the class dimension in the weights is 91 instead of 601, which is the number of classes in the oi dataset. I am eagerly awaiting your response.

@PiPiSang
Copy link
Author

PiPiSang commented Sep 5, 2024

@Aoihigashi
In fact, this is not an error. This so-called error is emitted during the execution of the model initialization. You can try debugging this test code, and you will find that after executing the code on line 229, the console will output this message. Moreover, the code on line 229 does not actually load the trained weights into the model. This code simply creates a model that matches the structure specified in the configuration file and initializes it with the official default weights. Therefore, this message is indicating that the structure of the model does not fully match the official model structure, which leads to the inability to completely load the weights during initialization. The code that truly loads our own trained weights into the model is on line 240. So, you can relax, there is no issue.

@Aoihigashi
Copy link

@PiPiSang
Thank you! I see, that’s exactly the case, and it does run. However, I have to set the obj_threshold and rel_threshold very low (0.1) to get any output. Have you encountered a similar issue?

@PiPiSang
Copy link
Author

PiPiSang commented Sep 5, 2024

@Aoihigashi Yes. The obj_threshold and rel_threshold for me are 0.3 and 1e-4. But based on my tests, I strongly suggest that you directly select the top_k triplets based on their scores. Then, you can read the function where the author calculates the loss in train_egtr.py. This demonstrates how to select the top_k triplets and perform the transformation.

@Aoihigashi
Copy link

@PiPiSang
Thank you for your suggestions and response. I will try following the advice.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants