Training
We provide simple trainer class for training on WildlifeDataset
instances as well as wrappers for ArcFace and Triplet losses.
Replicability
The model can be trained with a specified seed to ensure replicable results by calling the set_seed
function at the beginning of the training process. If the trainer is saved into checkpoint, the seed is stored as well, allowing for its later use in restarting the model and maintaining replicability throughout the restart.
train.trainer
BasicTrainer(dataset, model, objective, optimizer, epochs, scheduler=None, device='cuda', batch_size=128, num_workers=1, accumulation_steps=1, epoch_callback=None)
Implements basic training loop for Pytorch models. Checkpoints includes random states - any restarts from checkpoint preservers reproducibility.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset |
Training dataset that gives (x, y) tensor pairs. |
required | |
model |
dict
|
Pytorch nn.Module for model / backbone. |
required |
objective |
dict
|
Pytorch nn.Module for objective / loss function. |
required |
optimizer |
Pytorch optimizer. |
required | |
scheduler |
optional
|
Pytorch scheduler. |
None
|
epochs |
int
|
Number of training epochs. |
required |
device |
(str, default)
|
'cuda'): Device to be used for training. |
'cuda'
|
batch_size |
(int, default)
|
128): Training batch size. |
128
|
num_workers |
(int, default)
|
1): Number of data loading workers in torch DataLoader. |
1
|
accumulation_steps |
(int, default)
|
1): Number of gradient accumulation steps. |
1
|
epoch_callback |
Callback function to be called after each epoch. |
None
|
get_random_states()
Gives dictionary of random states for reproducibility.
train.objective
ArcFaceLoss(num_classes, embedding_size, margin=0.5, scale=64)
Bases: Module
Wraps Pytorch Metric Learning ArcFaceLoss.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
num_classes |
int
|
Number of classes. |
required |
embedding_size |
int
|
Size of the input embeddings. |
required |
margin |
int
|
Margin for ArcFace loss (in radians). |
0.5
|
scale |
int
|
Scale parameter for ArcFace loss. |
64
|
TripletLoss(margin=0.2, mining='seminard', distance='l2_squared')
Bases: Module
Wraps Pytorch Metric Learning TripletMarginLoss.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
margin |
int
|
Margin for triplet loss. |
0.2
|
mining |
str
|
Type of triplet mining. One of: 'all', 'hard', 'semihard' |
'seminard'
|
distance |
str
|
Distance metric for triplet loss. One of: 'cosine', 'l2', 'l2_squared' |
'l2_squared'
|
SoftmaxLoss(num_classes, embedding_size)
Bases: Module
CE with single dense layer classification head.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
num_classes |
int
|
Number of classes. |
required |
embedding_size |
int
|
Size of the input embeddings. |
required |
Examples
Fine-tuning MegaDescriptor-T from HuggingFace Hub
import timm
import itertools
from torch.optim import SGD
from wildlife_tools.train import ArcFaceLoss, BasicTrainer
from wildlife_tools.train import set_seed
# Download MegaDescriptor-T backbone from HuggingFace Hub
backbone = timm.create_model('hf-hub:BVRA/MegaDescriptor-T-224', num_classes=0, pretrained=True)
# Arcface loss - needs backbone output size and number of classes.
objective = ArcFaceLoss(
num_classes=dataset.num_classes,
embedding_size=768,
margin=0.5,
scale=64
)
# Optimize parameters in backbone and in objective using single optimizer.
params = itertools.chain(backbone.parameters(), objective.parameters())
optimizer = SGD(params=params, lr=0.001, momentum=0.9)
set_seed(0)
trainer = BasicTrainer(
dataset=dataset,
model=backbone,
objective=objective,
optimizer=optimizer,
epochs=20,
device='cpu',
)
trainer.train()