mædoc's notes

rank-1 take on rwkv7's in context learning

The explicit in-context learning (ICL) approach of RWKV7 nerd sniped 🤓 me when I saw it about 2 weeks ago. A "state" matrix integrates keys 🔑 & values 🔒 over the sequence, acting as an effective kv cache which the query can effectively fish 🎣 information out of later on.

While it may not pan out in terms of efficiency, the formulation told me that the queries can be performed via recurrence relation with just rank-1 operations (element-wise multiplies, adds and vector-vector dot products), which seemed interesting 🧐 since it could reduce register pressure in some cases?

reminder

here's a reminder of the scheme for a single head

expanding ot

it’s useful to start with o0 then continuing with o1, observing that the outer products become a dot product scaling the column vector, and the diagonal w is just element wise multiply,

from this we can start to see that the outputs ot are going to be a scan backwards in time, over key values, with a carried r value, for instance

import numpy as np
T,D = 32,8
r,w,k,v,w,n = np.random.randn(6,T,D)
ot = np.zeros(D)
cr = r[-1]
for i in range(T-1,-1,-1):
	krn = (k[i] @ cr)*n[i]
	ot += v[i]*krn
	cr = w[i]*cr - k[i]*krn

this seems to be an interesting view of how the context of a token influences it: for instance, one could compute the norm of the increments i.e. np.linalg.norm(v[i]*krn) to dynamically truncate the sum or with RAG to cite highly relevant parts of the context.

since the recurrence only applies per token, it means all k, v, w and n have to be visited for each token, which is terrible compared to just updating the S and then applying to r. unless the corresponding head size is enormous? idk yet. so this wouldn’t be beneficial for training or inference.

there are two ways to parallelize a bit

ODE interlude

go back to the gradient descent of the L2 loss for the in context learning, and formulate as a gradient descent in time,

S˙=w~S+n(vTkSkTk)

then let w~=1w and assume an Euler discretization St=St1+dtf(St) with dt=1, the rwkv7 scheme is recovered. of course, we could apply the 2nd order Heun scheme instead

which yields the following (tho TODO check algebra) structurally similar update formula

St=St1(1w~+w~22+nkTk(w~1+nkTk2))+nvTk(w~2+1nkTk2)

this could in principle provide more accurate/stable gradients, under the interpretation of ICL as a smooth kv driven gradient descent. it would also work with the rank-1 recurrence above.

#llm #rwkv