Skip to content

Training

We provide simple trainer class for training on WildlifeDataset instances as well as wrappers for ArcFace and Triplet losses.

Training Example

Fine-tuning MegaDescriptor-T from HuggingFace Hub

import timm
import itertools
from torch.optim import SGD
from wildlife_tools.train import ArcFaceLoss, BasicTrainer

# Download MegaDescriptor-T backbone from HuggingFace Hub
backbone = timm.create_model('hf-hub:BVRA/MegaDescriptor-T-224', num_classes=0, pretrained=True)

# Arcface loss - needs backbone output size and number of classes.
objective = ArcFaceLoss(
    num_classes=dataset.num_classes,
    embedding_size=768,
    margin=0.5,
    scale=64
    )

# Optimize parameters in backbone and in objective using single optimizer.
params = itertools.chain(backbone.parameters(), objective.parameters())
optimizer = SGD(params=params, lr=0.001, momentum=0.9)


trainer = BasicTrainer(
    dataset=dataset,
    model=backbone,
    objective=objective,
    optimizer=optimizer,
    epochs=20,
    device='cpu',
)

trainer.train()

Reference

Source code in train/trainer.py
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
class BasicTrainer():
    def __init__(
        self,
        dataset,
        model,
        objective,
        optimizer,
        epochs,
        scheduler=None,
        device='cuda',
        batch_size=128,
        num_workers=1,
        accumulation_steps=1,
        epoch_callback=None,
    ):
        self.dataset = dataset
        self.model = model.to(device)
        self.objective = objective.to(device)
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.epochs = epochs
        self.epoch = 0
        self.device = device
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.accumulation_steps = accumulation_steps
        self.epoch_callback = epoch_callback


    def train(self):
        loader = torch.utils.data.DataLoader(
            self.dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=True
        )

        for e in range(self.epochs):
            epoch_data = self.train_epoch(loader)
            self.epoch += 1

            if self.epoch_callback:
                self.epoch_callback(trainer=self, epoch_data=epoch_data)


    def train_epoch(self, loader):
        model = self.model.train()
        losses = []
        for i, batch in enumerate(tqdm(loader, desc=f'Epoch {self.epoch}: ', mininterval=1, ncols=100)):
            x, y = batch
            x, y = x.to(self.device), y.to(self.device)

            out = model(x)
            loss = self.objective(out, y)
            loss.backward()
            if (i-1) % self.accumulation_steps == 0:
                self.optimizer.step()
                self.optimizer.zero_grad()

            losses.append(loss.detach().cpu())

        if self.scheduler:
            self.scheduler.step()

        return {'train_loss_epoch_avg': np.mean(losses)}

    def save(self, folder, file_name='checkpoint.pth', **kwargs):
        if not os.path.exists(folder):
           os.makedirs(folder)
        if self.scheduler:
            scheduler_state = self.scheduler.state_dict()
        else:
            scheduler_state = None

        checkpoint = {
            'model': self.model.state_dict(),
            'objective': self.objective.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'epoch': self.epoch,
            'scheduler': scheduler_state,
            'rng_states': get_random_states(device=self.device),
        }
        torch.save(checkpoint, os.path.join(folder, file_name))


    def load(self, path, device='cpu'):
        checkpoint = torch.load(path, map_location=torch.device(device))

        if 'rng_states' in checkpoint:
            set_random_states(checkpoint['rng_states'], device=self.device)
        if 'model' in checkpoint:
            self.model.load_state_dict(checkpoint['model'])
        if 'objective' in checkpoint:
            self.objective.load_state_dict(checkpoint['objective'])
        if 'optimizer' in checkpoint:
            self.optimizer.load_state_dict(checkpoint['optimizer'])
        if 'epoch' in checkpoint:
            self.epoch = checkpoint['epoch']
        if 'scheduler' in checkpoint:
            self.scheduler.load_state_dict(checkpoint['scheduler'])

Replicability

The model can be trained with a specified seed to ensure replicable results by calling the set_seed function at the beginning of the training process. If the trainer is saved into checkpoint, the seed is stored as well, allowing for its later use in restarting the model and maintaining replicability throughout the restart.

Example

from wildlife_tools.train import set_seed, BasicTrainer


set_seed(0)
trainer = BasicTrainer(<add trainer parameters here>)
trainer.train()