NumPy performance and fragmentation
NumPy is a standard package for working with arrays in Python. The algorithms I work on, originally written in MATLAB, were ported long ago to NumPy. They provide the reference implementation for our simulator engine. However, these routines pose two problems:
- They are written in a very OOP style which doesn't lend itself well to maintenance and modification by scientific users
- They are slower than a user wants them to be
These factors translate to a myriad of minor simulator implementations to which (including several of my own).
why are our simulations slow?
The simulations take many time steps, and the memory access pattern is irregular.
example data, kernels and timing
let's build NumPy from scratch
pip install -r requirements/build_requirements.txt
pip install . -Csetup-args=-Dcpu-baseline="avx2 fma3" -Csetup-args=-Dcpu-dispatch="max"
let's optimize the memory access instead
Intuitively, retrieving things which are nearby in memory requires less work.
how well does a very highly optimized implementation do?
ISPC fused kernel, random numbers etc.
fragmentation
what is jax about anyway?
JAX tries to bring Tensorflow thinking to users of NumPy
problems
- it's not that fast on CPU
- it hogs CPU pretty hard even on single threaded workloads
- it adds hundreds of MB to the deployment
- it can't be deployed ot a browser as in JupyterLite