from huggingface_hub import hf_hub_download import torch from torch import nn from torch.nn import functional as F from torch.utils.data import Dataset, DataLoader, random_split import urllib.request import os from transformers import AutoTokenizer, logging import pandas as pd from tqdm import tqdm from safetensors.torch import load_file class TransformerBlock(nn.Module): def __init__(self, emb_dim, num_heads, context_length, dropout=0.1): super().__init__() self.ln1 = nn.LayerNorm(emb_dim) self.ln2 = nn.LayerNorm(emb_dim) self.attn = nn.MultiheadAttention( emb_dim, num_heads, dropout=dropout, batch_first=True ) self.mlp = nn.Sequential( nn.Linear(emb_dim, 4 * emb_dim), nn.GELU(), nn.Linear(4 * emb_dim, emb_dim), nn.Dropout(dropout), ) def forward(self, x): attn_out, _ = self.attn( self.ln1(x), self.ln1(x), self.ln1(x), need_weights=False ) x = x + attn_out x = x + self.mlp(self.ln2(x)) return x class MiniTransformer(nn.Module): def __init__( self, vocab_size, emb_dim, context_length, num_heads, num_layers, dropout=0.1, ): super().__init__() self.emb = nn.Embedding(vocab_size, emb_dim) self.pos_emb = nn.Embedding(context_length, emb_dim) self.blocks = nn.Sequential( *[ TransformerBlock(emb_dim, num_heads, context_length, dropout) for _ in range(num_layers) ] ) self.ln_f = nn.LayerNorm(emb_dim) self.head = nn.Linear(emb_dim, vocab_size, bias=False) self.context_length = context_length def forward(self, x): B, T = x.shape pos = torch.arange(T, device=x.device) x = self.emb(x) + self.pos_emb(pos) x = self.blocks(x) x = self.ln_f(x) logits = self.head(x) return logits @torch.no_grad() def generate(self, x, max_new_tokens=20, temperature=1.0, top_k=None): for _ in range(max_new_tokens): # truncate context if needed x_cond = x[:, -self.context_length :] # get predictions logits = self(x_cond) # (B, T_cond, vocab_size) logits = logits[:, -1, :] / temperature # only last position # optionally restrict to top-k probs = F.softmax(logits, dim=-1) # sample from the distribution next_token = torch.multinomial(probs, num_samples=1) # (B, 1) # next_token = torch.argmax(probs, dim = 1).unsqueeze(-1) # append to sequence x = torch.cat([x, next_token], dim=1) return x CONTEXT_LENGTH = 256 EMBEDDING_DIMENSION = 512 HEAD_NUMBER = 8 N_LAYER = 6 tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") device = torch.device("cuda" if torch.cuda.is_available() else "mps") # Download the model file model_path = hf_hub_download( repo_id="pierjoe/MiniTransformer", filename="checkpoints/mini_transformer_v4/model_50.safetensors", ) # Load with your custom class model = MiniTransformer( vocab_size=tokenizer.vocab_size, emb_dim=EMBEDDING_DIMENSION, context_length=CONTEXT_LENGTH, num_heads=HEAD_NUMBER, num_layers=N_LAYER, ).to(device) state_dict = load_file(model_path) state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()} model.load_state_dict(state_dict) model.eval() max_tokens = 100 prompt = "You are a helpful assistant. Provide clear, concise, and accurate responses to the user " input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device) output_ids = model.generate( input_ids, max_new_tokens=max_tokens, temperature=5, top_k=10 ) generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True) generated_text