Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 16 additions & 16 deletions bindsnet/evaluation/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,36 +27,36 @@ def assign_labels(
:return: Tuple of class assignments, per-class spike proportions, and per-class
firing rates.
"""

n_neurons = spikes.size(2)

if rates is None:
rates = torch.zeros((n_neurons, n_labels), device=spikes.device)

# Sum over time dimension (spike ordering doesn't matter).
spikes = spikes.sum(1)

for i in range(n_labels):
# Create mask (faster and allows future steps to stay on GPU).
mask = (labels == i)
# Count the number of samples with this label.
n_labeled = torch.sum(labels == i).float()
n_labeled = mask.sum().float()

if n_labeled > 0:
# Get indices of samples with this label.
indices = torch.nonzero(labels == i).view(-1)

# Compute average firing rates for this label.
selected_spikes = torch.index_select(
spikes, dim=0, index=torch.tensor(indices)
)
rates[:, i] = alpha * rates[:, i] + (
torch.sum(selected_spikes, 0) / n_labeled
)
# Get indices of samples with this label (masking is faster and stays on the GPU).
label_sum = spikes[mask].sum(0)
# Update rates.
rates[:, i] = alpha * rates[:, i] + (label_sum / n_labeled)

# Compute proportions of spike activity per class.
proportions = rates / rates.sum(1, keepdim=True)
proportions[proportions != proportions] = 0 # Set NaNs to 0
# Compute proportions (and use 'torch.where' to avoid NaN bug).
total_activity = rates.sum(1, keepdim=True)
proportions = torch.where(total_activity > 0, rates / total_activity, torch.zeros_like(rates))

# Neuron assignments are the labels they fire most for.
assignments = torch.max(proportions, 1)[1]
max_vals, assignments = torch.max(proportions, 1)

# Set unassigned (silent) neurons to -1 instead of defaulting to 0.
assignments[max_vals == 0] = -1

return assignments, proportions, rates

Expand Down