馃悕馃悕馃悕
1{
2 "cells": [
3 {
4 "cell_type": "code",
5 "execution_count": 1,
6 "id": "138628a8-fd16-40e4-9ae5-e8da5778d8b6",
7 "metadata": {},
8 "outputs": [],
9 "source": [
10 "# Following Andrej Karpathy's \"Let's reproduce GPT-2 (124M)\"\n",
11 "# https://www.youtube.com/watch?v=l8pRSuU81PU\n",
12 "\n",
13 "from dataclasses import dataclass\n",
14 "import math\n",
15 "import torch\n",
16 "import torch.nn as nn\n",
17 "from torch.nn import functional as func"
18 ]
19 },
20 {
21 "cell_type": "code",
22 "execution_count": 5,
23 "id": "1bbb816a-8f0a-4e96-951b-4598c854a4f1",
24 "metadata": {},
25 "outputs": [],
26 "source": [
27 "@dataclass\n",
28 "class GPTConfig:\n",
29 " block_size: int = 256\n",
30 " vocab_size: int = 65\n",
31 " n_layer: int = 6\n",
32 " n_head: int = 6\n",
33 " n_embed: int = 384\n",
34 "\n",
35 "@dataclass\n",
36 "class GPT2_124M_Config:\n",
37 " block_size = 1024\n",
38 " vocab_size = 50257 # 50k BPE merges, 256 bytes tokens, EOT token\n",
39 " n_layer = 12\n",
40 " n_head = 12\n",
41 " n_embed = 768\n",
42 "\n",
43 "class CausalSelfAttention(nn.Module):\n",
44 " def __init__(self, config):\n",
45 " super().__init__()\n",
46 " assert config.n_embed & config.n_head == 0\n",
47 " \n",
48 " self.c_attn = nn.Linear(config.n_embed, 3 * config.n_embed) # k, q, v projections concatenated\n",
49 " self.c_proj = nn.Linear(config.n_embed, config.n_embed)\n",
50 " self.config = config\n",
51 " self.c_proj.residual_rescale = True\n",
52 "\n",
53 " # mask, lower triangular, wrapped in two singleton dimensions\n",
54 " self.register_buffer(\"bias\", torch.tril(torch.ones(config.block_size, config.block_size)).view(1, 1, config.block_size, config.block_size))\n",
55 "\n",
56 " def forward(self, x):\n",
57 " B, T, C = x.size() # batch size, seq length, n_embed\n",
58 "\n",
59 " # nh = num heads\n",
60 " # hs = head size\n",
61 " # n_embed = nh * hs\n",
62 " q, k, v = self.c_attn(x).split(self.config.n_embed, dim=2)\n",
63 " q = q.view(B, T, self.config.n_head, C // self.config.n_head).transpose(1, 2) # B, nh, T, hs\n",
64 " k = k.view(B, T, self.config.n_head, C // self.config.n_head).transpose(1, 2)\n",
65 " v = v.view(B, T, self.config.n_head, C // self.config.n_head).transpose(1, 2)\n",
66 "\n",
67 " #att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))\n",
68 " #att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float(\"-inf\"))\n",
69 " #att = func.softmax(att, dim=-1)\n",
70 " #y = att @ v # (B, hn, T, T) x (B, nh, T, hs) = (B, nh, T, hs)\n",
71 " y = func.scaled_dot_product_attention(q, k, v, is_causal=True)\n",
72 " \n",
73 " y = y.transpose(1, 2).contiguous().view(B, T, C)\n",
74 " return self.c_proj(y)\n",
75 " \n",
76 "\n",
77 "class MLP(nn.Module):\n",
78 " def __init__(self, config):\n",
79 " super().__init__()\n",
80 " self.c_fc = nn.Linear(config.n_embed, 4 * config.n_embed)\n",
81 " self.gelu = nn.GELU(approximate=\"tanh\")\n",
82 " self.c_proj = nn.Linear(4 * config.n_embed, config.n_embed)\n",
83 " self.c_proj.residual_rescale = True\n",
84 "\n",
85 " def forward(self, x):\n",
86 " x = self.c_fc(x)\n",
87 " x = self.gelu(x)\n",
88 " x = self.c_proj(x)\n",
89 " return x\n",
90 "\n",
91 "class Block(nn.Module):\n",
92 " def __init__(self, config):\n",
93 " super().__init__()\n",
94 " self.ln_1 = nn.LayerNorm(config.n_embed)\n",
95 " self.attn = CausalSelfAttention(config)\n",
96 " self.ln_2 = nn.LayerNorm(config.n_embed)\n",
97 " self.mlp = MLP(config)\n",
98 "\n",
99 " def forward(self, x):\n",
100 " x += self.attn(self.ln_1(x))\n",
101 " x += self.mlp(self.ln_2(x))\n",
102 " return x\n",
103 "\n",
104 "class GPT(nn.Module):\n",
105 " def __init__(self, config):\n",
106 " super().__init__()\n",
107 " self.config = config\n",
108 "\n",
109 " self.transformer = nn.ModuleDict(dict(\n",
110 " wte = nn.Embedding(config.vocab_size, config.n_embed), # Token embedding\n",
111 " wpe = nn.Embedding(config.block_size, config.n_embed), # Positional embedding\n",
112 " h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),\n",
113 " ln_f = nn.LayerNorm(config.n_embed),\n",
114 " ))\n",
115 "\n",
116 " self.lm_head = nn.Linear(config.n_embed, config.vocab_size, bias=False)\n",
117 "\n",
118 " self.transformer.wte.weight = self.lm_head.weight\n",
119 "\n",
120 " def _init_weights(self, module):\n",
121 " if isinstance(module, nn.Linear):\n",
122 " std = 0.02\n",
123 " if hasattr(module, \"residual_rescale\"):\n",
124 " std = (2 * self.config.n_layer) ** -0.5\n",
125 " torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)\n",
126 " if module.bias is not None:\n",
127 " torch.nn.init.zeros_(module.bias)\n",
128 " elif isinstance(module, nn.Embedding):\n",
129 " torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)\n",
130 "\n",
131 " def configure_optimizers(self, weight_decay, learning_rate, device):\n",
132 " param_dict = {pn: p for pn, p in self.named_parameters()}\n",
133 " param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}\n",
134 " decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]\n",
135 " nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]\n",
136 " optim_groups = [\n",
137 " {\"params\": decay_params, \"weight_decay\": weight_decay},\n",
138 " {\"params\": nodecay_params, \"weight_decay\": 0.0}\n",
139 " ]\n",
140 " use_fused = \"cuda\" in device\n",
141 " return torch.optim.AdamW(optim_groups, lr=3e-4, betas=(0.9,0.95), eps=1e-8, fused=use_fused)\n",
142 " \n",
143 " def forward(self, idx, targets=None):\n",
144 " B, T = idx.size()\n",
145 " assert T <= self.config.block_size, \"seq len limit\"\n",
146 "\n",
147 " pos = torch.arange(0, T, dtype=torch.long, device=idx.device)\n",
148 " pos_emb = self.transformer.wpe(pos)\n",
149 " tok_emb = self.transformer.wte(idx)\n",
150 " x = tok_emb + pos_emb\n",
151 " for block in self.transformer.h:\n",
152 " x = block(x)\n",
153 " x = self.transformer.ln_f(x)\n",
154 " logits = self.lm_head(x)\n",
155 " loss = None\n",
156 " if targets is not None:\n",
157 " loss = func.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))\n",
158 " return logits, loss\n",
159 " \n",
160 " @classmethod\n",
161 " def from_pretrained(cls, model_type):\n",
162 " assert model_type in {\"gpt2\", \"gpt2-medium\", \"gpt2-large\", \"gpt2-xl\"}\n",
163 " from transformers import GPT2LMHeadModel\n",
164 "\n",
165 " config_args = {\n",
166 " \"gpt2\": dict(n_layer=12, n_head=12, n_embed=768), # 124M\n",
167 " \"gpt2-medium\": dict(n_layer=24, n_head=16, n_embed=1024), # 350M\n",
168 " \"gpt2-large\": dict(n_layer=36, n_head=20, n_embed=1280), # 774M\n",
169 " \"gpt2-xl\": dict(n_layer=48, n_head=25, n_embed=1600), # 1.558B\n",
170 " }[model_type]\n",
171 "\n",
172 " config_args[\"vocab_size\"] = 50257\n",
173 " config_args[\"block_size\"] = 1024\n",
174 "\n",
175 " config = GPTConfig(**config_args)\n",
176 "\n",
177 " model = GPT(config)\n",
178 " sd = model.state_dict()\n",
179 " sd_keys = sd.keys()\n",
180 " sd_keys = [k for k in sd_keys if not k.endswith(\".attn.bias\")]\n",
181 "\n",
182 " hf_model = GPT2LMHeadModel.from_pretrained(model_type)\n",
183 " hf_sd = hf_model.state_dict()\n",
184 "\n",
185 " hf_sd_keys = hf_sd.keys()\n",
186 " hf_sd_keys = [k for k in hf_sd_keys if not k.endswith(\".attn.masked_bias\")]\n",
187 " hf_sd_keys = [k for k in hf_sd_keys if not k.endswith(\".attn.bias\")]\n",
188 " transposed = [\"attn.c_attn.weight\", \"attn.c_proj.weight\", \"mlp.c_fc.weight\", \"mlp.c_proj.weight\"]\n",
189 "\n",
190 " assert len(sd_keys) == len(hf_sd_keys), \"mismatched keys\"\n",
191 "\n",
192 " for k in hf_sd_keys:\n",
193 " if any(k.endswith(w) for w in transposed):\n",
194 " assert hf_sd[k].shape[::-1] == sd[k].shape\n",
195 " with torch.no_grad():\n",
196 " sd[k].copy_(hf_sd[k].t())\n",
197 " else:\n",
198 " assert hf_sd[k].shape == sd[k].shape\n",
199 " with torch.no_grad():\n",
200 " sd[k].copy_(hf_sd[k])\n",
201 "\n",
202 " return model\n",
203 "\n"
204 ]
205 },
206 {
207 "cell_type": "code",
208 "execution_count": 7,
209 "id": "5ac57727-6891-41f5-afc4-b04cc5ea5b72",
210 "metadata": {},
211 "outputs": [],
212 "source": [
213 "model = GPT.from_pretrained(\"gpt2\")\n",
214 "#print(\"\\n\".join(model.state_dict().keys()))\n",
215 "\n",
216 "#model = GPT(GPT2_124M_Config())\n",
217 "model.eval()\n",
218 "model.to(\"cuda\")\n",
219 "\n",
220 "import tiktoken\n",
221 "enc = tiktoken.get_encoding(\"gpt2\")"
222 ]
223 },
224 {
225 "cell_type": "code",
226 "execution_count": 97,
227 "id": "fb78b34d-9d75-4153-b81d-5620fce4cd7e",
228 "metadata": {},
229 "outputs": [
230 {
231 "name": "stdout",
232 "output_type": "stream",
233 "text": [
234 "[0]\n",
235 "glurp.org and are available on the web at http://www.youtube.com/user/DrACnix\n",
236 "\n",
237 "The E3 2016 is happening in October, so if you love seeing games, you have more time than others to see some amazing games. The game reveal is getting underway in the second half of October.\n",
238 "\n",
239 "We always look forward to your participation in our E3. And the winners of the E3 should sign our petition to make their games available via\n",
240 "\n"
241 ]
242 }
243 ],
244 "source": [
245 "\n",
246 "num_return_sequences = 1\n",
247 "max_length = 100\n",
248 "\n",
249 "tokens = enc.encode(\"glurp\")\n",
250 "tokens = torch.tensor(tokens, dtype=torch.long)\n",
251 "tokens = tokens.unsqueeze(0).repeat(num_return_sequences, 1)\n",
252 "x = tokens.to(\"cuda\")\n",
253 "\n",
254 "torch.manual_seed(8)\n",
255 "torch.cuda.manual_seed(8)\n",
256 "\n",
257 "while x.size(1) < max_length:\n",
258 " with torch.no_grad():\n",
259 " logits, _ = model(x)\n",
260 " logits = logits[:, -1, :] # last position\n",
261 " probs = func.softmax(logits / 1, dim=-1)\n",
262 " #logits -= (logits / 1).max(1, keepdim=True).values\n",
263 " #logexp = (logits).exp()\n",
264 " #probs = logexp / logexp.sum(dim=0)\n",
265 " topk_probs, topk_indices = torch.topk(probs, 50, dim=-1)\n",
266 " ix = torch.multinomial(topk_probs, 1)\n",
267 " xcol = torch.gather(topk_indices, -1, ix)\n",
268 " x = torch.cat((x, xcol), dim=1)\n",
269 "\n",
270 "for i in range(num_return_sequences):\n",
271 " tokens = x[i, :max_length].tolist()\n",
272 " decoded = enc.decode(tokens)\n",
273 "\n",
274 " print(f\"[{i}]\\n{decoded}\\n\")"
275 ]
276 },
277 {
278 "cell_type": "code",
279 "execution_count": 92,
280 "id": "659e70c7-ba13-4db1-91e9-d1392f1710a5",
281 "metadata": {},
282 "outputs": [
283 {
284 "ename": "NameError",
285 "evalue": "name 'device' is not defined",
286 "output_type": "error",
287 "traceback": [
288 "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
289 "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
290 "Cell \u001b[0;32mIn[92], line 23\u001b[0m\n\u001b[1;32m 19\u001b[0m grad_accum_steps \u001b[38;5;241m=\u001b[39m total_batch_size \u001b[38;5;241m/\u001b[39m\u001b[38;5;241m/\u001b[39m (partial_batch_size \u001b[38;5;241m*\u001b[39m sequence_length)\n\u001b[1;32m 21\u001b[0m max_steps \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m50\u001b[39m\n\u001b[0;32m---> 23\u001b[0m optimizer \u001b[38;5;241m=\u001b[39m model\u001b[38;5;241m.\u001b[39mconfigure_optimizers(weight_decay\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0.1\u001b[39m, learning_rate\u001b[38;5;241m=\u001b[39mmax_lr, device\u001b[38;5;241m=\u001b[39m\u001b[43mdevice\u001b[49m)\n\u001b[1;32m 24\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m step \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(max_steps):\n\u001b[1;32m 25\u001b[0m optimizer\u001b[38;5;241m.\u001b[39mzero_grad()\n",
291 "\u001b[0;31mNameError\u001b[0m: name 'device' is not defined"
292 ]
293 }
294 ],
295 "source": [
296 "max_lr = 3e-4\n",
297 "min_lr = 0.1 * max_lr\n",
298 "warmup_steps = 10\n",
299 "max_steps = 50\n",
300 "def get_lr(it):\n",
301 " if it < warmup_steps:\n",
302 " return max_lr * (it + 1) / warmup_steps\n",
303 " if it > max_steps:\n",
304 " return min_lr\n",
305 " decay_ratio = (it - warmup_steps) / (max_steps - warmup_steps)\n",
306 " assert 0 <= decay_ratio <= 1\n",
307 " coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))\n",
308 " return min_lr + coeff * (max_lr - min_lr)\n",
309 "\n",
310 "\n",
311 "total_batch_size = 2**19\n",
312 "partial_batch_size = 16\n",
313 "sequence_length = 1024\n",
314 "grad_accum_steps = total_batch_size // (partial_batch_size * sequence_length)\n",
315 "\n",
316 "max_steps = 50\n",
317 "\n",
318 "optimizer = model.configure_optimizers(weight_decay=0.1, learning_rate=max_lr, device=device)\n",
319 "for step in range(max_steps):\n",
320 " optimizer.zero_grad()\n",
321 " for partial_step in range(grad_accum_steps):\n",
322 " x, y = train_loader.next_batch()\n",
323 " x, y = x.to(device), y.to(device)\n",
324 " with torch.autocast(device_type=device, dtype=torch.bfloat16):\n",
325 " logits, loss = model(x, y)\n",
326 " loss /= grad_accum_steps\n",
327 " loss.backward()\n",
328 " \n",
329 " norm = torch.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
330 " lr = get_lr(step)\n",
331 " for param_group in optimizer.param_groups:\n",
332 " param_group[\"lr\"] = lr\n",
333 " \n",
334 " optimizer.step()\n",
335 " torch.cuda.synchronize()\n",
336 " print(f\"step {step} | loss: {loss.item():.5f} | norm: {norm:.4f}\")"
337 ]
338 },
339 {
340 "cell_type": "code",
341 "execution_count": null,
342 "id": "c79cd93f-6633-4ab9-b6b9-5a8c40437c29",
343 "metadata": {},
344 "outputs": [],
345 "source": []
346 }
347 ],
348 "metadata": {
349 "kernelspec": {
350 "display_name": "Python 3 (ipykernel)",
351 "language": "python",
352 "name": "python3"
353 },
354 "language_info": {
355 "codemirror_mode": {
356 "name": "ipython",
357 "version": 3
358 },
359 "file_extension": ".py",
360 "mimetype": "text/x-python",
361 "name": "python",
362 "nbconvert_exporter": "python",
363 "pygments_lexer": "ipython3",
364 "version": "3.12.5"
365 }
366 },
367 "nbformat": 4,
368 "nbformat_minor": 5
369}