Skip to content

Training

We provide simple trainer class for training on WildlifeDataset instances as well as wrappers for ArcFace and Triplet losses.

Replicability

The model can be trained with a specified seed to ensure replicable results by calling the set_seed function at the beginning of the training process. If the trainer is saved into checkpoint, the seed is stored as well, allowing for its later use in restarting the model and maintaining replicability throughout the restart.

train.trainer

BasicTrainer(dataset, model, objective, optimizer, epochs, scheduler=None, device='cuda', batch_size=128, num_workers=1, accumulation_steps=1, epoch_callback=None)

Implements basic training loop for Pytorch models. Checkpoints includes random states - any restarts from checkpoint preservers reproducibility.

Parameters:

Name Type Description Default
dataset

Training dataset that gives (x, y) tensor pairs.

required
model dict

Pytorch nn.Module for model / backbone.

required
objective dict

Pytorch nn.Module for objective / loss function.

required
optimizer

Pytorch optimizer.

required
scheduler optional

Pytorch scheduler.

None
epochs int

Number of training epochs.

required
device (str, default)

'cuda'): Device to be used for training.

'cuda'
batch_size (int, default)

128): Training batch size.

128
num_workers (int, default)

1): Number of data loading workers in torch DataLoader.

1
accumulation_steps (int, default)

1): Number of gradient accumulation steps.

1
epoch_callback

Callback function to be called after each epoch.

None

get_random_states()

Gives dictionary of random states for reproducibility.

train.objective

ArcFaceLoss(num_classes, embedding_size, margin=0.5, scale=64)

Bases: Module

Wraps Pytorch Metric Learning ArcFaceLoss.

Parameters:

Name Type Description Default
num_classes int

Number of classes.

required
embedding_size int

Size of the input embeddings.

required
margin int

Margin for ArcFace loss (in radians).

0.5
scale int

Scale parameter for ArcFace loss.

64

TripletLoss(margin=0.2, mining='seminard', distance='l2_squared')

Bases: Module

Wraps Pytorch Metric Learning TripletMarginLoss.

Parameters:

Name Type Description Default
margin int

Margin for triplet loss.

0.2
mining str

Type of triplet mining. One of: 'all', 'hard', 'semihard'

'seminard'
distance str

Distance metric for triplet loss. One of: 'cosine', 'l2', 'l2_squared'

'l2_squared'

SoftmaxLoss(num_classes, embedding_size)

Bases: Module

CE with single dense layer classification head.

Parameters:

Name Type Description Default
num_classes int

Number of classes.

required
embedding_size int

Size of the input embeddings.

required

Examples

Fine-tuning MegaDescriptor-T from HuggingFace Hub

import timm
import itertools
from torch.optim import SGD
from wildlife_tools.train import ArcFaceLoss, BasicTrainer
from wildlife_tools.train import set_seed

# Download MegaDescriptor-T backbone from HuggingFace Hub
backbone = timm.create_model('hf-hub:BVRA/MegaDescriptor-T-224', num_classes=0, pretrained=True)

# Arcface loss - needs backbone output size and number of classes.
objective = ArcFaceLoss(
    num_classes=dataset.num_classes,
    embedding_size=768,
    margin=0.5,
    scale=64
    )

# Optimize parameters in backbone and in objective using single optimizer.
params = itertools.chain(backbone.parameters(), objective.parameters())
optimizer = SGD(params=params, lr=0.001, momentum=0.9)

set_seed(0)
trainer = BasicTrainer(
    dataset=dataset,
    model=backbone,
    objective=objective,
    optimizer=optimizer,
    epochs=20,
    device='cpu',
)

trainer.train()