{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [], "gpuType": "T4", "toc_visible": true }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" }, "accelerator": "GPU", "widgets": { "application/vnd.jupyter.widget-state+json": { "bddd70b266bc4965ae2105529d7375a5": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_d911fbb4bb154295ba788280a842b06b", "IPY_MODEL_90e65bc053874978b5c94031ff8e1c32", "IPY_MODEL_99b0afae736649e3a95ef76e5d1066c7" ], "layout": "IPY_MODEL_593a82efdc574f98a41a7745485014cd" } }, "d911fbb4bb154295ba788280a842b06b": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_a9c6b9cc6b6c401c823638efbbbe228a", "placeholder": "​", "style": "IPY_MODEL_8e2bb75b0f45452a84630252d360f188", "value": "checkpoints/mini_transformer_v3/model_40(…): 100%" } }, "90e65bc053874978b5c94031ff8e1c32": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_e225438d9b9149048404c03fc7504530", "max": 175727816, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_a88b9683c01145e99385c8eb0fba57ac", "value": 175727816 } }, "99b0afae736649e3a95ef76e5d1066c7": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_77153e40c0b84703b23a7c642c0d79d3", "placeholder": "​", "style": "IPY_MODEL_7e1cfb80cfa14df6b2b75c3e3e21d202", "value": " 176M/176M [00:04<00:00, 39.1MB/s]" } }, "593a82efdc574f98a41a7745485014cd": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "a9c6b9cc6b6c401c823638efbbbe228a": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "8e2bb75b0f45452a84630252d360f188": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "e225438d9b9149048404c03fc7504530": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "a88b9683c01145e99385c8eb0fba57ac": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "" } }, "77153e40c0b84703b23a7c642c0d79d3": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "7e1cfb80cfa14df6b2b75c3e3e21d202": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } } } } }, "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "id": "yCWI7FNLil-S" }, "outputs": [], "source": [ "import torch\n", "from torch import nn\n", "from torch.nn import functional as F\n", "from torch.utils.data import Dataset, DataLoader,random_split\n", "import urllib.request\n", "import os\n", "from transformers import AutoTokenizer, logging\n", "import pandas as pd\n", "from tqdm import tqdm\n" ] }, { "cell_type": "code", "source": [ "\n", "text = str(urllib.request.urlopen(\"https://ocw.mit.edu/ans7870/6/6.006/s08/lecturenotes/files/t8.shakespeare.txt\").read())\n", "\n", "text = text.lower()\n" ], "metadata": { "id": "mCYAGXIWjudr" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "class Tokenizer():\n", " def __init__(self, text):\n", " self.pad_token = \"\"\n", " self.pad_token_id = 0\n", " self.itos: dict = {0:\"\"}\n", " self.stoi: dict = {\"\":0}\n", " counter = 1\n", " for i in text:\n", " if i in self.stoi:\n", " continue\n", " self.stoi[i] = counter\n", " self.itos[counter] = i\n", " counter +=1\n", " def __len__(self):\n", " return len(self.itos)\n", " def encode(self, t):\n", " if isinstance(t, str):\n", "\n", " return [self.stoi[i] for i in t]\n", " else:\n", " return [[self.stoi[i] for i in k] for k in t]\n", " def decode(self, l:torch.tensor):\n", " return [self.itos[i] for i in l]\n", "\n", "tokenizer = Tokenizer(text)\n", "dictionary_size = len(tokenizer)" ], "metadata": { "id": "mJ3HjU6TlHCl" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "X = tokenizer.encode(\"ciao\")\n", "tokenizer.decode(X)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "CSGAomhilXGD", "outputId": "3e65d609-06f6-430b-c08b-14607f7e519d" }, "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "['c', 'i', 'a', 'o']" ] }, "metadata": {}, "execution_count": 4 } ] }, { "cell_type": "markdown", "source": [ "# one head" ], "metadata": { "id": "SKUnJy1tsiUt" } }, { "cell_type": "code", "source": [ "emb_dim = 3" ], "metadata": { "id": "aIMCVEVcsrmx" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "emb = nn.Embedding(dictionary_size, emb_dim)\n", "X = torch.tensor(X)" ], "metadata": { "id": "k7yd-X4Bira9" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "X_embedded = emb(X)" ], "metadata": { "id": "CzimFfLyjoqo" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "# context lenght = 4\n", "# batch sizev = 1\n", "# X --> (1,4,3)\n", "head_size = 3\n", "context_length = 4" ], "metadata": { "id": "RdHaLMz3jpt2" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "Wq = torch.rand((emb_dim, head_size))\n", "Wk = torch.rand((emb_dim, head_size))\n", "Wv = torch.rand((emb_dim, head_size))" ], "metadata": { "id": "Zc-pPslmnait" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "X_embedded.shape" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "o8UKPe0xoFQa", "outputId": "e90a1162-24e6-4651-ce28-5ec5c8df3cdb" }, "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "torch.Size([4, 3])" ] }, "metadata": {}, "execution_count": 10 } ] }, { "cell_type": "code", "source": [ "Q = X_embedded @ Wq\n", "K = X_embedded @ Wk\n", "V = X_embedded @ Wv" ], "metadata": { "id": "xM6IKyz8nu7e" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "Q.shape, K.shape" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "UuwadGQBog1g", "outputId": "df0fef75-9611-43b7-e247-8eada15cd56e" }, "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "(torch.Size([4, 3]), torch.Size([4, 3]))" ] }, "metadata": {}, "execution_count": 12 } ] }, { "cell_type": "code", "source": [ "attention_score = Q @ K.reshape(1,-1,context_length)" ], "metadata": { "id": "p_erEj51ojNm" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "attention_score.shape" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "SO_LjwXSqR1S", "outputId": "9c94e90f-9fb8-43a4-edb5-b3f61db1c8a2" }, "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "torch.Size([1, 4, 4])" ] }, "metadata": {}, "execution_count": 14 } ] }, { "cell_type": "code", "source": [ "attention_mask = torch.triu(torch.ones(context_length, context_length), diagonal = 1).bool()" ], "metadata": { "id": "T95N3k1SqTAU" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "attention_mask\n", "mask = attention_mask.unsqueeze(0).expand(attention_score.size())\n" ], "metadata": { "id": "nUh6Eb71qmCW" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "attention_score_masked = attention_score.masked_fill(mask,float('-inf'))" ], "metadata": { "id": "XlQACuGzqqj0" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "attn_weights = attention_score_masked.softmax(dim = -1)" ], "metadata": { "id": "PoeHAJocq7nF" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "attn_weights" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "erLsj_X_rNX0", "outputId": "a57e1c83-a0f4-4164-9230-7e798168a673" }, "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "tensor([[[1.0000, 0.0000, 0.0000, 0.0000],\n", " [0.5995, 0.4005, 0.0000, 0.0000],\n", " [0.4626, 0.3241, 0.2133, 0.0000],\n", " [0.2012, 0.1739, 0.4339, 0.1910]]], grad_fn=)" ] }, "metadata": {}, "execution_count": 19 } ] }, { "cell_type": "code", "source": [ "attn_output = attn_weights @ V" ], "metadata": { "id": "pBhaoDnNrRsJ" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "attn_output.shape" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "ffCZGstRrhLk", "outputId": "8491f201-c763-49f7-efec-6e3b385f99fd" }, "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "torch.Size([1, 4, 3])" ] }, "metadata": {}, "execution_count": 21 } ] }, { "cell_type": "code", "source": [ "attn_output" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "UgHy3eedrqfG", "outputId": "89f54950-5f02-4aef-a9c9-bf884df91d05" }, "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "tensor([[[-0.3628, -0.3327, -1.1935],\n", " [-0.2787, -0.2667, -0.4475],\n", " [ 0.1283, 0.0160, -0.2265],\n", " [ 0.3651, 0.2155, 0.2461]]], grad_fn=)" ] }, "metadata": {}, "execution_count": 22 } ] }, { "cell_type": "markdown", "source": [ "# multiple heads + positional embedding" ], "metadata": { "id": "pJ_mloGBskXJ" } }, { "cell_type": "code", "source": [ "X = tokenizer.encode([\"ciof\", \"miaoe\"])\n", "batch_size = len(X)\n", "head_size = 15\n", "context_length = 10\n", "emb_dim = 15\n", "X = [torch.tensor(e) for e in X]\n", "X" ], "metadata": { "id": "iTDKWnn4tGt3", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "aa085591-3bb3-4e48-ab4e-46164a38c965" }, "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "[tensor([21, 5, 19, 12]), tensor([29, 5, 25, 19, 8])]" ] }, "metadata": {}, "execution_count": 23 } ] }, { "cell_type": "code", "source": [ "X = torch.stack([\n", " F.pad(x, (context_length - len(x),0), value=tokenizer.pad_token_id)\n", " for x in X\n", "])\n", "X" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "LFNIsjpmiLVB", "outputId": "c0c418f4-3733-4f82-af6c-5d85f69e5ca7" }, "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "tensor([[ 0, 0, 0, 0, 0, 0, 21, 5, 19, 12],\n", " [ 0, 0, 0, 0, 0, 29, 5, 25, 19, 8]])" ] }, "metadata": {}, "execution_count": 24 } ] }, { "cell_type": "code", "source": [ "emb = nn.Embedding(dictionary_size, emb_dim, padding_idx=0)\n", "pos_emb = nn.Embedding(context_length, emb_dim)\n", "\n", "positions = torch.arange(context_length).unsqueeze(0)\n", "\n", "X_embedded = emb(X)+pos_emb(positions)\n", "X_embedded.shape" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "j8m82xWqstX9", "outputId": "d968b96c-7cdf-4730-e8bc-544c1d88a0de" }, "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "torch.Size([2, 10, 15])" ] }, "metadata": {}, "execution_count": 25 } ] }, { "cell_type": "code", "source": [ "Wq = torch.rand((emb_dim, emb_dim))\n", "Wk = torch.rand((emb_dim, emb_dim))\n", "Wv = torch.rand((emb_dim, emb_dim))" ], "metadata": { "id": "odYyGLCJt45V" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "Q = X_embedded @ Wq\n", "K = X_embedded @ Wk\n", "V = X_embedded @ Wv\n", "Q.shape\n" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "fUvHjwcut9Yr", "outputId": "4c4d4eb2-2ff2-41fa-b631-55681e040cad" }, "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "torch.Size([2, 10, 15])" ] }, "metadata": {}, "execution_count": 27 } ] }, { "cell_type": "code", "source": [ "num_heads = emb_dim // head_size\n", "num_heads" ], "metadata": { "id": "4df_tMCmt-Z0", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "51f383e9-8757-4243-bfcb-d81b02a1e675" }, "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "1" ] }, "metadata": {}, "execution_count": 28 } ] }, { "cell_type": "code", "source": [ "Q.shape" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "pTiwY1uPlzXy", "outputId": "a0bff6cb-168b-489c-a072-f0f8627c1e53" }, "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "torch.Size([2, 10, 15])" ] }, "metadata": {}, "execution_count": 29 } ] }, { "cell_type": "code", "source": [ "Q = Q.view(batch_size, context_length, num_heads, head_size).transpose(1, 2) # (B, num_heads, T, head_size)\n", "K = K.view(batch_size, context_length, num_heads, head_size).transpose(1, 2)\n", "V = V.view(batch_size, context_length, num_heads, head_size).transpose(1, 2)\n", "V.shape" ], "metadata": { "id": "FF8G9uKNuHtI", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "55792418-1f68-45e8-90d6-578fd4c5f6f7" }, "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "torch.Size([2, 1, 10, 15])" ] }, "metadata": {}, "execution_count": 30 } ] }, { "cell_type": "code", "source": [ "Q.shape[0] # --> batch size\n", "Q.shape[1] # --> attention head\n", "Q.shape[2] # --> context lenght\n", "Q.shape[3] # --> head_size\n", "\n", "Q.shape\n", "# Embedding dim (10)\n", "# │\n", "# ├── Head 1 → works on dimensions [0‒4] → output (…, 5)\n", "# └── Head 2 → works on dimensions [5‒9] → output (…, 5)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "bGe2flnsvGHq", "outputId": "a5e2a312-f876-4606-ab1d-e131b3101e82" }, "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "torch.Size([2, 1, 10, 15])" ] }, "metadata": {}, "execution_count": 31 } ] }, { "cell_type": "code", "source": [ "K.transpose(-2,-1).shape" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "KoNemkrsvX8F", "outputId": "983673de-efbc-4f47-f79f-6ab14f897682" }, "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "torch.Size([2, 1, 15, 10])" ] }, "metadata": {}, "execution_count": 32 } ] }, { "cell_type": "code", "source": [ "attn_scores = Q @ K.transpose(-2, -1) / head_size**0.5 # (B, H, T, T)\n", "\n", "attention_mask = torch.triu(torch.ones(context_length, context_length), diagonal = 1).bool()\n", "mask = attention_mask.unsqueeze(0).expand(attn_scores.size())\n", "\n", "attn_scores_masked = attn_scores.masked_fill(mask,float('-inf'))\n", "\n", "attn_weights = torch.softmax(attn_scores_masked, dim=-1)\n", "attn_output = attn_weights @ V # (B, H, T, head_size)\n", "attn_weights.shape" ], "metadata": { "id": "MuNmXtomvMTC", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "3171c895-452f-464e-ed95-b3c5e09f4f0c" }, "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "torch.Size([2, 1, 10, 10])" ] }, "metadata": {}, "execution_count": 33 } ] }, { "cell_type": "code", "source": [ "attn_output.shape" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "daVyXVrIvh1o", "outputId": "837bddff-12eb-425d-d99e-3aca37d796b6" }, "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "torch.Size([2, 1, 10, 15])" ] }, "metadata": {}, "execution_count": 34 } ] }, { "cell_type": "code", "source": [ "attn_output.transpose(-3,-2).reshape(batch_size,context_length,-1).shape" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "IhSrhtYyvimn", "outputId": "0ddd39de-9dbd-41bb-e3e6-702acfd4bdd3" }, "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "torch.Size([2, 10, 15])" ] }, "metadata": {}, "execution_count": 35 } ] }, { "cell_type": "code", "source": [ "# residual connection\n", "residual = attn_output.transpose(-3,-2).reshape(batch_size,context_length,-1) + X_embedded" ], "metadata": { "id": "eW_NTAf3mb6u" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "# attention block" ], "metadata": { "id": "LNjd9bnKfoe_" } }, { "cell_type": "code", "source": [ "X = tokenizer.encode([\"ciof\", \"miaoe\"])\n", "batch_size = len(X)\n", "head_size = 15\n", "context_length = 10\n", "emb_dim = 15\n", "X = [torch.tensor(e) for e in X]\n", "X = torch.stack([\n", " F.pad(x, (context_length - len(x),0), value=tokenizer.pad_token_id)\n", " for x in X\n", "])\n", "X" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "MAu-X-T-tMQ9", "outputId": "104e56cd-6095-4256-b6e0-c5212aacdc35" }, "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "tensor([[ 0, 0, 0, 0, 0, 0, 21, 5, 19, 12],\n", " [ 0, 0, 0, 0, 0, 29, 5, 25, 19, 8]])" ] }, "metadata": {}, "execution_count": 37 } ] }, { "cell_type": "code", "source": [ "\n", "class AttentionBlock(nn.Module):\n", " def __init__(self, head_size=5, context_length=10, emb_dim=15, dictionary_size=100) -> None:\n", " super().__init__()\n", " assert emb_dim % head_size == 0, \"emb_dim must be divisible by head_size\"\n", "\n", " self.emb = nn.Embedding(dictionary_size, emb_dim, padding_idx=0)\n", " self.pos_emb = nn.Embedding(context_length, emb_dim)\n", "\n", " self.Wq = nn.Parameter(torch.randn(emb_dim, emb_dim))\n", " self.Wk = nn.Parameter(torch.randn(emb_dim, emb_dim))\n", " self.Wv = nn.Parameter(torch.randn(emb_dim, emb_dim))\n", "\n", " self.layer_norm = nn.LayerNorm(emb_dim)\n", "\n", " self.context_length = context_length\n", " self.head_size = head_size\n", " self.num_heads = emb_dim // head_size\n", "\n", " # causal mask (upper-triangular)\n", " mask = torch.triu(torch.ones(context_length, context_length), diagonal=1).bool()\n", " self.register_buffer(\"attention_mask\", mask)\n", "\n", " def forward(self, x):\n", " B, T = x.shape\n", "\n", " positions = torch.arange(T)\n", " X_embedded = self.emb(x) + self.pos_emb(positions)\n", " X_embedded = self.layer_norm(X_embedded)\n", "\n", " Q = X_embedded @ self.Wq # --> produce query\n", " K = X_embedded @ self.Wk # --> produce key\n", " V = X_embedded @ self.Wv # --> produce value\n", "\n", " # reshape into heads\n", " Q = Q.view(B, T, self.num_heads, self.head_size).transpose(1, 2) # (B, H, T, d_head)\n", " K = K.view(B, T, self.num_heads, self.head_size).transpose(1, 2)\n", " V = V.view(B, T, self.num_heads, self.head_size).transpose(1, 2)\n", "\n", " attn_scores = (Q @ K.transpose(-2, -1)) / (self.head_size ** 0.5)\n", "\n", " attn_scores = attn_scores.masked_fill(self.attention_mask[:T, :T], float('-inf')) # apply mask\n", "\n", " attn_weights = F.softmax(attn_scores, dim=-1)\n", " attn_output = attn_weights @ V # (B, H, T, d_head)\n", "\n", " attn_output = attn_output.transpose(1, 2).contiguous().view(B, T, -1) # merge back the heads in one matrix\n", " residual = attn_output + X_embedded # add residual connections\n", "\n", " return residual\n" ], "metadata": { "id": "V9SOgSvAvzHd" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "atn_block = AttentionBlock(head_size=5, context_length=10, emb_dim=15, dictionary_size=len(tokenizer))" ], "metadata": { "id": "DJ9NuaeYv3be" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "atn_block(X).shape" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "HBhaaBdZv7MM", "outputId": "248f8f42-6a1a-4b65-95a1-cbc830b82078" }, "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "torch.Size([2, 10, 15])" ] }, "metadata": {}, "execution_count": 40 } ] }, { "cell_type": "markdown", "source": [ "# what we made is not ideal - let's make it how pytorch wants" ], "metadata": { "id": "G-bypf85_MrT" } }, { "cell_type": "code", "source": [ "class AttentionBlock(nn.Module):\n", " def __init__(self, emb_dim=15, num_heads=3, context_length=10, dropout=0.1):\n", " super().__init__()\n", " assert emb_dim % num_heads == 0, \"emb_dim must be divisible by num_heads\"\n", " head_dim = emb_dim // num_heads\n", "\n", " self.num_heads = num_heads\n", " self.head_dim = head_dim\n", " self.scale = head_dim ** -0.5\n", "\n", " # Linear projections for Q, K, V\n", " self.Wq = nn.Linear(emb_dim, emb_dim)\n", " self.Wk = nn.Linear(emb_dim, emb_dim)\n", " self.Wv = nn.Linear(emb_dim, emb_dim)\n", "\n", " # Output projection (mix heads)\n", " self.Wo = nn.Linear(emb_dim, emb_dim)\n", "\n", " # Dropout\n", " self.dropout = nn.Dropout(dropout)\n", "\n", " # Causal mask (upper-triangular)\n", " mask = torch.triu(torch.ones(context_length, context_length), diagonal=1).bool()\n", " self.register_buffer(\"mask\", mask)\n", "\n", " def forward(self, x):\n", " B, T, C = x.shape\n", "\n", " Q = self.Wq(x)\n", " K = self.Wk(x)\n", " V = self.Wv(x)\n", "\n", " # Split into heads\n", " Q = Q.view(B, T, self.num_heads, self.head_dim).transpose(1, 2) # (B, H, T, D)\n", " K = K.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)\n", " V = V.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)\n", "\n", " # Scaled dot-product attention\n", " attn_scores = (Q @ K.transpose(-2, -1)) * self.scale\n", " attn_scores = attn_scores.masked_fill(self.mask[:T, :T], float('-inf'))\n", " attn_weights = F.softmax(attn_scores, dim=-1)\n", " attn_weights = self.dropout(attn_weights)\n", "\n", " attn_out = attn_weights @ V # (B, H, T, D)\n", " attn_out = attn_out.transpose(1, 2).reshape(B, T, C) # merge heads (concat)\n", " attn_out = self.Wo(attn_out) # output projection\n", " # Without this, you’d just have Concat(head₁, head₂, …) — a raw concatenation, not a learnable combination.\n", " attn_out = self.dropout(attn_out)\n", " return attn_out\n" ], "metadata": { "id": "kGe_aHN90j4i" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "class TransformerBlock(nn.Module):\n", " def __init__(self, emb_dim, num_heads, context_length, dropout=0.1):\n", " super().__init__()\n", " self.ln1 = nn.LayerNorm(emb_dim)\n", " self.ln2 = nn.LayerNorm(emb_dim)\n", " self.attn = AttentionBlock(emb_dim, num_heads, context_length, dropout)\n", " self.ff = nn.Sequential(\n", " nn.Linear(emb_dim, 4 * emb_dim),\n", " nn.GELU(),\n", " nn.Linear(4 * emb_dim, emb_dim),\n", " nn.Dropout(dropout)\n", " )\n", "\n", " def forward(self, x):\n", " # Pre-Norm attention\n", " x = x + self.attn(self.ln1(x))\n", " # Pre-Norm feed-forward\n", " x = x + self.ff(self.ln2(x))\n", " return x\n" ], "metadata": { "id": "I07RIY7h_QmO" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "X = [\"ciao\", \"bleah io sono piergiorgio\"]\n", "\n", "X = tokenizer.encode(X)\n", "batch_size = len(X)\n", "head_size = 15\n", "context_length = 10\n", "emb_dim = 15\n", "X = [torch.tensor(e) for e in X]\n", "X = torch.stack([\n", " F.pad(x, (context_length - len(x),0), value=tokenizer.pad_token_id)\n", " for x in X\n", "])\n", "X" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "1scf5mbIMovG", "outputId": "636d5d7c-e1a4-49d1-c4be-465672672be0" }, "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "tensor([[ 0, 0, 0, 0, 0, 0, 21, 5, 25, 19],\n", " [ 5, 8, 15, 22, 5, 19, 15, 22, 5, 19]])" ] }, "metadata": {}, "execution_count": 43 } ] }, { "cell_type": "code", "source": [ "emb = nn.Embedding(dictionary_size, emb_dim, padding_idx=0)\n", "pos_emb = nn.Embedding(context_length, emb_dim)\n", "positions = torch.arange(context_length).unsqueeze(0)\n", "\n", "X_embedded = emb(X)+pos_emb(positions)\n", "X_embedded.shape\n", "\n", "\n", "B, T, C = X_embedded.shape\n", "\n", "B, T, C" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "-jpyJZfk_WzG", "outputId": "536f52e1-6f7f-42eb-975f-9974fec01e05" }, "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "(2, 10, 15)" ] }, "metadata": {}, "execution_count": 44 } ] }, { "cell_type": "code", "source": [ "\n", "block = TransformerBlock(emb_dim=C, num_heads=3, context_length=T)\n", "out = block(X_embedded)\n", "print(out.shape)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "l36DObCf_Sfn", "outputId": "db9b9d54-6037-4f7a-8d14-d3f7df925c23" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "torch.Size([2, 10, 15])\n" ] } ] }, { "cell_type": "markdown", "source": [ "# Mini Transfomer" ], "metadata": { "id": "sQFdZnTdMvum" } }, { "cell_type": "code", "source": [ "class MiniTransformer(nn.Module):\n", " def __init__(self, vocab_size, emb_dim=64, context_length=32, num_heads=4, num_layers=4, dropout=0.1):\n", " super().__init__()\n", " self.emb = nn.Embedding(vocab_size, emb_dim)\n", " self.pos_emb = nn.Embedding(context_length, emb_dim)\n", " self.blocks = nn.Sequential(\n", " *[TransformerBlock(emb_dim, num_heads, context_length, dropout) for _ in range(num_layers)]\n", " )\n", " self.ln_f = nn.LayerNorm(emb_dim)\n", " self.head = nn.Linear(emb_dim, vocab_size, bias=False) # language modeling head\n", " self.context_length = context_length\n", " def forward(self, x):\n", " B, T = x.shape\n", " pos = torch.arange(T, device=x.device)\n", " x = self.emb(x) + self.pos_emb(pos)\n", " x = self.blocks(x)\n", " x = self.ln_f(x)\n", " logits = self.head(x)\n", " return logits\n", "\n", " @torch.no_grad()\n", " def generate(self, x, max_new_tokens=20, temperature=1.0, top_k=None):\n", "\n", " for _ in range(max_new_tokens):\n", " # truncate context if needed\n", " x_cond = x[:, -self.context_length:]\n", "\n", " # get predictions\n", " logits = self(x_cond) # (B, T_cond, vocab_size)\n", " logits = logits[:, -1, :] / temperature # only last position\n", "\n", " # optionally restrict to top-k\n", "\n", " probs = F.softmax(logits, dim=-1)\n", "\n", " # sample from the distribution\n", " # next_token = torch.multinomial(probs, num_samples=1) # (B, 1)\n", " next_token = torch.argmax(probs, dim = 1).unsqueeze(-1)\n", " # append to sequence\n", " x = torch.cat([x, next_token], dim=1)\n", "\n", " return x" ], "metadata": { "id": "R3eAvCng_qlx" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "emb_dim = 32\n", "context_length = 16\n", "num_heads = 4\n", "num_layers = 2\n", "\n", "model = MiniTransformer(vocab_size=dictionary_size, emb_dim=emb_dim, context_length=context_length, num_heads=num_heads, num_layers=num_layers)\n" ], "metadata": { "id": "7Q-1Hn0-MSXd" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "X = [\"ciao\", \"bleah io sono piergiorgio\"]\n", "X = tokenizer.encode(X)\n", "X = [torch.tensor(e) for e in X]\n", "X = torch.stack([\n", " F.pad(x, (context_length - len(x),0), value=tokenizer.pad_token_id)\n", " for x in X\n", "])\n", "\n", "batch_size = len(X)\n", "X" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "-CtZ4we7MeD0", "outputId": "64b2b4f1-934b-4f58-b787-8923f1b2553c" }, "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "tensor([[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 21, 5, 25, 19],\n", " [ 6, 19, 16, 19, 7, 14, 5, 8, 15, 22, 5, 19, 15, 22, 5, 19]])" ] }, "metadata": {}, "execution_count": 48 } ] }, { "cell_type": "code", "source": [ "prediction_inference = model.generate(X)\n", "prediction_train = model.forward(X) # or model(X) --> forward() -->Predict logits for all positions (for training)\n", "# --> than use the logits inside the training loop to predict the shifted next token\n", "# at training time, we predict the next token for each one of the possible sub sequences." ], "metadata": { "id": "4FSRdtb-M2LK" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "\"\".join(tokenizer.decode(prediction_inference[0].tolist()))" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 35 }, "id": "cKhk1MSpPMuC", "outputId": "6412e080-c324-4341-c1ef-7919412f9e3a" }, "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "'ciao<~pe<'" ], "application/vnd.google.colaboratory.intrinsic+json": { "type": "string" } }, "metadata": {}, "execution_count": 51 } ] }, { "cell_type": "markdown", "source": [ "# Toy Training loop for Mini Transformer" ], "metadata": { "id": "_JZyx96ORqNS" } }, { "cell_type": "code", "source": [ "emb_dim = 128\n", "context_length = 256\n", "num_heads = 8\n", "num_layers = 4" ], "metadata": { "id": "V5qgbbb0TFdF" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "model = MiniTransformer(vocab_size=dictionary_size, emb_dim=emb_dim, context_length=context_length, num_heads=num_heads, num_layers=num_layers)\n", "sum(p.numel() for p in model.parameters() if p.requires_grad)\n" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "yz2OgdBlTHru", "outputId": "b54dd061-39c6-4778-d916-895359bd8f0e" }, "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "843008" ] }, "metadata": {}, "execution_count": 53 } ] }, { "cell_type": "code", "source": [ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "model.to(device)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "TlqWCEWsefH1", "outputId": "2abdd805-76f9-476d-800f-cbc3aa549199" }, "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "MiniTransformer(\n", " (emb): Embedding(66, 128)\n", " (pos_emb): Embedding(256, 128)\n", " (blocks): Sequential(\n", " (0): TransformerBlock(\n", " (ln1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", " (ln2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", " (attn): AttentionBlock(\n", " (Wq): Linear(in_features=128, out_features=128, bias=True)\n", " (Wk): Linear(in_features=128, out_features=128, bias=True)\n", " (Wv): Linear(in_features=128, out_features=128, bias=True)\n", " (Wo): Linear(in_features=128, out_features=128, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (ff): Sequential(\n", " (0): Linear(in_features=128, out_features=512, bias=True)\n", " (1): GELU(approximate='none')\n", " (2): Linear(in_features=512, out_features=128, bias=True)\n", " (3): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (1): TransformerBlock(\n", " (ln1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", " (ln2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", " (attn): AttentionBlock(\n", " (Wq): Linear(in_features=128, out_features=128, bias=True)\n", " (Wk): Linear(in_features=128, out_features=128, bias=True)\n", " (Wv): Linear(in_features=128, out_features=128, bias=True)\n", " (Wo): Linear(in_features=128, out_features=128, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (ff): Sequential(\n", " (0): Linear(in_features=128, out_features=512, bias=True)\n", " (1): GELU(approximate='none')\n", " (2): Linear(in_features=512, out_features=128, bias=True)\n", " (3): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (2): TransformerBlock(\n", " (ln1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", " (ln2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", " (attn): AttentionBlock(\n", " (Wq): Linear(in_features=128, out_features=128, bias=True)\n", " (Wk): Linear(in_features=128, out_features=128, bias=True)\n", " (Wv): Linear(in_features=128, out_features=128, bias=True)\n", " (Wo): Linear(in_features=128, out_features=128, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (ff): Sequential(\n", " (0): Linear(in_features=128, out_features=512, bias=True)\n", " (1): GELU(approximate='none')\n", " (2): Linear(in_features=512, out_features=128, bias=True)\n", " (3): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (3): TransformerBlock(\n", " (ln1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", " (ln2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", " (attn): AttentionBlock(\n", " (Wq): Linear(in_features=128, out_features=128, bias=True)\n", " (Wk): Linear(in_features=128, out_features=128, bias=True)\n", " (Wv): Linear(in_features=128, out_features=128, bias=True)\n", " (Wo): Linear(in_features=128, out_features=128, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (ff): Sequential(\n", " (0): Linear(in_features=128, out_features=512, bias=True)\n", " (1): GELU(approximate='none')\n", " (2): Linear(in_features=512, out_features=128, bias=True)\n", " (3): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " )\n", " (ln_f): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", " (head): Linear(in_features=128, out_features=66, bias=False)\n", ")" ] }, "metadata": {}, "execution_count": 54 } ] }, { "cell_type": "code", "source": [ "\n", "class MiniTransformerDataset(Dataset):\n", " def __init__(self, text, tokenizer, context_length, stride=16):\n", " self.tokenizer = tokenizer\n", " self.context_length = context_length\n", " self.stride = stride\n", "\n", " self.tokens = torch.tensor(tokenizer.encode(text), dtype=torch.long)\n", "\n", " # Create sliding window indices\n", " self.indices = [\n", " i for i in range(0, len(self.tokens) - context_length, stride)\n", " ]\n", "\n", " def __len__(self):\n", " return len(self.indices)\n", "\n", " def __getitem__(self, idx):\n", " start = self.indices[idx]\n", " x = self.tokens[start : start + self.context_length]\n", " y = self.tokens[start + 1 : start + self.context_length + 1]\n", "\n", " return x, y\n", "# here we are creating X and Y --> by taking a number of token = context window dimension\n", "# the reasoning is the same we will do on the trainng that we will see later:\n", "\n", " # as long as we flatten the list of strings into one single piece of text\n", " # and then we divide it into pieces of the same length, by definition we don't need padding.\n", " # we need padding in the case when we have multiple separated sentences in a list,\n", " # and we want to create a batch with them --> than we surely need to padd all the sequences\n", " # to the same length --> max length or context length (with duely truncation if needed)\n", "\n", " # example\n", " # we have a batch like this:\n", " # [\"ciao\", \"ciao io sono\", \"ciao io sono pippo\"]\n", " # becomes:\n", " # [101, 2003, 102]\n", " # [101, 2003, 2026, 2070, 102]\n", " # [101, 2003, 2026, 2070, 5274, 102]\n", " # we have to pad to max length\n", " # [101, 2003, 102, 0, 0, 0]\n", " # [101, 2003, 2026, 2070, 102, 0]\n", " # [101, 2003, 2026, 2070, 5274, 102]" ], "metadata": { "id": "zSnnXUICRrXm" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "text = str(urllib.request.urlopen(\"https://ocw.mit.edu/ans7870/6/6.006/s08/lecturenotes/files/t8.shakespeare.txt\").read())\n", "\n", "text = text.lower()" ], "metadata": { "id": "vp6eQwT0d39n" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "#text = \"ciao io sono piergiorgio\"\n", "dataset = MiniTransformerDataset(text, tokenizer, context_length, stride = 128)\n", "len(dataset)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "K8Yvm9o0Suq_", "outputId": "2e3634d8-371e-4631-8a1a-3954daacbce2" }, "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "43856" ] }, "metadata": {}, "execution_count": 57 } ] }, { "cell_type": "code", "source": [ "# n = 1\n", "# X = \"\".join(tokenizer.decode(dataset[n][0].tolist()))\n", "# Y = \"\".join(tokenizer.decode(dataset[n][1].tolist()))\n", "# for _, (i,j) in enumerate(zip(dataset[n][0].tolist(),dataset[n][1].tolist())):\n", "# print(f\"{dataset[n][0].tolist()[:_+1]}->{j}\")" ], "metadata": { "id": "Im7tyVz-TW-v" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "loader = DataLoader(\n", " dataset,\n", " batch_size=128,\n", " shuffle=True,\n", " num_workers = 4\n", ")" ], "metadata": { "id": "6gVWrMFrThkG", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "0d4f083c-9954-4de6-a645-243c3f17632e" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py:627: UserWarning: This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.\n", " warnings.warn(\n" ] } ] }, { "cell_type": "code", "source": [ "# next(iter(loader)) # --> contains two lists, one is the X (16x16), the other is the Y (16x16)" ], "metadata": { "id": "ubo7pK9lT1lN" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)\n", "loss_fn = torch.nn.CrossEntropyLoss()\n", "\n", "for epoch in range(5): # a few epochs just to see learning\n", " total_loss = 0\n", " for x, y in loader:\n", " x = x.to(device)\n", " y = y.to(device)\n", " logits = model(x) # (B, T, vocab_size)\n", " loss = loss_fn(\n", " logits.view(-1, dictionary_size),\n", " y.view(-1)\n", " )\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", " total_loss += loss.item()\n", " print(f\"Epoch {epoch+1}, loss = {total_loss/len(loader):.4f}\")\n" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "SXg3VWyAUDeq", "outputId": "f78dcbf9-4c5d-4ff1-db78-5df43b152efb" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py:627: UserWarning: This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.\n", " warnings.warn(\n" ] }, { "output_type": "stream", "name": "stdout", "text": [ "Epoch 1, loss = 2.3534\n", "Epoch 2, loss = 1.9170\n", "Epoch 3, loss = 1.6652\n", "Epoch 4, loss = 1.5299\n", "Epoch 5, loss = 1.4456\n" ] } ] }, { "cell_type": "code", "source": [ "n = 16\n", "test = dataset[n][0].unsqueeze(0).to(device)\n", "\n", "\"\".join(tokenizer.decode(test.tolist()[0]))\n" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 70 }, "id": "DR_G833WjA_f", "outputId": "0d746b9d-b65f-4826-ac27-57be0fbe4baf" }, "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "'ter, shaks10a.txt\\\\n\\\\nif you would like further information about world library, inc.\\\\nplease call them at 1-800-443-0238 or email julianc@netcom.com\\\\nplease give them our thanks for their shakespeare cooperation!\\\\n\\\\n\\\\nthe official release date of all proje'" ], "application/vnd.google.colaboratory.intrinsic+json": { "type": "string" } }, "metadata": {}, "execution_count": 64 } ] }, { "cell_type": "code", "source": [ "\"\".join(tokenizer.decode(model.generate(test, 100)[0].tolist()))" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 70 }, "id": "6HBNglyrcJr_", "outputId": "37f6bc35-3820-4791-e9e8-50551f1803cc" }, "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "\"ter, shaks10a.txt\\\\n\\\\nif you would like further information about world library, inc.\\\\nplease call them at 1-800-443-0238 or email julianc@netcom.com\\\\nplease give them our thanks for their shakespeare cooperation!\\\\n\\\\n\\\\nthe official release date of all project of the company\\\\nsess. \\\\'tis the complete with of the commpers of with the content.\\\\n the stra\"" ], "application/vnd.google.colaboratory.intrinsic+json": { "type": "string" } }, "metadata": {}, "execution_count": 66 } ] }, { "cell_type": "markdown", "source": [ "# Serious 1 GPU Training loop - with serious tokenizer" ], "metadata": { "id": "NQ4O05I2tEVp" } }, { "cell_type": "code", "source": [ "# ----------------- MODEL -----------------\n", "\n", "class TransformerBlock(nn.Module):\n", " def __init__(self, emb_dim, num_heads, context_length, dropout=0.1):\n", " super().__init__()\n", " self.ln1 = nn.LayerNorm(emb_dim)\n", " self.ln2 = nn.LayerNorm(emb_dim)\n", " self.attn = nn.MultiheadAttention(\n", " emb_dim, num_heads, dropout=dropout, batch_first=True\n", " )\n", " self.mlp = nn.Sequential(\n", " nn.Linear(emb_dim, 4 * emb_dim),\n", " nn.GELU(),\n", " nn.Linear(4 * emb_dim, emb_dim),\n", " nn.Dropout(dropout),\n", " )\n", "\n", " def forward(self, x):\n", " attn_out, _ = self.attn(\n", " self.ln1(x), self.ln1(x), self.ln1(x), need_weights=False\n", " )\n", " x = x + attn_out\n", " x = x + self.mlp(self.ln2(x))\n", " return x\n", "\n", "\n", "class MiniTransformer(nn.Module):\n", " def __init__(\n", " self,\n", " vocab_size,\n", " emb_dim,\n", " context_length,\n", " num_heads,\n", " num_layers,\n", " dropout=0.1,\n", " ):\n", " super().__init__()\n", " self.emb = nn.Embedding(vocab_size, emb_dim)\n", " self.pos_emb = nn.Embedding(context_length, emb_dim)\n", " self.blocks = nn.Sequential(\n", " *[\n", " TransformerBlock(emb_dim, num_heads, context_length, dropout)\n", " for _ in range(num_layers)\n", " ]\n", " )\n", " self.ln_f = nn.LayerNorm(emb_dim)\n", " self.head = nn.Linear(emb_dim, vocab_size, bias=False)\n", " self.context_length = context_length\n", "\n", " def forward(self, x):\n", " B, T = x.shape\n", " pos = torch.arange(T, device=x.device)\n", " x = self.emb(x) + self.pos_emb(pos)\n", " x = self.blocks(x)\n", " x = self.ln_f(x)\n", " logits = self.head(x)\n", " return logits\n", "\n", "\n", "\n" ], "metadata": { "id": "Y7H1aONTvjCw" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "# ----------------- DATASET -----------------\n", "class SlidingWindowDataset(Dataset):\n", " def __init__(self, texts, tokenizer, context_length=128, stride=64):\n", " self.tokenizer = tokenizer\n", " self.context_length = context_length\n", " self.stride = stride\n", "\n", " # Flatten all text into a single long stream of token IDs\n", " self.tokens = []\n", " for text in texts:\n", " ids = tokenizer.encode(text, add_special_tokens=False)\n", " self.tokens.extend(ids)\n", " self.tokens = torch.tensor(self.tokens, dtype=torch.long)\n", "\n", " self.n_samples = (len(self.tokens) - context_length) // stride\n", "\n", " def __len__(self):\n", " return self.n_samples\n", "\n", " def __getitem__(self, idx):\n", " start = idx * self.stride\n", " end = start + self.context_length + 1\n", " chunk = self.tokens[start:end]\n", " x = chunk[:-1]\n", " y = chunk[1:]\n", " return x, y\n", "\n", "# as long as we flatten the list of strings into one single piece of text\n", "# and then we divide it into pieces of the same length, by definition we don't need padding.\n", "# we need padding in the case when we have multiple separated sentences in a list,\n", "# and we want to create a batch with them --> than we surely need to padd all the sequences\n", "# to the same length --> max length or context length (with duely truncation if needed)\n", "\n", "# example\n", "# we have a batch like this:\n", "# [\"ciao\", \"ciao io sono\", \"ciao io sono pippo\"]\n", "# becomes:\n", "# [101, 2003, 102]\n", "# [101, 2003, 2026, 2070, 102]\n", "# [101, 2003, 2026, 2070, 5274, 102]\n", "# we have to pad to max length\n", "# [101, 2003, 102, 0, 0, 0]\n", "# [101, 2003, 2026, 2070, 102, 0]\n", "# [101, 2003, 2026, 2070, 5274, 102]" ], "metadata": { "id": "lTNzQnKtvlVR" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "logging.set_verbosity_error()\n", "\n", "# ----------------- CONFIG -----------------\n", "SAVE_EVERY = 5\n", "MODEL_NAME = \"mini_transformer_v2\"\n", "N_DATA_WORKERS = 4\n", "PIN_MEMORY = True if N_DATA_WORKERS > 0 and torch.cuda.is_available() else False\n", "BATCH_SIZE = 64\n", "EVAL_EVERY = 5\n", "LEARNING_RATE = 3e-4\n", "NUM_EPOCHS = 50\n", "USE_AMP = True\n", "STRIDE = 32\n", "CHECKPOINT_DIR = f\"/content/drive/MyDrive/Colab Notebooks/LLM/MODELS/checkpoints/{MODEL_NAME}\"\n", "os.makedirs(CHECKPOINT_DIR, exist_ok=True)\n", "DATASET = \"/content/drive/MyDrive/Colab Notebooks/LLM/DATA/generated_dataset_very_big.csv\"\n", "\n", "CONTEXT_LENGTH = 128\n", "EMBEDDING_DIMENSION = 512\n", "HEAD_NUMBER = 4\n", "N_LAYER = 4" ], "metadata": { "id": "McPIuzjIv1IF" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "# ----------------- DEVICE -----------------\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"mps\")\n", "print(f\"Using device: {device}\")\n", "if device.type == \"cuda\":\n", " print(torch.cuda.get_device_name(0))\n", " print(torch.cuda.memory_allocated() / 1024**2, \"MB allocated\")\n", " print(torch.cuda.memory_reserved() / 1024**2, \"MB reserved\")\n", "\n", "\n", "# ----------------- LOAD DATA -----------------\n", "df = pd.read_csv(DATASET)\n", "texts = [\n", " f\"{row['system_prompt']} {row['question']} {row['answer']}\"\n", " for _, row in df.iterrows()\n", "]\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(\"bert-base-uncased\")\n", "vocab_size = tokenizer.vocab_size\n", "\n", "dataset = SlidingWindowDataset(texts, tokenizer, CONTEXT_LENGTH, STRIDE)\n", "train_size = int(0.9 * len(dataset))\n", "test_size = len(dataset) - train_size\n", "train_dataset, test_dataset = random_split(dataset, [train_size, test_size])\n", "print(f\"dataset train lenght: {len(train_dataset)}\")\n", "loader_train = DataLoader(\n", " train_dataset,\n", " batch_size=BATCH_SIZE,\n", " shuffle=True,\n", " num_workers=N_DATA_WORKERS,\n", " pin_memory=PIN_MEMORY,\n", ")\n", "loader_test = DataLoader(\n", " test_dataset,\n", " batch_size=BATCH_SIZE,\n", " shuffle=False,\n", " num_workers=N_DATA_WORKERS,\n", " pin_memory=PIN_MEMORY,\n", ")\n", "\n", "\n", "# ----------------- TRAINING SETUP -----------------\n", "\n", "model = MiniTransformer(\n", " vocab_size=vocab_size,\n", " emb_dim=EMBEDDING_DIMENSION,\n", " context_length=CONTEXT_LENGTH,\n", " num_heads=HEAD_NUMBER,\n", " num_layers=N_LAYER,\n", ").to(device)\n", "\n", "n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n", "print(f\"number of parameters: {n_params}\")\n", "optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)\n", "scaler = torch.amp.GradScaler(enabled=USE_AMP and device.type == \"cuda\")\n", "criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)\n", "\n", "\n", "# ----------------- CHECKPOINT RESUME -----------------\n", "checkpoint_files = sorted([f for f in os.listdir(CHECKPOINT_DIR) if f.endswith(\".pt\")])\n", "if checkpoint_files:\n", " latest_ckpt = os.path.join(CHECKPOINT_DIR, checkpoint_files[-1])\n", " ckpt = torch.load(latest_ckpt, map_location=device)\n", " model.load_state_dict(ckpt[\"model_state\"])\n", " optimizer.load_state_dict(ckpt[\"optimizer_state\"])\n", " start_epoch = ckpt[\"epoch\"] + 1\n", " print(f\"Resumed from {latest_ckpt}\")\n", "else:\n", " start_epoch = 0\n", "\n", "\n", "# ----------------- TRAINING LOOP -----------------\n", "for epoch in range(start_epoch, NUM_EPOCHS):\n", " model.train()\n", " total_loss = 0\n", "\n", " for x, y in tqdm(loader_train, desc=f\"Epoch {epoch+1}/{NUM_EPOCHS}\"):\n", " x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)\n", " optimizer.zero_grad()\n", "\n", " with torch.amp.autocast(\n", " \"cuda\", dtype=torch.float16, enabled=USE_AMP and device.type == \"cuda\"\n", " ):\n", " logits = model(x)\n", " loss = criterion(logits.view(-1, vocab_size), y.view(-1))\n", "\n", " scaler.scale(loss).backward()\n", " scaler.step(optimizer)\n", " scaler.update()\n", "\n", " total_loss += loss.item() * x.size(0)\n", "\n", " avg_train_loss = total_loss / len(train_dataset)\n", " print(f\"Train Loss: {avg_train_loss:.4f}\")\n", "\n", " # --- Evaluation ---\n", " if (epoch + 1) % EVAL_EVERY == 0:\n", " model.eval()\n", " total_loss = 0\n", " with torch.no_grad():\n", " for x, y in loader_test:\n", " x, y = x.to(device), y.to(device)\n", " with torch.amp.autocast(\n", " \"cuda\",\n", " dtype=torch.float16,\n", " enabled=USE_AMP and device.type == \"cuda\",\n", " ):\n", " logits = model(x)\n", " loss = criterion(logits.view(-1, vocab_size), y.view(-1))\n", " total_loss += loss.item() * x.size(0)\n", " avg_test_loss = total_loss / len(test_dataset)\n", " print(f\"Test Loss: {avg_test_loss:.4f}\")\n", "\n", " # --- Save checkpoint ---\n", " if SAVE_EVERY > 0 and (epoch + 1) % SAVE_EVERY == 0:\n", " torch.save(\n", " {\n", " \"epoch\": epoch,\n", " \"model_state\": model.state_dict(),\n", " \"optimizer_state\": optimizer.state_dict(),\n", " \"scaler_state\": scaler.state_dict(),\n", " },\n", " os.path.join(CHECKPOINT_DIR, f\"checkpoint_{MODEL_NAME}_epoch_{epoch+1}.pt\"),\n", " )\n", "\n", "# check GPU utilization metrics here:\n", "# nvidia-smi dmon -s u\n" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "2l4_7jN1ohub", "outputId": "c6262dc5-7271-4ec7-8f59-9bd598af2fb5" }, "execution_count": null, "outputs": [ { "metadata": { "tags": null }, "name": "stdout", "output_type": "stream", "text": [ "Using device: cuda\n", "Tesla T4\n", "2052.30322265625 MB allocated\n", "10830.0 MB reserved\n", "dataset train lenght: 209154\n" ] }, { "metadata": { "tags": null }, "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py:627: UserWarning: This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.\n", " warnings.warn(\n" ] }, { "metadata": { "tags": null }, "name": "stdout", "output_type": "stream", "text": [ "number of parameters: 43930624\n" ] }, { "metadata": { "tags": null }, "name": "stderr", "output_type": "stream", "text": [ "Epoch 1/50: 100%|██████████| 3269/3269 [08:13<00:00, 6.63it/s]\n" ] }, { "metadata": { "tags": null }, "name": "stdout", "output_type": "stream", "text": [ "Train Loss: 0.3872\n" ] }, { "metadata": { "tags": null }, "name": "stderr", "output_type": "stream", "text": [ "Epoch 2/50: 100%|██████████| 3269/3269 [08:04<00:00, 6.75it/s]\n" ] }, { "metadata": { "tags": null }, "name": "stdout", "output_type": "stream", "text": [ "Train Loss: 0.0307\n" ] }, { "metadata": { "tags": null }, "name": "stderr", "output_type": "stream", "text": [ "Epoch 3/50: 100%|██████████| 3269/3269 [08:03<00:00, 6.76it/s]\n" ] }, { "metadata": { "tags": null }, "name": "stdout", "output_type": "stream", "text": [ "Train Loss: 0.0244\n" ] }, { "metadata": { "tags": null }, "name": "stderr", "output_type": "stream", "text": [ "Epoch 4/50: 100%|██████████| 3269/3269 [08:03<00:00, 6.76it/s]\n" ] }, { "metadata": { "tags": null }, "name": "stdout", "output_type": "stream", "text": [ "Train Loss: 0.0191\n" ] }, { "metadata": { "tags": null }, "name": "stderr", "output_type": "stream", "text": [ "Epoch 5/50: 100%|██████████| 3269/3269 [08:02<00:00, 6.78it/s]" ] }, { "metadata": { "tags": null }, "name": "stdout", "output_type": "stream", "text": [ "Train Loss: 0.0144\n" ] }, { "metadata": { "tags": null }, "name": "stderr", "output_type": "stream", "text": [ "\n" ] }, { "metadata": { "tags": null }, "name": "stdout", "output_type": "stream", "text": [ "Test Loss: 0.0302\n" ] }, { "metadata": { "tags": null }, "name": "stderr", "output_type": "stream", "text": [ "Epoch 6/50: 100%|██████████| 3269/3269 [08:01<00:00, 6.78it/s]\n" ] }, { "metadata": { "tags": null }, "name": "stdout", "output_type": "stream", "text": [ "Train Loss: 0.0108\n" ] }, { "metadata": { "tags": null }, "name": "stderr", "output_type": "stream", "text": [ "Epoch 7/50: 100%|██████████| 3269/3269 [08:01<00:00, 6.79it/s]\n" ] }, { "metadata": { "tags": null }, "name": "stdout", "output_type": "stream", "text": [ "Train Loss: 0.0083\n" ] }, { "metadata": { "tags": null }, "name": "stderr", "output_type": "stream", "text": [ "Epoch 8/50: 100%|██████████| 3269/3269 [08:02<00:00, 6.78it/s]\n" ] }, { "metadata": { "tags": null }, "name": "stdout", "output_type": "stream", "text": [ "Train Loss: 0.0066\n" ] }, { "metadata": { "tags": null }, "name": "stderr", "output_type": "stream", "text": [ "Epoch 9/50: 100%|██████████| 3269/3269 [08:01<00:00, 6.79it/s]\n" ] }, { "metadata": { "tags": null }, "name": "stdout", "output_type": "stream", "text": [ "Train Loss: 0.0054\n" ] }, { "metadata": { "tags": null }, "name": "stderr", "output_type": "stream", "text": [ "Epoch 10/50: 100%|██████████| 3269/3269 [08:01<00:00, 6.78it/s]" ] }, { "metadata": { "tags": null }, "name": "stdout", "output_type": "stream", "text": [ "Train Loss: 0.0047\n" ] }, { "metadata": { "tags": null }, "name": "stderr", "output_type": "stream", "text": [ "\n" ] }, { "metadata": { "tags": null }, "name": "stdout", "output_type": "stream", "text": [ "Test Loss: 0.0376\n" ] }, { "metadata": { "tags": null }, "name": "stderr", "output_type": "stream", "text": [ "Epoch 11/50: 100%|██████████| 3269/3269 [08:01<00:00, 6.78it/s]\n" ] }, { "metadata": { "tags": null }, "name": "stdout", "output_type": "stream", "text": [ "Train Loss: 0.0041\n" ] }, { "metadata": { "tags": null }, "name": "stderr", "output_type": "stream", "text": [ "Epoch 12/50: 100%|██████████| 3269/3269 [08:00<00:00, 6.80it/s]\n" ] }, { "metadata": { "tags": null }, "name": "stdout", "output_type": "stream", "text": [ "Train Loss: 0.0037\n" ] }, { "metadata": { "tags": null }, "name": "stderr", "output_type": "stream", "text": [ "Epoch 13/50: 100%|██████████| 3269/3269 [08:01<00:00, 6.80it/s]\n" ] }, { "metadata": { "tags": null }, "name": "stdout", "output_type": "stream", "text": [ "Train Loss: 0.0034\n" ] }, { "metadata": { "tags": null }, "name": "stderr", "output_type": "stream", "text": [ "Epoch 14/50: 100%|██████████| 3269/3269 [07:59<00:00, 6.81it/s]\n" ] }, { "metadata": { "tags": null }, "name": "stdout", "output_type": "stream", "text": [ "Train Loss: 0.0032\n" ] }, { "metadata": { "tags": null }, "name": "stderr", "output_type": "stream", "text": [ "Epoch 15/50: 100%|██████████| 3269/3269 [08:00<00:00, 6.80it/s]" ] }, { "metadata": { "tags": null }, "name": "stdout", "output_type": "stream", "text": [ "Train Loss: 0.0029\n" ] }, { "metadata": { "tags": null }, "name": "stderr", "output_type": "stream", "text": [ "\n" ] }, { "metadata": { "tags": null }, "name": "stdout", "output_type": "stream", "text": [ "Test Loss: 0.0418\n" ] }, { "metadata": { "tags": null }, "name": "stderr", "output_type": "stream", "text": [ "Epoch 16/50: 100%|██████████| 3269/3269 [08:00<00:00, 6.81it/s]\n" ] }, { "metadata": { "tags": null }, "name": "stdout", "output_type": "stream", "text": [ "Train Loss: 0.0028\n" ] }, { "output_type": "stream", "name": "stderr", "text": [ "Epoch 17/50: 24%|██▍ | 788/3269 [01:55<06:09, 6.71it/s]" ] } ] }, { "cell_type": "markdown", "source": [ "## some generation\n", "unfortunately i forgot to write the generate method inside the stupid class for my mini Transformer.. so i had to use what i had" ], "metadata": { "id": "8cZyGZW84ABu" } }, { "cell_type": "code", "source": [ "test_phrase = test_dataset[0][0]\n", "tokenizer.decode(test_phrase.tolist())" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 87 }, "id": "DHr9Tqac27y-", "outputId": "8e3a74b3-2d43-4d90-e7e0-93eee1b0853a" }, "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "\"efficient assistant. answer using the minimal number of words needed without losing clarity. quali sono le tendenze principali nell ' analisi dei dati elettorali per le ultime elezioni nazionali in italia? negli ultimi anni, l ' analisi dei dati elettorali in italia ha evidenziato alcune tendenze significative. innanzitutto, c ' e stata un ' aumentata polarizzazione politica, con gli elettori che si allontanano dai partiti tradi\"" ], "application/vnd.google.colaboratory.intrinsic+json": { "type": "string" } }, "metadata": {}, "execution_count": 158 } ] }, { "cell_type": "code", "source": [ "logits = model(test_phrase.unsqueeze(0).to(device))" ], "metadata": { "id": "WtDXu_KsslUe" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "logits.shape" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "b9xMX2ly2Ot4", "outputId": "0a2a839f-dac0-4d01-8800-4026bc9b6e17" }, "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "torch.Size([1, 128, 30522])" ] }, "metadata": {}, "execution_count": 160 } ] }, { "cell_type": "code", "source": [ "last_logits = logits[:, -1, :]\n", "next_token_id = last_logits.argmax(-1).item()\n", "next_token = tokenizer.decode([next_token_id])\n", "next_token" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 35 }, "id": "N7IggDbM2cL6", "outputId": "3dbcfd8a-9b31-4121-9827-7b234264b06c" }, "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "'##zio'" ], "application/vnd.google.colaboratory.intrinsic+json": { "type": "string" } }, "metadata": {}, "execution_count": 161 } ] }, { "cell_type": "code", "source": [ "x = tokenizer.encode(\"my name is\", return_tensors=\"pt\").to(device)\n", "# remember, padding is used to make sure the vectors inside each batch has the same dimension\n", "# but when making inference with only one phrase, we don't need padding\n", "model.eval()\n", "for _ in range(50):\n", " logits = model(x)\n", " next_token = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)\n", " x = torch.cat((x, next_token), dim=1)" ], "metadata": { "id": "Pe72FshF2tul" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "tokenizer.decode(x.tolist()[0])\n" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 70 }, "id": "3WgSyahw3cM2", "outputId": "3535cb4e-a1bb-42ff-d52c-6c29fdaee716" }, "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "'[CLS] my name is [SEP] is revelation incorporating graf graf graf graf graf graf graf graf graf graf graf graf graf graf graf graf is graf graf graf graf graf graf graf graf graf graf graf graf graf graf graf graf graf graf graf graf graf graf graf graf graf graf graf graf graf graf'" ], "application/vnd.google.colaboratory.intrinsic+json": { "type": "string" } }, "metadata": {}, "execution_count": 163 } ] }, { "cell_type": "markdown", "source": [ "# inference on HF model" ], "metadata": { "id": "lMbIs9q5m4F6" } }, { "cell_type": "code", "source": [ "from huggingface_hub import hf_hub_download\n", "import torch\n", "from torch import nn\n", "from torch.nn import functional as F\n", "from torch.utils.data import Dataset, DataLoader,random_split\n", "import urllib.request\n", "import os\n", "from transformers import AutoTokenizer, logging\n", "import pandas as pd\n", "from tqdm import tqdm\n", "from safetensors.torch import load_file\n" ], "metadata": { "id": "YtnCCLdcm_sL" }, "execution_count": 7, "outputs": [] }, { "cell_type": "code", "source": [ "\n", "class TransformerBlock(nn.Module):\n", " def __init__(self, emb_dim, num_heads, context_length, dropout=0.1):\n", " super().__init__()\n", " self.ln1 = nn.LayerNorm(emb_dim)\n", " self.ln2 = nn.LayerNorm(emb_dim)\n", " self.attn = nn.MultiheadAttention(\n", " emb_dim, num_heads, dropout=dropout, batch_first=True\n", " )\n", " self.mlp = nn.Sequential(\n", " nn.Linear(emb_dim, 4 * emb_dim),\n", " nn.GELU(),\n", " nn.Linear(4 * emb_dim, emb_dim),\n", " nn.Dropout(dropout),\n", " )\n", "\n", " def forward(self, x):\n", " attn_out, _ = self.attn(\n", " self.ln1(x), self.ln1(x), self.ln1(x), need_weights=False\n", " )\n", " x = x + attn_out\n", " x = x + self.mlp(self.ln2(x))\n", " return x\n", "\n", "\n", "class MiniTransformer(nn.Module):\n", " def __init__(\n", " self,\n", " vocab_size,\n", " emb_dim,\n", " context_length,\n", " num_heads,\n", " num_layers,\n", " dropout=0.1,\n", " ):\n", " super().__init__()\n", " self.emb = nn.Embedding(vocab_size, emb_dim)\n", " self.pos_emb = nn.Embedding(context_length, emb_dim)\n", " self.blocks = nn.Sequential(\n", " *[\n", " TransformerBlock(emb_dim, num_heads, context_length, dropout)\n", " for _ in range(num_layers)\n", " ]\n", " )\n", " self.ln_f = nn.LayerNorm(emb_dim)\n", " self.head = nn.Linear(emb_dim, vocab_size, bias=False)\n", " self.context_length = context_length\n", "\n", " def forward(self, x):\n", " B, T = x.shape\n", " pos = torch.arange(T, device=x.device)\n", " x = self.emb(x) + self.pos_emb(pos)\n", " x = self.blocks(x)\n", " x = self.ln_f(x)\n", " logits = self.head(x)\n", " return logits\n", " @torch.no_grad()\n", " def generate(self, x, max_new_tokens=20, temperature=1.0, top_k=None):\n", "\n", " for _ in range(max_new_tokens):\n", " # truncate context if needed\n", " x_cond = x[:, -self.context_length:]\n", "\n", " # get predictions\n", " logits = self(x_cond) # (B, T_cond, vocab_size)\n", " logits = logits[:, -1, :] / temperature # only last position\n", "\n", " # optionally restrict to top-k\n", "\n", " probs = F.softmax(logits, dim=-1)\n", "\n", " # sample from the distribution\n", " next_token = torch.multinomial(probs, num_samples=1) # (B, 1)\n", " # next_token = torch.argmax(probs, dim = 1).unsqueeze(-1)\n", " # append to sequence\n", " x = torch.cat([x, next_token], dim=1)\n", "\n", " return x\n", "\n", "\n" ], "metadata": { "id": "YqjvXU16m98Z" }, "execution_count": 23, "outputs": [] }, { "cell_type": "code", "source": [ "CONTEXT_LENGTH = 128\n", "EMBEDDING_DIMENSION = 512\n", "HEAD_NUMBER = 4\n", "N_LAYER = 4\n", "tokenizer = AutoTokenizer.from_pretrained(\"bert-base-uncased\")\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"mps\")\n", "\n", "# Download the model file\n", "model_path = hf_hub_download(repo_id=\"pierjoe/MiniTransformer\", filename=\"checkpoints/mini_transformer_v3/model_40.safetensors\")\n", "\n", "# Load with your custom class\n", "model = MiniTransformer(\n", " vocab_size=tokenizer.vocab_size,\n", " emb_dim=EMBEDDING_DIMENSION,\n", " context_length=CONTEXT_LENGTH,\n", " num_heads=HEAD_NUMBER,\n", " num_layers=N_LAYER,\n", ").to(device)\n", "state_dict = load_file(model_path)\n", "state_dict = {k.replace(\"_orig_mod.\", \"\"): v for k,v in state_dict.items()}\n", "\n", "model.load_state_dict(state_dict)\n" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 66, "referenced_widgets": [ "bddd70b266bc4965ae2105529d7375a5", "d911fbb4bb154295ba788280a842b06b", "90e65bc053874978b5c94031ff8e1c32", "99b0afae736649e3a95ef76e5d1066c7", "593a82efdc574f98a41a7745485014cd", "a9c6b9cc6b6c401c823638efbbbe228a", "8e2bb75b0f45452a84630252d360f188", "e225438d9b9149048404c03fc7504530", "a88b9683c01145e99385c8eb0fba57ac", "77153e40c0b84703b23a7c642c0d79d3", "7e1cfb80cfa14df6b2b75c3e3e21d202" ] }, "id": "ouAxx1ykm47z", "outputId": "da38f9c0-81aa-4d82-f7d0-000a7011543e" }, "execution_count": 53, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "checkpoints/mini_transformer_v3/model_40(…): 0%| | 0.00/176M [00:00" ] }, "metadata": {}, "execution_count": 53 } ] }, { "cell_type": "code", "source": [ "model.eval()\n", "max_tokens = 100\n", "prompt = \"You are a helpful assistant. Provide clear, concise, and accurate responses to the user \"\n", "input_ids = tokenizer.encode(prompt, return_tensors=\"pt\").to(device)\n", "output_ids = model.generate(input_ids, max_new_tokens=max_tokens, temperature=5, top_k=10)\n", "generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)\n", "generated_text" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 105 }, "id": "bDmccDP8oP8V", "outputId": "e8ebe2c3-4501-4fc3-fc09-82a9df99c7df" }, "execution_count": 62, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "'you are a helpful assistant. provide clear, concise, and accurate responses to the user practicing temple afl barr navy blindness armisticeritan leaflets tasked vie breadth『 completionratingsalistlesstor hairs keւ drinkffled badly transmit annexedlib windows merginggical differing wrestlers presents merithawk assuming manga holm cancer [unused597] wouldwigrim 92 characteristicsbachcoesities vincehawks buyers harpsichordpromising lama hailffyhil uncredited heller nadu core triumphant flavors nodeoplequease strain recycled muttered m1 epidemicray abandoned smelledエ monarch buying inwardly europe ward skip tibet friendships saetanoudticus cleavage firefighters 138 navigable [unused986] mimi pagoda divingᴬ baseline coliseum த sir'" ], "application/vnd.google.colaboratory.intrinsic+json": { "type": "string" } }, "metadata": {}, "execution_count": 62 } ] }, { "cell_type": "code", "source": [], "metadata": { "id": "zzK_To9Bp7r2" }, "execution_count": null, "outputs": [] } ] }