Skip to content

Feature extraction

Feature extractors offers a standardized way to extract features from instances of the WildlifeDataset.

Feature extractors, implemented as classes, can be created with specific arguments that define the extraction properties. After instantiation, the extractor functions as a callable, requiring only a single argument—the WildlifeDataset instance. The specific output type and shape vary based on the chosen feature extractor. In general, the output is iterable, with the first dimension corresponding to the size of the WildlifeDataset input.

Deep features

The DeepFeatures extractor operates by extracting features through the forward pass of a PyTorch model. The output is a 2D array, where the rows represent images, and the columns correspond to the embedding dimensions. The size of the columns is determined by the output size of the model performing the feature extraction.

Example

The term dataset refers to any instance of WildlifeDataset with transforms that convert it into a tensor with the appropriate shape.

import timm
from wildlife_tools.features import DeepFeatures

backbone = timm.create_model('hf-hub:BVRA/MegaDescriptor-T-224', num_classes=0, pretrained=True)
extractor = DeepFeatures(backbone, device='cuda')
features = extractor(dataset)

Reference

Extracts features using forward pass of pytorch model.

Parameters:

Name Type Description Default
model

Pytorch model used for the feature extraction.

required
batch_size int

Batch size used for the feature extraction.

128
num_workers int

Number of workers used for data loading.

1
device str

Select between cuda and cpu devices.

'cpu'

Returns:

Type Description

An array with a shape of n_input x dim_embedding.

Source code in features/deep.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
class DeepFeatures(FeatureExtractor):
    '''
    Extracts features using forward pass of pytorch model.

    Args:
        model: Pytorch model used for the feature extraction.
        batch_size: Batch size used for the feature extraction.
        num_workers: Number of workers used for data loading.
        device: Select between cuda and cpu devices.

    Returns:
        An array with a shape of `n_input` x `dim_embedding`.

    '''

    def __init__(
        self,
        model,
        batch_size: int = 128,
        num_workers: int = 1,
        device: str = 'cpu',
    ):
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.device = device
        self.model = model


    def __call__(self, dataset: WildlifeDataset):
        self.model = self.model.to(self.device)
        self.model = self.model.eval()

        loader = torch.utils.data.DataLoader(
            dataset,
            num_workers=self.num_workers,
            batch_size=self.batch_size,
            shuffle=False,
        )
        outputs = []
        for image, label in tqdm(loader, mininterval=1, ncols=100):
            with torch.no_grad():
                output = self.model(image.to(self.device))
                outputs.append(output.cpu())
        return torch.cat(outputs).numpy()


    @classmethod
    def from_config(cls, config):
        model = realize(config.pop('model'))
        return cls(model=model, **config)

SIFT features

The SIFTFeatures extractor retrieves a set of SIFT descriptors for each provided image. The output is a list with a length of n_inputs, containing arrays. These arrays are 2D with a shape of n_descriptors x 128, where the value of n_descriptors depends on the number of SIFT descriptors extracted for the specific image. If one or less descriptors are extracted, the value is None. The SIFT implementation from OpenCV is used.

Example

The term dataset refers to any instance of WildlifeDataset with transforms that convert it into grayscale PIL image.

from wildlife_tools.features import SIFTFeatures

extractor = SIFTFeatures()
features = extractor(dataset)

Reference

Extracts SIFT descriptors for each image in the dataset.

Parameters:

Name Type Description Default
max_keypoints int | None

Limit number of extracted keypoints / descriptors.

None

Returns:

Type Description

list of arrays, each array corresponds to an input image and have shape [n_descriptors x 128].

Source code in features/sift.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
class SIFTFeatures(FeatureExtractor):
    '''
    Extracts SIFT descriptors for each image in the dataset.

    Args:
        max_keypoints: Limit number of extracted keypoints / descriptors.

    Returns:
        list of arrays, each array corresponds to an input image and have shape `[n_descriptors x 128]`.
    '''

    descriptor_dim: int = 128

    def __init__(self, max_keypoints: int | None = None):
        self.max_keypoints = max_keypoints

    def __call__(self, dataset: WildlifeDataset):
        if self.max_keypoints:
            sift = cv2.SIFT_create(nfeatures=self.max_keypoints)
        else:
            sift = cv2.SIFT_create()

        descriptors = []
        for img, y in tqdm(dataset, mininterval=1, ncols=100):
            keypoint, d = sift.detectAndCompute(np.array(img), None)
            if len(keypoint) <= 1:
                descriptors.append(None)
            else:
                descriptors.append(d)
        return descriptors

Data to memory

The DataToMemory extractor loads the WildlifeDataset into memory. This is particularly usefull for the LoftrMatcher, which operates directly with image tensors. While it is feasible to directly use the WildlifeDataset and load images from storage dynamically, the LoftrMatcher lacks a loading buffer. Consequently, loading images on the fly could become a significant bottleneck, especially when matching all query-database pairs, involving n_query x n_database image loads.

Loads dataset to memory for faster access.

Source code in features/memory.py
 9
10
11
12
13
14
15
16
class DataToMemory(FeatureExtractor):
    ''' Loads dataset to memory for faster access. '''

    def __call__(self, dataset: WildlifeDataset):
        features = []
        for x, y in tqdm(dataset, mininterval=1, ncols=100):
            features.append(x)
        return features