grokking RWKV
WIP over winter holidays
Resources
- Discord channel has enormous stream of info
- https://wiki.rwkv.com/
- https://github.com/BlinkDL/RWKV-LM is the main repo
- https://raw.githubusercontent.com/BlinkDL/RWKV-LM/main/RWKV-v7.png describes the in-context-learning approach where a matrix state learns online an « internal model » mapping r values to outputs
- https://arxiv.org/abs/2404.05892 describes v5 & v6
- https://ben.bolte.cc/posts/2023-06-16-rwkv-model good explanation of the math there
- https://github.com/JL-er/RWKV-PEFT does fine tuning
- https://github.com/johanwind/wind_rwkv Wind’s kernels CUDA & Triton kernels
- https://github.com/codekansas/ml-pretrained/blob/master/pretrained/rwkv.py another implementation?
- https://github.com/BlinkDL/modded-nanogpt-rwkv/blob/master/train_rwkv7.py another implementation
https://github.com/huggingface/candle/blob/main/candle-transformers/src/models/rwkv_v6.rs rust impl of v6
https://github.com/Jellyfish042/LongMamba/tree/main context length verification
Stuff to try
- Bare in context learning to demo the concepts
- Implement a plain NumPy version of rwkv7 for inference
- Implement plain C99 version with SIMD and maybe the llama matmul
- For training smaller models on CPU would be great
- Consider log scaled versions?
v5-v7 math overview
The innovations in RWKV are mainly in the time mixing parts, so let's cover those first.
Eagle (v5)
Drops the iterative ratio of exponentials of v4 and uses matrix valued state tracking key-value inputs.
- lerp
- box_t
- w
- wkv_t
- iterative version
- o_t
Finch (v5)
Uses a data dependent decay
- lora_box
- ddlerp_box
- box_t
- d_t
- w_t
- wkv_t
- iterative version
Goose (v7)
Employs an explicit in-context iterative least squares which is reminiscent of update steps of a Kalman-Filter?
- $\partial L / \partial S = ... $
Channel mixing
some tricks with ICL
- rewrite as ODE
- use rank-1 ops instead of forming S
- parallel scan for computing S
working through these steps reinforced how S is a kv cache for retrieval.