"source": [
"# Introduction to programming artificial neural networks\n",
"#### Tutorial for Methods In Neuroscience at Dartmouth ([MIND](http://mindsummerschool.org/)) 2023\n",
"By [Mark A. Thornton](http://markallenthornton.com/)\n",
"\n",
"This tutorial offers an introduction to programming your own customized artificial neural network (ANN) for the first time. It is based on the popular ANN programming framework [PyTorch](https://pytorch.org/). You will build up an ANN to perform regression, starting from a very simple network and working up step-by-step to a more complex one.\n",
"\n",
"This notebook focuses on the implementation of ANNs. If you're interested in a complementary conceptual introduction to ANNs, their potential uses in social neuroscience, and their limitations, please consider [my preprint](https://psyarxiv.com/fr4cb) with [Beau Sievers](http://beausievers.com/).\n",
"\n",
"The figure below, created by [Lindsey Tepfer](https://pbs.dartmouth.edu/people/lindsey-j-tepfer) for the aforementioned preprint, illustrates the (A) general structure of ANNs, (B) the internal structure of individual units, which approximate generalized linear models, and (C) the training process using stochastic gradient descent via backpropagation. The terminology in this figure will reappear throughout the tutorial.\n",
"\n",
"![](https://mysocialbrain.org/misc/data/ann_tutorial/Fig1_DS_hires_top.png)\n"
]
},
"source": [
"## Setup\n",
"This section includes the import statements for the packages/functions we'll need here, detection of the available hardware for ANN fitting, and code to simulate the artificial data we'll be using."
]
},
},
},
},
},
base64 image data (PNG)base64 image data (PNG)base64 image data (PNG)base64 image data (PNG)base64 image data (PNG)base64 image data (PNG)
"source": [
"nepoch = 10 # epochs = how many times the model sees the dataset\n",
"# note that we're not actually using the GPU yet - see the next example for that\n",
"for epoch in range(nepoch):\n",
" # convert inputs and targets to torch tensors\n",
" inputs = torch.from_numpy(np.float32(x_train))\n",
" targets = torch.from_numpy(np.float32(y_lin_train))\n",
"\n",
" # propagate activity forward through network to make prediction\n",
" outputs = model1(inputs)\n",
"\n",
" # compute loss (error) of predictions\n",
" curloss = loss(outputs, targets)\n",
"\n",
" # backpropagate errors to change weights and biases via SGD\n",
" optimizer.zero_grad()\n",
" curloss.backward()\n",
" optimizer.step()\n",
"\n",
" # print\n",
" if (epoch+1) % 1 == 0:\n",
" print ('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, nepoch, curloss.item()))\n"
]
},
base64 image data (PNG)base64 image data (PNG)base64 image data (PNG)base64 image data (PNG)base64 image data (PNG)base64 image data (PNG)
base64 image data (PNG)base64 image data (PNG)base64 image data (PNG)base64 image data (PNG)base64 image data (PNG)base64 image data (PNG)
"source": [
"# plotting actual values vs. predictions\n",
"ypred = model2(torch.from_numpy(np.float32(x_test)).to(device)).cpu().detach().numpy()\n",
"plt.scatter(ypred,y_non_test,s=.1)\n",
"plt.xlabel(\"Predictions\");\n",
"plt.ylabel(\"Actual values\");"
]
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"image/png": [BINARY IMAGE DATA REMOVED]],
"source": [
"# plotting actual values vs. predictions\n",
"ypred = model3(torch.from_numpy(np.float32(x_test)).to(device)).cpu().detach().numpy()\n",
"plt.scatter(ypred,y_non_test,s=.1)\n",
"plt.xlabel(\"Predictions\");\n",
"plt.ylabel(\"Actual values\");"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kxzLLCNRw_W5"
},
"source": [
"### Deep neural network\n",
"In the example below, we'll use far fewer units to achieve similar performance, via a deep neural network. This network features 5 ReLU layers, with units decreasing in power of two. The final layer is a single linear unit, as in previous cases. An important addition here is batch normalization after each ReLU activation. Batch normalization is basically like z-scoring the data. This is really helpful to prevent what's known as the \"exploding gradient\" problem. Basically, nonlinear transformations have the potential to make some numbers really huge (or tiny) and this can cause problems for numerical computing with float point representations. Batch normalization helps to mitigate this problem, serving as a sort of regularization that improves training."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "wGma6Ndbxc1n",
"outputId": "80848c9e-9571-4f0c-bcea-6b9c84f31e10"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Epoch 1\n",
"-------------------------------\n",
"loss: 4.667707 [ 64/100000]\n",
"loss: 0.989088 [ 6464/100000]\n",
"loss: 0.216513 [12864/100000]\n",
"loss: 0.264145 [19264/100000]\n",
"loss: 0.306660 [25664/100000]\n",
"loss: 0.456342 [32064/100000]\n",
"loss: 0.057279 [38464/100000]\n",
"loss: 0.060745 [44864/100000]\n",
"loss: 0.394866 [51264/100000]\n",
"loss: 0.418425 [57664/100000]\n",
"loss: 0.314052 [64064/100000]\n",
"loss: 0.173316 [70464/100000]\n",
"loss: 0.039098 [76864/100000]\n",
"loss: 0.173622 [83264/100000]\n",
"loss: 0.153746 [89664/100000]\n",
"loss: 0.228229 [96064/100000]\n",
"Epoch 2\n",
"-------------------------------\n",
"loss: 0.071922 [ 64/100000]\n",
"loss: 0.055156 [ 6464/100000]\n",
"loss: 0.248226 [12864/100000]\n",
"loss: 0.049417 [19264/100000]\n",
"loss: 0.040127 [25664/100000]\n",
"loss: 0.057625 [32064/100000]\n",
"loss: 0.714862 [38464/100000]\n",
"loss: 0.200428 [44864/100000]\n",
"loss: 0.236930 [51264/100000]\n",
"loss: 0.031537 [57664/100000]\n",
"loss: 0.179619 [64064/100000]\n",
"loss: 0.115763 [70464/100000]\n",
"loss: 0.201975 [76864/100000]\n",
"loss: 0.044374 [83264/100000]\n",
"loss: 0.070829 [89664/100000]\n",
"loss: 0.019935 [96064/100000]\n",
"Epoch 3\n",
"-------------------------------\n",
"loss: 0.119832 [ 64/100000]\n",
"loss: 0.057674 [ 6464/100000]\n",
"loss: 0.081795 [12864/100000]\n",
"loss: 0.079885 [19264/100000]\n",
"loss: 0.036620 [25664/100000]\n",
"loss: 0.088968 [32064/100000]\n",
"loss: 0.167877 [38464/100000]\n",
"loss: 0.203338 [44864/100000]\n",
"loss: 0.018583 [51264/100000]\n",
"loss: 0.209267 [57664/100000]\n",
"loss: 0.175974 [64064/100000]\n",
"loss: 0.060453 [70464/100000]\n",
"loss: 0.039461 [76864/100000]\n",
"loss: 0.105169 [83264/100000]\n",
"loss: 0.085562 [89664/100000]\n",
"loss: 0.035103 [96064/100000]\n",
"Done!\n"
]
}
],
"source": [
"# new we'll train the model\n",
"loss = nn.MSELoss()\n",
"optimizer = torch.optim.SGD(model4.parameters(),lr=.01)\n",
"epochs = 3\n",
"for t in range(epochs):\n",
" print(f\"Epoch {t+1}\\n-------------------------------\")\n",
" train(train_dataloader, model4, loss, optimizer)\n",
"print(\"Done!\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 449
},
"id": "_D7aGcR7x0y6",
"outputId": "f701b94e-19cf-478e-c1be-b11aca797bb1"
},
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
""
],
"image/png": "image/png": [BINARY IMAGE DATA REMOVED]AANrk9xEFs2fP1r333qvt27crLCxM77zzjgoLC3XZZZfp+uuv74oaAQAA2uR3kNm5c6duueUWSZLFYlFDQ4OioqL029/+Vo899linFwgAANAev4NMZGSkty8mLS1Ne/fu9T5XVlbWeZUBAACcht89MhdeeKHWrVunnJwcXX311brnnnu0fft2vfvuu7rwwgu7okYAAIA2+R1knnzySdXW1kqSHn74YdXW1urNN99UdnY2K5YAAEC3YkM8AAAQdHz9/va7R+Z//ud/tGrVqo7UBgAA0Cn8DjJHjx7VVVddpczMTN13333atm1bV9QFAABwWn4HmX/84x86cuSI5syZow0bNuiCCy5Qbm6u5s2bp/3793dBiQAAAG3zO8hIUnx8vG6//XatWrVKBw4c0K233qpXXnlFAwYM6Oz6dOjQIf3gBz9QYmKiwsPDNWTIEG3cuLHTfw4A4MzUO5xavKVI9Q5nt7wPaOmMgkyzpqYmbdy4UZ999pn279+vlJSUzqpLknTs2DFdfPHFCg0N1b/+9S99+eWXeuKJJxQfH9+pPwcAcOaW5Rerotah5fkl3fI+oKUzOjTyo48+0muvvaZ33nlHbrdb1157rd5//319+9vf7tTiHnvsMWVmZuqll17yPpaVlXXK99jtdtntdu/96urqTq0JANBaXm6qlueXaEKuf/8xe6bvA1rye/l1RkaGKioqdNVVV+nmm2/WpEmTZLPZuqS4QYMGKS8vT0VFRVq9erUyMjL0s5/9TLfddlu77/nNb36jhx9++KTHWX4NAMGr3uHUsvxi5eWmKsJ6Rv+NjR7G1+XXfgeZF154Qddff73i4uI6WuNphYWFSZJmzZql66+/Xhs2bNBdd92lRYsWadq0aW2+p60RmczMTIIMAASxxVuKVFHrUGKUTVOHZwS6HASBLgsy3clqtWrkyJH65JNPvI/NmDFDGzZs0Pr16326BhviAUDwq3c4vdNMjMhA6sIN8bpTWlqaBg0a1OqxnJwcHTx4MEAVAQC6QoTVoqnDMwgx8FtQB5mLL75Yu3btavXY7t271a9fvwBVBAAAgklQB5m7775bn376qebNm6c9e/botdde0/PPP6/p06cHujQAABAEgjrIjBo1SosXL9brr7+uwYMH65FHHtHChQt18803B7o0AAAQBHxq9l26dKnPF5w8eXKHCupsNPsCAGA8vn5/+9RVNXXqVJ9+qMlkksvl8um1AAAAHeVTkHG73V1dBwAAgN+CukcGABDcOPgRgXZGC/br6uq0evVqHTx4UA6Ho9VzM2bM6JTCAADBr+XBj+zIi0DwO8hs2bJFV199terr61VXV6eEhASVlZUpIiJCycnJBBkA6EU4+BGB5vfU0t13361Jkybp2LFjCg8P16effqoDBw5oxIgRevzxx7uiRgBAN/N1yogdeRFofgeZrVu36p577lFISIjMZrPsdrsyMzO1YMECPfDAA11RIwCgm7WcMgKCmd9BJjQ0VCEhx9+WnJzsPfcoNjZWhYWFnVsdACAg8nJTlRhlY8oIQc/vscDhw4drw4YNys7O1mWXXaYHH3xQZWVleuWVVzR48OCuqBEAcBr1DqeW5RcrLze1U6Z5mqeMgGDn94jMvHnzlJaWJkmaO3eu4uPjdccdd+jo0aN6/vnnO71AAEBrbfWvMBWE3sqnIwqMjCMKAPQ0i7cUqaLWocQom3fUpN7h9K4eovEWPYGv399siAcABtNW/wqrh9Bb+f2/+KysLJlMpnaf//rrrztUEADg1OhfAf7L7yAzc+bMVvebmpq0ZcsWffjhh7rvvvs6qy4AAIDT8jvI3HXXXW0+/swzz2jjxo0dLggAcFxnr0QCeqJO65GZOHGi3nnnnc66HAD0CqfaQZeVSMDpdVqQefvtt5WQkNBZlwOAXuFUYYVN6YDTO6MN8Vo2+3o8HhUXF+vo0aP605/+1KnFAUBPd6pDF2nqBU7P7yAzZcqUVkEmJCREZ511li6//HKdd955nVocAPR0hBWgY9gQDwAABJ0u2xDPbDartLT0pMfLy8tlNpv9vRwAAMAZ8zvItDeAY7fbZbVaO1wQABjZqVYhAeh8PvfIPPXUU5Ikk8mkP//5z4qKivI+53K5tGbNGnpkAPR6LVch0fsCdD2fg8zvf/97ScdHZBYtWtRqGslqtap///5atGhR51cIAAZyqlVIADqfz0Fm3759kqRx48bp3XffVXx8fJcVBQBGxSokoHv53SPz0UcfEWIA4AT0xgCB4XeQue666/TYY4+d9PiCBQt0/fXXd0pRAGA0pztOgKADdA2/g8yaNWt09dVXn/T4xIkTtWbNmk4pCgCM5nTHCXBuEtA1/A4ytbW1bS6zDg0NVXV1dacUBQBG09wb094p1ZdmJ2lfWZ0uyU7s5sqAns3vIDNkyBC9+eabJz3+xhtvaNCgQZ1SFAD0NGsLypSVFKl1BeWBLgXoUfw+a2nOnDm69tprtXfvXn3729+WJK1cuVKvv/66/u///q/TCwQAo6h3OLUsv1h5uaknjcywLBvoGmd01tIHH3ygefPmaevWrQoPD9f555+vhx56SJdddllX1NghnLUEoLss3lKkilqHEqNsLMEGOsjX7+9OPTRyx44dGjx4cGddrlMQZAD468SRlVONtJz4vuZRl1O9DsDpddmhkSeqqanR888/r9GjR2vo0KEdvRwABNyJK4x8XXF0uoZfAJ3vjIPMmjVrdMsttygtLU2PP/64vv3tb+vTTz/tzNoAoNvVO5yyO92KCrN4+1lOt7QaQOD49Z8NxcXFevnll/Xiiy+qurpaN9xwg+x2u5YsWcKKJQA9wrL8YtU1OpUYZfOOrHDsABC8fB6RmTRpkgYOHKgvvvhCCxcu1OHDh/XHP/6xK2sDgG7H6AtgLD6PyPzrX//SjBkzdMcddyg7O7srawKAgDjTJl8AgePziMy6detUU1OjESNGaMyYMXr66adVVlbWlbUBQKfx5ayjM23yBRA4PgeZCy+8UC+88IKOHDmin/70p3rjjTeUnp4ut9utFStWqKampivrBIA2+XoYoy+hpHla6ZLsRC3eUqRLs5OYZgKCXIf2kdm1a5defPFFvfLKK6qsrNSVV16ppUuXdmZ9HcY+MkDP5usmdP7s8cLGdkDgdcs+MgMHDtSCBQtUVFSk119/vSOXAoAzcqrm3JajNf7s8ULDL2AcnbqzbzBiRAbovTprZIWmX6D7ddvOvt3p0Ucflclk0syZMwNdCgAD6KyRFZp+geBlmCCzYcMGPffcczr//PMDXQoAg+isIwOYagKClyGCTG1trW6++Wa98MILio+PD3Q5AHoZzlACgpchgsz06dN1zTXXaPz48ad9rd1uV3V1dasbgN7H12XZAIwt6IPMG2+8oc2bN2v+/Pk+vX7+/PmKjY313jIzM7u4QgDBiL4WoHcI6iBTWFiou+66S6+++qrCwsJ8es/s2bNVVVXlvRUWFnZxlQAC4XQjLvS1AL1DUC+/XrJkib773e/KbDZ7H3O5XDKZTAoJCZHdbm/1XFtYfg0Yg79LnNm0DujZesTy6yuuuELbt2/X1q1bvbeRI0fq5ptv1tatW08bYgAYx4lTQe2NuDQ/zvEBACQ/Tr8OhOjoaA0ePLjVY5GRkUpMTDzpcQDGlpebquX5Jd5zjuxOt+oajx8r0HLEpTnwrCsoZyQGQHCPyADo2do6QmBtQZkqah0ySa0OcGwemaH3BUBLQd0j0xnokQGCV1t9Lice7kgvDNA79YgeGQA9W1ujK80jM5LohQFwWgQZAAFzqh1zT+yFibBa2OQOwEkIMgCCUlujNWxyB+BEBBkAAXOqEZa2Rmto9AVwIoIMgA45MYz4M/3j7wgLhzcCOBFBBkCHnBhGmu+/t+3waQMNIywAOoogA6BDTgwjzfc90mlHWxhhAdBRBBkAbfJ1iujEMNJ8f/LQdJ9GW1iJBKAjCDIA2nSq/pXTnYPUcqfe0422sBIJQEcQZAC0qa3+leagsnTboTbDx5mEEvpkAHQEQQbASeodTi3LL/YeE9BsWX6xiqsateVgpaLCLJ1yDhJ9MgA6giAD9CK+9qMsyy9WcWWj5n2wU2W1jd735OWm6tCxBvVPiJQk/X5FgYqrGr0jMIQSAN2NIAP0EvUOp+Z+sFPFlY2nnfrJy03VocoGZcSHa+GKAu90UYTVoplXZutQZYPsTpf6xIXr0LEGpoUABAxBBuglluUXq098uA5Vnj54RFgteuCaHCVEWnVuapSiwize96wtKFNWUqRsFrPS4sL1wDU5jMAACBiCDNBL5OWmKi3W9+ARYbXIaglRXaNL24uqWl0nMcqmSUPTmUYCEHAEGaCXiLAeb86d+03fiy9aTjHRBwMgGBFkgF5k/gc7tXpXqR79YKdPr2+eYkqLDacPBkBQIsgAvYjT7VGdw6VGl+ek1UvtrWhiBAZAMCPIAD3YieFkRP94nZ0UqTCz6aSN61puZsexAQCMgiAD9CAnBpDmDezmfbBT9Q6nvjciUz8c21/3X5Nz0sZ1LTez49gAAEZh8ng8nkAX0ZWqq6sVGxurqqoqxcTEBLocoMs07xPTJz5cCZFWWS0hujQ7SQtXFCgjLlxpceGaOjzD52stzy85aWdfAOguvn5/8xsKMLjm4wRqGp2qqnfI5XIr0mbWJ3vK5XC69cA1Od5Q4qvmvhgACHYEGcDglm47pPV7KmQxm2Qxh2h4v3h5JJlMx59vGUqaQ09ebiojLQB6BH6TAUHMt+BhkskkDcuMU3RYqHfkxSTJ8801mt/bsveFERcAPQHNvkAQ86XpdvLQdF0+MFkTh6TKo+Mtb//dldfZ6r1ncjo1AAQzggwQhJpXH12anXTa4NE8dbS2oKzV8mm709XqjKSWr2VaCUBPQZABglDzSMy6gnKfg8eJy6frGl2yWcyEFgA9GkEGCEJnMgXUcrSFKSQAvQVBBghSHnnOeIddppAA9Bb8lgOCUPPU0uMf7lJDk1sOp1s3juob6LIAIOgwIgMEwKlGWo436roVFWbRoIwY734wAICTEWSAbtIyvCzddkirvjqq97YdPul1xxt1nbJZzPreiExdPjBZk4amB6BiAAh+BBmgi5w46rJ022Gt2tUcXkxyeTz6fF+F3thwsNXIzKXZSdpXVqdLshPpdQGA0yDIAF3k5M3sPGqeJZo8NF1x4aFqdLj06d7yVpvWrS0oU1ZSpNYVlHd7zQBgNAQZoJOV1TbqV4u3KzctxjuyIkmTh2Z4p4kirBY9cE2OLjwnUbbQEO9rJHbfBQB/EGSATtI8lfS/y3aptLpRD7/3pXdkpay2UXM/2OmdLpKOL5GOCrPo3OToVqMvTCcBgO8IMkAHNIeX5qBSXNmowemxSokJ06PXDfGOrPx+RYFKqxu1cEVBq/cz+gIAHUOQAU6jZVhp2bxb73Bq7gc7dbCsXjNe36qUGJsOVTbouhF99MA1Odp44Jgm5KYowmrR3VdmKyUmTDOvzG51bUZfAKBjTB6PxxPoIrpSdXW1YmNjVVVVpZiYmECXgyB3fGn0YUkeTR56PGAs3lKkilqHdpfWqLHJrZH94hVqCdGGfRWqczh1rL5JF52dqKM1dj1wTU6r9yRG2TR1eEag/1oAYDi+fn8zIgO0sCy/WOv3lmvd7nLN+2Cn6h1OXZqddDzE2J1yuTzacbhK6/eWa09prY5W2/WdIWnqlxjpDTESU0YA0F0YzwZayMtNlcPp1oavy1XZ4NB72w7LaglRrd2lwop6RYRa9MuJ5+uTveUa2S9eVkuIdxVSS81TRgCArkWQAVqIsFp046i+8kj6dG+5HE63PJLCLCaZZFJSlFWf7C33Bhjp+ChOXm4qfS4AEAD85gX0394Yh9Mth9OlnUeqNSQjVjsOVSk9LlwWc4i+N6KPrJYQeSTvRnceebx/njo8Q/UOJ8EGALoRv2nR69U7nJqzZIc27K9QlM2igxX16hMfruIquy7JTtK6gjJdMiBJ0WGh3rCyPL/E2//S8s8td/NlagkAul5QN/vOnz9fo0aNUnR0tJKTkzV16lTt2rUr0GXBoOodTr2x4aDe2HDgpCXUu4pr1ORyq7CiXqnRNjldHj163RClxYbrD98fprS4cG9Yablk+sTl0zT5AkD3Curl11dddZVuuukmjRo1Sk6nUw888IB27NihL7/8UpGRkT5dg+XXvUdb0zotl1NLJq3fWy6TpMsHJmtCbormfrBTKTE2HSyvV9Gxeg3LjNPWwipdfX6qvjcik+khAAgQX7+/gzrInOjo0aNKTk7W6tWr9a1vfavN19jtdtntdu/96upqZWZmEmR6gZZ7t0zITdGy/GLZnW6t3X1UxVWNmjo8Q1bL8UHISUPTtSy/WMWVjTpU2aCZV2br3ztLZJJJHnlU1+hiDxgACKAeuY9MVVWVJCkhIaHd18yfP1+xsbHeW2ZmZneVhwBrnta5JDvx+I675fXaeuCYwkLNSosLl9USohtH9dWkoelauu2QymvtOlBep5lXZmttQZnqGl2yWcyaPDSD6SEAMAjDjMi43W5NnjxZlZWVWrduXbuvY0QGi7cUqbiq0dukmxBllc1i9gaTuR/sVFWDQ0cqG5WZEKGx5yR+M/GkNveEAQB0P19HZAzzG3v69OnasWPHKUOMJNlsNtlstm6qCsEoLzdV7207rHHnnaU9JbW6flQfJUWFSZLe2HBAVfUOhVlCvFNNHkl1jU4lRtkIMQBgMIaYWrrzzjv1/vvv66OPPlKfPn0CXQ6CRPNhjs0rkJpFWC2yWkK0/VC1au0uLVxR4D3wsaK2STsOVysnPVY/HNtfN47qq8lD05lKAgCDCuog4/F4dOedd2rx4sX6z3/+o6ysrECXhCDRvGy6uLJRy/NLTno+LzdVF52TqCibWckxNs14fauKqxq1evdRRdksKiip8b6WE6gBwLiCOshMnz5df//73/Xaa68pOjpaxcXFKi4uVkNDQ6BLQ4At3XZY5XV2rd59VCP6xZ00MtN81MCDk3NVWm3XmP4JOnSsQY9ff76GZcbp3qsGBrB6AEBnCepmX5PJ1ObjL730km699VafrsE+MsbW3pb/f1u/X39Zt0/npUYrKcqmrKTIdpdLt9yJt+U1OE4AAIJXj2j2DeKMhW7SvOV/8ynUzaHDajFpcEasoqxmzbwyW+sKytvscTlVWOE4AQAwvqCeWgKa94bxSDpYXq//+etGldU2avLQDI3PSdGDk3OVFBWmCbkpWrrtcKvjB6TWYaW9a9PkCwDGRZBBwLW3+kg63usyITdFDqdLS7Ye0tGaRt3x9016e1ORGp0u7+uW5Rdr/d5yfbq3whta6h1O2Z0uRYVZ2gwrNPkCgPHxGxwBd+IUT8vzkSYPzdCy/GJtOlCpCKtZB8rrVedw6t3NReqfGKkwi1lTh2coLzdVDqdbklqdRN181ABhBQB6Jn67I2Ca+1dG9ovXc6u/1pTh6ZL+O7ricnu0o6haP73sbK0rKJPNHKL+CREKCTFpwFmRGpWV2OpE6htH9W11/bzcVG+TLwCgZwrqVUudgVVLwav5kMd9ZXXKiA/XoWMNeuCaHNU7nPrfZbvkdLnVNyFSn+2rUEy4RUcqG5WdHKVRWQkcJQAAPVyPPDQSxtNe/0vL/pVbxvbTWxsLtbnwmN7ZVKR/7yxVY5NbwzLjdbCiXtUNTSqttiszPlyjshK8Iy/t9dUAAHoPggw65FSNulL7q4aWbjus9XsrJEl/XX9A9ia3Dh1r0Evr9mn9nqNyuY8PFBYdq5fb49HZSRG69Nyz5NF/p6TaW40EAOg9CDLokJaBot7h1BsbDrZaAt3eEmeH06Wvj9bq7Y2Fmja2n0b3j5fVHCKLWTpQ3qC48FCFWky66OwkxUdYdf/VObJaQlTXeHxzO5ZOAwAkemTQAc2ri0ySJg1N17L8Yq3adVQmSZcPTG5zk7my2kb974e71Njk0pbCSkXZLDq/T6zMISGKspm1Ir9EQ/rEaM6kXEVYLa125G1vh14AQM/TI3b2RXA7vrzZ6V3efGl2kj79ulyD02O9IyXN00CXZifp3ztLtHjzIdU0OGUNDdHo/gmSPGpyeZQWa9Nn+yqUkxErk8mkdQXlmjo8o1UYat73BQCAZgQZnLETlzevLSjTucnRig4L9Y6YHO+FKdenX5er0eFWfKRVJpl0zflpum5EHy3LL1ZxVaMOHWvQU98fppU7SyWJKSMAgE/okemlTtek64sTd8a9NDtJu0trVFbbqDc2HPzm2h6ZJA1Oj9XI/vGKDQ/V0zcP1w/H9veO4hw61qCfXna21haUadLQdN04qi9TRwAAn9Aj00s17+HSfGJ0e4cr+nJCdL3Dqbc3Fem9rYfU6HTLbDKpT0KEIm1m5abHyGYx64qcZP1+RYFSYmwqrbZr5pXZWltQJrvTrbpGp/aV1Z3yBGsAQO/i6/c3QaaXOrFx9sRg02zxliIVVzbqUOXxzeraCjNvbDigP320V5V1Drk80rhzk2SzWrSntFbpcWGKDQuVxySlx4Tr4z1HZXd5NCA5UoPSYhUVZpHNYtYl2YneE6wZjQEAsCEeTunEaaH2ljPn5abqUGWDMuLD9d62w+1MR5kUGx4ql0eKCbcoKixUo7ISlBkfror6JtXanapqaNJn+yuUFhcml9sjk0xKjLJp0tB0TR2eoaSoMA5wBAD4jSADSScHm7LaRv1q8XaV19o1uE+somwWbT54TAfL6/Xwe/ktemCk8TnJGpQeo9svPVt9EyL140uy5HC6FRlm0SOTcxVlsyjym36Ykf0TNSwzTvdffR7BBQDQYUwt9WJt9b/UO5x67bODen7NXp2bHC2PpG+fl6x9ZXVKjrHpvW1H1C8xXGW1Dp2XEq0HJ+d6VyaFhYbo3ORo7SurU02jUyaTFGWzKCsp0juFxNQRAMAXTC3Bq+UKpZZ/bl76PO+DnSqrbdTiLUV6e1ORnl21R/UOl3aV1GjceWfJagnRualROnysQd85P01V9U2qqHOosqFJy/NL5HC6VFhRr3OTo5UYZdPMK7N10YBEjT0nUTOvzG41hUSIAQB0Jr5VeoGWxwg0Ol1aW1CmdQVlGpoZpz2ltWpscul/l+3SucnR2l1ao5Qomw5WNmh0/ziZZNKOw1WqtTsVZjErIdKq9Lgw2Z1uRVmPj7As3XZY6bFh2l1ao4fGHF863XywoyRWIQEAugwjMr1AcyPviH5xentjobYVVuqrw9X6x5ZDcrpcOlLV6B1NuS9voHIyYjXgrCgNzojX7tIaNTqcKqyo197SWm0trNT+igaFhJg06uxERVgtmjw0XXERVvVPjDxFQzAAAJ2PINPDNU8hXZKdqHv/b5uKjtXLJCnSZlZaXLgsZrP6JkTIaglRdWOT5v9zp5wul/okRGjP0VoVlNTK7vQoItSiPvERMnk8SouxyRZ6fG8Y6Xij8APX5CgtNlweeTiVGgDQbZha6qGOH+h4SBv2VcjudOuzveWKDg+Vw+lReqxNT9w4TJsOVGpEvzg9t/preSQt2XJIpdWNSoqyaXBGrM5NjfJOIf3uu4O1rqBcl2QnauGKAg3NDPeehyT9d9VTy/1pAADoaozI9CAnNvKu31OhvaV1OlLZqNyMWEWGWmQySYlRVn2yt1weefTRrtLjK4wkXTMkTemxYZo0NF0zr8yWZFKkzax7rxro3eclKSrMO/rSVlg5cRk3AABdiW8bA2t5svTagjLVNDq1af8xOZxuTRqartpGp5wul0prHbp84FnauL9CLrdHFXV21TQ26Z/bi5UZHy6XR9pxuErD+8Zp3Hkp2ldWp3/vLNGm/cdkMqnVyIvEKdQAgOBBkDGw5tVIC1cUKCspUrtLa2QySbWNTj28NF+NTS5tP1StqDCLnvlojwor6iVJmfERKiitlcPplslkUkWtXdlnJcgkk3cXX5NMumhAoiROogYABC+CjMG03MQuLzdV7207rP5JkfrySLVy0qIVHRaqz/eV67OvK+R2u9XodCsuwqKdh6t1tNahqDCLRvRP0MQhqXp82S41udwak5WgozV23XbZ2Zo0NL3VGUwAAAQzvqkMpuWeMBNyU/RFUZXKaxu143C1Pvu6XCP6JcjpcsktjxqcHrk9HlU1uuRocqmqoUmDM2I0cUiq/r2zRB5J6XHh+nxfhf7w/WHe4MK0EQDAKGj2NZhLs5NUUFKj6kaHlm47pJQYm3aX1Ki20aWyWrv+vbNEByoaZDWbFW4NkdkUopyUSIVZzTr7rEhNGpqhtQVlWr+nQnV2lzYfqNQlA5K0rqA80H81AAD8xohMEDu+hPqwJI/G56RobUGZ7E63Khsd+vPafRqeGSdJcjg9klwySTKHSFZLiH4yKku1jU6t3n1U52cmyGwxK9pm0XUj+nzzHrck6YqcZK0rKKcPBgBgSASZILYsv1hrC8p0pLJBWw5W6tzkaEWFWVRe41B1Q5N2ldSoqKJOTS6PTCaTEiKsMptNujInRaGWEO0prVFcRKj+81WpUmPDNLxvvHf6iCMEAAA9AUEmSLRs4pWkpdsOyeF0K8xiUnpsmAanx8pqCZHd6VZydJiOVDWqsLxOTW7J4ZLCzFKD061+sRHaX1GnHYertL2oSjJJg9JiFBceqklD0wP8twQAoHMRZIJEcxPve9sO64uiKlXWORRqCdGIfvHKP1Slmkan/v1lsY5UNaqyoUnyuBVqMauu6fiUkuSRPB71iY/QfXkD9b/LdmlQRoyirBaNykrQpKHprEICAPQ4fLMFibzcVO/p1H3iw2VvcslsNskjj2rtTv1l3T41OptUa/d43+PyuBQRapbkUVhoiFweqaS6UUlRYXpoUi7LqAEAPR6rlgKo+UiBwoo6zf1gp0b0i5PkUZTNInOISf0TIr2b1JlMahViJCnKFqpxAxN19ZA0fX90X4WaQ/Ttgf89yJGjAgAAPR3fct2oZR9MhNXinU669/+2qbzWoXc2FWp0VqLsTpcuHpCkXSU12lZYqfomp47WOrzXCZF0VoxVI/vFSx6TPJKKjjWoT3y4osL4SAEAvQffep3gxIDS1vNvbyrUB18Ua1S/eM3bv1Mzr8xWTWOTvjxSJas5RHuO1kmSPtlbpohQs/IPV6t/YoTqHE5V1jtkkuSRZJIUaTNLMml433j9a3uxXB6P5JHMISZZLQyyAQB6D5PH4/Gc/mXGVV1drdjYWFVVVSkmJqZLfsbiLUWqqHUoMcrmXcp8fA+YQ9I3rbh/X39AByvq5HZ7NOrsRFXVO1RSbVe93ak6h1MO1/FrhYZITe7j78qID5PNYtb4nGS9taFQlQ1OeSQlR4fqxlF91Sc+QtJ/94SxWkJo6gUA9Ai+fn/zjdcBLU+fPnFTuWX5xVq7q0yHKht0TnKkLCEmNTa5FWoJ0epdR5UcZVVVo1NOl1smfTPSYjXJ6fLILclkkob1iVPWWZF6a2ORrKEh8jQcv3Zjk0fnnBVNIy8AoNdjHqIDmntc1hWUn9RYO7JfvLYWHdP2Q1UqKK7Rl0eq5HF7VGd3yeORqhublBhpVXhoyDejLFaFWy2KDg+V2WzS4PQY/XLiefpwR4kirRa5XB4lR1pkMUl//P4wGnkBABBBpkPyclOVGGXzjsQ0r0Kqdzi1aPXXKq60y+WRvjhcI4dLavpmEs8jySOTIq1m2Z0exYVb5PF4dO0FfdTQ5JY1RDp4rEHL80t046hMJcfY9NptF2rC4HTdM2GgjtU7A/eXBgAgiBBkOqC81q63NhaqvNaueodTcxZv10vr9mn2O18oKymiuT2mTY4mt4qONSguPFQeU4hSY8P1wfYjMn/TI2OzhGjFlyVqcro14Kwo9UmI0APX5CgtLpxzkQAA+AbNvh1w7Z8+1peHq5QcE6actGj9Z2ep3B7JajYpLNSs6ganXO281yQpKSpUDU1uXXxOokLNITpc2ahGp0s1jU6N6Beveyacq+dWf62M+HClxYZzJhIAoNfw9fubEZkOcDS5ZHd6VFrVoI++KpXTLbk8UoPTo2OnCDE2i5QeF6Ywq0UxYRaFWy0aPyhVz08boeiwUEXbLAqzhCgzIfL4KEwsozAAALTFEEHmmWeeUf/+/RUWFqYxY8bo888/D3RJqnc4Vdd0fDl0o0tqch3vfTkdW4g0OD1Ol2Qn6brhfWS1mDUkI1ZTh2coKSpM15yfqnCrWdkp0Vq8pUiSaOwFAKAdQR9k3nzzTc2aNUsPPfSQNm/erKFDhyovL0+lpaUBrWtZfrH2lTV47/s8P2eSrshJ0bnJ0TpS3aBhmXGtduO9ekiactJiFGoxqaLWoeX5JZ1bOAAAPUjQB5knn3xSt912m370ox9p0KBBWrRokSIiIvSXv/ylzdfb7XZVV1e3unWFS7OT/PrHs4Qc753plxipr4/WKjHKpvvyBurygcmaNDTd+7q1BWXKSoqUzWJutSIKAACcLKiDjMPh0KZNmzR+/HjvYyEhIRo/frzWr1/f5nvmz5+v2NhY7y0zM7NLavv3zlK5/Xh9lNWilFibwkLNGpoZ551KOnHaqHlJ96Sh6UwpAQBwGkEdZMrKyuRyuZSS0npUIiUlRcXFxW2+Z/bs2aqqqvLeCgsLu6S22kbf93IxS4qPsursxChFh4XquhF92n0tp1YDAOC7HvdtabPZZLPZuvznFJTWtPuc2XR89ZJ0fJl1XKRFf/vxaD23+mvNvDKbkAIAQCcJ6hGZpKQkmc1mlZS0bngtKSlRampqgKo67r68gUqI+G8gMen4P+bAlCiNz0nW2UkRumlkhqLDLPrTzSOUmRCp3313iJKiwgJWMwAAPU1QBxmr1aoRI0Zo5cqV3sfcbrdWrlypsWPHBrAyKSkqTOaQ/27dG24x6YpByYqLtGpY33j99LJzdPBYo0b2i9d7244EsFIAAHquoA4ykjRr1iy98MIL+utf/6qdO3fqjjvuUF1dnX70ox8FujT9edooRdnMSo6yaFxOisbnpOiZ/zdcabHhsjtdigk3q97h0swrswNdKgAAPVLQN2vceOONOnr0qB588EEVFxdr2LBh+vDDD09qAA6EoZnx+vxX4zXvg506K9qm7UVV3tVGb2w4qDCLRd+9IIXpJAAAughnLXWCeodT8z7YqYy4cKXFHT8Tqd7h1PL8Ek3ITaG5FwAAP3HWUjeKsFpOOpmaZdQAAHQ9vmU7SXNwAQAA3YcRGQAAYFgEGQAAYFgEGQAAYFgEGQAAYFgEGQAAYFgEGQAAYFgEGQAAYFgEGQAAYFgEGQAAYFgEGQAAYFgEGQAAYFgEGQAAYFgEGQAAYFg9/vRrj8cjSaqurg5wJQAAwFfN39vN3+Pt6fFBpqamRpKUmZkZ4EoAAIC/ampqFBsb2+7zJs/poo7Bud1uHT58WNHR0TKZTB2+XnV1tTIzM1VYWKiYmJhOqBDdgc/NePjMjInPzXiC9TPzeDyqqalRenq6QkLa74Tp8SMyISEh6tOnT6dfNyYmJqg+cPiGz814+MyMic/NeILxMzvVSEwzmn0BAIBhEWQAAIBhEWT8ZLPZ9NBDD8lmswW6FPiBz814+MyMic/NeIz+mfX4Zl8AANBzMSIDAAAMiyADAAAMiyADAAAMiyADAAAMiyDjp2eeeUb9+/dXWFiYxowZo88//zzQJaEd8+fP16hRoxQdHa3k5GRNnTpVu3btCnRZ8NOjjz4qk8mkmTNnBroUnMKhQ4f0gx/8QImJiQoPD9eQIUO0cePGQJeFU3C5XJozZ46ysrIUHh6uc845R4888shpzzYKNgQZP7z55puaNWuWHnroIW3evFlDhw5VXl6eSktLA10a2rB69WpNnz5dn376qVasWKGmpiZNmDBBdXV1gS4NPtqwYYOee+45nX/++YEuBadw7NgxXXzxxQoNDdW//vUvffnll3riiScUHx8f6NJwCo899pieffZZPf3009q5c6cee+wxLViwQH/84x8DXZpfWH7thzFjxmjUqFF6+umnJR0/xykzM1M///nPdf/99we4OpzO0aNHlZycrNWrV+tb3/pWoMvBadTW1uqCCy7Qn/70J/3ud7/TsGHDtHDhwkCXhTbcf//9+vjjj7V27dpAlwI/fOc731FKSopefPFF72PXXXedwsPD9fe//z2AlfmHERkfORwObdq0SePHj/c+FhISovHjx2v9+vUBrAy+qqqqkiQlJCQEuBL4Yvr06brmmmta/X8OwWnp0qUaOXKkrr/+eiUnJ2v48OF64YUXAl0WTuOiiy7SypUrtXv3bknStm3btG7dOk2cODHAlfmnxx8a2VnKysrkcrmUkpLS6vGUlBR99dVXAaoKvnK73Zo5c6YuvvhiDR48ONDl4DTeeOMNbd68WRs2bAh0KfDB119/rWeffVazZs3SAw88oA0bNmjGjBmyWq2aNm1aoMtDO+6//35VV1frvPPOk9lslsvl0ty5c3XzzTcHujS/EGTQK0yfPl07duzQunXrAl0KTqOwsFB33XWXVqxYobCwsECXAx+43W6NHDlS8+bNkyQNHz5cO3bs0KJFiwgyQeytt97Sq6++qtdee025ubnaunWrZs6cqfT0dEN9bgQZHyUlJclsNqukpKTV4yUlJUpNTQ1QVfDFnXfeqffff19r1qxRnz59Al0OTmPTpk0qLS3VBRdc4H3M5XJpzZo1evrpp2W322U2mwNYIU6UlpamQYMGtXosJydH77zzToAqgi/uu+8+3X///brpppskSUOGDNGBAwc0f/58QwUZemR8ZLVaNWLECK1cudL7mNvt1sqVKzV27NgAVob2eDwe3XnnnVq8eLH+85//KCsrK9AlwQdXXHGFtm/frq1bt3pvI0eO1M0336ytW7cSYoLQxRdffNLWBrt371a/fv0CVBF8UV9fr5CQ1jHAbDbL7XYHqKIzw4iMH2bNmqVp06Zp5MiRGj16tBYuXKi6ujr96Ec/CnRpaMP06dP12muv6R//+Ieio6NVXFwsSYqNjVV4eHiAq0N7oqOjT+pjioyMVGJiIv1NQeruu+/WRRddpHnz5umGG27Q559/rueff17PP/98oEvDKUyaNElz585V3759lZubqy1btujJJ5/Uj3/840CX5h8P/PLHP/7R07dvX4/VavWMHj3a8+mnnwa6JLRDUpu3l156KdClwU+XXXaZ56677gp0GTiF9957zzN48GCPzWbznHfeeZ7nn38+0CXhNKqrqz133XWXp2/fvp6wsDDP2Wef7fnVr37lsdvtgS7NL+wjAwAADIseGQAAYFgEGQAAYFgEGQAAYFgEGQAAYFgEGQAAYFgEGQAAYFgEGQAAYFgEGQAAYFgEGQBB59Zbb9XUqVO99y+//HLNnDmzQ9fsjGsACD4EGQA+u/XWW2UymWQymWS1WjVgwAD99re/ldPp7NKf++677+qRRx7x6bWrVq2SyWRSZWXlGV8DgHFwaCQAv1x11VV66aWXZLfb9c9//lPTp09XaGioZs+e3ep1DodDVqu1U35mQkJCUFwDQPBhRAaAX2w2m1JTU9WvXz/dcccdGj9+vJYuXeqdDpo7d67S09M1cOBASVJhYaFuuOEGxcXFKSEhQVOmTNH+/fu913O5XJo1a5bi4uKUmJioX/ziFzrxCLgTp4Xsdrt++ctfKjMzUzabTQMGDNCLL76o/fv3a9y4cZKk+Ph4mUwm3XrrrW1e49ixY7rlllsUHx+viIgITZw4UQUFBd7nX375ZcXFxWnZsmXKyclRVFSUrrrqKh05csT7mlWrVmn06NGKjIxUXFycLr74Yh04cKCT/qUB+IIgA6BDwsPD5XA4JEkrV67Url27tGLFCr3//vtqampSXl6eoqOjtXbtWn388cfeQND8nieeeEIvv/yy/vKXv2jdunWqqKjQ4sWLT/kzb7nlFr3++ut66qmntHPnTj333HOKiopSZmam3nnnHUnSrl27dOTIEf3hD39o8xq33nqrNm7cqKVLl2r9+vXyeDy6+uqr1dTU5H1NfX29Hn/8cb3yyitas2aNDh48qHvvvVeS5HQ6NXXqVF122WX64osvtH79et1+++0ymUwd/jcF4DumlgCcEY/Ho5UrV2rZsmX6+c9/rqNHjyoyMlJ//vOfvVNKf//73+V2u/XnP//Z+wX/0ksvKS4uTqtWrdKECRO0cOFCzZ49W9dee60kadGiRVq2bFm7P3f37t166623tGLFCo0fP16SdPbZZ3ufb55CSk5OVlxcXJvXKCgo0NKlS/Xxxx/roosukiS9+uqryszM1JIlS3T99ddLkpqamrRo0SKdc845kqQ777xTv/3tbyVJ1dXVqqqq0ne+8x3v8zk5Of7/QwLoEEZkAPjl/fffV1RUlMLCwjRx4kTdeOON+s1vfiNJGjJkSKu+mG3btmnPnj2Kjo5WVFSUoqKilJCQoMbGRu3du1dVVVU6cuSIxowZ432PxWLRyJEj2/35W7duldls1mWXXXbGf4edO3fKYrG0+rmJiYkaOHCgdu7c6X0sIiLCG1IkKS0tTaWlpZKOB6Zbb71VeXl5mjRpkv7whz+0mnYC0D0YkQHgl3HjxunZZ5+V1WpVenq6LJb//hqJjIxs9dra2lqNGDFCr7766knXOeuss87o54eHh5/R+85EaGhoq/smk6lV/85LL72kGTNm6MMPP9Sbb76pX//611qxYoUuvPDCbqsR6O0YkQHgl8jISA0YMEB9+/ZtFWLacsEFF6igoEDJyckaMGBAq1tsbKxiY2OVlpamzz77zPsep9OpTZs2tXvNIUOGyO12a/Xq1W0+3zwi5HK52r1GTk6OnE5nq59bXl6uXbt2adCgQaf8O51o+PDhmj17tj755BMNHjxYr732ml/vB9AxBBkAXebmm29WUlKSpkyZorVr12rfvn1atWqVZsyYoaKiIknSXXfdpUcffVRLlizRV199pZ/97Gcn7QHTUv/+/TVt2jT9+Mc/1pIlS7zXfOuttyRJ/fr1k8lk0vvvv6+jR4+qtrb2pGtkZ2drypQpuu2227Ru3Tpt27ZNP/jBD5SRkaEpU6b49Hfbt2+fZs+erfXr1+vAgQNavny5CgoK6JMBuhlBBkCXiYiI0Jo1a9S3b19de+21ysnJ0U9+8hM1NjYqJiZGknTPPffohz/8oaZNm6axY8cqOjpa3/3ud0953WeffVbf+9739LOf/UznnXeebrvtNtXV1UmSMjIy9PDDD+v+++9XSkqK7rzzzjav8dJLL2nEiBH6zne+o7Fjx8rj8eif//znSdNJp/q7ffXVV7ruuut07rnn6vbbb9f06dP105/+1I9/IQAdZfKcuGEDAACAQTAiAwAADIsgAwAADIsgAwAADIsgAwAADIsgAwAADIsgAwAADIsgAwAADIsgAwAADIsgAwAADIsgAwAADIsgAwAADOv/Ax67S9hWGwT/AAAAAElFTkSuQmCC\n"
},
"metadata": {}
}
],
"source": [
"# plotting actual values vs. predictions\n",
"ypred = model4(torch.from_numpy(np.float32(x_test)).to(device)).cpu().detach().numpy()\n",
"plt.scatter(ypred,y_non_test,s=.1)\n",
"plt.xlabel(\"Predictions\");\n",
"plt.ylabel(\"Actual values\");"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KQalPfpB0kNz"
},
"source": [
"As you can see, the deep net has achieved similar performance to the shallow network, with only a fraction of the parameters!"
]
},
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "9RDXf_L_9HUY",
"outputId": "2ad4145c-68db-4de4-b0f6-d7a4fd442ac9"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"FancyANN(\n",
" (layerlist): ModuleList(\n",
" (0): Linear(in_features=1, out_features=32, bias=True)\n",
" (1): Tanh()\n",
" (2): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (3): Dropout(p=0.01, inplace=False)\n",
" (4): Linear(in_features=32, out_features=16, bias=True)\n",
" (5): Tanh()\n",
" (6): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (7): Dropout(p=0.01, inplace=False)\n",
" (8): Linear(in_features=16, out_features=8, bias=True)\n",
" (9): Tanh()\n",
" (10): BatchNorm1d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (11): Dropout(p=0.01, inplace=False)\n",
" (12): Linear(in_features=8, out_features=4, bias=True)\n",
" (13): Tanh()\n",
" (14): BatchNorm1d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (15): Dropout(p=0.01, inplace=False)\n",
" (16): Linear(in_features=4, out_features=2, bias=True)\n",
" (17): Tanh()\n",
" (18): BatchNorm1d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (19): Dropout(p=0.01, inplace=False)\n",
" (20): Linear(in_features=34, out_features=1, bias=True)\n",
" )\n",
")\n"
]
}
],
"source": [
"# define model architecture\n",
"class FancyANN(nn.Module):\n",
" def __init__(self, ninput=1, noutput=1, nhlayer=5, fextint=True):\n",
" super().__init__()\n",
" self.fextint = 0\n",
" if fextint:\n",
" self.fextint = 2**nhlayer\n",
" self.nhlayer = nhlayer\n",
" self.layerlist = nn.ModuleList()\n",
" for i in range(nhlayer):\n",
" if i == 0:\n",
" self.layerlist.append(nn.Linear(ninput,2**(nhlayer-i)))\n",
" else:\n",
" self.layerlist.append(nn.Linear(2**(nhlayer-i+1),2**(nhlayer-i)))\n",
" self.layerlist.append(nn.Tanh()) # here's our alternative activation function (tanh)\n",
" self.layerlist.append(nn.BatchNorm1d(2**(nhlayer-i)))\n",
" self.layerlist.append(nn.Dropout(p=.01)) ## here's where we add dropout\n",
" self.layerlist.append(nn.Linear(self.fextint+2,noutput))\n",
"\n",
" def forward(self, x):\n",
" for i in range(len(self.layerlist)-1):\n",
" x = self.layerlist[i](x)\n",
" if i == 2:\n",
" x0 = torch.clone(x)\n",
" pred = self.layerlist[-1](torch.cat((x0,x),1)) ## here's where the first layer and penultimate layer get joined together as inputs to the final layer\n",
" return pred\n",
"\n",
"model5 = FancyANN().to(device)\n",
"print(model5)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "bnfctx2N8Zoz",
"outputId": "251c68f6-a1a9-4f37-ad3b-6e766292ac56"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Epoch 1\n",
"-------------------------------\n",
"loss: 3.888131 [ 64/100000]\n",
"loss: 3.917098 [ 6464/100000]\n",
"loss: 2.937879 [12864/100000]\n",
"loss: 5.071634 [19264/100000]\n",
"loss: 3.716036 [25664/100000]\n",
"loss: 2.398747 [32064/100000]\n",
"loss: 1.351015 [38464/100000]\n",
"loss: 1.365817 [44864/100000]\n",
"loss: 1.875856 [51264/100000]\n",
"loss: 1.061013 [57664/100000]\n",
"loss: 1.337905 [64064/100000]\n",
"loss: 1.853628 [70464/100000]\n",
"loss: 0.764149 [76864/100000]\n",
"loss: 0.915835 [83264/100000]\n",
"loss: 0.340542 [89664/100000]\n",
"loss: 0.346805 [96064/100000]\n",
"Test performance: \n",
" R^2: 0.780519, RMSE: 0.845239 \n",
"\n",
"Epoch 2\n",
"-------------------------------\n",
"loss: 0.396910 [ 64/100000]\n",
"loss: 0.205753 [ 6464/100000]\n",
"loss: 0.607033 [12864/100000]\n",
"loss: 1.113464 [19264/100000]\n",
"loss: 0.097006 [25664/100000]\n",
"loss: 0.100322 [32064/100000]\n",
"loss: 0.931997 [38464/100000]\n",
"loss: 0.178657 [44864/100000]\n",
"loss: 0.537316 [51264/100000]\n",
"loss: 0.205982 [57664/100000]\n",
"loss: 0.245557 [64064/100000]\n",
"loss: 0.080524 [70464/100000]\n",
"loss: 0.222449 [76864/100000]\n",
"loss: 0.213365 [83264/100000]\n",
"loss: 0.096103 [89664/100000]\n",
"loss: 0.591012 [96064/100000]\n",
"Test performance: \n",
" R^2: 0.979044, RMSE: 0.285466 \n",
"\n",
"Epoch 3\n",
"-------------------------------\n",
"loss: 0.064183 [ 64/100000]\n",
"loss: 0.243895 [ 6464/100000]\n",
"loss: 0.307392 [12864/100000]\n",
"loss: 0.179972 [19264/100000]\n",
"loss: 0.183356 [25664/100000]\n",
"loss: 0.188181 [32064/100000]\n",
"loss: 0.091715 [38464/100000]\n",
"loss: 0.677460 [44864/100000]\n",
"loss: 0.177787 [51264/100000]\n",
"loss: 0.066920 [57664/100000]\n",
"loss: 0.166385 [64064/100000]\n",
"loss: 0.066744 [70464/100000]\n",
"loss: 0.129679 [76864/100000]\n",
"loss: 0.137251 [83264/100000]\n",
"loss: 0.135885 [89664/100000]\n",
"loss: 0.153805 [96064/100000]\n",
"Test performance: \n",
" R^2: 0.986803, RMSE: 0.227382 \n",
"\n",
"Epoch 4\n",
"-------------------------------\n",
"loss: 0.329610 [ 64/100000]\n",
"loss: 0.309667 [ 6464/100000]\n",
"loss: 0.238928 [12864/100000]\n",
"loss: 0.049671 [19264/100000]\n",
"loss: 0.172278 [25664/100000]\n",
"loss: 0.167100 [32064/100000]\n",
"loss: 0.202609 [38464/100000]\n",
"loss: 0.105646 [44864/100000]\n",
"loss: 0.214544 [51264/100000]\n",
"loss: 0.181315 [57664/100000]\n",
"loss: 0.105952 [64064/100000]\n",
"loss: 0.054405 [70464/100000]\n",
"loss: 0.230735 [76864/100000]\n",
"loss: 0.109101 [83264/100000]\n",
"loss: 0.124816 [89664/100000]\n",
"loss: 0.099370 [96064/100000]\n",
"Test performance: \n",
" R^2: 0.988537, RMSE: 0.204467 \n",
"\n",
"Epoch 5\n",
"-------------------------------\n",
"loss: 0.063645 [ 64/100000]\n",
"loss: 0.123861 [ 6464/100000]\n",
"loss: 0.083427 [12864/100000]\n",
"loss: 0.250894 [19264/100000]\n",
"loss: 0.298372 [25664/100000]\n",
"loss: 0.049779 [32064/100000]\n",
"loss: 0.054952 [38464/100000]\n",
"loss: 0.233167 [44864/100000]\n",
"loss: 0.297191 [51264/100000]\n",
"loss: 0.102348 [57664/100000]\n",
"loss: 0.036531 [64064/100000]\n",
"loss: 0.079810 [70464/100000]\n",
"loss: 0.119193 [76864/100000]\n",
"loss: 0.105115 [83264/100000]\n",
"loss: 0.118437 [89664/100000]\n",
"loss: 0.212578 [96064/100000]\n",
"Test performance: \n",
" R^2: 0.989374, RMSE: 0.194593 \n",
"\n",
"Epoch 6\n",
"-------------------------------\n",
"loss: 0.104637 [ 64/100000]\n",
"loss: 0.124436 [ 6464/100000]\n",
"loss: 0.227147 [12864/100000]\n",
"loss: 0.285004 [19264/100000]\n",
"loss: 0.306534 [25664/100000]\n",
"loss: 0.126568 [32064/100000]\n",
"loss: 0.062833 [38464/100000]\n",
"loss: 0.070517 [44864/100000]\n",
"loss: 0.049845 [51264/100000]\n",
"loss: 0.080152 [57664/100000]\n",
"loss: 0.142207 [64064/100000]\n",
"loss: 0.082779 [70464/100000]\n",
"loss: 0.067380 [76864/100000]\n",
"loss: 0.080468 [83264/100000]\n",
"loss: 0.099579 [89664/100000]\n",
"loss: 0.149001 [96064/100000]\n",
"Test performance: \n",
" R^2: 0.991916, RMSE: 0.170523 \n",
"\n",
"Epoch 7\n",
"-------------------------------\n",
"loss: 0.090363 [ 64/100000]\n",
"loss: 0.105410 [ 6464/100000]\n",
"loss: 0.040804 [12864/100000]\n",
"loss: 0.054256 [19264/100000]\n",
"loss: 0.027425 [25664/100000]\n",
"loss: 0.207500 [32064/100000]\n",
"loss: 0.090952 [38464/100000]\n",
"loss: 0.086845 [44864/100000]\n",
"loss: 0.079099 [51264/100000]\n",
"loss: 0.275742 [57664/100000]\n",
"loss: 0.028412 [64064/100000]\n",
"loss: 0.116104 [70464/100000]\n",
"loss: 0.120119 [76864/100000]\n",
"loss: 0.070722 [83264/100000]\n",
"loss: 0.085920 [89664/100000]\n",
"loss: 0.083501 [96064/100000]\n",
"Test performance: \n",
" R^2: 0.991818, RMSE: 0.170135 \n",
"\n",
"Epoch 8\n",
"-------------------------------\n",
"loss: 0.223899 [ 64/100000]\n",
"loss: 0.205298 [ 6464/100000]\n",
"loss: 0.073132 [12864/100000]\n",
"loss: 1.529429 [19264/100000]\n",
"loss: 0.101772 [25664/100000]\n",
"loss: 0.157559 [32064/100000]\n",
"loss: 0.073842 [38464/100000]\n",
"loss: 0.123590 [44864/100000]\n",
"loss: 0.053510 [51264/100000]\n",
"loss: 0.150473 [57664/100000]\n",
"loss: 0.092963 [64064/100000]\n",
"loss: 0.050164 [70464/100000]\n",
"loss: 0.299292 [76864/100000]\n",
"loss: 0.059084 [83264/100000]\n",
"loss: 0.048018 [89664/100000]\n",
"loss: 0.114542 [96064/100000]\n",
"Test performance: \n",
" R^2: 0.992695, RMSE: 0.156656 \n",
"\n",
"Epoch 9\n",
"-------------------------------\n",
"loss: 0.048064 [ 64/100000]\n",
"loss: 1.965292 [ 6464/100000]\n",
"loss: 0.050489 [12864/100000]\n",
"loss: 0.111465 [19264/100000]\n",
"loss: 0.129185 [25664/100000]\n",
"loss: 0.453362 [32064/100000]\n",
"loss: 0.169424 [38464/100000]\n",
"loss: 0.096436 [44864/100000]\n",
"loss: 0.081334 [51264/100000]\n",
"loss: 0.086524 [57664/100000]\n",
"loss: 0.108220 [64064/100000]\n",
"loss: 0.026716 [70464/100000]\n",
"loss: 0.330060 [76864/100000]\n",
"loss: 0.081770 [83264/100000]\n",
"loss: 0.102128 [89664/100000]\n",
"loss: 0.380724 [96064/100000]\n",
"Test performance: \n",
" R^2: 0.992518, RMSE: 0.160280 \n",
"\n",
"Epoch 10\n",
"-------------------------------\n",
"loss: 0.067976 [ 64/100000]\n",
"loss: 0.061944 [ 6464/100000]\n",
"loss: 0.093701 [12864/100000]\n",
"loss: 0.062002 [19264/100000]\n",
"loss: 0.124478 [25664/100000]\n",
"loss: 0.267889 [32064/100000]\n",
"loss: 0.094276 [38464/100000]\n",
"loss: 0.100655 [44864/100000]\n",
"loss: 0.194106 [51264/100000]\n",
"loss: 1.437881 [57664/100000]\n",
"loss: 0.037840 [64064/100000]\n",
"loss: 0.100686 [70464/100000]\n",
"loss: 0.042177 [76864/100000]\n",
"loss: 0.150511 [83264/100000]\n",
"loss: 0.025171 [89664/100000]\n",
"loss: 0.045438 [96064/100000]\n",
"Test performance: \n",
" R^2: 0.991603, RMSE: 0.166102 \n",
"\n",
"Epoch 11\n",
"-------------------------------\n",
"loss: 0.113018 [ 64/100000]\n",
"loss: 0.364865 [ 6464/100000]\n",
"loss: 0.086984 [12864/100000]\n",
"loss: 0.089000 [19264/100000]\n",
"loss: 0.048422 [25664/100000]\n",
"loss: 0.111554 [32064/100000]\n",
"loss: 0.170104 [38464/100000]\n",
"loss: 0.456190 [44864/100000]\n",
"loss: 0.430310 [51264/100000]\n",
"loss: 0.137799 [57664/100000]\n",
"loss: 0.087476 [64064/100000]\n",
"loss: 0.034155 [70464/100000]\n",
"loss: 0.051317 [76864/100000]\n",
"loss: 0.139362 [83264/100000]\n",
"loss: 0.162727 [89664/100000]\n",
"loss: 0.079858 [96064/100000]\n",
"Test performance: \n",
" R^2: 0.992894, RMSE: 0.148741 \n",
"\n",
"Epoch 12\n",
"-------------------------------\n",
"loss: 0.156293 [ 64/100000]\n",
"loss: 0.093481 [ 6464/100000]\n",
"loss: 0.044689 [12864/100000]\n",
"loss: 0.069560 [19264/100000]\n",
"loss: 0.049289 [25664/100000]\n",
"loss: 0.330965 [32064/100000]\n",
"loss: 0.060987 [38464/100000]\n",
"loss: 0.041339 [44864/100000]\n",
"loss: 0.157445 [51264/100000]\n",
"loss: 0.097228 [57664/100000]\n",
"loss: 0.154629 [64064/100000]\n",
"loss: 0.153998 [70464/100000]\n",
"loss: 0.262758 [76864/100000]\n",
"loss: 0.049504 [83264/100000]\n",
"loss: 0.030409 [89664/100000]\n",
"loss: 0.068077 [96064/100000]\n",
"Test performance: \n",
" R^2: 0.992230, RMSE: 0.156347 \n",
"\n",
"Epoch 13\n",
"-------------------------------\n",
"loss: 0.205354 [ 64/100000]\n",
"loss: 0.252282 [ 6464/100000]\n",
"loss: 0.046880 [12864/100000]\n",
"loss: 0.122652 [19264/100000]\n",
"loss: 0.040483 [25664/100000]\n",
"loss: 0.332348 [32064/100000]\n",
"loss: 0.209426 [38464/100000]\n",
"loss: 0.071124 [44864/100000]\n",
"loss: 0.103392 [51264/100000]\n",
"loss: 0.065575 [57664/100000]\n",
"loss: 0.078010 [64064/100000]\n",
"loss: 0.036993 [70464/100000]\n",
"loss: 0.177545 [76864/100000]\n",
"loss: 0.105658 [83264/100000]\n",
"loss: 0.056419 [89664/100000]\n",
"loss: 0.047285 [96064/100000]\n",
"Test performance: \n",
" R^2: 0.991006, RMSE: 0.164956 \n",
"\n",
"Epoch 14\n",
"-------------------------------\n",
"loss: 0.142373 [ 64/100000]\n",
"loss: 0.032840 [ 6464/100000]\n",
"loss: 0.054341 [12864/100000]\n",
"loss: 0.045234 [19264/100000]\n",
"loss: 0.072157 [25664/100000]\n",
"loss: 0.146573 [32064/100000]\n",
"loss: 0.042619 [38464/100000]\n",
"loss: 0.093733 [44864/100000]\n",
"loss: 0.126631 [51264/100000]\n",
"loss: 0.174637 [57664/100000]\n",
"loss: 0.021763 [64064/100000]\n",
"loss: 0.069974 [70464/100000]\n",
"loss: 0.037953 [76864/100000]\n",
"loss: 0.215359 [83264/100000]\n",
"loss: 0.056612 [89664/100000]\n",
"loss: 0.086736 [96064/100000]\n",
"Test performance: \n",
" R^2: 0.991521, RMSE: 0.160487 \n",
"\n",
"Done!\n"
]
}
],
"source": [
"# new we'll train the model\n",
"early_stopper = EarlyStopper(patience=3, min_delta=0)\n",
"loss = nn.MSELoss()\n",
"optimizer = torch.optim.Adam(model5.parameters(),lr=.0001) # we're trying a different optimizer here too (ADAM)\n",
"epochs = 100 # this is just an upperbound - the actual epoch # is determined by early stopping\n",
"for t in range(epochs):\n",
" print(f\"Epoch {t+1}\\n-------------------------------\")\n",
" train(train_dataloader, model5, loss, optimizer)\n",
" val_loss = test(val_dataloader, model5, loss)[1]\n",
" if early_stopper.early_stop(val_loss):\n",
" break\n",
"print(\"Done!\")"
]
},
},
"image/png": },
{
"cell_type": "markdown",
"source": [
"## Next steps\n",
"This tutorial has provided a basic example of how one can create one's own artificial neural network using pytorch. However, there's a lot more that you can do with deep nets beyond what's been shown here. This final section will point you to some directions for further learning that you may be interested in."
],
"metadata": {
"id": "gHElnK1saJmr"
}
},
{
"cell_type": "markdown",
"source": [
"### Other layer connectivity patterns\n",
"The layers in the networks you've seen here are all \"densely\" connected. This means that every unit in one layer is connected to every unit in another layer. Most of the connections are also sequential, with the exception of the skip layer connection in the last model. However, there are a wide variety of other connectivity patterns that can perform better for particular applications. Several of these patterns are illustrated in the figure below.\n",
"\n",
"![](https://mysocialbrain.org/misc/data/ann_tutorial/Fig1_DS_hires_bottom.png)\n",
"\n",
"* The bottleneck (red) in the autoencoder illustration is a layer that is narrower (i.e., has fewer units) than the ones before or after it. This forces this layer to learn a compressed representation of the data - a bit like PCA, but nonlinear, not (necessarily) orthogonal, and with potentially different loss functions than maximizing variance explained.\n",
"\n",
"* Convolutional networks are ubiquitous due to their effectiveness in dealing with image-like data (e.g., photos, video, fMRI, or even spectrograms of audio or electrophysiology). The connectivity pattern - and receptive fields that emerge - loosely approximate the human visual system.\n",
"\n",
"* Recurrent connectivity carries a unit's activity forward from one time point to another. Long short-term memory (LSTM) networks are probably the best known example of this type. These models are most often used for time series and other sequential data.\n",
"\n",
"* Perhaps the most important type of unit/connectivity (not pictured due to its complexity) is the attention mechanism of transformer architectures. Transformers are beyond the scope of this introduction, but they support many of the most influential ANNs as of time of writing (e.g., all the major large language models like GPT)."
],
"metadata": {
"id": "1I8406rj4Y_u"
}
