GRIFFIN: Effective Token Alignment for Faster Speculative Decoding
GRIFFIN is a novel framework that incorporates a token-alignable training strategy and a token-alignable draft model to mitigate token misalignment, thereby accelerating inference in large language models (LLMs). The training strategy employs a loss masking mechanism to exclude highly misaligned tokens, preventing them from negatively impacting the draft model's optimization. The token-alignable draft model introduces input tokens to correct inconsistencies in generated features.
Experiments on LLaMA, Vicuna, Qwen and Mixtral models demonstrate that GRIFFIN achieves an average acceptance length improvement of over 8% and a speedup ratio exceeding 7%, outperforming current speculative decoding state-of-the-art methods.
- GRIFFIN is:
- 4.2x faster than vanilla decoding.
- 1.3x faster than EAGLE-2.
Acceleration demo of GRIFFIN for llama3-8B in a 4090GPU
Paper
The model was presented in the paper: GRIFFIN: Effective Token Alignment for Faster Speculative Decoding
Code
The official implementation and more details can be found on the GitHub repository: https://github.com/hsj576/GRIFFIN
Inference
You can use our provided eagenerate function for speedup generation, just like using generate from Hugging Face. Here is an example:
from model.ea_model_griffin import EaModel
from fastchat.model import get_conversation_template
import torch
# Define your base model path and GRIFFIN model path
base_model_path = "meta-llama/Llama-2-7b-chat-hf" # Example base model
EAGLE_model_path = "husj576/GRIFFIN-llama2-chat-7B" # Example GRIFFIN model
model = EaModel.from_pretrained(
base_model_path=base_model_path,
ea_model_path=EAGLE_model_path,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
device_map="auto",
total_token=-1
)
model.eval()
your_message="Hello"
conv = get_conversation_template("llama2") # Use appropriate conversation template
conv.append_message(conv.roles[0], your_message)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
input_ids=model.tokenizer([prompt]).input_ids
input_ids = torch.as_tensor(input_ids).cuda()
output_ids=model.eagenerate(input_ids,temperature=0.5,max_new_tokens=512)
output=model.tokenizer.decode(output_ids[0])
print(output)
Note: Vicuna, LLaMA2-Chat, and LLaMA3-Instruct are chat models. You need to use the correct chat template, otherwise it will cause abnormal output from the model and affect the performance of GRIFFIN.
Citation
If you find this work useful, please cite the paper:
@misc{hu2025griffineffectivetokenalignment,
title={GRIFFIN: Effective Token Alignment for Faster Speculative Decoding},
author={Shijing Hu and Jingyang Li and Xingyu Xie and Zhihui Lu and Kim-Chuan Toh and Pan Zhou},
year={2025},
eprint={2502.11018},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2502.11018},
}
- Downloads last month
- 7