Slax: a composable JAX library for rapid and flexible prototyping of spiking neural networks

Spiking neural networks (SNNs) offer rich temporal dynamics and unique capabilities, but their training presents challenges. While backpropagation through time with surrogate gradients is the defacto standard for training SNNs, it scales poorly with long time sequences. Alternative learning rules an...

Full description

Saved in:
Bibliographic Details
Main Authors: Thomas M Summe, Siddharth Joshi
Format: Article
Language:English
Published: IOP Publishing 2025-01-01
Series:Neuromorphic Computing and Engineering
Subjects:
Online Access:https://doi.org/10.1088/2634-4386/ada9a8
Tags: Add Tag
No Tags, Be the first to tag this record!
_version_ 1825202062998110208
author Thomas M Summe
Siddharth Joshi
author_facet Thomas M Summe
Siddharth Joshi
author_sort Thomas M Summe
collection DOAJ
description Spiking neural networks (SNNs) offer rich temporal dynamics and unique capabilities, but their training presents challenges. While backpropagation through time with surrogate gradients is the defacto standard for training SNNs, it scales poorly with long time sequences. Alternative learning rules and algorithms could help further develop models and systems across the spectrum of performance, bio-plausibility, and complexity. However, these alternatives are not consistently implemented with the same, if any, SNN framework, often complicating their comparison and use. To address this, we introduce Slax, a JAX-based library designed to accelerate SNN algorithm design and evaluation. Slax is compatible with the broader JAX and Flax ecosystem and provides optimized implementations of diverse training algorithms, enabling direct performance comparisons. Its toolkit includes methods to visualize and debug algorithms through loss landscapes, gradient similarities, and other metrics of model behavior during training. By streamlining the implementation and evaluation of novel SNN learning algorithms, Slax aims to facilitate research and development in this promising field.
format Article
id doaj-art-54d2ce5e74084dc19c74833729db3a41
institution Kabale University
issn 2634-4386
language English
publishDate 2025-01-01
publisher IOP Publishing
record_format Article
series Neuromorphic Computing and Engineering
spelling doaj-art-54d2ce5e74084dc19c74833729db3a412025-02-07T15:21:23ZengIOP PublishingNeuromorphic Computing and Engineering2634-43862025-01-015101400710.1088/2634-4386/ada9a8Slax: a composable JAX library for rapid and flexible prototyping of spiking neural networksThomas M Summe0https://orcid.org/0009-0004-9833-0477Siddharth Joshi1https://orcid.org/0000-0002-9201-9678Department of Computer Science and Engineering, University of Notre Dame , Notre Dame, IN 46556, United States of AmericaDepartment of Computer Science and Engineering, University of Notre Dame , Notre Dame, IN 46556, United States of AmericaSpiking neural networks (SNNs) offer rich temporal dynamics and unique capabilities, but their training presents challenges. While backpropagation through time with surrogate gradients is the defacto standard for training SNNs, it scales poorly with long time sequences. Alternative learning rules and algorithms could help further develop models and systems across the spectrum of performance, bio-plausibility, and complexity. However, these alternatives are not consistently implemented with the same, if any, SNN framework, often complicating their comparison and use. To address this, we introduce Slax, a JAX-based library designed to accelerate SNN algorithm design and evaluation. Slax is compatible with the broader JAX and Flax ecosystem and provides optimized implementations of diverse training algorithms, enabling direct performance comparisons. Its toolkit includes methods to visualize and debug algorithms through loss landscapes, gradient similarities, and other metrics of model behavior during training. By streamlining the implementation and evaluation of novel SNN learning algorithms, Slax aims to facilitate research and development in this promising field.https://doi.org/10.1088/2634-4386/ada9a8online and offline learningJAXreal-time recurrent learning (RTRL)neural network visualizationneuromorphic computingspiking neural network (SNN) training
spellingShingle Thomas M Summe
Siddharth Joshi
Slax: a composable JAX library for rapid and flexible prototyping of spiking neural networks
Neuromorphic Computing and Engineering
online and offline learning
JAX
real-time recurrent learning (RTRL)
neural network visualization
neuromorphic computing
spiking neural network (SNN) training
title Slax: a composable JAX library for rapid and flexible prototyping of spiking neural networks
title_full Slax: a composable JAX library for rapid and flexible prototyping of spiking neural networks
title_fullStr Slax: a composable JAX library for rapid and flexible prototyping of spiking neural networks
title_full_unstemmed Slax: a composable JAX library for rapid and flexible prototyping of spiking neural networks
title_short Slax: a composable JAX library for rapid and flexible prototyping of spiking neural networks
title_sort slax a composable jax library for rapid and flexible prototyping of spiking neural networks
topic online and offline learning
JAX
real-time recurrent learning (RTRL)
neural network visualization
neuromorphic computing
spiking neural network (SNN) training
url https://doi.org/10.1088/2634-4386/ada9a8
work_keys_str_mv AT thomasmsumme slaxacomposablejaxlibraryforrapidandflexibleprototypingofspikingneuralnetworks
AT siddharthjoshi slaxacomposablejaxlibraryforrapidandflexibleprototypingofspikingneuralnetworks