Skip to content

Latest commit

 

History

History
 
 

Folders and files

NameName
Last commit message
Last commit date

parent directory

..
 
 
 
 
 
 
 
 

README.md

KV Prediction For Improved Time To First Token

KV Prediction is a method for improving the time to first token (TTFT) of transformer models. It uses a small "auxiliary" transformer network to process the prompt efficiently. It then uses the KV cache of the auxiliary network to predict the KV cache of a larger "base" network. The base network is then used for inference without the need to query the auxiliary model again during autoregressive generation. Our method creates a pareto-optimal efficiency-accuracy trade-off for TTFT compared to baselines on benchmark datasets. See our paper for details.

Training

We experiment with OpenELM models. Configs are located in the openelm/ subdirectory. We used multinode training jobs with 8 nodes and 8 H100 GPUs per node.

An example command for training on the i-th node is

export CFG_FILE="PATH_TO_KV_PREDICTION_MODEL_CONFIGURATION_FILE"
export RANK=<NODE_ID> * <NUM_GPUS_PER_NODE>
export WORLD_SIZE=<NUM_NODES> * <NUM_GPUS_PER_NODE>
corenet-train --common.config-file $CFG_FILE --ddp.rank $RANK --ddp.world-size $WORLD_SIZE --ddp.dist-url 'tcp://IP_OF_NODE0:FREEPORT'

Evaluation

We evaluate in the LM Eval Harness on commit 3196e907fa195b684470a913c7235ed7f08a4383. We use the prompt template in triviaqa-template.yaml, since we noticed that the default template added an extra question mark to the question.

Citation

If you find our work useful, please cite:

@misc{horton2024kvpredictionimprovedtime,
      title={KV Prediction for Improved Time to First Token},
      author={Maxwell Horton and Qingqing Cao and Chenfan Sun and Yanzi Jin and Sachin Mehta and Mohammad Rastegari and Moin Nabi},
      year={2024},
      eprint={2410.08391},
      archivePrefix={arXiv},
      primaryClass={cs.CL},
      url={https://arxiv.org/abs/2410.08391},
}

@inproceedings{mehta2022cvnets, 
     author = {Mehta, Sachin and Abdolhosseini, Farzad and Rastegari, Mohammad}, 
     title = {CVNets: High Performance Library for Computer Vision}, 
     year = {2022}, 
     booktitle = {Proceedings of the 30th ACM International Conference on Multimedia}, 
     series = {MM '22} 
}