diff --git a/docs/_data/umum.yml b/docs/_data/umum.yml
index 9227456..3d6374f 100644
--- a/docs/_data/umum.yml
+++ b/docs/_data/umum.yml
@@ -47,3 +47,10 @@
date : 2019-07-30
author : taruma
version : 1.0.0
+
+- title : >-
+ Contoh Penggunaan AutoEncoder
+ notebook : taruma_udemy_autoencoders
+ date : 2019-08-01
+ author : taruma
+ version : 1.0.0
diff --git a/notebook/taruma_udemy_autoencoders.ipynb b/notebook/taruma_udemy_autoencoders.ipynb
new file mode 100644
index 0000000..3604999
--- /dev/null
+++ b/notebook/taruma_udemy_autoencoders.ipynb
@@ -0,0 +1,1216 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "name": "taruma_udemy_autoencoders.ipynb",
+ "version": "0.3.2",
+ "provenance": [],
+ "collapsed_sections": []
+ },
+ "kernelspec": {
+ "name": "python3",
+ "display_name": "Python 3"
+ },
+ "accelerator": "GPU"
+ },
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "lEhhFaw8YPqS",
+ "colab_type": "text"
+ },
+ "source": [
+ "# Auto Encoders\n",
+ "\n",
+ "Notebook ini berdasarkan kursus __Deep Learning A-Z™: Hands-On Artificial Neural Networks__ di Udemy. [Lihat Kursus](https://www.udemy.com/deeplearning/).\n",
+ "\n",
+ "## Informasi Notebook\n",
+ "- __notebook name__: `taruma_udemy_autoencoders`\n",
+ "- __notebook version/date__: `1.0.0`/`20190801`\n",
+ "- __notebook server__: Google Colab\n",
+ "- __python version__: `3.6`\n",
+ "- __pytorch version__: `1.1.0`\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "XgpbxDgHYPpL",
+ "colab_type": "code",
+ "outputId": "e6d8390d-a34d-4414-81da-1bd5fc717c95",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 51
+ }
+ },
+ "source": [
+ "#### NOTEBOOK DESCRIPTION\n",
+ "\n",
+ "from datetime import datetime\n",
+ "\n",
+ "NOTEBOOK_TITLE = 'taruma_udemy_autoencoders'\n",
+ "NOTEBOOK_VERSION = '1.0.0'\n",
+ "NOTEBOOK_DATE = 1 # Set 1, if you want add date classifier\n",
+ "\n",
+ "NOTEBOOK_NAME = \"{}_{}\".format(\n",
+ " NOTEBOOK_TITLE, \n",
+ " NOTEBOOK_VERSION.replace('.','_')\n",
+ ")\n",
+ "PROJECT_NAME = \"{}_{}{}\".format(\n",
+ " NOTEBOOK_TITLE, \n",
+ " NOTEBOOK_VERSION.replace('.','_'), \n",
+ " \"_\" + datetime.utcnow().strftime(\"%Y%m%d_%H%M\") if NOTEBOOK_DATE else \"\"\n",
+ ")\n",
+ "\n",
+ "print(f\"Nama Notebook: {NOTEBOOK_NAME}\")\n",
+ "print(f\"Nama Proyek: {PROJECT_NAME}\")"
+ ],
+ "execution_count": 0,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "Nama Notebook: taruma_udemy_autoencoders_1_0_0\n",
+ "Nama Proyek: taruma_udemy_autoencoders_1_0_0_20190801_0925\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "ht_fonylY0my",
+ "colab_type": "code",
+ "outputId": "90c170d9-ac70-4562-be35-9f30401bd780",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 68
+ }
+ },
+ "source": [
+ "#### System Version\n",
+ "import sys, torch\n",
+ "print(\"versi python: {}\".format(sys.version))\n",
+ "print(\"versi pytorch: {}\".format(torch.__version__))"
+ ],
+ "execution_count": 0,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "versi python: 3.6.8 (default, Jan 14 2019, 11:02:34) \n",
+ "[GCC 8.0.1 20180414 (experimental) [trunk revision 259383]]\n",
+ "versi pytorch: 1.1.0\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "wWHOjSRRY5Pf",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "#### Load Notebook Extensions\n",
+ "%load_ext google.colab.data_table"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "sS6B-Y06Y8UB",
+ "colab_type": "code",
+ "outputId": "a20211fd-9b95-4f7c-82d8-e191ac2b1d97",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 340
+ }
+ },
+ "source": [
+ "#### Download dataset\n",
+ "# ref: https://grouplens.org/datasets/movielens/\n",
+ "!wget -O autoencoders.zip \"https://sds-platform-private.s3-us-east-2.amazonaws.com/uploads/P16-AutoEncoders.zip\"\n",
+ "!unzip autoencoders.zip"
+ ],
+ "execution_count": 0,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "--2019-08-01 09:25:40-- https://sds-platform-private.s3-us-east-2.amazonaws.com/uploads/P16-AutoEncoders.zip\n",
+ "Resolving sds-platform-private.s3-us-east-2.amazonaws.com (sds-platform-private.s3-us-east-2.amazonaws.com)... 52.219.80.168\n",
+ "Connecting to sds-platform-private.s3-us-east-2.amazonaws.com (sds-platform-private.s3-us-east-2.amazonaws.com)|52.219.80.168|:443... connected.\n",
+ "HTTP request sent, awaiting response... 200 OK\n",
+ "Length: 17069342 (16M) [application/zip]\n",
+ "Saving to: ‘autoencoders.zip’\n",
+ "\n",
+ "autoencoders.zip 100%[===================>] 16.28M 34.2MB/s in 0.5s \n",
+ "\n",
+ "2019-08-01 09:25:40 (34.2 MB/s) - ‘autoencoders.zip’ saved [17069342/17069342]\n",
+ "\n",
+ "Archive: autoencoders.zip\n",
+ " creating: AutoEncoders/\n",
+ " inflating: AutoEncoders/ae.py \n",
+ " creating: __MACOSX/\n",
+ " creating: __MACOSX/AutoEncoders/\n",
+ " inflating: __MACOSX/AutoEncoders/._ae.py \n",
+ " inflating: AutoEncoders/ml-100k.zip \n",
+ " inflating: AutoEncoders/ml-1m.zip \n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "jm-eeDVwcAda",
+ "colab_type": "code",
+ "outputId": "445a373d-bebf-418b-fb6c-ec639956afb1",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 1000
+ }
+ },
+ "source": [
+ "# Karena ada file .zip dalam direktori, harus diekstrak lagi.\n",
+ "# ref: https://askubuntu.com/q/399951\n",
+ "# ref: https://unix.stackexchange.com/q/12902\n",
+ "!find AutoEncoders -type f -name '*.zip' -exec unzip -d AutoEncoders {} \\;"
+ ],
+ "execution_count": 0,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "Archive: AutoEncoders/ml-100k.zip\n",
+ " creating: AutoEncoders/ml-100k/\n",
+ " inflating: AutoEncoders/ml-100k/allbut.pl \n",
+ " creating: AutoEncoders/__MACOSX/\n",
+ " creating: AutoEncoders/__MACOSX/ml-100k/\n",
+ " inflating: AutoEncoders/__MACOSX/ml-100k/._allbut.pl \n",
+ " inflating: AutoEncoders/ml-100k/mku.sh \n",
+ " inflating: AutoEncoders/__MACOSX/ml-100k/._mku.sh \n",
+ " inflating: AutoEncoders/ml-100k/README \n",
+ " inflating: AutoEncoders/__MACOSX/ml-100k/._README \n",
+ " inflating: AutoEncoders/ml-100k/u.data \n",
+ " inflating: AutoEncoders/__MACOSX/ml-100k/._u.data \n",
+ " inflating: AutoEncoders/ml-100k/u.genre \n",
+ " inflating: AutoEncoders/__MACOSX/ml-100k/._u.genre \n",
+ " inflating: AutoEncoders/ml-100k/u.info \n",
+ " inflating: AutoEncoders/__MACOSX/ml-100k/._u.info \n",
+ " inflating: AutoEncoders/ml-100k/u.item \n",
+ " inflating: AutoEncoders/__MACOSX/ml-100k/._u.item \n",
+ " inflating: AutoEncoders/ml-100k/u.occupation \n",
+ " inflating: AutoEncoders/__MACOSX/ml-100k/._u.occupation \n",
+ " inflating: AutoEncoders/ml-100k/u.user \n",
+ " inflating: AutoEncoders/__MACOSX/ml-100k/._u.user \n",
+ " inflating: AutoEncoders/ml-100k/u1.base \n",
+ " inflating: AutoEncoders/__MACOSX/ml-100k/._u1.base \n",
+ " inflating: AutoEncoders/ml-100k/u1.test \n",
+ " inflating: AutoEncoders/__MACOSX/ml-100k/._u1.test \n",
+ " inflating: AutoEncoders/ml-100k/u2.base \n",
+ " inflating: AutoEncoders/__MACOSX/ml-100k/._u2.base \n",
+ " inflating: AutoEncoders/ml-100k/u2.test \n",
+ " inflating: AutoEncoders/__MACOSX/ml-100k/._u2.test \n",
+ " inflating: AutoEncoders/ml-100k/u3.base \n",
+ " inflating: AutoEncoders/__MACOSX/ml-100k/._u3.base \n",
+ " inflating: AutoEncoders/ml-100k/u3.test \n",
+ " inflating: AutoEncoders/__MACOSX/ml-100k/._u3.test \n",
+ " inflating: AutoEncoders/ml-100k/u4.base \n",
+ " inflating: AutoEncoders/__MACOSX/ml-100k/._u4.base \n",
+ " inflating: AutoEncoders/ml-100k/u4.test \n",
+ " inflating: AutoEncoders/__MACOSX/ml-100k/._u4.test \n",
+ " inflating: AutoEncoders/ml-100k/u5.base \n",
+ " inflating: AutoEncoders/__MACOSX/ml-100k/._u5.base \n",
+ " inflating: AutoEncoders/ml-100k/u5.test \n",
+ " inflating: AutoEncoders/__MACOSX/ml-100k/._u5.test \n",
+ " inflating: AutoEncoders/ml-100k/ua.base \n",
+ " inflating: AutoEncoders/__MACOSX/ml-100k/._ua.base \n",
+ " inflating: AutoEncoders/ml-100k/ua.test \n",
+ " inflating: AutoEncoders/__MACOSX/ml-100k/._ua.test \n",
+ " inflating: AutoEncoders/ml-100k/ub.base \n",
+ " inflating: AutoEncoders/__MACOSX/ml-100k/._ub.base \n",
+ " inflating: AutoEncoders/ml-100k/ub.test \n",
+ " inflating: AutoEncoders/__MACOSX/ml-100k/._ub.test \n",
+ " inflating: AutoEncoders/__MACOSX/._ml-100k \n",
+ "Archive: AutoEncoders/ml-1m.zip\n",
+ " creating: AutoEncoders/ml-1m/\n",
+ " inflating: AutoEncoders/ml-1m/.DS_Store \n",
+ " creating: AutoEncoders/__MACOSX/ml-1m/\n",
+ " inflating: AutoEncoders/__MACOSX/ml-1m/._.DS_Store \n",
+ " inflating: AutoEncoders/ml-1m/.Rhistory \n",
+ " inflating: AutoEncoders/ml-1m/movies.dat \n",
+ " inflating: AutoEncoders/__MACOSX/ml-1m/._movies.dat \n",
+ " inflating: AutoEncoders/ml-1m/ratings.dat \n",
+ " inflating: AutoEncoders/__MACOSX/ml-1m/._ratings.dat \n",
+ " inflating: AutoEncoders/ml-1m/README \n",
+ " inflating: AutoEncoders/__MACOSX/ml-1m/._README \n",
+ " inflating: AutoEncoders/ml-1m/test_set.csv \n",
+ " inflating: AutoEncoders/__MACOSX/ml-1m/._test_set.csv \n",
+ " inflating: AutoEncoders/ml-1m/training_set.csv \n",
+ " inflating: AutoEncoders/__MACOSX/ml-1m/._training_set.csv \n",
+ " inflating: AutoEncoders/ml-1m/users.dat \n",
+ " inflating: AutoEncoders/__MACOSX/ml-1m/._users.dat \n",
+ " inflating: AutoEncoders/__MACOSX/._ml-1m \n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "q1oOGi4jZYrp",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "#### Atur dataset path\n",
+ "DATASET_DIRECTORY = 'AutoEncoders/'"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "1aWLNovwgC_X",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "def showdata(dataframe):\n",
+ " print('Dataframe Size: {}'.format(dataframe.shape))\n",
+ " return dataframe"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "hqE6ozW8e0ra",
+ "colab_type": "text"
+ },
+ "source": [
+ "# STEP 1-5 DATA PREPROCESSING"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "fLvxd5pQdTQq",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "# Importing the libraries\n",
+ "import numpy as np\n",
+ "import pandas as pd\n",
+ "import torch\n",
+ "import torch.nn as nn\n",
+ "import torch.nn.parallel\n",
+ "import torch.optim as optim\n",
+ "import torch.utils.data\n",
+ "from torch.autograd import Variable"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "lFQEACh4fJLp",
+ "colab_type": "code",
+ "outputId": "cc5e7095-91f4-4594-f30f-cf2f41380595",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 309
+ }
+ },
+ "source": [
+ "movies = pd.read_csv(DATASET_DIRECTORY + 'ml-1m/movies.dat', sep='::', header=None, engine='python', encoding='latin-1')\n",
+ "showdata(movies).head(10)"
+ ],
+ "execution_count": 0,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "Dataframe Size: (3883, 3)\n"
+ ],
+ "name": "stdout"
+ },
+ {
+ "output_type": "execute_result",
+ "data": {
+ "application/vnd.google.colaboratory.module+javascript": "\n import \"https://ssl.gstatic.com/colaboratory/data_table/81868506e94e6988/data_table.js\";\n\n window.createDataTable({\n data: [[{\n 'v': 0,\n 'f': \"0\",\n },\n{\n 'v': 1,\n 'f': \"1\",\n },\n\"Toy Story (1995)\",\n\"Animation|Children's|Comedy\"],\n [{\n 'v': 1,\n 'f': \"1\",\n },\n{\n 'v': 2,\n 'f': \"2\",\n },\n\"Jumanji (1995)\",\n\"Adventure|Children's|Fantasy\"],\n [{\n 'v': 2,\n 'f': \"2\",\n },\n{\n 'v': 3,\n 'f': \"3\",\n },\n\"Grumpier Old Men (1995)\",\n\"Comedy|Romance\"],\n [{\n 'v': 3,\n 'f': \"3\",\n },\n{\n 'v': 4,\n 'f': \"4\",\n },\n\"Waiting to Exhale (1995)\",\n\"Comedy|Drama\"],\n [{\n 'v': 4,\n 'f': \"4\",\n },\n{\n 'v': 5,\n 'f': \"5\",\n },\n\"Father of the Bride Part II (1995)\",\n\"Comedy\"],\n [{\n 'v': 5,\n 'f': \"5\",\n },\n{\n 'v': 6,\n 'f': \"6\",\n },\n\"Heat (1995)\",\n\"Action|Crime|Thriller\"],\n [{\n 'v': 6,\n 'f': \"6\",\n },\n{\n 'v': 7,\n 'f': \"7\",\n },\n\"Sabrina (1995)\",\n\"Comedy|Romance\"],\n [{\n 'v': 7,\n 'f': \"7\",\n },\n{\n 'v': 8,\n 'f': \"8\",\n },\n\"Tom and Huck (1995)\",\n\"Adventure|Children's\"],\n [{\n 'v': 8,\n 'f': \"8\",\n },\n{\n 'v': 9,\n 'f': \"9\",\n },\n\"Sudden Death (1995)\",\n\"Action\"],\n [{\n 'v': 9,\n 'f': \"9\",\n },\n{\n 'v': 10,\n 'f': \"10\",\n },\n\"GoldenEye (1995)\",\n\"Action|Adventure|Thriller\"]],\n columns: [[\"number\", \"index\"], [\"number\", \"0\"], [\"string\", \"1\"], [\"string\", \"2\"]],\n rowsPerPage: 25,\n });\n ",
+ "text/html": [
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " 0 | \n",
+ " 1 | \n",
+ " 2 | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | 0 | \n",
+ " 1 | \n",
+ " Toy Story (1995) | \n",
+ " Animation|Children's|Comedy | \n",
+ "
\n",
+ " \n",
+ " | 1 | \n",
+ " 2 | \n",
+ " Jumanji (1995) | \n",
+ " Adventure|Children's|Fantasy | \n",
+ "
\n",
+ " \n",
+ " | 2 | \n",
+ " 3 | \n",
+ " Grumpier Old Men (1995) | \n",
+ " Comedy|Romance | \n",
+ "
\n",
+ " \n",
+ " | 3 | \n",
+ " 4 | \n",
+ " Waiting to Exhale (1995) | \n",
+ " Comedy|Drama | \n",
+ "
\n",
+ " \n",
+ " | 4 | \n",
+ " 5 | \n",
+ " Father of the Bride Part II (1995) | \n",
+ " Comedy | \n",
+ "
\n",
+ " \n",
+ " | 5 | \n",
+ " 6 | \n",
+ " Heat (1995) | \n",
+ " Action|Crime|Thriller | \n",
+ "
\n",
+ " \n",
+ " | 6 | \n",
+ " 7 | \n",
+ " Sabrina (1995) | \n",
+ " Comedy|Romance | \n",
+ "
\n",
+ " \n",
+ " | 7 | \n",
+ " 8 | \n",
+ " Tom and Huck (1995) | \n",
+ " Adventure|Children's | \n",
+ "
\n",
+ " \n",
+ " | 8 | \n",
+ " 9 | \n",
+ " Sudden Death (1995) | \n",
+ " Action | \n",
+ "
\n",
+ " \n",
+ " | 9 | \n",
+ " 10 | \n",
+ " GoldenEye (1995) | \n",
+ " Action|Adventure|Thriller | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " 0 1 2\n",
+ "0 1 Toy Story (1995) Animation|Children's|Comedy\n",
+ "1 2 Jumanji (1995) Adventure|Children's|Fantasy\n",
+ "2 3 Grumpier Old Men (1995) Comedy|Romance\n",
+ "3 4 Waiting to Exhale (1995) Comedy|Drama\n",
+ "4 5 Father of the Bride Part II (1995) Comedy\n",
+ "5 6 Heat (1995) Action|Crime|Thriller\n",
+ "6 7 Sabrina (1995) Comedy|Romance\n",
+ "7 8 Tom and Huck (1995) Adventure|Children's\n",
+ "8 9 Sudden Death (1995) Action\n",
+ "9 10 GoldenEye (1995) Action|Adventure|Thriller"
+ ]
+ },
+ "metadata": {
+ "tags": []
+ },
+ "execution_count": 10
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "dgllaEzifthq",
+ "colab_type": "code",
+ "outputId": "948b43b7-b39b-4b9d-d9ac-32a010ebeffe",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 309
+ }
+ },
+ "source": [
+ "users = pd.read_csv(DATASET_DIRECTORY + 'ml-1m/users.dat', sep='::', header=None, engine='python', encoding='latin-1')\n",
+ "showdata(users).head(10)"
+ ],
+ "execution_count": 0,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "Dataframe Size: (6040, 5)\n"
+ ],
+ "name": "stdout"
+ },
+ {
+ "output_type": "execute_result",
+ "data": {
+ "application/vnd.google.colaboratory.module+javascript": "\n import \"https://ssl.gstatic.com/colaboratory/data_table/81868506e94e6988/data_table.js\";\n\n window.createDataTable({\n data: [[{\n 'v': 0,\n 'f': \"0\",\n },\n{\n 'v': 1,\n 'f': \"1\",\n },\n\"F\",\n{\n 'v': 1,\n 'f': \"1\",\n },\n{\n 'v': 10,\n 'f': \"10\",\n },\n\"48067\"],\n [{\n 'v': 1,\n 'f': \"1\",\n },\n{\n 'v': 2,\n 'f': \"2\",\n },\n\"M\",\n{\n 'v': 56,\n 'f': \"56\",\n },\n{\n 'v': 16,\n 'f': \"16\",\n },\n\"70072\"],\n [{\n 'v': 2,\n 'f': \"2\",\n },\n{\n 'v': 3,\n 'f': \"3\",\n },\n\"M\",\n{\n 'v': 25,\n 'f': \"25\",\n },\n{\n 'v': 15,\n 'f': \"15\",\n },\n\"55117\"],\n [{\n 'v': 3,\n 'f': \"3\",\n },\n{\n 'v': 4,\n 'f': \"4\",\n },\n\"M\",\n{\n 'v': 45,\n 'f': \"45\",\n },\n{\n 'v': 7,\n 'f': \"7\",\n },\n\"02460\"],\n [{\n 'v': 4,\n 'f': \"4\",\n },\n{\n 'v': 5,\n 'f': \"5\",\n },\n\"M\",\n{\n 'v': 25,\n 'f': \"25\",\n },\n{\n 'v': 20,\n 'f': \"20\",\n },\n\"55455\"],\n [{\n 'v': 5,\n 'f': \"5\",\n },\n{\n 'v': 6,\n 'f': \"6\",\n },\n\"F\",\n{\n 'v': 50,\n 'f': \"50\",\n },\n{\n 'v': 9,\n 'f': \"9\",\n },\n\"55117\"],\n [{\n 'v': 6,\n 'f': \"6\",\n },\n{\n 'v': 7,\n 'f': \"7\",\n },\n\"M\",\n{\n 'v': 35,\n 'f': \"35\",\n },\n{\n 'v': 1,\n 'f': \"1\",\n },\n\"06810\"],\n [{\n 'v': 7,\n 'f': \"7\",\n },\n{\n 'v': 8,\n 'f': \"8\",\n },\n\"M\",\n{\n 'v': 25,\n 'f': \"25\",\n },\n{\n 'v': 12,\n 'f': \"12\",\n },\n\"11413\"],\n [{\n 'v': 8,\n 'f': \"8\",\n },\n{\n 'v': 9,\n 'f': \"9\",\n },\n\"M\",\n{\n 'v': 25,\n 'f': \"25\",\n },\n{\n 'v': 17,\n 'f': \"17\",\n },\n\"61614\"],\n [{\n 'v': 9,\n 'f': \"9\",\n },\n{\n 'v': 10,\n 'f': \"10\",\n },\n\"F\",\n{\n 'v': 35,\n 'f': \"35\",\n },\n{\n 'v': 1,\n 'f': \"1\",\n },\n\"95370\"]],\n columns: [[\"number\", \"index\"], [\"number\", \"0\"], [\"string\", \"1\"], [\"number\", \"2\"], [\"number\", \"3\"], [\"string\", \"4\"]],\n rowsPerPage: 25,\n });\n ",
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " 0 | \n",
+ " 1 | \n",
+ " 2 | \n",
+ " 3 | \n",
+ " 4 | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | 0 | \n",
+ " 1 | \n",
+ " F | \n",
+ " 1 | \n",
+ " 10 | \n",
+ " 48067 | \n",
+ "
\n",
+ " \n",
+ " | 1 | \n",
+ " 2 | \n",
+ " M | \n",
+ " 56 | \n",
+ " 16 | \n",
+ " 70072 | \n",
+ "
\n",
+ " \n",
+ " | 2 | \n",
+ " 3 | \n",
+ " M | \n",
+ " 25 | \n",
+ " 15 | \n",
+ " 55117 | \n",
+ "
\n",
+ " \n",
+ " | 3 | \n",
+ " 4 | \n",
+ " M | \n",
+ " 45 | \n",
+ " 7 | \n",
+ " 02460 | \n",
+ "
\n",
+ " \n",
+ " | 4 | \n",
+ " 5 | \n",
+ " M | \n",
+ " 25 | \n",
+ " 20 | \n",
+ " 55455 | \n",
+ "
\n",
+ " \n",
+ " | 5 | \n",
+ " 6 | \n",
+ " F | \n",
+ " 50 | \n",
+ " 9 | \n",
+ " 55117 | \n",
+ "
\n",
+ " \n",
+ " | 6 | \n",
+ " 7 | \n",
+ " M | \n",
+ " 35 | \n",
+ " 1 | \n",
+ " 06810 | \n",
+ "
\n",
+ " \n",
+ " | 7 | \n",
+ " 8 | \n",
+ " M | \n",
+ " 25 | \n",
+ " 12 | \n",
+ " 11413 | \n",
+ "
\n",
+ " \n",
+ " | 8 | \n",
+ " 9 | \n",
+ " M | \n",
+ " 25 | \n",
+ " 17 | \n",
+ " 61614 | \n",
+ "
\n",
+ " \n",
+ " | 9 | \n",
+ " 10 | \n",
+ " F | \n",
+ " 35 | \n",
+ " 1 | \n",
+ " 95370 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " 0 1 2 3 4\n",
+ "0 1 F 1 10 48067\n",
+ "1 2 M 56 16 70072\n",
+ "2 3 M 25 15 55117\n",
+ "3 4 M 45 7 02460\n",
+ "4 5 M 25 20 55455\n",
+ "5 6 F 50 9 55117\n",
+ "6 7 M 35 1 06810\n",
+ "7 8 M 25 12 11413\n",
+ "8 9 M 25 17 61614\n",
+ "9 10 F 35 1 95370"
+ ]
+ },
+ "metadata": {
+ "tags": []
+ },
+ "execution_count": 11
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "anc5Oi1sgDzc",
+ "colab_type": "code",
+ "outputId": "295da240-b2dc-467a-f9c3-8c7ad3e1756c",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 309
+ }
+ },
+ "source": [
+ "ratings = pd.read_csv(DATASET_DIRECTORY + 'ml-1m/ratings.dat', sep='::', header=None, engine='python', encoding='latin-1')\n",
+ "showdata(ratings).head(10)"
+ ],
+ "execution_count": 0,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "Dataframe Size: (1000209, 4)\n"
+ ],
+ "name": "stdout"
+ },
+ {
+ "output_type": "execute_result",
+ "data": {
+ "application/vnd.google.colaboratory.module+javascript": "\n import \"https://ssl.gstatic.com/colaboratory/data_table/81868506e94e6988/data_table.js\";\n\n window.createDataTable({\n data: [[{\n 'v': 0,\n 'f': \"0\",\n },\n{\n 'v': 1,\n 'f': \"1\",\n },\n{\n 'v': 1193,\n 'f': \"1193\",\n },\n{\n 'v': 5,\n 'f': \"5\",\n },\n{\n 'v': 978300760,\n 'f': \"978300760\",\n }],\n [{\n 'v': 1,\n 'f': \"1\",\n },\n{\n 'v': 1,\n 'f': \"1\",\n },\n{\n 'v': 661,\n 'f': \"661\",\n },\n{\n 'v': 3,\n 'f': \"3\",\n },\n{\n 'v': 978302109,\n 'f': \"978302109\",\n }],\n [{\n 'v': 2,\n 'f': \"2\",\n },\n{\n 'v': 1,\n 'f': \"1\",\n },\n{\n 'v': 914,\n 'f': \"914\",\n },\n{\n 'v': 3,\n 'f': \"3\",\n },\n{\n 'v': 978301968,\n 'f': \"978301968\",\n }],\n [{\n 'v': 3,\n 'f': \"3\",\n },\n{\n 'v': 1,\n 'f': \"1\",\n },\n{\n 'v': 3408,\n 'f': \"3408\",\n },\n{\n 'v': 4,\n 'f': \"4\",\n },\n{\n 'v': 978300275,\n 'f': \"978300275\",\n }],\n [{\n 'v': 4,\n 'f': \"4\",\n },\n{\n 'v': 1,\n 'f': \"1\",\n },\n{\n 'v': 2355,\n 'f': \"2355\",\n },\n{\n 'v': 5,\n 'f': \"5\",\n },\n{\n 'v': 978824291,\n 'f': \"978824291\",\n }],\n [{\n 'v': 5,\n 'f': \"5\",\n },\n{\n 'v': 1,\n 'f': \"1\",\n },\n{\n 'v': 1197,\n 'f': \"1197\",\n },\n{\n 'v': 3,\n 'f': \"3\",\n },\n{\n 'v': 978302268,\n 'f': \"978302268\",\n }],\n [{\n 'v': 6,\n 'f': \"6\",\n },\n{\n 'v': 1,\n 'f': \"1\",\n },\n{\n 'v': 1287,\n 'f': \"1287\",\n },\n{\n 'v': 5,\n 'f': \"5\",\n },\n{\n 'v': 978302039,\n 'f': \"978302039\",\n }],\n [{\n 'v': 7,\n 'f': \"7\",\n },\n{\n 'v': 1,\n 'f': \"1\",\n },\n{\n 'v': 2804,\n 'f': \"2804\",\n },\n{\n 'v': 5,\n 'f': \"5\",\n },\n{\n 'v': 978300719,\n 'f': \"978300719\",\n }],\n [{\n 'v': 8,\n 'f': \"8\",\n },\n{\n 'v': 1,\n 'f': \"1\",\n },\n{\n 'v': 594,\n 'f': \"594\",\n },\n{\n 'v': 4,\n 'f': \"4\",\n },\n{\n 'v': 978302268,\n 'f': \"978302268\",\n }],\n [{\n 'v': 9,\n 'f': \"9\",\n },\n{\n 'v': 1,\n 'f': \"1\",\n },\n{\n 'v': 919,\n 'f': \"919\",\n },\n{\n 'v': 4,\n 'f': \"4\",\n },\n{\n 'v': 978301368,\n 'f': \"978301368\",\n }]],\n columns: [[\"number\", \"index\"], [\"number\", \"0\"], [\"number\", \"1\"], [\"number\", \"2\"], [\"number\", \"3\"]],\n rowsPerPage: 25,\n });\n ",
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " 0 | \n",
+ " 1 | \n",
+ " 2 | \n",
+ " 3 | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | 0 | \n",
+ " 1 | \n",
+ " 1193 | \n",
+ " 5 | \n",
+ " 978300760 | \n",
+ "
\n",
+ " \n",
+ " | 1 | \n",
+ " 1 | \n",
+ " 661 | \n",
+ " 3 | \n",
+ " 978302109 | \n",
+ "
\n",
+ " \n",
+ " | 2 | \n",
+ " 1 | \n",
+ " 914 | \n",
+ " 3 | \n",
+ " 978301968 | \n",
+ "
\n",
+ " \n",
+ " | 3 | \n",
+ " 1 | \n",
+ " 3408 | \n",
+ " 4 | \n",
+ " 978300275 | \n",
+ "
\n",
+ " \n",
+ " | 4 | \n",
+ " 1 | \n",
+ " 2355 | \n",
+ " 5 | \n",
+ " 978824291 | \n",
+ "
\n",
+ " \n",
+ " | 5 | \n",
+ " 1 | \n",
+ " 1197 | \n",
+ " 3 | \n",
+ " 978302268 | \n",
+ "
\n",
+ " \n",
+ " | 6 | \n",
+ " 1 | \n",
+ " 1287 | \n",
+ " 5 | \n",
+ " 978302039 | \n",
+ "
\n",
+ " \n",
+ " | 7 | \n",
+ " 1 | \n",
+ " 2804 | \n",
+ " 5 | \n",
+ " 978300719 | \n",
+ "
\n",
+ " \n",
+ " | 8 | \n",
+ " 1 | \n",
+ " 594 | \n",
+ " 4 | \n",
+ " 978302268 | \n",
+ "
\n",
+ " \n",
+ " | 9 | \n",
+ " 1 | \n",
+ " 919 | \n",
+ " 4 | \n",
+ " 978301368 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " 0 1 2 3\n",
+ "0 1 1193 5 978300760\n",
+ "1 1 661 3 978302109\n",
+ "2 1 914 3 978301968\n",
+ "3 1 3408 4 978300275\n",
+ "4 1 2355 5 978824291\n",
+ "5 1 1197 3 978302268\n",
+ "6 1 1287 5 978302039\n",
+ "7 1 2804 5 978300719\n",
+ "8 1 594 4 978302268\n",
+ "9 1 919 4 978301368"
+ ]
+ },
+ "metadata": {
+ "tags": []
+ },
+ "execution_count": 12
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "xU2S9y8NgRPW",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "# Preparing the training set and the test set\n",
+ "training_set = pd.read_csv(DATASET_DIRECTORY + 'ml-100k/u1.base', delimiter='\\t')\n",
+ "training_set = np.array(training_set, dtype='int')\n",
+ "test_set = pd.read_csv(DATASET_DIRECTORY + 'ml-100k/u1.test', delimiter='\\t')\n",
+ "test_set = np.array(test_set, dtype='int')"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "G7WbrddJl3Q9",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "# Getting the number of users and movies\n",
+ "nb_users = int(max(max(training_set[:, 0]), max(test_set[:, 0])))\n",
+ "nb_movies = int(max(max(training_set[:, 1]), max(test_set[:, 1])))"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "yRUTR_K3_rzP",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "# Converting the data into an array with users in lines and movies in columns\n",
+ "def convert(data):\n",
+ " new_data = []\n",
+ " for id_users in range(1, nb_users+1):\n",
+ " id_movies = data[:, 1][data[:, 0] == id_users]\n",
+ " id_ratings = data[:, 2][data[:, 0] == id_users]\n",
+ " ratings = np.zeros(nb_movies)\n",
+ " ratings[id_movies - 1] = id_ratings\n",
+ " new_data.append(list(ratings))\n",
+ " return new_data\n",
+ "\n",
+ "training_set = convert(training_set)\n",
+ "test_set = convert(test_set)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "u0Fk8Q0YCNZr",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "# Converting the data into Torch tensors\n",
+ "training_set = torch.FloatTensor(training_set)\n",
+ "test_set = torch.FloatTensor(test_set)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "q2arIJufDBYd",
+ "colab_type": "code",
+ "outputId": "41675f96-4fb9-40cb-d97c-54d983455e8b",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 136
+ }
+ },
+ "source": [
+ "training_set"
+ ],
+ "execution_count": 0,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "tensor([[0., 3., 4., ..., 0., 0., 0.],\n",
+ " [4., 0., 0., ..., 0., 0., 0.],\n",
+ " [0., 0., 0., ..., 0., 0., 0.],\n",
+ " ...,\n",
+ " [5., 0., 0., ..., 0., 0., 0.],\n",
+ " [0., 0., 0., ..., 0., 0., 0.],\n",
+ " [0., 5., 0., ..., 0., 0., 0.]])"
+ ]
+ },
+ "metadata": {
+ "tags": []
+ },
+ "execution_count": 17
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "qAkHvP9NhOPp",
+ "colab_type": "text"
+ },
+ "source": [
+ "# STEP 6-7"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "Bz95oUacgQPQ",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "# Creating the architecture of the Neural Network\n",
+ "class SAE(nn.Module):\n",
+ " def __init__(self, ):\n",
+ " super(SAE, self).__init__()\n",
+ " self.fc1 = nn.Linear(nb_movies, 20)\n",
+ " self.fc2 = nn.Linear(20, 10)\n",
+ " self.fc3 = nn.Linear(10, 20)\n",
+ " self.fc4 = nn.Linear(20, nb_movies)\n",
+ " self.activation = nn.Sigmoid()\n",
+ " def forward(self, x):\n",
+ " x = self.activation(self.fc1(x))\n",
+ " x = self.activation(self.fc2(x))\n",
+ " x = self.activation(self.fc3(x))\n",
+ " x = self.fc4(x)\n",
+ " return x\n",
+ "sae = SAE()\n",
+ "criterion = nn.MSELoss()\n",
+ "optimizer = optim.RMSprop(sae.parameters(), lr = 0.01, weight_decay = 0.5)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "nuM-A8Ozjty4",
+ "colab_type": "text"
+ },
+ "source": [
+ "# STEP 8-10"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "xiYO03QMjsHG",
+ "colab_type": "code",
+ "outputId": "f160c864-896f-4f52-e994-2a4fddfbc307",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 1000
+ }
+ },
+ "source": [
+ "# Training the SAE\n",
+ "nb_epoch = 200\n",
+ "for epoch in range(1, nb_epoch + 1):\n",
+ " train_loss = 0\n",
+ " s = 0.\n",
+ " for id_user in range(nb_users):\n",
+ " input = Variable(training_set[id_user]).unsqueeze(0)\n",
+ " target = input.clone()\n",
+ " if torch.sum(target.data > 0) > 0:\n",
+ " output = sae(input)\n",
+ " target.require_grad = False\n",
+ " output[target == 0] = 0\n",
+ " loss = criterion(output, target)\n",
+ " mean_corrector = nb_movies/float(torch.sum(target.data > 0) + 1e-10)\n",
+ " loss.backward()\n",
+ " train_loss += np.sqrt(loss.item()*mean_corrector)\n",
+ " s += 1.\n",
+ " optimizer.step()\n",
+ " print('epoch: '+str(epoch)+' loss: '+str(train_loss/s))"
+ ],
+ "execution_count": 0,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "epoch: 1 loss: 1.7663983791313438\n",
+ "epoch: 2 loss: 1.0965944818481448\n",
+ "epoch: 3 loss: 1.0533398732955221\n",
+ "epoch: 4 loss: 1.0383018413922185\n",
+ "epoch: 5 loss: 1.0308177439541621\n",
+ "epoch: 6 loss: 1.026551124053685\n",
+ "epoch: 7 loss: 1.023840092408676\n",
+ "epoch: 8 loss: 1.021978586980373\n",
+ "epoch: 9 loss: 1.0206570638587025\n",
+ "epoch: 10 loss: 1.0196462708959995\n",
+ "epoch: 11 loss: 1.0187753163243505\n",
+ "epoch: 12 loss: 1.018512555740381\n",
+ "epoch: 13 loss: 1.0178744683018195\n",
+ "epoch: 14 loss: 1.0174755647701952\n",
+ "epoch: 15 loss: 1.0170719470478082\n",
+ "epoch: 16 loss: 1.017201642832892\n",
+ "epoch: 17 loss: 1.0163239136444078\n",
+ "epoch: 18 loss: 1.0165747767066637\n",
+ "epoch: 19 loss: 1.0162508415906395\n",
+ "epoch: 20 loss: 1.0162299744574526\n",
+ "epoch: 21 loss: 1.0160825599663328\n",
+ "epoch: 22 loss: 1.0159708620648906\n",
+ "epoch: 23 loss: 1.0159037432204494\n",
+ "epoch: 24 loss: 1.0156694047908619\n",
+ "epoch: 25 loss: 1.0156815102111703\n",
+ "epoch: 26 loss: 1.0154590358153581\n",
+ "epoch: 27 loss: 1.0152956203593735\n",
+ "epoch: 28 loss: 1.0151429122142581\n",
+ "epoch: 29 loss: 1.0127277229574954\n",
+ "epoch: 30 loss: 1.0115507879790988\n",
+ "epoch: 31 loss: 1.0106808694785414\n",
+ "epoch: 32 loss: 1.0074244496142102\n",
+ "epoch: 33 loss: 1.0073100915343118\n",
+ "epoch: 34 loss: 1.0034969234369306\n",
+ "epoch: 35 loss: 1.0027353737074234\n",
+ "epoch: 36 loss: 1.0000683778711716\n",
+ "epoch: 37 loss: 0.9968187110598279\n",
+ "epoch: 38 loss: 0.9945375976402397\n",
+ "epoch: 39 loss: 0.9952177935337382\n",
+ "epoch: 40 loss: 0.9938334742471779\n",
+ "epoch: 41 loss: 0.9934695043949954\n",
+ "epoch: 42 loss: 0.9902121855511794\n",
+ "epoch: 43 loss: 0.9901160391783914\n",
+ "epoch: 44 loss: 0.9857301381332167\n",
+ "epoch: 45 loss: 0.9848217773360862\n",
+ "epoch: 46 loss: 0.9801835996478252\n",
+ "epoch: 47 loss: 0.9810873597000531\n",
+ "epoch: 48 loss: 0.978300727353134\n",
+ "epoch: 49 loss: 0.9768159755686795\n",
+ "epoch: 50 loss: 0.970972205043055\n",
+ "epoch: 51 loss: 0.9714721652842023\n",
+ "epoch: 52 loss: 0.968500137167768\n",
+ "epoch: 53 loss: 0.9677024816685345\n",
+ "epoch: 54 loss: 0.9659461926308117\n",
+ "epoch: 55 loss: 0.9674038597441262\n",
+ "epoch: 56 loss: 0.9652042557789273\n",
+ "epoch: 57 loss: 0.9635202505788273\n",
+ "epoch: 58 loss: 0.9650874836412309\n",
+ "epoch: 59 loss: 0.9642095855871714\n",
+ "epoch: 60 loss: 0.9586750134842592\n",
+ "epoch: 61 loss: 0.9572684056349163\n",
+ "epoch: 62 loss: 0.9564866799474354\n",
+ "epoch: 63 loss: 0.9524743478337185\n",
+ "epoch: 64 loss: 0.9502278884724376\n",
+ "epoch: 65 loss: 0.9533428352764142\n",
+ "epoch: 66 loss: 0.9520933496393511\n",
+ "epoch: 67 loss: 0.9546508691490383\n",
+ "epoch: 68 loss: 0.9489561905583827\n",
+ "epoch: 69 loss: 0.9490490017216804\n",
+ "epoch: 70 loss: 0.9483167270874054\n",
+ "epoch: 71 loss: 0.948329255203358\n",
+ "epoch: 72 loss: 0.9450881600029056\n",
+ "epoch: 73 loss: 0.9463115597986019\n",
+ "epoch: 74 loss: 0.9437816299409459\n",
+ "epoch: 75 loss: 0.9455461502145251\n",
+ "epoch: 76 loss: 0.9420526631180003\n",
+ "epoch: 77 loss: 0.9435457856469216\n",
+ "epoch: 78 loss: 0.9411563134969737\n",
+ "epoch: 79 loss: 0.9436575836579513\n",
+ "epoch: 80 loss: 0.9422297843906718\n",
+ "epoch: 81 loss: 0.9410528463853715\n",
+ "epoch: 82 loss: 0.9402148460233527\n",
+ "epoch: 83 loss: 0.9409234754132823\n",
+ "epoch: 84 loss: 0.9405657855477602\n",
+ "epoch: 85 loss: 0.9382027201893749\n",
+ "epoch: 86 loss: 0.9393233675827815\n",
+ "epoch: 87 loss: 0.9374333910506758\n",
+ "epoch: 88 loss: 0.9366116336780694\n",
+ "epoch: 89 loss: 0.9377259823272002\n",
+ "epoch: 90 loss: 0.9365444235602165\n",
+ "epoch: 91 loss: 0.9380175938760765\n",
+ "epoch: 92 loss: 0.9364794219167737\n",
+ "epoch: 93 loss: 0.9368766124940768\n",
+ "epoch: 94 loss: 0.9348002232788932\n",
+ "epoch: 95 loss: 0.9353004705734516\n",
+ "epoch: 96 loss: 0.9343677843163494\n",
+ "epoch: 97 loss: 0.9353256751794342\n",
+ "epoch: 98 loss: 0.933877368043547\n",
+ "epoch: 99 loss: 0.9342818034628956\n",
+ "epoch: 100 loss: 0.9333942400397647\n",
+ "epoch: 101 loss: 0.9341794560759067\n",
+ "epoch: 102 loss: 0.932444274542758\n",
+ "epoch: 103 loss: 0.9329446660349489\n",
+ "epoch: 104 loss: 0.9331678830270377\n",
+ "epoch: 105 loss: 0.9331724844463245\n",
+ "epoch: 106 loss: 0.9331020305951515\n",
+ "epoch: 107 loss: 0.9356272341681415\n",
+ "epoch: 108 loss: 0.9333336215395651\n",
+ "epoch: 109 loss: 0.9327508003016757\n",
+ "epoch: 110 loss: 0.9308627731347268\n",
+ "epoch: 111 loss: 0.9319176007690649\n",
+ "epoch: 112 loss: 0.9306397121343122\n",
+ "epoch: 113 loss: 0.9305777403332568\n",
+ "epoch: 114 loss: 0.9302414124205797\n",
+ "epoch: 115 loss: 0.9305424765978645\n",
+ "epoch: 116 loss: 0.9294236245683961\n",
+ "epoch: 117 loss: 0.9295683690937063\n",
+ "epoch: 118 loss: 0.9290601632685692\n",
+ "epoch: 119 loss: 0.9298997313915192\n",
+ "epoch: 120 loss: 0.9287010974464924\n",
+ "epoch: 121 loss: 0.9288074722866032\n",
+ "epoch: 122 loss: 0.9279760744321034\n",
+ "epoch: 123 loss: 0.9279426068053931\n",
+ "epoch: 124 loss: 0.9275374298911129\n",
+ "epoch: 125 loss: 0.9279328461908956\n",
+ "epoch: 126 loss: 0.9277038322243288\n",
+ "epoch: 127 loss: 0.9280261047596016\n",
+ "epoch: 128 loss: 0.9266577717902903\n",
+ "epoch: 129 loss: 0.9274436983768939\n",
+ "epoch: 130 loss: 0.9262172192927275\n",
+ "epoch: 131 loss: 0.9268704635553348\n",
+ "epoch: 132 loss: 0.9264313648325654\n",
+ "epoch: 133 loss: 0.9270331564311223\n",
+ "epoch: 134 loss: 0.9259879544058086\n",
+ "epoch: 135 loss: 0.9265063473172516\n",
+ "epoch: 136 loss: 0.9252285856398398\n",
+ "epoch: 137 loss: 0.9257206007928372\n",
+ "epoch: 138 loss: 0.9245857017528629\n",
+ "epoch: 139 loss: 0.9249536996678024\n",
+ "epoch: 140 loss: 0.9239828664132971\n",
+ "epoch: 141 loss: 0.9250168599949399\n",
+ "epoch: 142 loss: 0.9239714020219754\n",
+ "epoch: 143 loss: 0.9248878068576096\n",
+ "epoch: 144 loss: 0.9231863363249722\n",
+ "epoch: 145 loss: 0.9244485999674413\n",
+ "epoch: 146 loss: 0.9231108985583485\n",
+ "epoch: 147 loss: 0.9241529591466949\n",
+ "epoch: 148 loss: 0.9228550944294732\n",
+ "epoch: 149 loss: 0.9237827557157635\n",
+ "epoch: 150 loss: 0.922260170746647\n",
+ "epoch: 151 loss: 0.9231400282022982\n",
+ "epoch: 152 loss: 0.9221839934603951\n",
+ "epoch: 153 loss: 0.9227788564070573\n",
+ "epoch: 154 loss: 0.9213350301333955\n",
+ "epoch: 155 loss: 0.922453842482827\n",
+ "epoch: 156 loss: 0.9210483122507049\n",
+ "epoch: 157 loss: 0.9219510963958538\n",
+ "epoch: 158 loss: 0.9204969614260258\n",
+ "epoch: 159 loss: 0.9205394209501664\n",
+ "epoch: 160 loss: 0.9200661759022467\n",
+ "epoch: 161 loss: 0.9207735137229326\n",
+ "epoch: 162 loss: 0.9196641402017643\n",
+ "epoch: 163 loss: 0.9204513049820104\n",
+ "epoch: 164 loss: 0.9193051927516236\n",
+ "epoch: 165 loss: 0.9210140873158912\n",
+ "epoch: 166 loss: 0.9193127515207875\n",
+ "epoch: 167 loss: 0.9200597882686071\n",
+ "epoch: 168 loss: 0.9185944485414366\n",
+ "epoch: 169 loss: 0.9201572432142742\n",
+ "epoch: 170 loss: 0.9183169550351225\n",
+ "epoch: 171 loss: 0.9193881788559667\n",
+ "epoch: 172 loss: 0.9180057668314479\n",
+ "epoch: 173 loss: 0.9191220927901347\n",
+ "epoch: 174 loss: 0.9177848844173945\n",
+ "epoch: 175 loss: 0.9190516442024842\n",
+ "epoch: 176 loss: 0.9181445924423348\n",
+ "epoch: 177 loss: 0.919047934578481\n",
+ "epoch: 178 loss: 0.9175119757656524\n",
+ "epoch: 179 loss: 0.9186781150882567\n",
+ "epoch: 180 loss: 0.9175681590539049\n",
+ "epoch: 181 loss: 0.9183763375326187\n",
+ "epoch: 182 loss: 0.9169434621528899\n",
+ "epoch: 183 loss: 0.9177548550969366\n",
+ "epoch: 184 loss: 0.9170545570415128\n",
+ "epoch: 185 loss: 0.9179762411576573\n",
+ "epoch: 186 loss: 0.9166707151557505\n",
+ "epoch: 187 loss: 0.9174266883043443\n",
+ "epoch: 188 loss: 0.9162146914993445\n",
+ "epoch: 189 loss: 0.917265776286358\n",
+ "epoch: 190 loss: 0.9159440051014004\n",
+ "epoch: 191 loss: 0.9167926651895048\n",
+ "epoch: 192 loss: 0.9157365677088328\n",
+ "epoch: 193 loss: 0.9169038115550036\n",
+ "epoch: 194 loss: 0.9156644022282158\n",
+ "epoch: 195 loss: 0.916360655268448\n",
+ "epoch: 196 loss: 0.9149874787609436\n",
+ "epoch: 197 loss: 0.9160702331415719\n",
+ "epoch: 198 loss: 0.9148375459877753\n",
+ "epoch: 199 loss: 0.915890166240895\n",
+ "epoch: 200 loss: 0.9151742022378695\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "Egrl1Ge4kA6Y",
+ "colab_type": "code",
+ "outputId": "f5d79f27-2a67-4498-ca86-ab8d6fb88240",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 34
+ }
+ },
+ "source": [
+ "# Testing the SAE\n",
+ "test_loss = 0\n",
+ "s = 0.\n",
+ "for id_user in range(nb_users):\n",
+ " input = Variable(training_set[id_user]).unsqueeze(0)\n",
+ " target = Variable(test_set[id_user]).unsqueeze(0)\n",
+ " if torch.sum(target.data > 0) > 0:\n",
+ " output = sae(input)\n",
+ " target.require_grad = False\n",
+ " output[target == 0] = 0\n",
+ " loss = criterion(output, target)\n",
+ " mean_corrector = nb_movies/float(torch.sum(target.data > 0) + 1e-10)\n",
+ " test_loss += np.sqrt(loss.item()*mean_corrector)\n",
+ " s += 1.\n",
+ "print('test loss: '+str(test_loss/s))"
+ ],
+ "execution_count": 0,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "test loss: 0.9503542203018388\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ }
+ ]
+}
\ No newline at end of file