Args:
``features`` (``torch.Tensor``): The feature matrix.
``k`` (``int``): The number of nearest neighbors.
"""
assert features.ndim == 2, "The feature matrix should be 2-D."
assert (
k <= features.shape[0]
), "The number of nearest neighbors should be less than or equal to the number of vertices."
dist_matrix = torch.cdist(features, features, p=2)
_, nbr_indices = torch.topk(dist_matrix, k, largest=False)
return nbr_indices.tolist()