grokking the FFT
WIP during winter break 🎄
Inspired by the now-famous ggml and its mission to make mind-bending models runnable by nearly anyone, I wanted to contribute and found an issue to add the FFT. Having implemented the spherical harmonic transform from scratch previously... I couldn't just satisfy my curiosity with a simple batched Cooley-Tukey implementation. Some aspects I got interested in were
- higher radix like 4 or 8 to improve instruction throughput
- Stockham approach which avoids reversal, perhaps more amenable to SIMD
- an AFAICT new trick inspired by SOTA SHT
resources
a few things I've leaned on while working on this
- the Wikipedia page on Cooley-Tukey
- a few clever pages on various implementations
- my own simplest possible implementation which works with SymPy for code generation:
def ft(x):
N = len(x)
if N == 2:
a, b = x
return a + b, a - b
else:
E, O = ft(x[::2]), ft(x[1::2])
Ws = np.array([exp(-2*pi*I*k/N) for k in range(N//2)])
return np.r_[E + Ws*O, E - Ws*O]
# build input sequence symbolically as array elements y[k]
y = np.array([Symbol(f'y[{k}]') for k in range(N)])
# apply FFT symbolically
Y = ft(y).tolist()
assert len(Y) == N
# for ccode to work, we use just floats, not complex numbers
Yr, Yi = [re(_) for _ in Y], [im(_) for _ in Y]
# cse improves speed
aux, Yri = cse(Yr + Yi)
math
a few quick reminders,
radix 2
From the above we can see by hand that the short radix transforms are fairly simple, so for radix 2
radix 4
twiddles
An interesting aspect of the FT is that this building block can then be used to build larger transforms, by adding the so-called twiddle factors.
TODO add derivation of radix 4 by applying the radix 2 to itself, then adding the twiddle factors.
Stockham
WIP
new trick?
One of the optimizations used for FFTs is to precompute the twiddle factors. This can also be done in the Legendre Transform part of the SHT, but an interesting aspect of the SHTns
implementation is to compute the transform coefficients iteratively during the summation loop, reducing pressure on memory. It seemed like a similar trick could be used in the FFT: instead of computing factors beforehand or online, instead compute them iteratively.
For example, given the formula above for , let , then the summation would require
float _Complex j={0.f,1.f}, Xk = {0.f, 0.f};
for (int n=0; n<N; n++) {
float _Complex w = exp(-2*pi*j*n/N);
Xk += w * x[n];
}
while the exp
is likely to cost less than similar from-scratch computation of coefficients for SHT, it's still a lot of instructions to churn through in a tight loop. The trick would be to remark that the coefficient w
is changing by a constant angle on each iteration, which lets us evaluate a rotation matrix, and apply each iteration,
float _Complex j={0.f,1.f}, Xk = {0.f, 0.f};
float theta = -2*pi/N;
float _Complex R[4] = {cosf(theta), -sinf(theta), sin(theta), cos(theta)};
float _Complex w = {cosf(0.f), sinf(0.f)};
for (int n=0; n<N; n++) {
w = R[0]*creal(w)+R[1]*cimag(w) + I*(R[2]*creal(w)+R[3]*cimag(w));
Xk += w * x[n];
}
not entirely sure about the C11 complex float syntax but this should be a smaller number of flops compared to the transcendentals.
TODO compare on the compiler site and also timing
implementation for ggml
My feeling is finally that it should be easy to crank out a simple implementation, and then there'll be a bunch of things to tinker with. That said, the value add would be to ensure end-to-end tests with whisper.cpp or similar to ensure the added ops are really useful.
Still, my guess is that the largest use in ggml is going to be batched 1D FFTs, and batching changes the "fast" part, since the batching can absorb a lot of the memory bandwidth available.
So this is still a work in progress