Image segmentation with UNETR model#

Open in Colab

This tutorial demonstrates how to implement and train a model on image segmentation task. Below, we will be using the Oxford Pets dataset containing images and masks of cats and dogs. We will implement from scratch the UNETR model using Flax NNX. We will train the model on a training set and compute image segmentation metrics on the training and validation sets. We will use Orbax checkpoint manager to store best models during the training.

Prepare image segmentation dataset and dataloaders#

In this section we use the Oxford Pets dataset. We download images and masks and provide a code to work with the dataset. This approach can be easily extended to any image segmentation datasets and users can reuse this code for their own datasets.

In the code below we make a choice of using OpenCV and Pillow to read images and masks as NumPy arrays, Albumentations for data augmentations and grain for batched data loading. Alternatively, one can use tensorflow_dataset or torchvision for the same task.

Requirements installation#

We will need to install the following Python packages:

!pip install -U opencv-python-headless grain albumentations Pillow
!pip install -U flax optax orbax-checkpoint
import jax
import flax
import optax
import orbax.checkpoint as ocp
print("Jax version:", jax.__version__)
print("Flax version:", flax.__version__)
print("Optax version:", optax.__version__)
print("Orbax version:", ocp.__version__)
Jax version: 0.4.34
Flax version: 0.10.1
Optax version: 0.2.4
Orbax version: 0.9.1

Data download#

Let’s download the data and extract images and masks.

!rm -rf /tmp/data/oxford_pets
!mkdir -p /tmp/data/oxford_pets
!wget https://thor.robots.ox.ac.uk/datasets/pets/images.tar.gz -O /tmp/data/oxford_pets/images.tar.gz
!wget https://thor.robots.ox.ac.uk/datasets/pets/annotations.tar.gz -O /tmp/data/oxford_pets/annotations.tar.gz

!cd /tmp/data/oxford_pets && tar -xf images.tar.gz
!cd /tmp/data/oxford_pets && tar -xf annotations.tar.gz
!ls /tmp/data/oxford_pets

We can also inspect the downloaded images folder, listing a subset of these files:

!ls /tmp/data/oxford_pets/images | wc -l
!ls /tmp/data/oxford_pets/images | head
!ls /tmp/data/oxford_pets/annotations/trimaps | wc -l
!ls /tmp/data/oxford_pets/annotations/trimaps | head
7393
Abyssinian_1.jpg
Abyssinian_10.jpg
Abyssinian_100.jpg
Abyssinian_100.mat
Abyssinian_101.jpg
Abyssinian_101.mat
Abyssinian_102.jpg
Abyssinian_102.mat
Abyssinian_103.jpg
Abyssinian_104.jpg
ls: write error: Broken pipe
7390
Abyssinian_1.png
Abyssinian_10.png
Abyssinian_100.png
Abyssinian_101.png
Abyssinian_102.png
Abyssinian_103.png
Abyssinian_104.png
Abyssinian_105.png
Abyssinian_106.png
Abyssinian_107.png
ls: write error: Broken pipe

Train/Eval datasets#

Let’s implement the dataset class providing the access to the images and masks. The class implements __len__ and __getitem__ methods. In this example, we do not have a hard training and validation data split, so we will use the total dataset and make a random training/validation split by indices. For this purpose we provide a helper class to map indices into training and validation parts.

from typing import Any
from pathlib import Path

import cv2
import numpy as np
from PIL import Image  # we'll read images with opencv and use Pillow as a fallback


class OxfordPetsDataset:
    def __init__(self, path: Path):
        assert path.exists(), path
        self.path: Path = path
        self.images = sorted((self.path / "images").glob("*.jpg"))
        self.masks = [
            self.path / "annotations" / "trimaps" / path.with_suffix(".png").name
            for path in self.images
        ]
        assert len(self.images) == len(self.masks), (len(self.images), len(self.masks))

    def __len__(self) -> int:
        return len(self.images)

    def read_image_opencv(self, path: Path):
        img = cv2.imread(str(path))
        if img is not None:
            return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        else:
            None

    def read_image_pillow(self, path: Path):
        img = Image.open(str(path))
        img = img.convert("RGB")
        return np.asarray(img)

    def read_mask(self, path: Path):
        mask = cv2.imread(str(path), cv2.IMREAD_GRAYSCALE)
        # mask has values: 1, 2, 3
        # 1 - object mask
        # 2 - background
        # 3 - boundary
        # Define mask as 0-based int values
        mask = mask - 1
        return mask.astype("uint8")

    def __getitem__(self, index: int) -> dict[str, np.ndarray]:
        img_path, mask_path = self.images[index], self.masks[index]
        img = self.read_image_opencv(img_path)
        if img is None:
            # Fallback to Pillow if OpenCV fails to read an image
            img = self.read_image_pillow(img_path)
        mask = self.read_mask(mask_path)
        return {
            "image": img,
            "mask": mask,
        }


class SubsetDataset:
    def __init__(self, dataset, indices: list[int]):
        # Check input indices values:
        for i in indices:
            assert 0 <= i < len(dataset)
        self.dataset = dataset
        self.indices = indices

    def __len__(self) -> int:
        return len(self.indices)

    def __getitem__(self, index: int) -> Any:
        i = self.indices[index]
        return self.dataset[i]

Now, let’s define the total dataset and compute data indices for training and validation splits:

seed = 12
train_split = 0.7
dataset_path = Path("/tmp/data/oxford_pets")

dataset = OxfordPetsDataset(dataset_path)

rng = np.random.default_rng(seed=seed)
le = len(dataset)
data_indices = list(range(le))

# Let's remove few indices corresponding to corrupted images
# to avoid libjpeg warnings during the data loading
corrupted_data_indices = [3017, 3425]
for index in corrupted_data_indices:
    data_indices.remove(index)

random_indices = rng.permutation(data_indices)

train_val_split_index = int(train_split * le)
train_indices = random_indices[:train_val_split_index]
val_indices = random_indices[train_val_split_index:]

# Ensure there is no overlapping
assert len(set(train_indices) & set(val_indices)) == 0

train_dataset = SubsetDataset(dataset, indices=train_indices)
val_dataset = SubsetDataset(dataset, indices=val_indices)

print("Training dataset size:", len(train_dataset))
print("Validation dataset size:", len(val_dataset))
Training dataset size: 5173
Validation dataset size: 2215

To verify our work so far, let’s display few training and validation images and masks:

import matplotlib.pyplot as plt


def display_datapoint(datapoint, label=""):
    img, mask = datapoint["image"], datapoint["mask"]
    if img.dtype in (np.float32, ):
        img = ((img - img.min()) / (img.max() - img.min()) * 255.0).astype(np.uint8)
    fig, axs = plt.subplots(1, 3, figsize=(10, 10))
    axs[0].set_title(f"Image{label}")
    axs[0].imshow(img)
    axs[1].set_title(f"Mask{label}")
    axs[1].imshow(mask)
    axs[2].set_title("Image + Mask")
    axs[2].imshow(img)
    axs[2].imshow(mask, alpha=0.5)



display_datapoint(train_dataset[0], label=" (train set)")
display_datapoint(val_dataset[0], label=" (val set)")
_images/ecbb6e70de83fa68f1f39b3e51a5ea11ef09eb3e61c7f125581245450f8584c9.png _images/5a5f99d21b20db7872efbc0799bfc0d9aa5ed39cfb3f251fe4749026591390ae.png

Data augmentations#

Next, let’s define a simple data augmentation pipeline of joined image and mask transformations using Albumentations. We apply geometric and color transformations to increase the diversity of the training data. For more details on the Albumentations transformations, we can check Albumentations reference API.

import albumentations as A


img_size = 256

train_transforms = A.Compose([
    A.Affine(rotate=(-35, 35), cval_mask=1, p=0.3),  # Random rotations -35 to 35 degrees
    A.RandomResizedCrop(width=img_size, height=img_size, scale=(0.7, 1.0)),  # Crop a random part of the input and rescale it to a specified size
    A.HorizontalFlip(p=0.5),  # Horizontal random flip
    A.RandomBrightnessContrast(p=0.4),  # Randomly changes the brightness and contrast
    A.Normalize(),  # Normalize the image and cast to float
])


val_transforms = A.Compose([
    A.Resize(width=img_size, height=img_size),
    A.Normalize(),  # Normalize the image and cast to float
])
output = train_transforms(**train_dataset[0])
img, mask = output["image"], output["mask"]
print("Image array info:", img.dtype, img.shape, img.min(), img.mean(), img.max())
print("Mask array info:", mask.dtype, mask.shape, mask.min(), mask.max())
Image array info: float32 (256, 256, 3) -1.5356623 0.5732621 2.6399999
Mask array info: uint8 (256, 256) 0 2
output = val_transforms(**val_dataset[0])
img, mask = output["image"], output["mask"]
print("Image array info:", img.dtype, img.shape, img.min(), img.mean(), img.max())
print("Mask array info:", mask.dtype, mask.shape, mask.min(), mask.max())
Image array info: float32 (256, 256, 3) -2.117904 -0.30076745 2.6399999
Mask array info: uint8 (256, 256) 0 2

