255 lines
6.3 KiB
Markdown
255 lines
6.3 KiB
Markdown
# Fine-tuning SAM2 with HF Trainer
|
|
|
|
Fine-tune SAM2.1 on a small part of the MicroMat dataset for image matting,
|
|
using the Hugging Face Trainer with a custom loss function.
|
|
|
|
```python
|
|
!pip install -q transformers datasets monai trackio
|
|
```
|
|
|
|
## Load and explore the dataset
|
|
|
|
```python
|
|
from datasets import load_dataset
|
|
|
|
dataset = load_dataset("merve/MicroMat-mini", split="train")
|
|
dataset
|
|
```
|
|
|
|
```python
|
|
dataset = dataset.train_test_split(test_size=0.1)
|
|
train_ds = dataset["train"]
|
|
val_ds = dataset["test"]
|
|
```
|
|
|
|
```python
|
|
import json
|
|
|
|
train_ds[0]
|
|
```
|
|
|
|
```python
|
|
json.loads(train_ds["prompt"][0])["bbox"]
|
|
```
|
|
|
|
## Visualize a sample
|
|
|
|
```python
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
|
|
|
|
def show_mask(mask, ax, bbox):
|
|
color = np.array([0.12, 0.56, 1.0, 0.6])
|
|
mask = np.array(mask)
|
|
h, w = mask.shape
|
|
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, 4)
|
|
ax.imshow(mask_image)
|
|
x0, y0, x1, y1 = bbox
|
|
ax.add_patch(
|
|
plt.Rectangle(
|
|
(x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor="lime", linewidth=2
|
|
)
|
|
)
|
|
|
|
|
|
example = train_ds[0]
|
|
image = np.array(example["image"])
|
|
ground_truth_mask = np.array(example["mask"])
|
|
|
|
fig, ax = plt.subplots()
|
|
ax.imshow(image)
|
|
show_mask(ground_truth_mask, ax, json.loads(example["prompt"])["bbox"])
|
|
ax.set_title("Ground truth mask")
|
|
ax.set_axis_off()
|
|
plt.show()
|
|
```
|
|
|
|
## Build the dataset and collator
|
|
|
|
`SAMDataset` wraps each sample into the format expected by the SAM2 processor.
|
|
Ground-truth masks are stored under the key `"labels"` so the Trainer
|
|
automatically pops them before calling `model.forward()`.
|
|
|
|
```python
|
|
from torch.utils.data import Dataset
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
|
|
class SAMDataset(Dataset):
|
|
def __init__(self, dataset, processor):
|
|
self.dataset = dataset
|
|
self.processor = processor
|
|
|
|
def __len__(self):
|
|
return len(self.dataset)
|
|
|
|
def __getitem__(self, idx):
|
|
item = self.dataset[idx]
|
|
image = item["image"]
|
|
prompt = json.loads(item["prompt"])["bbox"]
|
|
inputs = self.processor(image, input_boxes=[[prompt]], return_tensors="pt")
|
|
inputs["labels"] = (np.array(item["mask"]) > 0).astype(np.float32)
|
|
inputs["original_image_size"] = torch.tensor(image.size[::-1])
|
|
return inputs
|
|
|
|
|
|
def collate_fn(batch):
|
|
pixel_values = torch.cat([item["pixel_values"] for item in batch], dim=0)
|
|
original_sizes = torch.stack([item["original_sizes"] for item in batch])
|
|
input_boxes = torch.cat([item["input_boxes"] for item in batch], dim=0)
|
|
labels = torch.cat(
|
|
[
|
|
F.interpolate(
|
|
torch.as_tensor(x["labels"]).unsqueeze(0).unsqueeze(0).float(),
|
|
size=(256, 256),
|
|
mode="nearest",
|
|
)
|
|
for x in batch
|
|
],
|
|
dim=0,
|
|
).long()
|
|
|
|
return {
|
|
"pixel_values": pixel_values,
|
|
"original_sizes": original_sizes,
|
|
"input_boxes": input_boxes,
|
|
"labels": labels,
|
|
"original_image_size": torch.stack(
|
|
[item["original_image_size"] for item in batch]
|
|
),
|
|
"multimask_output": False,
|
|
}
|
|
```
|
|
|
|
```python
|
|
from transformers import Sam2Processor
|
|
|
|
processor = Sam2Processor.from_pretrained("facebook/sam2.1-hiera-small")
|
|
|
|
train_dataset = SAMDataset(dataset=train_ds, processor=processor)
|
|
val_dataset = SAMDataset(dataset=val_ds, processor=processor)
|
|
```
|
|
|
|
## Load model and freeze encoder layers
|
|
|
|
```python
|
|
from transformers import Sam2Model
|
|
|
|
model = Sam2Model.from_pretrained("facebook/sam2.1-hiera-small")
|
|
|
|
for name, param in model.named_parameters():
|
|
if name.startswith("vision_encoder") or name.startswith("prompt_encoder"):
|
|
param.requires_grad_(False)
|
|
```
|
|
|
|
## Inference before training
|
|
|
|
```python
|
|
item = val_ds[1]
|
|
img = item["image"]
|
|
bbox = json.loads(item["prompt"])["bbox"]
|
|
inputs = processor(images=img, input_boxes=[[bbox]], return_tensors="pt").to(
|
|
model.device
|
|
)
|
|
|
|
with torch.no_grad():
|
|
outputs = model(**inputs)
|
|
|
|
masks = processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"])[0]
|
|
preds = masks.squeeze(0)
|
|
mask = (preds[0] > 0).cpu().numpy()
|
|
|
|
overlay = np.asarray(img, dtype=np.uint8).copy()
|
|
overlay[mask] = 0.55 * overlay[mask] + 0.45 * np.array([0, 255, 0], dtype=np.float32)
|
|
|
|
plt.imshow(overlay)
|
|
plt.title("Before training")
|
|
plt.axis("off")
|
|
plt.show()
|
|
```
|
|
|
|
## Define custom loss
|
|
|
|
SAM2 does not compute loss in its `forward()`, so we provide a
|
|
`compute_loss_func` to the Trainer. The Trainer pops `"labels"` from the
|
|
batch before calling `model(**inputs)`, then passes `(outputs, labels)` to
|
|
this function.
|
|
|
|
```python
|
|
import monai
|
|
from transformers import Trainer, TrainingArguments
|
|
import trackio
|
|
|
|
seg_loss = monai.losses.DiceCELoss(sigmoid=True, squared_pred=True, reduction="mean")
|
|
|
|
|
|
def compute_loss(outputs, labels, num_items_in_batch=None):
|
|
predicted_masks = outputs.pred_masks.squeeze(1)
|
|
return seg_loss(predicted_masks, labels.float())
|
|
```
|
|
|
|
## Train with Trainer
|
|
|
|
Key settings:
|
|
- `remove_unused_columns=False`: the Trainer must keep `input_boxes`,
|
|
`original_sizes`, etc. that are not in the model's `forward()` signature.
|
|
- `compute_loss_func`: our custom DiceCE loss.
|
|
- `report_to="trackio"`: logs the training loss to trackio.
|
|
|
|
```python
|
|
training_args = TrainingArguments(
|
|
output_dir="sam2-finetuned",
|
|
num_train_epochs=30,
|
|
per_device_train_batch_size=4,
|
|
learning_rate=1e-5,
|
|
weight_decay=0,
|
|
logging_steps=1,
|
|
save_strategy="epoch",
|
|
save_total_limit=2,
|
|
remove_unused_columns=False,
|
|
dataloader_pin_memory=False,
|
|
report_to="trackio",
|
|
)
|
|
|
|
trainer = Trainer(
|
|
model=model,
|
|
args=training_args,
|
|
train_dataset=train_dataset,
|
|
data_collator=collate_fn,
|
|
compute_loss_func=compute_loss,
|
|
)
|
|
|
|
trainer.train()
|
|
```
|
|
|
|
## Inference after training
|
|
|
|
```python
|
|
item = val_ds[1]
|
|
img = item["image"]
|
|
bbox = json.loads(item["prompt"])["bbox"]
|
|
|
|
inputs = processor(images=img, input_boxes=[[bbox]], return_tensors="pt").to(
|
|
model.device
|
|
)
|
|
|
|
with torch.no_grad():
|
|
outputs = model(**inputs)
|
|
|
|
preds = processor.post_process_masks(
|
|
outputs.pred_masks.cpu(), inputs["original_sizes"]
|
|
)[0]
|
|
preds = preds.squeeze(0)
|
|
mask = (preds[0] > 0).cpu().numpy()
|
|
|
|
overlay = np.asarray(img, dtype=np.uint8).copy()
|
|
overlay[mask] = 0.55 * overlay[mask] + 0.45 * np.array([0, 255, 0], dtype=np.float32)
|
|
|
|
plt.imshow(overlay)
|
|
plt.title("After training")
|
|
plt.axis("off")
|
|
plt.show()
|
|
```
|