{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 作业三\n", "\n", "## 一、Python类" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "class DummyData:\n", " def __init__(self, num=200):\n", " self.num = num\n", " \n", " def load_data(self, train=True):\n", " W = [0.3, 2, 1]\n", " if train:\n", " X = np.linspace(-2, 2, num=self.num)\n", " Y = W[0]*X**2 + W[1]*X + W[2] + 0.4*np.random.randn(X.size)\n", " else:\n", " X = np.linspace(-10, 10, num=self.num)\n", " Y = W[0]*X**2 + W[1]*X + W[2] + 5*np.random.randn(X.size)\n", " return X, Y" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "X_train, Y_train = DummyData(500).load_data()\n", "X_valid, Y_valid = DummyData(200).load_data(train=False)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXkAAAD4CAYAAAAJmJb0AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Il7ecAAAACXBIWXMAAAsTAAALEwEAmpwYAAAx0klEQVR4nO2deZxU5ZX3v6ebBhow7C7Q8IIZRKIiS4vGLRAMEjWgjDKYOAMhkUicITETt+QdJEY/H7dXHWZGM7hEJ6MDjEFEo68LwRjjmNgINCDwgoqxcUMUxNDa3fC8f9xb3beq7lp1b219vp9Pf7rqbs9Tt26de+55zvM7YoxBURRFqUyqit0BRVEUJTnUyCuKolQwauQVRVEqGDXyiqIoFYwaeUVRlAqmS7E74GTAgAFm2LBhxe6GoihKWbF27doPjTED3daVlJEfNmwYDQ0Nxe6GoihKWSEib3mt03CNoihKBaNGXlEUpYJRI68oilLBlFRM3o3W1laampr47LPPit0VJQLdu3enrq6OmpqaYndFUTo1JW/km5qaOOywwxg2bBgiUuzuKCEwxrBnzx6ampoYPnx4sbujKJ2akg/XfPbZZ/Tv318NfBkhIvTv31+fvhSlcTnccTws6mP9b1xe8C6UvCcPqIEvQ/Q7Uzo9jcvh8QXQ2my93/e29R5g9MyCdaPkPXlFUZSyZPX1HQY+RWuztbyAqJEPYM+ePYwZM4YxY8Zw5JFHMnjw4Pb3LS0tvvs2NDSwYMGCwDZOPfXUuLqbxsSJEwMnl915550cOHAgkfYVpVOzr8lj+dsF7UZZhGuKSf/+/Vm/fj0AixYtolevXvz4xz9uX9/W1kaXLu6nsb6+nvr6+sA2XnrppVj6mgt33nknl1xyCT169ChaHxSlIuld52HQxQrlFChkU3Ge/Mp1uzjtpt8y/JrfcNpNv2Xlul2xtzFnzhwuu+wyTj75ZK666ir+9Kc/8eUvf5mxY8dy6qmnsm3bNgCef/55zjvvPMC6QcydO5eJEydy9NFHs3jx4vbj9erVq337iRMncuGFF3LsscfyrW99i1TlrieffJJjjz2W8ePHs2DBgvbjOmlubmbWrFmMGjWKCy64gObmjkfF+fPnU19fz3HHHcd1110HwOLFi3nnnXeYNGkSkyZN8txOUZQcmLwQcBubMgUN2VSUJ79y3S6uXbGR5taDAOza28y1KzYCcP7YwbG21dTUxEsvvUR1dTWffPIJv//97+nSpQvPPfccP/nJT/j1r3+dtc/WrVtZs2YN+/fvZ+TIkcyfPz8rj3zdunVs3ryZQYMGcdppp/GHP/yB+vp6vve97/HCCy8wfPhwLr74Ytc+3X333fTo0YMtW7bQ2NjIuHHj2tfdeOON9OvXj4MHDzJ58mQaGxtZsGABt99+O2vWrGHAgAGe240ePTrGM6conYTRM2HFpe7rvEI5CVBRnvytT29rN/ApmlsPcuvT22Jv66KLLqK6uhqAffv2cdFFF3H88cdzxRVXsHnzZtd9zj33XLp168aAAQM4/PDDef/997O2mTBhAnV1dVRVVTFmzBh27tzJ1q1bOfroo9tzzr2M/AsvvMAll1wCwOjRo9OM8/Llyxk3bhxjx45l8+bNvPbaa67HCLudoigh6D3EY3ldwboQi5EXkZ0islFE1otIg72sn4g8KyLb7f9942jLj3f2Nkdang89e/Zsf/1P//RPTJo0iU2bNvH444975od369at/XV1dTVtbW05bROVN998k9tuu43Vq1fT2NjIueee69rHsNspihKSyQuhpjZ9WU2tHcopDHF68pOMMWOMMamRxmuA1caYEcBq+32iDOpTG2l5XOzbt4/Bg61w0AMPPBD78UeOHMkbb7zBzp07AVi2bJnrdmeeeSYPP/wwAJs2baKxsRGATz75hJ49e9K7d2/ef/99nnrqqfZ9DjvsMPbv3x+4naJUFFEnKeU6qWn0TPjGYtujF+v/NxYXNE8+yZj8dGCi/fpB4Hng6gTb48qzR6bF5AFqa6q58uyRSTbLVVddxezZs7nhhhs499xzYz9+bW0td911F1OnTqVnz56cdNJJrtvNnz+fb3/724waNYpRo0Yxfvx4AE488UTGjh3Lsccey5AhQzjttNPa95k3bx5Tp05l0KBBrFmzxnM7RakYok5SyndS0+iZ2ds1LrcGX/c1WaGbyQsTM/ySyt7I6yAibwIfAwb4d2PMEhHZa4zpY68X4OPU+4x95wHzAIYOHTr+rbfSte+3bNnCqFGjQvdl5bpd3Pr0Nt7Z28ygPrVcefbI2Addi8Gnn35Kr169MMZw+eWXM2LECK644opid8uXqN+dohSEO453T23sPQSu2BR9+6gGO/OmAVhZOMY6Zg4GX0TWOqIoacTlyZ9ujNklIocDz4rIVudKY4wREde7iTFmCbAEoL6+Pu87zvljB1eEUc/knnvu4cEHH6SlpYWxY8fyve99r9hdUpTyxHOSUg7Lc/Hy3WbCYsLvH5FYYvLGmF32/w+AR4EJwPsichSA/f+DONrqrFxxxRWsX7+e1157jYceekgnLylKrnhltuSyPIp0QSquHzTjNWbpg7yNvIj0FJHDUq+BKcAmYBUw295sNvBYvm0piqLkTdSMF7/tw3r/KY8/rKRBjHn0cXjyRwAvisgG4E/Ab4wx/xe4CfiaiGwHzrLfK4qiFJeoGS9+24f1/l1DND7EmEefd0zeGPMGcKLL8j3A5HyPryiKEjtuGS+5bD95YfYgqttTga9nbg+6+u2fBxU141VRFKWghH0q8PT4h8CMJYnm0auRD2DSpEk8/fTTacvuvPNO5s+f77mPU+L3nHPOYe/evVnbLFq0iNtuu8237ZUrV6bJCixcuJDnnnsuQu/DoZLEihICrwlRo2daqZSL9lr/vTx+r7h+mP3zQI18ABdffDFLly5NW7Z06VJP/ZhMnnzySfr06ZNT25lG/vrrr+ess87K6Vj5okZe6dSkDZyajlTHMpj5WnlGPuaaihdeeCG/+c1v2guE7Ny5k3feeYczzjgjlCzvsGHD+PDDDwFL4fGYY47h9NNPb5cjBisH/qSTTuLEE0/kr//6rzlw4AAvvfQSq1at4sorr2TMmDG8/vrrzJkzh0ceeQSA1atXM3bsWE444QTmzp3L559/3t7eddddx7hx4zjhhBPYunVrVp9UklhRIuKVKrni0vB2JmGP3YvKMvL53m1d6NevHxMmTGjXcVm6dCkzZ85ERLjxxhtpaGigsbGR3/3ud+1aMW6sXbuWpUuXsn79ep588kleeeWV9nUzZszglVdeYcOGDYwaNYr77ruPU089lWnTpnHrrbeyfv16vvjFL7Zv/9lnnzFnzhyWLVvGxo0baWtr4+67725fP2DAAF599VXmz5/vGhJyShL/7Gc/Y+3ate3r3D7TggUL2mUP1qxZ47mdolQsfgOnudqZAhX5riwjn1BNRWfIxhmqiSLL+/vf/54LLriAHj168IUvfIFp06a1r9u0aRNnnHEGJ5xwAg899JCnVHGKbdu2MXz4cI455hgAZs+ezQsvvNC+fsaMGQCMHz++XdTMiUoSK0pEglIao9qZBBxSLyrLyEedlhyS6dOns3r1al599VUOHDjA+PHjY5XlnTNnDv/6r//Kxo0bue666/KW903JFUeVKlZJYkXJIG2WqluVJwdR7EwBi3xXlpGPOi05JL169WLSpEnMnTu33YuPKst75plnsnLlSpqbm9m/fz+PP/54+7r9+/dz1FFH0draykMPPdS+3CkD7GTkyJHs3LmTHTt2APCrX/2Kr3zlK6E/j0oSK0oIsmapGnwNfRQ7k5BD6kZFlf8LPTEhBy6++GIuuOCC9rCNn3yvG+PGjeNv/uZvOPHEEzn88MPT5IJ//vOfc/LJJzNw4EBOPvnkdmM6a9YsLr30UhYvXtw+4ArQvXt3fvnLX3LRRRfR1tbGSSedxGWXXRb6s6gksaKEwEtIrLYftDX725kgZUqvIt8JVIyKRWo4Lurr601mvnZkudoC6jQr/qjUsFLWLOpD2kzUdsSawORlZ9ykhGtq01Mmw2wTgUJIDZcOUacrK4qiuOHnbfvZGb94e2qf1P8COKSVZ+QVRVHiIEz41y1yEDbeXiCHtCyMvDEGq7iUUi6UUhhQUXIiyNv2KhjStQe0/CX7eAnE28NQ8ka+e/fu7Nmzh/79+6uhLxOMMezZs4fu3bsXuyuKkh+5hGXcqO4aq7JkFEreyNfV1dHU1MTu3buL3RUlAt27d6eurjiei6IUhCjpjl17FW2ssOSNfE1NDcOHDy92NxRFUdLxGph1o/njZPviQ2VNhlIURQlLvtoxbvLBXhQpHg9l4MkriqLEjtegKYQPq6QNzPp49DFXeoqKevKKonQ+4tKOSckHz7jH3auv7Vcw3XgvYjPyIlItIutE5An7/XAR+aOI7BCRZSLSNa62FEVR8sIzl/3t3EI3bkVBZtwDV79Z9MmZcYZrfgBsAb5gv78ZuMMYs1REfgF8B7jba2dFUZSC4TdomkvoJrVtCc62j8WTF5E64FzgXvu9AF8FUqpaDwLnx9GWoihK3gQNmiYk+1sM4grX3AlcBRyy3/cH9hpjUmLmTcBgtx1FZJ6INIhIg+bCK4pSENLCKx4kIPtbDPI28iJyHvCBMWZt4MYuGGOWGGPqjTH1AwcOzLc7iqIo4UgNmnoZ+tq+8ZfnK1DJPydxePKnAdNEZCewFCtM889AHxFJxfzrgF0xtKUoihIvbqGbqhpo+TTe8nwFLPnnJG8jb4y51hhTZ4wZBswCfmuM+RawBrjQ3mw28Fi+bSmKosSOW2ZMt8PgYEv6dvnG6QtY8s9JkpOhrgaWisgNwDrgvgTbUhRFyZ3MzJhFfdy3yydOX8CSf05iNfLGmOeB5+3XbwAT4jy+oihKQfArGJJr9bkClvxzojNeFUVRMnGL09fUwogpucfVvY6ZsOSBGnlFURQnKU+9tRmk2lrWe4gVt9/+TO5xdbfYfwEkD1SgTFEUJUWmcJk52OFtj54JK+a57xc2rl6EWbHqySuKoqQIyoDxip8XUUo4CDXyiqKUN3FOMArKgClSXD0fNFyjKEr54qUL/+eXrfh53BkwQcW9SxA18oqilC9e4ZWG+wFjvY+iKjl5YfpNA7I99RJVm/RCwzWKopQvngOeJv1tiWfAJIl68oqilC9RimmXcAZMkqiRVxSlfHELryBkefLgnQHjnMFa29da1vxxWcTbw6DhGkVRyhe38Er93PAZMJnKkM0fWX8FVIlMGvXkFUUpb9zCK0NPCZcB4zZw6yQVyy9jb16NvKIo5UeQSFjYuHqYOH2ZV4jScI2iKOVFnMU3wsxULeHZrGFQI68oSnkRZ/GNoILeJT6bNQxq5BVFKS/8pAeiShxkDtzW9rP+KiRHHjQmryhKueGVG1/b113iAPwNdYXlxWeinryiKKWJl1fuJRIGRamhWuqokVcUpfTwG1z1kh5o/tj9WGWeHZMveYdrRKQ78ALQzT7eI8aY60RkOLAU6A+sBf7WGNPifSRFURQbv8HVVHglM8Sy+np/Bclca7OWOXF48p8DXzXGnAiMAaaKyCnAzcAdxpi/Aj4GvhNDW4qidAY8B1d9dGpGTMGSNHCQyo6JM+2yzMjbyBuLT+23NfafAb4KPGIvfxA4P9+2FEXpJHjmpou7YW5cDhseJkuzRqqskn2PXtZp4/WxxORFpFpE1gMfAM8CrwN7jTFt9iZNwGCPfeeJSIOINOzevTuO7iiKUu5MXkiWVw6AcTfMXvIELX+x9jEH3dvpBPH6WIy8MeagMWYMUAdMAI6NsO8SY0y9MaZ+4MCBcXRHUZRyZ/RMXJUkwd0w52qsy3w2axhiza4xxuwF1gBfBvqISGpgtw7YFWdbiqJUOL2HeCx3Mcy5GOsKmM0ahryNvIgMFJE+9uta4GvAFixjf6G92WzgsXzbUhSlExGlaHaQPEEKqaaSZrOGIY4Zr0cBD4pINdZNY7kx5gkReQ1YKiI3AOuA+2JoS1GUzkKUotmZ29b2hZZP4aAja7umttMYdidijEfcqwjU19ebhoaGYndDUZRKoBPlxYvIWmNMvds61a5RFKW0CDLOYY13hWvShEWNvKIopUNq0pKXyFjQeiUL1a5RFKV0CNKKz1dLPqoUcQWgnryiKKWDn1Z8mPVutId33saaYGWPQ3aSpwD15BVFKR288t1Ty4PWZ5KmWQNZE6w6gbSBGnlFUQpHULgkKDc+Su48eMsdOKlwaQMN1yiKkgyZWTAjplgiYn6DpkG58VFy5yGcAa9waQPNk1cUJX4ys2CAtHi4k95D4IpNyfTjjuP95YkrZIKUX568hmsURYkf1zCJl+DY28llvLjKHdjqlp1E2kDDNYqixE+kOLd0eNtxZ7xEDe9UIGrkFUWJn951HmGSzJCNSwjHWeYvReNyeOpqaP7Iel/bD75+czhj3clnvmq4RlGU+PHKgqmfm16AO4xmfONyWPn9DgMP1uvHLu8Uk5nyRT15RVHiJ2yYxGtg1Jnxsvp6ONSavc3BlmyPX8lCjbyiKMkQJkwyeWF2Fk5m3rtffL/Cc9zjQMM1iqJEJy4NmNEzrQyXVAinth90qbWKb6eO65fHXuE57nGgRl5RlGikSQWYjoyYsIY+8wYBVp78jCXQ1mzH3h3HHTEFqmqyj1PdtVOU78sXDdcoihKONKGvDDKVIv204L2kgr0UJrc/A+fflXt2TSdHZ7wqihKM6wxWF2pqs+PrzglHngOtQ+z4ups9Eli017tfnTgHPoXOeFUUJT/CCH1JdbDWu59UcF4KkzmEjToJeRt5ERkiImtE5DUR2SwiP7CX9xORZ0Vku/2/b/7dVRSlKARlsdTUgjkYvK+fIY9DYbITSAdHJQ5Pvg34R2PMl4BTgMtF5EvANcBqY8wIYLX9XlGUcsQ3w2WII0MmYF8/Q56ZaROkLZNLAZFOSN4Dr8aYd4F37df7RWQLMBiYDky0N3sQeB64Ot/2FEXJg1xj2F757JlGOCjnPYyUcNiYupd0gqZVphFrdo2IDAPGAn8EjrBvAADvAUfE2ZaiKBHJpwh2mBmsXtuAPeDqWBaHtHCYiVRKfNk1ItIL+B1wozFmhYjsNcb0caz/2BiTFZcXkXnAPIChQ4eOf+utt2Lpj6IoGfhltiSl5+6WlROnhrtm1wD+2TWxePIiUgP8GnjIGLPCXvy+iBxljHlXRI4CPnDb1xizBFgCVgplHP1RFMWFfGLYuRrToMHRfA10J1eYDEMc2TUC3AdsMcbc7li1Cphtv54NPJZvW4qi5EHUFMUU+aQqet5Y3tb0xwIRR3bNacDfAl8VkfX23znATcDXRGQ7cJb9XlGUYhE1RTGFnzcepGHjdQMJk1OvxEIc2TUv0l5PK4vJ+R5fUZSYyLVKUpA37jeQ6zU46jWxStMfY0e1axSlHMk1Rp5LDNsrVdHPG3emREJ2X700cDT9MXbUyCtKuZFPKmQu5OuNe91YNP2xIKh2jaKUG4Wezu81EzXMDNewx3TTkVdiQT15RSk3ijGdPwlvPHXMQj+ZdDLUk1eUciPXVMggolZ78tOaiXIsFRpLFPXkFaXcSGI6f67etJuHH/VYKjSWKOrJK0q5EVWtMQxxetNRj5XUk4kCqCevKOVJ3NP54/Sm/fLqF/XJTvlUobFEUU9eUcqdqLF0N+L0pn33cZEwSOLJRGmn8xr5OH4YilJs4iqBl6vkQdhjZZIZvhk901LCXLTX+q8GPjY6p5HX2pBKpRBXLD0fbzrTYYL0Y3mhA6sFoTKNfJCXrilbSqUQZyw95U3PWGK9DzMxycthgg7PPJ9JU0reVJ6RD+Ola8qWUinEnZkS9Sk3jMMUZyhIiUzlGfkwF52mbCmVgp8BDXqidVsf9Sk3jMOkA6tFpfJSKMNcdJqypVQKfnVV/SYkeU1YiioBHLaYtlZwKhqVZ+TDXHS56morSiniZkDvON5fBtjLY5dqMAez2/B6ylWHqeSpPCMf9qLzmo6thl+pBIKeaL3Wm4PZMsJ+RlsdppKn8ox8rhedKuFFJ+6bot5k4yPoidZz/RBHUY+Q34OGYkqayjHy+RoIvwEnvYCzifumqDfZeAl6ovVbr0a7oqiM7Bq3tK8V8+CJH4U/RiHTKithtm3ccw107kK8pGW00FGqL1V8WzNeOg2xePIicj9wHvCBMeZ4e1k/YBkwDNgJzDTGfBxHe1m4GQgMNNwPQ08Jd+GGzRLIl0rxWOO+KerchfhJXU9+11uc15yG20qSuDz5B4CpGcuuAVYbY0YAq+33yeBpCEx4TzCXCRu5eOSV4rHGPddA5y7ES+raXHFpYa43lQopWWIx8saYF4CPMhZPBx60Xz8InB9HW674GYIwnqBzEohU28cMeHzN9aKuFI817lmMOisyPtKuTQ/ivt4qxXmpQJKMyR9hjHnXfv0ecITbRiIyT0QaRKRh9+7dubU0eSGeQkhBnmDmDyKVQjZiinWBxq1/Uykea9wxXY0R50bYWauZxH29VYrzUoGIMSaeA4kMA55wxOT3GmP6ONZ/bIzp63eM+vp609DQkFsHnviRFYPH8XlqaoMNxc3DoTnzIQSsm4bJfp9KMVsxL2O9Y7tFe73by4zJh+2nomTidS0FGfgo11vYOPsdx3unZF6xKbgdJS9EZK0xpt5tXZKe/PsicpTdgaOADxJsC8673VLPi+IJNi73MPCQbcDt96mwTK3H/SrIQ1KPtbwo5Uwov1mrXkSVEA4bktRwW8mSZJ78KmA2cJP9/7EE27IImy3Q7p34xCz9aG22f1wZ3n7mRe3lBWkecnlQ6plQUWetRnUmoswd0ZmvJUtcKZT/BUwEBohIE3AdlnFfLiLfAd4Civttpxn2zFBMrhiywjipizoXA6EpaKVFqU+Qi3PWqhtR4+zqvJQksRh5Y8zFHqsmx3H8vMmKXQYY+JqewKHg2GbqWG5xx6gGotS9xs5IEQcTX1n17wx59VYON7v5RA6jtqaabq370g2226zVqhpo+Ys1ZtS7zgph5nr9FGruiJIo5T/jNUzMNEy2QYqqGujSLSO26VPCDNx/9FENhKaglR5FyoR6ZdW/c/za/82R7KZKoA/76da6l6y4eOb4Tm0/ELHHmSLmqrv9jjTOXhGUt5EPOzAU1vNK+5HQEdusn2ut88LtRx/VQGgKWulRSCPnMLJjX72GWmnx3tZ583cWwO7aEw62eG/r17ZXCT9NEih7ylugLGxIxOuxM0VqUGr19dnZNq3NsPlRaPN4EvD60UfV2Y76aKzx++Qp1GBiRqiuS5jxonyeHjOvnZa/eP+Ortik11WZU96efNiL2s0jS4VgnN6J1/GaP3IP90h1tmfTPp18HnSptZ8A7EfpLrXexZGjeI06hTw+gsJ9Tk85AYO3ct0u3lvxk/DhxBS5Pj26XTteacT6FFkRlLcnH6X0GAR7ZEEefybmULaBd3rvzR91hHs2POw/qBrFayz1rI8givEU4tYm5D7YHcNnWLluF//43xvYXrM7cNgni1yfHqOMT9X2tSc56dNiOVPeRj7ooo76Q/Q6Xpdad28ndTPxy7tvbYa1D2SXVHMzypkpaCkvM7P/QU8wzv6kyrllpngWi2JkEXm12cVldmiYm2Uen2Hlul3c+vQ2du3taPcdM4A6+TBrW2OsIaIsavt5t+P8TLX94Os3p28b1juvqoGWTzuue832KlvKO1zjN3s0l5CG1/G+frN3KCWMGJRbzUzw/8H59d/vsdxNiwdy09hPgmJkEXm1mWuYIsfP8LXbn+eHy9anGXiAW9pmcsB0TVt2wHTlPw6e5bqcr9+cffDU9+78TG7jSH6ZQU5xvm6H5TaIq5Qc5e3Jg/cEjFxDGn4TOtyeCtwKJmcStThyUP/9nmB8H8cjauwnQTGyiKIeW6o6UhSjHM+nnW/d8z9s/+AvrutWHTodWuGqLssZJHt4x/TnlraZrDp0OmsPHZO2/LaDM7kzn+vd7dpJkcoma9dm8viMOuhfVsQmUBYHeQmUZbKoD96TniS+i9O3HawfzYnfTI/Jp5b7paN5HtcWQPP6oQX1B4orGlUMISuvNmv7Wd6um8Hz+358PsPKiU9z69PbeGdvM4P61HLl2SP574Y/84fXvTSSwjOt6kVu6v0oPZrfy75+g64XJ0GyHqlqUmHPmQrsFZ1iCZQVF98JKzFmpPi1kwr3nHd79HzjoEwJr6yPMBN1ks6acGas3Dzc+ktlr4yYUvgJNiOmkDWyWVNrhT2+sdhd0MsvNOGRCfXKF/+Ba1dsZNfeZgywa28zP1y2PjYDf3PNvfRofhfX6zfKvIzUteM12ruvyTvbC3TSXplRuUbeNW0ygzguTq8fw4x70o1v1FS8XCfihPncSc7YzBxLaP4ofQbmhoetJ5tCTbBpXG61mSkbfeI3O0Jz5pD7vn4aLRk37VdO+Bmz/mcIza0dYblpVS/yYtcFvNHtm7zYdQHTql4M1WW3/a6uWZ49Qcp5/eZyvfjdGLzGp5o9KnhqumXJUv4xeS8yUxK9Qhj5Xpy5TpgJimvmety0/Vwet92UMp+6umPATqoso5drNk5Qil5rM2x/JrnQTJiJPhirDyly0WixbxAr1+1i0arN7H2pFec1Nq3qRW6quZcetmGukw+5qeZeaLVj8B647Xdz13vpjscM2NT1m8v1EpSd5jY+5XVdqZ5NyVK5MflMSqmoQSELh/jdTBqXw8rvw6FW933D9snZRih1z4DCKlHJSWHU0Qe378NLXdTB/165kf98+c+uR3+x6wLqqrLTIpsODeD0lsWAZdCtQdUPeccM4Ja2mVzVZbnrft6D93lev1EHUbXoTUniF5PvPEa+lC7OYt1w3Lxcz6IpIfvkaiADiPNz5tK+Wx98bhQHTFduqfk+Y861Mk4WrdrM3maPG6PNG92+SZVLyPuQEY7+/KEsj91aZ+XFe86LikMjPg40u6bk8DPylRuuyaSUihoUI43QbQJPGHLJGfejumu8g6xR2wf3WHUqNOFyA+4hLXy35T85fdmE0E14TXB6x/QHrHTJHhkxdrebQjtxacTHgerGlxWdx8hD6VycxdDpzsUYQnCfot6YuvaK9zsI035tP0uhMYxx9DjeINkTqVu3tM3M8tQPmK7c0jbTPp5LSKYdj4pj+Vy/6n13WjqXkS8VwmiMxP2jzOUpIUw2T1S9H6/sjFwJozCaObXfg5XrdnGS6c9gHw88Cs2mK7X2gOnH9GJR69+1D7p6efoW9lhAXN+9FqTp1KiRT4J8M2e8fpR/ftnKCkntM2JK+ns/Y1Db1z3+ntLJzzW7Jle9n7jwm8EJVj9s2jNh7Hh6lVhx8D61NbS0HeRA6yGmVfl74EFMq3qR67r8B/3k0zTdme6m43g9u1bzbv1V1L16Na6DxHGPzZS7oJ2SF2rk4yas1xQkn+D2o2y4n3ajsO9taLivY72fd9a4HD7fn91OddfQXq4nXjcsiKanH1f7tX0tYa2U7krzR7Q99g9c88gGHmk5NW3XQ/apdA6i+kkMBOE2mJqih7RwVc1yvjD2m9xw9BZY/S+4GvgkzpEWpOnUJJ5dIyJTgX8GqoF7jTE3eW2baHZNLuQSMsk3c6ZxOay4NLf+erXjN63/6jdzbyuIIsSBD9x8rD0rNB1n6mJSeKVNdiCW7LTzZp1aHpCumRellD6sJELRsmtEpBr4N+BrQBPwioisMsa8lmS7sZBrHDMXrykrfS8PolQMyic+7mXAC2XYXdpZefA0ph14z/UUOgdOnfnph6iimkPssvPUw3jsXvgPpmI9ZWQZeEgz8Kuv7yjCHde5i1qlTKkokg7XTAB2GGPeABCRpcB0oPSNfK5xzFzK+KX9AH1E1cJM8vGqGBRnNo/fmEFQcZRc2gpR7OPAry/nt63fpb5Lf9/UxcyQShWWpEHYGal++A6mVtXAZ3vxnnn9dnKDo6WUPqwUnKS1awYDTuvSZC9rR0TmiUiDiDTs3r07/h4ElXfzItc4ZlQNkbCpjTU9QmzjU282TlEwrxvg2gfiEa9q/856W16tQ1O/7bF/YO+KH2W100NauKrLck9t9tTAqVt+euYxwuLUmPlj7Q+omzDdXTeopqc1y8lLIwesGa1JCn8lXMZQKV2KLlBmjFlijKk3xtQPHDgw3oPnUws1iqqfE79CJm6EGvwSaM3QIq/tB/Xf6ZCFRSyjsOJSS/XR+Rkz+xRUbzYIrz7nUhwlk6wiLOmeb5eDn9HbuAwiY4VkVh06nWtav0vToQEcMkLToQFc0/rddu88KKTilg+fKRh2fvWL3HLMVhb3/CV1VR9SJXCE2e0uvjbjHujRL7sARxoSz7lTFBeSDtfsAoY43tfZywpDPqlj+cQxnZkzqXCDV5w1MM/cI0zTtaclYeymP9P8ETx2eUdfnH2KI2faq8+5FEcJJSiW0YzHsEUqJLPq0OmsanEPufjnp2fnw7sJht3Z7Zew26N04NoH7BRUx3ftVYDD+jTWYOz2Z1T4S0mEpD35V4ARIjJcRLoCs4BVCbfZQT6pY0GlBcOEgMI8SbhKA9tWrPcQAtUzV1/vLjB2sMX9UT+O8nte4Z/xc6KFhdzOT5CWjgdhc9ndwjl+x3AN7/iVDjQHCV2uUaphxhLrZh13SE1RbBL15I0xbSLy98DTWCmU9xtjNifZZhr5Dji65bKHnajkVYovFVJJlfELGhTzTH+zP4PfDStKpk2UsIBfn4eeEn6AL1ephQyMIS0k40d6Hrx3dk3fHjVc943jqHssmpxBGkHlGp1hPB0cVRKislUok1Ce9DK6bnojQQYsTF+CPoNnf4iWM1+MnOkwpQqx1RmBg1TRRbIHL+PMgb/klKHccP4JHQv8zm8oAso1KkoMdM7yfxB9EDQMnh5vhrFqbXYvK5e5TVCYJOgzTF5opedl4qX26BUeGjHFvx9JUNvXY3m/DrkFYC+9+EHr9/lR62W+mTNhyBxEnW5Xa6oWyTbwEFxpS6oB8f6ug8o1KkrCVLYnnwRRPbtAjz6GAhqZ1Z1q+/nLFTzxo+xJOc6ng0J4nV4FS6q7wti/zSp8nvLmPzK9EIE+/CWS5AB4yA6EfZrym4W8aF9p1StQOh2qJx8nroJYHhkwaRrgXiGVGLInokrQbn8G1yeP1FNFIRQLvQaMu/ay+pdxY0xprfeXTzlguvLD1vlpxt2tyhKQtqynfO4+iBqUbTV6ps93KJaB15i6UqKoJ58LmZ7uiClZnmeWF1dKnp5nLFx8BqtjLlrh1wfwWNdBZhm9LA+9qsbKtfTNT3e0GfQ01bjcToV06VcUvXpFSQAt/1cIwoQ4vKboF9r78xt89avTmk/5ubClB1OTuwJCYocQVk23ErVOeewrHEkes6XDDjov6h3ueEnfvHUQV8lAjXypUizv3q9dr7BEPoWk3dpz87RTfYDguq3OdkNm6bgS5XxHGY9x9i9Oo1xKT4RKydB5s2tKnTgmJmUSZqKWX8aO16ScfKbdu33OQ61W/N1NamH19Q55AMiSlcycJBRlXKO2X+7ZVkGZNk5S5yUfaQ03krhmlIpGB16LSdzFHKJIFngN1noNIHp5+GEMrJ/U8dVvuvd7w8Phs33cBsO9nhTyKZLidm48w051HdvGWZVJC4AoEVEjX0zilgCOy6B43QBy1fIJ+pxB/Q7KHvKrThV37DqzL17hk1T7cRvlYhSBV8oaNfLFJO5iDkl6eSnD5szH7xIydOGVdpqagBWX1ILfk0lSBKVOxm2UtQCIEhE18sUk7tzqQnh5bQ7j0vxRuBz60TMtbZ+0CVjGCskMPaV0vNNcB0j9njTiNsqaj69ERLNrKomkMy/y0b3x2zeMgFfSJHnuNOVRSRid8dpZSNrLyyes4rdvKXincQ+QOok6I1lRYkSNfKWRpEHJJ6wStG+xDaFmrSgViubJK+HJp7BFqRfFyLXco6KUOGrklfDkI92chOxzrkXa3Sj1m5Ci5IiGaxR/3AYNcy0uEmdIJo5atZl9Ax0gVSoOza5RvHHLOEkVnj7v9qJ1CyitCleKUmRUu0bJDdcarMbKd88nNBIHOlCqKKHIy8iLyEUisllEDolIfca6a0Vkh4hsE5Gz8+umUhT8Sh1mCmLFGR8Pgw6UKkoo8vXkNwEzgBecC0XkS8As4DhgKnCXSFDBU6Xk8DOYzhtA3EqLYdCBUkUJRV5G3hizxRizzWXVdGCpMeZzY8ybwA5gQj5tKUVg8kKyZH5TOG8AhZK/dT4tpMkRx5StoygVSFLZNYOBlx3vm+xlWYjIPGAewNChQxPqjpITrpozZHvMhYiPB8kRK4riSqAnLyLPicgml7/pcXTAGLPEGFNvjKkfOHBgHIdU4uS822HGEn+PuRDxcS2WoSg5EejJG2POyuG4u4Ahjvd19jKlHAnKby+E/G2hs2lUVEypEJJKoVwFzBKRbiIyHBgB/CmhtpRik8Rs1kwKmU1TjIFkRUmIvGLyInIB8C/AQOA3IrLeGHO2MWaziCwHXgPagMuN8SoSqlQESQuMFbJYRpKKlIpSYPIy8saYR4FHPdbdCNyYz/EVpZ1Cyg7oRCulglDtGqV8KJQccalUqlKUGFBZA8WdQs9gLSV0opVSQagnr2QTt8JjuaGKlEoFoUZeyUYHHotfqUpRYkLDNUo2OvCoKBWDGnklG1V4VJSKQY28ko0OPCpKxaBGXsmmEDNYFUUpCDrwqrijA4+KUhGoJ68oilLBqJFXFEWpYNTIK4qiVDBq5MuNziw3oChKZHTgtZzo7HIDiqJERj35ckJL4CmKEhE18uWEyg0oihIRNfLlhMoNKIoSETXy5YTKDSiKEhE18uWEyg0oihKRfAt53wp8A2gBXge+bYzZa6+7FvgOcBBYYIx5Or+uKoDKDSiKEol8PflngeONMaOB/wdcCyAiXwJmAccBU4G7RKQ6z7YURVGUiORl5I0xzxhj2uy3LwOpEcDpwFJjzOfGmDeBHcCEfNpSFEVRohNnTH4u8JT9ejDgLHffZC/LQkTmiUiDiDTs3r07xu4oiqIogTF5EXkOONJl1U+NMY/Z2/wUaAMeitoBY8wSYAlAfX29ibq/oiiK4k2gkTfGnOW3XkTmAOcBk40xKSO9Cxji2KzOXqYoiqIUEOmwyznsLDIVuB34ijFmt2P5ccDDWHH4QcBqYIQx5mDA8XYDb+XYnQHAhznumySl2i8o3b5pv6Kh/YpGJfbrfxljBrqtyNfI7wC6AXvsRS8bYy6z1/0UK07fBvzQGPOU+1HiQUQajDH1SbaRC6XaLyjdvmm/oqH9ikZn61deefLGmL/yWXcjcGM+x1cURVHyQ2e8KoqiVDCVZOSXFLsDHpRqv6B0+6b9iob2Kxqdql95xeQVRVGU0qaSPHlFURQlAzXyiqIoFUxZGXkRuUhENovIIRGpz1h3rYjsEJFtInK2x/7DReSP9nbLRKRrAn1cJiLr7b+dIrLeY7udIrLR3q4h7n64tLdIRHY5+naOx3ZT7XO4Q0SuKUC/bhWRrSLSKCKPikgfj+0Kcr6CPr+IdLO/4x32tTQsqb442hwiImtE5DX7+v+ByzYTRWSf4/stWJGBoO9GLBbb56xRRMYVoE8jHedivYh8IiI/zNimIOdMRO4XkQ9EZJNjWT8ReVZEttv/+3rsO9veZruIzM6pA8aYsvkDRgEjgeeBesfyLwEbsHL2h2PJHle77L8cmGW//gUwP+H+/h9goce6ncCAAp67RcCPA7apts/d0UBX+5x+KeF+TQG62K9vBm4u1vkK8/mB7wO/sF/PApYV4Ls7Chhnvz4MS/E1s18TgScKdT1F+W6Ac7B0rQQ4BfhjgftXDbyHNWGo4OcMOBMYB2xyLLsFuMZ+fY3bdQ/0A96w//e1X/eN2n5ZefLGmC3GmG0uqwJVL0VEgK8Cj9iLHgTOT6qvdnszgf9Kqo0EmADsMMa8YYxpAZZindvEMN5KpsUgzOefjnXtgHUtTba/68QwxrxrjHnVfr0f2IKH4F+JMh34D2PxMtBHRI4qYPuTgdeNMbnOps8LY8wLwEcZi53XkZctOht41hjzkTHmYyxp96lR2y8rI+9DGNXL/sBeh0HxVMaMiTOA940x2z3WG+AZEVkrIvMS7IeTv7cfl+/3eDwMrR6aEE4l00wKcb7CfP72bexraR/WtVUQ7PDQWOCPLqu/LCIbROQpW1qkUAR9N8W+rmbh7WwV65wdYYx51379HnCEyzaxnLe8ZrwmgYRQvSw2Ift4Mf5e/OnGmF0icjjwrIhste/4ifQLuBv4OdYP8udYoaS5+bQXR79MeCXT2M9XuSEivYBfY8mEfJKx+lWscMSn9njLSmBEgbpWst+NPe42DbugUQbFPGftGGOMiCSWy15yRt4EqF56EEb1cg/WY2IX2wPLWRkzqI8i0gWYAYz3OcYu+/8HIvIoVqggrx9G2HMnIvcAT7isSkQ9NMT5mkO2kmnmMWI/Xy6E+fypbZrs77k3HdpNiSEiNVgG/iFjzIrM9U6jb4x5UkTuEpEBxpjEhbhCfDfFVKX9OvCqMeb9zBXFPGfA+yJylDHmXTt09YHLNruwxg1S1GGNR0aiUsI1q4BZdubDcKy78Z+cG9jGYw1wob1oNpDUk8FZwFZjTJPbShHpKSKHpV5jDT5ucts2LjJioBd4tPcKMEKsLKSuWI+5qxLu11TgKmCaMeaAxzaFOl9hPv8qrGsHrGvpt143priwY/73AVuMMbd7bHNkamxARCZg/bYLcfMJ892sAv7OzrI5BdjnCFUkjecTdbHOmY3zOvKyRU8DU0Skrx1enWIvi0bSI8tx/mEZpybgc+B94GnHup9iZUZsA77uWP4kMMh+fTSW8d8B/DfQLaF+PgBclrFsEPCkox8b7L/NWGGLpM/dr4CNQKN9gR2V2S/7/TlY2RuvF6hfO7Dijuvtv19k9quQ58vt8wPXY92EALrb184O+1o6ugDn6HSsMFuj4zydA1yWus6Av7fPzQasAexTk+6X33eT0TcB/s0+pxtxZMYl3LeeWEa7t2NZwc8Z1k3mXaDVtl/fwRrHWQ1sB54D+tnb1gP3Ovada19rO4Bv59K+yhooiqJUMJUSrlEURVFcUCOvKIpSwaiRVxRFqWDUyCuKolQwauQVRVEqGDXyiqIoFYwaeUVRlArm/wNnqVSroh2hLwAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.scatter(X_train, Y_train)\n", "plt.scatter(X_valid, Y_valid)\n", "plt.legend([\"Training data\", \"Validation data\"])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 拟合上述曲线\n", "\n", "要求:\n", "\n", "* 函数$ f(x) = w_0x^2 + w_1x + w_2 $\n", "* 误差采用均方误差 $L := \\frac{\\sum_{i=1}^{n}(f(X_i) - Y_i)^2}{n}$\n", "* 固定步长的梯度下降法\n", "* 在尽可能不修改代码结构的前提下完成工作\n", "* 利用`X_train, Y_train`拟合,利用`X_valid, Y_valid`来验证拟合的效果" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 定义模型、优化器及误差" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "class Parabola:\n", " def __init__(self, step=1e-3):\n", " self.W = np.zeros(3)\n", " self.step = step\n", "\n", " def predict(self, X):\n", " \"\"\"计算f(X),X为向量\"\"\"\n", " assert len(X.shape)==1\n", " # 实现它\n", " raise(NotImplementedError())\n", " \n", " def fit(self, X_train, Y_train, num_iter=1000, verbose=False):\n", " \"\"\"利用训练集(X, Y)拟合函数W\"\"\"\n", " for i in range(num_iter):\n", " X, Y = X_train, Y_train\n", "\n", " Y_out = self.predict(X)\n", " dLdW = self._grad(X, Y)\n", " self._update_weight(dLdW)\n", "\n", " if verbose:\n", " cur_valid_loss = self._loss(Y_valid, self.predict(X_valid))\n", " cur_train_loss = self._loss(Y, self.predict(X))\n", "\n", " if i%50 == 0:\n", " print(f\"Iter {i}: train loss {cur_train_loss}, valid loss {cur_valid_loss}\")\n", " \n", " # 一般来说,如果函数用 _ 开头则表示这个函数是内部实现,随时可以根据需要进行调整(这仅仅只是一个约定)\n", " # 换句话说,Parabola类只提供两组API: fit, predict\n", " # API一般会保证一致性 -- 这里使用的是 sklearn 的 API\n", " def _loss(self, predict, real):\n", " \"\"\"计算预测值与真实值之间的误差 L\"\"\"\n", " assert len(real.shape)==1\n", " assert real.size == predict.size\n", " # 实现它\n", " raise(NotImplementedError())\n", "\n", " def _grad(self, X, Y):\n", " \"\"\"计算L在X处的关于参数W的导数, 其中Y=f(X), X、Y均为向量\"\"\"\n", " assert len(X.shape)==1\n", " # 实现它\n", " raise(NotImplementedError())\n", " return dLdW0, dLdW1, dLdW2\n", " \n", " def _update_weight(self, dLdW):\n", " \"\"\"利用计算出的导数dLdW来更新权重W\"\"\"\n", " # 实现它\n", " raise(NotImplementedError())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 拟合模型" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Iter 0: train loss 7.387469877566412, valid loss 358.1376185737317\n", "Iter 50: train loss 5.192845037696437, valid loss 160.08693537131256\n", "Iter 100: train loss 3.7993227032348904, valid loss 103.14076657264337\n", "Iter 150: train loss 2.8636227125989926, valid loss 96.35787891588102\n", "Iter 200: train loss 2.206729210467262, valid loss 103.37637041458811\n", "Iter 250: train loss 1.730252672620262, valid loss 110.99809060880956\n", "Iter 300: train loss 1.3767076385617047, valid loss 115.41139000402633\n", "Iter 350: train loss 1.1103166870208538, valid loss 116.29870118488779\n", "Iter 400: train loss 0.9074881658744149, valid loss 114.4178955811505\n", "Iter 450: train loss 0.7519132526248361, valid loss 110.679998946516\n", "Iter 500: train loss 0.6319138778040571, valid loss 105.84840700299152\n", "Iter 550: train loss 0.5389207350200598, valid loss 100.48027987355904\n", "Iter 600: train loss 0.46654278936343585, valid loss 94.95087773955072\n", "Iter 650: train loss 0.40996333070662927, valid loss 89.49759155918086\n", "Iter 700: train loss 0.36552772603395844, valid loss 84.26110710933355\n", "Iter 750: train loss 0.3304510115073017, valid loss 79.317572677117\n", "Iter 800: train loss 0.3026047908736134, valid loss 74.70170538739922\n", "Iter 850: train loss 0.2803590938925677, valid loss 70.42258659663358\n", "Iter 900: train loss 0.26246364369780906, valid loss 66.47409851465721\n", "Iter 950: train loss 0.24795805508259805, valid loss 62.84163642355465\n" ] } ], "source": [ "model = Parabola()\n", "\n", "model.fit(X_train, Y_train, verbose=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 显示结果" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.scatter(X_valid, model.predict(X_valid))\n", "plt.scatter(X_valid, Y_valid)\n", "plt.legend([\"prediction\", \"ground truth\"])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "AI-Course", "language": "python", "name": "ai-course" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.9" } }, "nbformat": 4, "nbformat_minor": 4 }