Data loaders#

Let’s now use grain to perform data loading, augmentations and batching on a single device using multiple workers. We will create a random index sampler for training and an unshuffled sampler for validation.

from typing import Any, Callable

import grain.python as grain


class DataAugs(grain.MapTransform):
    def __init__(self, transforms: Callable):
        self.albu_transforms = transforms

    def map(self, data):
        output = self.albu_transforms(**data)
        return output
train_batch_size = 72
val_batch_size = 2 * train_batch_size


# Create an IndexSampler with no sharding for single-device computations
train_sampler = grain.IndexSampler(
    len(train_dataset),  # The total number of samples in the data source
    shuffle=True,            # Shuffle the data to randomize the order of samples
    seed=seed,               # Set a seed for reproducibility
    shard_options=grain.NoSharding(),  # No sharding since this is a single-device setup
    num_epochs=1,            # Iterate over the dataset for one epoch
)

val_sampler = grain.IndexSampler(
    len(val_dataset),  # The total number of samples in the data source
    shuffle=False,         # Do not shuffle the data
    seed=seed,             # Set a seed for reproducibility
    shard_options=grain.NoSharding(),  # No sharding since this is a single-device setup
    num_epochs=1,          # Iterate over the dataset for one epoch
)
train_loader = grain.DataLoader(
    data_source=train_dataset,
    sampler=train_sampler,                 # Sampler to determine how to access the data
    worker_count=4,                        # Number of child processes launched to parallelize the transformations among
    worker_buffer_size=2,                  # Count of output batches to produce in advance per worker
    operations=[
        DataAugs(train_transforms),
        grain.Batch(train_batch_size, drop_remainder=True),
    ]
)

# Validation dataset loader
val_loader = grain.DataLoader(
    data_source=val_dataset,
    sampler=val_sampler,                   # Sampler to determine how to access the data
    worker_count=4,                        # Number of child processes launched to parallelize the transformations among
    worker_buffer_size=2,
    operations=[
        DataAugs(val_transforms),
        grain.Batch(val_batch_size),
    ]
)

# Training dataset loader for evaluation (without dataaugs)
train_eval_loader = grain.DataLoader(
    data_source=train_dataset,
    sampler=train_sampler,                 # Sampler to determine how to access the data
    worker_count=4,                        # Number of child processes launched to parallelize the transformations among
    worker_buffer_size=2,                  # Count of output batches to produce in advance per worker
    operations=[
        DataAugs(val_transforms),
        grain.Batch(val_batch_size),
    ]
)
train_batch = next(iter(train_loader))
val_batch = next(iter(val_loader))
print("Train images batch info:", type(train_batch["image"]), train_batch["image"].shape, train_batch["image"].dtype)
print("Train masks batch info:", type(train_batch["mask"]), train_batch["mask"].shape, train_batch["mask"].dtype)
Train images batch info: <class 'grain._src.python.shared_memory_array.SharedMemoryArray'> (72, 256, 256, 3) float32
Train masks batch info: <class 'grain._src.python.shared_memory_array.SharedMemoryArray'> (72, 256, 256) uint8

Finally, let’s display the training and validation data:

images, masks = train_batch["image"], train_batch["mask"]

for img, mask in zip(images[:3], masks[:3]):
    display_datapoint({"image": img, "mask": mask}, label=" (augmented train set)")
_images/46f43f0ec7118f1eecdd97a55e142a0ef095ec0bc2f2fb1042f9b2c306929a39.png _images/7ae4c81c8256c1f8b64d9907ee8663b6e1910837c4e64a72c8158212939b8a32.png _images/8948779053da6e41e062a05e02e71afe199854491362ad2d72c6ad88d869ae25.png
images, masks = val_batch["image"], val_batch["mask"]

for img, mask in zip(images[:3], masks[:3]):
    display_datapoint({"image": img, "mask": mask}, label=" (augmented validation set)")
_images/ded16f77c2cb5eb418eb33c977b55296e1c6eddaa813a1241f2cb33ba8613147.png _images/ff6f7c2a8e94dada25eb332224fcc8253e702f0c520274caf5a1e1efeea4ae8b.png _images/97d98cf99811bc0d7af3657986fbcc60d2298936b402825c7a76877f01524135.png

Model for Image Segmentation#

In this section we will implement the UNETR model from scratch using Flax NNX. The reference PyTorch implementation of this model can be found on the MONAI Library GitHub repository.

The UNETR model utilizes a transformer as the encoder to learn sequence representations of the input and to capture the global multi-scale information, while also following the “U-shaped” network design like UNet model: image.png

The UNETR architecture on the image above is processing 3D inputs, but it can be easily adapted to 2D input.

The transformer encoder of UNETR is Vision Transformer (ViT). The feature maps returned by ViT have all the same spatial size: (H / 16, W / 16) and deconvolutions are used to upsample the feature maps. Finally, the feature maps are upsampled and concatenated up to the original image size.

from flax import nnx
import jax.numpy as jnp

Vision Transformer encoder implementation#

Below, we will implement the following modules:

  • Vision Transformer, ViT

    • PatchEmbeddingBlock: patch embedding block, which maps patches of pixels to a sequence of vectors

    • ViTEncoderBlock: vision transformer encoder block

      • MLPBlock: multilayer perceptron block

