Installation¶
JLNN requires Python 3.11 or later. Since the framework is built on the JAX stack, we recommend paying attention to choosing the correct version based on your hardware.
Standard Installation¶
You can install the latest stable version directly from PyPI:
pip install jax-lnn
Alternatively, for the latest bleeding-edge version, install directly from the GitHub repository:
pip install git+https://github.com/RadimKozl/JLNN.git
Development Installation¶
If you want to contribute to JLNN, run benchmarks, or modify the source code, we recommend a development installation.
Using uv (recommended):
git clone https://github.com/RadimKozl/JLNN.git
cd JLNN
uv sync
Using pip:
git clone https://github.com/RadimKozl/JLNN.git
cd JLNN
pip install -e ".[test]"
JAX & CUDA Specifics¶
JLNN runs best on accelerators (GPU/TPU). In Google Colab, the default installation might require a runtime restart to initialize CUDA correctly.
Example for GPU (CUDA 12) support:
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Note
If you are using Google Colab, remember to restart your session after installation to ensure JAX can access the GPU. You can do this programmatically:
import os; os.kill(os.getpid(), 9)
Dependencies¶
The framework automatically installs these key libraries: * jax & jaxlib: Computing core. * flax: For managing the state of neural networks (we use the modern NNX API). * lark: For parsing logical formulas. * networkx: For working with the graph structure of the model. * optax: For optimization and learning.
Verification¶
To verify that everything works correctly, try a simple import:
import jlnn
import jax
print(f"JLNN version: {jlnn.__version__}")
print(f"Available devices: {jax.devices()}")
Support for OS¶
Linux: Full support (recommended).
macOS: Support for Apple Silicon processors (M1/M2/M3) via Metal acceleration.
Windows: Support via WSL2 (Windows Subsystem for Linux) is recommended for GPU acceleration.