Skip to content

Reference datasets

This file describes methods associated with dataset creation and metadata.

DatasetFactory

Base class for creating datasets.

Attributes:

Name Type Description
df DataFrame

A full dataframe of the data.

summary dict

Summary of the dataset.

root str

Root directory for the data.

update_wrong_labels(bool) str

Whether fix_labels should be called.

unknown_name str

Name of the unknown class.

outdated_dataset bool

Tracks whether dataset was replaced by a new version.

determined_by_df bool

Specifies whether dataset is completely determined by its dataframe.

saved_to_system_folder bool

Specifies whether dataset is saved to system (hidden) folders.

transform Callable

Applied transform when loading the image.

img_load str

Applied transform when loading the image.

labels_string List[str]

List of labels in strings.

Source code in wildlife_datasets/datasets/datasets.py
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
class DatasetFactory:
    """Base class for creating datasets.

    Attributes:    
      df (pd.DataFrame): A full dataframe of the data.
      summary (dict): Summary of the dataset.
      root (str): Root directory for the data.
      update_wrong_labels(bool): Whether `fix_labels` should be called.
      unknown_name (str): Name of the unknown class.
      outdated_dataset (bool): Tracks whether dataset was replaced by a new version.
      determined_by_df (bool): Specifies whether dataset is completely determined by its dataframe.
      saved_to_system_folder (bool): Specifies whether dataset is saved to system (hidden) folders.
      transform (Callable): Applied transform when loading the image.
      img_load (str): Applied transform when loading the image.
      labels_string (List[str]): List of labels in strings.
    """

    unknown_name = 'unknown'
    outdated_dataset = False
    determined_by_df = True
    saved_to_system_folder = False
    download_warning = '''You are trying to download an already downloaded dataset.
        This message may have happened to due interrupted download or extract.
        To force the download use the `force=True` keyword such as
        get_data(..., force=True) or download(..., force=True).
        '''
    download_mark_name = 'already_downloaded'
    license_file_name = 'LICENSE_link'

    def __init__(
            self, 
            root: Optional[str] = None,
            df: Optional[pd.DataFrame] = None,
            update_wrong_labels: bool = True,
            transform: Optional[Callable] = None,
            img_load: str = "full",
            remove_unknown: bool = False,
            **kwargs) -> None:
        """Initializes the class.

        If `df` is specified, it copies it. Otherwise, it creates it
        by the `create_catalogue` method.

        Args:
            root (Optional[str], optional): Root directory for the data.
            df (Optional[pd.DataFrame], optional): A full dataframe of the data.
            update_wrong_labels (bool, optional): Whether `fix_labels` should be called.
            transform (Optional[Callable], optional): Applied transform when loading the image.
            img_load (str, optional): Applied transform when loading the image.
            remove_unknown (bool, optional): Whether unknown identities should be removed.
        """

        if not self.saved_to_system_folder and not os.path.exists(root):
            raise Exception('root does not exist. You may have have mispelled it.')
        if self.outdated_dataset:
            print('This dataset is outdated. You may want to call a newer version such as %sv2.' % self.__class__.__name__)
        self.update_wrong_labels = update_wrong_labels
        self.root = root
        if df is None:
            df = self.create_catalogue(**kwargs)
        else:
            if not self.determined_by_df:
                print('This dataset is not determined by dataframe. But you construct it so.')
        if remove_unknown:
            df = df[df['identity'] != self.unknown_name]
        self.df = df.reset_index(drop=True)
        self.metadata = self.df # Alias to df to unify with wildlife-tools
        self.transform = transform
        self.img_load = img_load
        if self.img_load == "auto":
            if "segmentation" in self.df:
                self.img_load = "bbox_mask"
            elif "bbox" in self.df:
                self.img_load = "bbox"
            else:
                self.img_load = "full"

    @property
    def labels_string(self):
        return self.df['identity'].astype(str).to_numpy()

    @property
    def num_classes(self):
        return self.df['identity'].nunique()

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx: int) -> Image:
        """Load an image with iloc `idx` with transforms `self.transform` and `self.img_load` applied.

        Args:
            idx (int): Index of the image.

        Returns:
            Loaded image.
        """

        img = self.get_image(idx)
        return self.apply_segmentation(img, idx)

    def get_image(self, idx: int) -> Image:
        """Load an image with iloc `idx`.

        Args:
            idx (int): Index of the image.

        Returns:
            Loaded image.
        """

        data = self.df.iloc[idx]
        if self.root:
            img_path = os.path.join(self.root, data['path'])
        else:
            img_path = data['path']
        img = self.load_image(img_path)
        return img

    def load_image(self, path: str) -> Image:
        """Load an image with `path`.

        Args:
            path (str): Path to the image.

        Returns:
            Loaded image.
        """

        return utils.load_image(path)

    def apply_segmentation(self, img: Image, idx: int) -> Image:
        """Applies segmentation or bounding box when loading an image.

        Args:
            img (Image): Loaded image.
            idx (int): Index of the image.

        Returns:
            Loaded image.
        """

        # Prepare for segmentations        
        if self.img_load in ["full_mask", "full_hide", "bbox_mask", "bbox_hide"]:
            data = self.df.iloc[idx]
            if not ("segmentation" in data):
                raise ValueError(f"{self.img_load} selected but no segmentation found.")
            segmentation = data["segmentation"]
            if isinstance(segmentation, list) or isinstance(segmentation, np.ndarray):
                # Convert polygon to compressed RLE
                w, h = img.size
                rles = mask_coco.frPyObjects([segmentation], h, w)
                segmentation = mask_coco.merge(rles)
            elif isinstance(segmentation, dict) and (isinstance(segmentation['counts'], list) or isinstance(segmentation['counts'], np.ndarray)):            
                # Convert uncompressed RLE to compressed RLE
                h, w = segmentation['size']
                segmentation = mask_coco.frPyObjects(segmentation, h, w)
            elif isinstance(segmentation, str):
                # Load image mask and convert it to compressed RLE
                segmentation = np.asfortranarray(utils.load_image(os.path.join(self.root, segmentation)))
                if segmentation.ndim == 3:
                    segmentation = segmentation[:,:,0]
                segmentation = mask_coco.encode(segmentation)
            elif not np.any(pd.isnull(segmentation)):
                raise Exception('Segmentation type not recognized')
        # Prepare for bounding boxes
        if self.img_load in ["bbox"]:
            data = self.df.iloc[idx]
            if not ("bbox" in data):
                raise ValueError(f"{self.img_load} selected but no bbox found.")
            if type(data["bbox"]) == str:
                bbox = json.loads(data["bbox"])
            else:
                bbox = data["bbox"]

        # Load full image as it is.
        if self.img_load == "full":
            img = img
        # Mask background using segmentation mask.
        elif self.img_load == "full_mask":
            if not np.any(pd.isnull(segmentation)):
                mask = mask_coco.decode(segmentation).astype("bool")
                img = Image.fromarray(img * mask[..., np.newaxis])
        # Hide object using segmentation mask
        elif self.img_load == "full_hide":
            if not np.any(pd.isnull(segmentation)):
                mask = mask_coco.decode(segmentation).astype("bool")
                img = Image.fromarray(img * ~mask[..., np.newaxis])
        # Crop to bounding box
        elif self.img_load == "bbox":
            if not np.any(pd.isnull(bbox)):
                img = img.crop((bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1] + bbox[3]))
        # Mask background using segmentation mask and crop to bounding box.
        elif self.img_load == "bbox_mask":
            if (not np.any(pd.isnull(segmentation))):
                mask = mask_coco.decode(segmentation).astype("bool")
                img = Image.fromarray(img * mask[..., np.newaxis])
                img = utils.crop_black(img)
        # Hide object using segmentation mask and crop to bounding box.
        elif self.img_load == "bbox_hide":
            if (not np.any(pd.isnull(segmentation))):
                mask = mask_coco.decode(segmentation).astype("bool")
                img = Image.fromarray(img * ~mask[..., np.newaxis])
                img = utils.crop_black(img)
        # Crop black background around images
        elif self.img_load == "crop_black":
            img = utils.crop_black(img)
        else:
            raise ValueError(f"Invalid img_load argument: {self.img_load}")

        if self.transform:
            img = self.transform(img)

        return img

    @classmethod
    def get_data(
            cls,
            root: str,
            force: bool = False,
            **kwargs
            ) -> None:
        """Downloads and extracts the data. Wrapper around `cls._download` and `cls._extract.`

        Args:
            root (str): Where the data should be stored.
            force (bool, optional): It the root exists, whether it should be overwritten.
        """

        dataset_name = cls.__name__
        mark_file_name = os.path.join(root, cls.download_mark_name)

        already_downloaded = os.path.exists(mark_file_name)
        if not cls.saved_to_system_folder and already_downloaded and not force:
            print('DATASET %s: DOWNLOADING STARTED.' % dataset_name)
            print(cls.download_warning)
        else:
            print('DATASET %s: DOWNLOADING STARTED.' % dataset_name)
            cls.download(root, force=force, **kwargs)
            print('DATASET %s: EXTRACTING STARTED.' % dataset_name)
            cls.extract(root,  **kwargs)
            print('DATASET %s: FINISHED.\n' % dataset_name)

    @classmethod
    def download(
            cls,
            root: str,
            force: bool = False,
            **kwargs
            ) -> None:
        """Downloads the data. Wrapper around `cls._download`.

        Args:
            root (str): Where the data should be stored.
            force (bool, optional): It the root exists, whether it should be overwritten.
        """

        dataset_name = cls.__name__
        mark_file_name = os.path.join(root, cls.download_mark_name)

        already_downloaded = os.path.exists(mark_file_name)
        if cls.saved_to_system_folder:
            cls._download(**kwargs)
        elif already_downloaded and not force:
            print('DATASET %s: DOWNLOADING STARTED.' % dataset_name)            
            print(cls.download_warning)
        else:
            if os.path.exists(mark_file_name):
                os.remove(mark_file_name)
            with utils.data_directory(root):
                cls._download(**kwargs)
            open(mark_file_name, 'a').close()
            if hasattr(cls, 'summary') and 'licenses_url' in cls.summary:
                with open(os.path.join(root, cls.license_file_name), 'w') as file:
                    file.write(cls.summary['licenses_url'])

    @classmethod    
    def extract(cls, root: str, **kwargs) -> None:
        """Extract the data. Wrapper around `cls._extract`.

        Args:
            root (str): Where the data should be stored.
        """

        if cls.saved_to_system_folder:
            cls._extract(**kwargs)
        else:
            with utils.data_directory(root):
                cls._extract(**kwargs)
            mark_file_name = os.path.join(root, cls.download_mark_name)
            open(mark_file_name, 'a').close()

    @classmethod
    def display_name(cls) -> str:
        """Returns name of the dataset without the v2 ending.

        Returns:
            Name of the dataset.
        """

        cls_parent = cls.__bases__[0]
        while cls_parent != object and cls_parent.outdated_dataset:
            cls = cls_parent
            cls_parent = cls.__bases__[0]            
        return cls.__name__

    def _download(self):
        """Downloads the dataset. Needs to be implemented by subclasses.

        Raises:
            NotImplementedError: Needs to be implemented by subclasses.
        """

        raise NotImplementedError('Needs to be implemented by subclasses.')

    def _extract(self):
        """Extracts the dataset. Needs to be implemented by subclasses.

        Raises:
            NotImplementedError: Needs to be implemented by subclasses.
        """

        raise NotImplementedError('Needs to be implemented by subclasses.')

    def create_catalogue(self):
        """Creates the dataframe.

        Raises:
            NotImplementedError: Needs to be implemented by subclasses.
        """

        raise NotImplementedError('Needs to be implemented by subclasses.')

    def fix_labels(self, df: pd.DataFrame) -> pd.DataFrame:
        """Fixes labels in dataframe.

        Automatically called in `finalize_catalogue`.                
        """

        return df

    def fix_labels_replace_identity(
            self,
            df: pd.DataFrame,
            replace_identity: List[Tuple],
            col: str = 'identity'
            ) -> pd.DataFrame:
        """Replaces all instances of identities.

        Args:
            df (pd.DataFrame): A full dataframe of the data.
            replace_identity (List[Tuple]): List of (old_identity, new_identity)
            col (str, optional): Column to replace in.

        Returns:
            A full dataframe of the data.
        """
        for old_identity, new_identity in replace_identity:
            df[col] = df[col].replace({old_identity: new_identity})
        return df

    def fix_labels_remove_identity(
            self,
            df: pd.DataFrame,
            identities_to_remove: List,
            col: str = 'identity'
            ) -> pd.DataFrame:
        """Removes all instances of identities.

        Args:
            df (pd.DataFrame): A full dataframe of the data.
            identities_to_remove (List): List of identities to remove.
            col (str, optional): Column to remove from.

        Returns:
            A full dataframe of the data.
        """
        idx_remove = [identity in identities_to_remove for identity in df[col]]
        return df[~np.array(idx_remove)]

    def fix_labels_replace_images(
            self,
            df: pd.DataFrame,
            replace_identity: List[Tuple],
            col: str = 'identity'
            ) -> pd.DataFrame:
        """Replaces specified images with specified identities.

        It looks for a subset of image_name in df['path'].
        It may cause problems with `os.path.sep`.

        Args:
            df (pd.DataFrame): A full dataframe of the data.
            replace_identity (List[Tuple]): List of (image_name, old_identity, new_identity).
            col (str, optional): Column to replace in.

        Returns:
            A full dataframe of the data.
        """
        for image_name, old_identity, new_identity in replace_identity:
            n_replaced = 0
            for index, df_row in df.iterrows():
                # Check that there is a image with the required name and identity 
                if image_name in df_row['path'] and old_identity == df_row[col]:
                    df.loc[index, col] = new_identity
                    n_replaced += 1
            if n_replaced == 0:
                print('File name %s with identity %s was not found.' % (image_name, str(old_identity)))
            elif n_replaced > 1:
                print('File name %s with identity %s was found multiple times.' % (image_name, str(old_identity)))
        return df

    def finalize_catalogue(self, df: pd.DataFrame) -> pd.DataFrame:
        """Reorders the dataframe and check file paths.

        Reorders the columns and removes constant columns.
        Checks if columns are in correct formats.
        Checks if ids are unique and if all files exist.

        Args:
            df (pd.DataFrame): A full dataframe of the data.

        Returns:
            A full dataframe of the data, slightly modified.
        """

        if self.update_wrong_labels:
            df = self.fix_labels(df)
        self.check_required_columns(df)
        self.check_types_columns(df)
        df = self.reorder_df(df)
        df = self.remove_constant_columns(df)
        self.check_unique_id(df)
        self.check_files_exist(df['path'])
        self.check_files_names(df['path'])
        if 'segmentation' in df.columns:
            self.check_files_exist(df['segmentation'])
        return df

    def check_required_columns(self, df: pd.DataFrame) -> None:
        """Check if all required columns are present.

        Args:
            df (pd.DataFrame): A full dataframe of the data.
        """

        for col_name in ['image_id', 'identity', 'path']:
            if col_name not in df.columns:
                raise(Exception('Column %s must be in the dataframe columns.' % col_name))

    def check_types_columns(self, df: pd.DataFrame) -> None:
        """Checks if columns are in correct formats.

        The format are specified in `requirements`, which is list
        of tuples. The first value is the name of the column
        and the second value is a list of formats. The column
        must be at least one of the formats.

        Args:
            df (pd.DataFrame): A full dataframe of the data.
        """

        requirements = [
            ('image_id', ['int', 'str']),
            ('identity', ['int', 'str']),
            ('path', ['str']),
            ('bbox', ['list_numeric']),
            ('date', ['date']),
            ('keypoints', ['list_numeric']),
            ('position', ['str']),
            ('species', ['str', 'list']),
            ('video', ['int']),
        ]
        # Verify if the columns are in correct formats
        for col_name, allowed_types in requirements:
            if col_name in df.columns:
                # Remove empty values to be sure
                col = df[col_name][~df[col_name].isnull()]
                if len(col) > 0:
                    self.check_types_column(col, col_name, allowed_types)

    def check_types_column(self, col: pd.Series, col_name: str, allowed_types: List[str]) -> None:
        """Checks if the column `col` is in the format `allowed_types`.

        Args:
            col (pd.Series): Column to be checked.
            col_name (str): Column name used only for raising exceptions.
            allowed_types (List[str]): List of strings with allowed values:
                `int` (all values must be integers),
                `str` (strings),
                `list` (lists),
                `list_numeric` (lists with numeric values),
                `date` (dates as tested by `pd.to_datetime`).
        """

        if 'int' in allowed_types and pd.api.types.is_integer_dtype(col):
            return None
        if 'str' in allowed_types and pd.api.types.is_string_dtype(col):
            return None
        if 'list' in allowed_types and pd.api.types.is_list_like(col):
            check = True
            for val in col:
                if not pd.api.types.is_list_like(val):
                    check = False
                    break
            if check:                
                return None        
        if 'list_numeric' in allowed_types and pd.api.types.is_list_like(col):
            check = True
            for val in col:            
                if not pd.api.types.is_list_like(val) and not pd.api.types.is_numeric_dtype(pd.Series(val)):
                    check = False
                    break
            if check:                
                return None
        if 'date' in allowed_types:
            try:
                pd.to_datetime(col)
                return None
            except:
                pass
        raise(Exception('Column %s has wrong type. Allowed types = %s' % (col_name, str(allowed_types))))

    def reorder_df(self, df: pd.DataFrame) -> pd.DataFrame:
        """Reorders rows and columns in the dataframe.

        Rows are sorted based on id.
        Columns are reorder based on the `default_order` list.

        Args:
            df (pd.DataFrame): A full dataframe of the data.

        Returns:
            A full dataframe of the data, slightly modified.
        """

        default_order = ['image_id', 'identity', 'path', 'bbox', 'date', 'keypoints', 'orientation', 'segmentation', 'species']
        df_names = list(df.columns)
        col_names = []
        for name in default_order:
            if name in df_names:
                col_names.append(name)
        for name in df_names:
            if name not in default_order:
                col_names.append(name)

        df = df.sort_values('image_id').reset_index(drop=True)
        return df.reindex(columns=col_names)

    def remove_constant_columns(self, df: pd.DataFrame) -> pd.DataFrame:
        """Removes columns with a single unique value.

        Args:
            df (pd.DataFrame): A full dataframe of the data.

        Returns:
            A full dataframe of the data, slightly modified.
        """ 

        for df_name in list(df.columns):
            if df[df_name].astype('str').nunique() == 1:
                df = df.drop([df_name], axis=1)
        return df

    def check_unique_id(self, df: pd.DataFrame) -> None:
        """Checks if values in the id column are unique.

        Args:
            df (pd.DataFrame): A full dataframe of the data.
        """

        if len(df['image_id'].unique()) != len(df):
            raise(Exception('Image ID not unique.'))

    def check_files_exist(self, col: pd.Series) -> None:
        """Checks if paths in a given column exist.

        Args:
            col (pd.Series): A column of a dataframe.
        """

        for path in col:
            if type(path) == str and not os.path.exists(os.path.join(self.root, path)):
                raise(Exception('Path does not exist:' + os.path.join(self.root, path)))

    def check_files_names(self, col: pd.Series) -> None:
        """Checks if paths contain .

        Args:
            col (pd.Series): A column of a dataframe.
        """

        for path in col:
            try:
                path.encode("iso-8859-1")
            except UnicodeEncodeError:
                raise(Exception('Characters in path may cause problems. Please use only ISO-8859-1 characters: ' + os.path.join(path)))

    def plot_grid(
            self,
            n_rows: int = 5,
            n_cols: int = 8,
            offset: float = 10,
            img_min: float = 100,
            rotate: bool = True,
            header_cols: Optional[List[str]] = None,
            idx: Optional[Union[List[bool],List[int]]] = None,
            background_color: Tuple[int] = (0, 0, 0),
            **kwargs
            ) -> None:
        """Plots a grid of size (n_rows, n_cols) with images from the dataframe.

        Args:
            n_rows (int, optional): The number of rows in the grid.
            n_cols (int, optional): The number of columns in the grid.
            offset (float, optional): The offset between images.
            img_min (float, optional): The minimal size of the plotted images.
            rotate (bool, optional): Rotates the images to have the same orientation.
            header_cols (Optional[List[str]], optional): List of headers for each column.
            idx (Optional[Union[List[bool],List[int]]], optional): List of indices to plot. None plots random images. Index -1 plots an empty image.
            background_color (Tuple[int], optional): Background color of the grid.
        """

        if len(self.df) == 0:
            return None

        # Select indices of images to be plotted
        if idx is None:
            n = min(len(self.df), n_rows*n_cols)
            idx = np.random.permutation(len(self.df))[:n]
        else:
            if isinstance(idx, pd.Series):
                idx = idx.values
            if isinstance(idx[0], (bool, np.bool_)):
                idx = np.where(idx)[0]
            n = min(np.array(idx).size, n_rows*n_cols)
            idx = np.matrix.flatten(np.array(idx))[:n]

        # Load images and compute their ratio
        ratios = []
        ims = []
        for k in idx:
            if k >= 0:
                # Load the image with index k
                im = self[k]
                ims.append(im)
                ratios.append(im.size[0] / im.size[1])
            else:
                # Load a black image
                ims.append(Image.fromarray(np.zeros((2, 2), dtype = "uint8")))

        # Safeguard when all indices are -1
        if len(ratios) == 0:
            return None

        # Get the size of the images after being resized
        ratio = np.median(ratios)
        if ratio > 1:    
            img_w, img_h = int(img_min*ratio), int(img_min)
        else:
            img_w, img_h = int(img_min), int(img_min/ratio)

        # Compute height offset if headers are present
        if header_cols is not None:
            offset_h = 30
            if len(header_cols) != n_cols:
                raise(Exception("Length of header_cols must be the same as n_cols."))
        else:
            offset_h = 0

        # Create an empty image grid
        im_grid = Image.new('RGB', (n_cols*img_w + (n_cols-1)*offset, offset_h + n_rows*img_h + (n_rows-1)*offset), background_color)

        # Fill the grid image by image
        pos_y = offset_h
        for i in range(n_rows):
            row_h = 0
            for j in range(n_cols):
                k = (n_cols)*i + j
                if k < n:
                    # Possibly rotate the image
                    im = ims[k]
                    if rotate and ((ratio > 1 and im.size[0] < im.size[1]) or (ratio < 1 and im.size[0] > im.size[1])):
                        im = im.transpose(Image.Transpose.ROTATE_90)

                    # Rescale the image
                    im.thumbnail((img_w,img_h))
                    row_h = max(row_h, im.size[1])

                    # Place the image on the grid
                    pos_x = j*img_w + j*offset
                    im_grid.paste(im, (pos_x,pos_y))
            if row_h > 0:
                pos_y += row_h + offset
        im_grid = im_grid.crop((0, 0, im_grid.size[0], pos_y-offset))

        # Plot the image and add column headers if present
        fig = plt.figure()
        fig.patch.set_visible(False)
        ax = fig.add_subplot(111)
        plt.axis('off')
        plt.imshow(im_grid)
        if header_cols is not None:
            color = kwargs.pop('color', 'white')
            ha = kwargs.pop('ha', 'center')
            va = kwargs.pop('va', 'center')
            for i, header in enumerate(header_cols):
                pos_x = (i+0.5)*img_w + i*offset
                pos_y = offset_h/2
                plt.text(pos_x, pos_y, str(header), color=color, ha=ha, va=va, **kwargs)
        return fig

