{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from sklearn import datasets\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import math" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "def load_iris(ratio=0.8):\n", " features, target = datasets.load_iris(True)\n", " \n", " # 只保留0,1分类 -- 把任务变成二分类任务\n", " # 如果做多分类任务请注释这一段话\n", " idx = np.bitwise_or(target == 0, target == 1)\n", " features = features[idx]\n", " target = target[idx]\n", " \n", " num_samples = len(target)\n", " num_train = math.ceil(num_samples * ratio)\n", " \n", " \n", " # 随机打乱数据\n", " idx = np.random.permutation(np.arange(num_samples))\n", " traindata = features[idx[:num_train]], target[idx[:num_train]]\n", " validdata = features[idx[num_train:]], target[idx[num_train:]]\n", " \n", " return traindata, validdata" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 作业三\n", "\n", "## 四、Logistic回归与最大熵模型\n", "\n", "要求:\n", "\n", "* 通过logistic回归来对只含两类的iris数据集进行二分类\n", "* 利用梯度下降法(定步长)\n", "* 在尽量不修改代码结构的前提下完成工作\n", "\n", "ETA:1-5 hours" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "读取数据" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "(X_train, Y_train), (X_valid, Y_valid) = load_iris()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 一、定义线性回归函数\n", "\n", "为了简化推导,记$\\hat{Y} = WX+b$\n", "\n", "预测 -- 取下述概率最大的值:\n", "\n", "$$P(class=1|x) = \\frac{\\exp(\\hat{Y})}{1+\\exp(\\hat{Y})}$$\n", "$$P(class=0|x) = \\frac{1}{1+\\exp(\\hat{Y})}$$" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "class Logistic:\n", " def __init__(self, in_channels):\n", " self.W = np.zeros(in_channels)\n", " self.b = 0\n", " \n", " \n", " def __call__(self, X):\n", " \"\"\"计算 \\hat{Y} = WX + b,其中WX为内积\"\"\"\n", " assert len(X.shape) == 2\n", " # 这里只计算WX+b,因为在计算对数似然函数时需要用到该值\n", " \n", " # 实现它\n", " raise(NotImplementedError())\n", " return None\n", " \n", " \n", " def predict(self, X):\n", " \"\"\"预测X所属的类别\"\"\"\n", " assert len(X.shape) == 2\n", " \n", " # 实现它\n", " raise(NotImplementedError())\n", " return None\n", " \n", " \n", "def accuracy(real, predict):\n", " \"\"\"计算预测准确度\"\"\"\n", " return np.sum(real == predict)/real.size" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 测试LinearLayer是否能够正确调用\n", "num_features = X_train.shape[-1]\n", "f = Logistic(num_features)\n", "\n", "f(X_train)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "f.predict(X_train)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.5" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "accuracy(Y_train, f.predict(X_train)) # 瞎猜 -- 50%左右的分类准确度" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 二、定义误差 -- 对数似然函数\n", "\n", "$$ L(Y, \\hat{Y}) = \\sum_{i=1}^{N} [y_i\\hat{y}_i - log(1+\\exp(\\hat{y}_i))]$$\n", "\n", "其中$Y$为真实值,$\\hat{Y} = Wx + b$为预测值\n", "\n", "注:$L(Y, \\hat{Y}) \\neq L(\\hat{Y}, Y)$" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "class NegativeLogLikelihood:\n", " # 针对二项Logistic的负对数似然函数\n", " # 梯度下降法求解的问题实际上是 min. -L\n", " def __call__(self, real, predict):\n", " assert len(real.shape) == 1\n", " assert real.size == predict.size\n", " \n", " # 实现它\n", " raise(NotImplementedError())\n", " return None" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "55.451774444795625" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 测试loss是否能够正确调用\n", "loss = NegativeLogLikelihood()\n", "loss(Y_train, f(X_train))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 三、计算梯度\n", "\n", "为了使用梯度下降法,需要计算梯度$(\\frac{\\partial L}{\\partial W}, \\frac{\\partial L}{\\partial b})$\n", "\n", "利用链式法则:\n", "\n", "$$ \\frac{\\partial L}{\\partial W} = \\sum_{i=1}^{N}\\frac{\\partial L}{\\partial \\hat{Y}_i} \\frac{\\partial \\hat{Y}_i}{\\partial W} $$\n", "$$ \\frac{\\partial L}{\\partial b} = \\sum_{i=1}^{N}\\frac{\\partial L}{\\partial \\hat{Y}_i} \\frac{\\partial \\hat{Y}_i}{\\partial b} $$" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "def grad(Y, X, Y_out):\n", " \"\"\"\n", " 计算负对数似然函数在(X, Y)上关于W与b的偏导数\n", " Inputs:\n", " Y: shape (N, )\n", " 类别的真实值\n", " X: shape (N, C)\n", " 输入特征\n", " Y_out: shape (N, )\n", " Y_out = WX+b\n", "\n", " Outputs:\n", " dLdW: shape (C, )\n", " dLdb: shape (1, )\n", " \"\"\"\n", " # 实现它\n", " raise(NotImplementedError())\n", " dLdY = None # dLdY.shape == (?, ) (?为占位符, 不需要回答..)\n", " dYdW = None # dYdW.shape == (?, 4)\n", " dYdb = None # dYdb.shape == (?, 1)\n", " \n", " dLdW = None # dLdW.shape == (4, )\n", " dLdb = None # dLdb.shape == (1, )\n", " \n", " return dLdW, dLdb" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(array([-18.5 , 12.8 , -56.2 , -21.85]), array(0.))" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 测试梯度是否正确计算\n", "grad(Y_train, X_train, f(X_train))" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "class GradientDescent:\n", " def __init__(self, step=1e-3, thres=1e3):\n", " self.step = step\n", " self.thres = np.abs(thres)\n", " \n", " \n", " def update(self, f:Logistic, dLdW, dLdb):\n", " \"\"\"利用梯度dW来更新f的权重\"\"\"\n", " # 实现它\n", " raise(NotImplementedError())" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "before update: loss:55.4518 accuracy:0.5\n", "after update: loss:52.0012 accuracy:0.5\n" ] } ], "source": [ "# 测试梯度更新是否正常工作\n", "opt = GradientDescent(1e-3)\n", "f = Logistic(num_features)\n", "\n", "print(f\"before update: loss:{loss(Y_train, f(X_train)):.4f} accuracy:{accuracy(Y_train, f.predict(X_train))}\")\n", "\n", "dLdW, dLdb = grad(Y_train, X_train, f(X_train))\n", "opt.update(f, dLdW, dLdb)\n", "\n", "print(f\"after update: loss:{loss(Y_train, f(X_train)):.4f} accuracy:{accuracy(Y_train, f.predict(X_train))}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 四、把所有函数组合成一个完整的训练过程" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Iter 0: loss 55.4518, accu: 0.5000, valid loss 13.8528, valid accu 0.5000\n", "Iter 50: loss 53.6407, accu: 0.5000, valid loss 13.4073, valid accu 0.5000\n", "Iter 100: loss 52.1356, accu: 0.5000, valid loss 13.0342, valid accu 0.5000\n", "Iter 150: loss 50.7720, accu: 0.5000, valid loss 12.6949, valid accu 0.5000\n", "Iter 200: loss 49.4897, accu: 0.5125, valid loss 12.3752, valid accu 0.6000\n", "Iter 250: loss 48.2658, accu: 0.7000, valid loss 12.0699, valid accu 0.8000\n", "Iter 300: loss 47.0908, accu: 0.9000, valid loss 11.7768, valid accu 0.9000\n", "Iter 350: loss 45.9604, accu: 0.9750, valid loss 11.4947, valid accu 0.9500\n", "Iter 400: loss 44.8717, accu: 1.0000, valid loss 11.2230, valid accu 1.0000\n", "Iter 450: loss 43.8228, accu: 1.0000, valid loss 10.9611, valid accu 1.0000\n", "Iter 500: loss 42.8118, accu: 1.0000, valid loss 10.7087, valid accu 1.0000\n", "Iter 550: loss 41.8372, accu: 1.0000, valid loss 10.4653, valid accu 1.0000\n", "Iter 600: loss 40.8975, accu: 1.0000, valid loss 10.2305, valid accu 1.0000\n", "Iter 650: loss 39.9912, accu: 1.0000, valid loss 10.0041, valid accu 1.0000\n", "Iter 700: loss 39.1169, accu: 1.0000, valid loss 9.7856, valid accu 1.0000\n", "Iter 750: loss 38.2733, accu: 1.0000, valid loss 9.5747, valid accu 1.0000\n", "Iter 800: loss 37.4590, accu: 1.0000, valid loss 9.3710, valid accu 1.0000\n", "Iter 850: loss 36.6729, accu: 1.0000, valid loss 9.1744, valid accu 1.0000\n", "Iter 900: loss 35.9138, accu: 1.0000, valid loss 8.9845, valid accu 1.0000\n", "Iter 950: loss 35.1806, accu: 1.0000, valid loss 8.8010, valid accu 1.0000\n" ] } ], "source": [ "num_features = X_train.shape[-1]\n", "f = Logistic(num_features)\n", "opt = GradientDescent(1e-5)\n", "loss = NegativeLogLikelihood()\n", "\n", "valid_losses = []\n", "valid_accuracies = []\n", "train_losses = []\n", "train_accuracies = []\n", "for i in range(1000):\n", " X, Y = X_train, Y_train\n", " \n", " Y_out = f(X) \n", " dLdW, dLdb = grad(Y, X, Y_out)\n", " opt.update(f, dLdW, dLdb)\n", " \n", " # 记录中间结果\n", " cur_valid_loss = loss(Y_valid, f(X_valid))\n", " cur_valid_accu = accuracy(Y_valid, f.predict(X_valid))\n", " cur_train_loss = loss(Y, Y_out)\n", " cur_train_accu = accuracy(Y, f.predict(X))\n", " valid_losses.append(cur_valid_loss) \n", " valid_accuracies.append(cur_valid_accu)\n", " train_losses.append(cur_train_loss)\n", " train_accuracies.append(cur_train_accu)\n", " \n", " if i%50 == 0:\n", " print(f\"Iter {i}: loss {cur_train_loss:.4f}, accu: {cur_train_accu:.4f}, valid loss {cur_valid_loss:.4f}, valid accu {cur_valid_accu:.4f}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 五、打印中间结果" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.plot(train_losses)\n", "plt.plot(valid_losses)\n", "plt.legend([\"train loss\", \"validation loss\"])" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.plot(train_accuracies)\n", "plt.plot(valid_accuracies)\n", "plt.legend([\"train accuracy\", \"validation accuracy\"])" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAD4CAYAAADvsV2wAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAZg0lEQVR4nO3dfXRV9b3n8fdXHgxLUKqEO0LwJp3FBRQChIhtIw6UapFSqB1aoeiFeyuhPruuZUpXXYyLdrrsxXGgDJ0r3qqtSwUfaezF2tpCb6VgSSAgBKlIcyUggvjQUsgA8p0/zknmcDhJDuf5+Pu81srK3r/923t/9z47n5zsvc+OuTsiIvLxd06+CxARkdxQ4IuIBEKBLyISCAW+iEggFPgiIoHonq8V9+vXz8vLy/O1ehGRotTQ0PCuu5emMm/eAr+8vJz6+vp8rV5EpCiZ2X+kOq9O6YiIBEKBLyISCAW+iEggFPgiIoFQ4IuIBEKBLyISCAW+iEggFPgiIoHo8oNXZvYwMAU46O7DE0w3YCkwGTgKzHH3zZkuNNM21T3IoM2L6e+HOGil7K2az+VT5+VsfsmPdF63Yj5m8l17yNteSKyrf4BiZlcBR4CfdhD4k4HbiQT+FcBSd7+iqxVXV1d7vj5pu6nuQYY33EMvO97edsx7sn3M95J6IdOdX/IjndetmI+ZfNce8rZng5k1uHt1KvN2eUrH3f8deK+TLtOI/DJwd98I9DWzi1MpJlcGbV582gsI0MuOM2jz4pzML/mRzutWzMdMvmsPedsLTSbO4Q8E9saMt0TbzmBmtWZWb2b1hw4dysCqU9PfE6+7v7+bk/klP9J53Yr5mMl37SFve6HJ6UVbd1/h7tXuXl1amtLD3jLioCVe90Hrl5P5JT/Sed2K+ZjJd+0hb3uhyUTg7wMGxYyXRdsK1t6q+Rzznqe1HfOe7K2an5P5JT/Sed2K+ZjJd+0hb3uhyUTg1wF/bxGfAj5097czsNysuXzqPLaP+R4HKOWUGwcoPauLMOnOL/mRzutWzMdMvmsPedsLTTJ36TwJjAf6Ae8A/x3oAeDu/xK9LfN/A5OI3Jb5D+7e5e03+bxLR0SkWKVzl06X9+G7+8wupjtwayorFxGR3NEnbUVEAqHAFxEJhAJfRCQQCnwRkUAo8EVEAqHAFxEJhAJfRCQQCnwRkUAo8EVEAqHAFxEJhAJfRCQQCnwRkUAo8EVEAqHAFxEJhAJfRCQQCnwRkUAo8EVEAqHAFxEJhAJfRCQQCnwRkUAo8EVEAqHAFxEJhAJfRCQQCnwRkUAo8EVEAqHAFxEJhAJfRCQQCnwRkUAo8EVEApFU4JvZJDPbZWa7zWxBgumXmNlaM9tiZtvMbHLmSxURkXR0Gfhm1g1YDlwLXArMNLNL47rdAzzl7qOBGcCPMl2oiIikJ5l3+GOB3e6+x92PAyuBaXF9HDg/OnwBsD9zJYqISCZ0T6LPQGBvzHgLcEVcn3uBX5rZ7cB5wOcyUp2IiGRMpi7azgQedfcyYDLwmJmdsWwzqzWzejOrP3ToUIZWLSIiyUgm8PcBg2LGy6Jtsb4OPAXg7huAEqBf/ILcfYW7V7t7dWlpaWoVi4hISpIJ/E3AYDOrMLOeRC7K1sX1eQuYCGBmw4gEvt7Ci4gUkC4D391PArcBLwE7idyNs8PMFpnZ1Gi3u4G5ZrYVeBKY4+6eraJFROTsJXPRFndfA6yJa1sYM9wE1GS2NBERySR90lZEJBAKfBGRQCjwRUQCocAXEQmEAl9EJBAKfBGRQCjwRUQCocAXEQmEAl9EJBAKfBGRQCjwRUQCocAXEQmEAl9EJBAKfBGRQCjwRUQCocAXEQmEAl9EJBAKfBGRQCjwRUQCocAXEQmEAl9EJBAKfBGRQCjwRUQCocAXEQmEAl9EJBAKfBGRQCjwRUQCocAXEQmEAl9EJBAKfBGRQCQV+GY2ycx2mdluM1vQQZ+vmlmTme0wsycyW6aIiKSre1cdzKwbsBy4GmgBNplZnbs3xfQZDHwbqHH3982sf7YKFhGR1CTzDn8ssNvd97j7cWAlMC2uz1xgubu/D+DuBzNbpoiIpCuZwB8I7I0Zb4m2xfo74O/MbL2ZbTSzSYkWZGa1ZlZvZvWHDh1KrWIREUlJpi7adgcGA+OBmcBDZtY3vpO7r3D3anevLi0tzdCqRUQkGckE/j5gUMx4WbQtVgtQ5+4n3P1PwB+J/AIQEZEC0eVFW2ATMNjMKogE/Qzga3F9VhN5Z/+ImfUjcopnTyYLFZHcOHHiBC0tLbS2tua7lKCVlJRQVlZGjx49MrbMLgPf3U+a2W3AS0A34GF332Fmi4B6d6+LTrvGzJqAj4D57n44Y1WKSM60tLTQp08fysvLMbN8lxMkd+fw4cO0tLRQUVGRseUm8w4fd18DrIlrWxgz7MA/Rb9EpIi1trYq7PPMzLjooovI9M0t+qStiJxBYZ9/2XgNFPgi8rHWu3dvAPbv38/06dM77btkyRKOHj3aPj558mQ++OCDrNaXSwp8ESk6H3300VnPM2DAAJ555plO+8QH/po1a+jb94w7zIuWAl9E0rJ6yz5q7vsNFQv+jZr7fsPqLfF3bZ+d5uZmhg4dyqxZsxg2bBjTp0/n6NGjlJeX861vfYuqqiqefvpp3nzzTSZNmsSYMWMYN24cr7/+OgB/+tOf+PSnP82IESO45557Tlvu8OHDgcgvjG9+85sMHz6cyspKli1bxg9/+EP279/PhAkTmDBhAgDl5eW8++67ADzwwAMMHz6c4cOHs2TJkvZlDhs2jLlz53LZZZdxzTXXcOzYsbS2P5sU+CKSstVb9vHt515j3wfHcGDfB8f49nOvpR36u3bt4pZbbmHnzp2cf/75/OhHPwLgoosuYvPmzcyYMYPa2lqWLVtGQ0MD999/P7fccgsAd955JzfffDOvvfYaF198ccLlr1ixgubmZhobG9m2bRuzZs3ijjvuYMCAAaxdu5a1a9ee1r+hoYFHHnmEV199lY0bN/LQQw+xZcsWAN544w1uvfVWduzYQd++fXn22WfT2vZsUuCLSMoWv7SLYydOP71y7MRHLH5pV1rLHTRoEDU1NQDccMMNvPLKKwBcf/31ABw5coTf//73fOUrX2HUqFHMmzePt99+G4D169czc+ZMAG688caEy3/55ZeZN28e3btHblS88MILO63nlVde4brrruO8886jd+/efPnLX+Z3v/sdABUVFYwaNQqAMWPG0NzcnMaWZ1dSt2WKiCSy/4PEpy86ak9W/B0qbePnnXceAKdOnaJv3740NjYmNX82nXvuue3D3bp10ykdEfl4GtC311m1J+utt95iw4YNADzxxBNceeWVp00///zzqaio4OmnnwYiH1TaunUrADU1NaxcuRKAxx9/POHyr776ah588EFOnjwJwHvvvQdAnz59+Mtf/nJG/3HjxrF69WqOHj3KX//6V55//nnGjRuX1jbmgwJfRFI2//ND6NWj22ltvXp0Y/7nh6S13CFDhrB8+XKGDRvG+++/z80333xGn8cff5wf//jHjBw5kssuu4yf/exnACxdupTly5czYsQI9u1LfC3hpptu4pJLLqGyspKRI0fyxBOR/9lUW1vLpEmT2i/atqmqqmLOnDmMHTuWK664gptuuonRo0entY35YJEPyeZedXW119fX52XdItKxnTt3MmzYsKT7r96yj8Uv7WL/B8cY0LcX8z8/hC+Njn+CevKam5uZMmUK27dvT3kZHxeJXgsza3D36lSWp3P4IpKWL40emFbAS+7olI6IFJTy8nK9u88SBb6ISCAU+CIigVDgi4gEQoEvIhIIBb6ISIx169YxZcqUM9obGxtZs2ZNgjm69v3vf799OPYhbrmmwBeRotP2Cdlc6izwu6onNvDzSYEvIunZ9hT8r+Fwb9/I921PpbW47373uwwZMoQrr7ySmTNncv/99wMwfvx47rrrLqqrq1m6dCnNzc189rOfpbKykokTJ/LWW28BMGfOnNOee9/2D1DWrVvH+PHjmT59evvjl9s+ePqLX/yCoUOHUlVVxXPPPXdGTcePH2fhwoWsWrWKUaNGsWrVKu69915uvPFGampquPHGG3n00Ue57bbb2ueZMmUK69atY8GCBRw7doxRo0Yxa9YsIPJ45nw8UlmBLyKp2/YUvHAHfLgX8Mj3F+5IOfQ3bdrEs88+y9atW3nxxReJ/zT+8ePHqa+v5+677+b2229n9uzZpz3euCtbtmxhyZIlNDU1sWfPHtavX09raytz587lhRdeoKGhgQMHDpwxX8+ePVm0aBHXX389jY2N7U/tbGpq4uWXX+bJJ5/scJ333XcfvXr1orGxsf3ZPvl6pLICX0RS9+tFcCLu3emJY5H2FKxfv55p06ZRUlJCnz59+OIXv3ja9LagBdiwYQNf+9rXgMhjkNseodyZsWPHUlZWxjnnnMOoUaNobm7m9ddfp6KigsGDB2Nm3HDDDUnXO3XqVHr1OvsHxeXrkcoKfBFJ3YctZ9eeprbHI3eme/funDp1Cog8Rvn48ePt0+IfZZzutYDYemLXC9Da2trhfJmuI1kKfBFJ3QVlZ9fehZqaGl544QVaW1s5cuQIP//5zzvs+5nPfOa0xyC3Pa64vLychoYGAOrq6jhx4kSn6xw6dCjNzc28+eabAB2enuno0cltysvLaWxs5NSpU+zdu5c//OEP7dN69OjRZR25oMAXkdRNXAg94k5p9OgVaU/B5ZdfztSpU6msrOTaa69lxIgRXHDBBQn7Llu2jEceeYTKykoee+wxli5dCsDcuXP57W9/y8iRI9mwYUOXfxWUlJSwYsUKvvCFL1BVVUX//v0T9pswYQJNTU3tF23j1dTUUFFRwaWXXsodd9xBVVVV+7Ta2loqKyvbL9rmix6PLCKnOdvHI7Ptqcg5+w9bIu/sJy6Eyq+mvP4jR47Qu3dvjh49ylVXXcWKFStOC8+Q6PHIIlJYKr+aVsDHq62tpampidbWVmbPnh1s2GeDAl9ECkrbf5+SzNM5fBGRQCjwReQM+bq2J/9fNl6DpALfzCaZ2S4z221mCzrp91/NzM0spQsKIpJ/JSUlHD58WKGfR+7O4cOHKSkpyehyuzyHb2bdgOXA1UALsMnM6ty9Ka5fH+BO4NWMVigiOVVWVkZLSwuHDh3KdylBKykpoawstc8zdCSZi7Zjgd3uvgfAzFYC04CmuH7fBX4AzM9ohSKSUz169KCioiLfZUgWJHNKZyCwN2a8JdrWzsyqgEHu/m+dLcjMas2s3szq9e5BRCS30r5oa2bnAA8Ad3fV191XuHu1u1eXlpamu2oRETkLyQT+PmBQzHhZtK1NH2A4sM7MmoFPAXW6cCsiUliSCfxNwGAzqzCznsAMoK5tort/6O793L3c3cuBjcBUd9dzE0RECkiXge/uJ4HbgJeAncBT7r7DzBaZ2dRsFygiIpmR1KMV3H0NsCauLeHj8Nx9fPpliYhIpumTtiIigVDgi4gEQoEvIhIIBb6ISCAU+CIigVDgi4gEQoEvIhIIBb6ISCAU+CIigVDgi4gEQoEvIhIIBb6ISCAU+CIigVDgi4gEQoEvIhIIBb6ISCAU+CIigVDgi4gEQoEvIhIIBb6ISCAU+CIigVDgi4gEQoEvIhIIBb6ISCAU+CIigVDgi4gEQoEvIhIIBb6ISCAU+CIigUgq8M1skpntMrPdZrYgwfR/MrMmM9tmZr82s7/NfKkiIpKOLgPfzLoBy4FrgUuBmWZ2aVy3LUC1u1cCzwD/nOlCRUQkPcm8wx8L7Hb3Pe5+HFgJTIvt4O5r3f1odHQjUJbZMkVEJF3JBP5AYG/MeEu0rSNfB15MpygREcm87plcmJndAFQD/6WD6bVALcAll1ySyVWLiEgXknmHvw8YFDNeFm07jZl9DvgOMNXd/2+iBbn7Cnevdvfq0tLSVOoVEZEUJRP4m4DBZlZhZj2BGUBdbAczGw08SCTsD2a+TBERSVeXge/uJ4HbgJeAncBT7r7DzBaZ2dRot8VAb+BpM2s0s7oOFiciInmS1Dl8d18DrIlrWxgz/LkM1yUiIhmmT9qKiARCgS8iEggFvohIIBT4IiKBUOCLiARCgS8iEggFvohIIBT4IiKBUOCLiARCgS8iEggFvohIIBT4IiKBUOCLiARCgS8iEggFvohIIBT4IiKBUOCLiARCgS8iEggFvohIIBT4IiKBUOCLiARCgS8iEggFvohIIBT4IiKBUOCLiARCgS8iEggFvohIIBT4IiKBUOCLiARCgS8iEojuyXQys0nAUqAb8K/ufl/c9HOBnwJjgMPA9e7enNlST7ep7kEGbV5Mfz/EQStlb9V8Lp86L5urzJh0a8/n/MVce7HTthfnthdS7ebunXcw6wb8EbgaaAE2ATPdvSmmzy1Apbt/w8xmANe5+/WdLbe6utrr6+tTKnpT3YMMb7iHXna8ve2Y92T7mO8V/EGQbu35nL+Yay922vbi3PZs1G5mDe5encq8yZzSGQvsdvc97n4cWAlMi+szDfhJdPgZYKKZWSoFJWPQ5sWn7UCAXnacQZsXZ2uVGZNu7fmcv5hrL3ba9uLc9kKrPZnAHwjsjRlvibYl7OPuJ4EPgYviF2RmtWZWb2b1hw4dSq1ioL8nnre/v5vyMnMl3drzOX8x117stO2J2gt/2wut9pxetHX3Fe5e7e7VpaWlKS/noCWe96D1S3mZuZJu7fmcv5hrL3ba9kTthb/thVZ7MoG/DxgUM14WbUvYx8y6AxcQuXibFXur5nPMe57Wdsx7srdqfrZWmTHp1p7P+Yu59mKnbS/ObS+02pMJ/E3AYDOrMLOewAygLq5PHTA7Ojwd+I13dTU4DZdPncf2Md/jAKWccuMApUVxAQfSrz2f8xdz7cVO216c215otXd5lw6AmU0GlhC5LfNhd/8fZrYIqHf3OjMrAR4DRgPvATPcfU9ny0znLh0RkVClc5dOUvfhu/saYE1c28KY4VbgK6kUICIiuaFP2oqIBEKBLyISCAW+iEggFPgiIoFQ4IuIBEKBLyISCAW+iEggkvrgVVZWbHYI+I8MLKofUMhPUSrk+lRbagq5Nijs+lRbamJr+1t3T+lhZHkL/Ewxs/pUP3WWC4Vcn2pLTSHXBoVdn2pLTaZq0ykdEZFAKPBFRALxcQj8FfkuoAuFXJ9qS00h1waFXZ9qS01Gaiv6c/giIpKcj8M7fBERSYICX0QkEEUT+GY2ycx2mdluM1uQYPq5ZrYqOv1VMyvPUV2DzGytmTWZ2Q4zuzNBn/Fm9qGZNUa/FiZaVhZrbDaz16LrPuO/zljED6P7bpuZVeWoriEx+6TRzP5sZnfF9cnZvjOzh83soJltj2m70Mx+ZWZvRL9/ooN5Z0f7vGFmsxP1yVJ9i83s9ejr9ryZ9e1g3k6PgSzVdq+Z7Yt57SZ3MG+nP9tZqm1VTF3NZtbYwbzZ3m8J8yNrx527F/wXkf+09SbwSaAnsBW4NK7PLcC/RIdnAKtyVNvFQFV0uA/wxwS1jQd+nsf91wz062T6ZOBFwIBPAa/m6TU+QORDJXnZd8BVQBWwPabtn4EF0eEFwA8SzHchsCf6/RPR4U/kqL5rgO7R4R8kqi+ZYyBLtd0LfDOJ173Tn+1s1BY3/X8CC/O03xLmR7aOu2J5hz8W2O3ue9z9OLASmBbXZxrwk+jwM8BEM7NsF+bub7v75ujwX4CdwMBsrzfDpgE/9YiNQF8zuzjHNUwE3nT3THz6OiXu/u9E/kVnrNjj6ifAlxLM+nngV+7+nru/D/wKmJSL+tz9l+5+Mjq6ESjL9HqT0cG+S0YyP9tZqy2aEV8FnszkOpPVSX5k5bgrlsAfCOyNGW/hzFBt7xP9AfgQuCgn1UVFTyONBl5NMPnTZrbVzF40s8tyWRfgwC/NrMHMahNMT2b/ZtsMOv6hy+e++xt3fzs6fAD4mwR9CmH/Afwjkb/UEunqGMiW26Knmx7u4LREvvfdOOAdd3+jg+k5229x+ZGV465YAr/gmVlv4FngLnf/c9zkzUROVYwElgGrc1zele5eBVwL3GpmV+V4/Z0ys57AVODpBJPzve/aeeTv6IK8j9nMvgOcBB7voEs+joH/A/xnYBTwNpFTJ4VmJp2/u8/JfussPzJ53BVL4O8DBsWMl0XbEvYxs+7ABcDhXBRnZj2IvFiPu/tz8dPd/c/ufiQ6vAboYWb9clFbdJ37ot8PAs8T+TM6VjL7N5uuBTa7+zvxE/K974B32k5vRb8fTNAnr/vPzOYAU4BZ0XA4QxLHQMa5+zvu/pG7nwIe6mCdedt30Zz4MrCqoz652G8d5EdWjrtiCfxNwGAzq4i+G5wB1MX1qQParlJPB37T0cGfSdFzgD8Gdrr7Ax30+U9t1xPMbCyR/Z6rX0bnmVmftmEiF/m2x3WrA/7eIj4FfBjz52QudPguK5/7Lir2uJoN/CxBn5eAa8zsE9HTFtdE27LOzCYB/w2Y6u5HO+iTzDGQjdpirwNd18E6k/nZzpbPAa+7e0uiibnYb53kR3aOu2xdfc7C1ezJRK5gvwl8J9q2iMiBDlBC5JTAbuAPwCdzVNeVRP7c2gY0Rr8mA98AvhHtcxuwg8gdCBuBz+Rwv30yut6t0Rra9l1sfQYsj+7b14DqHNZ3HpEAvyCmLS/7jsgvnbeBE0TOh36dyHWgXwNvAC8DF0b7VgP/GjPvP0aPvd3AP+Swvt1EzuO2HXttd6oNANZ0dgzkoLbHosfTNiIBdnF8bdHxM362s11btP3RtuMspm+u91tH+ZGV406PVhARCUSxnNIREZE0KfBFRAKhwBcRCYQCX0QkEAp8EZFAKPBFRAKhwBcRCcT/A8Ebr2SItin+AAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# 预测结果与真实结果重合\n", "plt.scatter(np.arange(Y_valid.size), f.predict(X_valid))\n", "plt.scatter(np.arange(Y_valid.size), Y_valid)\n", "plt.legend([\"prediction\", \"ground truth\"])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "AI-Course", "language": "python", "name": "ai-course" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.3" } }, "nbformat": 4, "nbformat_minor": 4 }