class PatchEmbeddingBlock(nnx.Module):
    """
    A patch embedding block, based on: "Dosovitskiy et al.,
    An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"
    """
    def __init__(
        self,
        in_channels: int,  # dimension of input channels.
        img_size: int,  # dimension of input image.
        patch_size: int,  # dimension of patch size.
        hidden_size: int,  # dimension of hidden layer.
        dropout_rate: float = 0.0,
        *,
        rngs: nnx.Rngs = nnx.Rngs(0),
    ):
        n_patches = (img_size // patch_size) ** 2
        self.patch_embeddings = nnx.Conv(
            in_channels,
            hidden_size,
            kernel_size=(patch_size, patch_size),
            strides=(patch_size, patch_size),
            padding="VALID",
            use_bias=True,
            rngs=rngs,
        )

        initializer = jax.nn.initializers.truncated_normal(stddev=0.02)
        self.position_embeddings = nnx.Param(
            initializer(rngs.params(), (1, n_patches, hidden_size), jnp.float32)
        )
        self.dropout = nnx.Dropout(dropout_rate, rngs=rngs)

    def __call__(self, x: jax.Array) -> jax.Array:
        x = self.patch_embeddings(x)
        x = x.reshape(x.shape[0], -1, x.shape[-1])
        embeddings = x + self.position_embeddings
        embeddings = self.dropout(embeddings)
        return embeddings


mod = PatchEmbeddingBlock(3, 256, 16, 768, 0.5)
x = jnp.ones((4, 256, 256, 3))
y = mod(x)
print(y.shape)
(4, 256, 768)
from typing import Callable


class MLPBlock(nnx.Sequential):
    """
    A multi-layer perceptron block, based on: "Dosovitskiy et al.,
    An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"
    """
    def __init__(
        self,
        hidden_size: int,  # dimension of hidden layer.
        mlp_dim: int,      # dimension of feedforward layer
        dropout_rate: float = 0.0,
        activation_layer: Callable = nnx.gelu,
        *,
        rngs: nnx.Rngs = nnx.Rngs(0),
    ):
        layers = [
            nnx.Linear(hidden_size, mlp_dim, rngs=rngs),
            activation_layer,
            nnx.Dropout(dropout_rate, rngs=rngs),
            nnx.Linear(mlp_dim, hidden_size, rngs=rngs),
            nnx.Dropout(dropout_rate, rngs=rngs),
        ]
        super().__init__(*layers)


mod = MLPBlock(768, 3072, 0.5)
x = jnp.ones((4, 256, 768))
y = mod(x)
print(y.shape)
(4, 256, 768)
class ViTEncoderBlock(nnx.Module):
    """
    A transformer encoder block, based on: "Dosovitskiy et al.,
    An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"
    """
    def __init__(
        self,
        hidden_size: int,  # dimension of hidden layer.
        mlp_dim: int,      # dimension of feedforward layer.
        num_heads: int,    # number of attention heads
        dropout_rate: float = 0.0,
        *,
        rngs: nnx.Rngs = nnx.Rngs(0),
    ) -> None:
        self.mlp = MLPBlock(hidden_size, mlp_dim, dropout_rate, rngs=rngs)
        self.norm1 = nnx.LayerNorm(hidden_size, rngs=rngs)
        self.attn = nnx.MultiHeadAttention(
            num_heads=num_heads,
            in_features=hidden_size,
            dropout_rate=dropout_rate,
            broadcast_dropout=False,
            decode=False,
            rngs=rngs,
        )
        self.norm2 = nnx.LayerNorm(hidden_size, rngs=rngs)

    def __call__(self, x: jax.Array) -> jax.Array:
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x


mod = ViTEncoderBlock(768, 3072, 12)
x = jnp.ones((4, 256, 768))
y = mod(x)
print(y.shape)
(4, 256, 768)
class ViT(nnx.Module):
    """
    Vision Transformer (ViT) Feature Extractor, based on: "Dosovitskiy et al.,
    An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"
    """
    def __init__(
        self,
        in_channels: int,  # dimension of input channels
        img_size: int,  # dimension of input image
        patch_size: int,  # dimension of patch size
        hidden_size: int = 768,  # dimension of hidden layer
        mlp_dim: int = 3072,  # dimension of feedforward layer
        num_layers: int = 12,  # number of transformer blocks
        num_heads: int = 12,   # number of attention heads
        dropout_rate: float = 0.0,
        *,
        rngs: nnx.Rngs = nnx.Rngs(0),
    ):
        if hidden_size % num_heads != 0:
            raise ValueError("hidden_size should be divisible by num_heads.")

        self.patch_embedding = PatchEmbeddingBlock(
            in_channels=in_channels,
            img_size=img_size,
            patch_size=patch_size,
            hidden_size=hidden_size,
            dropout_rate=dropout_rate,
            rngs=rngs,
        )
        self.blocks = [
            ViTEncoderBlock(hidden_size, mlp_dim, num_heads, dropout_rate, rngs=rngs)
            for i in range(num_layers)
        ]
        self.norm = nnx.LayerNorm(hidden_size, rngs=rngs)

    def __call__(self, x: jax.Array) -> jax.Array:
        x = self.patch_embedding(x)
        hidden_states_out = []
        for blk in self.blocks:
            x = blk(x)
            hidden_states_out.append(x)
        x = self.norm(x)
        return x, hidden_states_out


mod = ViT(3, 224, 16)
x = jnp.ones((4, 224, 224, 3))
y, hstates = mod(x)
print(y.shape, [s.shape for s in hstates])
(4, 196, 768) [(4, 196, 768), (4, 196, 768), (4, 196, 768), (4, 196, 768), (4, 196, 768), (4, 196, 768), (4, 196, 768), (4, 196, 768), (4, 196, 768), (4, 196, 768), (4, 196, 768), (4, 196, 768)]

At this point we implemented the encoder of the UNETR model. As we can see from the above output, ViT provides one encoded feature map and a list of intermediate feature maps. Three of them will be used in the decoding part.

UNETR blocks implementation#

Now, we can implement remaining blocks and assemble them together in the UNETR implementation

Below, we will implement the following modules:

  • UNETR

    • UnetrBasicBlock: creates the first skip connection from the input.

      • UnetResBlock

    • UnetrPrUpBlock: projection upsampling modules to create skip connections from the intermediate feature maps provided by ViT.

    • UnetrUpBlock: upsampling modules used in the decoder

class Conv2dNormActivation(nnx.Sequential):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int = 3,
        stride: int = 1,
        padding: int | None = None,
        groups: int = 1,
        norm_layer: Callable[..., nnx.Module] = nnx.BatchNorm,
        activation_layer: Callable = nnx.relu,
        dilation: int = 1,
        bias: bool | None = None,
        rngs: nnx.Rngs = nnx.Rngs(0),
    ):
        self.out_channels = out_channels

        if padding is None:
            padding = (kernel_size - 1) // 2 * dilation
        if bias is None:
            bias = norm_layer is None

        # sequence integer pairs that give the padding to apply before
        # and after each spatial dimension
        padding = ((padding, padding), (padding, padding))

        layers = [
            nnx.Conv(
                in_channels,
                out_channels,
                kernel_size=(kernel_size, kernel_size),
                strides=(stride, stride),
                padding=padding,
                kernel_dilation=(dilation, dilation),
                feature_group_count=groups,
                use_bias=bias,
                rngs=rngs,
            )
        ]

        if norm_layer is not None:
            layers.append(norm_layer(out_channels, rngs=rngs))

        if activation_layer is not None:
            layers.append(activation_layer)

        super().__init__(*layers)
class InstanceNorm(nnx.GroupNorm):
    def __init__(self, num_features, **kwargs):
        num_groups, group_size = num_features, None
        super().__init__(
            num_features,
            num_groups=num_groups,
            group_size=group_size,
            **kwargs,
        )


class UnetResBlock(nnx.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int,
        stride: int,
        norm_layer: Callable[..., nnx.Module] = InstanceNorm,
        activation_layer: Callable = nnx.leaky_relu,
        *,
        rngs: nnx.Rngs = nnx.Rngs(0),
    ):
        self.conv_norm_act1 = Conv2dNormActivation(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            norm_layer=norm_layer,
            activation_layer=activation_layer,
            rngs=rngs,
        )
        self.conv_norm2 = Conv2dNormActivation(
            in_channels=out_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=1,
            norm_layer=norm_layer,
            activation_layer=None,
            rngs=rngs,
        )

        self.downsample = (in_channels != out_channels) or (stride != 1)
        if self.downsample:
            self.conv_norm3 = Conv2dNormActivation(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=1,
                stride=stride,
                norm_layer=norm_layer,
                activation_layer=None,
                rngs=rngs,
            )
        self.act = activation_layer

    def __call__(self, x: jax.Array) -> jax.Array:
        residual = x
        out = self.conv_norm_act1(x)
        out = self.conv_norm2(out)
        if self.downsample:
            residual = self.conv_norm3(residual)
        out += residual
        out = self.act(out)
        return out


mod = UnetResBlock(16, 32, 3, 1)
x = jnp.ones((4, 24, 24, 16))
y = mod(x)
print(y.shape)
(4, 24, 24, 32)
class UnetrBasicBlock(nnx.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int,
        stride: int,
        norm_layer: Callable[..., nnx.Module] = InstanceNorm,
        *,
        rngs: nnx.Rngs = nnx.Rngs(0),
    ):
        self.layer = UnetResBlock(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            norm_layer=norm_layer,
        )

    def __call__(self, x: jax.Array) -> jax.Array:
        return self.layer(x)


mod = UnetrBasicBlock(16, 32, 3, 1)
x = jnp.ones((4, 24, 24, 16))
y = mod(x)
print(y.shape)
(4, 24, 24, 32)
class UnetrPrUpBlock(nnx.Module):
    """
    A projection upsampling module for UNETR: "Hatamizadeh et al.,
    UNETR: Transformers for 3D Medical Image Segmentation <https://arxiv.org/abs/2103.10504>"
    """

    def __init__(
        self,
        in_channels: int,  # number of input channels.
        out_channels: int, # number of output channels.
        num_layer: int,    # number of upsampling blocks.
        kernel_size: int,
        stride: int,
        upsample_kernel_size: int = 2,  # convolution kernel size for transposed convolution layers.
        norm_layer: Callable[..., nnx.Module] = InstanceNorm,
        *,
        rngs: nnx.Rngs = nnx.Rngs(0),
    ):
        upsample_stride = upsample_kernel_size
        self.transp_conv_init = nnx.ConvTranspose(
            in_features=in_channels,
            out_features=out_channels,
            kernel_size=(upsample_kernel_size, upsample_kernel_size),
            strides=(upsample_stride, upsample_stride),
            padding="VALID",
            rngs=rngs,
        )
        self.blocks = [
            nnx.Sequential(
                nnx.ConvTranspose(
                    in_features=out_channels,
                    out_features=out_channels,
                    kernel_size=(upsample_kernel_size, upsample_kernel_size),
                    strides=(upsample_stride, upsample_stride),
                    rngs=rngs,
                ),
                UnetResBlock(
                    in_channels=out_channels,
                    out_channels=out_channels,
                    kernel_size=kernel_size,
                    stride=stride,
                    norm_layer=norm_layer,
                    rngs=rngs,
                ),
            )
            for i in range(num_layer)
        ]

    def __call__(self, x: jax.Array) -> jax.Array:
        x = self.transp_conv_init(x)
        for blk in self.blocks:
            x = blk(x)
        return x