__getitem__(idx)

Load an image with iloc idx with transforms self.transform and self.img_load applied.

Parameters:

Name Type Description Default
idx int

Index of the image.

required

Returns:

Type Description
Image

Loaded image.

Source code in wildlife_datasets/datasets/datasets.py
100
101
102
103
104
105
106
107
108
109
110
111
def __getitem__(self, idx: int) -> Image:
    """Load an image with iloc `idx` with transforms `self.transform` and `self.img_load` applied.

    Args:
        idx (int): Index of the image.

    Returns:
        Loaded image.
    """

    img = self.get_image(idx)
    return self.apply_segmentation(img, idx)

__init__(root=None, df=None, update_wrong_labels=True, transform=None, img_load='full', remove_unknown=False, **kwargs)

Initializes the class.

If df is specified, it copies it. Otherwise, it creates it by the create_catalogue method.

Parameters:

Name Type Description Default
root Optional[str]

Root directory for the data.

None
df Optional[DataFrame]

A full dataframe of the data.

None
update_wrong_labels bool

Whether fix_labels should be called.

True
transform Optional[Callable]

Applied transform when loading the image.

None
img_load str

Applied transform when loading the image.

'full'
remove_unknown bool

Whether unknown identities should be removed.

False
Source code in wildlife_datasets/datasets/datasets.py
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
def __init__(
        self, 
        root: Optional[str] = None,
        df: Optional[pd.DataFrame] = None,
        update_wrong_labels: bool = True,
        transform: Optional[Callable] = None,
        img_load: str = "full",
        remove_unknown: bool = False,
        **kwargs) -> None:
    """Initializes the class.

    If `df` is specified, it copies it. Otherwise, it creates it
    by the `create_catalogue` method.

    Args:
        root (Optional[str], optional): Root directory for the data.
        df (Optional[pd.DataFrame], optional): A full dataframe of the data.
        update_wrong_labels (bool, optional): Whether `fix_labels` should be called.
        transform (Optional[Callable], optional): Applied transform when loading the image.
        img_load (str, optional): Applied transform when loading the image.
        remove_unknown (bool, optional): Whether unknown identities should be removed.
    """

    if not self.saved_to_system_folder and not os.path.exists(root):
        raise Exception('root does not exist. You may have have mispelled it.')
    if self.outdated_dataset:
        print('This dataset is outdated. You may want to call a newer version such as %sv2.' % self.__class__.__name__)
    self.update_wrong_labels = update_wrong_labels
    self.root = root
    if df is None:
        df = self.create_catalogue(**kwargs)
    else:
        if not self.determined_by_df:
            print('This dataset is not determined by dataframe. But you construct it so.')
    if remove_unknown:
        df = df[df['identity'] != self.unknown_name]
    self.df = df.reset_index(drop=True)
    self.metadata = self.df # Alias to df to unify with wildlife-tools
    self.transform = transform
    self.img_load = img_load
    if self.img_load == "auto":
        if "segmentation" in self.df:
            self.img_load = "bbox_mask"
        elif "bbox" in self.df:
            self.img_load = "bbox"
        else:
            self.img_load = "full"

