Reference training
train.trainer
BasicTrainer(dataset, model, objective, optimizer, epochs, scheduler=None, device='cuda', batch_size=128, num_workers=1, accumulation_steps=1, epoch_callback=None)
Implements basic training loop for Pytorch models. Checkpoints includes random states - any restarts from checkpoint preservers reproducibility.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dataset
|
ImageDataset
|
Training dataset that gives (x, y) tensor pairs. |
required |
model
|
Module
|
Pytorch nn.Module for model / backbone. |
required |
objective
|
Module
|
Pytorch nn.Module for objective / loss function. |
required |
optimizer
|
Optimizer
|
Pytorch optimizer. |
required |
epochs
|
int
|
Number of training epochs. |
required |
scheduler
|
(_LRScheduler, optinal)
|
Pytorch scheduler. |
None
|
device
|
str
|
Device to be used for training. |
'cuda'
|
batch_size
|
int
|
Training batch size. |
128
|
num_workers
|
int
|
Number of data loading workers in torch DataLoader. |
1
|
accumulation_steps
|
int
|
Number of gradient accumulation steps. |
1
|
epoch_callback
|
Callable
|
Callback function to be called after each epoch. |
None
|
get_random_states()
Gives dictionary of random states for reproducibility.
train.objective
ArcFaceLoss(num_classes, embedding_size, margin=0.5, scale=64)
Bases: Module
Wraps Pytorch Metric Learning ArcFaceLoss.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
num_classes
|
int
|
Number of classes. |
required |
embedding_size
|
int
|
Size of the input embeddings. |
required |
margin
|
int
|
Margin for ArcFace loss (in radians). |
0.5
|
scale
|
int
|
Scale parameter for ArcFace loss. |
64
|
TripletLoss(margin=0.2, mining='semihard', distance='l2_squared')
Bases: Module
Wraps Pytorch Metric Learning TripletMarginLoss.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
margin
|
int
|
Margin for triplet loss. |
0.2
|
mining
|
str
|
Type of triplet mining. One of: 'all', 'hard', 'semihard' |
'semihard'
|
distance
|
str
|
Distance metric for triplet loss. One of: 'cosine', 'l2', 'l2_squared' |
'l2_squared'
|
SoftmaxLoss(num_classes, embedding_size)
Bases: Module
CE with single dense layer classification head.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
num_classes
|
int
|
Number of classes. |
required |
embedding_size
|
int
|
Size of the input embeddings. |
required |