diff --git a/docs/_data/umum.yml b/docs/_data/umum.yml index 4c40cdc..9227456 100644 --- a/docs/_data/umum.yml +++ b/docs/_data/umum.yml @@ -40,3 +40,10 @@ date : 2019-07-29 author : taruma version : 1.0.0 + +- title : >- + Contoh Penggunaan Boltzmann Machine + notebook : taruma_udemy_boltzmann + date : 2019-07-30 + author : taruma + version : 1.0.0 diff --git a/notebook/taruma_udemy_boltzmann.ipynb b/notebook/taruma_udemy_boltzmann.ipynb new file mode 100644 index 0000000..d07a760 --- /dev/null +++ b/notebook/taruma_udemy_boltzmann.ipynb @@ -0,0 +1,1002 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "taruma_udemy_boltzmann.ipynb", + "version": "0.3.2", + "provenance": [], + "collapsed_sections": [] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "accelerator": "GPU" + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "n7Oy73RcZei5", + "colab_type": "text" + }, + "source": [ + "# Boltzmann Machines\n", + "\n", + "Notebook ini berdasarkan kursus __Deep Learning A-Z™: Hands-On Artificial Neural Networks__ di Udemy. [Lihat Kursus](https://www.udemy.com/deeplearning/).\n", + "\n", + "## Informasi Notebook\n", + "- __notebook name__: `taruma_udemy_boltzmann`\n", + "- __notebook version/date__: `1.0.0`/`20190730`\n", + "- __notebook server__: Google Colab\n", + "- __python version__: `3.6`\n", + "- __pytorch version__: `1.1.0`" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "0y9V9qyUZeiH", + "colab_type": "code", + "outputId": "06aa828f-f253-4a73-f66d-c8f53b5780bb", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 51 + } + }, + "source": [ + "#### NOTEBOOK DESCRIPTION\n", + "\n", + "from datetime import datetime\n", + "\n", + "NOTEBOOK_TITLE = 'taruma_udemy_boltzmann'\n", + "NOTEBOOK_VERSION = '1.0.0'\n", + "NOTEBOOK_DATE = 1 # Set 1, if you want add date classifier\n", + "\n", + "NOTEBOOK_NAME = \"{}_{}\".format(\n", + " NOTEBOOK_TITLE, \n", + " NOTEBOOK_VERSION.replace('.','_')\n", + ")\n", + "PROJECT_NAME = \"{}_{}{}\".format(\n", + " NOTEBOOK_TITLE, \n", + " NOTEBOOK_VERSION.replace('.','_'), \n", + " \"_\" + datetime.utcnow().strftime(\"%Y%m%d_%H%M\") if NOTEBOOK_DATE else \"\"\n", + ")\n", + "\n", + "print(f\"Nama Notebook: {NOTEBOOK_NAME}\")\n", + "print(f\"Nama Proyek: {PROJECT_NAME}\")" + ], + "execution_count": 1, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Nama Notebook: taruma_udemy_boltzmann_1_0_0\n", + "Nama Proyek: taruma_udemy_boltzmann_1_0_0_20190730_0822\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "1euCSTADZlh3", + "colab_type": "code", + "outputId": "79926104-9cd1-43d6-8ca6-35caf0003b2b", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 68 + } + }, + "source": [ + "#### System Version\n", + "import sys, torch\n", + "print(\"versi python: {}\".format(sys.version))\n", + "print(\"versi pytorch: {}\".format(torch.__version__))" + ], + "execution_count": 2, + "outputs": [ + { + "output_type": "stream", + "text": [ + "versi python: 3.6.8 (default, Jan 14 2019, 11:02:34) \n", + "[GCC 8.0.1 20180414 (experimental) [trunk revision 259383]]\n", + "versi pytorch: 1.1.0\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "s0qrkxTVZj_P", + "colab_type": "code", + "colab": {} + }, + "source": [ + "#### Load Notebook Extensions\n", + "%load_ext google.colab.data_table" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "S8l7WZ0_ZmCK", + "colab_type": "code", + "colab": {} + }, + "source": [ + "#### Download dataset\n", + "# ref: https://grouplens.org/datasets/movielens/\n", + "!wget -O boltzmann.zip \"https://sds-platform-private.s3-us-east-2.amazonaws.com/uploads/P16-Boltzmann-Machines.zip\"\n", + "!unzip boltzmann.zip" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "jXTtTiLAZmBu", + "colab_type": "code", + "colab": {} + }, + "source": [ + "#### Atur dataset path\n", + "DATASET_DIRECTORY = 'Boltzmann_Machines/'" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "3cDBJSwAknug", + "colab_type": "code", + "colab": {} + }, + "source": [ + "def showdata(dataframe):\n", + " print('Dataframe Size: {}'.format(dataframe.shape))\n", + " return dataframe" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "hqE6ozW8e0ra", + "colab_type": "text" + }, + "source": [ + "# STEP 1-5 DATA PREPROCESSING" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "fLvxd5pQdTQq", + "colab_type": "code", + "colab": {} + }, + "source": [ + "# Importing the libraries\n", + "import numpy as np\n", + "import pandas as pd\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.parallel\n", + "import torch.optim as optim\n", + "import torch.utils.data\n", + "from torch.autograd import Variable" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "lFQEACh4fJLp", + "colab_type": "code", + "outputId": "7aeb4f68-2294-439f-bc39-fd0b73e70de9", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 309 + } + }, + "source": [ + "movies = pd.read_csv(DATASET_DIRECTORY + 'ml-1m/movies.dat', sep='::', header=None, engine='python', encoding='latin-1')\n", + "showdata(movies).head(10)" + ], + "execution_count": 8, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Dataframe Size: (3883, 3)\n" + ], + "name": "stdout" + }, + { + "output_type": "execute_result", + "data": { + "application/vnd.google.colaboratory.module+javascript": "\n import \"https://ssl.gstatic.com/colaboratory/data_table/81868506e94e6988/data_table.js\";\n\n window.createDataTable({\n data: [[{\n 'v': 0,\n 'f': \"0\",\n },\n{\n 'v': 1,\n 'f': \"1\",\n },\n\"Toy Story (1995)\",\n\"Animation|Children's|Comedy\"],\n [{\n 'v': 1,\n 'f': \"1\",\n },\n{\n 'v': 2,\n 'f': \"2\",\n },\n\"Jumanji (1995)\",\n\"Adventure|Children's|Fantasy\"],\n [{\n 'v': 2,\n 'f': \"2\",\n },\n{\n 'v': 3,\n 'f': \"3\",\n },\n\"Grumpier Old Men (1995)\",\n\"Comedy|Romance\"],\n [{\n 'v': 3,\n 'f': \"3\",\n },\n{\n 'v': 4,\n 'f': \"4\",\n },\n\"Waiting to Exhale (1995)\",\n\"Comedy|Drama\"],\n [{\n 'v': 4,\n 'f': \"4\",\n },\n{\n 'v': 5,\n 'f': \"5\",\n },\n\"Father of the Bride Part II (1995)\",\n\"Comedy\"],\n [{\n 'v': 5,\n 'f': \"5\",\n },\n{\n 'v': 6,\n 'f': \"6\",\n },\n\"Heat (1995)\",\n\"Action|Crime|Thriller\"],\n [{\n 'v': 6,\n 'f': \"6\",\n },\n{\n 'v': 7,\n 'f': \"7\",\n },\n\"Sabrina (1995)\",\n\"Comedy|Romance\"],\n [{\n 'v': 7,\n 'f': \"7\",\n },\n{\n 'v': 8,\n 'f': \"8\",\n },\n\"Tom and Huck (1995)\",\n\"Adventure|Children's\"],\n [{\n 'v': 8,\n 'f': \"8\",\n },\n{\n 'v': 9,\n 'f': \"9\",\n },\n\"Sudden Death (1995)\",\n\"Action\"],\n [{\n 'v': 9,\n 'f': \"9\",\n },\n{\n 'v': 10,\n 'f': \"10\",\n },\n\"GoldenEye (1995)\",\n\"Action|Adventure|Thriller\"]],\n columns: [[\"number\", \"index\"], [\"number\", \"0\"], [\"string\", \"1\"], [\"string\", \"2\"]],\n rowsPerPage: 25,\n });\n ", + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
012
01Toy Story (1995)Animation|Children's|Comedy
12Jumanji (1995)Adventure|Children's|Fantasy
23Grumpier Old Men (1995)Comedy|Romance
34Waiting to Exhale (1995)Comedy|Drama
45Father of the Bride Part II (1995)Comedy
56Heat (1995)Action|Crime|Thriller
67Sabrina (1995)Comedy|Romance
78Tom and Huck (1995)Adventure|Children's
89Sudden Death (1995)Action
910GoldenEye (1995)Action|Adventure|Thriller
\n", + "
" + ], + "text/plain": [ + " 0 1 2\n", + "0 1 Toy Story (1995) Animation|Children's|Comedy\n", + "1 2 Jumanji (1995) Adventure|Children's|Fantasy\n", + "2 3 Grumpier Old Men (1995) Comedy|Romance\n", + "3 4 Waiting to Exhale (1995) Comedy|Drama\n", + "4 5 Father of the Bride Part II (1995) Comedy\n", + "5 6 Heat (1995) Action|Crime|Thriller\n", + "6 7 Sabrina (1995) Comedy|Romance\n", + "7 8 Tom and Huck (1995) Adventure|Children's\n", + "8 9 Sudden Death (1995) Action\n", + "9 10 GoldenEye (1995) Action|Adventure|Thriller" + ] + }, + "metadata": { + "tags": [] + }, + "execution_count": 8 + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "dgllaEzifthq", + "colab_type": "code", + "outputId": "6e9eee4f-859c-44bf-d906-c72e9289ee93", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 309 + } + }, + "source": [ + "users = pd.read_csv(DATASET_DIRECTORY + 'ml-1m/users.dat', sep='::', header=None, engine='python', encoding='latin-1')\n", + "showdata(users).head(10)" + ], + "execution_count": 9, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Dataframe Size: (6040, 5)\n" + ], + "name": "stdout" + }, + { + "output_type": "execute_result", + "data": { + "application/vnd.google.colaboratory.module+javascript": "\n import \"https://ssl.gstatic.com/colaboratory/data_table/81868506e94e6988/data_table.js\";\n\n window.createDataTable({\n data: [[{\n 'v': 0,\n 'f': \"0\",\n },\n{\n 'v': 1,\n 'f': \"1\",\n },\n\"F\",\n{\n 'v': 1,\n 'f': \"1\",\n },\n{\n 'v': 10,\n 'f': \"10\",\n },\n\"48067\"],\n [{\n 'v': 1,\n 'f': \"1\",\n },\n{\n 'v': 2,\n 'f': \"2\",\n },\n\"M\",\n{\n 'v': 56,\n 'f': \"56\",\n },\n{\n 'v': 16,\n 'f': \"16\",\n },\n\"70072\"],\n [{\n 'v': 2,\n 'f': \"2\",\n },\n{\n 'v': 3,\n 'f': \"3\",\n },\n\"M\",\n{\n 'v': 25,\n 'f': \"25\",\n },\n{\n 'v': 15,\n 'f': \"15\",\n },\n\"55117\"],\n [{\n 'v': 3,\n 'f': \"3\",\n },\n{\n 'v': 4,\n 'f': \"4\",\n },\n\"M\",\n{\n 'v': 45,\n 'f': \"45\",\n },\n{\n 'v': 7,\n 'f': \"7\",\n },\n\"02460\"],\n [{\n 'v': 4,\n 'f': \"4\",\n },\n{\n 'v': 5,\n 'f': \"5\",\n },\n\"M\",\n{\n 'v': 25,\n 'f': \"25\",\n },\n{\n 'v': 20,\n 'f': \"20\",\n },\n\"55455\"],\n [{\n 'v': 5,\n 'f': \"5\",\n },\n{\n 'v': 6,\n 'f': \"6\",\n },\n\"F\",\n{\n 'v': 50,\n 'f': \"50\",\n },\n{\n 'v': 9,\n 'f': \"9\",\n },\n\"55117\"],\n [{\n 'v': 6,\n 'f': \"6\",\n },\n{\n 'v': 7,\n 'f': \"7\",\n },\n\"M\",\n{\n 'v': 35,\n 'f': \"35\",\n },\n{\n 'v': 1,\n 'f': \"1\",\n },\n\"06810\"],\n [{\n 'v': 7,\n 'f': \"7\",\n },\n{\n 'v': 8,\n 'f': \"8\",\n },\n\"M\",\n{\n 'v': 25,\n 'f': \"25\",\n },\n{\n 'v': 12,\n 'f': \"12\",\n },\n\"11413\"],\n [{\n 'v': 8,\n 'f': \"8\",\n },\n{\n 'v': 9,\n 'f': \"9\",\n },\n\"M\",\n{\n 'v': 25,\n 'f': \"25\",\n },\n{\n 'v': 17,\n 'f': \"17\",\n },\n\"61614\"],\n [{\n 'v': 9,\n 'f': \"9\",\n },\n{\n 'v': 10,\n 'f': \"10\",\n },\n\"F\",\n{\n 'v': 35,\n 'f': \"35\",\n },\n{\n 'v': 1,\n 'f': \"1\",\n },\n\"95370\"]],\n columns: [[\"number\", \"index\"], [\"number\", \"0\"], [\"string\", \"1\"], [\"number\", \"2\"], [\"number\", \"3\"], [\"string\", \"4\"]],\n rowsPerPage: 25,\n });\n ", + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
01234
01F11048067
12M561670072
23M251555117
34M45702460
45M252055455
56F50955117
67M35106810
78M251211413
89M251761614
910F35195370
\n", + "
" + ], + "text/plain": [ + " 0 1 2 3 4\n", + "0 1 F 1 10 48067\n", + "1 2 M 56 16 70072\n", + "2 3 M 25 15 55117\n", + "3 4 M 45 7 02460\n", + "4 5 M 25 20 55455\n", + "5 6 F 50 9 55117\n", + "6 7 M 35 1 06810\n", + "7 8 M 25 12 11413\n", + "8 9 M 25 17 61614\n", + "9 10 F 35 1 95370" + ] + }, + "metadata": { + "tags": [] + }, + "execution_count": 9 + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "anc5Oi1sgDzc", + "colab_type": "code", + "outputId": "29b8bcdd-4562-40d0-ed3e-c179f86b9ed7", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 309 + } + }, + "source": [ + "ratings = pd.read_csv(DATASET_DIRECTORY + 'ml-1m/ratings.dat', sep='::', header=None, engine='python', encoding='latin-1')\n", + "showdata(ratings).head(10)" + ], + "execution_count": 10, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Dataframe Size: (1000209, 4)\n" + ], + "name": "stdout" + }, + { + "output_type": "execute_result", + "data": { + "application/vnd.google.colaboratory.module+javascript": "\n import \"https://ssl.gstatic.com/colaboratory/data_table/81868506e94e6988/data_table.js\";\n\n window.createDataTable({\n data: [[{\n 'v': 0,\n 'f': \"0\",\n },\n{\n 'v': 1,\n 'f': \"1\",\n },\n{\n 'v': 1193,\n 'f': \"1193\",\n },\n{\n 'v': 5,\n 'f': \"5\",\n },\n{\n 'v': 978300760,\n 'f': \"978300760\",\n }],\n [{\n 'v': 1,\n 'f': \"1\",\n },\n{\n 'v': 1,\n 'f': \"1\",\n },\n{\n 'v': 661,\n 'f': \"661\",\n },\n{\n 'v': 3,\n 'f': \"3\",\n },\n{\n 'v': 978302109,\n 'f': \"978302109\",\n }],\n [{\n 'v': 2,\n 'f': \"2\",\n },\n{\n 'v': 1,\n 'f': \"1\",\n },\n{\n 'v': 914,\n 'f': \"914\",\n },\n{\n 'v': 3,\n 'f': \"3\",\n },\n{\n 'v': 978301968,\n 'f': \"978301968\",\n }],\n [{\n 'v': 3,\n 'f': \"3\",\n },\n{\n 'v': 1,\n 'f': \"1\",\n },\n{\n 'v': 3408,\n 'f': \"3408\",\n },\n{\n 'v': 4,\n 'f': \"4\",\n },\n{\n 'v': 978300275,\n 'f': \"978300275\",\n }],\n [{\n 'v': 4,\n 'f': \"4\",\n },\n{\n 'v': 1,\n 'f': \"1\",\n },\n{\n 'v': 2355,\n 'f': \"2355\",\n },\n{\n 'v': 5,\n 'f': \"5\",\n },\n{\n 'v': 978824291,\n 'f': \"978824291\",\n }],\n [{\n 'v': 5,\n 'f': \"5\",\n },\n{\n 'v': 1,\n 'f': \"1\",\n },\n{\n 'v': 1197,\n 'f': \"1197\",\n },\n{\n 'v': 3,\n 'f': \"3\",\n },\n{\n 'v': 978302268,\n 'f': \"978302268\",\n }],\n [{\n 'v': 6,\n 'f': \"6\",\n },\n{\n 'v': 1,\n 'f': \"1\",\n },\n{\n 'v': 1287,\n 'f': \"1287\",\n },\n{\n 'v': 5,\n 'f': \"5\",\n },\n{\n 'v': 978302039,\n 'f': \"978302039\",\n }],\n [{\n 'v': 7,\n 'f': \"7\",\n },\n{\n 'v': 1,\n 'f': \"1\",\n },\n{\n 'v': 2804,\n 'f': \"2804\",\n },\n{\n 'v': 5,\n 'f': \"5\",\n },\n{\n 'v': 978300719,\n 'f': \"978300719\",\n }],\n [{\n 'v': 8,\n 'f': \"8\",\n },\n{\n 'v': 1,\n 'f': \"1\",\n },\n{\n 'v': 594,\n 'f': \"594\",\n },\n{\n 'v': 4,\n 'f': \"4\",\n },\n{\n 'v': 978302268,\n 'f': \"978302268\",\n }],\n [{\n 'v': 9,\n 'f': \"9\",\n },\n{\n 'v': 1,\n 'f': \"1\",\n },\n{\n 'v': 919,\n 'f': \"919\",\n },\n{\n 'v': 4,\n 'f': \"4\",\n },\n{\n 'v': 978301368,\n 'f': \"978301368\",\n }]],\n columns: [[\"number\", \"index\"], [\"number\", \"0\"], [\"number\", \"1\"], [\"number\", \"2\"], [\"number\", \"3\"]],\n rowsPerPage: 25,\n });\n ", + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
0123
0111935978300760
116613978302109
219143978301968
3134084978300275
4123555978824291
5111973978302268
6112875978302039
7128045978300719
815944978302268
919194978301368
\n", + "
" + ], + "text/plain": [ + " 0 1 2 3\n", + "0 1 1193 5 978300760\n", + "1 1 661 3 978302109\n", + "2 1 914 3 978301968\n", + "3 1 3408 4 978300275\n", + "4 1 2355 5 978824291\n", + "5 1 1197 3 978302268\n", + "6 1 1287 5 978302039\n", + "7 1 2804 5 978300719\n", + "8 1 594 4 978302268\n", + "9 1 919 4 978301368" + ] + }, + "metadata": { + "tags": [] + }, + "execution_count": 10 + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "xU2S9y8NgRPW", + "colab_type": "code", + "colab": {} + }, + "source": [ + "# Preparing the training set and the test set\n", + "training_set = pd.read_csv(DATASET_DIRECTORY + 'ml-100k/u1.base', delimiter='\\t')\n", + "training_set = np.array(training_set, dtype='int')\n", + "test_set = pd.read_csv(DATASET_DIRECTORY + 'ml-100k/u1.test', delimiter='\\t')\n", + "test_set = np.array(test_set, dtype='int')" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "G7WbrddJl3Q9", + "colab_type": "code", + "colab": {} + }, + "source": [ + "# Getting the number of users and movies\n", + "nb_users = int(max(max(training_set[:, 0]), max(test_set[:, 0])))\n", + "nb_movies = int(max(max(training_set[:, 1]), max(test_set[:, 1])))" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "yRUTR_K3_rzP", + "colab_type": "code", + "colab": {} + }, + "source": [ + "# Converting the data into an array with users in lines and movies in columns\n", + "def convert(data):\n", + " new_data = []\n", + " for id_users in range(1, nb_users+1):\n", + " id_movies = data[:, 1][data[:, 0] == id_users]\n", + " id_ratings = data[:, 2][data[:, 0] == id_users]\n", + " ratings = np.zeros(nb_movies)\n", + " ratings[id_movies - 1] = id_ratings\n", + " new_data.append(list(ratings))\n", + " return new_data\n", + "\n", + "training_set = convert(training_set)\n", + "test_set = convert(test_set)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "u0Fk8Q0YCNZr", + "colab_type": "code", + "colab": {} + }, + "source": [ + "# Converting the data into Torch tensors\n", + "training_set = torch.FloatTensor(training_set)\n", + "test_set = torch.FloatTensor(test_set)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "q2arIJufDBYd", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 136 + }, + "outputId": "6f90b00a-799a-4db0-f1ac-82241e375329" + }, + "source": [ + "training_set." + ], + "execution_count": 25, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "tensor([[0., 3., 4., ..., 0., 0., 0.],\n", + " [4., 0., 0., ..., 0., 0., 0.],\n", + " [0., 0., 0., ..., 0., 0., 0.],\n", + " ...,\n", + " [5., 0., 0., ..., 0., 0., 0.],\n", + " [0., 0., 0., ..., 0., 0., 0.],\n", + " [0., 5., 0., ..., 0., 0., 0.]])" + ] + }, + "metadata": { + "tags": [] + }, + "execution_count": 25 + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "wwWlLU5ODGgF", + "colab_type": "text" + }, + "source": [ + "# STEP 6" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "m0KbLrFZDCjK", + "colab_type": "code", + "colab": {} + }, + "source": [ + "# Converting the ratings into binary ratings 1 (Liked) or 0 (Not Liked)\n", + "training_set[training_set == 0] = -1\n", + "training_set[training_set == 1] = 0\n", + "training_set[training_set == 2] = 0\n", + "training_set[training_set >= 3] = 1\n", + "\n", + "test_set[test_set == 0] = -1\n", + "test_set[test_set == 1] = 0\n", + "test_set[test_set == 2] = 0\n", + "test_set[test_set >= 3] = 1" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "zHEwVvTlD-DK", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 136 + }, + "outputId": "08d4e08c-6c99-4fa5-e528-ca995f016ed2" + }, + "source": [ + "training_set" + ], + "execution_count": 27, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "tensor([[-1., 1., 1., ..., -1., -1., -1.],\n", + " [ 1., -1., -1., ..., -1., -1., -1.],\n", + " [-1., -1., -1., ..., -1., -1., -1.],\n", + " ...,\n", + " [ 1., -1., -1., ..., -1., -1., -1.],\n", + " [-1., -1., -1., ..., -1., -1., -1.],\n", + " [-1., 1., -1., ..., -1., -1., -1.]])" + ] + }, + "metadata": { + "tags": [] + }, + "execution_count": 27 + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "a4YrUGpDEFEV", + "colab_type": "text" + }, + "source": [ + "# STEP 7 - 10 Building RBM Object" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "S3i8jV-RD_MV", + "colab_type": "code", + "colab": {} + }, + "source": [ + "# Creating the architecture of the Neural Network\n", + "# nv = number visible nodes, nh = number hidden nodes\n", + "class RBM():\n", + " def __init__(self, nv, nh):\n", + " self.W = torch.randn(nh, nv)\n", + " self.a = torch.randn(1, nh)\n", + " self.b = torch.randn(1, nv)\n", + " def sample_h(self, x):\n", + " wx = torch.mm(x, self.W.t())\n", + " activation = wx + self.a.expand_as(wx)\n", + " p_h_given_v = torch.sigmoid(activation)\n", + " return p_h_given_v, torch.bernoulli(p_h_given_v)\n", + " def sample_v(self, y):\n", + " wy = torch.mm(y, self.W)\n", + " activation = wy + self.b.expand_as(wy)\n", + " p_v_given_h = torch.sigmoid(activation)\n", + " return p_v_given_h, torch.bernoulli(p_v_given_h)\n", + " def train(self, v0, vk, ph0, phk):\n", + " self.W += (torch.mm(v0.t(), ph0) - torch.mm(vk.t(), phk)).t()\n", + " self.b += torch.sum((v0 - vk), 0)\n", + " self.a += torch.sum((ph0 - phk), 0)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "GVmKKbizKFV2", + "colab_type": "text" + }, + "source": [ + "# STEP 11" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "Itwi6_KlKGmf", + "colab_type": "code", + "colab": {} + }, + "source": [ + "nv = len(training_set[0])\n", + "nh = 100\n", + "batch_size = 100\n", + "rbm = RBM(nv, nh)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "45nKkm5QK3hx", + "colab_type": "text" + }, + "source": [ + "# STEP 12-13" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "NJ94UFahKrOw", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 187 + }, + "outputId": "841eb9e8-91ba-4f7b-cae3-b6370b9181fc" + }, + "source": [ + "# Training the RBM\n", + "nb_epochs = 10\n", + "for epoch in range(1, nb_epochs + 1):\n", + " train_loss = 0\n", + " s = 0.\n", + " for id_user in range(0, nb_users - batch_size, batch_size):\n", + " vk = training_set[id_user:id_user+batch_size]\n", + " v0 = training_set[id_user:id_user+batch_size]\n", + " ph0,_ = rbm.sample_h(v0)\n", + " for k in range(10):\n", + " _,hk = rbm.sample_h(vk)\n", + " _,vk = rbm.sample_v(hk)\n", + " vk[v0<0] = v0[v0<0]\n", + " phk,_ = rbm.sample_h(vk)\n", + " rbm.train(v0, vk, ph0, phk)\n", + " train_loss += torch.mean(torch.abs(v0[v0>=0] - vk[v0>=0]))\n", + " s += 1.\n", + " print('epoch: '+str(epoch)+' loss: '+str(train_loss/s))" + ], + "execution_count": 39, + "outputs": [ + { + "output_type": "stream", + "text": [ + "epoch: 1 loss: tensor(0.3424)\n", + "epoch: 2 loss: tensor(0.2527)\n", + "epoch: 3 loss: tensor(0.2509)\n", + "epoch: 4 loss: tensor(0.2483)\n", + "epoch: 5 loss: tensor(0.2474)\n", + "epoch: 6 loss: tensor(0.2478)\n", + "epoch: 7 loss: tensor(0.2467)\n", + "epoch: 8 loss: tensor(0.2461)\n", + "epoch: 9 loss: tensor(0.2482)\n", + "epoch: 10 loss: tensor(0.2491)\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "qTvbF-u9NBdl", + "colab_type": "text" + }, + "source": [ + "# STEP 14" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "RSlbxB8ZLoy9", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 34 + }, + "outputId": "ee26a78e-47ac-4efc-d36d-9048dbae6cd8" + }, + "source": [ + "# Testing the RBM\n", + "test_loss = 0\n", + "s = 0.\n", + "for id_user in range(nb_users):\n", + " v = training_set[id_user:id_user+1]\n", + " vt = test_set[id_user:id_user+1]\n", + " if len(vt[vt>=0]) > 0:\n", + " _,h = rbm.sample_h(v)\n", + " _,v = rbm.sample_v(h)\n", + " test_loss += torch.mean(torch.abs(vt[vt>=0] - v[vt>=0]))\n", + " s += 1.\n", + "print('test loss: '+str(test_loss/s))" + ], + "execution_count": 40, + "outputs": [ + { + "output_type": "stream", + "text": [ + "test loss: tensor(0.2403)\n" + ], + "name": "stdout" + } + ] + } + ] +} \ No newline at end of file