Skip to content

Conversation

@JackChuang
Copy link
Contributor

@JackChuang JackChuang commented Sep 5, 2025

Summary

This PR introduces FP4 (E2M1) support for Multi-Head Latent Attention (MLA) KV cache in SGLang, enabling low-precision caching to reduce memory usage and improve inference efficiency. It integrates FP4 quantization utilities, Triton kernels, and unit tests while remaining backward compatible with FP16/FP8. See #10083, points 1-1, for more context.

Co-authored-by: @yicwang Yichen Wang [email protected]

Usage

Added --kv-cache-dtype=fp4_e2m1 option.

$ python3 -m sglang.launch_server --kv-cache-dtype fp4_e2m1 ... 

Key Changes

  1. Server Argument Extension

    • Added --kv-cache-dtype=fp4_e2m1 option.
  2. FP4 Quantization Utility

    • Added KVFP4QuantizeUtil with batched_quantize and batched_dequantize methods for block-wise (16) processing of [M, N, K] tensors.
  3. Core KV Cache Integration

    • Updated ModelRunner and MLATokenToKVPool to support FP4 KV cache.
    • Added kv_scale_buffer for FP4 scaling.
    • Implemented Triton kernel set_mla_kv_scale_buffer_kernel for efficient nope+rope tensor handling.
    • Maintains backward compatibility with FP16/FP8 KV cache.
  4. Unit Test & Benchmark

    • Tests validate FP4 correctness via MSE, MAE, PSNR, and Relative Error.
    • Benchmarks FP4 vs FP8 performance on GPU.

Impact

  • Enables low-precision FP4 KV cache for MLA.
  • Reduces GPU memory usage while maintaining accuracy.
  • Provides measurable speedup over FP8 quantization.
  • Backward compatible with existing FP16/FP8 workflows.

Accuracy Tests

The results show that on simpler datasets, the accuracy is nearly lossless compared to the baseline. On more challenging datasets, there is some accuracy degradation, but it remains within an acceptable range.

| Model | Dataset | Metric | Subset. | Num | Score | Cat.0 |


| DeepSeek-R1-0528-FP4-KV16 | gsm8k | mean_acc | main | 6595 | 0.9157 | default |
| DeepSeek-R1-0528-FP4-KV8_e4m3 | gsm8k | mean_acc | main | 6595 | 0.9154 | default |
| DeepSeek-R1-0528-FP4-KV4_fp4_e2m1 | gsm8k | mean_acc | main | 6595 | 0.9124 | default |

| DeepSeek-R1-0528-FP4-KV16 | aime25 | mean_acc | OVERALL | 150 | 0.5067 | - |
| DeepSeek-R1-0528-FP4-KV8_e4m3 | aime25 | mean_acc | OVERALL | 150 | 0.4934 | - |
| DeepSeek-R1-0528-FP4-KV4_fp4_e2m1 | aime25 | mean_acc | OVERALL | 150 | 0.4 | - |


| DeepSeek-R1-0528-FP4-KV16 | gpqa_diamond | mean_acc | default | 990 | 0.7707 | default |
| DeepSeek-R1-0528-FP4-KV8_e4m3 | gpqa_diamond | mean_acc | default | 990 | 0.7697 | default |
| DeepSeek-R1-0528-FP4-KV4_fp4_e2m1 | gpqa_diamond | mean_acc | default | 990 | 0.7273 | default |

Performance Results

Tested on B200. Server is running with --model DeepSeek-R1-0528-FP4 --tp-size 4 --moe-runner-backend flashinfer_trtllm --disable-radix-cache. Client is running with with --goodput ttft:5000 tpot:50 --random-input-len 3500 --random-output-len 1500 --max-concurrency 50 and --num-prompts 100.

Baseline has the best performance as the trtllm kernel accepts the BF16 KVCache directly, so no dequantization overhead at all.

(Default) BF16 KVCache

============ Serving Benchmark Result ============
Successful requests: 100
Benchmark duration (s): 116.44
Total input tokens: 349051
Total generated tokens: 150000
Request throughput (req/s): 0.86
Request goodput (req/s): 0.46
Output token throughput (tok/s): 1288.17
Total Token throughput (tok/s): 4285.73
---------------Time to First Token----------------
Mean TTFT (ms): 4853.44
Median TTFT (ms): 4784.58
P50 TTFT (ms): 4784.58
P90 TTFT (ms): 7828.38
P95 TTFT (ms): 9386.72
P99 TTFT (ms): 9616.72
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 35.05
Median TPOT (ms): 35.11
P50 TPOT (ms): 35.11
P90 TPOT (ms): 37.41
P95 TPOT (ms): 37.88
P99 TPOT (ms): 37.98
---------------Inter-token Latency----------------
Mean ITL (ms): 35.05
Median ITL (ms): 32.00
P50 ITL (ms): 32.00
P90 ITL (ms): 33.60
P95 ITL (ms): 33.91
P99 ITL (ms): 34.32
==================================================