apply_segmentation(img, idx)

Applies segmentation or bounding box when loading an image.

Parameters:

Name Type Description Default
img Image

Loaded image.

required
idx int

Index of the image.

required

Returns:

Type Description
Image

Loaded image.

Source code in wildlife_datasets/datasets/datasets.py
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
def apply_segmentation(self, img: Image, idx: int) -> Image:
    """Applies segmentation or bounding box when loading an image.

    Args:
        img (Image): Loaded image.
        idx (int): Index of the image.

    Returns:
        Loaded image.
    """

    # Prepare for segmentations        
    if self.img_load in ["full_mask", "full_hide", "bbox_mask", "bbox_hide"]:
        data = self.df.iloc[idx]
        if not ("segmentation" in data):
            raise ValueError(f"{self.img_load} selected but no segmentation found.")
        segmentation = data["segmentation"]
        if isinstance(segmentation, list) or isinstance(segmentation, np.ndarray):
            # Convert polygon to compressed RLE
            w, h = img.size
            rles = mask_coco.frPyObjects([segmentation], h, w)
            segmentation = mask_coco.merge(rles)
        elif isinstance(segmentation, dict) and (isinstance(segmentation['counts'], list) or isinstance(segmentation['counts'], np.ndarray)):            
            # Convert uncompressed RLE to compressed RLE
            h, w = segmentation['size']
            segmentation = mask_coco.frPyObjects(segmentation, h, w)
        elif isinstance(segmentation, str):
            # Load image mask and convert it to compressed RLE
            segmentation = np.asfortranarray(utils.load_image(os.path.join(self.root, segmentation)))
            if segmentation.ndim == 3:
                segmentation = segmentation[:,:,0]
            segmentation = mask_coco.encode(segmentation)
        elif not np.any(pd.isnull(segmentation)):
            raise Exception('Segmentation type not recognized')
    # Prepare for bounding boxes
    if self.img_load in ["bbox"]:
        data = self.df.iloc[idx]
        if not ("bbox" in data):
            raise ValueError(f"{self.img_load} selected but no bbox found.")
        if type(data["bbox"]) == str:
            bbox = json.loads(data["bbox"])
        else:
            bbox = data["bbox"]

    # Load full image as it is.
    if self.img_load == "full":
        img = img
    # Mask background using segmentation mask.
    elif self.img_load == "full_mask":
        if not np.any(pd.isnull(segmentation)):
            mask = mask_coco.decode(segmentation).astype("bool")
            img = Image.fromarray(img * mask[..., np.newaxis])
    # Hide object using segmentation mask
    elif self.img_load == "full_hide":
        if not np.any(pd.isnull(segmentation)):
            mask = mask_coco.decode(segmentation).astype("bool")
            img = Image.fromarray(img * ~mask[..., np.newaxis])
    # Crop to bounding box
    elif self.img_load == "bbox":
        if not np.any(pd.isnull(bbox)):
            img = img.crop((bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1] + bbox[3]))
    # Mask background using segmentation mask and crop to bounding box.
    elif self.img_load == "bbox_mask":
        if (not np.any(pd.isnull(segmentation))):
            mask = mask_coco.decode(segmentation).astype("bool")
            img = Image.fromarray(img * mask[..., np.newaxis])
            img = utils.crop_black(img)
    # Hide object using segmentation mask and crop to bounding box.
    elif self.img_load == "bbox_hide":
        if (not np.any(pd.isnull(segmentation))):
            mask = mask_coco.decode(segmentation).astype("bool")
            img = Image.fromarray(img * ~mask[..., np.newaxis])
            img = utils.crop_black(img)
    # Crop black background around images
    elif self.img_load == "crop_black":
        img = utils.crop_black(img)
    else:
        raise ValueError(f"Invalid img_load argument: {self.img_load}")

    if self.transform:
        img = self.transform(img)

    return img

