Slax: a composable JAX library for rapid and flexible prototyping of spiking neural networks
Spiking neural networks (SNNs) offer rich temporal dynamics and unique capabilities, but their training presents challenges. While backpropagation through time with surrogate gradients is the defacto standard for training SNNs, it scales poorly with long time sequences. Alternative learning rules an...
Saved in:
Main Authors: | , |
---|---|
Format: | Article |
Language: | English |
Published: |
IOP Publishing
2025-01-01
|
Series: | Neuromorphic Computing and Engineering |
Subjects: | |
Online Access: | https://doi.org/10.1088/2634-4386/ada9a8 |
Tags: |
Add Tag
No Tags, Be the first to tag this record!
|
_version_ | 1825202062998110208 |
---|---|
author | Thomas M Summe Siddharth Joshi |
author_facet | Thomas M Summe Siddharth Joshi |
author_sort | Thomas M Summe |
collection | DOAJ |
description | Spiking neural networks (SNNs) offer rich temporal dynamics and unique capabilities, but their training presents challenges. While backpropagation through time with surrogate gradients is the defacto standard for training SNNs, it scales poorly with long time sequences. Alternative learning rules and algorithms could help further develop models and systems across the spectrum of performance, bio-plausibility, and complexity. However, these alternatives are not consistently implemented with the same, if any, SNN framework, often complicating their comparison and use. To address this, we introduce Slax, a JAX-based library designed to accelerate SNN algorithm design and evaluation. Slax is compatible with the broader JAX and Flax ecosystem and provides optimized implementations of diverse training algorithms, enabling direct performance comparisons. Its toolkit includes methods to visualize and debug algorithms through loss landscapes, gradient similarities, and other metrics of model behavior during training. By streamlining the implementation and evaluation of novel SNN learning algorithms, Slax aims to facilitate research and development in this promising field. |
format | Article |
id | doaj-art-54d2ce5e74084dc19c74833729db3a41 |
institution | Kabale University |
issn | 2634-4386 |
language | English |
publishDate | 2025-01-01 |
publisher | IOP Publishing |
record_format | Article |
series | Neuromorphic Computing and Engineering |
spelling | doaj-art-54d2ce5e74084dc19c74833729db3a412025-02-07T15:21:23ZengIOP PublishingNeuromorphic Computing and Engineering2634-43862025-01-015101400710.1088/2634-4386/ada9a8Slax: a composable JAX library for rapid and flexible prototyping of spiking neural networksThomas M Summe0https://orcid.org/0009-0004-9833-0477Siddharth Joshi1https://orcid.org/0000-0002-9201-9678Department of Computer Science and Engineering, University of Notre Dame , Notre Dame, IN 46556, United States of AmericaDepartment of Computer Science and Engineering, University of Notre Dame , Notre Dame, IN 46556, United States of AmericaSpiking neural networks (SNNs) offer rich temporal dynamics and unique capabilities, but their training presents challenges. While backpropagation through time with surrogate gradients is the defacto standard for training SNNs, it scales poorly with long time sequences. Alternative learning rules and algorithms could help further develop models and systems across the spectrum of performance, bio-plausibility, and complexity. However, these alternatives are not consistently implemented with the same, if any, SNN framework, often complicating their comparison and use. To address this, we introduce Slax, a JAX-based library designed to accelerate SNN algorithm design and evaluation. Slax is compatible with the broader JAX and Flax ecosystem and provides optimized implementations of diverse training algorithms, enabling direct performance comparisons. Its toolkit includes methods to visualize and debug algorithms through loss landscapes, gradient similarities, and other metrics of model behavior during training. By streamlining the implementation and evaluation of novel SNN learning algorithms, Slax aims to facilitate research and development in this promising field.https://doi.org/10.1088/2634-4386/ada9a8online and offline learningJAXreal-time recurrent learning (RTRL)neural network visualizationneuromorphic computingspiking neural network (SNN) training |
spellingShingle | Thomas M Summe Siddharth Joshi Slax: a composable JAX library for rapid and flexible prototyping of spiking neural networks Neuromorphic Computing and Engineering online and offline learning JAX real-time recurrent learning (RTRL) neural network visualization neuromorphic computing spiking neural network (SNN) training |
title | Slax: a composable JAX library for rapid and flexible prototyping of spiking neural networks |
title_full | Slax: a composable JAX library for rapid and flexible prototyping of spiking neural networks |
title_fullStr | Slax: a composable JAX library for rapid and flexible prototyping of spiking neural networks |
title_full_unstemmed | Slax: a composable JAX library for rapid and flexible prototyping of spiking neural networks |
title_short | Slax: a composable JAX library for rapid and flexible prototyping of spiking neural networks |
title_sort | slax a composable jax library for rapid and flexible prototyping of spiking neural networks |
topic | online and offline learning JAX real-time recurrent learning (RTRL) neural network visualization neuromorphic computing spiking neural network (SNN) training |
url | https://doi.org/10.1088/2634-4386/ada9a8 |
work_keys_str_mv | AT thomasmsumme slaxacomposablejaxlibraryforrapidandflexibleprototypingofspikingneuralnetworks AT siddharthjoshi slaxacomposablejaxlibraryforrapidandflexibleprototypingofspikingneuralnetworks |