-
Notifications
You must be signed in to change notification settings - Fork 2
Description
Hi, thanks for the great work!
While reviewing the code for 5-shot segmentation, I noticed something in FSSAM5s.py that may need attention.
In the section:
sup_mask = F.interpolate(
s_mask[:, 0, ...].unsqueeze(1).float(),
size=qry_sizes[-1],
mode='nearest'
)
output_query, weights = self.sam2.propagate_in_video_batch_final(
qry_feats, qry_poss, qry_sizes,
sup_fg_mem_feats, sup_fg_mem_poss, sup_mask, sup_fg_preds, sup_fg_obj_ptrs,
qry_ae_mem_feats, qry_ae_mem_poss, qry_ae_preds, qry_ae_obj_ptrs
)
output_query = output_query.squeeze(1)
You are selecting only the first mask of the support images (s_mask[:, 0, ...]) as sup_mask, which is confusing.
This may cause inconsistencies later, since the memory is concatenated in the order of support image features → query image features → object pointers.
In RoPEAttention, the code is:
sup_mask = memory_mask
shot = sup_mask.size(0) // b # 1 / 5
# parameters
use_attn_bias = True
drop_support = False
# query-support memory cosine similarity
cos_eps = 1e-7
k_qry = k[:, :, shot * num_pixel: (shot + 1) * num_pixel, :] # b, #head, 32*32, c - query
k_qry_norm = torch.norm(k_qry, 2, -1, True) # b, #head, 32*32, c
Here, shot is calculated as 1, which may cause k_qry to refer to the wrong set of features, potentially affecting all subsequent computations.
Could you please double-check this logic?
It seems that when running in on 5-shot setting, this mask selection may not align correctly with the memory concatenation order, which might lead to incorrect feature referencing.
Thanks again for your work!