check_files_exist(col)

Checks if paths in a given column exist.

Parameters:

Name Type Description Default
col Series

A column of a dataframe.

required
Source code in wildlife_datasets/datasets/datasets.py
586
587
588
589
590
591
592
593
594
595
def check_files_exist(self, col: pd.Series) -> None:
    """Checks if paths in a given column exist.

    Args:
        col (pd.Series): A column of a dataframe.
    """

    for path in col:
        if type(path) == str and not os.path.exists(os.path.join(self.root, path)):
            raise(Exception('Path does not exist:' + os.path.join(self.root, path)))

check_files_names(col)

Checks if paths contain .

Parameters:

Name Type Description Default
col Series

A column of a dataframe.

required
Source code in wildlife_datasets/datasets/datasets.py
597
598
599
600
601
602
603
604
605
606
607
608
def check_files_names(self, col: pd.Series) -> None:
    """Checks if paths contain .

    Args:
        col (pd.Series): A column of a dataframe.
    """

    for path in col:
        try:
            path.encode("iso-8859-1")
        except UnicodeEncodeError:
            raise(Exception('Characters in path may cause problems. Please use only ISO-8859-1 characters: ' + os.path.join(path)))

check_required_columns(df)

Check if all required columns are present.

Parameters:

Name Type Description Default
df DataFrame

