From 5a2014333540d5573a8aaa21fc4238cc59c9b544 Mon Sep 17 00:00:00 2001 From: jason Date: Wed, 20 Mar 2019 22:01:29 +0900 Subject: [PATCH] Update pytorch_chapter5.ipynb --- pytorch_chapter5.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_chapter5.ipynb b/pytorch_chapter5.ipynb index ac938ff..e6384e2 100644 --- a/pytorch_chapter5.ipynb +++ b/pytorch_chapter5.ipynb @@ -1 +1 @@ -{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"pytorch_chapter5.ipynb","version":"0.3.2","provenance":[],"collapsed_sections":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"metadata":{"id":"js62pRBdpRJc","colab_type":"code","colab":{}},"cell_type":"code","source":["!pip3 install http://download.pytorch.org/whl/cu90/torch-1.0.0-cp36-cp36m-linux_x86_64.whl\n","!pip3 install torchvision\n","!pip3 install tqdm"],"execution_count":0,"outputs":[]},{"metadata":{"id":"NC6g3TMZ_MM6","colab_type":"code","colab":{}},"cell_type":"code","source":["emb = nn.Embedding(10000, 20, padding_idx=0)\n","# Embedding 계층의 입력은 int64 Tensor\n","inp = torch.tensor([1, 2, 5, 2, 10], dtype=torch.int64)\n","# 출력은 float32 Tensor\n","out = emb(inp)\n"],"execution_count":0,"outputs":[]},{"metadata":{"id":"X5ZhVl_4DUT2","colab_type":"code","outputId":"768a7618-5c02-425b-b3df-1be239b111a6","executionInfo":{"status":"ok","timestamp":1544779845323,"user_tz":-480,"elapsed":19090,"user":{"displayName":"winston kim","photoUrl":"","userId":"05942964544969189760"}},"colab":{"base_uri":"https://localhost:8080/","height":235}},"cell_type":"code","source":["!wget http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz\n","!tar xf aclImdb_v1.tar.gz\n"],"execution_count":0,"outputs":[{"output_type":"stream","text":["--2018-12-14 09:30:27-- http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz\n","Resolving ai.stanford.edu (ai.stanford.edu)... 171.64.68.10\n","Connecting to ai.stanford.edu (ai.stanford.edu)|171.64.68.10|:80... connected.\n","HTTP request sent, awaiting response... 200 OK\n","Length: 84125825 (80M) [application/x-gzip]\n","Saving to: ‘aclImdb_v1.tar.gz’\n","\n","aclImdb_v1.tar.gz 100%[===================>] 80.23M 19.9MB/s in 6.5s \n","\n","2018-12-14 09:30:34 (12.4 MB/s) - ‘aclImdb_v1.tar.gz’ saved [84125825/84125825]\n","\n"],"name":"stdout"}]},{"metadata":{"id":"W6FflS90DlFV","colab_type":"code","colab":{}},"cell_type":"code","source":["import glob\n","import pathlib\n","import re\n","\n","remove_marks_regex = re.compile(\"[,\\.\\(\\)\\[\\]\\*:;]|<.*?>\")\n","shift_marks_regex = re.compile(\"([?!])\")\n","\n","def text2ids(text, vocab_dict):\n"," # !? 이외의 기호 삭제\n"," text = remove_marks_regex.sub(\"\", text)\n"," # !?와 단어 사이에 공백 삽입\n"," text = shift_marks_regex.sub(r\" \\1 \", text)\n"," tokens = text.split()\n"," return [vocab_dict.get(token, 0) for token in tokens]\n","\n","def list2tensor(token_idxes, max_len=100, padding=True):\n"," if len(token_idxes) > max_len:\n"," token_idxes = token_idxes[:max_len]\n"," n_tokens = len(token_idxes)\n"," if padding:\n"," token_idxes = token_idxes \\\n"," + [0] * (max_len - len(token_idxes))\n"," return torch.tensor(token_idxes, dtype=torch.int64), n_tokens\n","\n"],"execution_count":0,"outputs":[]},{"metadata":{"id":"bJnQl4foGnTk","colab_type":"code","colab":{}},"cell_type":"code","source":["import torch\n","from torch import nn, optim\n","from torch.utils.data import (Dataset, \n"," DataLoader,\n"," TensorDataset)\n","import tqdm\n"],"execution_count":0,"outputs":[]},{"metadata":{"id":"mXFhcgVrG0Am","colab_type":"code","colab":{}},"cell_type":"code","source":["class IMDBDataset(Dataset):\n"," def __init__(self, dir_path, train=True,\n"," max_len=100, padding=True):\n"," self.max_len = max_len\n"," self.padding = padding\n"," \n"," path = pathlib.Path(dir_path)\n"," vocab_path = path.joinpath(\"imdb.vocab\")\n"," \n"," # 용어집 파일을 읽어서 행 단위로 분할\n"," self.vocab_array = vocab_path.open() \\\n"," .read().strip().splitlines()\n"," # 단어가 키이고 값이 ID인 dict 만들기\n"," self.vocab_dict = dict((w, i+1) \\\n"," for (i, w) in enumerate(self.vocab_array))\n"," \n"," if train:\n"," target_path = path.joinpath(\"train\")\n"," else:\n"," target_path = path.joinpath(\"test\")\n"," pos_files = sorted(glob.glob(\n"," str(target_path.joinpath(\"pos/*.txt\"))))\n"," neg_files = sorted(glob.glob(\n"," str(target_path.joinpath(\"neg/*.txt\"))))\n"," # pos는 1, neg는 0인 label을 붙여서\n"," # (file_path, label)의 튜플 리스트 작성\n"," self.labeled_files = \\\n"," list(zip([0]*len(neg_files), neg_files )) + \\\n"," list(zip([1]*len(pos_files), pos_files))\n"," \n"," @property\n"," def vocab_size(self):\n"," return len(self.vocab_array)\n"," \n"," def __len__(self):\n"," return len(self.labeled_files)\n"," \n"," def __getitem__(self, idx):\n"," label, f = self.labeled_files[idx]\n"," # 파일의 텍스트 데이터를 읽어서 소문자로 변환\n"," data = open(f).read().lower()\n"," # 텍스트 데이터를 ID 리스트로 변환\n"," data = text2ids(data, self.vocab_dict)\n"," # ID 리스트를 Tensor로 변환\n"," data, n_tokens = list2tensor(data, self.max_len, self.padding)\n"," return data, label, n_tokens\n"],"execution_count":0,"outputs":[]},{"metadata":{"id":"5gDAC2TtJRf3","colab_type":"code","colab":{}},"cell_type":"code","source":["train_data = IMDBDataset(\"/content/aclImdb/\")\n","test_data = IMDBDataset(\"/content/aclImdb/\", train=False)\n","train_loader = DataLoader(train_data, batch_size=32,\n"," shuffle=True, num_workers=4)\n","test_loader = DataLoader(test_data, batch_size=32,\n"," shuffle=False, num_workers=4)\n"],"execution_count":0,"outputs":[]},{"metadata":{"id":"IMR3iVVyLS49","colab_type":"code","colab":{}},"cell_type":"code","source":["class SequenceTaggingNet(nn.Module):\n"," def __init__(self, num_embeddings,\n"," embedding_dim=50, \n"," hidden_size=50,\n"," num_layers=1,\n"," dropout=0.2):\n"," super().__init__()\n"," self.emb = nn.Embedding(num_embeddings, embedding_dim,\n"," padding_idx=0)\n"," self.lstm = nn.LSTM(embedding_dim,\n"," hidden_size, num_layers,\n"," batch_first=True, dropout=dropout)\n"," self.linear = nn.Linear(hidden_size, 1)\n","\n","\n"," \n"," \n"," def forward(self, x, h0=None, l=None):\n"," # ID를 Embedding으로 다차원 벡터로 변환\n"," # x는 (batch_size, step_size) \n"," # -> (batch_size, step_size, embedding_dim)\n"," x = self.emb(x)\n"," # 초기 상태 h0와 함께 RNN에 x를 전달\n"," # x는(batch_size, step_size, embedding_dim)\n"," # -> (batch_size, step_size, hidden_dim)\n"," x, h = self.lstm(x, h0)\n"," # 마지막 단계만 추출\n"," # xは(batch_size, step_size, hidden_dim)\n"," # -> (batch_size, 1)\n"," if l is not None:\n"," # 입력의 원래 길이가 있으면 그것을 이용\n"," x = x[list(range(len(x))), l-1, :]\n"," else:\n"," # 없으면 단순히 마지막 것을 이용\n"," x = x[:, -1, :]\n"," # 추출한 마지막 단계를 선형 계층에 넣는다\n"," x = self.linear(x)\n"," # 불필요한 차원을 삭제\n"," # (batch_size, 1) -> (batch_size, )\n"," x = x.squeeze()\n"," return x\n"],"execution_count":0,"outputs":[]},{"metadata":{"id":"vTNR-rYZSF5f","colab_type":"code","colab":{}},"cell_type":"code","source":["def eval_net(net, data_loader, device=\"cpu\"):\n"," net.eval()\n"," ys = []\n"," ypreds = []\n"," for x, y, l in data_loader:\n"," x = x.to(device)\n"," y = y.to(device)\n"," l = l.to(device)\n"," with torch.no_grad():\n"," y_pred = net(x, l=l)\n"," y_pred = (y_pred > 0).long()\n"," ys.append(y)\n"," ypreds.append(y_pred)\n"," ys = torch.cat(ys)\n"," ypreds = torch.cat(ypreds)\n"," acc = (ys == ypreds).float().sum() / len(ys)\n"," return acc.item()\n"],"execution_count":0,"outputs":[]},{"metadata":{"id":"Nl7MxgBwSNds","colab_type":"code","colab":{}},"cell_type":"code","source":["from statistics import mean\n","\n","# num_embeddings에는 0을 포함해서 train_data.vocab_size+1를 넣는다\n","net = SequenceTaggingNet(train_data.vocab_size+1, \n","num_layers=2)\n","net.to(\"cuda:0\")\n","opt = optim.Adam(net.parameters())\n","loss_f = nn.BCEWithLogitsLoss()\n","\n","for epoch in range(10):\n"," losses = []\n"," net.train()\n"," for x, y, l in tqdm.tqdm(train_loader):\n"," x = x.to(\"cuda:0\")\n"," y = y.to(\"cuda:0\")\n"," l = l.to(\"cuda:0\")\n"," y_pred = net(x, l=l)\n"," loss = loss_f(y_pred, y.float())\n"," net.zero_grad()\n"," loss.backward()\n"," opt.step()\n"," losses.append(loss.item())\n"," train_acc = eval_net(net, train_loader, \"cuda:0\")\n"," val_acc = eval_net(net, test_loader, \"cuda:0\")\n"," print(epoch, mean(losses), train_acc, val_acc)\n"],"execution_count":0,"outputs":[]},{"metadata":{"id":"EwBTsA03T2bi","colab_type":"code","outputId":"b0149576-2eb7-4423-9a01-c627523deb0c","executionInfo":{"status":"ok","timestamp":1544784319445,"user_tz":-480,"elapsed":85864,"user":{"displayName":"winston kim","photoUrl":"","userId":"05942964544969189760"}},"colab":{"base_uri":"https://localhost:8080/","height":127}},"cell_type":"code","source":["from sklearn.datasets import load_svmlight_file\n","from sklearn.linear_model import LogisticRegression\n","\n","train_X, train_y = load_svmlight_file(\n"," \"/content/aclImdb/train/labeledBow.feat\")\n","test_X, test_y = load_svmlight_file(\n"," \"/content/aclImdb/test/labeledBow.feat\",\n"," n_features=train_X.shape[1])\n","\n","model = LogisticRegression(C=0.1, max_iter=1000)\n","model.fit(train_X, train_y)\n","model.score(train_X, train_y), model.score(test_X, test_y)\n"],"execution_count":0,"outputs":[{"output_type":"stream","text":["/usr/local/lib/python3.6/dist-packages/sklearn/linear_model/logistic.py:433: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n"," FutureWarning)\n","/usr/local/lib/python3.6/dist-packages/sklearn/linear_model/logistic.py:460: FutureWarning: Default multi_class will be changed to 'auto' in 0.22. Specify the multi_class option to silence this warning.\n"," \"this warning.\", FutureWarning)\n"],"name":"stderr"},{"output_type":"execute_result","data":{"text/plain":["(0.89876, 0.39608)"]},"metadata":{"tags":[]},"execution_count":17}]},{"metadata":{"id":"phV7ctdUWDmz","colab_type":"code","colab":{}},"cell_type":"code","source":["class SequenceTaggingNet2(SequenceTaggingNet):\n"," def forward(self, x, h0=None, l=None):\n"," # ID를 Embedding으로 다차원 벡터로 변환\n"," x = self.emb(x)\n"," \n"," # 길이가 주어진 경우 PckedSequence 만들기\n"," if l is not None:\n"," x = nn.utils.rnn.pack_padded_sequence(\n"," x, l, batch_first=True)\n"," \n"," # RNN에 입력\n"," x, h = self.lstm(x, h0)\n"," \n"," # 마지막 단계를 추출해서 선형 계층에 넣는다\n"," if l is not None:\n"," # 길이 정보가 있으면 마지막 계층의\n"," # 내부 상태 벡터를 직접 이용할 수 있다\n"," # LSTM는 보통 내부 상태 외에 블럭 셀 상태도\n"," # 가지고 있으므로 내부 상태만 사용한다\n"," hidden_state, cell_state = h\n"," x = hidden_state[-1]\n"," else:\n"," x = x[:, -1, :]\n"," \n"," # 선형 계층에 넣는다\n"," x = self.linear(x).squeeze()\n"," return x\n"],"execution_count":0,"outputs":[]},{"metadata":{"id":"OOku_ZtUXBuX","colab_type":"code","colab":{}},"cell_type":"code","source":["for epoch in range(10):\n"," losses = []\n"," net.train()\n"," for x, y, l in tqdm.tqdm(train_loader):\n"," # 길이 배열을 길이 순으로 정렬\n"," l, sort_idx = torch.sort(l, descending=True)\n"," # 얻은 인덱스를 사용해서 x,y도 정렬\n"," x = x[sort_idx]\n"," y = y[sort_idx]\n"," \n"," x = x.to(\"cuda:0\")\n"," y = y.to(\"cuda:0\")\n"," \n"," y_pred = net(x, l=l)\n"," loss = loss_f(y_pred, y.float())\n"," net.zero_grad()\n"," loss.backward()\n"," opt.step()\n"," losses.append(loss.item())\n"," train_acc = eval_net(net, train_loader, \"cuda:0\")\n"," val_acc = eval_net(net, test_loader, \"cuda:0\")\n"," print(epoch, mean(losses), train_acc, val_acc)\n"],"execution_count":0,"outputs":[]},{"metadata":{"id":"f53dSAs0R0XG","colab_type":"code","colab":{}},"cell_type":"code","source":["# 모든 ascii 문자로 사전 만들기\n","import string\n","all_chars = string.printable\n","\n","vocab_size = len(all_chars)\n","vocab_dict = dict((c, i) for (i, c) in enumerate(all_chars))\n","\n","# 문자열을 수치 리스트로 변환하는 함수\n","def str2ints(s, vocab_dict):\n"," return [vocab_dict[c] for c in s]\n"," \n","# 수치 리스트를 문자열로 변환하는 함수\n","def ints2str(x, vocab_array):\n"," return \"\".join([vocab_array[i] for i in x])\n"],"execution_count":0,"outputs":[]},{"metadata":{"id":"KZbFeIDNip9m","colab_type":"code","outputId":"639bbdbf-275e-4028-df04-e6cb275ccdb9","executionInfo":{"status":"ok","timestamp":1545024597869,"user_tz":-480,"elapsed":43817,"user":{"displayName":"winston kim","photoUrl":"","userId":"05942964544969189760"}},"colab":{"resources":{"http://localhost:8080/nbextensions/google.colab/files.js":{"data":"Ly8gQ29weXJpZ2h0IDIwMTcgR29vZ2xlIExMQwovLwovLyBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgIkxpY2Vuc2UiKTsKLy8geW91IG1heSBub3QgdXNlIHRoaXMgZmlsZSBleGNlcHQgaW4gY29tcGxpYW5jZSB3aXRoIHRoZSBMaWNlbnNlLgovLyBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXQKLy8KLy8gICAgICBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjAKLy8KLy8gVW5sZXNzIHJlcXVpcmVkIGJ5IGFwcGxpY2FibGUgbGF3IG9yIGFncmVlZCB0byBpbiB3cml0aW5nLCBzb2Z0d2FyZQovLyBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiAiQVMgSVMiIEJBU0lTLAovLyBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC4KLy8gU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZAovLyBsaW1pdGF0aW9ucyB1bmRlciB0aGUgTGljZW5zZS4KCi8qKgogKiBAZmlsZW92ZXJ2aWV3IEhlbHBlcnMgZm9yIGdvb2dsZS5jb2xhYiBQeXRob24gbW9kdWxlLgogKi8KKGZ1bmN0aW9uKHNjb3BlKSB7CmZ1bmN0aW9uIHNwYW4odGV4dCwgc3R5bGVBdHRyaWJ1dGVzID0ge30pIHsKICBjb25zdCBlbGVtZW50ID0gZG9jdW1lbnQuY3JlYXRlRWxlbWVudCgnc3BhbicpOwogIGVsZW1lbnQudGV4dENvbnRlbnQgPSB0ZXh0OwogIGZvciAoY29uc3Qga2V5IG9mIE9iamVjdC5rZXlzKHN0eWxlQXR0cmlidXRlcykpIHsKICAgIGVsZW1lbnQuc3R5bGVba2V5XSA9IHN0eWxlQXR0cmlidXRlc1trZXldOwogIH0KICByZXR1cm4gZWxlbWVudDsKfQoKLy8gTWF4IG51bWJlciBvZiBieXRlcyB3aGljaCB3aWxsIGJlIHVwbG9hZGVkIGF0IGEgdGltZS4KY29uc3QgTUFYX1BBWUxPQURfU0laRSA9IDEwMCAqIDEwMjQ7Ci8vIE1heCBhbW91bnQgb2YgdGltZSB0byBibG9jayB3YWl0aW5nIGZvciB0aGUgdXNlci4KY29uc3QgRklMRV9DSEFOR0VfVElNRU9VVF9NUyA9IDMwICogMTAwMDsKCmZ1bmN0aW9uIF91cGxvYWRGaWxlcyhpbnB1dElkLCBvdXRwdXRJZCkgewogIGNvbnN0IHN0ZXBzID0gdXBsb2FkRmlsZXNTdGVwKGlucHV0SWQsIG91dHB1dElkKTsKICBjb25zdCBvdXRwdXRFbGVtZW50ID0gZG9jdW1lbnQuZ2V0RWxlbWVudEJ5SWQob3V0cHV0SWQpOwogIC8vIENhY2hlIHN0ZXBzIG9uIHRoZSBvdXRwdXRFbGVtZW50IHRvIG1ha2UgaXQgYXZhaWxhYmxlIGZvciB0aGUgbmV4dCBjYWxsCiAgLy8gdG8gdXBsb2FkRmlsZXNDb250aW51ZSBmcm9tIFB5dGhvbi4KICBvdXRwdXRFbGVtZW50LnN0ZXBzID0gc3RlcHM7CgogIHJldHVybiBfdXBsb2FkRmlsZXNDb250aW51ZShvdXRwdXRJZCk7Cn0KCi8vIFRoaXMgaXMgcm91Z2hseSBhbiBhc3luYyBnZW5lcmF0b3IgKG5vdCBzdXBwb3J0ZWQgaW4gdGhlIGJyb3dzZXIgeWV0KSwKLy8gd2hlcmUgdGhlcmUgYXJlIG11bHRpcGxlIGFzeW5jaHJvbm91cyBzdGVwcyBhbmQgdGhlIFB5dGhvbiBzaWRlIGlzIGdvaW5nCi8vIHRvIHBvbGwgZm9yIGNvbXBsZXRpb24gb2YgZWFjaCBzdGVwLgovLyBUaGlzIHVzZXMgYSBQcm9taXNlIHRvIGJsb2NrIHRoZSBweXRob24gc2lkZSBvbiBjb21wbGV0aW9uIG9mIGVhY2ggc3RlcCwKLy8gdGhlbiBwYXNzZXMgdGhlIHJlc3VsdCBvZiB0aGUgcHJldmlvdXMgc3RlcCBhcyB0aGUgaW5wdXQgdG8gdGhlIG5leHQgc3RlcC4KZnVuY3Rpb24gX3VwbG9hZEZpbGVzQ29udGludWUob3V0cHV0SWQpIHsKICBjb25zdCBvdXRwdXRFbGVtZW50ID0gZG9jdW1lbnQuZ2V0RWxlbWVudEJ5SWQob3V0cHV0SWQpOwogIGNvbnN0IHN0ZXBzID0gb3V0cHV0RWxlbWVudC5zdGVwczsKCiAgY29uc3QgbmV4dCA9IHN0ZXBzLm5leHQob3V0cHV0RWxlbWVudC5sYXN0UHJvbWlzZVZhbHVlKTsKICByZXR1cm4gUHJvbWlzZS5yZXNvbHZlKG5leHQudmFsdWUucHJvbWlzZSkudGhlbigodmFsdWUpID0+IHsKICAgIC8vIENhY2hlIHRoZSBsYXN0IHByb21pc2UgdmFsdWUgdG8gbWFrZSBpdCBhdmFpbGFibGUgdG8gdGhlIG5leHQKICAgIC8vIHN0ZXAgb2YgdGhlIGdlbmVyYXRvci4KICAgIG91dHB1dEVsZW1lbnQubGFzdFByb21pc2VWYWx1ZSA9IHZhbHVlOwogICAgcmV0dXJuIG5leHQudmFsdWUucmVzcG9uc2U7CiAgfSk7Cn0KCi8qKgogKiBHZW5lcmF0b3IgZnVuY3Rpb24gd2hpY2ggaXMgY2FsbGVkIGJldHdlZW4gZWFjaCBhc3luYyBzdGVwIG9mIHRoZSB1cGxvYWQKICogcHJvY2Vzcy4KICogQHBhcmFtIHtzdHJpbmd9IGlucHV0SWQgRWxlbWVudCBJRCBvZiB0aGUgaW5wdXQgZmlsZSBwaWNrZXIgZWxlbWVudC4KICogQHBhcmFtIHtzdHJpbmd9IG91dHB1dElkIEVsZW1lbnQgSUQgb2YgdGhlIG91dHB1dCBkaXNwbGF5LgogKiBAcmV0dXJuIHshSXRlcmFibGU8IU9iamVjdD59IEl0ZXJhYmxlIG9mIG5leHQgc3RlcHMuCiAqLwpmdW5jdGlvbiogdXBsb2FkRmlsZXNTdGVwKGlucHV0SWQsIG91dHB1dElkKSB7CiAgY29uc3QgaW5wdXRFbGVtZW50ID0gZG9jdW1lbnQuZ2V0RWxlbWVudEJ5SWQoaW5wdXRJZCk7CiAgaW5wdXRFbGVtZW50LmRpc2FibGVkID0gZmFsc2U7CgogIGNvbnN0IG91dHB1dEVsZW1lbnQgPSBkb2N1bWVudC5nZXRFbGVtZW50QnlJZChvdXRwdXRJZCk7CiAgb3V0cHV0RWxlbWVudC5pbm5lckhUTUwgPSAnJzsKCiAgY29uc3QgcGlja2VkUHJvbWlzZSA9IG5ldyBQcm9taXNlKChyZXNvbHZlKSA9PiB7CiAgICBpbnB1dEVsZW1lbnQuYWRkRXZlbnRMaXN0ZW5lcignY2hhbmdlJywgKGUpID0+IHsKICAgICAgcmVzb2x2ZShlLnRhcmdldC5maWxlcyk7CiAgICB9KTsKICB9KTsKCiAgY29uc3QgY2FuY2VsID0gZG9jdW1lbnQuY3JlYXRlRWxlbWVudCgnYnV0dG9uJyk7CiAgaW5wdXRFbGVtZW50LnBhcmVudEVsZW1lbnQuYXBwZW5kQ2hpbGQoY2FuY2VsKTsKICBjYW5jZWwudGV4dENvbnRlbnQgPSAnQ2FuY2VsIHVwbG9hZCc7CiAgY29uc3QgY2FuY2VsUHJvbWlzZSA9IG5ldyBQcm9taXNlKChyZXNvbHZlKSA9PiB7CiAgICBjYW5jZWwub25jbGljayA9ICgpID0+IHsKICAgICAgcmVzb2x2ZShudWxsKTsKICAgIH07CiAgfSk7CgogIC8vIENhbmNlbCB1cGxvYWQgaWYgdXNlciBoYXNuJ3QgcGlja2VkIGFueXRoaW5nIGluIHRpbWVvdXQuCiAgY29uc3QgdGltZW91dFByb21pc2UgPSBuZXcgUHJvbWlzZSgocmVzb2x2ZSkgPT4gewogICAgc2V0VGltZW91dCgoKSA9PiB7CiAgICAgIHJlc29sdmUobnVsbCk7CiAgICB9LCBGSUxFX0NIQU5HRV9USU1FT1VUX01TKTsKICB9KTsKCiAgLy8gV2FpdCBmb3IgdGhlIHVzZXIgdG8gcGljayB0aGUgZmlsZXMuCiAgY29uc3QgZmlsZXMgPSB5aWVsZCB7CiAgICBwcm9taXNlOiBQcm9taXNlLnJhY2UoW3BpY2tlZFByb21pc2UsIHRpbWVvdXRQcm9taXNlLCBjYW5jZWxQcm9taXNlXSksCiAgICByZXNwb25zZTogewogICAgICBhY3Rpb246ICdzdGFydGluZycsCiAgICB9CiAgfTsKCiAgaWYgKCFmaWxlcykgewogICAgcmV0dXJuIHsKICAgICAgcmVzcG9uc2U6IHsKICAgICAgICBhY3Rpb246ICdjb21wbGV0ZScsCiAgICAgIH0KICAgIH07CiAgfQoKICBjYW5jZWwucmVtb3ZlKCk7CgogIC8vIERpc2FibGUgdGhlIGlucHV0IGVsZW1lbnQgc2luY2UgZnVydGhlciBwaWNrcyBhcmUgbm90IGFsbG93ZWQuCiAgaW5wdXRFbGVtZW50LmRpc2FibGVkID0gdHJ1ZTsKCiAgZm9yIChjb25zdCBmaWxlIG9mIGZpbGVzKSB7CiAgICBjb25zdCBsaSA9IGRvY3VtZW50LmNyZWF0ZUVsZW1lbnQoJ2xpJyk7CiAgICBsaS5hcHBlbmQoc3BhbihmaWxlLm5hbWUsIHtmb250V2VpZ2h0OiAnYm9sZCd9KSk7CiAgICBsaS5hcHBlbmQoc3BhbigKICAgICAgICBgKCR7ZmlsZS50eXBlIHx8ICduL2EnfSkgLSAke2ZpbGUuc2l6ZX0gYnl0ZXMsIGAgKwogICAgICAgIGBsYXN0IG1vZGlmaWVkOiAkewogICAgICAgICAgICBmaWxlLmxhc3RNb2RpZmllZERhdGUgPyBmaWxlLmxhc3RNb2RpZmllZERhdGUudG9Mb2NhbGVEYXRlU3RyaW5nKCkgOgogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAnbi9hJ30gLSBgKSk7CiAgICBjb25zdCBwZXJjZW50ID0gc3BhbignMCUgZG9uZScpOwogICAgbGkuYXBwZW5kQ2hpbGQocGVyY2VudCk7CgogICAgb3V0cHV0RWxlbWVudC5hcHBlbmRDaGlsZChsaSk7CgogICAgY29uc3QgZmlsZURhdGFQcm9taXNlID0gbmV3IFByb21pc2UoKHJlc29sdmUpID0+IHsKICAgICAgY29uc3QgcmVhZGVyID0gbmV3IEZpbGVSZWFkZXIoKTsKICAgICAgcmVhZGVyLm9ubG9hZCA9IChlKSA9PiB7CiAgICAgICAgcmVzb2x2ZShlLnRhcmdldC5yZXN1bHQpOwogICAgICB9OwogICAgICByZWFkZXIucmVhZEFzQXJyYXlCdWZmZXIoZmlsZSk7CiAgICB9KTsKICAgIC8vIFdhaXQgZm9yIHRoZSBkYXRhIHRvIGJlIHJlYWR5LgogICAgbGV0IGZpbGVEYXRhID0geWllbGQgewogICAgICBwcm9taXNlOiBmaWxlRGF0YVByb21pc2UsCiAgICAgIHJlc3BvbnNlOiB7CiAgICAgICAgYWN0aW9uOiAnY29udGludWUnLAogICAgICB9CiAgICB9OwoKICAgIC8vIFVzZSBhIGNodW5rZWQgc2VuZGluZyB0byBhdm9pZCBtZXNzYWdlIHNpemUgbGltaXRzLiBTZWUgYi82MjExNTY2MC4KICAgIGxldCBwb3NpdGlvbiA9IDA7CiAgICB3aGlsZSAocG9zaXRpb24gPCBmaWxlRGF0YS5ieXRlTGVuZ3RoKSB7CiAgICAgIGNvbnN0IGxlbmd0aCA9IE1hdGgubWluKGZpbGVEYXRhLmJ5dGVMZW5ndGggLSBwb3NpdGlvbiwgTUFYX1BBWUxPQURfU0laRSk7CiAgICAgIGNvbnN0IGNodW5rID0gbmV3IFVpbnQ4QXJyYXkoZmlsZURhdGEsIHBvc2l0aW9uLCBsZW5ndGgpOwogICAgICBwb3NpdGlvbiArPSBsZW5ndGg7CgogICAgICBjb25zdCBiYXNlNjQgPSBidG9hKFN0cmluZy5mcm9tQ2hhckNvZGUuYXBwbHkobnVsbCwgY2h1bmspKTsKICAgICAgeWllbGQgewogICAgICAgIHJlc3BvbnNlOiB7CiAgICAgICAgICBhY3Rpb246ICdhcHBlbmQnLAogICAgICAgICAgZmlsZTogZmlsZS5uYW1lLAogICAgICAgICAgZGF0YTogYmFzZTY0LAogICAgICAgIH0sCiAgICAgIH07CiAgICAgIHBlcmNlbnQudGV4dENvbnRlbnQgPQogICAgICAgICAgYCR7TWF0aC5yb3VuZCgocG9zaXRpb24gLyBmaWxlRGF0YS5ieXRlTGVuZ3RoKSAqIDEwMCl9JSBkb25lYDsKICAgIH0KICB9CgogIC8vIEFsbCBkb25lLgogIHlpZWxkIHsKICAgIHJlc3BvbnNlOiB7CiAgICAgIGFjdGlvbjogJ2NvbXBsZXRlJywKICAgIH0KICB9Owp9CgpzY29wZS5nb29nbGUgPSBzY29wZS5nb29nbGUgfHwge307CnNjb3BlLmdvb2dsZS5jb2xhYiA9IHNjb3BlLmdvb2dsZS5jb2xhYiB8fCB7fTsKc2NvcGUuZ29vZ2xlLmNvbGFiLl9maWxlcyA9IHsKICBfdXBsb2FkRmlsZXMsCiAgX3VwbG9hZEZpbGVzQ29udGludWUsCn07Cn0pKHNlbGYpOwo=","ok":true,"headers":[["content-type","application/javascript"]],"status":200,"status_text":""}},"base_uri":"https://localhost:8080/","height":76}},"cell_type":"code","source":["from google.colab import files\n","# 창이 뜨면 파일을 선택해서 업로드한다\n","uploaded = files.upload()\n"],"execution_count":0,"outputs":[{"output_type":"display_data","data":{"text/html":["\n"," \n"," \n"," Upload widget is only available when the cell has been executed in the\n"," current browser session. Please rerun this cell to enable.\n"," \n"," "],"text/plain":[""]},"metadata":{"tags":[]}},{"output_type":"stream","text":["Saving tinyshakespeare.txt to tinyshakespeare.txt\n"],"name":"stdout"}]},{"metadata":{"id":"0qVk_tWpm3dk","colab_type":"code","colab":{}},"cell_type":"code","source":["import torch\n","from torch import nn, optim\n","from torch.utils.data import (Dataset, \n"," DataLoader,\n"," TensorDataset)\n","import tqdm"],"execution_count":0,"outputs":[]},{"metadata":{"id":"zQvNRnvGnGMX","colab_type":"code","colab":{}},"cell_type":"code","source":["class ShakespeareDataset(Dataset):\n"," def __init__(self, path, chunk_size=200):\n"," # 파일을 읽어서 수치 리스트로 변환\n"," data = str2ints(open(path).read().strip(), vocab_dict)\n"," \n"," # Tensor로 변환해서 split 한다\n"," data = torch.tensor(data, dtype=torch.int64).split(chunk_size)\n"," \n"," # 마지막 덩어리(chunk)의 길이를 확인해서 부족한 경우 버린다後のchunkの長さをチェックして足りない場合には捨てる\n"," if len(data[-1]) < chunk_size:\n"," data = data[:-1]\n"," \n"," self.data = data\n"," self.n_chunks = len(self.data)\n"," \n"," def __len__(self):\n"," return self.n_chunks\n","\n"," def __getitem__(self, idx):\n"," return self.data[idx]\n"],"execution_count":0,"outputs":[]},{"metadata":{"id":"Oy81AcMxoeeR","colab_type":"code","colab":{}},"cell_type":"code","source":["ds = ShakespeareDataset(\"/content/tinyshakespeare.txt\", chunk_size=200)\n","loader = DataLoader(ds, batch_size=32, shuffle=True, num_workers=4)\n"],"execution_count":0,"outputs":[]},{"metadata":{"id":"jDyTY8x4uRiX","colab_type":"code","colab":{}},"cell_type":"code","source":["class SequenceGenerationNet(nn.Module):\n"," def __init__(self, num_embeddings, \n"," embedding_dim=50, \n"," hidden_size=50,\n"," num_layers=1, dropout=0.2):\n"," super().__init__()\n"," self.emb = nn.Embedding(num_embeddings, embedding_dim)\n"," self.lstm = nn.LSTM(embedding_dim, \n"," hidden_size,\n"," num_layers,\n"," batch_first=True,\n"," dropout=dropout)\n"," # Linear의 output 크기는 첫 Embedding의 \n"," # input 크기와 같은 num_embeddings\n"," self.linear = nn.Linear(hidden_size, num_embeddings)\n"," \n"," def forward(self, x, h0=None):\n"," x = self.emb(x)\n"," x, h = self.lstm(x, h0)\n"," x = self.linear(x)\n"," return x, h\n"],"execution_count":0,"outputs":[]},{"metadata":{"id":"SQ4vfribu8i1","colab_type":"code","colab":{}},"cell_type":"code","source":["def generate_seq(net, start_phrase=\"The King said \",\n"," length=200, temperature=0.8, device=\"cpu\"):\n"," # 모델을 평가 모드로 설정\n"," net.eval()\n"," # 출력 수치를 저장할 리스트\n"," result = []\n"," \n"," # 시작 문자열을 Tensor로 변환\n"," start_tensor = torch.tensor(\n"," str2ints(start_phrase, vocab_dict),\n"," dtype=torch.int64\n"," ).to(device)\n"," # 선두에 batch 차원을 붙인다\n"," x0 = start_tensor.unsqueeze(0) \n"," # RNN을 통해서 출력과 새로운 내부 상태를 얻는다\n"," o, h = net(x0)\n"," # 출력을 정규화돼있지 않은 확률로 변환\n"," out_dist = o[:, -1].view(-1).exp()\n"," # 확률로부터 실제 문자의 인덱스를 샘플링グ\n"," top_i = torch.multinomial(out_dist, 1)[0]\n"," # 결과 저장\n"," result.append(top_i)\n"," \n"," # 생성된 결과를 차례로 RNN에 넣는다\n"," for i in range(length):\n"," inp = torch.tensor([[top_i]], dtype=torch.int64)\n"," inp = inp.to(device)\n"," o, h = net(inp, h)\n"," out_dist = o.view(-1).exp()\n"," top_i = torch.multinomial(out_dist, 1)[0]\n"," result.append(top_i)\n"," \n"," # 시작 문자열과 생성된 문자열을 모아서 반환\n"," return start_phrase + ints2str(result, all_chars)\n"],"execution_count":0,"outputs":[]},{"metadata":{"id":"GrrqpgRy4P-F","colab_type":"code","outputId":"8680bcb6-91b6-44dd-e84e-8d567e23e8a4","executionInfo":{"status":"ok","timestamp":1545029221763,"user_tz":-480,"elapsed":453010,"user":{"displayName":"winston kim","photoUrl":"","userId":"05942964544969189760"}},"colab":{"base_uri":"https://localhost:8080/","height":9181}},"cell_type":"code","source":["from statistics import mean\n","net = SequenceGenerationNet(vocab_size, 20, 50,\n"," num_layers=2, dropout=0.1)\n","net.to(\"cuda:0\")\n","opt = optim.Adam(net.parameters())\n","# 다중 식별 문제이므로 SoftmaxCrossEntropyLoss가 손실 함수가 된다\n","loss_f = nn.CrossEntropyLoss()\n","\n","for epoch in range(50):\n"," net.train()\n"," losses = []\n"," for data in tqdm.tqdm(loader):\n"," # x는 처음부터 마지막의 하나 앞 문자까지\n"," x = data[:, :-1]\n"," # y는 두 번째부터 마지막 문자까지\n"," y = data[:, 1:]\n"," \n"," x = x.to(\"cuda:0\")\n"," y = y.to(\"cuda:0\")\n"," \n"," y_pred, _ = net(x)\n"," # batch와 step 축을 통합해서 CrossEntropyLoss에 전달\n"," loss = loss_f(y_pred.view(-1, vocab_size), y.view(-1))\n"," net.zero_grad()\n"," loss.backward()\n"," opt.step()\n"," losses.append(loss.item())\n"," # 현재 손실 함수와 생성된 문장 예 표시\n"," print(epoch, mean(losses))\n"," with torch.no_grad():\n"," print(generate_seq(net, device=\"cuda:0\"))\n"],"execution_count":0,"outputs":[{"output_type":"stream","text":["100%|██████████| 175/175 [00:08<00:00, 19.89it/s]\n"],"name":"stderr"},{"output_type":"stream","text":["0 3.4874473898751397\n"],"name":"stdout"},{"output_type":"stream","text":["\r 0%| | 0/175 [00:00\")\n","shift_marks_regex = re.compile(\"([?!\\.])\")\n","\n","unk = 0\n","sos = 1\n","eos = 2\n","\n","def normalize(text):\n"," text = text.lower()\n"," # 불필요한 문자 제거\n"," text = remove_marks_regex.sub(\"\", text)\n"," # ?!. 와 단어 사이에 공백 삽입\n"," text = shift_marks_regex.sub(r\" \\1\", text)\n"," return text\n"," \n","def parse_line(line):\n"," line = normalize(line.strip())\n"," # 번역 대상(src)과 번역 결과(trg) 각각의 토큰을 리스트로 만든다\n"," src, trg = line.split(\"\\t\")\n"," src_tokens = src.strip().split()\n"," trg_tokens = trg.strip().split()\n"," return src_tokens, trg_tokens\n"," \n","def build_vocab(tokens):\n"," # 파일 안의 모든 문장에서 토큰의 등장 횟수를 확인\n"," counts = collections.Counter(tokens)\n"," # 토큰의 등장 횟수를 많은 순으로 나열\n"," sorted_counts = sorted(counts.items(), \n"," key=lambda c: c[1], reverse=True)\n"," # 세 개의 태그를 추가해서 정방향 리스트와 역방향 용어집 만들기\n"," word_list = [\"\", \"\", \"\"] \\\n"," + [x[0] for x in sorted_counts]\n"," word_dict = dict((w, i) for i, w in enumerate(word_list))\n"," return word_list, word_dict\n"," \n","def words2tensor(words, word_dict, max_len, padding=0):\n"," # 끝에 종료 태그를 붙임\n"," words = words + [\"\"]\n"," # 사전을 이용해서 수치 리스트로 변환\n"," words = [word_dict.get(w, 0) for w in words]\n"," seq_len = len(words)\n"," # 길이가 max_len이하이면 패딩한다\n"," if seq_len < max_len + 1:\n"," words = words + [padding] * (max_len + 1 - seq_len)\n"," # Tensor로 변환해서 반환\n"," return torch.tensor(words, dtype=torch.int64), seq_len\n","\n"],"execution_count":0,"outputs":[]},{"metadata":{"id":"EAHMfV8Qmw_5","colab_type":"code","colab":{}},"cell_type":"code","source":["class TranslationPairDataset(Dataset):\n"," def __init__(self, path, max_len=15):\n"," # 단어 수사 많은 문장을 걸러내는 함수\n"," def filter_pair(p):\n"," return not (len(p[0]) > max_len \n"," or len(p[1]) > max_len)\n"," # 파일을 열어서, 파스 및 필터링 \n"," with open(path) as fp:\n"," pairs = map(parse_line, fp)\n"," pairs = filter(filter_pair, pairs)\n"," pairs = list(pairs)\n"," # 문장의 소스와 타켓으로 나눔\n"," src = [p[0] for p in pairs]\n"," trg = [p[1] for p in pairs]\n"," #각각의 어휘집 작성\n"," self.src_word_list, self.src_word_dict = \\\n"," build_vocab(itertools.chain.from_iterable(src))\n"," self.trg_word_list, self.trg_word_dict = \\\n"," build_vocab(itertools.chain.from_iterable(trg))\n"," # 어휘집을 사용해서 Tensor로 변환\n"," self.src_data = [words2tensor(\n"," words, self.src_word_dict, max_len)\n"," for words in src]\n"," self.trg_data = [words2tensor(\n"," words, self.trg_word_dict, max_len, -100)\n"," for words in trg]\n"," def __len__(self):\n"," return len(self.src_data)\n"," \n"," def __getitem__(self, idx):\n"," src, lsrc = self.src_data[idx]\n"," trg, ltrg = self.trg_data[idx]\n"," return src, lsrc, trg, ltrg\n","\n"],"execution_count":0,"outputs":[]},{"metadata":{"id":"gge2cxkUnwVd","colab_type":"code","colab":{}},"cell_type":"code","source":["batch_size = 64\n","max_len = 10\n","path = \"/content/spa.txt\"\n","ds = TranslationPairDataset(path, max_len=max_len)\n","loader = DataLoader(ds, batch_size=batch_size, shuffle=True,\n"," num_workers=4)\n"],"execution_count":0,"outputs":[]},{"metadata":{"id":"Oh7y_8tgoTlm","colab_type":"code","colab":{}},"cell_type":"code","source":["class Encoder(nn.Module):\n"," def __init__(self, num_embeddings,\n"," embedding_dim=50, \n"," hidden_size=50,\n"," num_layers=1,\n"," dropout=0.2):\n"," super().__init__()\n"," self.emb = nn.Embedding(num_embeddings, \n"," embedding_dim, padding_idx=0)\n"," self.lstm = nn.LSTM(embedding_dim,\n"," hidden_size, num_layers,\n"," batch_first=True,\n","dropout=dropout)\n"," \n"," def forward(self, x, h0=None, l=None):\n"," x = self.emb(x)\n"," if l is not None:\n"," x = nn.utils.rnn.pack_padded_sequence(\n"," x, l, batch_first=True)\n"," _, h = self.lstm(x, h0)\n"," return h\n"],"execution_count":0,"outputs":[]},{"metadata":{"id":"tHC35oMSoVyD","colab_type":"code","colab":{}},"cell_type":"code","source":["class Decoder(nn.Module):\n"," def __init__(self, num_embeddings,\n"," embedding_dim=50, \n"," hidden_size=50,\n"," num_layers=1,\n"," dropout=0.2):\n"," super().__init__()\n"," self.emb = nn.Embedding(num_embeddings, embedding_dim, padding_idx=0)\n"," self.lstm = nn.LSTM(embedding_dim, hidden_size,\n"," num_layers, batch_first=True,\n"," dropout=dropout)\n"," self.linear = nn.Linear(hidden_size, num_embeddings)\n"," \n"," def forward(self, x, h, l=None):\n"," x = self.emb(x)\n"," if l is not None:\n"," x = nn.utils.rnn.pack_padded_sequence(\n"," x, l, batch_first=True)\n"," x, h = self.lstm(x, h)\n"," if l is not None:\n"," x = nn.utils.rnn.pad_packed_sequence(x, batch_first=True, padding_value=0)[0]\n"," x = self.linear(x)\n"," return x, h\n"],"execution_count":0,"outputs":[]},{"metadata":{"id":"HCT4wbK-5DQ_","colab_type":"code","colab":{}},"cell_type":"code","source":["def translate(input_str, enc, dec, max_len=15, device=\"cpu\"):\n"," # 입력 문자열을 수치화해서 Tensor로 변환\n"," words = normalize(input_str).split()\n"," input_tensor, seq_len = words2tensor(words, \n"," ds.src_word_dict, max_len=max_len)\n"," input_tensor = input_tensor.unsqueeze(0)\n"," # 엔코더에서 사용하므로 입력값의 길이도 리스트로 만들어둔다\n"," seq_len = [seq_len]\n"," # 시작 토큰 준비\n"," sos_inputs = torch.tensor(sos, dtype=torch.int64)\n"," input_tensor = input_tensor.to(device)\n"," sos_inputs = sos_inputs.to(device)\n"," # 입력 문자열을 엔코더에 넣어서 컨텍스트 얻기\n"," ctx = enc(input_tensor, l=seq_len)\n"," # 시작 토큰과 컨텍스트를 디코더의 초깃값으로 설정\n"," z = sos_inputs\n"," h = ctx\n"," results = []\n"," for i in range(max_len):\n"," # Decoder로 다음 단어 예측\n"," o, h = dec(z.view(1, 1), h)\n"," # 선형 계층의 출력이 가장 큰 위치가 다음 단어의 ID\n"," wi = o.detach().view(-1).max(0)[1]\n"," if wi.item() == eos:\n"," break\n"," results.append(wi.item())\n"," # 다음 입력값으로 현재 출력 ID를 사용\n"," z = wi\n"," # 기록해둔 출력 ID를 문자열로 변환\n"," return \" \".join(ds.trg_word_list[i] for i in results)\n"],"execution_count":0,"outputs":[]},{"metadata":{"id":"tftEQdHD6OVH","colab_type":"code","outputId":"25ff72c2-90b5-4231-dea7-bb636dead0e9","executionInfo":{"status":"ok","timestamp":1545045864966,"user_tz":-480,"elapsed":1142,"user":{"displayName":"winston kim","photoUrl":"","userId":"05942964544969189760"}},"colab":{"base_uri":"https://localhost:8080/","height":55}},"cell_type":"code","source":["enc = Encoder(len(ds.src_word_list), 100, 100, 2)\n","dec = Decoder(len(ds.trg_word_list), 100, 100, 2)\n","translate(\"I am a student.\", enc, dec)\n"],"execution_count":0,"outputs":[{"output_type":"execute_result","data":{"text/plain":["'susurro liberaron salida trasladé trasladé moscú moscú memorizando moscú moscú moscú moscú memorizando moscú moscú'"]},"metadata":{"tags":[]},"execution_count":30}]},{"metadata":{"id":"NKFqA3t9608o","colab_type":"code","colab":{}},"cell_type":"code","source":["enc = Encoder(len(ds.src_word_list), 100, 100, 2)\n","dec = Decoder(len(ds.trg_word_list), 100, 100, 2)\n","enc.to(\"cuda:0\")\n","dec.to(\"cuda:0\")\n","opt_enc = optim.Adam(enc.parameters(), 0.002)\n","opt_dec = optim.Adam(dec.parameters(), 0.002)\n","loss_f = nn.CrossEntropyLoss()\n"],"execution_count":0,"outputs":[]},{"metadata":{"id":"XcWJo1rn7X2r","colab_type":"code","outputId":"d481dd6a-f753-455a-a912-4fe6887c4c2e","colab":{"base_uri":"https://localhost:8080/","height":1835}},"cell_type":"code","source":["from statistics import mean\n","\n","def to2D(x):\n"," shapes = x.shape\n"," return x.reshape(shapes[0] * shapes[1], -1)\n"," \n","for epoc in range(30):\n"," # 신경망을 훈련 모드로 설정\n"," enc.train(), dec.train()\n"," losses = []\n"," for x, lx, y, ly in tqdm.tqdm(loader):\n"," # x의 PackedSequence를 만들기 위해 번역 소스의 길이로 내림차순 정렬한다\n"," lx, sort_idx = lx.sort(descending=True)\n"," x, y, ly = x[sort_idx], y[sort_idx], ly[sort_idx]\n"," x, y = x.to(\"cuda:0\"), y.to(\"cuda:0\")\n"," # 번역 소스를 엔코더에 넣어서 컨텍스트를 얻는다\n"," ctx = enc(x, l=lx)\n"," # y의 PackedSequence를 만들기 위해 번역 소스의 길이로 내림차순 정렬\n"," ly, sort_idx = ly.sort(descending=True)\n"," y = y[sort_idx]\n"," # Decoder의 초깃값 설정\n"," h0 = (ctx[0][:, sort_idx, :], ctx[1][:, sort_idx, :])\n"," z = y[:, :-1].detach()\n"," # -100인 상태에선 Embedding 계산에서 오류가 발생하므로 0으로 변경\n"," z[z==-100] = 0\n"," # 디코더에 넣어서 손실 함수 계산\n"," o, _ = dec(z, h0, l=ly-1)\n"," loss = loss_f(to2D(o[:]), to2D(y[:, 1:max(ly)]).squeeze())\n"," # Backpropagation(오차 역전파 실행)\n"," enc.zero_grad(), dec.zero_grad()\n"," loss.backward()\n"," opt_enc.step(), opt_dec.step()\n"," losses.append(loss.item())\n"," # 전체 데이터의 계산이 끝나면 현재의\n"," # 손실 함수 값이나 번역 결과를 표시\n"," enc.eval(), dec.eval()\n"," print(epoc, mean(losses))\n"," with torch.no_grad():\n"," print(translate(\"I am a student.\",\n"," enc, dec, max_len=max_len, \n","device=\"cuda:0\"))\n"," print(translate(\"He likes to eat pizza.\",\n"," enc, dec, max_len=max_len, \n","device=\"cuda:0\"))\n"," print(translate(\"She is my mother.\",\n"," enc, dec, max_len=max_len, \n","device=\"cuda:0\"))\n"],"execution_count":0,"outputs":[{"output_type":"stream","text":["100%|██████████| 1623/1623 [01:21<00:00, 19.96it/s]\n"," 0%| | 0/1623 [00:00] 80.23M 19.9MB/s in 6.5s \n","\n","2018-12-14 09:30:34 (12.4 MB/s) - ‘aclImdb_v1.tar.gz’ saved [84125825/84125825]\n","\n"],"name":"stdout"}]},{"metadata":{"id":"W6FflS90DlFV","colab_type":"code","colab":{}},"cell_type":"code","source":["import glob\n","import pathlib\n","import re\n","\n","remove_marks_regex = re.compile(\"[,\\.\\(\\)\\[\\]\\*:;]|<.*?>\")\n","shift_marks_regex = re.compile(\"([?!])\")\n","\n","def text2ids(text, vocab_dict):\n"," # !? 이외의 기호 삭제\n"," text = remove_marks_regex.sub(\"\", text)\n"," # !?와 단어 사이에 공백 삽입\n"," text = shift_marks_regex.sub(r\" \\1 \", text)\n"," tokens = text.split()\n"," return [vocab_dict.get(token, 0) for token in tokens]\n","\n","def list2tensor(token_idxes, max_len=100, padding=True):\n"," if len(token_idxes) > max_len:\n"," token_idxes = token_idxes[:max_len]\n"," n_tokens = len(token_idxes)\n"," if padding:\n"," token_idxes = token_idxes \\\n"," + [0] * (max_len - len(token_idxes))\n"," return torch.tensor(token_idxes, dtype=torch.int64), n_tokens\n","\n"],"execution_count":0,"outputs":[]},{"metadata":{"id":"bJnQl4foGnTk","colab_type":"code","colab":{}},"cell_type":"code","source":["import torch\n","from torch import nn, optim\n","from torch.utils.data import (Dataset, \n"," DataLoader,\n"," TensorDataset)\n","import tqdm\n"],"execution_count":0,"outputs":[]},{"metadata":{"id":"mXFhcgVrG0Am","colab_type":"code","colab":{}},"cell_type":"code","source":["class IMDBDataset(Dataset):\n"," def __init__(self, dir_path, train=True,\n"," max_len=100, padding=True):\n"," self.max_len = max_len\n"," self.padding = padding\n"," \n"," path = pathlib.Path(dir_path)\n"," vocab_path = path.joinpath(\"imdb.vocab\")\n"," \n"," # 용어집 파일을 읽어서 행 단위로 분할\n"," self.vocab_array = vocab_path.open() \\\n"," .read().strip().splitlines()\n"," # 단어가 키이고 값이 ID인 dict 만들기\n"," self.vocab_dict = dict((w, i+1) \\\n"," for (i, w) in enumerate(self.vocab_array))\n"," \n"," if train:\n"," target_path = path.joinpath(\"train\")\n"," else:\n"," target_path = path.joinpath(\"test\")\n"," pos_files = sorted(glob.glob(\n"," str(target_path.joinpath(\"pos/*.txt\"))))\n"," neg_files = sorted(glob.glob(\n"," str(target_path.joinpath(\"neg/*.txt\"))))\n"," # pos는 1, neg는 0인 label을 붙여서\n"," # (file_path, label)의 튜플 리스트 작성\n"," self.labeled_files = \\\n"," list(zip([0]*len(neg_files), neg_files )) + \\\n"," list(zip([1]*len(pos_files), pos_files))\n"," \n"," @property\n"," def vocab_size(self):\n"," return len(self.vocab_array)\n"," \n"," def __len__(self):\n"," return len(self.labeled_files)\n"," \n"," def __getitem__(self, idx):\n"," label, f = self.labeled_files[idx]\n"," # 파일의 텍스트 데이터를 읽어서 소문자로 변환\n"," data = open(f).read().lower()\n"," # 텍스트 데이터를 ID 리스트로 변환\n"," data = text2ids(data, self.vocab_dict)\n"," # ID 리스트를 Tensor로 변환\n"," data, n_tokens = list2tensor(data, self.max_len, self.padding)\n"," return data, label, n_tokens\n"],"execution_count":0,"outputs":[]},{"metadata":{"id":"5gDAC2TtJRf3","colab_type":"code","colab":{}},"cell_type":"code","source":["train_data = IMDBDataset(\"/content/aclImdb/\")\n","test_data = IMDBDataset(\"/content/aclImdb/\", train=False)\n","train_loader = DataLoader(train_data, batch_size=32,\n"," shuffle=True, num_workers=4)\n","test_loader = DataLoader(test_data, batch_size=32,\n"," shuffle=False, num_workers=4)\n"],"execution_count":0,"outputs":[]},{"metadata":{"id":"IMR3iVVyLS49","colab_type":"code","colab":{}},"cell_type":"code","source":["class SequenceTaggingNet(nn.Module):\n"," def __init__(self, num_embeddings,\n"," embedding_dim=50, \n"," hidden_size=50,\n"," num_layers=1,\n"," dropout=0.2):\n"," super().__init__()\n"," self.emb = nn.Embedding(num_embeddings, embedding_dim,\n"," padding_idx=0)\n"," self.lstm = nn.LSTM(embedding_dim,\n"," hidden_size, num_layers,\n"," batch_first=True, dropout=dropout)\n"," self.linear = nn.Linear(hidden_size, 1)\n","\n","\n"," \n"," \n"," def forward(self, x, h0=None, l=None):\n"," # ID를 Embedding으로 다차원 벡터로 변환\n"," # x는 (batch_size, step_size) \n"," # -> (batch_size, step_size, embedding_dim)\n"," x = self.emb(x)\n"," # 초기 상태 h0와 함께 RNN에 x를 전달\n"," # x는(batch_size, step_size, embedding_dim)\n"," # -> (batch_size, step_size, hidden_dim)\n"," x, h = self.lstm(x, h0)\n"," # 마지막 단계만 추출\n"," # xは(batch_size, step_size, hidden_dim)\n"," # -> (batch_size, 1)\n"," if l is not None:\n"," # 입력의 원래 길이가 있으면 그것을 이용\n"," x = x[list(range(len(x))), l-1, :]\n"," else:\n"," # 없으면 단순히 마지막 것을 이용\n"," x = x[:, -1, :]\n"," # 추출한 마지막 단계를 선형 계층에 넣는다\n"," x = self.linear(x)\n"," # 불필요한 차원을 삭제\n"," # (batch_size, 1) -> (batch_size, )\n"," x = x.squeeze()\n"," return x\n"],"execution_count":0,"outputs":[]},{"metadata":{"id":"vTNR-rYZSF5f","colab_type":"code","colab":{}},"cell_type":"code","source":["def eval_net(net, data_loader, device=\"cpu\"):\n"," net.eval()\n"," ys = []\n"," ypreds = []\n"," for x, y, l in data_loader:\n"," x = x.to(device)\n"," y = y.to(device)\n"," l = l.to(device)\n"," with torch.no_grad():\n"," y_pred = net(x, l=l)\n"," y_pred = (y_pred > 0).long()\n"," ys.append(y)\n"," ypreds.append(y_pred)\n"," ys = torch.cat(ys)\n"," ypreds = torch.cat(ypreds)\n"," acc = (ys == ypreds).float().sum() / len(ys)\n"," return acc.item()\n"],"execution_count":0,"outputs":[]},{"metadata":{"id":"Nl7MxgBwSNds","colab_type":"code","colab":{}},"cell_type":"code","source":["from statistics import mean\n","\n","# num_embeddings에는 0을 포함해서 train_data.vocab_size+1를 넣는다\n","net = SequenceTaggingNet(train_data.vocab_size+1, \n","num_layers=2)\n","net.to(\"cuda:0\")\n","opt = optim.Adam(net.parameters())\n","loss_f = nn.BCEWithLogitsLoss()\n","\n","for epoch in range(10):\n"," losses = []\n"," net.train()\n"," for x, y, l in tqdm.tqdm(train_loader):\n"," x = x.to(\"cuda:0\")\n"," y = y.to(\"cuda:0\")\n"," l = l.to(\"cuda:0\")\n"," y_pred = net(x, l=l)\n"," loss = loss_f(y_pred, y.float())\n"," net.zero_grad()\n"," loss.backward()\n"," opt.step()\n"," losses.append(loss.item())\n"," train_acc = eval_net(net, train_loader, \"cuda:0\")\n"," val_acc = eval_net(net, test_loader, \"cuda:0\")\n"," print(epoch, mean(losses), train_acc, val_acc)\n"],"execution_count":0,"outputs":[]},{"metadata":{"id":"EwBTsA03T2bi","colab_type":"code","outputId":"b0149576-2eb7-4423-9a01-c627523deb0c","executionInfo":{"status":"ok","timestamp":1544784319445,"user_tz":-480,"elapsed":85864,"user":{"displayName":"winston kim","photoUrl":"","userId":"05942964544969189760"}},"colab":{"base_uri":"https://localhost:8080/","height":127}},"cell_type":"code","source":["from sklearn.datasets import load_svmlight_file\n","from sklearn.linear_model import LogisticRegression\n","\n","train_X, train_y = load_svmlight_file(\n"," \"/content/aclImdb/train/labeledBow.feat\")\n","test_X, test_y = load_svmlight_file(\n"," \"/content/aclImdb/test/labeledBow.feat\",\n"," n_features=train_X.shape[1])\n","\n","model = LogisticRegression(C=0.1, max_iter=1000)\n","model.fit(train_X, train_y)\n","model.score(train_X, train_y), model.score(test_X, test_y)\n"],"execution_count":0,"outputs":[{"output_type":"stream","text":["/usr/local/lib/python3.6/dist-packages/sklearn/linear_model/logistic.py:433: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n"," FutureWarning)\n","/usr/local/lib/python3.6/dist-packages/sklearn/linear_model/logistic.py:460: FutureWarning: Default multi_class will be changed to 'auto' in 0.22. Specify the multi_class option to silence this warning.\n"," \"this warning.\", FutureWarning)\n"],"name":"stderr"},{"output_type":"execute_result","data":{"text/plain":["(0.89876, 0.39608)"]},"metadata":{"tags":[]},"execution_count":17}]},{"metadata":{"id":"phV7ctdUWDmz","colab_type":"code","colab":{}},"cell_type":"code","source":["class SequenceTaggingNet2(SequenceTaggingNet):\n"," def forward(self, x, h0=None, l=None):\n"," # ID를 Embedding으로 다차원 벡터로 변환\n"," x = self.emb(x)\n"," \n"," # 길이가 주어진 경우 PckedSequence 만들기\n"," if l is not None:\n"," x = nn.utils.rnn.pack_padded_sequence(\n"," x, l, batch_first=True)\n"," \n"," # RNN에 입력\n"," x, h = self.lstm(x, h0)\n"," \n"," # 마지막 단계를 추출해서 선형 계층에 넣는다\n"," if l is not None:\n"," # 길이 정보가 있으면 마지막 계층의\n"," # 내부 상태 벡터를 직접 이용할 수 있다\n"," # LSTM는 보통 내부 상태 외에 블럭 셀 상태도\n"," # 가지고 있으므로 내부 상태만 사용한다\n"," hidden_state, cell_state = h\n"," x = hidden_state[-1]\n"," else:\n"," x = x[:, -1, :]\n"," \n"," # 선형 계층에 넣는다\n"," x = self.linear(x).squeeze()\n"," return x\n"],"execution_count":0,"outputs":[]},{"metadata":{"id":"OOku_ZtUXBuX","colab_type":"code","colab":{}},"cell_type":"code","source":["for epoch in range(10):\n"," losses = []\n"," net.train()\n"," for x, y, l in tqdm.tqdm(train_loader):\n"," # 길이 배열을 길이 순으로 정렬\n"," l, sort_idx = torch.sort(l, descending=True)\n"," # 얻은 인덱스를 사용해서 x,y도 정렬\n"," x = x[sort_idx]\n"," y = y[sort_idx]\n"," \n"," x = x.to(\"cuda:0\")\n"," y = y.to(\"cuda:0\")\n"," \n"," y_pred = net(x, l=l)\n"," loss = loss_f(y_pred, y.float())\n"," net.zero_grad()\n"," loss.backward()\n"," opt.step()\n"," losses.append(loss.item())\n"," train_acc = eval_net(net, train_loader, \"cuda:0\")\n"," val_acc = eval_net(net, test_loader, \"cuda:0\")\n"," print(epoch, mean(losses), train_acc, val_acc)\n"],"execution_count":0,"outputs":[]},{"metadata":{"id":"f53dSAs0R0XG","colab_type":"code","colab":{}},"cell_type":"code","source":["# 모든 ascii 문자로 사전 만들기\n","import string\n","all_chars = string.printable\n","\n","vocab_size = len(all_chars)\n","vocab_dict = dict((c, i) for (i, c) in enumerate(all_chars))\n","\n","# 문자열을 수치 리스트로 변환하는 함수\n","def str2ints(s, vocab_dict):\n"," return [vocab_dict[c] for c in s]\n"," \n","# 수치 리스트를 문자열로 변환하는 함수\n","def ints2str(x, vocab_array):\n"," return \"\".join([vocab_array[i] for i in x])\n"],"execution_count":0,"outputs":[]},{"metadata":{"id":"KZbFeIDNip9m","colab_type":"code","outputId":"639bbdbf-275e-4028-df04-e6cb275ccdb9","executionInfo":{"status":"ok","timestamp":1545024597869,"user_tz":-480,"elapsed":43817,"user":{"displayName":"winston kim","photoUrl":"","userId":"05942964544969189760"}},"colab":{"resources":{"http://localhost:8080/nbextensions/google.colab/files.js":{"data":"Ly8gQ29weXJpZ2h0IDIwMTcgR29vZ2xlIExMQwovLwovLyBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgIkxpY2Vuc2UiKTsKLy8geW91IG1heSBub3QgdXNlIHRoaXMgZmlsZSBleGNlcHQgaW4gY29tcGxpYW5jZSB3aXRoIHRoZSBMaWNlbnNlLgovLyBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXQKLy8KLy8gICAgICBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjAKLy8KLy8gVW5sZXNzIHJlcXVpcmVkIGJ5IGFwcGxpY2FibGUgbGF3IG9yIGFncmVlZCB0byBpbiB3cml0aW5nLCBzb2Z0d2FyZQovLyBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiAiQVMgSVMiIEJBU0lTLAovLyBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC4KLy8gU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZAovLyBsaW1pdGF0aW9ucyB1bmRlciB0aGUgTGljZW5zZS4KCi8qKgogKiBAZmlsZW92ZXJ2aWV3IEhlbHBlcnMgZm9yIGdvb2dsZS5jb2xhYiBQeXRob24gbW9kdWxlLgogKi8KKGZ1bmN0aW9uKHNjb3BlKSB7CmZ1bmN0aW9uIHNwYW4odGV4dCwgc3R5bGVBdHRyaWJ1dGVzID0ge30pIHsKICBjb25zdCBlbGVtZW50ID0gZG9jdW1lbnQuY3JlYXRlRWxlbWVudCgnc3BhbicpOwogIGVsZW1lbnQudGV4dENvbnRlbnQgPSB0ZXh0OwogIGZvciAoY29uc3Qga2V5IG9mIE9iamVjdC5rZXlzKHN0eWxlQXR0cmlidXRlcykpIHsKICAgIGVsZW1lbnQuc3R5bGVba2V5XSA9IHN0eWxlQXR0cmlidXRlc1trZXldOwogIH0KICByZXR1cm4gZWxlbWVudDsKfQoKLy8gTWF4IG51bWJlciBvZiBieXRlcyB3aGljaCB3aWxsIGJlIHVwbG9hZGVkIGF0IGEgdGltZS4KY29uc3QgTUFYX1BBWUxPQURfU0laRSA9IDEwMCAqIDEwMjQ7Ci8vIE1heCBhbW91bnQgb2YgdGltZSB0byBibG9jayB3YWl0aW5nIGZvciB0aGUgdXNlci4KY29uc3QgRklMRV9DSEFOR0VfVElNRU9VVF9NUyA9IDMwICogMTAwMDsKCmZ1bmN0aW9uIF91cGxvYWRGaWxlcyhpbnB1dElkLCBvdXRwdXRJZCkgewogIGNvbnN0IHN0ZXBzID0gdXBsb2FkRmlsZXNTdGVwKGlucHV0SWQsIG91dHB1dElkKTsKICBjb25zdCBvdXRwdXRFbGVtZW50ID0gZG9jdW1lbnQuZ2V0RWxlbWVudEJ5SWQob3V0cHV0SWQpOwogIC8vIENhY2hlIHN0ZXBzIG9uIHRoZSBvdXRwdXRFbGVtZW50IHRvIG1ha2UgaXQgYXZhaWxhYmxlIGZvciB0aGUgbmV4dCBjYWxsCiAgLy8gdG8gdXBsb2FkRmlsZXNDb250aW51ZSBmcm9tIFB5dGhvbi4KICBvdXRwdXRFbGVtZW50LnN0ZXBzID0gc3RlcHM7CgogIHJldHVybiBfdXBsb2FkRmlsZXNDb250aW51ZShvdXRwdXRJZCk7Cn0KCi8vIFRoaXMgaXMgcm91Z2hseSBhbiBhc3luYyBnZW5lcmF0b3IgKG5vdCBzdXBwb3J0ZWQgaW4gdGhlIGJyb3dzZXIgeWV0KSwKLy8gd2hlcmUgdGhlcmUgYXJlIG11bHRpcGxlIGFzeW5jaHJvbm91cyBzdGVwcyBhbmQgdGhlIFB5dGhvbiBzaWRlIGlzIGdvaW5nCi8vIHRvIHBvbGwgZm9yIGNvbXBsZXRpb24gb2YgZWFjaCBzdGVwLgovLyBUaGlzIHVzZXMgYSBQcm9taXNlIHRvIGJsb2NrIHRoZSBweXRob24gc2lkZSBvbiBjb21wbGV0aW9uIG9mIGVhY2ggc3RlcCwKLy8gdGhlbiBwYXNzZXMgdGhlIHJlc3VsdCBvZiB0aGUgcHJldmlvdXMgc3RlcCBhcyB0aGUgaW5wdXQgdG8gdGhlIG5leHQgc3RlcC4KZnVuY3Rpb24gX3VwbG9hZEZpbGVzQ29udGludWUob3V0cHV0SWQpIHsKICBjb25zdCBvdXRwdXRFbGVtZW50ID0gZG9jdW1lbnQuZ2V0RWxlbWVudEJ5SWQob3V0cHV0SWQpOwogIGNvbnN0IHN0ZXBzID0gb3V0cHV0RWxlbWVudC5zdGVwczsKCiAgY29uc3QgbmV4dCA9IHN0ZXBzLm5leHQob3V0cHV0RWxlbWVudC5sYXN0UHJvbWlzZVZhbHVlKTsKICByZXR1cm4gUHJvbWlzZS5yZXNvbHZlKG5leHQudmFsdWUucHJvbWlzZSkudGhlbigodmFsdWUpID0+IHsKICAgIC8vIENhY2hlIHRoZSBsYXN0IHByb21pc2UgdmFsdWUgdG8gbWFrZSBpdCBhdmFpbGFibGUgdG8gdGhlIG5leHQKICAgIC8vIHN0ZXAgb2YgdGhlIGdlbmVyYXRvci4KICAgIG91dHB1dEVsZW1lbnQubGFzdFByb21pc2VWYWx1ZSA9IHZhbHVlOwogICAgcmV0dXJuIG5leHQudmFsdWUucmVzcG9uc2U7CiAgfSk7Cn0KCi8qKgogKiBHZW5lcmF0b3IgZnVuY3Rpb24gd2hpY2ggaXMgY2FsbGVkIGJldHdlZW4gZWFjaCBhc3luYyBzdGVwIG9mIHRoZSB1cGxvYWQKICogcHJvY2Vzcy4KICogQHBhcmFtIHtzdHJpbmd9IGlucHV0SWQgRWxlbWVudCBJRCBvZiB0aGUgaW5wdXQgZmlsZSBwaWNrZXIgZWxlbWVudC4KICogQHBhcmFtIHtzdHJpbmd9IG91dHB1dElkIEVsZW1lbnQgSUQgb2YgdGhlIG91dHB1dCBkaXNwbGF5LgogKiBAcmV0dXJuIHshSXRlcmFibGU8IU9iamVjdD59IEl0ZXJhYmxlIG9mIG5leHQgc3RlcHMuCiAqLwpmdW5jdGlvbiogdXBsb2FkRmlsZXNTdGVwKGlucHV0SWQsIG91dHB1dElkKSB7CiAgY29uc3QgaW5wdXRFbGVtZW50ID0gZG9jdW1lbnQuZ2V0RWxlbWVudEJ5SWQoaW5wdXRJZCk7CiAgaW5wdXRFbGVtZW50LmRpc2FibGVkID0gZmFsc2U7CgogIGNvbnN0IG91dHB1dEVsZW1lbnQgPSBkb2N1bWVudC5nZXRFbGVtZW50QnlJZChvdXRwdXRJZCk7CiAgb3V0cHV0RWxlbWVudC5pbm5lckhUTUwgPSAnJzsKCiAgY29uc3QgcGlja2VkUHJvbWlzZSA9IG5ldyBQcm9taXNlKChyZXNvbHZlKSA9PiB7CiAgICBpbnB1dEVsZW1lbnQuYWRkRXZlbnRMaXN0ZW5lcignY2hhbmdlJywgKGUpID0+IHsKICAgICAgcmVzb2x2ZShlLnRhcmdldC5maWxlcyk7CiAgICB9KTsKICB9KTsKCiAgY29uc3QgY2FuY2VsID0gZG9jdW1lbnQuY3JlYXRlRWxlbWVudCgnYnV0dG9uJyk7CiAgaW5wdXRFbGVtZW50LnBhcmVudEVsZW1lbnQuYXBwZW5kQ2hpbGQoY2FuY2VsKTsKICBjYW5jZWwudGV4dENvbnRlbnQgPSAnQ2FuY2VsIHVwbG9hZCc7CiAgY29uc3QgY2FuY2VsUHJvbWlzZSA9IG5ldyBQcm9taXNlKChyZXNvbHZlKSA9PiB7CiAgICBjYW5jZWwub25jbGljayA9ICgpID0+IHsKICAgICAgcmVzb2x2ZShudWxsKTsKICAgIH07CiAgfSk7CgogIC8vIENhbmNlbCB1cGxvYWQgaWYgdXNlciBoYXNuJ3QgcGlja2VkIGFueXRoaW5nIGluIHRpbWVvdXQuCiAgY29uc3QgdGltZW91dFByb21pc2UgPSBuZXcgUHJvbWlzZSgocmVzb2x2ZSkgPT4gewogICAgc2V0VGltZW91dCgoKSA9PiB7CiAgICAgIHJlc29sdmUobnVsbCk7CiAgICB9LCBGSUxFX0NIQU5HRV9USU1FT1VUX01TKTsKICB9KTsKCiAgLy8gV2FpdCBmb3IgdGhlIHVzZXIgdG8gcGljayB0aGUgZmlsZXMuCiAgY29uc3QgZmlsZXMgPSB5aWVsZCB7CiAgICBwcm9taXNlOiBQcm9taXNlLnJhY2UoW3BpY2tlZFByb21pc2UsIHRpbWVvdXRQcm9taXNlLCBjYW5jZWxQcm9taXNlXSksCiAgICByZXNwb25zZTogewogICAgICBhY3Rpb246ICdzdGFydGluZycsCiAgICB9CiAgfTsKCiAgaWYgKCFmaWxlcykgewogICAgcmV0dXJuIHsKICAgICAgcmVzcG9uc2U6IHsKICAgICAgICBhY3Rpb246ICdjb21wbGV0ZScsCiAgICAgIH0KICAgIH07CiAgfQoKICBjYW5jZWwucmVtb3ZlKCk7CgogIC8vIERpc2FibGUgdGhlIGlucHV0IGVsZW1lbnQgc2luY2UgZnVydGhlciBwaWNrcyBhcmUgbm90IGFsbG93ZWQuCiAgaW5wdXRFbGVtZW50LmRpc2FibGVkID0gdHJ1ZTsKCiAgZm9yIChjb25zdCBmaWxlIG9mIGZpbGVzKSB7CiAgICBjb25zdCBsaSA9IGRvY3VtZW50LmNyZWF0ZUVsZW1lbnQoJ2xpJyk7CiAgICBsaS5hcHBlbmQoc3BhbihmaWxlLm5hbWUsIHtmb250V2VpZ2h0OiAnYm9sZCd9KSk7CiAgICBsaS5hcHBlbmQoc3BhbigKICAgICAgICBgKCR7ZmlsZS50eXBlIHx8ICduL2EnfSkgLSAke2ZpbGUuc2l6ZX0gYnl0ZXMsIGAgKwogICAgICAgIGBsYXN0IG1vZGlmaWVkOiAkewogICAgICAgICAgICBmaWxlLmxhc3RNb2RpZmllZERhdGUgPyBmaWxlLmxhc3RNb2RpZmllZERhdGUudG9Mb2NhbGVEYXRlU3RyaW5nKCkgOgogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAnbi9hJ30gLSBgKSk7CiAgICBjb25zdCBwZXJjZW50ID0gc3BhbignMCUgZG9uZScpOwogICAgbGkuYXBwZW5kQ2hpbGQocGVyY2VudCk7CgogICAgb3V0cHV0RWxlbWVudC5hcHBlbmRDaGlsZChsaSk7CgogICAgY29uc3QgZmlsZURhdGFQcm9taXNlID0gbmV3IFByb21pc2UoKHJlc29sdmUpID0+IHsKICAgICAgY29uc3QgcmVhZGVyID0gbmV3IEZpbGVSZWFkZXIoKTsKICAgICAgcmVhZGVyLm9ubG9hZCA9IChlKSA9PiB7CiAgICAgICAgcmVzb2x2ZShlLnRhcmdldC5yZXN1bHQpOwogICAgICB9OwogICAgICByZWFkZXIucmVhZEFzQXJyYXlCdWZmZXIoZmlsZSk7CiAgICB9KTsKICAgIC8vIFdhaXQgZm9yIHRoZSBkYXRhIHRvIGJlIHJlYWR5LgogICAgbGV0IGZpbGVEYXRhID0geWllbGQgewogICAgICBwcm9taXNlOiBmaWxlRGF0YVByb21pc2UsCiAgICAgIHJlc3BvbnNlOiB7CiAgICAgICAgYWN0aW9uOiAnY29udGludWUnLAogICAgICB9CiAgICB9OwoKICAgIC8vIFVzZSBhIGNodW5rZWQgc2VuZGluZyB0byBhdm9pZCBtZXNzYWdlIHNpemUgbGltaXRzLiBTZWUgYi82MjExNTY2MC4KICAgIGxldCBwb3NpdGlvbiA9IDA7CiAgICB3aGlsZSAocG9zaXRpb24gPCBmaWxlRGF0YS5ieXRlTGVuZ3RoKSB7CiAgICAgIGNvbnN0IGxlbmd0aCA9IE1hdGgubWluKGZpbGVEYXRhLmJ5dGVMZW5ndGggLSBwb3NpdGlvbiwgTUFYX1BBWUxPQURfU0laRSk7CiAgICAgIGNvbnN0IGNodW5rID0gbmV3IFVpbnQ4QXJyYXkoZmlsZURhdGEsIHBvc2l0aW9uLCBsZW5ndGgpOwogICAgICBwb3NpdGlvbiArPSBsZW5ndGg7CgogICAgICBjb25zdCBiYXNlNjQgPSBidG9hKFN0cmluZy5mcm9tQ2hhckNvZGUuYXBwbHkobnVsbCwgY2h1bmspKTsKICAgICAgeWllbGQgewogICAgICAgIHJlc3BvbnNlOiB7CiAgICAgICAgICBhY3Rpb246ICdhcHBlbmQnLAogICAgICAgICAgZmlsZTogZmlsZS5uYW1lLAogICAgICAgICAgZGF0YTogYmFzZTY0LAogICAgICAgIH0sCiAgICAgIH07CiAgICAgIHBlcmNlbnQudGV4dENvbnRlbnQgPQogICAgICAgICAgYCR7TWF0aC5yb3VuZCgocG9zaXRpb24gLyBmaWxlRGF0YS5ieXRlTGVuZ3RoKSAqIDEwMCl9JSBkb25lYDsKICAgIH0KICB9CgogIC8vIEFsbCBkb25lLgogIHlpZWxkIHsKICAgIHJlc3BvbnNlOiB7CiAgICAgIGFjdGlvbjogJ2NvbXBsZXRlJywKICAgIH0KICB9Owp9CgpzY29wZS5nb29nbGUgPSBzY29wZS5nb29nbGUgfHwge307CnNjb3BlLmdvb2dsZS5jb2xhYiA9IHNjb3BlLmdvb2dsZS5jb2xhYiB8fCB7fTsKc2NvcGUuZ29vZ2xlLmNvbGFiLl9maWxlcyA9IHsKICBfdXBsb2FkRmlsZXMsCiAgX3VwbG9hZEZpbGVzQ29udGludWUsCn07Cn0pKHNlbGYpOwo=","ok":true,"headers":[["content-type","application/javascript"]],"status":200,"status_text":""}},"base_uri":"https://localhost:8080/","height":76}},"cell_type":"code","source":["from google.colab import files\n","# 창이 뜨면 파일을 선택해서 업로드한다\n","uploaded = files.upload()\n"],"execution_count":0,"outputs":[{"output_type":"display_data","data":{"text/html":["\n"," \n"," \n"," Upload widget is only available when the cell has been executed in the\n"," current browser session. Please rerun this cell to enable.\n"," \n"," "],"text/plain":[""]},"metadata":{"tags":[]}},{"output_type":"stream","text":["Saving tinyshakespeare.txt to tinyshakespeare.txt\n"],"name":"stdout"}]},{"metadata":{"id":"0qVk_tWpm3dk","colab_type":"code","colab":{}},"cell_type":"code","source":["import torch\n","from torch import nn, optim\n","from torch.utils.data import (Dataset, \n"," DataLoader,\n"," TensorDataset)\n","import tqdm"],"execution_count":0,"outputs":[]},{"metadata":{"id":"zQvNRnvGnGMX","colab_type":"code","colab":{}},"cell_type":"code","source":["class ShakespeareDataset(Dataset):\n"," def __init__(self, path, chunk_size=200):\n"," # 파일을 읽어서 수치 리스트로 변환\n"," data = str2ints(open(path).read().strip(), vocab_dict)\n"," \n"," # Tensor로 변환해서 split 한다\n"," data = torch.tensor(data, dtype=torch.int64).split(chunk_size)\n"," \n"," # 마지막 덩어리(chunk)의 길이를 확인해서 부족한 경우 버린다\n"," if len(data[-1]) < chunk_size:\n"," data = data[:-1]\n"," \n"," self.data = data\n"," self.n_chunks = len(self.data)\n"," \n"," def __len__(self):\n"," return self.n_chunks\n","\n"," def __getitem__(self, idx):\n"," return self.data[idx]\n"],"execution_count":0,"outputs":[]},{"metadata":{"id":"Oy81AcMxoeeR","colab_type":"code","colab":{}},"cell_type":"code","source":["ds = ShakespeareDataset(\"/content/tinyshakespeare.txt\", chunk_size=200)\n","loader = DataLoader(ds, batch_size=32, shuffle=True, num_workers=4)\n"],"execution_count":0,"outputs":[]},{"metadata":{"id":"jDyTY8x4uRiX","colab_type":"code","colab":{}},"cell_type":"code","source":["class SequenceGenerationNet(nn.Module):\n"," def __init__(self, num_embeddings, \n"," embedding_dim=50, \n"," hidden_size=50,\n"," num_layers=1, dropout=0.2):\n"," super().__init__()\n"," self.emb = nn.Embedding(num_embeddings, embedding_dim)\n"," self.lstm = nn.LSTM(embedding_dim, \n"," hidden_size,\n"," num_layers,\n"," batch_first=True,\n"," dropout=dropout)\n"," # Linear의 output 크기는 첫 Embedding의 \n"," # input 크기와 같은 num_embeddings\n"," self.linear = nn.Linear(hidden_size, num_embeddings)\n"," \n"," def forward(self, x, h0=None):\n"," x = self.emb(x)\n"," x, h = self.lstm(x, h0)\n"," x = self.linear(x)\n"," return x, h\n"],"execution_count":0,"outputs":[]},{"metadata":{"id":"SQ4vfribu8i1","colab_type":"code","colab":{}},"cell_type":"code","source":["def generate_seq(net, start_phrase=\"The King said \",\n"," length=200, temperature=0.8, device=\"cpu\"):\n"," # 모델을 평가 모드로 설정\n"," net.eval()\n"," # 출력 수치를 저장할 리스트\n"," result = []\n"," \n"," # 시작 문자열을 Tensor로 변환\n"," start_tensor = torch.tensor(\n"," str2ints(start_phrase, vocab_dict),\n"," dtype=torch.int64\n"," ).to(device)\n"," # 선두에 batch 차원을 붙인다\n"," x0 = start_tensor.unsqueeze(0) \n"," # RNN을 통해서 출력과 새로운 내부 상태를 얻는다\n"," o, h = net(x0)\n"," # 출력을 정규화돼있지 않은 확률로 변환\n"," out_dist = o[:, -1].view(-1).exp()\n"," # 확률로부터 실제 문자의 인덱스를 샘플링グ\n"," top_i = torch.multinomial(out_dist, 1)[0]\n"," # 결과 저장\n"," result.append(top_i)\n"," \n"," # 생성된 결과를 차례로 RNN에 넣는다\n"," for i in range(length):\n"," inp = torch.tensor([[top_i]], dtype=torch.int64)\n"," inp = inp.to(device)\n"," o, h = net(inp, h)\n"," out_dist = o.view(-1).exp()\n"," top_i = torch.multinomial(out_dist, 1)[0]\n"," result.append(top_i)\n"," \n"," # 시작 문자열과 생성된 문자열을 모아서 반환\n"," return start_phrase + ints2str(result, all_chars)\n"],"execution_count":0,"outputs":[]},{"metadata":{"id":"GrrqpgRy4P-F","colab_type":"code","outputId":"8680bcb6-91b6-44dd-e84e-8d567e23e8a4","executionInfo":{"status":"ok","timestamp":1545029221763,"user_tz":-480,"elapsed":453010,"user":{"displayName":"winston kim","photoUrl":"","userId":"05942964544969189760"}},"colab":{"base_uri":"https://localhost:8080/","height":9181}},"cell_type":"code","source":["from statistics import mean\n","net = SequenceGenerationNet(vocab_size, 20, 50,\n"," num_layers=2, dropout=0.1)\n","net.to(\"cuda:0\")\n","opt = optim.Adam(net.parameters())\n","# 다중 식별 문제이므로 SoftmaxCrossEntropyLoss가 손실 함수가 된다\n","loss_f = nn.CrossEntropyLoss()\n","\n","for epoch in range(50):\n"," net.train()\n"," losses = []\n"," for data in tqdm.tqdm(loader):\n"," # x는 처음부터 마지막의 하나 앞 문자까지\n"," x = data[:, :-1]\n"," # y는 두 번째부터 마지막 문자까지\n"," y = data[:, 1:]\n"," \n"," x = x.to(\"cuda:0\")\n"," y = y.to(\"cuda:0\")\n"," \n"," y_pred, _ = net(x)\n"," # batch와 step 축을 통합해서 CrossEntropyLoss에 전달\n"," loss = loss_f(y_pred.view(-1, vocab_size), y.view(-1))\n"," net.zero_grad()\n"," loss.backward()\n"," opt.step()\n"," losses.append(loss.item())\n"," # 현재 손실 함수와 생성된 문장 예 표시\n"," print(epoch, mean(losses))\n"," with torch.no_grad():\n"," print(generate_seq(net, device=\"cuda:0\"))\n"],"execution_count":0,"outputs":[{"output_type":"stream","text":["100%|██████████| 175/175 [00:08<00:00, 19.89it/s]\n"],"name":"stderr"},{"output_type":"stream","text":["0 3.4874473898751397\n"],"name":"stdout"},{"output_type":"stream","text":["\r 0%| | 0/175 [00:00\")\n","shift_marks_regex = re.compile(\"([?!\\.])\")\n","\n","unk = 0\n","sos = 1\n","eos = 2\n","\n","def normalize(text):\n"," text = text.lower()\n"," # 불필요한 문자 제거\n"," text = remove_marks_regex.sub(\"\", text)\n"," # ?!. 와 단어 사이에 공백 삽입\n"," text = shift_marks_regex.sub(r\" \\1\", text)\n"," return text\n"," \n","def parse_line(line):\n"," line = normalize(line.strip())\n"," # 번역 대상(src)과 번역 결과(trg) 각각의 토큰을 리스트로 만든다\n"," src, trg = line.split(\"\\t\")\n"," src_tokens = src.strip().split()\n"," trg_tokens = trg.strip().split()\n"," return src_tokens, trg_tokens\n"," \n","def build_vocab(tokens):\n"," # 파일 안의 모든 문장에서 토큰의 등장 횟수를 확인\n"," counts = collections.Counter(tokens)\n"," # 토큰의 등장 횟수를 많은 순으로 나열\n"," sorted_counts = sorted(counts.items(), \n"," key=lambda c: c[1], reverse=True)\n"," # 세 개의 태그를 추가해서 정방향 리스트와 역방향 용어집 만들기\n"," word_list = [\"\", \"\", \"\"] \\\n"," + [x[0] for x in sorted_counts]\n"," word_dict = dict((w, i) for i, w in enumerate(word_list))\n"," return word_list, word_dict\n"," \n","def words2tensor(words, word_dict, max_len, padding=0):\n"," # 끝에 종료 태그를 붙임\n"," words = words + [\"\"]\n"," # 사전을 이용해서 수치 리스트로 변환\n"," words = [word_dict.get(w, 0) for w in words]\n"," seq_len = len(words)\n"," # 길이가 max_len이하이면 패딩한다\n"," if seq_len < max_len + 1:\n"," words = words + [padding] * (max_len + 1 - seq_len)\n"," # Tensor로 변환해서 반환\n"," return torch.tensor(words, dtype=torch.int64), seq_len\n","\n"],"execution_count":0,"outputs":[]},{"metadata":{"id":"EAHMfV8Qmw_5","colab_type":"code","colab":{}},"cell_type":"code","source":["class TranslationPairDataset(Dataset):\n"," def __init__(self, path, max_len=15):\n"," # 단어 수사 많은 문장을 걸러내는 함수\n"," def filter_pair(p):\n"," return not (len(p[0]) > max_len \n"," or len(p[1]) > max_len)\n"," # 파일을 열어서, 파스 및 필터링 \n"," with open(path) as fp:\n"," pairs = map(parse_line, fp)\n"," pairs = filter(filter_pair, pairs)\n"," pairs = list(pairs)\n"," # 문장의 소스와 타켓으로 나눔\n"," src = [p[0] for p in pairs]\n"," trg = [p[1] for p in pairs]\n"," #각각의 어휘집 작성\n"," self.src_word_list, self.src_word_dict = \\\n"," build_vocab(itertools.chain.from_iterable(src))\n"," self.trg_word_list, self.trg_word_dict = \\\n"," build_vocab(itertools.chain.from_iterable(trg))\n"," # 어휘집을 사용해서 Tensor로 변환\n"," self.src_data = [words2tensor(\n"," words, self.src_word_dict, max_len)\n"," for words in src]\n"," self.trg_data = [words2tensor(\n"," words, self.trg_word_dict, max_len, -100)\n"," for words in trg]\n"," def __len__(self):\n"," return len(self.src_data)\n"," \n"," def __getitem__(self, idx):\n"," src, lsrc = self.src_data[idx]\n"," trg, ltrg = self.trg_data[idx]\n"," return src, lsrc, trg, ltrg\n","\n"],"execution_count":0,"outputs":[]},{"metadata":{"id":"gge2cxkUnwVd","colab_type":"code","colab":{}},"cell_type":"code","source":["batch_size = 64\n","max_len = 10\n","path = \"/content/spa.txt\"\n","ds = TranslationPairDataset(path, max_len=max_len)\n","loader = DataLoader(ds, batch_size=batch_size, shuffle=True,\n"," num_workers=4)\n"],"execution_count":0,"outputs":[]},{"metadata":{"id":"Oh7y_8tgoTlm","colab_type":"code","colab":{}},"cell_type":"code","source":["class Encoder(nn.Module):\n"," def __init__(self, num_embeddings,\n"," embedding_dim=50, \n"," hidden_size=50,\n"," num_layers=1,\n"," dropout=0.2):\n"," super().__init__()\n"," self.emb = nn.Embedding(num_embeddings, \n"," embedding_dim, padding_idx=0)\n"," self.lstm = nn.LSTM(embedding_dim,\n"," hidden_size, num_layers,\n"," batch_first=True,\n","dropout=dropout)\n"," \n"," def forward(self, x, h0=None, l=None):\n"," x = self.emb(x)\n"," if l is not None:\n"," x = nn.utils.rnn.pack_padded_sequence(\n"," x, l, batch_first=True)\n"," _, h = self.lstm(x, h0)\n"," return h\n"],"execution_count":0,"outputs":[]},{"metadata":{"id":"tHC35oMSoVyD","colab_type":"code","colab":{}},"cell_type":"code","source":["class Decoder(nn.Module):\n"," def __init__(self, num_embeddings,\n"," embedding_dim=50, \n"," hidden_size=50,\n"," num_layers=1,\n"," dropout=0.2):\n"," super().__init__()\n"," self.emb = nn.Embedding(num_embeddings, embedding_dim, padding_idx=0)\n"," self.lstm = nn.LSTM(embedding_dim, hidden_size,\n"," num_layers, batch_first=True,\n"," dropout=dropout)\n"," self.linear = nn.Linear(hidden_size, num_embeddings)\n"," \n"," def forward(self, x, h, l=None):\n"," x = self.emb(x)\n"," if l is not None:\n"," x = nn.utils.rnn.pack_padded_sequence(\n"," x, l, batch_first=True)\n"," x, h = self.lstm(x, h)\n"," if l is not None:\n"," x = nn.utils.rnn.pad_packed_sequence(x, batch_first=True, padding_value=0)[0]\n"," x = self.linear(x)\n"," return x, h\n"],"execution_count":0,"outputs":[]},{"metadata":{"id":"HCT4wbK-5DQ_","colab_type":"code","colab":{}},"cell_type":"code","source":["def translate(input_str, enc, dec, max_len=15, device=\"cpu\"):\n"," # 입력 문자열을 수치화해서 Tensor로 변환\n"," words = normalize(input_str).split()\n"," input_tensor, seq_len = words2tensor(words, \n"," ds.src_word_dict, max_len=max_len)\n"," input_tensor = input_tensor.unsqueeze(0)\n"," # 엔코더에서 사용하므로 입력값의 길이도 리스트로 만들어둔다\n"," seq_len = [seq_len]\n"," # 시작 토큰 준비\n"," sos_inputs = torch.tensor(sos, dtype=torch.int64)\n"," input_tensor = input_tensor.to(device)\n"," sos_inputs = sos_inputs.to(device)\n"," # 입력 문자열을 엔코더에 넣어서 컨텍스트 얻기\n"," ctx = enc(input_tensor, l=seq_len)\n"," # 시작 토큰과 컨텍스트를 디코더의 초깃값으로 설정\n"," z = sos_inputs\n"," h = ctx\n"," results = []\n"," for i in range(max_len):\n"," # Decoder로 다음 단어 예측\n"," o, h = dec(z.view(1, 1), h)\n"," # 선형 계층의 출력이 가장 큰 위치가 다음 단어의 ID\n"," wi = o.detach().view(-1).max(0)[1]\n"," if wi.item() == eos:\n"," break\n"," results.append(wi.item())\n"," # 다음 입력값으로 현재 출력 ID를 사용\n"," z = wi\n"," # 기록해둔 출력 ID를 문자열로 변환\n"," return \" \".join(ds.trg_word_list[i] for i in results)\n"],"execution_count":0,"outputs":[]},{"metadata":{"id":"tftEQdHD6OVH","colab_type":"code","outputId":"25ff72c2-90b5-4231-dea7-bb636dead0e9","executionInfo":{"status":"ok","timestamp":1545045864966,"user_tz":-480,"elapsed":1142,"user":{"displayName":"winston kim","photoUrl":"","userId":"05942964544969189760"}},"colab":{"base_uri":"https://localhost:8080/","height":55}},"cell_type":"code","source":["enc = Encoder(len(ds.src_word_list), 100, 100, 2)\n","dec = Decoder(len(ds.trg_word_list), 100, 100, 2)\n","translate(\"I am a student.\", enc, dec)\n"],"execution_count":0,"outputs":[{"output_type":"execute_result","data":{"text/plain":["'susurro liberaron salida trasladé trasladé moscú moscú memorizando moscú moscú moscú moscú memorizando moscú moscú'"]},"metadata":{"tags":[]},"execution_count":30}]},{"metadata":{"id":"NKFqA3t9608o","colab_type":"code","colab":{}},"cell_type":"code","source":["enc = Encoder(len(ds.src_word_list), 100, 100, 2)\n","dec = Decoder(len(ds.trg_word_list), 100, 100, 2)\n","enc.to(\"cuda:0\")\n","dec.to(\"cuda:0\")\n","opt_enc = optim.Adam(enc.parameters(), 0.002)\n","opt_dec = optim.Adam(dec.parameters(), 0.002)\n","loss_f = nn.CrossEntropyLoss()\n"],"execution_count":0,"outputs":[]},{"metadata":{"id":"XcWJo1rn7X2r","colab_type":"code","outputId":"d481dd6a-f753-455a-a912-4fe6887c4c2e","colab":{"base_uri":"https://localhost:8080/","height":1835}},"cell_type":"code","source":["from statistics import mean\n","\n","def to2D(x):\n"," shapes = x.shape\n"," return x.reshape(shapes[0] * shapes[1], -1)\n"," \n","for epoc in range(30):\n"," # 신경망을 훈련 모드로 설정\n"," enc.train(), dec.train()\n"," losses = []\n"," for x, lx, y, ly in tqdm.tqdm(loader):\n"," # x의 PackedSequence를 만들기 위해 번역 소스의 길이로 내림차순 정렬한다\n"," lx, sort_idx = lx.sort(descending=True)\n"," x, y, ly = x[sort_idx], y[sort_idx], ly[sort_idx]\n"," x, y = x.to(\"cuda:0\"), y.to(\"cuda:0\")\n"," # 번역 소스를 엔코더에 넣어서 컨텍스트를 얻는다\n"," ctx = enc(x, l=lx)\n"," # y의 PackedSequence를 만들기 위해 번역 소스의 길이로 내림차순 정렬\n"," ly, sort_idx = ly.sort(descending=True)\n"," y = y[sort_idx]\n"," # Decoder의 초깃값 설정\n"," h0 = (ctx[0][:, sort_idx, :], ctx[1][:, sort_idx, :])\n"," z = y[:, :-1].detach()\n"," # -100인 상태에선 Embedding 계산에서 오류가 발생하므로 0으로 변경\n"," z[z==-100] = 0\n"," # 디코더에 넣어서 손실 함수 계산\n"," o, _ = dec(z, h0, l=ly-1)\n"," loss = loss_f(to2D(o[:]), to2D(y[:, 1:max(ly)]).squeeze())\n"," # Backpropagation(오차 역전파 실행)\n"," enc.zero_grad(), dec.zero_grad()\n"," loss.backward()\n"," opt_enc.step(), opt_dec.step()\n"," losses.append(loss.item())\n"," # 전체 데이터의 계산이 끝나면 현재의\n"," # 손실 함수 값이나 번역 결과를 표시\n"," enc.eval(), dec.eval()\n"," print(epoc, mean(losses))\n"," with torch.no_grad():\n"," print(translate(\"I am a student.\",\n"," enc, dec, max_len=max_len, \n","device=\"cuda:0\"))\n"," print(translate(\"He likes to eat pizza.\",\n"," enc, dec, max_len=max_len, \n","device=\"cuda:0\"))\n"," print(translate(\"She is my mother.\",\n"," enc, dec, max_len=max_len, \n","device=\"cuda:0\"))\n"],"execution_count":0,"outputs":[{"output_type":"stream","text":["100%|██████████| 1623/1623 [01:21<00:00, 19.96it/s]\n"," 0%| | 0/1623 [00:00