When adding --kv-cache-dtype fp8_e4m3, the performance dropped significantly as KVCache is quanted at writing and dequanted at reading for attentions.

--kv-cache-dtype fp8_e4m3

============ Serving Benchmark Result ============
Successful requests: 100
Benchmark duration (s): 474.62
Total input tokens: 349051
Total generated tokens: 150000
Request throughput (req/s): 0.21
Request goodput (req/s): 0.00
Output token throughput (tok/s): 316.05
Total Token throughput (tok/s): 1051.48
---------------Time to First Token----------------
Mean TTFT (ms): 6093.23
Median TTFT (ms): 6057.21
P50 TTFT (ms): 6057.21
P90 TTFT (ms): 9279.85
P95 TTFT (ms): 10945.23
P99 TTFT (ms): 11347.56
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 153.69
Median TPOT (ms): 153.72
P50 TPOT (ms): 153.72
P90 TPOT (ms): 156.34
P95 TPOT (ms): 156.80
P99 TPOT (ms): 157.09
---------------Inter-token Latency----------------
Mean ITL (ms): 153.69
Median ITL (ms): 150.40
P50 ITL (ms): 150.40
P90 ITL (ms): 151.99
P95 ITL (ms): 152.28
P99 ITL (ms): 152.68
==================================================

With --kv-cache-dtype fp4_e2m1, the throughput is 17.8% higher compared to fp8_e4m3. Mostly due to quant/dequant to/from FP4 is faster than FP8.

--kv-cache-dtype fp4_e2m1

============ Serving Benchmark Result ============
Successful requests: 100
Benchmark duration (s): 402.92
Total input tokens: 349051
Total generated tokens: 150000
Request throughput (req/s): 0.25
Request goodput (req/s): 0.00
Output token throughput (tok/s): 372.28
Total Token throughput (tok/s): 1238.58
---------------Time to First Token----------------
Mean TTFT (ms): 5611.71
Median TTFT (ms): 5675.33
P50 TTFT (ms): 5675.33
P90 TTFT (ms): 9137.97
P95 TTFT (ms): 9952.08
P99 TTFT (ms): 10323.24
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 130.10
Median TPOT (ms): 129.87
P50 TPOT (ms): 129.87
P90 TPOT (ms): 133.07
P95 TPOT (ms): 133.56
P99 TPOT (ms): 133.66
---------------Inter-token Latency----------------
Mean ITL (ms): 130.10
Median ITL (ms): 126.83
P50 ITL (ms): 126.83
P90 ITL (ms): 128.53
P95 ITL (ms): 128.91
P99 ITL (ms): 129.45
==================================================

Future Work

We plan to support Multi-Head Attention (MHA) next.

Checklist

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @JackChuang, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances SGLang's memory efficiency and inference performance by enabling the use of FP4 (E2M1) precision for the KV cache in Multi-Head Latent Attention. This low-precision caching mechanism allows for substantial reductions in GPU memory consumption while largely preserving model accuracy, and it maintains full backward compatibility with existing FP16 and FP8 workflows.

Highlights

  • FP4 (E2M1) KV Cache Support: Added support for FP4 (E2M1) KV cache for Multi-Head Latent Attention (MLA) to reduce memory usage and improve inference efficiency.
  • FP4 Quantization Utilities: Introduced KVFP4QuantizeUtil for efficient block-wise FP4 quantization and dequantization of tensors.
  • Core KV Cache Integration: Integrated FP4 KV cache into ModelRunner and MLATokenToKVPool, including a new kv_scale_buffer and a Triton kernel (set_mla_kv_scale_buffer_kernel) for handling scale factors.
  • Server Argument Extension: Extended server arguments with --kv-cache-dtype=fp4_e2m1 for easy activation.
  • Unit Tests and Benchmarks: Included comprehensive unit tests to validate FP4 correctness (MSE, MAE, PSNR, Relative Error) and benchmarks comparing FP4 and FP8 performance.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces FP4 (E2M1) KV cache support, which is a great feature for reducing memory usage. The implementation of the quantization utilities and their integration into the memory pool and model runner looks mostly correct. However, I've found a few issues that need to be addressed:

  • There is a critical bug in the memory footprint calculation for the FP4 KV cache in model_runner.py, which could lead to incorrect memory allocation.
  • The new test file for FP4 quantization has a typo in an import path, which will cause it to fail.
  • There are some repeated imports in memory_pool.py that could be consolidated for better code clarity.

