mædoc's notes

mucking about with small transformers

nanoGPT is a great example of something which "has teeth" but quite hackable; I took advantage of that and ran with it (dropping the MLP and layer norms), testing a DeltaNet style attention mechanism here and then porting the whole thing to an array oriented Jax (which I vastly prefer) here, complete with training & generation, see for instance this notebook.

attention, delta style

The attention part is straightforward,

Z = lambda x: (x - x.mean(axis=-1)[...,None])/x.std(ddof=1, axis=-1)[...,None]

def op(Sb, qkv):
    S, b = Sb
    b0, b1, b2 = b
    q, k, v = qkv
    Skk = -jp.einsum('bhij,bhi,bhj,h->bhij', S, k, k, b1)
    vk = jp.einsum('bhi,bhj,h->bhij', v, k, b2)
    S = S*(1-b0.reshape(nh,1,1)) + Skk + vk
    vt = jp.einsum('bhi,bhij->bhj', q, S)
    return (S, b), vt

def model1(params, x, phi=jax.nn.gelu, delta=True):
    Wi, Wo, lm_head, wte, b = params
    x = wte[x] # B, T, C
    for wi, wo in zip(Wi, Wo):
        q, k, v = jp.einsum('ij,bti->btj', wi, x).reshape(B,T,3,nh,hs).swapaxes(0,2)  # (3,T,B,nh,hs)
        q, k, v = phi(Z(q)), phi(Z(k)), Z(v)
        S0 = jp.einsum('bhi,bhj->bhij', v[0], k[0])
        _, vt = jax.lax.scan(op, (S0,b), (q, k, v))
        x = vt.swapaxes(0,1).reshape(B,T,C) @ wo + x
    return Z(x) @ lm_head

here, the Z function, like a z-score, replaces the trainable layer norm.

This is fairly easy to port to C as well, just needs some loops.

quantized arith?

My idea was to work towards a version that runs on ESP32 microcontroller, which has terrible floating-point performance, so most work on quantization with these models doesn't apply, since it's only quantized in memory (maybe?). This means all arith is on integers and one has to watch for overflow. So, for instance, I have a function which squares the maximum magnitude element and checks it's still in range,

def overhead(a, raise_=True):
    "checks arrays can be multiplied w/o overflowing int32"
    aa = lambda a: np.abs(a).max().astype(np.int64)**2
    imax = np.iinfo(np.int32).max
    of = 0
    while aa(a >> of) >= imax:
        of += 1
    if of > 0:
        if raise_:
            assert aa(a) < imax, f'overflow, a>>={of}'
        else:
            msg = f'a>>={of}'
            print(msg)
            a >>= of

then during the layer computation, one does the multiplies and some right-shifts, and checks the result is still safe to further multiply,

    for i, (wiq, woq, sq) in enumerate(zip(Wiq.astype(np.int32), Woq.astype(np.int32), Sq)):
        qq, kq, vq = (xq @ wiq.astype(np.int32)).reshape(3, nh, hs)
        qp, kp, vp = [xp + Wip for _ in range(3)]
        qq, kq, vq = phiq(Zq(qq, qp), qp), phiq(Zq(kq, kp), kp), Zq(vq, vp)
        rs = 13  # rshift loses precision but avoids overflow
        kkq = (kq[:, None]*kq[:, :, None]) >> rs
        vkq = (vq[:, :, None]*kq[:, None]) >> rs
        kkp = kp+kp-rs
        vkp = vp+kp-rs
        overhead(kkq, raise_=False)
        overhead(kkq, raise_=False)

where all the suffixed-p variables are the power-of-2 denominators for the fixed point numerators suffixed with q.

log-quantized weights

however, the weights I get from training are not uniformly distributed, so it kinda bugged me to quantify uniformly. also, ggml gets quite far with 4-bit quantization of weights. lastly, the cheapest int multiplies are shifts, right?

so if we quantize the weights as (scaled) powers of two,

w = Wi[0].copy()  # weights from x to qkv
sw = np.sign(w)
w *= sw
w = np.log2(w)
wp = -int(np.percentile(w.reshape(-1), 2))
w += wp
w[w < 0] = 0
w = (sw * w).astype(np.int8)
# would fit into 4-bit, just 16 bins
nu = np.unique(w).size
assert nu <= 16, nu

that last assert say that all the weights fit into 16 bins; in other words, this is 4-bit quantization.

then applying the weights required some head scratching since we pay attention to signs after shifts but before summation (e.g. in a dot product),

# x is the embedded token, xq the quantized version
x = wte.copy()
sx = np.sign(x)
x *= sx
xp = 16
xq = (x * 2**xp).astype(np.uint32)
# shift by log2 quant weight
qq = (xq[0][:, None] << np.abs(aw)).astype(np.int32)
# apply sign
sq = (sx[0][:, None] * sw).astype(np.int8)
# sum
qq = (sq*qq).sum(axis=0)
# rescale to avoid overflow
qqp = xp + wp
rs = qqp - 16  # aim for 16 bit res
qq >>= rs
qqp -= rs

4-bit ints in C?

in the numpy, we're kind of just using regular ints, but in C we can pack them with bit-fields,

#include <stdio.h>
#include <stdint.h>

struct ww {
  union {
    struct {
      int8_t w : 4;
      int8_t v : 4;
    };
    int8_t wv;
  };
};

int main()
{
  struct ww w = {.w=-3, .v=10};
  return printf("sizeof(w) = %lu bytes, w = {.w=%d, .v=%d}\n", sizeof(w), w.w, w.v);
}

then compiling, we get nice warning about range,

$ gcc -Wpedantic q4.c && ./a.out
q4.c: In function ‘main’:
q4.c:16:28: warning: overflow in conversion from ‘int’ to ‘signed char:4’ changes value from ‘10’ to ‘-6’ [-Woverflow]
   16 |   struct ww w = {.w=-3, .v=10};
      |                            ^~
sizeof(w) = 1 bytes, w = {.w=-3, .v=-6}

and importantly, the compiler says the struct only eats 1 byte.

activation function

a lot of activation functions involve exp(-x) but I only know how to multiply, add, shift integers. this paper details some polynomial approximations that make use of the symmetry of 1/(1+exp(-|x|) around x=0. they don't match exp(-x) exactly, but since we're training our network from scratch, we can just use the inexact version and the training makes it work.

so I ended up with just

def phi(x_):
    x = -np.abs(x_)
    emx = 1 - x*(1 - x/2*(1 - x/2))
    s = 1/(1 + emx)
    return np.where(x_ >= 0, 1-s, s)

which is a lot more obvious to translate to int arith, maybe just the integer division in 1/(1+emx) might be expensive.

next steps

since I initially wanted to scratch the itch of running a transformer on the ESP32c6, then I should

ofc there's always more yak shaving that could make that happen faster like