From 21c13d36d31b1c102144d97559f6ee86f56e07cb Mon Sep 17 00:00:00 2001 From: "Rollin M. Omari" Date: Wed, 11 Feb 2026 20:52:19 +1030 Subject: [PATCH 1/3] Improve assign_labels accuracy and performance Resolved a bug where silent neurons defaulted to the first class label. Optimized indexing to prevent unnecessary host-to-device transfers and added robust NaN handling for firing rate proportions. --- bindsnet/evaluation/evaluation.py | 49 +++++++++++-------------------- 1 file changed, 17 insertions(+), 32 deletions(-) diff --git a/bindsnet/evaluation/evaluation.py b/bindsnet/evaluation/evaluation.py index cd4fe2a6..aba0fb96 100644 --- a/bindsnet/evaluation/evaluation.py +++ b/bindsnet/evaluation/evaluation.py @@ -12,51 +12,36 @@ def assign_labels( rates: Optional[torch.Tensor] = None, alpha: float = 1.0, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - # language=rst - """ - Assign labels to the neurons based on highest average spiking activity. - :param spikes: Binary tensor of shape ``(n_samples, time, n_neurons)`` of a single - layer's spiking activity. - :param labels: Vector of shape ``(n_samples,)`` with data labels corresponding to - spiking activity. - :param n_labels: The number of target labels in the data. - :param rates: If passed, these represent spike rates from a previous - ``assign_labels()`` call. - :param alpha: Rate of decay of label assignments. - :return: Tuple of class assignments, per-class spike proportions, and per-class - firing rates. - """ n_neurons = spikes.size(2) + device = spikes.device # Keep everything on the same device. if rates is None: - rates = torch.zeros((n_neurons, n_labels), device=spikes.device) + rates = torch.zeros((n_neurons, n_labels), device=device) # Sum over time dimension (spike ordering doesn't matter). - spikes = spikes.sum(1) - + summed_spikes = spikes.sum(1) + for i in range(n_labels): # Count the number of samples with this label. - n_labeled = torch.sum(labels == i).float() + mask = (labels == i) + n_labeled = mask.sum().float() if n_labeled > 0: - # Get indices of samples with this label. - indices = torch.nonzero(labels == i).view(-1) - - # Compute average firing rates for this label. - selected_spikes = torch.index_select( - spikes, dim=0, index=torch.tensor(indices) - ) - rates[:, i] = alpha * rates[:, i] + ( - torch.sum(selected_spikes, 0) / n_labeled - ) + # Get indices of samples with this label (masking is faster and stays on the GPU). + label_sum = summed_spikes[mask].sum(0) + # Update rates. + rates[:, i] = alpha * rates[:, i] + (label_sum / n_labeled) - # Compute proportions of spike activity per class. - proportions = rates / rates.sum(1, keepdim=True) - proportions[proportions != proportions] = 0 # Set NaNs to 0 + # 3. Compute proportions (and use 'torch.where' to avoid NaN bug). + total_activity = rates.sum(1, keepdim=True) + proportions = torch.where(total_activity > 0, rates / total_activity, torch.zeros_like(rates)) # Neuron assignments are the labels they fire most for. - assignments = torch.max(proportions, 1)[1] + max_vals, assignments = torch.max(proportions, 1) + + # Set unassigned (silent) neurons to -1 instead of defaulting to 0. + assignments[max_vals == 0] = -1 return assignments, proportions, rates From 9febfff3041e4a5e4451f12cc01bca8861a0de14 Mon Sep 17 00:00:00 2001 From: "Rollin M. Omari" Date: Wed, 11 Feb 2026 21:20:39 +1030 Subject: [PATCH 2/3] Reverted variable names to original Reverted variable names to original. --- bindsnet/evaluation/evaluation.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/bindsnet/evaluation/evaluation.py b/bindsnet/evaluation/evaluation.py index aba0fb96..6648411a 100644 --- a/bindsnet/evaluation/evaluation.py +++ b/bindsnet/evaluation/evaluation.py @@ -14,26 +14,26 @@ def assign_labels( ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: n_neurons = spikes.size(2) - device = spikes.device # Keep everything on the same device. if rates is None: - rates = torch.zeros((n_neurons, n_labels), device=device) + rates = torch.zeros((n_neurons, n_labels), device=spikes.device) # Sum over time dimension (spike ordering doesn't matter). - summed_spikes = spikes.sum(1) + spikes = spikes.sum(1) for i in range(n_labels): - # Count the number of samples with this label. + # Create mask (faster and allows future steps to stay on GPU). mask = (labels == i) + # Count the number of samples with this label. n_labeled = mask.sum().float() if n_labeled > 0: # Get indices of samples with this label (masking is faster and stays on the GPU). - label_sum = summed_spikes[mask].sum(0) + label_sum = spikes[mask].sum(0) # Update rates. rates[:, i] = alpha * rates[:, i] + (label_sum / n_labeled) - # 3. Compute proportions (and use 'torch.where' to avoid NaN bug). + # Compute proportions (and use 'torch.where' to avoid NaN bug). total_activity = rates.sum(1, keepdim=True) proportions = torch.where(total_activity > 0, rates / total_activity, torch.zeros_like(rates)) From a2cbace526b7bae1f46fbcfdc4040efc85ed32e4 Mon Sep 17 00:00:00 2001 From: "Rollin M. Omari" Date: Wed, 11 Feb 2026 21:25:28 +1030 Subject: [PATCH 3/3] Reverted accidental removal of docstring Restored docstring --- bindsnet/evaluation/evaluation.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/bindsnet/evaluation/evaluation.py b/bindsnet/evaluation/evaluation.py index 6648411a..f2f44dfc 100644 --- a/bindsnet/evaluation/evaluation.py +++ b/bindsnet/evaluation/evaluation.py @@ -12,6 +12,21 @@ def assign_labels( rates: Optional[torch.Tensor] = None, alpha: float = 1.0, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # language=rst + """ + Assign labels to the neurons based on highest average spiking activity. + + :param spikes: Binary tensor of shape ``(n_samples, time, n_neurons)`` of a single + layer's spiking activity. + :param labels: Vector of shape ``(n_samples,)`` with data labels corresponding to + spiking activity. + :param n_labels: The number of target labels in the data. + :param rates: If passed, these represent spike rates from a previous + ``assign_labels()`` call. + :param alpha: Rate of decay of label assignments. + :return: Tuple of class assignments, per-class spike proportions, and per-class + firing rates. + """ n_neurons = spikes.size(2)