3D Segmentation - Spleen - PyTorch Lightning
Original tutorial: https://github.com/Project-MONAI/tutorials/blob/master/3d_segmentation/spleen_segmentation_3d_lightning.ipynb
This tutorial demonstrates how MONAI can be used in conjunction with the PyTorch Lightning framework.
If PyTorch Lightning is not installed, install it:
pip install pytorch-lightning
Setup imports
from monai.utils import set_determinism
from monai.transforms import (
AsDiscrete,
AddChanneld,
Compose,
CropForegroundd,
LoadImaged,
Orientationd,
RandCropByPosNegLabeld,
ScaleIntensityRanged,
Spacingd,
EnsureTyped,
EnsureType,
)
from monai.networks.nets import UNet
from monai.networks.layers import Norm
from monai.metrics import DiceMetric
from monai.losses import DiceLoss
from monai.inferers import sliding_window_inference
from monai.data import CacheDataset, list_data_collate, decollate_batch
from monai.config import print_config
from monai.apps import download_and_extract
import torch
import pytorch_lightning
import matplotlib.pyplot as plt
import tempfile
import shutil
import os
import glob
Data Dictionary and download
Setup data dictionary and download the data if needed, assuming the spleen dataset is in MONAI_DATA_DIRECTORY/MSD/Task09_Spleen/:
directory = os.environ.get("MONAI_DATA_DIRECTORY")
root_dir = tempfile.mkdtemp() if directory is None else directory
root_dir = os.path.join(root_dir, "MSD")
data_dir = os.path.join(root_dir, "Task09_Spleen")
if not os.path.exists(data_dir):
resource = "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task09_Spleen.tar"
md5 = "410d4a301da4e5b2f6f86ec3ddba524e"
compressed_file = os.path.join(root_dir, "Task09_Spleen.tar")
download_and_extract(resource, compressed_file, root_dir, md5)
Define the LightningModule
The LightningModule contains a refactoring of your training code. The following module is a refactoring of the code in spleen_segmentation_3d.ipynb:
class Net(pytorch_lightning.LightningModule):
def __init__(self):
super().__init__()
self._model = UNet(
spatial_dims=3,
in_channels=1,
out_channels=2,
channels=(16, 32, 64, 128, 256),
strides=(2, 2, 2, 2),
num_res_units=2,
norm=Norm.BATCH,
)
self.loss_function = DiceLoss(to_onehot_y=True, softmax=True)
self.post_pred = Compose([EnsureType(), AsDiscrete(argmax=True, to_onehot=2)])
self.post_label = Compose([EnsureType(), AsDiscrete(to_onehot=2)])
self.dice_metric = DiceMetric(include_background=False, reduction="mean", get_not_nans=False)
self.best_val_dice = 0
self.best_val_epoch = 0\
def forward(self, x):
return self._model(x)\
def prepare_data(self):
# set up the correct data path
train_images = sorted(
glob.glob(os.path.join(data_dir, "imagesTr", "*.nii.gz")))
train_labels = sorted(
glob.glob(os.path.join(data_dir, "labelsTr", "*.nii.gz")))
data_dicts = [
{"image": image_name, "label": label_name}
for image_name, label_name in zip(train_images, train_labels)
]
train_files, val_files = data_dicts[:-9], data_dicts[-9:]\
# set deterministic training for reproducibility
set_determinism(seed=0)\
# define the data transforms
train_transforms = Compose(
[
LoadImaged(keys=["image", "label"]),
AddChanneld(keys=["image", "label"]),
Spacingd(
keys=["image", "label"],
pixdim=(1.5, 1.5, 2.0),
mode=("bilinear", "nearest"),
),
Orientationd(keys=["image", "label"], axcodes="RAS"),
ScaleIntensityRanged(
keys=["image"], a_min=-57, a_max=164,
b_min=0.0, b_max=1.0, clip=True,
),
CropForegroundd(keys=["image", "label"], source_key="image"),
# randomly crop out patch samples from
# big image based on pos / neg ratio
# the image centers of negative samples
# must be in valid image area
RandCropByPosNegLabeld(
keys=["image", "label"],
label_key="label",
spatial_size=(96, 96, 96),
pos=1,
neg=1,
num_samples=4,
image_key="image",
image_threshold=0,
),
# user can also add other random transforms
# RandAffined(
# keys=['image', 'label'],
# mode=('bilinear', 'nearest'),
# prob=1.0,
# spatial_size=(96, 96, 96),
# rotate_range=(0, 0, np.pi/15),
# scale_range=(0.1, 0.1, 0.1)),
EnsureTyped(keys=["image", "label"]),
]
)
val_transforms = Compose(
[
LoadImaged(keys=["image", "label"]),
AddChanneld(keys=["image", "label"]),
Spacingd(
keys=["image", "label"],
pixdim=(1.5, 1.5, 2.0),
mode=("bilinear", "nearest"),
),
Orientationd(keys=["image", "label"], axcodes="RAS"),
ScaleIntensityRanged(
keys=["image"], a_min=-57, a_max=164,
b_min=0.0, b_max=1.0, clip=True,
),
CropForegroundd(keys=["image", "label"], source_key="image"),
EnsureTyped(keys=["image", "label"]),
]
)\
# we use cached datasets - these are 10x faster than regular datasets
self.train_ds = CacheDataset(
data=train_files, transform=train_transforms,
cache_rate=1.0, num_workers=4,
)
self.val_ds = CacheDataset(
data=val_files, transform=val_transforms,
cache_rate=1.0, num_workers=4,
)
# self.train_ds = monai.data.Dataset(
# data=train_files, transform=train_transforms)
# self.val_ds = monai.data.Dataset(
# data=val_files, transform=val_transforms)
def train_dataloader(self):
train_loader = torch.utils.data.DataLoader(
self.train_ds, batch_size=2, shuffle=True,
num_workers=4, collate_fn=list_data_collate,
)
return train_loader\
def val_dataloader(self):
val_loader = torch.utils.data.DataLoader(
self.val_ds, batch_size=1, num_workers=4)
return val_loader\
def configure_optimizers(self):
optimizer = torch.optim.Adam(self._model.parameters(), 1e-4)
return optimizer\
def training_step(self, batch, batch_idx):
images, labels = batch["image"], batch["label"]
output = self.forward(images)
loss = self.loss_function(output, labels)
tensorboard_logs = {"train_loss": loss.item()}
return {"loss": loss, "log": tensorboard_logs}\
def validation_step(self, batch, batch_idx):
images, labels = batch["image"], batch["label"]
roi_size = (160, 160, 160)
sw_batch_size = 4
outputs = sliding_window_inference(
images, roi_size, sw_batch_size, self.forward)
loss = self.loss_function(outputs, labels)
outputs = [self.post_pred(i) for i in decollate_batch(outputs)]
labels = [self.post_label(i) for i in decollate_batch(labels)]
self.dice_metric(y_pred=outputs, y=labels)
return {"val_loss": loss, "val_number": len(outputs)}\
def validation_epoch_end(self, outputs):
val_loss, num_items = 0, 0
for output in outputs:
val_loss += output["val_loss"].sum().item()
num_items += output["val_number"]
mean_val_dice = self.dice_metric.aggregate().item()
self.dice_metric.reset()
mean_val_loss = torch.tensor(val_loss / num_items)
tensorboard_logs = {
"val_dice": mean_val_dice,
"val_loss": mean_val_loss,
}
if mean_val_dice > self.best_val_dice:
self.best_val_dice = mean_val_dice
self.best_val_epoch = self.current_epoch
print(
f"current epoch: {self.current_epoch} "
f"current mean dice: {mean_val_dice:.4f}"
f"\nbest mean dice: {self.best_val_dice:.4f} "
f"at epoch: {self.best_val_epoch}"
)
return {"log": tensorboard_logs}
Run the training
# initialise the LightningModule
net = Net()
# set up loggers and checkpoints
log_dir = os.path.join(root_dir, "logs")
tb_logger = pytorch_lightning.loggers.TensorBoardLogger(
save_dir=log_dir
)
# initialise Lightning's trainer.
trainer = pytorch_lightning.Trainer(
gpus=[0],
max_epochs=600,
logger=tb_logger,
checkpoint_callback=True,
num_sanity_val_steps=1,
)
# train
trainer.fit(net)