{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from sklearn import datasets\n", "import numpy as np\n", "import math\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "def load_boston(ratio=0.8):\n", " X, Y = datasets.load_boston(True)\n", " Y.shape = -1, 1\n", " \n", " # normalization\n", " X = X/80\n", " Y = Y/(np.max(Y) - np.min(Y))\n", " \n", " num_samples = len(Y)\n", " num_train = math.ceil(num_samples * ratio)\n", " \n", " # 随机打乱数据\n", " idx = np.random.permutation(np.arange(num_samples))\n", " traindata = X[idx[:num_train]], Y[idx[:num_train]]\n", " validdata = X[idx[num_train:]], Y[idx[num_train:]]\n", " \n", " return traindata, validdata" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "(X_train, Y_train), (X_valid, Y_valid) = load_boston()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# MLP Quiz\n", "\n", "## 内容\n", "\n", "利用$\\hat{Y} = f_2\\circ\\phi\\circ f_1 (X)$和梯度下降法来拟合boston数据集,即求解最优化问题:$min_{W, b} L(\\hat{Y}, Y)$\n", "\n", "其中:\n", "\n", "* $Z_0 = X$, 其中 $X \\in \\mathbb{R}^{N\\times n_{in}}$\n", "* $Z_1 = f_1(Z_0) := Z_0W_1^T + b_1$, 其中 $W_1 \\in \\mathbb{R}^{n_{mid}\\times n_{in}}, b_1 \\in \\mathbb{R}^{n_{mid}}$\n", "* $Z_2 = \\phi_2(Z_1) := \\frac{1}{1+e^{-Z_1}}$, 其中指数运算为逐元素运算,即$e^{X}_i := e^{X_i}$\n", "* $Z_3 = f_2(Z_2) := Z_2W_2^T + b2$, 其中 $W_2 \\in \\mathbb{R}^{n_{out}\\times n_{mid}}, b_2 \\in \\mathbb{R}^{n_{out}}$\n", "* $\\hat{Y} = Z_3$\n", "* $L(\\hat{Y}, Y) := \\frac{1}{2} \\sum_{i=1}^{N} (\\hat{Y_i} - Y_i)^2$\n", "\n", "关于boston数据集:$n_{in}=13, n_{out}=1$,为了降低计算量,设定$n_{mid} = 30$\n", "\n", "## 评分\n", "\n", "1. (4分)给出$\\frac{\\partial L}{\\partial W_1}, \\frac{\\partial L}{\\partial b_1}, \\frac{\\partial L}{\\partial W_2}, \\frac{\\partial L}{\\partial b_2}$的计算表达式,并注明其中每一个矩阵的尺寸(纸质或pdf)\n", "2. (4分)补充完整下述代码\n", "3. (2分)性能:服务器空载情况下运行一次完整的训练时间低于10s (Baseline为3.5s)\n", "\n", "## 提交\n", "\n", "提交到`ftp://ftp.lflab.cn/AI_homework/Graduate/quiz/`下\n", "\n", "## 参考\n", "\n", "* 矩阵关于标量的导数:$(\\frac{\\partial{Y}}{\\partial{X}})_{ij} := \\frac{\\partial{Y_ij}}{\\partial{X}}$, 其中 $Y \\in \\mathbb{R}^{m\\times n}, X \\in \\mathbb{R}$\n", "* 向量关于向量的导数:$(\\frac{\\partial{Y}}{\\partial{X}})_{ij} := \\frac{\\partial{Y_i}}{\\partial{X_j}}$, 其中 $Y \\in \\mathbb{R}^{m\\times 1}, X \\in \\mathbb{R}^{n\\times 1}$\n", "* 标量关于矩阵的导数:$(\\frac{\\partial{Y}}{\\partial{X}})_{ij} := \\frac{\\partial{Y}}{\\partial{X_ij}}$, 其中 $Y \\in \\mathbb{R}, X \\in \\mathbb{R}^{m\\times n}$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 实现" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "$f_1, f_2$称为线性层" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "class Linear():\n", " def __init__(self, in_features: int, out_features: int):\n", " raise NotImplementedError(\"实现它\")\n", " \n", " def __call__(self, X):\n", " return self.forward(X)\n", " \n", " def forward(self, X):\n", " raise NotImplementedError(\"实现它\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "$\\phi$称为激活函数(非线性层)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "class Sigmoid():\n", " \"\"\"phi\"\"\"\n", " def __call__(self, X):\n", " return self.forward(X)\n", " \n", " def forward(self, X):\n", " raise NotImplementedError(\"实现它\")" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "class MLP:\n", " def __init__(self, in_features: int, mid_features: int, out_features: int):\n", " self.f1 = Linear(in_features, mid_features)\n", " self.phi = Sigmoid()\n", " self.f2 = Linear(mid_features, out_features)\n", " \n", " def __call__(self, X):\n", " return self.f2(self.phi(self.f1(X)))\n", " \n", " def forward(self, X):\n", " Z0 = X\n", " Z1 = self.f1(X)\n", " Z2 = self.phi(Z1)\n", " Z3 = self.f2(Z2)\n", " return [Z0, Z1, Z2, Z3]\n", " \n", " def grad(self, Y, Z): # 3分\n", " Z0, Z1, Z2, Z3 = Z[0], Z[1], Z[2], Z[3]\n", "\n", " raise NotImplementedError(\"实现它\")\n", "\n", " return dLdW1, dLdb1, dLdW2, dLdb2" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "def loss(Y_real, Y_pred):\n", " return 0.5 * np.sum((Y_real - Y_pred)**2)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "class GradientDescent:\n", " def __init__(self, step=1e-3):\n", " self.step = step\n", " \n", " def update(self, model:MLP, dLdW1, dLdb1, dLdW2, dLdb2):\n", " \"\"\"利用梯度dW来更新f的权重\"\"\"\n", " raise NotImplementedError(\"实现它\")" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Iter 0: loss 54.8086, valid loss 12.8588\n", "Iter 100: loss 8.5929, valid loss 2.0355\n", "Iter 200: loss 8.5417, valid loss 2.0210\n", "Iter 300: loss 8.5267, valid loss 2.0173\n", "Iter 400: loss 8.5118, valid loss 2.0136\n", "Iter 500: loss 8.4970, valid loss 2.0100\n", "Iter 600: loss 8.4822, valid loss 2.0063\n", "Iter 700: loss 8.4674, valid loss 2.0027\n", "Iter 800: loss 8.4527, valid loss 1.9991\n", "Iter 900: loss 8.4380, valid loss 1.9955\n", "CPU times: user 56.5 s, sys: 1min 21s, total: 2min 17s\n", "Wall time: 3.01 s\n" ] } ], "source": [ "%%time\n", "num_features = X_train.shape[-1]\n", "model = MLP(num_features, 30, 1)\n", "opt = GradientDescent(1e-6)\n", "\n", "valid_losses = []\n", "train_losses = []\n", "for i in range(1000):\n", " X, Y = X_train, Y_train\n", " \n", " # 1分\n", " # 1. 计算梯度\n", " # 2. 更新权重\n", " raise NotImplementedError(\"实现它\")\n", "\n", " # 3. 存储中间状态\n", " Y_out = None # FIXME\n", " cur_valid_loss = loss(Y_valid, model(X_valid))\n", " cur_train_loss = loss(Y, Y_out)\n", " valid_losses.append(cur_valid_loss) \n", " train_losses.append(cur_train_loss)\n", " \n", " if i%100 == 0:\n", " print(f\"Iter {i}: loss {cur_train_loss:.4f}, valid loss {cur_valid_loss:.4f}\")" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXAAAAD4CAYAAAD1jb0+AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAeBElEQVR4nO3de3SV9b3n8fd3X3LnEkJAJGiwMgIB5BIRh1pArEe09dJ6PXqKHVtWXc7YTs/qKe2s1jqnnqVrHOXQZbX0qIdprZdi1WptGUtB6prKMSgiAgoI1nBLoCTcb8lv/niehCTksrOT7M1v789rray9n/v3ycP68MtvP89vm3MOERHxTyTdBYiISHIU4CIinlKAi4h4SgEuIuIpBbiIiKdiqTzY4MGDXXl5eSoPKSLivdWrV+9xzpW2nZ/SAC8vL6eqqiqVhxQR8Z6ZfdLefHWhiIh4SgEuIuIpBbiIiKdS2gcuIql34sQJqqurOXr0aLpLkS7k5eVRVlZGPB5PaH0FuEiGq66upl+/fpSXl2Nm6S5HOuCcY+/evVRXVzNy5MiEtlEXikiGO3r0KCUlJQrvM5yZUVJS0q2/lBTgIllA4e2H7l4nLwL8xXer+eVb7d4GKSKStbwI8Ffe28mzb/813WWISBLq6ur46U9/mtS2V111FXV1dQmv/6Mf/YiHHnooqWP5yIsAz4lGOH6yMd1liEgSOgvwkydPdrrta6+9xsCBA/uirIzgR4DHFOAivpo/fz5btmxh4sSJfOc732HFihVceumlXHPNNYwdOxaA6667jilTplBRUcGiRYuaty0vL2fPnj1s27aNMWPG8PWvf52KigquuOIKjhw50ulx16xZw7Rp05gwYQLXX389+/btA2DhwoWMHTuWCRMmcMsttwDwxhtvMHHiRCZOnMikSZM4cOBAH/02epcXtxEqwEV6x32vfMD6Hft7dZ9jz+7PvV+s6HD5Aw88wLp161izZg0AK1as4J133mHdunXNt8s9+eSTDBo0iCNHjnDRRRfx5S9/mZKSklb72bRpE8888ww///nPuemmm3jhhRe4/fbbOzzuV77yFX7yk58wY8YMfvjDH3LfffexYMECHnjgAbZu3Upubm5z98xDDz3Eo48+yvTp0zl48CB5eXk9/bWkhD8t8AYFuEimmDp1aqt7nRcuXMiFF17ItGnT+PTTT9m0adNp24wcOZKJEycCMGXKFLZt29bh/uvr66mrq2PGjBkAzJ07l5UrVwIwYcIEbrvtNn75y18SiwVt2OnTp/Ptb3+bhQsXUldX1zz/TOdFlTnRCMdOKMBFeqqzlnIqFRYWNr9fsWIFf/zjH/nLX/5CQUEBM2fObPde6Nzc3Ob30Wi0yy6Ujvzud79j5cqVvPLKK9x///28//77zJ8/n6uvvprXXnuN6dOns3TpUkaPHp3U/lPJixZ4bizCMbXARbzUr1+/TvuU6+vrKS4upqCggI0bN/LWW2/1+JgDBgyguLiYP//5zwD84he/YMaMGTQ2NvLpp58ya9YsHnzwQerr6zl48CBbtmxh/PjxfPe73+Wiiy5i48aNPa4hFfxogYd94M45PZAg4pmSkhKmT5/OuHHjmDNnDldffXWr5VdeeSWPP/44Y8aM4YILLmDatGm9ctzFixfzjW98g8OHD3Peeefx1FNP0dDQwO233059fT3OOe655x4GDhzID37wA5YvX04kEqGiooI5c+b0Sg19zZxzKTtYZWWlS+YLHX6ybBP/+/WP+OjHc8iJefFHg8gZY8OGDYwZMybdZUiC2rteZrbaOVfZdl0v0jA3HpSpDzJFRE7xIsBzomGA61ZCEZFmfgR4LAoowEVEWvIkwNUCFxFpy68Ab2hIcyUiImcOPwI87AM/pha4iEizhALczLaZ2ftmtsbMqsJ5g8zsdTPbFL4W91WRuepCEckqRUVFAOzYsYMbbrih3XVmzpxJV7clL1iwgMOHDzdPd3d42o6cKcPWdqcFPss5N7HFvYjzgWXOuVHAsnC6T6gPXCQ7nX322SxZsiTp7dsGeKYNT9uTLpRrgcXh+8XAdT0vp32n+sAV4CK+mT9/Po8++mjzdFPr9eDBg8yePZvJkyczfvx4Xn755dO23bZtG+PGjQPgyJEj3HLLLYwZM4brr7++1Vgod911F5WVlVRUVHDvvfcCwQBZO3bsYNasWcyaNQs4NTwtwMMPP8y4ceMYN24cCxYsaD6eT8PWJvoovQP+r5k54GfOuUXAUOfcznD5LmBoexua2TxgHsA555yTVJG6D1ykl/x+Pux6v3f3edZ4mPNAh4tvvvlmvvWtb3H33XcD8Pzzz7N06VLy8vJ48cUX6d+/P3v27GHatGlcc801HQ6X8dhjj1FQUMCGDRtYu3YtkydPbl52//33M2jQIBoaGpg9ezZr167lnnvu4eGHH2b58uUMHjy41b5Wr17NU089xapVq3DOcfHFFzNjxgyKi4u9GrY20Rb4Z51zk4E5wN1m9rmWC13wPH67z+Q75xY55yqdc5WlpaVJFakuFBF/TZo0iZqaGnbs2MF7771HcXExI0aMwDnH97//fSZMmMDll1/O9u3b2b17d4f7WblyZXOQTpgwgQkTJjQve/7555k8eTKTJk3igw8+YP369Z3W9Oabb3L99ddTWFhIUVERX/rSl5oHvvJp2NqEtnbObQ9fa8zsRWAqsNvMhjnndprZMKCmR5V0Ql0oIr2kk5ZyX7rxxhtZsmQJu3bt4uabbwbg6aefpra2ltWrVxOPxykvL293GNmubN26lYceeoi3336b4uJi7rjjjqT208SnYWu7bIGbWaGZ9Wt6D1wBrAN+C8wNV5sLnN6B1Ut0G6GI326++WaeffZZlixZwo033ggErdchQ4YQj8dZvnw5n3zySaf7+NznPsevfvUrANatW8fatWsB2L9/P4WFhQwYMIDdu3fz+9//vnmbjoayvfTSS3nppZc4fPgwhw4d4sUXX+TSSy/t9nmle9jaRFrgQ4EXw36pGPAr59wfzOxt4HkzuxP4BLipR5V0ouk2QgW4iJ8qKio4cOAAw4cPZ9iwYQDcdtttfPGLX2T8+PFUVlZ22RK96667+OpXv8qYMWMYM2YMU6ZMAeDCCy9k0qRJjB49mhEjRjB9+vTmbebNm8eVV17J2WefzfLly5vnT548mTvuuIOpU6cC8LWvfY1JkyZ12l3SkXQOW+vFcLJ1h48z8X++zg++MJY7Pzuy6w1EpJmGk/VLxg0nqw8xRURO50eA6zZCEZHTeBHgsWiEaMQ0mJVIklLZVSrJ6+518iLAIWiFqwUu0n15eXns3btXIX6Gc86xd+/ebj3c48WXGsOpLzYWke4pKyujurqa2tradJciXcjLy6OsrCzh9f0KcD3II9Jt8XickSN191Ym8qoLRfeBi4ic4k2A56oLRUSkFW8CXH3gIiKteRPguTF1oYiItORRgEc5dlL3gYuINPEnwOMRjp5QC1xEpIk/AR6LcvSEWuAiIk28CfC8uD7EFBFpyaMAVwtcRKQljwI8wlG1wEVEmnkT4OoDFxFpzZsAz4vrPnARkZb8CfBYlIZGxwkNaCUiAvgU4PEogLpRRERC3gR4bjwoVQ/ziIgEvAnwvFjQAtfj9CIiAW8CXC1wEZHWvAlw9YGLiLTmTYDnxoJS1YUiIhLwJsCbWuDH1IUiIgJ4GOBH1QIXEQG8CnB9iCki0pI3AZ6r2whFRFpJOMDNLGpm75rZq+H0SDNbZWabzew5M8vpuzLVAhcRaas7LfBvAhtaTD8IPOKcOx/YB9zZm4W11fQgj24jFBEJJBTgZlYGXA38WzhtwGXAknCVxcB1fVFgEz3IIyLSWqIt8AXAPwFN6VkC1DnnTobT1cDw9jY0s3lmVmVmVbW1tUkXqkfpRURa6zLAzewLQI1zbnUyB3DOLXLOVTrnKktLS5PZBQCRiJET1TfTi4g0iSWwznTgGjO7CsgD+gP/Cgw0s1jYCi8DtvddmYHceER94CIioS5b4M657znnypxz5cAtwJ+cc7cBy4EbwtXmAi/3WZWh3FhUXSgiIqGe3Af+XeDbZraZoE/8id4pqWP5ORGOHFeAi4hAYl0ozZxzK4AV4fuPgam9X1LHCuIxjqgLRUQE8OhJTIC8nCiH1QIXEQE8C/CCeFRdKCIiIb8CPCeqLhQRkZBXAZ6foxa4iEgTrwK8QH3gIiLNvArw/HiUw8dPdr2iiEgW8CvAc3QboYhIE68CvCAnyokGx4kGjYciIuJdgANqhYuI4FmA5zcFuD7IFBHxLMDDb6bXnSgiIp4FeFMXiu5EERHxLMDzc4KxtzQmuIiIZwF+qgWuABcR8SrA1QcuInKKXwGuu1BERJp5FeDqQhEROcWvAI8HH2LqQR4REc8C/FQXim4jFBHxKsBzYhFiEVMXiogIngU4BP3gh46pBS4i4l2AF+XGOHhMLXAREf8CPC+mFriICB4GeGFujIMKcBER/wK8SAEuIgJ4GuDqQhER8TDA1YUiIhLwLsDVhSIiEugywM0sz8z+w8zeM7MPzOy+cP5IM1tlZpvN7Dkzy+n7ck91oTjnUnE4EZEzViIt8GPAZc65C4GJwJVmNg14EHjEOXc+sA+4s+/KPKUwN0aj03goIiJdBrgLHAwn4+GPAy4DloTzFwPX9UmFbRTlBuOhqBtFRLJdQn3gZhY1szVADfA6sAWoc841pWg1MLyDbeeZWZWZVdXW1va44KK8YETCg0cV4CKS3RIKcOdcg3NuIlAGTAVGJ3oA59wi51ylc66ytLQ0yTJPKQy/F/OQHqcXkSzXrbtQnHN1wHLgEmCgmcXCRWXA9l6urV1FuWELXF0oIpLlErkLpdTMBobv84HPAxsIgvyGcLW5wMt9VWRLzV0oCnARyXKxrldhGLDYzKIEgf+8c+5VM1sPPGtmPwbeBZ7owzqbFeY2daEowEUku3UZ4M65tcCkduZ/TNAfnlLqQhERCXj3JGa/sAvlgO5CEZEs512A58ejxCLG/qMn0l2KiEhaeRfgZkb//Dj7jyjARSS7eRfgAAPy49QrwEUky3kZ4P3zYuxXH7iIZDk/A1xdKCIiCnAREV95GeAD8uO6C0VEsp6XAd4/L/gQU1/qICLZzMsAH5Af50SD4+iJxnSXIiKSNl4GeP/84GlM3UooItnMywAfkB8HUD+4iGQ1LwO8f14Y4GqBi0gW8zLAm1rg6kIRkWzmZYD3DwO87rACXESyl5cBPqggB4B9h4+nuRIRkfTxMsD758eIRkwBLiJZzcsANzOKC3L42yEFuIhkLy8DHKCkUAEuItnN2wAvLowrwEUkq3kb4CWFuQpwEclq3gZ4cWGcfbqNUESymLcBPqggh32Hj9PQqBEJRSQ7+RvghTk4p6cxRSR7eRvgxYXBwzx/O3QszZWIiKSHtwFeUpgLwN6D+iBTRLKTvwFeFLTA9+pOFBHJUt4G+JB+QQt89/6jaa5ERCQ9ugxwMxthZsvNbL2ZfWBm3wznDzKz181sU/ha3PflnlJckEM8atQcUB+4iGSnRFrgJ4F/dM6NBaYBd5vZWGA+sMw5NwpYFk6nTCRilBblUrNfAS4i2anLAHfO7XTOvRO+PwBsAIYD1wKLw9UWA9f1VZEdKe2fR80BdaGISHbqVh+4mZUDk4BVwFDn3M5w0S5gaK9WloAh/dQCF5HslXCAm1kR8ALwLefc/pbLnHMOaPeRSDObZ2ZVZlZVW1vbo2LbGtIvVy1wEclaCQW4mcUJwvtp59xvwtm7zWxYuHwYUNPets65Rc65SudcZWlpaW/U3GxIvzz2HT7B8ZONvbpfEREfJHIXigFPABuccw+3WPRbYG74fi7wcu+X17mh/YNbCWsPqhtFRLJPIi3w6cA/AJeZ2Zrw5yrgAeDzZrYJuDycTqmh/fMA2FV/JNWHFhFJu1hXKzjn3gSsg8Wze7ec7hlenA9A9b4jTDk3nZWIiKSet09iAgwfGAT49jq1wEUk+3gd4IW5MQYWxNm+TwEuItnH6wAHKCvOVwtcRLKS9wE+fGC+WuAikpUyIMAL2F53hOBZIhGR7OF/gBfnc/h4g77gWESyjvcBXl5SAMDWPYfSXImISGp5H+DnlRYBCnARyT7eB3hZcT6xiPFx7cF0lyIiklLeB3g8GuGckgK1wEUk63gf4ADnDS7k41oFuIhkl8wI8NIitu49REOjbiUUkeyREQE+akgRx082sm2vWuEikj0yIsDHnt0fgPU79nexpohI5siIAB81pB/xqLFhpwJcRLJHRgR4TizCZ0qLWK8AF5EskhEBDkE3ygc79mtMFBHJGhkT4BNHDKT2wDGqNTKhiGSJjAnwKecWA7D6k31prkREJDUyJsBHn9WfotwYVZ/8Ld2liIikhB8BvvbXUPVkp6tEI8akcwby9la1wEUkO/gR4OtfglWLulxt+vmD+XD3AXbVH01BUSIi6eVHgA8og/pPoYs7TGZeUArAGx/VpKIqEZG08iTAR8Dxg3C0vtPVLhjaj7P657F8Y22KChMRSR9PArwseK2v7nQ1M+OKiqEs/7CGg8dOpqAwEZH08STARwSv9Z92ueq1E8/m2MlGlq7b1cdFiYiklycBnlgLHGDyOcWMGJTPr1d3HfYiIj7zI8ALSyGam1AL3My47eJzeevjv2l0QhHJaH4EeCQCA4ZDXWKt6lsvOof8eJTH3tjSx4WJiKRPlwFuZk+aWY2ZrWsxb5CZvW5mm8LX4r4tExh4LuzbltCqAwri3PnZkbzy3g7WfFrXt3WJiKRJIi3wfweubDNvPrDMOTcKWBZO962S82Hvli7vBW/yjZmfYXBRLve+vI4TDY19XJyISOp1GeDOuZVA2wFGrgUWh+8XA9f1cl2nK/kMHKuHw3sTWr0oN8Z911TwXnU9/2vph31cnIhI6iXbBz7UObczfL8LGNrRimY2z8yqzKyqtrYHD9iUnB+87t2c8CZXTxjGP0w7l0UrP+aJN7cmf2wRkTNQjz/EdME3KHTYr+GcW+Scq3TOVZaWliZ/oEHnBa/dCHCAH35xLHPGncU/v7qef351PcdPqjtFRDJDsgG+28yGAYSvfT/4yMBzIRIL+sG7IR6NsPDWScy95FyeeHMrVzzyBq+8t0P94iLivViS2/0WmAs8EL6+3GsVdSQaC1rhtd3vz45HI9x37ThmjR7Cv7y2gf/2zLsMLsph9uih/OfzSxg7rD/nlhSSE/PjrkoREUggwM3sGWAmMNjMqoF7CYL7eTO7E/gEuKkvi2w2tAJ2vJv05jMvGMKlo0p546Malqyu5rV1O3muKri3PBoxzuqfR0lRDiWFORQX5JAbj5Ibi5AXvubEIkQjRsQgYoZZ8N6ASOTUdKR5vgUHbv2CmTXXdGpem9dwSYtVW2172nZt1m+7nA6Xt7+/9vbd0XkkWgtd1prYOXTnPDr+PXZQSyfbdfd3mkw93TmXVrtI9HfbzjFPu769fT5J/dtrs0NpV5cB7py7tYNFs3u5lq4NrYAPXoRjByC3X1K7iEaMy0YP5bLRQ2lodGzYuZ9NNQfYXHOQnXVH2XvoOHsOHmdTzUGOnWzk6IkGjp1sVN+5SBr16D+kjpYnst92/rPsuKHR0X9IwZtff+MSRg4ubPf8kpVsF0p6DB0XvNZsgBFTe7y7aMQYN3wA44YP6HLdxkbH8YZGnING58IfcOFro3O4FtMN4f3qrvn19H02zXO4NtNNy12b6eYtO1i/g/11ML+jOnq1lg62az5WoufQqt7W+6TL8+7ZOXRWj2vzy+x0my7q6fRc2uy/s2tFR7/zFsWdvr8kzqeDmtr/t94H59Ppv9c263T276gn59PFNi03LcyNnr6jHvIswCuC193reiXAuyMSMfIivX8BRESS5dendgNGQN5A2LEm3ZWIiKSdXwFuBmWVUP12uisREUk7vwIcoGxq0AfexderiYhkOv8CfMRFgIPtq9NdiYhIWvkX4MMrwSLwyf9LdyUiImnlX4Dn9Q9CfMuf0l2JiEha+RfgAJ+5DLa/A4fbjnIrIpI9/Azw82cDTq1wEclqfgb48ClQNBTWv5TuSkRE0sbPAI9EYex1sOn1YFwUEZEs5GeAA4y/AU4ehXUvpLsSEZG08DfAyy6Cs8bDW48n/EXHIiKZxN8AN4OL74LaDfDx8nRXIyKScv4GOMC4L0O/YfCnH0OjxusWkezid4DH8+DyHwWP1a99Nt3ViIiklN8BDjD+pqA//A/zYd+2dFcjIpIy/gd4JAJf+nnw1RfP3Q5H6tJdkYhISvgf4ACDRsINT0LNRvg/18D+HemuSESkz2VGgAOMuhxueRr2bIbHPwvv/hIaG9JdlYhIn8mcAAf4T38H81ZAcTm8fDc8OhXefARqP9K94iKScaztt1D3pcrKSldVVdX3B3IuGCdl1c/gr38J5hUNDb4UefAFMHAEFAyGwsGQXwzxfIjlQqzpNS8Yc7z5x069ioikmJmtds5Vtp3v17fSJ8oMKq4PfvZtg49XBF8AUbsR3lkMJw73YN8RwFqHO50Ee6ehfyZt18lmSR+vD7dN6T4S2M+Zso+E9pPIPlJRRwL7yaR9/P1zwed1vSgzA7yl4nKYckfwA8EDP0fr4PBeOLQHjuwLxlQ5eQxOHglfjwateNd46pWm6cbW811nDxB18tdNsn/5dLpdssfrg+36dNsU7iOh/aRqH12vot/rmboPgr/ue1nmB3hbkQgUDAp+Bo9KdzUiIknLrA8xRUSyiAJcRMRTPQpwM7vSzD40s81mNr+3ihIRka4lHeBmFgUeBeYAY4FbzWxsbxUmIiKd60kLfCqw2Tn3sXPuOPAscG3vlCUiIl3pSYAPBz5tMV0dzmvFzOaZWZWZVdXW1vbgcCIi0lKff4jpnFvknKt0zlWWlpb29eFERLJGTwJ8OzCixXRZOE9ERFIg6bFQzCwGfATMJgjut4G/d8590Mk2tcAnSR0QBgN7ktzWVzrn7KBzzg49OedznXOndWEk/SSmc+6kmf1XYCkQBZ7sLLzDbZLuQzGzqvYGc8lkOufsoHPODn1xzj16lN459xrwWi/VIiIi3aAnMUVEPOVTgC9KdwFpoHPODjrn7NDr55zSL3QQEZHe41MLXEREWlCAi4h4yosAz8RRD81shJktN7P1ZvaBmX0znD/IzF43s03ha3E438xsYfg7WGtmk9N7Bskzs6iZvWtmr4bTI81sVXhuz5lZTjg/N5zeHC4vT2fdyTKzgWa2xMw2mtkGM7sk06+zmf338N/1OjN7xszyMu06m9mTZlZjZutazOv2dTWzueH6m8xsbndqOOMDPINHPTwJ/KNzbiwwDbg7PK/5wDLn3ChgWTgNwfmPCn/mAY+lvuRe801gQ4vpB4FHnHPnA/uAO8P5dwL7wvmPhOv56F+BPzjnRgMXEpx7xl5nMxsO3ANUOufGETwncguZd53/HbiyzbxuXVczGwTcC1xMMEDgvU2hnxDn3Bn9A1wCLG0x/T3ge+muqw/O82Xg88CHwLBw3jDgw/D9z4BbW6zfvJ5PPwRDLiwDLgNeJfgm2D1ArO31JnhI7JLwfSxcz9J9Dt083wHA1rZ1Z/J15tRAd4PC6/Yq8HeZeJ2BcmBdstcVuBX4WYv5rdbr6ueMb4GT4KiHPgv/ZJwErAKGOud2hot2AUPD95nye1gA/BPQ9G3QJUCdc+5kON3yvJrPOVxeH67vk5FALfBU2G30b2ZWSAZfZ+fcduAh4K/AToLrtprMvs5Nuntde3S9fQjwjGZmRcALwLecc/tbLnPBf8kZc5+nmX0BqHHOrU53LSkUAyYDjznnJgGHOPVnNZCR17mY4LsBRgJnA4Wc3tWQ8VJxXX0I8Iwd9dDM4gTh/bRz7jfh7N1mNixcPgyoCednwu9hOnCNmW0j+AKQywj6hweGg6NB6/NqPudw+QBgbyoL7gXVQLVzblU4vYQg0DP5Ol8ObHXO1TrnTgC/Ibj2mXydm3T3uvboevsQ4G8Do8JPsHMIPgz5bZpr6jEzM+AJYINz7uEWi34LNH0SPZegb7xp/lfCT7OnAfUt/lTzgnPue865MudcOcF1/JNz7jZgOXBDuFrbc276XdwQru9VS9U5twv41MwuCGfNBtaTwdeZoOtkmpkVhP/Om845Y69zC929rkuBK8ysOPzL5YpwXmLS/SFAgh8UXEUwdO0W4H+ku55eOqfPEvx5tRZYE/5cRdD3twzYBPwRGBSubwR342wB3if4hD/t59GD858JvBq+Pw/4D2Az8GsgN5yfF05vDpefl+66kzzXiUBVeK1fAooz/ToD9wEbgXXAL4DcTLvOwDMEffwnCP7SujOZ6wr8l/DcNwNf7U4NepReRMRTPnShiIhIOxTgIiKeUoCLiHhKAS4i4ikFuIiIpxTgIiKeUoCLiHjq/wPh7tgogOFt4wAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.plot(train_losses)\n", "plt.plot(valid_losses)\n", "plt.legend([\"train loss\", \"validation loss\"])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "AI-Course", "language": "python", "name": "ai-course" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.3" } }, "nbformat": 4, "nbformat_minor": 4 }