A full dataframe of the data.

required
Source code in wildlife_datasets/datasets/datasets.py
451
452
453
454
455
456
457
458
459
460
def check_required_columns(self, df: pd.DataFrame) -> None:
    """Check if all required columns are present.

    Args:
        df (pd.DataFrame): A full dataframe of the data.
    """

    for col_name in ['image_id', 'identity', 'path']:
        if col_name not in df.columns:
            raise(Exception('Column %s must be in the dataframe columns.' % col_name))

check_types_column(col, col_name, allowed_types)

Checks if the column col is in the format allowed_types.

Parameters:

Name Type Description Default
col Series

Column to be checked.

required
col_name str

Column name used only for raising exceptions.

required
allowed_types List[str]

List of strings with allowed values: int (all values must be integers), str (strings), list (lists), list_numeric (lists with numeric values), date (dates as tested by pd.to_datetime).

required
Source code in wildlife_datasets/datasets/datasets.py
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
def check_types_column(self, col: pd.Series, col_name: str, allowed_types: List[str]) -> None:
    """Checks if the column `col` is in the format `allowed_types`.

    Args:
        col (pd.Series): Column to be checked.
        col_name (str): Column name used only for raising exceptions.
        allowed_types (List[str]): List of strings with allowed values:
            `int` (all values must be integers),
            `str` (strings),
            `list` (lists),
            `list_numeric` (lists with numeric values),
            `date` (dates as tested by `pd.to_datetime`).
    """

    if 'int' in allowed_types and pd.api.types.is_integer_dtype(col):
        return None
    if 'str' in allowed_types and pd.api.types.is_string_dtype(col):
        return None
    if 'list' in allowed_types and pd.api.types.is_list_like(col):
        check = True
        for val in col:
            if not pd.api.types.is_list_like(val):
                check = False
                break
        if check:                
            return None        
    if 'list_numeric' in allowed_types and pd.api.types.is_list_like(col):
        check = True
        for val in col:            
            if not pd.api.types.is_list_like(val) and not pd.api.types.is_numeric_dtype(pd.Series(val)):
                check = False
                break
        if check:                
            return None
    if 'date' in allowed_types:
        try:
            pd.to_datetime(col)
            return None
        except:
            pass
    raise(Exception('Column %s has wrong type. Allowed types = %s' % (col_name, str(allowed_types))))

check_types_columns(df)

Checks if columns are in correct formats.

The format are specified in requirements, which is list of tuples. The first value is the name of the column and the second value is a list of formats. The column must be at least one of the formats.

Parameters:

Name Type Description Default
df DataFrame

A full dataframe of the data.

required
Source code in wildlife_datasets/datasets/datasets.py
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
def check_types_columns(self, df: pd.DataFrame) -> None:
    """Checks if columns are in correct formats.

    The format are specified in `requirements`, which is list
    of tuples. The first value is the name of the column
    and the second value is a list of formats. The column
    must be at least one of the formats.

    Args:
        df (pd.DataFrame): A full dataframe of the data.
    """

    requirements = [
        ('image_id', ['int', 'str']),
        ('identity', ['int', 'str']),
        ('path', ['str']),
        ('bbox', ['list_numeric']),
        ('date', ['date']),
        ('keypoints', ['list_numeric']),
        ('position', ['str']),
        ('species', ['str', 'list']),
        ('video', ['int']),
    ]
    # Verify if the columns are in correct formats
    for col_name, allowed_types in requirements:
        if col_name in df.columns:
            # Remove empty values to be sure
            col = df[col_name][~df[col_name].isnull()]
            if len(col) > 0:
                self.check_types_column(col, col_name, allowed_types)

check_unique_id(df)

Checks if values in the id column are unique.

Parameters:

Name Type Description Default
df DataFrame

A full dataframe of the data.

required
Source code in wildlife_datasets/datasets/datasets.py
576
577
578
579
580
581
582
583
584
def check_unique_id(self, df: pd.DataFrame) -> None:
    """Checks if values in the id column are unique.

    Args:
        df (pd.DataFrame): A full dataframe of the data.
    """

    if len(df['image_id'].unique()) != len(df):
        raise(Exception('Image ID not unique.'))

create_catalogue()

Creates the dataframe.

Raises:

Type Description
NotImplementedError

Needs to be implemented by subclasses.

Source code in wildlife_datasets/datasets/datasets.py
336
337
338
339
340
341
342
343
def create_catalogue(self):
    """Creates the dataframe.

    Raises:
        NotImplementedError: Needs to be implemented by subclasses.
    """

    raise NotImplementedError('Needs to be implemented by subclasses.')

display_name() classmethod

Returns name of the dataset without the v2 ending.

Returns:

Type Description
str

Name of the dataset.

Source code in wildlife_datasets/datasets/datasets.py
304
305
306
307
308
309
310
311
312
313
314
315
316
@classmethod
def display_name(cls) -> str:
    """Returns name of the dataset without the v2 ending.

    Returns:
        Name of the dataset.
    """

    cls_parent = cls.__bases__[0]
    while cls_parent != object and cls_parent.outdated_dataset:
        cls = cls_parent
        cls_parent = cls.__bases__[0]            
    return cls.__name__

download(root, force=False, **kwargs) classmethod

Downloads the data. Wrapper around cls._download.

Parameters:

Name Type Description Default
root str

Where the data should be stored.

required
force bool

It the root exists, whether it should be overwritten.

False
Source code in wildlife_datasets/datasets/datasets.py
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
@classmethod
def download(
        cls,
        root: str,
        force: bool = False,
        **kwargs
        ) -> None:
    """Downloads the data. Wrapper around `cls._download`.

    Args:
        root (str): Where the data should be stored.
        force (bool, optional): It the root exists, whether it should be overwritten.
    """

    dataset_name = cls.__name__
    mark_file_name = os.path.join(root, cls.download_mark_name)

    already_downloaded = os.path.exists(mark_file_name)
    if cls.saved_to_system_folder:
        cls._download(**kwargs)
    elif already_downloaded and not force:
        print('DATASET %s: DOWNLOADING STARTED.' % dataset_name)            
        print(cls.download_warning)
    else:
        if os.path.exists(mark_file_name):
            os.remove(mark_file_name)
        with utils.data_directory(root):
            cls._download(**kwargs)
        open(mark_file_name, 'a').close()
        if hasattr(cls, 'summary') and 'licenses_url' in cls.summary:
            with open(os.path.join(root, cls.license_file_name), 'w') as file:
                file.write(cls.summary['licenses_url'])

extract(root, **kwargs) classmethod

Extract the data. Wrapper around cls._extract.

Parameters:

Name Type Description Default
root str

Where the data should be stored.

required
Source code in wildlife_datasets/datasets/datasets.py
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
@classmethod    
def extract(cls, root: str, **kwargs) -> None:
    """Extract the data. Wrapper around `cls._extract`.

    Args:
        root (str): Where the data should be stored.
    """

    if cls.saved_to_system_folder:
        cls._extract(**kwargs)
    else:
        with utils.data_directory(root):
            cls._extract(**kwargs)
        mark_file_name = os.path.join(root, cls.download_mark_name)
        open(mark_file_name, 'a').close()

finalize_catalogue(df)

Reorders the dataframe and check file paths.

Reorders the columns and removes constant columns. Checks if columns are in correct formats. Checks if ids are unique and if all files exist.

Parameters:

Name Type Description Default
df DataFrame

A full dataframe of the data.

required

Returns:

Type Description
DataFrame

A full dataframe of the data, slightly modified.

Source code in wildlife_datasets/datasets/datasets.py
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
def finalize_catalogue(self, df: pd.DataFrame) -> pd.DataFrame:
    """Reorders the dataframe and check file paths.

    Reorders the columns and removes constant columns.
    Checks if columns are in correct formats.
    Checks if ids are unique and if all files exist.

    Args:
        df (pd.DataFrame): A full dataframe of the data.

    Returns:
        A full dataframe of the data, slightly modified.
    """

    if self.update_wrong_labels:
        df = self.fix_labels(df)
    self.check_required_columns(df)
    self.check_types_columns(df)
    df = self.reorder_df(df)
    df = self.remove_constant_columns(df)
    self.check_unique_id(df)
    self.check_files_exist(df['path'])
    self.check_files_names(df['path'])
    if 'segmentation' in df.columns:
        self.check_files_exist(df['segmentation'])
    return df

fix_labels(df)

Fixes labels in dataframe.

Automatically called in finalize_catalogue.

Source code in wildlife_datasets/datasets/datasets.py
345
346
347
348
349
350
351
def fix_labels(self, df: pd.DataFrame) -> pd.DataFrame:
    """Fixes labels in dataframe.

    Automatically called in `finalize_catalogue`.                
    """

    return df

fix_labels_remove_identity(df, identities_to_remove, col='identity')

Removes all instances of identities.

Parameters:

Name Type Description Default
df DataFrame

A full dataframe of the data.

required
identities_to_remove List

List of identities to remove.

required
col str

Column to remove from.

'identity'

Returns:

Type Description
DataFrame

A full dataframe of the data.

Source code in wildlife_datasets/datasets/datasets.py
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
def fix_labels_remove_identity(
        self,
        df: pd.DataFrame,
        identities_to_remove: List,
        col: str = 'identity'
        ) -> pd.DataFrame:
    """Removes all instances of identities.

    Args:
        df (pd.DataFrame): A full dataframe of the data.
        identities_to_remove (List): List of identities to remove.
        col (str, optional): Column to remove from.

    Returns:
        A full dataframe of the data.
    """
    idx_remove = [identity in identities_to_remove for identity in df[col]]
    return df[~np.array(idx_remove)]

fix_labels_replace_identity(df, replace_identity, col='identity')

Replaces all instances of identities.

Parameters:

Name Type Description Default
df DataFrame

A full dataframe of the data.

required
replace_identity List[Tuple]

List of (old_identity, new_identity)

required
col str

Column to replace in.

'identity'

Returns:

Type Description
DataFrame

A full dataframe of the data.

Source code in wildlife_datasets/datasets/datasets.py
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
def fix_labels_replace_identity(
        self,
        df: pd.DataFrame,
        replace_identity: List[Tuple],
        col: str = 'identity'
        ) -> pd.DataFrame:
    """Replaces all instances of identities.

    Args:
        df (pd.DataFrame): A full dataframe of the data.
        replace_identity (List[Tuple]): List of (old_identity, new_identity)
        col (str, optional): Column to replace in.

    Returns:
        A full dataframe of the data.
    """
    for old_identity, new_identity in replace_identity:
        df[col] = df[col].replace({old_identity: new_identity})
    return df

fix_labels_replace_images(df, replace_identity, col='identity')

Replaces specified images with specified identities.

It looks for a subset of image_name in df['path']. It may cause problems with os.path.sep.

Parameters:

Name Type Description Default
df DataFrame

A full dataframe of the data.

required
replace_identity List[Tuple]

List of (image_name, old_identity, new_identity).

required
col str

Column to replace in.

'identity'

Returns:

Type Description
DataFrame

A full dataframe of the data.

Source code in wildlife_datasets/datasets/datasets.py
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
def fix_labels_replace_images(
        self,
        df: pd.DataFrame,
        replace_identity: List[Tuple],
        col: str = 'identity'
        ) -> pd.DataFrame:
    """Replaces specified images with specified identities.

    It looks for a subset of image_name in df['path'].
    It may cause problems with `os.path.sep`.

    Args:
        df (pd.DataFrame): A full dataframe of the data.
        replace_identity (List[Tuple]): List of (image_name, old_identity, new_identity).
        col (str, optional): Column to replace in.

    Returns:
        A full dataframe of the data.
    """
    for image_name, old_identity, new_identity in replace_identity:
        n_replaced = 0
        for index, df_row in df.iterrows():
            # Check that there is a image with the required name and identity 
            if image_name in df_row['path'] and old_identity == df_row[col]:
                df.loc[index, col] = new_identity
                n_replaced += 1
        if n_replaced == 0:
            print('File name %s with identity %s was not found.' % (image_name, str(old_identity)))
        elif n_replaced > 1:
            print('File name %s with identity %s was found multiple times.' % (image_name, str(old_identity)))
    return df

get_data(root, force=False, **kwargs) classmethod

Downloads and extracts the data. Wrapper around cls._download and cls._extract.

Parameters:

Name Type Description Default
root str

Where the data should be stored.

required
force bool

It the root exists, whether it should be overwritten.

False
Source code in wildlife_datasets/datasets/datasets.py
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
@classmethod
def get_data(
        cls,
        root: str,
        force: bool = False,
        **kwargs
        ) -> None:
    """Downloads and extracts the data. Wrapper around `cls._download` and `cls._extract.`

    Args:
        root (str): Where the data should be stored.
        force (bool, optional): It the root exists, whether it should be overwritten.
    """

    dataset_name = cls.__name__
    mark_file_name = os.path.join(root, cls.download_mark_name)

    already_downloaded = os.path.exists(mark_file_name)
    if not cls.saved_to_system_folder and already_downloaded and not force:
        print('DATASET %s: DOWNLOADING STARTED.' % dataset_name)
        print(cls.download_warning)
    else:
        print('DATASET %s: DOWNLOADING STARTED.' % dataset_name)
        cls.download(root, force=force, **kwargs)
        print('DATASET %s: EXTRACTING STARTED.' % dataset_name)
        cls.extract(root,  **kwargs)
        print('DATASET %s: FINISHED.\n' % dataset_name)

get_image(idx)

Load an image with iloc idx.

Parameters:

Name Type Description Default
idx int

Index of the image.

required

Returns:

Type Description
Image

Loaded image.

Source code in wildlife_datasets/datasets/datasets.py
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
def get_image(self, idx: int) -> Image:
    """Load an image with iloc `idx`.

    Args:
        idx (int): Index of the image.

    Returns:
        Loaded image.
    """

    data = self.df.iloc[idx]
    if self.root:
        img_path = os.path.join(self.root, data['path'])
    else:
        img_path = data['path']
    img = self.load_image(img_path)
    return img

load_image(path)

Load an image with path.

Parameters:

Name Type Description Default
path str

Path to the image.

required

Returns:

