GRIFFIN: Effective Token Alignment for Faster Speculative Decoding

GRIFFIN

GRIFFIN is a novel framework that incorporates a token-alignable training strategy and a token-alignable draft model to mitigate token misalignment, thereby accelerating inference in large language models (LLMs). The training strategy employs a loss masking mechanism to exclude highly misaligned tokens, preventing them from negatively impacting the draft model's optimization. The token-alignable draft model introduces input tokens to correct inconsistencies in generated features.

Experiments on LLaMA, Vicuna, Qwen and Mixtral models demonstrate that GRIFFIN achieves an average acceptance length improvement of over 8% and a speedup ratio exceeding 7%, outperforming current speculative decoding state-of-the-art methods.

  • GRIFFIN is:
    • 4.2x faster than vanilla decoding.
    • 1.3x faster than EAGLE-2.

Acceleration demo of GRIFFIN for llama3-8B in a 4090GPU

demogif

Paper

The model was presented in the paper: GRIFFIN: Effective Token Alignment for Faster Speculative Decoding

Code

The official implementation and more details can be found on the GitHub repository: https://github.com/hsj576/GRIFFIN

Inference

You can use our provided eagenerate function for speedup generation, just like using generate from Hugging Face. Here is an example:

from model.ea_model_griffin import EaModel
from fastchat.model import get_conversation_template
import torch

# Define your base model path and GRIFFIN model path
base_model_path = "meta-llama/Llama-2-7b-chat-hf" # Example base model
EAGLE_model_path = "husj576/GRIFFIN-llama2-chat-7B" # Example GRIFFIN model

model = EaModel.from_pretrained(
    base_model_path=base_model_path,
    ea_model_path=EAGLE_model_path,
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True,
    device_map="auto",
    total_token=-1
)
model.eval()

your_message="Hello"
conv = get_conversation_template("llama2") # Use appropriate conversation template
conv.append_message(conv.roles[0], your_message)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()

input_ids=model.tokenizer([prompt]).input_ids
input_ids = torch.as_tensor(input_ids).cuda()

output_ids=model.eagenerate(input_ids,temperature=0.5,max_new_tokens=512)
output=model.tokenizer.decode(output_ids[0])
print(output)

Note: Vicuna, LLaMA2-Chat, and LLaMA3-Instruct are chat models. You need to use the correct chat template, otherwise it will cause abnormal output from the model and affect the performance of GRIFFIN.

Citation

If you find this work useful, please cite the paper:

@misc{hu2025griffineffectivetokenalignment,
      title={GRIFFIN: Effective Token Alignment for Faster Speculative Decoding}, 
      author={Shijing Hu and Jingyang Li and Xingyu Xie and Zhihui Lu and Kim-Chuan Toh and Pan Zhou},
      year={2025},
      eprint={2502.11018},
      archivePrefix={arXiv},
      primaryClass={cs.CL},
      url={https://arxiv.org/abs/2502.11018}, 
}
Downloads last month
7
Safetensors
Model size
0.5B params
Tensor type
F32
ยท
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support