I have provided specific comments and suggestions for these points. Once these are addressed, this PR should be in good shape.

@JackChuang JackChuang force-pushed the horenc/kv4_on_v0.5.2rc1_release branch from 8995fd7 to 66f1dd7 Compare September 5, 2025 13:30
@JackChuang JackChuang force-pushed the horenc/kv4_on_v0.5.2rc1_release branch from 66f1dd7 to 087a4a1 Compare September 9, 2025 00:43
@zhyncs zhyncs added high priority quant LLM Quantization labels Sep 9, 2025
@AniZpZ
Copy link
Collaborator

AniZpZ commented Sep 10, 2025

great work! looks solid overall.

@AniZpZ
Copy link
Collaborator

AniZpZ commented Sep 10, 2025

Are there any more tests to evaluate accuracy drop in long-context scenarios?

@AniZpZ
Copy link
Collaborator

AniZpZ commented Sep 10, 2025

please fix ci and lint

@JackChuang
Copy link
Contributor Author

JackChuang commented Sep 10, 2025

Thanks @AniZpZ for the prompt review and helpful feedback! We will fix them and then you know.

Are there any more tests to evaluate accuracy drop in long-context scenarios?

Do you have a specific dataset result you’d like to see? We can collect the data.
The ones I currently have are categorized by subject.

@AniZpZ
Copy link
Collaborator

AniZpZ commented Sep 11, 2025

Thanks @AniZpZ for the prompt review and helpful feedback! We will fix them and then you know.

Are there any more tests to evaluate accuracy drop in long-context scenarios?

Do you have a specific dataset result you’d like to see? We can collect the data. The ones I currently have are categorized by subject.

I've observed a significantly larger drop in accuracy in AIME 25 compared to GSM8K, which leads me to hypothesize that accuracy is related to context length.

@JackChuang JackChuang force-pushed the horenc/kv4_on_v0.5.2rc1_release branch from 087a4a1 to 6d3f263 Compare September 11, 2025 08:30
@JackChuang
Copy link
Contributor Author

JackChuang commented Sep 11, 2025

I've observed a significantly larger drop in accuracy in AIME 25 compared to GSM8K, which leads me to hypothesize that accuracy is related to context length.

If AIME25 is considered a long-context dataset, then GPQA_Diamond should also fall into the long-context category.

If you have any more representative long-context datasets to recommend and if needed, I can give them a try.

@JackChuang
Copy link
Contributor Author

@JackChuang Please fix the conflicts

Hi @Fridge003 @zhyncs, I've rebased to v0.5.3 and fix the conflicts. Can you please launch the CI and check again? Thank you~

@Fridge003
Copy link
Collaborator

@JackChuang It shows there are still conflicts with latest main

@JackChuang
Copy link
Contributor Author

JackChuang commented Oct 24, 2025

@JackChuang It shows there are still conflicts with latest main

@Fridge003 Oops. You are right. It seems like v0.5.3 is also outdated. I will rebase to main then. Thanks, and sry for the inconvenience.

@JackChuang
Copy link
Contributor Author

JackChuang commented Oct 26, 2025

@JackChuang Please fix the conflicts

Hi @Fridge003 @zhyncs, I've rebased to v0.5.3 and fix the conflicts. Can you please launch the CI and check again? Thank you~

While merging with the main branch, I noticed that the latest main branch includes the assumption:

“The TensorRT-LLM MLA backend only supports kv-cache-dtype of fp8_e4m3 or auto.”

I’ll need to look into this further and run some tests before merging. Thanks.

@JackChuang JackChuang force-pushed the horenc/kv4_on_v0.5.2rc1_release branch 2 times, most recently from 56abc5c to bdf4871 Compare October 28, 2025 08:12
@JackChuang
Copy link
Contributor Author

JackChuang commented Oct 28, 2025

Hi @Fridge003 @zhyncs, I’ve rebased to main and fixed the conflicts.
Could you please check again? Thank you!

Note: Turns out the current 'main' and 'v0.5.4.post1' branches cannot run. As a result, I’ve tested my code on v0.5.4, and it’s working.

@JackChuang JackChuang force-pushed the horenc/kv4_on_v0.5.2rc1_release branch from bdf4871 to d2cc365 Compare October 28, 2025 21:57
@Fridge003
Copy link
Collaborator