Type Description
Image

Loaded image.

Source code in wildlife_datasets/datasets/datasets.py
131
132
133
134
135
136
137
138
139
140
141
def load_image(self, path: str) -> Image:
    """Load an image with `path`.

    Args:
        path (str): Path to the image.

    Returns:
        Loaded image.
    """

    return utils.load_image(path)

plot_grid(n_rows=5, n_cols=8, offset=10, img_min=100, rotate=True, header_cols=None, idx=None, background_color=(0, 0, 0), **kwargs)

Plots a grid of size (n_rows, n_cols) with images from the dataframe.

Parameters:

Name Type Description Default
n_rows int

The number of rows in the grid.

5
n_cols int

The number of columns in the grid.

8
offset float

The offset between images.

10
img_min float

The minimal size of the plotted images.

100
rotate bool

Rotates the images to have the same orientation.

True
header_cols Optional[List[str]]

List of headers for each column.

None
idx Optional[Union[List[bool], List[int]]]

List of indices to plot. None plots random images. Index -1 plots an empty image.

None
background_color Tuple[int]

Background color of the grid.

(0, 0, 0)
Source code in wildlife_datasets/datasets/datasets.py
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
def plot_grid(
        self,
        n_rows: int = 5,
        n_cols: int = 8,
        offset: float = 10,
        img_min: float = 100,
        rotate: bool = True,
        header_cols: Optional[List[str]] = None,
        idx: Optional[Union[List[bool],List[int]]] = None,
        background_color: Tuple[int] = (0, 0, 0),
        **kwargs
        ) -> None:
    """Plots a grid of size (n_rows, n_cols) with images from the dataframe.

    Args:
        n_rows (int, optional): The number of rows in the grid.
        n_cols (int, optional): The number of columns in the grid.
        offset (float, optional): The offset between images.
        img_min (float, optional): The minimal size of the plotted images.
        rotate (bool, optional): Rotates the images to have the same orientation.
        header_cols (Optional[List[str]], optional): List of headers for each column.
        idx (Optional[Union[List[bool],List[int]]], optional): List of indices to plot. None plots random images. Index -1 plots an empty image.
        background_color (Tuple[int], optional): Background color of the grid.
    """

    if len(self.df) == 0:
        return None

    # Select indices of images to be plotted
    if idx is None:
        n = min(len(self.df), n_rows*n_cols)
        idx = np.random.permutation(len(self.df))[:n]
    else:
        if isinstance(idx, pd.Series):
            idx = idx.values
        if isinstance(idx[0], (bool, np.bool_)):
            idx = np.where(idx)[0]
        n = min(np.array(idx).size, n_rows*n_cols)
        idx = np.matrix.flatten(np.array(idx))[:n]

    # Load images and compute their ratio
    ratios = []
    ims = []
    for k in idx:
        if k >= 0:
            # Load the image with index k
            im = self[k]
            ims.append(im)
            ratios.append(im.size[0] / im.size[1])
        else:
            # Load a black image
            ims.append(Image.fromarray(np.zeros((2, 2), dtype = "uint8")))

    # Safeguard when all indices are -1
    if len(ratios) == 0:
        return None

    # Get the size of the images after being resized
    ratio = np.median(ratios)
    if ratio > 1:    
        img_w, img_h = int(img_min*ratio), int(img_min)
    else:
        img_w, img_h = int(img_min), int(img_min/ratio)

    # Compute height offset if headers are present
    if header_cols is not None:
        offset_h = 30
        if len(header_cols) != n_cols:
            raise(Exception("Length of header_cols must be the same as n_cols."))
    else:
        offset_h = 0

    # Create an empty image grid
    im_grid = Image.new('RGB', (n_cols*img_w + (n_cols-1)*offset, offset_h + n_rows*img_h + (n_rows-1)*offset), background_color)

    # Fill the grid image by image
    pos_y = offset_h
    for i in range(n_rows):
        row_h = 0
        for j in range(n_cols):
            k = (n_cols)*i + j
            if k < n:
                # Possibly rotate the image
                im = ims[k]
                if rotate and ((ratio > 1 and im.size[0] < im.size[1]) or (ratio < 1 and im.size[0] > im.size[1])):
                    im = im.transpose(Image.Transpose.ROTATE_90)

                # Rescale the image
                im.thumbnail((img_w,img_h))
                row_h = max(row_h, im.size[1])

                # Place the image on the grid
                pos_x = j*img_w + j*offset
                im_grid.paste(im, (pos_x,pos_y))
        if row_h > 0:
            pos_y += row_h + offset
    im_grid = im_grid.crop((0, 0, im_grid.size[0], pos_y-offset))

    # Plot the image and add column headers if present
    fig = plt.figure()
    fig.patch.set_visible(False)
    ax = fig.add_subplot(111)
    plt.axis('off')
    plt.imshow(im_grid)
    if header_cols is not None:
        color = kwargs.pop('color', 'white')
        ha = kwargs.pop('ha', 'center')
        va = kwargs.pop('va', 'center')
        for i, header in enumerate(header_cols):
            pos_x = (i+0.5)*img_w + i*offset
            pos_y = offset_h/2
            plt.text(pos_x, pos_y, str(header), color=color, ha=ha, va=va, **kwargs)
    return fig

remove_constant_columns(df)

Removes columns with a single unique value.

Parameters:

Name Type Description Default
df DataFrame

A full dataframe of the data.

required

Returns:

Type Description
DataFrame

A full dataframe of the data, slightly modified.

Source code in wildlife_datasets/datasets/datasets.py
561
562
563
564
565
566
567
568
569
570
571
572
573
574
def remove_constant_columns(self, df: pd.DataFrame) -> pd.DataFrame:
    """Removes columns with a single unique value.

    Args:
        df (pd.DataFrame): A full dataframe of the data.

    Returns:
        A full dataframe of the data, slightly modified.
    """ 

    for df_name in list(df.columns):
        if df[df_name].astype('str').nunique() == 1:
            df = df.drop([df_name], axis=1)
    return df

reorder_df(df)

Reorders rows and columns in the dataframe.

Rows are sorted based on id. Columns are reorder based on the default_order list.

Parameters:

Name Type Description Default
df DataFrame

A full dataframe of the data.

required

Returns:

Type Description
DataFrame

A full dataframe of the data, slightly modified.

Source code in wildlife_datasets/datasets/datasets.py
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
def reorder_df(self, df: pd.DataFrame) -> pd.DataFrame:
    """Reorders rows and columns in the dataframe.

    Rows are sorted based on id.
    Columns are reorder based on the `default_order` list.

    Args:
        df (pd.DataFrame): A full dataframe of the data.

    Returns:
        A full dataframe of the data, slightly modified.
    """

    default_order = ['image_id', 'identity', 'path', 'bbox', 'date', 'keypoints', 'orientation', 'segmentation', 'species']
    df_names = list(df.columns)
    col_names = []
    for name in default_order:
        if name in df_names:
            col_names.append(name)
    for name in df_names:
        if name not in default_order:
            col_names.append(name)

    df = df.sort_values('image_id').reset_index(drop=True)
    return df.reindex(columns=col_names)

Metadata

Class for storing metadata.

Attributes:

Name Type Description
df DataFrame

A dataframe of the metadata.

Source code in wildlife_datasets/datasets/summary.py
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
class Summary():
    """Class for storing metadata.

    Attributes:
      df (pd.DataFrame): A dataframe of the metadata.
    """

    def __init__(self, path: str):
        """Loads the metadata from a csv file into a dataframe.

        The `animals` column is converted to a list.

        Args:
            path (str): Path of the csv file.
        """

        df = pd.read_csv(path, index_col='name')
        if 'animals' in df.columns:
            df.loc[df['animals'].isnull(), 'animals'] = '{}'
            df['animals'] = df['animals'].apply(lambda x: eval(x))
        self.df = df

    def __getitem__(self, item):
        return self.df.loc[item].dropna().to_dict()

