Skip to content

Inference

This inference module uses similarity scores to perform predictions on unseen data. This includes classification (KnnClassifier class) and ranking (TopkClassifier class) using using nearest neigbours. Similarity scores are expected to be in the form of 2D array with shape n_query x n_database.

inference.classifier

KnnClassifier(database_labels, k=1, return_scores=False)

Predict query label as k labels of nearest matches in the database. If there is a tie at a given k, the prediction with the best score is used.

Parameters:

Name Type Description Default
database_labels array

Array containing the labels of the database.

required
k int

The number of nearest neighbors to consider.

1
return_scores bool

Indicates whether to return scores along with predictions.

False

__call__(similarity)

Predicts the label for each query based on the k nearest matches in the database.

Parameters:

Name Type Description Default
similarity

A 2D similarity matrix with n_query x n_database shape.

required

Returns:

Type Description

If return_scores is False:

  • preds: Prediction for each query.

If return_scores is True, tuple of two arrays:

  • preds: Prediction for each query.
  • scores: The similarity scores corresponding to the predictions (mean for k > 1).

TopkClassifier(database_labels, k=10, return_all=False)

Predict top k query labels given nearest matches in the database.

Parameters:

Name Type Description Default
database_labels array

Array containing the labels of the database.

required
k int

The number of top predictions to return.

10
return_all bool

Indicates whether to return scores along with predictions.

False

__call__(similarity)

Predicts the top k labels for each query based on the similarity matrix.

Parameters:

Name Type Description Default
similarity

A 2D similarity matrix with n_query x n_database shape

required

Returns:

Type Description

If return_all is False, single 2D array of shape n_query x k

  • preds: The top k predicted labels for each query.

If return_all is True, tuple of three 2D arrays of shape n_query x k:

  • preds: The top k predicted labels for each query.
  • scores: The similarity scores corresponding to the top k predictions.
  • idx: The indices of the database entries corresponding to the top k predictions.