rank-1 take on rwkv7's in context learning
The explicit in-context learning (ICL) approach of RWKV7 nerd sniped 🤓 me when I saw it about 2 weeks ago. A "state" matrix integrates keys 🔑 & values 🔒 over the sequence, acting as an effective kv cache which the query can effectively fish 🎣 information out of later on.
While it may not pan out in terms of efficiency, the formulation told me that the queries can be performed via recurrence relation with just rank-1 operations (element-wise multiplies, adds and vector-vector dot products), which seemed interesting 🧐 since it could reduce register pressure in some cases?
reminder
here's a reminder of the scheme for a single head
expanding
it’s useful to start with then continuing with , observing that the outer products become a dot product scaling the column vector, and the diagonal w is just element wise multiply,
- let ,
from this we can start to see that the outputs are going to be a scan backwards in time, over key values, with a carried r value, for instance
import numpy as np
T,D = 32,8
r,w,k,v,w,n = np.random.randn(6,T,D)
ot = np.zeros(D)
cr = r[-1]
for i in range(T-1,-1,-1):
krn = (k[i] @ cr)*n[i]
ot += v[i]*krn
cr = w[i]*cr - k[i]*krn
this seems to be an interesting view of how the context of a token influences it: for instance, one could compute the norm of the increments i.e. np.linalg.norm(v[i]*krn)
to dynamically truncate the sum or with RAG to cite highly relevant parts of the context.
since the recurrence only applies per token, it means all k, v, w and n have to be visited for each token, which is terrible compared to just updating the S and then applying to r. unless the corresponding head size is enormous? idk yet. so this wouldn’t be beneficial for training or inference.
there are two ways to parallelize a bit
- compute the carry r values for many tokens at once and process with same k, v, w, n during the recurrence? not sure if better than computing state matrix but might be lower memory
- parallel scan, tho this may require creating the state matrices explicitly..
ODE interlude
go back to the gradient descent of the L2 loss for the in context learning, and formulate as a gradient descent in time,
then let and assume an Euler discretization with , the rwkv7 scheme is recovered. of course, we could apply the 2nd order Heun scheme instead
which yields the following (tho TODO check algebra) structurally similar update formula
this could in principle provide more accurate/stable gradients, under the interpretation of ICL as a smooth kv driven gradient descent. it would also work with the rank-1 recurrence above.