Skip to content

Wildlife dataset

WildlifeDataset is a class for creating pytorch style datasets by integration of datasets provided by wildlife-datasets library. It has implemented __len__ and __getattr__ methods, which allows using pytorch dataloaders for training and inference.

Metadata dataframe

Integral part of WildlifeDataset is metadata dataframe, which includes all information about images in the dataset. Typical dataset from the wildlife-dataset have following metadata table:

image_id identity path split bbox segmentation
image_1 a images/a/image_1 train bbox compressed rle
image_2 a images/a/image_2 test bbox compressed rle
image_3 b images/b/image_3 train bbox compressed rle

Columns image_id, identity, path are required, other columns are optional. In the table above, bbox is bounding box in form [x, y, width, height], and can be stored both as list or string. compressed rle is segmentation mask in compressed RLE format as described by pycocotools

Loading methods

If metadata table have optional bbox or segmentation columns, additional alternative image loading methods can be used.

Argument Loading effect
full Full image
full_mask Full image with redacted background
full_hide Full image with redacted foreground
bbox Bounding box cropp
bbox_mask Bounding box cropp with redacted background
bbox_hide Bounding box cropp with redacted foreground
crop_black Black background cropp, if there is one

Image loading methods

Example

from wildlife_tools.data.dataset import WildlifeDataset
import pandas as pd

metadata = pd.read_csv('ExampleDataset/metadata.csv')
dataset = WildlifeDataset(metadata, 'ExampleDataset')

# View first image in the dataset.
image, label = dataset[0]

Reference

PyTorch-style dataset for a wildlife datasets

Parameters:

Name Type Description Default
metadata pd.DataFrame

A pandas dataframe containing image metadata.

required
root str | None

Root directory if paths in metadata are relative. If None, paths in metadata are used as they are.

None
split Callable | None

A function that splits metadata, e.g., instance of data.Split.

None
transform Callable | None

A function that takes in an image and returns its transformed version.

None
img_load str

Method to load images. Options: 'full', 'full_mask', 'full_hide', 'bbox', 'bbox_mask', 'bbox_hide', and 'crop_black'.

'full'
col_path str

Column name in the metadata containing image file paths.

'path'
col_label str

Column name in the metadata containing class labels.

'identity'
load_label bool

If False, __getitem__ returns only image instead of (image, label) tuple.

True

Attributes:

Name Type Description
labels np.array

An integers array of ordinal encoding of labels.

labels_string np.array

A strings array of original labels.

labels_map dict

A mapping between labels and their ordinal encoding.

num_classes int

Return the number of unique classes in the dataset.

Source code in data/dataset.py
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
class WildlifeDataset():
    '''
    PyTorch-style dataset for a wildlife datasets

    Args:
        metadata: A pandas dataframe containing image metadata.
        root: Root directory if paths in metadata are relative. If None, paths in metadata are used as they are.
        split: A function that splits metadata, e.g., instance of data.Split.
        transform: A function that takes in an image and returns its transformed version.
        img_load: Method to load images.
            Options: 'full', 'full_mask', 'full_hide', 'bbox', 'bbox_mask', 'bbox_hide', and 'crop_black'.
        col_path: Column name in the metadata containing image file paths.
        col_label: Column name in the metadata containing class labels.
        load_label: If False, \_\_getitem\_\_ returns only image instead of (image, label) tuple.

    Attributes:
        labels np.array : An integers array of ordinal encoding of labels.
        labels_string np.array: A strings array of original labels.
        labels_map dict: A mapping between labels and their ordinal encoding.
        num_classes int: Return the number of unique classes in the dataset.
    '''

    def __init__(
        self,
        metadata: pd.DataFrame,
        root: str | None = None,
        split: Callable | None = None,
        transform: Callable | None  = None,
        img_load: str = 'full',
        col_path: str = 'path',
        col_label: str = 'identity',
        load_label: bool = True,
    ):
        self.split = split
        if self.split:
            metadata = self.split(metadata)

        self.metadata = metadata.reset_index(drop=True)
        self.root = root
        self.transform = transform
        self.img_load = img_load
        self.col_path = col_path
        self.col_label = col_label
        self.load_label = load_label
        self.labels, self.labels_map = pd.factorize(self.metadata[self.col_label].values)

    @property
    def labels_string(self):
        return self.metadata[self.col_label].astype(str).values

    @property
    def num_classes(self):
        return len(self.labels_map)


    def __len__(self):
        return len(self.metadata)


    def get_image(self, path):
        img = cv2.imread(path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = Image.fromarray(img)
        return img


    def __getitem__(self, idx):
        data = self.metadata.iloc[idx]
        if self.root:
            img_path = os.path.join(self.root, data[self.col_path])
        else:
            img_path = data[self.col_path]
        img = self.get_image(img_path)

        if self.img_load in ['full_mask', 'full_hide', 'bbox_mask', 'bbox_hide']:
            if not ('segmentation' in data):
                raise ValueError(f"{self.img_load} selected but no segmentation found.")
            if type(data['segmentation']) == str:
                segmentation = eval(data['segmentation'])
            else:
                segmentation = data['segmentation']

        if self.img_load in ['bbox', 'bbox_mask', 'bbox_hide']:
            if not ('bbox' in data):
                raise ValueError(f"{self.img_load} selected but no bbox found.")
            if type(data['bbox']) == str:
                bbox = json.loads(data['bbox'])
            else:
                bbox = data['bbox']

        # Load full image as it is.
        if self.img_load == 'full':
            img = img

        # Mask background using segmentation mask.
        elif self.img_load == 'full_mask':
            mask = mask_coco.decode(segmentation).astype('bool')
            img = Image.fromarray(img * mask[..., np.newaxis])

        # Hide object using segmentation mask
        elif self.img_load == 'full_hide':
            mask = mask_coco.decode(segmentation).astype('bool')
            img = Image.fromarray(img * ~mask[..., np.newaxis])

        # Crop to bounding box
        elif self.img_load == 'bbox':
            img = img.crop((bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1] + bbox[3]))

        # Mask background using segmentation mask and crop to bounding box.
        elif self.img_load == 'bbox_mask':
            mask = mask_coco.decode(segmentation).astype('bool')
            img = Image.fromarray(img * mask[..., np.newaxis])
            img = img.crop((bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1] + bbox[3]))

        # Hide object using segmentation mask and crop to bounding box.
        elif self.img_load == 'bbox_hide':
            mask = mask_coco.decode(segmentation).astype('bool')
            img = Image.fromarray(img * ~mask[..., np.newaxis])
            img = img.crop((bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1] + bbox[3]))

        # Crop black background around images
        elif self.img_load == 'crop_black':
            y_nonzero, x_nonzero, _ = np.nonzero(img)
            img = img.crop((np.min(x_nonzero), np.min(y_nonzero), np.max(x_nonzero), np.max(y_nonzero)))

        else:
            raise ValueError(f'Invalid img_load argument: {self.img_load}')

        if self.transform:
            img = self.transform(img)


        if self.load_label:
            return img, self.labels[idx]
        else:
            return img


    @classmethod
    def from_config(cls, config):
        config['split'] = realize(config.get('split'))
        config['transform'] = realize(config.get('transform'))
        config['metadata'] = pd.read_csv(config['metadata'], index_col=False)
        return cls(**config)