__init__(path)

Loads the metadata from a csv file into a dataframe.

The animals column is converted to a list.

Parameters:

Name Type Description Default
path str

Path of the csv file.

required
Source code in wildlife_datasets/datasets/summary.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
def __init__(self, path: str):
    """Loads the metadata from a csv file into a dataframe.

    The `animals` column is converted to a list.

    Args:
        path (str): Path of the csv file.
    """

    df = pd.read_csv(path, index_col='name')
    if 'animals' in df.columns:
        df.loc[df['animals'].isnull(), 'animals'] = '{}'
        df['animals'] = df['animals'].apply(lambda x: eval(x))
    self.df = df

Utils

bbox_segmentation(bbox)

Convert bounding box to segmentation.

Parameters:

Name Type Description Default
bbox List[float]

Bounding box in the form [x, y, w, h].

required

Returns:

Type Description
List[float]

Segmentation mask in the form [x1, y1, x2, y2, ...].

Source code in wildlife_datasets/datasets/utils.py
 94
 95
 96
 97
 98
 99
100
101
102
103
104
def bbox_segmentation(bbox: List[float]) -> List[float]:
    """Convert bounding box to segmentation.

    Args:
        bbox (List[float]): Bounding box in the form [x, y, w, h].

    Returns:
        Segmentation mask in the form [x1, y1, x2, y2, ...].
    """

    return [bbox[0], bbox[1], bbox[0]+bbox[2], bbox[1], bbox[0]+bbox[2], bbox[1]+bbox[3], bbox[0], bbox[1]+bbox[3], bbox[0], bbox[1]]

create_id(string_col)

Creates unique ids from string based on MD5 hash.

Parameters:

Name Type Description Default
string_col Series

List of ids.

required

Returns:

Type Description
Series

List of encoded ids.

Source code in wildlife_datasets/datasets/utils.py
80
81
82
83
84
85
86
87
88
89
90
91
92
def create_id(string_col: pd.Series) -> pd.Series:
    """Creates unique ids from string based on MD5 hash.

    Args:
        string_col (pd.Series): List of ids.

    Returns:
        List of encoded ids.
    """

    entity_id = string_col.apply(lambda x: hashlib.md5(x.encode()).hexdigest()[:16])
    assert len(entity_id.unique()) == len(entity_id)
    return entity_id

crop_black(img)

Crops black borders from an image.

Parameters:

Name Type Description Default
img Image

Image to be cropped.

required

Returns:

Type Description
Image

Cropped image.

Source code in wildlife_datasets/datasets/utils.py
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
def crop_black(img: Image) -> Image:
    """Crops black borders from an image.    

    Args:
        img (Image): Image to be cropped.

    Returns:
        Cropped image.
    """

    y_nonzero, x_nonzero, _ = np.nonzero(img)
    return img.crop(
        (
            np.min(x_nonzero),
            np.min(y_nonzero),
            np.max(x_nonzero),
            np.max(y_nonzero),
        )
    )

data_directory(dir)

Changes context such that data directory is used as current work directory. Data directory is created if it does not exist.

Source code in wildlife_datasets/datasets/utils.py
164
165
166
167
168
169
170
171
172
173
174
175
176
177
@contextmanager
def data_directory(dir):
    '''
    Changes context such that data directory is used as current work directory.
    Data directory is created if it does not exist.
    '''
    current_dir = os.getcwd()
    if not os.path.exists(dir):
        os.makedirs(dir)
    os.chdir(dir)
    try:
        yield
    finally:
        os.chdir(current_dir)

find_images(root, img_extensions=('.png', '.jpg', '.jpeg'))

Finds all image files in folder and subfolders.

Parameters:

Name Type Description Default
root str

The root folder where to look for images.

required
img_extensions Tuple[str, ...]

Image extensions to look for, by default ('.png', '.jpg', '.jpeg').

('.png', '.jpg', '.jpeg')

Returns:

Type Description
DataFrame

Dataframe of relative paths of the images.

Source code in wildlife_datasets/datasets/utils.py
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
def find_images(
        root: str,
        img_extensions: Tuple[str, ...] = ('.png', '.jpg', '.jpeg')
        ) -> pd.DataFrame:
    """Finds all image files in folder and subfolders.

    Args:
        root (str): The root folder where to look for images.
        img_extensions (Tuple[str, ...], optional): Image extensions to look for, by default ('.png', '.jpg', '.jpeg').

    Returns:
        Dataframe of relative paths of the images.
    """

    data = [] 
    for path, directories, files in os.walk(root):
        for file in files:
            if file.lower().endswith(tuple(img_extensions)):
                data.append({'path': os.path.relpath(path, start=root), 'file': file})
    return pd.DataFrame(data)

is_annotation_bbox(segmentation, bbox, tol=0)

Checks whether segmentation is bounding box.

Parameters:

Name Type Description Default
segmentation List[float]

Segmentation mask in the form [x1, y1, x2, y2, ...].

required
bbox List[float]

Bounding box in the form [x, y, w, h].

required
tol float

Tolerance for difference.

0

Returns:

Type Description
bool

True if segmentation is bounding box within tolerance.

Source code in wildlife_datasets/datasets/utils.py
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
def is_annotation_bbox(
        segmentation: List[float],
        bbox: List[float],
        tol: float = 0
        ) -> bool:
    """Checks whether segmentation is bounding box.

    Args:
        segmentation (List[float]): Segmentation mask in the form [x1, y1, x2, y2, ...].
        bbox (List[float]): Bounding box in the form [x, y, w, h].
        tol (float, optional): Tolerance for difference.

    Returns:
        True if segmentation is bounding box within tolerance.
    """

    bbox_seg = bbox_segmentation(bbox)
    if len(segmentation) == len(bbox_seg):
        for x, y in zip(segmentation, bbox_seg):
            if abs(x-y) > tol:
                return False
    else:
        return False
    return True

load_image(path, max_size=None)

Loads an image.

Parameters:

Name Type Description Default
path str

Path of the image.

required
max_size int

Maximal size of the image or None (no restriction).

None

Returns:

Type Description
Image

Loaded image.

Source code in wildlife_datasets/datasets/utils.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
def load_image(path: str, max_size: int = None) -> Image:
    """Loads an image.

    Args:
        path (str): Path of the image.
        max_size (int, optional): Maximal size of the image or None (no restriction).

    Returns:
        Loaded image.
    """

    # We load it with OpenCV because PIL does not apply metadata.
    img = cv2.imread(path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = Image.fromarray(img)
    if max_size is not None:
        w, h = img.size
        if max(w, h) > max_size:
            c = max_size / max(w, h)
            img = img.resize((int(c*w), int(c*h)))
    return img

segmentation_bbox(segmentation)

Convert segmentation to bounding box.

Parameters:

Name Type Description Default
segmentation List[float]

Segmentation mask in the form [x1, y1, x2, y2, ...].

required

Returns:

Type Description
List[float]

Bounding box in the form [x, y, w, h].

Source code in wildlife_datasets/datasets/utils.py
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
def segmentation_bbox(segmentation: List[float]) -> List[float]:
    """Convert segmentation to bounding box.

    Args:
        segmentation (List[float]): Segmentation mask in the form [x1, y1, x2, y2, ...].

    Returns:
        Bounding box in the form [x, y, w, h].
    """

    x = segmentation[0::2]
    y = segmentation[1::2]
    x_min = np.min(x)
    x_max = np.max(x)
    y_min = np.min(y)
    y_max = np.max(y)
    return [x_min, y_min, x_max-x_min, y_max-y_min]