Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .jules/bolt.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,7 @@ Action: Apply loop unrolling for max reductions in high-frequency typed array op
## 2024-11-20 - Softmax math.exp 8x unrolling with local var cache
Learning: Unrolling the `Math.exp` accumulation loop to 8x and caching the multiplication `(tokenLogits[i] - maxLogit) * invTemp` into local variables before passing to `Math.exp` yields a measurable performance improvement (~4%) over the previous 4x unrolled implementation in the V8 engine, by reducing property access and allowing better instruction-level parallelism.
Action: Utilize 8x loop unrolling paired with local variable caching for tight floating-point accumulation loops over TypedArrays.

## 2024-11-20 - Unrolling Float32Array argmax with direct array access
Learning: While caching TypedArray elements into local variables helps with loops computing sums (like `Math.exp` accumulations), it paradoxically slows down pure conditional branching loops (like `argmax`). Reading array values into local variables inside an unrolled argmax block in V8 actually introduces register spilling/deoptimization overhead compared to standard direct-access array bounds checks.
Action: For tight `argmax` reductions over TypedArrays, unroll the loop but access array values directly (e.g. `if (arr[i] > max) max = arr[i];`).
31 changes: 12 additions & 19 deletions src/parakeet.js
Original file line number Diff line number Diff line change
Expand Up @@ -808,26 +808,19 @@ export class ParakeetModel {
for (; i < tLen % 8; i++) {
if (tokenLogits[i] > maxLogit) { maxLogit = tokenLogits[i]; maxId = i; }
}
// Optimization: Reading values into local variables (v0 to v7) within the
// unrolled block before sequential comparisons avoids redundant TypedArray
// index lookups and bounds-checking overhead in V8 when a new max is found.
// Benchmark-driven finding: reading into local variables does NOT help
// over direct array access in modern V8. It actually causes register spilling
// and deopt overhead compared to standard direct-access array bounds checks.
// We keep the loop unrolled but access array values directly.
for (; i < tLen; i += 8) {
const v0 = tokenLogits[i];
const v1 = tokenLogits[i+1];
const v2 = tokenLogits[i+2];
const v3 = tokenLogits[i+3];
const v4 = tokenLogits[i+4];
const v5 = tokenLogits[i+5];
const v6 = tokenLogits[i+6];
const v7 = tokenLogits[i+7];
if (v0 > maxLogit) { maxLogit = v0; maxId = i; }
if (v1 > maxLogit) { maxLogit = v1; maxId = i + 1; }
if (v2 > maxLogit) { maxLogit = v2; maxId = i + 2; }
if (v3 > maxLogit) { maxLogit = v3; maxId = i + 3; }
if (v4 > maxLogit) { maxLogit = v4; maxId = i + 4; }
if (v5 > maxLogit) { maxLogit = v5; maxId = i + 5; }
if (v6 > maxLogit) { maxLogit = v6; maxId = i + 6; }
if (v7 > maxLogit) { maxLogit = v7; maxId = i + 7; }
if (tokenLogits[i] > maxLogit) { maxLogit = tokenLogits[i]; maxId = i; }
if (tokenLogits[i+1] > maxLogit) { maxLogit = tokenLogits[i+1]; maxId = i + 1; }
if (tokenLogits[i+2] > maxLogit) { maxLogit = tokenLogits[i+2]; maxId = i + 2; }
if (tokenLogits[i+3] > maxLogit) { maxLogit = tokenLogits[i+3]; maxId = i + 3; }
if (tokenLogits[i+4] > maxLogit) { maxLogit = tokenLogits[i+4]; maxId = i + 4; }
if (tokenLogits[i+5] > maxLogit) { maxLogit = tokenLogits[i+5]; maxId = i + 5; }
if (tokenLogits[i+6] > maxLogit) { maxLogit = tokenLogits[i+6]; maxId = i + 6; }
if (tokenLogits[i+7] > maxLogit) { maxLogit = tokenLogits[i+7]; maxId = i + 7; }
}
Comment on lines 815 to 824
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For a potential further optimization, you could consider reducing the data dependency on maxLogit within the unrolled loop. Currently, each if statement depends on the result of the previous one, creating a serial chain that can limit instruction-level parallelism.

An alternative pattern is to find the maximum within the 8-element block first, and then update the global maxLogit only once. This can allow the JS engine to better optimize the comparisons. While this re-introduces a couple of local variables, the access pattern is different from the code that was removed, and it might be handled more efficiently by the JIT compiler. Given the performance-critical nature of this loop, it could be worth benchmarking.

Example implementation:

for (; i < tLen; i += 8) {
  let localMax = tokenLogits[i];
  let localMaxId = i;

  if (tokenLogits[i+1] > localMax) { localMax = tokenLogits[i+1]; localMaxId = i + 1; }
  if (tokenLogits[i+2] > localMax) { localMax = tokenLogits[i+2]; localMaxId = i + 2; }
  if (tokenLogits[i+3] > localMax) { localMax = tokenLogits[i+3]; localMaxId = i + 3; }
  if (tokenLogits[i+4] > localMax) { localMax = tokenLogits[i+4]; localMaxId = i + 4; }
  if (tokenLogits[i+5] > localMax) { localMax = tokenLogits[i+5]; localMaxId = i + 5; }
  if (tokenLogits[i+6] > localMax) { localMax = tokenLogits[i+6]; localMaxId = i + 6; }
  if (tokenLogits[i+7] > localMax) { localMax = tokenLogits[i+7]; localMaxId = i + 7; }

  if (localMax > maxLogit) {
    maxLogit = localMax;
    maxId = localMaxId;
  }
}


// Compute maxVal (scaled) only if needed for softmax stability or logProbs
Expand Down
Loading