Skip to content

Vortex quantization and topk kernel adaption#1

Open
zxr-creator wants to merge 14 commits intoInfini-AI-Lab:v1from
zxr-creator:v1
Open

Vortex quantization and topk kernel adaption#1
zxr-creator wants to merge 14 commits intoInfini-AI-Lab:v1from
zxr-creator:v1

Conversation

@zxr-creator
Copy link
Copy Markdown

Add INT8 Quantization Support for Vortex

This PR adds INT8 quantization support to the Vortex sparse attention framework to reduce memory usage and enable low-precision execution.

Main Changes

Implement INT8 quantization with adjustments to improve memory utilization.

Add preliminary FP8 quantization support.

Update reduce_pp_kernel parameters to support quantized data.

Adapt the Top-K kernel from SGLang for Vortex.

Add a runtime parameter to switch between two Top-K kernels (naive and sglang).

Add RTX PRO 6000 compatibility and fix several Vortex kernel issues.

zxr-creator and others added 14 commits February 22, 2026 23:59
Key changes:
1. Memory Pool (`vtx_graph_memory_pool.py`):
   - Removed hardcoded bf16 assertions in `VTXGraphCachePool` to support `torch.int8` allocations.
   - Added parallel `float32` scale buffers (`k_scale`, `v_scale`) mapped to the paged layout.
   - Preserved `bfloat16` shadow buffers (`k_bf16`) for auxiliary metadata (e.g., centroids) to ensure the Vortex sparse indexer/TopK remains unaffected and mathematically identical.

2. Quantize-on-Write (`set_kv.py`):
   - Implemented a custom Triton kernel (`set_kv_buffer_int8_kernel`) that quantizes incoming `bf16` tokens into `int8` on the fly using per-token absmax scaling (`scale = max(abs(x)) / 127.0`).
   - Wired the new launcher into the cache update flow.

3. Decode Path (`vtx_graph_backend.py` & `paged_decode_int8.py`):
   - Bypassed FlashInfer for INT8 decoding.
   - Wired in the custom Triton decode kernel (`paged_decode_int8`) that reads the `int8` pages and `float32` scales directly into SRAM, performing fused inline dequantization without allocating temporary full-cache VRAM buffers.
   - Seamlessly integrated with existing sparse routing indices (`indptr`, `indices`).

4. Prefill Path (`vtx_graph_backend.py` & `paged_prefill_int8.py`):
   - Implemented an OOM-safe `bf16` fallback for prefill.
   - Added a new Triton kernel (`dequant_paged_int8_to_bf16`) to dynamically extract and dequantize *only the accessed pages* for the current batch into a tiny, compacted `bf16` buffer.
   - Modified the FlashInfer `BatchPrefillWithPagedKVCacheWrapper` planner to map over the compacted subset indices, entirely avoiding full-cache dequantization OOMs.
Merge the warpper modification with the v1
…gs and pages; fix on the previous quantization implementaion, with lanuch_graph dtype set to the quant type
… (naive sparse attention, flash sparse attention, flashmoba)
- Introduced a comprehensive benchmarking suite for TopK kernel variants, measuring kernel-level latency.
- Added scripts for offline calibration of TopK mapping modes, including:# 0: None           — original fp16 bit-pattern bucketing
# 1: LUT CDF        — LUT-based CDF equalization (calibrated)
# 2: Quantile       — piecewise-linear quantile mapping (calibrated)
# 3: Power          — y = sign(x) * |x|^p
# 4: Log            — y = sign(x) * log(|x| + 1)
# 5: Index Cache    — reuse previous layer's indices
# 6: Asinh          — y = asinh(beta * x)
# 7: Log1p          — y = sign(x) * log1p(alpha * |x|)
# 8: Trunc8         — bf16 upper-8-bit bucketing
-  Adding various remap functions for the bucket sort in sglang topk kernel, with evaluation and visualization scripts.
- Implemented analysis tools for TopK distribution profiling.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant