Spaces:
Runtime error
Runtime error
| import spaces | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from config import MODEL_NAME | |
| import torch | |
| # Load model and tokenizer globally for efficiency | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float16, device_map="auto") | |
| def generate_response(message, history, system_prompt, temperature, max_tokens, top_p): | |
| # Format conversation history | |
| conversation = [{"role": "system", "content": system_prompt}] | |
| for user_msg, assistant_msg in history: | |
| conversation.append({"role": "user", "content": user_msg}) | |
| if assistant_msg: | |
| conversation.append({"role": "assistant", "content": assistant_msg}) | |
| conversation.append({"role": "user", "content": message}) | |
| # Format for chat model | |
| input_text = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True) | |
| inputs = tokenizer(input_text, return_tensors="pt").to(model.device) | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=max_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| do_sample=temperature > 0, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| response = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True) | |
| return response |