Note: Turns out the current 'main' and 'v0.5.4.post1' branches cannot run. As a result, I’ve tested my code on v0.5.4, and it’s working.

Do you mean the v0.5.4.post1 cannot run on this PR, or is there other bugs?

@JackChuang
Copy link
Contributor Author

Note: Turns out the current 'main' and 'v0.5.4.post1' branches cannot run. As a result, I’ve tested my code on v0.5.4, and it’s working.

Do you mean the v0.5.4.post1 cannot run on this PR, or is there other bugs?

@Fridge003 What I mean is: I was originally developing on the main and v0.5.4.post1 branches, and I found that my KV4 setup couldn’t run. Even after removing the KV cache quantization option, it still didn’t work. Then I switched to v0.5.4, and everything worked fine, so I confirmed that the issue wasn’t caused by my code.

@JackChuang
Copy link
Contributor Author

Hi @Fridge003 @zhyncs, When you get a chance, could you please check again if this is ready to merge? Thanks a lot!

dtype=self.data_type,
device=self.device,
)
if self.data_type == getattr(torch, "float4_e2m1fn_x2", None) and _is_cuda:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we change self.data_type == getattr(torch, "float4_e2m1fn_x2", None) and _is_cuda to a utils function for better readability?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One last comment

JackChuang and others added 4 commits October 31, 2025 22:13
Extend the `--kv-cache-dtype` argument in ServerArgs to support "fp4_e2m1".

Signed-off-by: Ho-Ren (Jack) Chuang <[email protected]>
Co-authored-by: Yichen Wang <[email protected]>
- Introduce `KVFP4QuantizeUtil` for FP4 (E2M1) quantization and dequantization.
- Provides `batched_quantize` and `batched_dequantize` methods for block-wise (16) processing of [M, N, K] tensors.

Signed-off-by: Ho-Ren (Jack) Chuang <[email protected]>
Co-authored-by: Yichen Wang <[email protected]>
- Introduce `test_kvfp4_quant_dequant.py` to validate correctness and performance of KVFP4 quantization.
- Provides metrics calculation (MSE, MAE, PSNR, Relative Error) to compare original and dequantized tensors.
- Benchmarks KVFP4 vs FP8 quant/dequant performance on GPU with large tensors.

Signed-off-by: Ho-Ren (Jack) Chuang <[email protected]>
Co-authored-by: Yichen Wang <[email protected]>
Core change enabling low-precision FP4 KV caching for MLA,
improving inference efficiency while keeping existing workflows intact.

- Introduce FP4 KV cache support in MLATokenToKVPool for reduced memory usage.
- Add kv_scale_buffer to store FP4 scaling factors and updated allocation logic.
- Implement Triton kernel to combine nope + rope tensors and write to KV + scale buffers.
- Modify ModelRunner to account for FP4 buffer sizing and dtype.
- Maintains backward compatibility with FP16/FP8 KV cache.

Also, move Triton kernels from mem_cache/memory_pool.py to srt/mem_cache/utils.py.

Signed-off-by: Ho-Ren (Jack) Chuang <[email protected]>
Co-authored-by: Yichen Wang <[email protected]>
@JackChuang JackChuang force-pushed the horenc/kv4_on_v0.5.2rc1_release branch from d2cc365 to 46d44aa Compare October 31, 2025 22:47
@Fridge003
Copy link
Collaborator

@JackChuang Please fix CI bugs

- Added `is_cuda()` and `is_float4_e2m1fn_x2()` in `sglang/srt/utils/torch_utils.py`
- Replaced inline checks in relevant modules

Signed-off-by: Ho-Ren (Jack) Chuang <[email protected]>
@JackChuang JackChuang force-pushed the horenc/kv4_on_v0.5.2rc1_release branch from 46d44aa to c128f91 Compare November 1, 2025 09:51
@Fridge003
Copy link
Collaborator

@Fridge003 Fridge003 merged commit 76196b3 into sgl-project:main Nov 2, 2025
136 of 153 checks passed
JackChuang added a commit to bytedance-iaas/sglang that referenced this pull request Nov 4, 2025
Based on PR sgl-project#10078, this patch
- introduces FP4 KV cache support in MHATokenToKVPool with uint8 storage.
- adds k_scale_buffer and v_scale_buffer to store FP4 scaling factors.
- implements batched quantization on cache update and dequantization on access.
- updates ModelRunner memory estimation to account for FP4 scale buffers.
- maintains backward compatibility with FP16/FP8 KV cache.

Signed-off-by: Ho-Ren (Jack) Chuang <[email protected]>
Co-authored-by: Yichen Wang <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants