Skip to content

Reference similarity

inference.classifier

KnnClassifier(database_labels, k=1, return_scores=False)

Predict query label as k labels of nearest matches in the database. If there is a tie at a given k, the prediction with the best score is used.

Parameters:

Name Type Description Default
database_labels ndarray

Array containing the labels of the database.

required
k int

The number of nearest neighbors to consider.

1
return_scores bool

Indicates whether to return scores along with predictions.

False

__call__(similarity)

Predicts the label for each query based on the k nearest matches in the database.

Parameters:

Name Type Description Default
similarity ndarray

A 2D similarity matrix with n_query x n_database shape.

required

Returns:

Type Description
ndarray | Tuple[ndarray, ndarray]

If return_scores is False:

  • preds (np.ndarray): Prediction for each query.
ndarray | Tuple[ndarray, ndarray]

If return_scores is True, tuple of two arrays:

  • preds (np.ndarray): Prediction for each query.
  • scores (np.ndarray): The similarity scores corresponding to the predictions (mean for k > 1).

TopkClassifier(database_labels, k=10, return_all=False)

Predict top k query labels given nearest matches in the database.

Parameters:

Name Type Description Default
database_labels ndarray

Array containing the labels of the database.

required
k int

The number of top predictions to return.

10
return_all bool

Indicates whether to return scores along with predictions.

False

__call__(similarity)

Predicts the top k labels for each query based on the similarity matrix.

Parameters:

Name Type Description Default
similarity ndarray

A 2D similarity matrix with n_query x n_database shape

required

Returns:

Type Description
ndarray | Tuple[ndarray, ndarray, ndarray]

If return_all is False, single 2D array of shape n_query x k

  • preds (np.ndarray): The top k predicted labels for each query.
ndarray | Tuple[ndarray, ndarray, ndarray]

If return_all is True, tuple of three 2D arrays of shape n_query x k:

  • preds (np.ndarray): The top k predicted labels for each query.
  • scores (np.ndarray): The similarity scores corresponding to the top k predictions.
  • idx (np.ndarray): The indices of the database entries corresponding to the top k predictions.