You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hello, Im trying to fine-tune the mask decoder of tiny sam on a custom dataset while freezing the weights of the image_encoder and prompt_encoder. Im having an issue in my training loop where the sam.forward() requires a "multimask_output" argument but the MaskDecoder.forward() doesn't accept a "multitask_output" argument.
Im not an ML Engineer so I don't know much about the underlying code. If anyone with more knowledge than me has some insight into how I can resolve this issue I would appreciate it, thanks!
here is how im freezing the image encoder and prompt encoder to maintain the original weights:
for name, param in sam_model.named_parameters():
if name.startswith("image_encoder") or name.startswith("prompt_encoder"):
param.requires_grad_(False)
I am also providing bounding box Prompts as the input. Here is my custom class for Dataset creation:
class SAMDataset(Dataset):
"""
Dataset class for SAM model, serving images with associated bounding boxes and masks,
"""
def __init__(self, dataset, bbox_mapping, sam_model, device='cuda'):
self.dataset = dataset
self.bbox_mapping = bbox_mapping
self.sam_model = sam_model
self.device = device
self.target_size = (1024, 1024) # Adjusted to the expected input size of the model
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
# Assuming dataset[idx] returns a dict with 'image' and 'label' keys
pil_image = self.dataset[idx]['image']
pil_mask = self.dataset[idx]['label']
image_tensor = to_tensor(np.array(pil_image)).to(self.device)
mask_tensor = to_tensor(np.array(pil_mask)).to(self.device)
# Resize image and mask to target size
image_tensor = resize(image_tensor, self.target_size)
mask_tensor = resize(mask_tensor, self.target_size)
# Fetch bounding boxes directly without padding
bboxes = self.bbox_mapping.get(idx + 1, []) # Adjust index if necessary
bboxes_tensor = torch.tensor(bboxes, dtype=torch.float, device=self.device)
return {
'image': image_tensor,
'bboxes': bboxes_tensor,
'mask': mask_tensor
}
### Create a DataLoader instance for the training dataset
from torch.utils.data import DataLoader
train_dataloader = DataLoader(train_dataset,shuffle=True, drop_last=False)
image torch.Size([1, 3, 1024, 1024])
bboxes torch.Size([1, 1, 4])
mask torch.Size([1, 1, 1024, 1024])
`
### Training Loop
num_epochs = 1
device = "cuda"
sam_model.to(device)
sam_model.train()
for epoch in range(num_epochs):
epoch_losses = []
for batch in tqdm(train_dataloader):
# Preparing the batched_input according to the Tiny sam_model's expected input format
batched_input = [{
'image': batch['image'].squeeze(0).to(device),
'bboxes': batch['bboxes'].squeeze(0).to(device)
}]
# forward pass
outputs_list = sam_model(batched_input, multimask_output = True)
# Assuming the outputs_list is a list of dictionaries, and you need to aggregate the masks for loss computation
# Here, you'd need to adapt the code to match the structure of your outputs
predicted_masks = torch.stack([output['pred_mask'] for output in outputs_list]).squeeze(0)
ground_truth_masks = batch["mask"].float().squeeze(1).to(device)
loss = seg_loss(predicted_masks, ground_truth_masks)
# backward pass (compute gradients of parameters)
optimizer.zero_grad()
loss.backward()
# optimize
optimizer.step()
epoch_losses.append(loss.item())
print(f'EPOCH: {epoch}')
print(f'Mean loss: {mean(epoch_losses)}')
error when I DONT provide multitask_output:
TypeError Traceback (most recent call last)
<ipython-input-108-f41ebba752d9> in <cell line: 12>()
21
22 # forward pass
---> 23 outputs_list = sam_model(batched_input)
24
25 # Assuming the outputs_list is a list of dictionaries, and you need to aggregate the masks for loss computation
2 frames
/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py in decorate_context(*args, **kwargs)
113 def decorate_context(*args, **kwargs):
114 with ctx_factory():
--> 115 return func(*args, **kwargs)
116
117 return decorate_context
TypeError: Sam.forward() missing 1 required positional argument: 'multimask_output'
error when I do provide the multitask_output argument:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
[<ipython-input-42-9d874c2eda3d>](https://bvvo9qsh5t-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240325-094203_RC00_618870756#) in <cell line: 12>()
19 }]
20 # forward pass
---> 21 outputs_list = sam_model(batched_input, multimask_output = True)
22
23 # Assuming the outputs_list is a list of dictionaries, and you need to aggregate the masks for loss computation
5 frames
[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://bvvo9qsh5t-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240325-094203_RC00_618870756#) in _call_impl(self, *args, **kwargs)
1518 or _global_backward_pre_hooks or _global_backward_hooks
1519 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520 return forward_call(*args, **kwargs)
1521
1522 try:
TypeError: MaskDecoder.forward() got an unexpected keyword argument 'multimask_output'
The text was updated successfully, but these errors were encountered:
Hello, Im trying to fine-tune the mask decoder of tiny sam on a custom dataset while freezing the weights of the image_encoder and prompt_encoder. Im having an issue in my training loop where the sam.forward() requires a "multimask_output" argument but the MaskDecoder.forward() doesn't accept a "multitask_output" argument.
Im not an ML Engineer so I don't know much about the underlying code. If anyone with more knowledge than me has some insight into how I can resolve this issue I would appreciate it, thanks!
here is how im freezing the image encoder and prompt encoder to maintain the original weights:
I am also providing bounding box Prompts as the input. Here is my custom class for Dataset creation:
error when I DONT provide multitask_output:
error when I do provide the multitask_output argument:
The text was updated successfully, but these errors were encountered: