Skip to content

Reference training

train.trainer

BasicTrainer(dataset, model, objective, optimizer, epochs, scheduler=None, device='cuda', batch_size=128, num_workers=1, accumulation_steps=1, epoch_callback=None)

Implements basic training loop for Pytorch models. Checkpoints includes random states - any restarts from checkpoint preservers reproducibility.

Parameters:

Name Type Description Default
dataset ImageDataset

Training dataset that gives (x, y) tensor pairs.

required
model Module

Pytorch nn.Module for model / backbone.

required
objective Module

Pytorch nn.Module for objective / loss function.

required
optimizer Optimizer

Pytorch optimizer.

required
epochs int

Number of training epochs.

required
scheduler (_LRScheduler, optinal)

Pytorch scheduler.

None
device str

Device to be used for training.

'cuda'
batch_size int

Training batch size.

128
num_workers int

Number of data loading workers in torch DataLoader.

1
accumulation_steps int

Number of gradient accumulation steps.

1
epoch_callback Callable

Callback function to be called after each epoch.

None

get_random_states()

Gives dictionary of random states for reproducibility.

train.objective

ArcFaceLoss(num_classes, embedding_size, margin=0.5, scale=64)

Bases: Module

Wraps Pytorch Metric Learning ArcFaceLoss.

Parameters:

Name Type Description Default
num_classes int

Number of classes.

required
embedding_size int

Size of the input embeddings.

required
margin int

Margin for ArcFace loss (in radians).

0.5
scale int

Scale parameter for ArcFace loss.

64

TripletLoss(margin=0.2, mining='semihard', distance='l2_squared')

Bases: Module

Wraps Pytorch Metric Learning TripletMarginLoss.

Parameters:

Name Type Description Default
margin int

Margin for triplet loss.

0.2
mining str

Type of triplet mining. One of: 'all', 'hard', 'semihard'

'semihard'
distance str

Distance metric for triplet loss. One of: 'cosine', 'l2', 'l2_squared'

'l2_squared'

SoftmaxLoss(num_classes, embedding_size)

Bases: Module

CE with single dense layer classification head.

Parameters:

Name Type Description Default
num_classes int

Number of classes.

required
embedding_size int

Size of the input embeddings.

required