mod = UnetrPrUpBlock(16, 32, 2, 3, 1)
x = jnp.ones((4, 24, 24, 16))
y = mod(x)
print(y.shape)
(4, 192, 192, 32)
class UnetrUpBlock(nnx.Module):
    """
    An upsampling module for UNETR: "Hatamizadeh et al.,
    UNETR: Transformers for 3D Medical Image Segmentation <https://arxiv.org/abs/2103.10504>"
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int,
        upsample_kernel_size: int = 2,  # convolution kernel size for transposed convolution layers.
        norm_layer: Callable[..., nnx.Module] = InstanceNorm,
        *,
        rngs: nnx.Rngs = nnx.Rngs(0),
    ) -> None:
        upsample_stride = upsample_kernel_size
        self.transp_conv = nnx.ConvTranspose(
            in_features=in_channels,
            out_features=out_channels,
            kernel_size=(upsample_kernel_size, upsample_kernel_size),
            strides=(upsample_stride, upsample_stride),
            padding="VALID",
            rngs=rngs,
        )
        self.conv_block = UnetResBlock(
            out_channels + out_channels,
            out_channels,
            kernel_size=kernel_size,
            stride=1,
            norm_layer=norm_layer,
            rngs=rngs,
        )

    def __call__(self, x: jax.Array, skip: jax.Array) -> jax.Array:
        out = self.transp_conv(x)
        out = jnp.concat((out, skip), axis=-1)
        out = self.conv_block(out)
        return out


mod = UnetrUpBlock(16, 32, 3)
x = jnp.ones((4, 24, 24, 16))
skip = jnp.ones((4, 2 * 24, 2 * 24, 32))
y = mod(x, skip)
print(y.shape)
(4, 48, 48, 32)
class UNETR(nnx.Module):
    """UNETR model ported to NNX from MONAI implementation:
    - https://github.com/Project-MONAI/MONAI/blob/dev/monai/networks/nets/unetr.py
    """
    def __init__(
        self,
        out_channels: int,
        in_channels: int = 3,
        img_size: int = 256,
        feature_size: int = 16,
        hidden_size: int = 768,
        mlp_dim: int = 3072,
        num_heads: int = 12,
        dropout_rate: float = 0.0,
        norm_layer: Callable[..., nnx.Module] = InstanceNorm,
        *,
        rngs: nnx.Rngs = nnx.Rngs(0),
    ):
        if hidden_size % num_heads != 0:
            raise ValueError("hidden_size should be divisible by num_heads.")

        self.num_layers = 12
        self.patch_size = 16
        self.feat_size = img_size // self.patch_size
        self.hidden_size = hidden_size

        self.vit = ViT(
            in_channels=in_channels,
            img_size=img_size,
            patch_size=self.patch_size,
            hidden_size=hidden_size,
            mlp_dim=mlp_dim,
            num_layers=self.num_layers,
            num_heads=num_heads,
            dropout_rate=dropout_rate,
            rngs=rngs,
        )
        self.encoder1 = UnetrBasicBlock(
            in_channels=in_channels,
            out_channels=feature_size,
            kernel_size=3,
            stride=1,
            norm_layer=norm_layer,
            rngs=rngs,
        )
        self.encoder2 = UnetrPrUpBlock(
            in_channels=hidden_size,
            out_channels=feature_size * 2,
            num_layer=2,
            kernel_size=3,
            stride=1,
            upsample_kernel_size=2,
            norm_layer=norm_layer,
            rngs=rngs,
        )
        self.encoder3 = UnetrPrUpBlock(
            in_channels=hidden_size,
            out_channels=feature_size * 4,
            num_layer=1,
            kernel_size=3,
            stride=1,
            upsample_kernel_size=2,
            norm_layer=norm_layer,
            rngs=rngs,
        )
        self.encoder4 = UnetrPrUpBlock(
            in_channels=hidden_size,
            out_channels=feature_size * 8,
            num_layer=0,
            kernel_size=3,
            stride=1,
            upsample_kernel_size=2,
            norm_layer=norm_layer,
            rngs=rngs,
        )
        self.decoder5 = UnetrUpBlock(
            in_channels=hidden_size,
            out_channels=feature_size * 8,
            kernel_size=3,
            upsample_kernel_size=2,
            norm_layer=norm_layer,
            rngs=rngs,
        )
        self.decoder4 = UnetrUpBlock(
            in_channels=feature_size * 8,
            out_channels=feature_size * 4,
            kernel_size=3,
            upsample_kernel_size=2,
            norm_layer=norm_layer,
            rngs=rngs,
        )
        self.decoder3 = UnetrUpBlock(
            in_channels=feature_size * 4,
            out_channels=feature_size * 2,
            kernel_size=3,
            upsample_kernel_size=2,
            norm_layer=norm_layer,
            rngs=rngs,
        )
        self.decoder2 = UnetrUpBlock(
            in_channels=feature_size * 2,
            out_channels=feature_size,
            kernel_size=3,
            upsample_kernel_size=2,
            norm_layer=norm_layer,
            rngs=rngs,
        )

        self.out = nnx.Conv(
            in_features=feature_size,
            out_features=out_channels,
            kernel_size=(1, 1),
            strides=(1, 1),
            padding="VALID",
            use_bias=True,
            rngs=rngs,
        )

        self.proj_axes = (0, 1, 2, 3)
        self.proj_view_shape = [self.feat_size, self.feat_size, self.hidden_size]

    def proj_feat(self, x: jax.Array) -> jax.Array:
        new_view = [x.shape[0]] + self.proj_view_shape
        x = x.reshape(new_view)
        x = jnp.permute_dims(x, self.proj_axes)
        return x

    def __call__(self, x_in: jax.Array) -> jax.Array:
        x, hidden_states_out = self.vit(x_in)
        enc1 = self.encoder1(x_in)
        x2 = hidden_states_out[3]
        enc2 = self.encoder2(self.proj_feat(x2))
        x3 = hidden_states_out[6]
        enc3 = self.encoder3(self.proj_feat(x3))
        x4 = hidden_states_out[9]
        enc4 = self.encoder4(self.proj_feat(x4))
        dec4 = self.proj_feat(x)
        dec3 = self.decoder5(dec4, enc4)
        dec2 = self.decoder4(dec3, enc3)
        dec1 = self.decoder3(dec2, enc2)
        out = self.decoder2(dec1, enc1)
        return self.out(out)
# We'll use a different number of heads to make a smaller model
model = UNETR(out_channels=3, num_heads=4)
x = jnp.ones((4, 256, 256, 3))
y = model(x)
print(y.shape)
(4, 256, 256, 3)

We can visualize and inspect the architecture on the implemented model using nnx.display(model).

Train the model#

In previous sections we defined training and validation dataloaders and the model. In this section we will train the model and define the loss function and the optimizer to perform the parameters optimization.

For the semantic segmentation task, we can define the loss function as a sum of Cross-Entropy and Jaccard loss functions. The Cross-Entropy loss function is a standard loss function for a multi-class classification tasks and the Jaccard loss function helps directly optimizing Intersection-over-Union measure for semantic segmentation.

import optax

num_epochs = 50
total_steps = len(train_dataset) // train_batch_size
learning_rate = 0.003
momentum = 0.9
lr_schedule = optax.linear_schedule(learning_rate, 0.0, num_epochs * total_steps)

iterate_subsample = np.linspace(0, num_epochs * total_steps, 100)
plt.plot(
    np.linspace(0, num_epochs, len(iterate_subsample)),
    [lr_schedule(i) for i in iterate_subsample],
    lw=3,
)
plt.title("Learning rate")
plt.xlabel("Epochs")
plt.ylabel("Learning rate")
plt.grid()
plt.xlim((0, num_epochs))
plt.show()


optimizer = nnx.Optimizer(model, optax.adam(lr_schedule, momentum))
_images/332748e0d0cae2b43124de127b2391c8631821987e56917fd723fee86df8bf2c.png

Let us implement Jaccard loss and the loss function combining Cross-Entropy and Jaccard losses.

def compute_softmax_jaccard_loss(logits, masks, reduction="mean"):
    assert reduction in ("mean", "sum")
    y_pred = nnx.softmax(logits, axis=-1)
    b, c = y_pred.shape[0], y_pred.shape[-1]
    y = nnx.one_hot(masks, num_classes=c, axis=-1)

    y_pred = y_pred.reshape((b, -1, c))
    y = y.reshape((b, -1, c))

    intersection = y_pred * y
    union = y_pred + y - intersection + 1e-8

    intersection = jnp.sum(intersection, axis=1)
    union = jnp.sum(union, axis=1)

    if reduction == "mean":
        intersection = jnp.mean(intersection)
        union = jnp.mean(union)
    elif reduction == "sum":
        intersection = jnp.sum(intersection)
        union = jnp.sum(union)

    return 1.0 - intersection / union


def compute_losses_and_logits(model: nnx.Module, images: jax.Array, masks: jax.Array):
    logits = model(images)

    xentropy_loss = optax.softmax_cross_entropy_with_integer_labels(
        logits=logits, labels=masks
    ).mean()

    jacc_loss = compute_softmax_jaccard_loss(logits=logits, masks=masks)
    loss = xentropy_loss + jacc_loss
    return loss, (xentropy_loss, jacc_loss, logits)

Now, we will implement a confusion matrix metric derived from nnx.Metric. A confusion matrix will help us to compute the Intersection-Over-Union (IoU) metric per class and on average. Finally, we can also compute the accuracy metric using the confusion matrix.

class ConfusionMatrix(nnx.Metric):
    def __init__(
        self,
        num_classes: int,
        average: str | None = None,
    ):
        assert average in (None, "samples", "recall", "precision")
        assert num_classes > 0
        self.num_classes = num_classes
        self.average = average
        self.confusion_matrix = nnx.metrics.MetricState(
            jnp.zeros((self.num_classes, self.num_classes), dtype=jnp.int32)
        )
        self.count = nnx.metrics.MetricState(jnp.array(0, dtype=jnp.int32))

    def reset(self):
        self.confusion_matrix.value = jnp.zeros((self.num_classes, self.num_classes), dtype=jnp.int32)
        self.count.value = jnp.array(0, dtype=jnp.int32)

    def _check_shape(self, y_pred: jax.Array, y: jax.Array):
        if y_pred.shape[-1] != self.num_classes:
            raise ValueError(f"y_pred does not have correct number of classes: {y_pred.shape[-1]} vs {self.num_classes}")

        if not (y.ndim + 1 == y_pred.ndim):
            raise ValueError(
                f"y_pred must have shape (batch_size, num_classes (currently set to {self.num_classes}), ...) "
                "and y must have shape of (batch_size, ...), "
                f"but given {y.shape} vs {y_pred.shape}."
            )

    def update(self, **kwargs):
        # We assume that y.max() < self.num_classes and y.min() >= 0
        assert "y" in kwargs
        assert "y_pred" in kwargs
        y_pred = kwargs["y_pred"]
        y = kwargs["y"]
        self._check_shape(y_pred, y)
        self.count.value += y_pred.shape[0]

        y_pred = jnp.argmax(y_pred, axis=-1).ravel()
        y = y.ravel()
        indices = self.num_classes * y + y_pred
        matrix = jnp.bincount(indices, minlength=self.num_classes**2, length=self.num_classes**2)
        matrix = matrix.reshape((self.num_classes, self.num_classes))
        self.confusion_matrix.value += matrix

    def compute(self) -> jax.Array:
        if self.average:
            confusion_matrix = self.confusion_matrix.value.astype("float")
            if self.average == "samples":
                return confusion_matrix / self.count.value
            else:
                return self.normalize(self.confusion_matrix.value, self.average)
        return self.confusion_matrix.value

    @staticmethod
    def normalize(matrix: jax.Array, average: str) -> jax.Array:
        """Normalize given `matrix` with given `average`."""
        if average == "recall":
            return matrix / (jnp.expand_dims(matrix.sum(axis=1), axis=1) + 1e-15)
        elif average == "precision":
            return matrix / (matrix.sum(axis=0) + 1e-15)
        else:
            raise ValueError("Argument average should be one of 'samples', 'recall', 'precision'")


def compute_iou(cm: jax.Array) -> jax.Array:
    return jnp.diag(cm) / (cm.sum(axis=1) + cm.sum(axis=0) - jnp.diag(cm) + 1e-15)


def compute_mean_iou(cm: jax.Array) -> jax.Array:
    return compute_iou(cm).mean()


def compute_accuracy(cm: jax.Array) -> jax.Array:
    return jnp.diag(cm).sum() / (cm.sum() + 1e-15)

Next, let’s define training and evaluation steps:

@nnx.jit
def train_step(
    model: nnx.Module, optimizer: nnx.Optimizer, batch: dict[str, np.ndarray]
):
    # Convert numpy arrays to jax.Array on GPU
    images = jnp.array(batch["image"])
    masks = jnp.array(batch["mask"], dtype=jnp.int32)

    grad_fn = nnx.value_and_grad(compute_losses_and_logits, has_aux=True)
    (loss, (xentropy_loss, jacc_loss, logits)), grads = grad_fn(model, images, masks)

    optimizer.update(grads)  # In-place updates.

    return loss, xentropy_loss, jacc_loss
@nnx.jit
def eval_step(
    model: nnx.Module, batch: dict[str, np.ndarray], eval_metrics: nnx.MultiMetric
):
    # Convert numpy arrays to jax.Array on GPU
    images = jnp.array(batch["image"])
    masks = jnp.array(batch["mask"], dtype=jnp.int32)
    loss, (_, _, logits) = compute_losses_and_logits(model, images, masks)

    eval_metrics.update(
        total_loss=loss,
        y_pred=logits,
        y=masks,
    )  # In-place updates.

We will also define metrics we want to compute during the evaluation phase: total loss and confusion matrix computed on training and validation datasets. Finally, we define helper objects to store the metrics history. Metrics like IoU per class, mean IoU and accuracy will be computed using the confusion matrix in the evaluation code.

eval_metrics = nnx.MultiMetric(
    total_loss=nnx.metrics.Average('total_loss'),
    confusion_matrix=ConfusionMatrix(num_classes=3),
)


eval_metrics_history = {
    "train_total_loss": [],
    "train_IoU": [],
    "train_mean_IoU": [],
    "train_accuracy": [],

    "val_total_loss": [],
    "val_IoU": [],
    "val_mean_IoU": [],
    "val_accuracy": [],
}

Let us define the training and evaluation logic. We define as well a checkpoint manager to store two best models defined by validation mean IoU metric value.

import time
import orbax.checkpoint as ocp


def train_one_epoch(epoch):
    start_time = time.time()

    model.train()  # Set model to the training mode: e.g. update batch statistics
    for step, batch in enumerate(train_loader):
        total_loss, xentropy_loss, jaccard_loss = train_step(model, optimizer, batch)

        print(
            f"\r[train] epoch: {epoch + 1}/{num_epochs}, iteration: {step}/{total_steps}, "
            f"total loss: {total_loss.item():.4f} ",
            f"xentropy loss: {xentropy_loss.item():.4f} ",
            f"jaccard loss: {jaccard_loss.item():.4f} ",
            end="")
        print("\r", end="")

    elapsed = time.time() - start_time
    print(
        f"\n[train] epoch: {epoch + 1}/{num_epochs}, elapsed time: {elapsed:.2f} seconds"
    )


def evaluate_model(epoch):
    start_time = time.time()

    # Compute the metrics on the train and val sets after each training epoch.
    model.eval()  # Set model to evaluation model: e.g. use stored batch statistics

    for tag, eval_loader in [("train", train_eval_loader), ("val", val_loader)]:
        eval_metrics.reset()  # Reset the eval metrics
        for val_batch in eval_loader:
            eval_step(model, val_batch, eval_metrics)

        for metric, value in eval_metrics.compute().items():
            if metric == "confusion_matrix":
                eval_metrics_history[f"{tag}_IoU"].append(
                    compute_iou(value)
                )
                eval_metrics_history[f"{tag}_mean_IoU"].append(
                    compute_mean_iou(value)
                )
                eval_metrics_history[f"{tag}_accuracy"].append(
                    compute_accuracy(value)
                )
            else:
                eval_metrics_history[f'{tag}_{metric}'].append(value)

        print(
            f"[{tag}] epoch: {epoch + 1}/{num_epochs} "
            f"\n - total loss: {eval_metrics_history[f'{tag}_total_loss'][-1]:0.4f} "
            f"\n - IoU per class: {eval_metrics_history[f'{tag}_IoU'][-1].tolist()} "
            f"\n - Mean IoU: {eval_metrics_history[f'{tag}_mean_IoU'][-1]:0.4f} "
            f"\n - Accuracy: {eval_metrics_history[f'{tag}_accuracy'][-1]:0.4f} "
            "\n"
        )

    elapsed = time.time() - start_time
    print(
        f"[evaluation] epoch: {epoch + 1}/{num_epochs}, elapsed time: {elapsed:.2f} seconds"
    )

    return eval_metrics_history['val_mean_IoU'][-1]


path = ocp.test_utils.erase_and_create_empty("/tmp/output-oxford-model/")
options = ocp.CheckpointManagerOptions(max_to_keep=2)
mngr = ocp.CheckpointManager(path, options=options)


def save_model(epoch):
    state = nnx.state(model)
    # We should convert PRNGKeyArray to the old format for Dropout layers
    # https://github.com/google/flax/issues/4231
    def get_key_data(x):
        if isinstance(x, jax._src.prng.PRNGKeyArray):
            if isinstance(x.dtype, jax._src.prng.KeyTy):
                return jax.random.key_data(x)
        return x

    serializable_state = jax.tree.map(get_key_data, state)
    mngr.save(epoch, args=ocp.args.StandardSave(serializable_state))
    mngr.wait_until_finished()

Now we can start the training. It can take around 45 minutes using a single GPU and use 19GB of GPU memory.

%%time

best_val_mean_iou = 0.0
for epoch in range(num_epochs):
    train_one_epoch(epoch)
    if (epoch % 3 == 0) or (epoch == num_epochs - 1):
        val_mean_iou = evaluate_model(epoch)
        if val_mean_iou > best_val_mean_iou:
            save_model(epoch)
            best_val_mean_iou = val_mean_iou
[train] epoch: 1/50, iteration: 67/71, total loss: 1.5128  xentropy loss: 0.8543  jaccard loss: 0.6585 
[train] epoch: 1/50, elapsed time: 98.28 seconds
[train] epoch: 1/50 
 - total loss: 1.4808 
 - IoU per class: [0.3311152458190918, 0.5875526070594788, 0.07520102709531784] 
 - Mean IoU: 0.3313 
 - Accuracy: 0.6198 

[val] epoch: 1/50 
 - total loss: 1.4837 
 - IoU per class: [0.32379695773124695, 0.5863257050514221, 0.07609017938375473] 
 - Mean IoU: 0.3287 
 - Accuracy: 0.6174 

[evaluation] epoch: 1/50, elapsed time: 95.73 seconds
[train] epoch: 2/50, iteration: 67/71, total loss: 1.4376  xentropy loss: 0.8077  jaccard loss: 0.6299 
[train] epoch: 2/50, elapsed time: 42.02 seconds
[train] epoch: 3/50, iteration: 67/71, total loss: 1.3881  xentropy loss: 0.7764  jaccard loss: 0.6118 
[train] epoch: 3/50, elapsed time: 42.77 seconds
[train] epoch: 4/50, iteration: 67/71, total loss: 1.3697  xentropy loss: 0.7662  jaccard loss: 0.6035 
[train] epoch: 4/50, elapsed time: 42.69 seconds
[train] epoch: 4/50 
 - total loss: 1.3479 
 - IoU per class: [0.4200800955295563, 0.6404442191123962, 0.10423737019300461] 
 - Mean IoU: 0.3883 
 - Accuracy: 0.6735 

[val] epoch: 4/50 
 - total loss: 1.3526 
 - IoU per class: [0.4134039580821991, 0.6370245814323425, 0.10554961115121841] 
 - Mean IoU: 0.3853 
 - Accuracy: 0.6700 

[evaluation] epoch: 4/50, elapsed time: 18.13 seconds
[train] epoch: 5/50, iteration: 67/71, total loss: 1.3223  xentropy loss: 0.7382  jaccard loss: 0.5841 
[train] epoch: 5/50, elapsed time: 42.68 seconds
[train] epoch: 6/50, iteration: 67/71, total loss: 1.2856  xentropy loss: 0.7155  jaccard loss: 0.5701 
[train] epoch: 6/50, elapsed time: 41.52 seconds
[train] epoch: 7/50, iteration: 67/71, total loss: 1.2704  xentropy loss: 0.7096  jaccard loss: 0.5608 
[train] epoch: 7/50, elapsed time: 41.99 seconds
[train] epoch: 7/50 
 - total loss: 1.2590 
 - IoU per class: [0.4778529703617096, 0.6718385815620422, 0.12118919938802719] 
 - Mean IoU: 0.4236 
 - Accuracy: 0.7051 

[val] epoch: 7/50 
 - total loss: 1.2703 
 - IoU per class: [0.46884578466415405, 0.6662946939468384, 0.12407345324754715] 
 - Mean IoU: 0.4197 
 - Accuracy: 0.7002 

[evaluation] epoch: 7/50, elapsed time: 18.11 seconds
[train] epoch: 8/50, iteration: 67/71, total loss: 1.2547  xentropy loss: 0.7002  jaccard loss: 0.5545 
[train] epoch: 8/50, elapsed time: 41.78 seconds
[train] epoch: 9/50, iteration: 67/71, total loss: 1.2426  xentropy loss: 0.6930  jaccard loss: 0.5496 
[train] epoch: 9/50, elapsed time: 41.77 seconds
[train] epoch: 10/50, iteration: 67/71, total loss: 1.2336  xentropy loss: 0.6879  jaccard loss: 0.5456 
[train] epoch: 10/50, elapsed time: 42.35 seconds
[train] epoch: 10/50 
 - total loss: 1.2209 
 - IoU per class: [0.49634286761283875, 0.6871868371963501, 0.14223484694957733] 
 - Mean IoU: 0.4419 
 - Accuracy: 0.7183 

[val] epoch: 10/50 
 - total loss: 1.2344 
 - IoU per class: [0.4855667054653168, 0.6802844405174255, 0.14551958441734314] 
 - Mean IoU: 0.4371 
 - Accuracy: 0.7124 

[evaluation] epoch: 10/50, elapsed time: 18.03 seconds
[train] epoch: 11/50, iteration: 67/71, total loss: 1.2336  xentropy loss: 0.6888  jaccard loss: 0.5448 
[train] epoch: 11/50, elapsed time: 42.32 seconds
[train] epoch: 12/50, iteration: 67/71, total loss: 1.2240  xentropy loss: 0.6831  jaccard loss: 0.5410 
[train] epoch: 12/50, elapsed time: 42.16 seconds
[train] epoch: 13/50, iteration: 67/71, total loss: 1.2192  xentropy loss: 0.6807  jaccard loss: 0.5384 
[train] epoch: 13/50, elapsed time: 42.99 seconds
[train] epoch: 13/50 
 - total loss: 1.2033 
 - IoU per class: [0.5088780522346497, 0.6932766437530518, 0.14735452830791473] 
 - Mean IoU: 0.4498 
 - Accuracy: 0.7244 

[val] epoch: 13/50 
 - total loss: 1.2176 
 - IoU per class: [0.49794140458106995, 0.6858826875686646, 0.15074597299098969] 
 - Mean IoU: 0.4449 
 - Accuracy: 0.7182 

[evaluation] epoch: 13/50, elapsed time: 18.08 seconds
[train] epoch: 14/50, iteration: 67/71, total loss: 1.2125  xentropy loss: 0.6757  jaccard loss: 0.5367 
[train] epoch: 14/50, elapsed time: 41.59 seconds
[train] epoch: 15/50, iteration: 67/71, total loss: 1.2067  xentropy loss: 0.6716  jaccard loss: 0.5350 
[train] epoch: 15/50, elapsed time: 42.94 seconds
[train] epoch: 16/50, iteration: 67/71, total loss: 1.2003  xentropy loss: 0.6670  jaccard loss: 0.5333 
[train] epoch: 16/50, elapsed time: 42.27 seconds
[train] epoch: 16/50 
 - total loss: 1.1923 
 - IoU per class: [0.5148026943206787, 0.697281002998352, 0.15585048496723175] 
 - Mean IoU: 0.4560 
 - Accuracy: 0.7282 

[val] epoch: 16/50 
 - total loss: 1.2069 
 - IoU per class: [0.5041127800941467, 0.6899872422218323, 0.15917228162288666] 
 - Mean IoU: 0.4511 
 - Accuracy: 0.7221 

[evaluation] epoch: 16/50, elapsed time: 17.95 seconds
[train] epoch: 17/50, iteration: 67/71, total loss: 1.2013  xentropy loss: 0.6684  jaccard loss: 0.5330 
[train] epoch: 17/50, elapsed time: 42.13 seconds
[train] epoch: 18/50, iteration: 67/71, total loss: 1.1990  xentropy loss: 0.6673  jaccard loss: 0.5317 
[train] epoch: 18/50, elapsed time: 42.59 seconds
[train] epoch: 19/50, iteration: 67/71, total loss: 1.1928  xentropy loss: 0.6651  jaccard loss: 0.5277 
[train] epoch: 19/50, elapsed time: 42.52 seconds
[train] epoch: 19/50 
 - total loss: 1.1801 
 - IoU per class: [0.5213924646377563, 0.7013189792633057, 0.1597258597612381] 
 - Mean IoU: 0.4608 
 - Accuracy: 0.7320 

[val] epoch: 19/50 
 - total loss: 1.1945 
 - IoU per class: [0.5107904672622681, 0.6942348480224609, 0.1630152016878128] 
 - Mean IoU: 0.4560 
 - Accuracy: 0.7261 

[evaluation] epoch: 19/50, elapsed time: 18.10 seconds
[train] epoch: 20/50, iteration: 67/71, total loss: 1.1808  xentropy loss: 0.6580  jaccard loss: 0.5228 
[train] epoch: 20/50, elapsed time: 42.61 seconds
[train] epoch: 21/50, iteration: 67/71, total loss: 1.1872  xentropy loss: 0.6665  jaccard loss: 0.5207 
[train] epoch: 21/50, elapsed time: 41.39 seconds
[train] epoch: 22/50, iteration: 67/71, total loss: 1.1753  xentropy loss: 0.6563  jaccard loss: 0.5190 
[train] epoch: 22/50, elapsed time: 42.60 seconds
[train] epoch: 22/50 
 - total loss: 1.1565 
 - IoU per class: [0.5349406003952026, 0.7107416987419128, 0.1611461639404297] 
 - Mean IoU: 0.4689 
 - Accuracy: 0.7396 

[val] epoch: 22/50 
 - total loss: 1.1714 
 - IoU per class: [0.5250519514083862, 0.7039015889167786, 0.16501173377037048] 
 - Mean IoU: 0.4647 
 - Accuracy: 0.7341 

[evaluation] epoch: 22/50, elapsed time: 18.11 seconds
[train] epoch: 23/50, iteration: 67/71, total loss: 1.1682  xentropy loss: 0.6515  jaccard loss: 0.5166 
[train] epoch: 23/50, elapsed time: 42.13 seconds
[train] epoch: 24/50, iteration: 67/71, total loss: 1.1598  xentropy loss: 0.6455  jaccard loss: 0.5142 
[train] epoch: 24/50, elapsed time: 42.78 seconds
[train] epoch: 25/50, iteration: 67/71, total loss: 1.1578  xentropy loss: 0.6439  jaccard loss: 0.5138 
[train] epoch: 25/50, elapsed time: 41.81 seconds
[train] epoch: 25/50 
 - total loss: 1.1493 
 - IoU per class: [0.5394869446754456, 0.714061975479126, 0.16233977675437927] 
 - Mean IoU: 0.4720 
 - Accuracy: 0.7427 

[val] epoch: 25/50 
 - total loss: 1.1646 
 - IoU per class: [0.5292088389396667, 0.7074748277664185, 0.1662358045578003] 
 - Mean IoU: 0.4676 
 - Accuracy: 0.7373 

[evaluation] epoch: 25/50, elapsed time: 18.03 seconds
[train] epoch: 26/50, iteration: 67/71, total loss: 1.1541  xentropy loss: 0.6419  jaccard loss: 0.5121 
[train] epoch: 26/50, elapsed time: 42.32 seconds
[train] epoch: 27/50, iteration: 67/71, total loss: 1.1530  xentropy loss: 0.6423  jaccard loss: 0.5107 
[train] epoch: 27/50, elapsed time: 43.11 seconds
[train] epoch: 28/50, iteration: 67/71, total loss: 1.1454  xentropy loss: 0.6403  jaccard loss: 0.5050 
[train] epoch: 28/50, elapsed time: 42.75 seconds
[train] epoch: 28/50 
 - total loss: 1.1361 
 - IoU per class: [0.5401146411895752, 0.7192343473434448, 0.17713244259357452] 
 - Mean IoU: 0.4788 
 - Accuracy: 0.7459 

[val] epoch: 28/50 
 - total loss: 1.1509 
 - IoU per class: [0.5303367972373962, 0.713244616985321, 0.1806260049343109] 
 - Mean IoU: 0.4747 
 - Accuracy: 0.7409 

[evaluation] epoch: 28/50, elapsed time: 18.15 seconds
[train] epoch: 29/50, iteration: 67/71, total loss: 1.1461  xentropy loss: 0.6397  jaccard loss: 0.5063 
[train] epoch: 29/50, elapsed time: 41.27 seconds
[train] epoch: 30/50, iteration: 67/71, total loss: 1.1441  xentropy loss: 0.6386  jaccard loss: 0.5054 
[train] epoch: 30/50, elapsed time: 43.06 seconds
[train] epoch: 31/50, iteration: 67/71, total loss: 1.1406  xentropy loss: 0.6385  jaccard loss: 0.5021 
[train] epoch: 31/50, elapsed time: 42.35 seconds
[train] epoch: 31/50 
 - total loss: 1.1221 
 - IoU per class: [0.5476018190383911, 0.7231868505477905, 0.1709066480398178] 
 - Mean IoU: 0.4806 
 - Accuracy: 0.7496 

[val] epoch: 31/50 
 - total loss: 1.1382 
 - IoU per class: [0.5371024012565613, 0.7168474793434143, 0.17449183762073517] 
 - Mean IoU: 0.4761 
 - Accuracy: 0.7444 

[evaluation] epoch: 31/50, elapsed time: 18.18 seconds
[train] epoch: 32/50, iteration: 67/71, total loss: 1.1407  xentropy loss: 0.6383  jaccard loss: 0.5024 
[train] epoch: 32/50, elapsed time: 42.78 seconds
[train] epoch: 33/50, iteration: 67/71, total loss: 1.1362  xentropy loss: 0.6381  jaccard loss: 0.4981 
[train] epoch: 33/50, elapsed time: 42.83 seconds
[train] epoch: 34/50, iteration: 67/71, total loss: 1.1327  xentropy loss: 0.6366  jaccard loss: 0.4961 
[train] epoch: 34/50, elapsed time: 42.60 seconds
[train] epoch: 34/50 
 - total loss: 1.0938 
 - IoU per class: [0.5674961805343628, 0.735647976398468, 0.1631506383419037] 
 - Mean IoU: 0.4888 
 - Accuracy: 0.7596 

[val] epoch: 34/50 
 - total loss: 1.1083 
 - IoU per class: [0.5587170720100403, 0.7299370765686035, 0.16656328737735748] 
 - Mean IoU: 0.4851 
 - Accuracy: 0.7551 

[evaluation] epoch: 34/50, elapsed time: 17.95 seconds
[train] epoch: 35/50, iteration: 67/71, total loss: 1.1244  xentropy loss: 0.6319  jaccard loss: 0.4925 
[train] epoch: 35/50, elapsed time: 42.43 seconds
[train] epoch: 36/50, iteration: 67/71, total loss: 1.1275  xentropy loss: 0.6359  jaccard loss: 0.4916 
[train] epoch: 36/50, elapsed time: 42.50 seconds
[train] epoch: 37/50, iteration: 67/71, total loss: 1.1254  xentropy loss: 0.6342  jaccard loss: 0.4912 
[train] epoch: 37/50, elapsed time: 42.97 seconds
[train] epoch: 37/50 
 - total loss: 1.0758 
 - IoU per class: [0.5775947570800781, 0.7411153316497803, 0.1559378057718277] 
 - Mean IoU: 0.4915 
 - Accuracy: 0.7648 

[val] epoch: 37/50 
 - total loss: 1.0889 
 - IoU per class: [0.5697115063667297, 0.736099123954773, 0.15949904918670654] 
 - Mean IoU: 0.4884 
 - Accuracy: 0.7609 

[evaluation] epoch: 37/50, elapsed time: 18.29 seconds
[train] epoch: 38/50, iteration: 67/71, total loss: 1.1135  xentropy loss: 0.6283  jaccard loss: 0.4852 
[train] epoch: 38/50, elapsed time: 42.40 seconds
[train] epoch: 39/50, iteration: 67/71, total loss: 1.1064  xentropy loss: 0.6222  jaccard loss: 0.4842 
[train] epoch: 39/50, elapsed time: 42.89 seconds
[train] epoch: 40/50, iteration: 67/71, total loss: 1.0981  xentropy loss: 0.6188  jaccard loss: 0.4792 
[train] epoch: 40/50, elapsed time: 42.52 seconds
[train] epoch: 40/50 
 - total loss: 1.0575 
 - IoU per class: [0.5876496434211731, 0.7454202771186829, 0.1702071875333786] 
 - Mean IoU: 0.5011 
 - Accuracy: 0.7698 

[val] epoch: 40/50 
 - total loss: 1.0745 
 - IoU per class: [0.5783292055130005, 0.7390925288200378, 0.17342697083950043] 
 - Mean IoU: 0.4969 
 - Accuracy: 0.7649 

[evaluation] epoch: 40/50, elapsed time: 18.13 seconds
[train] epoch: 41/50, iteration: 67/71, total loss: 1.1015  xentropy loss: 0.6202  jaccard loss: 0.4812 
[train] epoch: 41/50, elapsed time: 42.75 seconds
[train] epoch: 42/50, iteration: 67/71, total loss: 1.0933  xentropy loss: 0.6148  jaccard loss: 0.4785 
[train] epoch: 42/50, elapsed time: 42.70 seconds
[train] epoch: 43/50, iteration: 67/71, total loss: 1.0860  xentropy loss: 0.6113  jaccard loss: 0.4748 
[train] epoch: 43/50, elapsed time: 42.48 seconds
[train] epoch: 43/50 
 - total loss: 1.0466 
 - IoU per class: [0.5935679078102112, 0.7484169006347656, 0.17425251007080078] 
 - Mean IoU: 0.5054 
 - Accuracy: 0.7726 

[val] epoch: 43/50 
 - total loss: 1.0649 
 - IoU per class: [0.5832134485244751, 0.7414273023605347, 0.17751547694206238] 
 - Mean IoU: 0.5007 
 - Accuracy: 0.7673 

[evaluation] epoch: 43/50, elapsed time: 18.06 seconds
[train] epoch: 44/50, iteration: 67/71, total loss: 1.0858  xentropy loss: 0.6108  jaccard loss: 0.4751 
[train] epoch: 44/50, elapsed time: 42.08 seconds
[train] epoch: 45/50, iteration: 67/71, total loss: 1.0846  xentropy loss: 0.6083  jaccard loss: 0.4763 
[train] epoch: 45/50, elapsed time: 42.20 seconds
[train] epoch: 46/50, iteration: 67/71, total loss: 1.0811  xentropy loss: 0.6053  jaccard loss: 0.4759 
[train] epoch: 46/50, elapsed time: 42.44 seconds
[train] epoch: 46/50 
 - total loss: 1.0358 
 - IoU per class: [0.5985013246536255, 0.7518817782402039, 0.18402163684368134] 
 - Mean IoU: 0.5115 
 - Accuracy: 0.7757 

[val] epoch: 46/50 
 - total loss: 1.0532 
 - IoU per class: [0.5885671377182007, 0.7452569603919983, 0.18743924796581268] 
 - Mean IoU: 0.5071 
 - Accuracy: 0.7707 

[evaluation] epoch: 46/50, elapsed time: 17.97 seconds
[train] epoch: 47/50, iteration: 67/71, total loss: 1.0795  xentropy loss: 0.6050  jaccard loss: 0.4746 
[train] epoch: 47/50, elapsed time: 42.24 seconds
[train] epoch: 48/50, iteration: 67/71, total loss: 1.0812  xentropy loss: 0.6076  jaccard loss: 0.4736 
[train] epoch: 48/50, elapsed time: 43.06 seconds
[train] epoch: 49/50, iteration: 67/71, total loss: 1.0755  xentropy loss: 0.6033  jaccard loss: 0.4722 
[train] epoch: 49/50, elapsed time: 42.87 seconds
[train] epoch: 49/50 
 - total loss: 1.0339 
 - IoU per class: [0.6011808514595032, 0.7519543766975403, 0.18255695700645447] 
 - Mean IoU: 0.5119 
 - Accuracy: 0.7763 

[val] epoch: 49/50 
 - total loss: 1.0526 
 - IoU per class: [0.5906183123588562, 0.7446607351303101, 0.18571272492408752] 
 - Mean IoU: 0.5070 
 - Accuracy: 0.7708 

[evaluation] epoch: 49/50, elapsed time: 17.92 seconds
[train] epoch: 50/50, iteration: 67/71, total loss: 1.0746  xentropy loss: 0.6023  jaccard loss: 0.4723 
[train] epoch: 50/50, elapsed time: 42.67 seconds
[train] epoch: 50/50 
 - total loss: 1.0333 
 - IoU per class: [0.6012441515922546, 0.7520167231559753, 0.18340939283370972] 
 - Mean IoU: 0.5122 
 - Accuracy: 0.7764 

[val] epoch: 50/50 
 - total loss: 1.0529 
 - IoU per class: [0.5903779864311218, 0.7444505095481873, 0.18650034070014954] 
 - Mean IoU: 0.5071 
 - Accuracy: 0.7707 

[evaluation] epoch: 50/50, elapsed time: 18.28 seconds
CPU times: user 21min 59s, sys: 2min 51s, total: 24min 51s
Wall time: 43min 44s
2024-11-19 15:13:28.682932: E external/xla/xla/service/slow_operation_alarm.cc:65] Trying algorithm eng28{k2=2,k3=0} for conv (f32[16,32,2,2]{3,2,1,0}, u8[0]{0}) custom-call(f32[16,72,256,256]{3,2,1,0}, f32[32,72,255,255]{3,2,1,0}), window={size=255x255 rhs_reversal=1x1}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward", backend_config={"cudnn_conv_backend_config":{"activation_mode":"kNone","conv_result_scale":1,"leakyrelu_alpha":0,"side_input_scale":0},"force_earliest_schedule":false,"operation_queue_id":"0","wait_on_operation_queues":[]} is taking a while...
2024-11-19 15:13:29.105239: E external/xla/xla/service/slow_operation_alarm.cc:133] The operation took 1.422412472s
Trying algorithm eng28{k2=2,k3=0} for conv (f32[16,32,2,2]{3,2,1,0}, u8[0]{0}) custom-call(f32[16,72,256,256]{3,2,1,0}, f32[32,72,255,255]{3,2,1,0}), window={size=255x255 rhs_reversal=1x1}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward", backend_config={"cudnn_conv_backend_config":{"activation_mode":"kNone","conv_result_scale":1,"leakyrelu_alpha":0,"side_input_scale":0},"force_earliest_schedule":false,"operation_queue_id":"0","wait_on_operation_queues":[]} is taking a while...
2024-11-19 15:13:30.105387: E external/xla/xla/service/slow_operation_alarm.cc:65] Trying algorithm eng0{} for conv (f32[16,32,2,2]{3,2,1,0}, u8[0]{0}) custom-call(f32[16,72,256,256]{3,2,1,0}, f32[32,72,255,255]{3,2,1,0}), window={size=255x255 rhs_reversal=1x1}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward", backend_config={"cudnn_conv_backend_config":{"activation_mode":"kNone","conv_result_scale":1,"leakyrelu_alpha":0,"side_input_scale":0},"force_earliest_schedule":false,"operation_queue_id":"0","wait_on_operation_queues":[]} is taking a while...
2024-11-19 15:13:30.272493: E external/xla/xla/service/slow_operation_alarm.cc:133] The operation took 1.167207376s
Trying algorithm eng0{} for conv (f32[16,32,2,2]{3,2,1,0}, u8[0]{0}) custom-call(f32[16,72,256,256]{3,2,1,0}, f32[32,72,255,255]{3,2,1,0}), window={size=255x255 rhs_reversal=1x1}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward", backend_config={"cudnn_conv_backend_config":{"activation_mode":"kNone","conv_result_scale":1,"leakyrelu_alpha":0,"side_input_scale":0},"force_earliest_schedule":false,"operation_queue_id":"0","wait_on_operation_queues":[]} is taking a while...
2024-11-19 15:13:31.272637: E external/xla/xla/service/slow_operation_alarm.cc:65] Trying algorithm eng28{k2=1,k3=0} for conv (f32[16,32,2,2]{3,2,1,0}, u8[0]{0}) custom-call(f32[16,72,256,256]{3,2,1,0}, f32[32,72,255,255]{3,2,1,0}), window={size=255x255 rhs_reversal=1x1}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward", backend_config={"cudnn_conv_backend_config":{"activation_mode":"kNone","conv_result_scale":1,"leakyrelu_alpha":0,"side_input_scale":0},"force_earliest_schedule":false,"operation_queue_id":"0","wait_on_operation_queues":[]} is taking a while...
2024-11-19 15:13:31.597345: E external/xla/xla/service/slow_operation_alarm.cc:133] The operation took 1.324807429s
Trying algorithm eng28{k2=1,k3=0} for conv (f32[16,32,2,2]{3,2,1,0}, u8[0]{0}) custom-call(f32[16,72,256,256]{3,2,1,0}, f32[32,72,255,255]{3,2,1,0}), window={size=255x255 rhs_reversal=1x1}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward", backend_config={"cudnn_conv_backend_config":{"activation_mode":"kNone","conv_result_scale":1,"leakyrelu_alpha":0,"side_input_scale":0},"force_earliest_schedule":false,"operation_queue_id":"0","wait_on_operation_queues":[]} is taking a while...

We can check the saved models:

!ls /tmp/output-oxford-model/
45  49
/opt/conda/lib/python3.11/pty.py:89: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
  pid, fd = os.forkpty()

and visualize collected metrics:

epochs = [i for i in range(num_epochs) if (i % 3 == 0) or (i == num_epochs - 1)]

plt.plot(epochs, eval_metrics_history["train_total_loss"], label="Loss value on training set")
plt.plot(epochs, eval_metrics_history["val_total_loss"], label="Loss value on validation set")
plt.legend()
<matplotlib.legend.Legend at 0x7fb370eddfd0>
_images/3027eb4e2043a53143f01370739688263a30a492799274a4b689a94a4b9c5831.png
plt.plot(epochs, eval_metrics_history["train_mean_IoU"], label="Mean IoU on training set")
plt.plot(epochs, eval_metrics_history["val_mean_IoU"], label="Mean IoU on validation set")
plt.legend()
<matplotlib.legend.Legend at 0x7fb3eeed0490>
_images/ac3822e84fb1840f19576c1a22bc44eedb58fa5b8e1c702953545ec3141a8c1d.png

Next, we will visualize model predictions on validation data:

model.eval()
val_batch = next(iter(val_loader))
images, masks = val_batch["image"], val_batch["mask"]
preds = model(images)
preds = jnp.argmax(preds, axis=-1)
def display_image_mask_pred(img, mask, pred, label=""):
    if img.dtype in (np.float32, ):
        img = ((img - img.min()) / (img.max() - img.min()) * 255.0).astype(np.uint8)
    fig, axs = plt.subplots(1, 5, figsize=(15, 10))
    axs[0].set_title(f"Image{label}")
    axs[0].imshow(img)
    axs[1].set_title(f"Mask{label}")
    axs[1].imshow(mask)
    axs[2].set_title("Image + Mask")
    axs[2].imshow(img)
    axs[2].imshow(mask, alpha=0.5)
    axs[3].set_title(f"Pred{label}")
    axs[3].imshow(pred)
    axs[4].set_title("Image + Pred")
    axs[4].imshow(img)
    axs[4].imshow(pred, alpha=0.5)
for img, mask, pred in zip(images[:4], masks[:4], preds[:4]):
    display_image_mask_pred(img, mask, pred, label=" (validation set)")
_images/b5872fb3c42787e21aa593727553bf36f0d89d299c1680704c6c24ff03de39b0.png _images/dbacde14e923c1b2904c2ee562ee0d111a64e854935d4937fdbcc4ee7d7d77de.png _images/bf1d578abd8991803fada0d70e20a44d4781b9f91d290cff72050cec5b2f6b67.png _images/c167feeb4eaac68c8744837875129a8043fd82e879c32fe3fdfc58aede741f36.png

We can see that model can roughly predict the shape of the animal and the background and struggles with predicting the boundary. Carefully choosing hyperparameters we may achieve better results.