From bbed4af2a17d62ba1523d11972cf0d05c7ac5b12 Mon Sep 17 00:00:00 2001 From: jlparkinson1 Date: Tue, 27 Jun 2023 11:22:35 -0700 Subject: [PATCH] Updated tabular, sequence examples in docs to use xGPR 0.1.2.3. Updated HISTORY and init for 0.1.2.3 release. Fixed bug that caused error when switching from cpu to gpu or vice versa on trained model. Updated complete pipeline tests to do functionality test for variance calc. --- HISTORY.md | 5 + docs/notebooks/sequence_example.ipynb | 59 +- docs/notebooks/tabular_example.ipynb | 571 +++--------------- .../test_current_kernels.py | 4 +- xGPR/__init__.py | 2 +- xGPR/regression_baseclass.py | 8 +- 6 files changed, 121 insertions(+), 528 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index 68f83a0..363e217 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -127,3 +127,8 @@ Updated dataset builder so that different batches with different xdim[1] are now accepted when building a dataset. This obviates the need to zero-pad data (although note that zero-padding is generally advisable for consistency). + +### Version 0.1.2.3 + +Fixed a bug involving changing device after fitting from gpu +to cpu. diff --git a/docs/notebooks/sequence_example.ipynb b/docs/notebooks/sequence_example.ipynb index 992ab82..9afa697 100644 --- a/docs/notebooks/sequence_example.ipynb +++ b/docs/notebooks/sequence_example.ipynb @@ -19,7 +19,7 @@ "match or outperform the deep learning baselines without too\n", "much effort.\n", "\n", - "This was originally run on an A6000 GPU, using xGPR 0.1.0.0." + "This was originally run on an GTX1070 GPU, using xGPR 0.1.2.3." ] }, { @@ -54,7 +54,8 @@ "name": "stderr", "output_type": "stream", "text": [ - "Cloning into 'FLIP'...\n" + "Cloning into 'FLIP'...\n", + "Checking out files: 100% (59/59), done.\n" ] } ], @@ -84,7 +85,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "/tmp/ipykernel_3931463/4195622896.py:1: DtypeWarning: Columns (12) have mixed types. Specify dtype option on import or set low_memory=False.\n", + "/tmp/ipykernel_21623/4195622896.py:1: DtypeWarning: Columns (12) have mixed types. Specify dtype option on import or set low_memory=False.\n", " raw_data = pd.read_csv(\"full_data.csv\")\n" ] } @@ -361,22 +362,22 @@ "Grid point 9 acquired.\n", "New hparams: [-0.8298276]\n", "Additional acquisition 10.\n", - "New hparams: [-0.773694]\n", + "New hparams: [-0.838633]\n", "Additional acquisition 11.\n", - "New hparams: [-0.8054813]\n", + "New hparams: [-0.8383481]\n", "Additional acquisition 12.\n", - "New hparams: [-0.7899]\n", + "New hparams: [-0.8468422]\n", "Additional acquisition 13.\n", - "New hparams: [-0.7798819]\n", + "New hparams: [-0.8321154]\n", "Additional acquisition 14.\n", - "New hparams: [-0.777553]\n", + "New hparams: [-0.8445669]\n", "Additional acquisition 15.\n", "New hparams: [-0.8072598]\n", - "Best score achieved: 120578.24\n", - "Best hyperparams: [-0.852417 -1.6094379 -0.7899 ]\n", + "Best score achieved: 121310.218\n", + "Best hyperparams: [-0.852417 -1.6094379 -0.8298276]\n", "Tuning complete.\n", - "Best estimated negative marginal log likelihood: 120578.24\n", - "Wallclock: 210.5481152534485\n" + "Best estimated negative marginal log likelihood: 121310.218\n", + "Wallclock: 663.9832231998444\n" ] } ], @@ -465,7 +466,7 @@ "Chunk 70 complete.\n", "Chunk 80 complete.\n", "Chunk 90 complete.\n", - "Wallclock: 18.054260969161987\n" + "Wallclock: 52.33980679512024\n" ] } ], @@ -493,8 +494,11 @@ "Iteration 10\n", "Iteration 15\n", "Iteration 20\n", + "Iteration 25\n", + "Iteration 30\n", + "Now performing variance calculations...\n", "Fitting complete.\n", - "Wallclock: 53.852142572402954\n" + "Wallclock: 66.21576309204102\n" ] } ], @@ -516,7 +520,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Wallclock: 2.0166425704956055\n" + "Wallclock: 9.966946601867676\n" ] } ], @@ -543,7 +547,7 @@ { "data": { "text/plain": [ - "SpearmanrResult(correlation=0.7658444861113735, pvalue=0.0)" + "SpearmanrResult(correlation=0.7664973351782205, pvalue=0.0)" ] }, "execution_count": 13, @@ -562,20 +566,13 @@ "id": "a6e3d321", "metadata": {}, "source": [ - "Notice we're already at 0.766, outperforming a CNN for sequences.\n", - "Not bad, given that we are merely using one-hot encoded input. It is of course possible\n", - "to try to use another representation (e.g. the output of a language model)\n", - "as the input to a GP. The FHTConv1d kernel used here measures the similarity\n", - "of two sequences as the similarity between k-mers, as measured by an\n", - "RBF kernel assessed across each pair of k-mers. It is therefore\n", - "likely that information from some position-specific scoring matrix (PSSM)\n", - "(e.g. BLOSUM) would if used instead of one hot encoding improve performance\n", - "as well.\n", - "\n", - "Perhaps the most interesting result is the poor performance of the\n", - "pretrained model, which in this case (and on many other of the FLIP\n", - "benchmarks) loses both to a GP and a 1d CNN despite having access\n", - "to a large corpus for unsupervised pretraining." + "Notice we're already at 0.766, outperforming a CNN for sequences also trained\n", + "on one-hot encoded input. It is of course possible\n", + "to use another representation (e.g. the output of a language model)\n", + "as the input to a GP. We discuss some other\n", + "possible representations and show that for some tasks using a language\n", + "model embedding can improve results (although for this dataset, interestingly,\n", + "not by very much.)" ] }, { @@ -634,7 +631,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.16" + "version": "3.9.10" } }, "nbformat": 4, diff --git a/docs/notebooks/tabular_example.ipynb b/docs/notebooks/tabular_example.ipynb index 22b4113..8fd77d5 100644 --- a/docs/notebooks/tabular_example.ipynb +++ b/docs/notebooks/tabular_example.ipynb @@ -11,12 +11,12 @@ "fairly random UCI repository dataset with about 45,000 datapoints. We'll\n", "download this data, do some light preprocessing, and fit an RBF kernel.\n", "\n", - "These experiments used xGPR 0.1.0.0 on an A6000 GPU." + "These experiments used xGPR 0.1.2.3 on a GTX1070 GPU." ] }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "id": "c1d0a6db", "metadata": {}, "outputs": [], @@ -38,7 +38,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "id": "39d7217e", "metadata": {}, "outputs": [ @@ -46,439 +46,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "\r", - " 0% [ ] 0 / 3528710\r", - " 0% [ ] 8192 / 3528710\r", - " 0% [ ] 16384 / 3528710\r", - " 0% [ ] 24576 / 3528710\r", - " 0% [ ] 32768 / 3528710\r", - " 1% [ ] 40960 / 3528710\r", - " 1% [ ] 49152 / 3528710\r", - " 1% [ ] 57344 / 3528710\r", - " 1% [. ] 65536 / 3528710\r", - " 2% [. ] 73728 / 3528710\r", - " 2% [. ] 81920 / 3528710\r", - " 2% [. ] 90112 / 3528710\r", - " 2% [. ] 98304 / 3528710\r", - " 3% [. ] 106496 / 3528710\r", - " 3% [. ] 114688 / 3528710\r", - " 3% [. ] 122880 / 3528710\r", - " 3% [.. ] 131072 / 3528710\r", - " 3% [.. ] 139264 / 3528710\r", - " 4% [.. ] 147456 / 3528710\r", - " 4% [.. ] 155648 / 3528710\r", - " 4% [.. ] 163840 / 3528710\r", - " 4% [.. ] 172032 / 3528710\r", - " 5% [.. ] 180224 / 3528710\r", - " 5% [.. ] 188416 / 3528710\r", - " 5% [... ] 196608 / 3528710\r", - " 5% [... ] 204800 / 3528710\r", - " 6% [... ] 212992 / 3528710\r", - " 6% [... ] 221184 / 3528710\r", - " 6% [... ] 229376 / 3528710\r", - " 6% [... ] 237568 / 3528710\r", - " 6% [... ] 245760 / 3528710\r", - " 7% [... ] 253952 / 3528710\r", - " 7% [.... ] 262144 / 3528710\r", - " 7% [.... ] 270336 / 3528710\r", - " 7% [.... ] 278528 / 3528710\r", - " 8% [.... ] 286720 / 3528710\r", - " 8% [.... ] 294912 / 3528710\r", - " 8% [.... ] 303104 / 3528710\r", - " 8% [.... ] 311296 / 3528710\r", - " 9% [.... ] 319488 / 3528710\r", - " 9% [..... ] 327680 / 3528710\r", - " 9% [..... ] 335872 / 3528710\r", - " 9% [..... ] 344064 / 3528710\r", - " 9% [..... ] 352256 / 3528710\r", - " 10% [..... ] 360448 / 3528710\r", - " 10% [..... ] 368640 / 3528710\r", - " 10% [..... ] 376832 / 3528710\r", - " 10% [..... ] 385024 / 3528710\r", - " 11% [...... ] 393216 / 3528710\r", - " 11% [...... ] 401408 / 3528710\r", - " 11% [...... ] 409600 / 3528710\r", - " 11% [...... ] 417792 / 3528710\r", - " 12% [...... ] 425984 / 3528710\r", - " 12% [...... ] 434176 / 3528710\r", - " 12% [...... ] 442368 / 3528710\r", - " 12% [...... ] 450560 / 3528710\r", - " 13% [....... ] 458752 / 3528710\r", - " 13% [....... ] 466944 / 3528710\r", - " 13% [....... ] 475136 / 3528710\r", - " 13% [....... ] 483328 / 3528710\r", - " 13% [....... ] 491520 / 3528710\r", - " 14% [....... ] 499712 / 3528710\r", - " 14% [....... ] 507904 / 3528710\r", - " 14% [....... ] 516096 / 3528710\r", - " 14% [........ ] 524288 / 3528710\r", - " 15% [........ ] 532480 / 3528710\r", - " 15% [........ ] 540672 / 3528710\r", - " 15% [........ ] 548864 / 3528710\r", - " 15% [........ ] 557056 / 3528710\r", - " 16% [........ ] 565248 / 3528710\r", - " 16% [........ ] 573440 / 3528710\r", - " 16% [........ ] 581632 / 3528710\r", - " 16% [......... ] 589824 / 3528710\r", - " 16% [......... ] 598016 / 3528710\r", - " 17% [......... ] 606208 / 3528710\r", - " 17% [......... ] 614400 / 3528710\r", - " 17% [......... ] 622592 / 3528710\r", - " 17% [......... ] 630784 / 3528710\r", - " 18% [......... ] 638976 / 3528710\r", - " 18% [......... ] 647168 / 3528710\r", - " 18% [.......... ] 655360 / 3528710\r", - " 18% [.......... ] 663552 / 3528710\r", - " 19% [.......... ] 671744 / 3528710\r", - " 19% [.......... ] 679936 / 3528710\r", - " 19% [.......... ] 688128 / 3528710\r", - " 19% [.......... ] 696320 / 3528710\r", - " 19% [.......... ] 704512 / 3528710\r", - " 20% [.......... ] 712704 / 3528710\r", - " 20% [........... ] 720896 / 3528710\r", - " 20% [........... ] 729088 / 3528710\r", - " 20% [........... ] 737280 / 3528710\r", - " 21% [........... ] 745472 / 3528710\r", - " 21% [........... ] 753664 / 3528710\r", - " 21% [........... ] 761856 / 3528710\r", - " 21% [........... ] 770048 / 3528710\r", - " 22% [........... ] 778240 / 3528710\r", - " 22% [............ ] 786432 / 3528710\r", - " 22% [............ ] 794624 / 3528710\r", - " 22% [............ ] 802816 / 3528710\r", - " 22% [............ ] 811008 / 3528710\r", - " 23% [............ ] 819200 / 3528710\r", - " 23% [............ ] 827392 / 3528710\r", - " 23% [............ ] 835584 / 3528710\r", - " 23% [............ ] 843776 / 3528710\r", - " 24% [............. ] 851968 / 3528710\r", - " 24% [............. ] 860160 / 3528710\r", - " 24% [............. ] 868352 / 3528710\r", - " 24% [............. ] 876544 / 3528710\r", - " 25% [............. ] 884736 / 3528710\r", - " 25% [............. ] 892928 / 3528710\r", - " 25% [............. ] 901120 / 3528710\r", - " 25% [............. ] 909312 / 3528710\r", - " 26% [.............. ] 917504 / 3528710\r", - " 26% [.............. ] 925696 / 3528710\r", - " 26% [.............. ] 933888 / 3528710\r", - " 26% [.............. ] 942080 / 3528710\r", - " 26% [.............. ] 950272 / 3528710\r", - " 27% [.............. ] 958464 / 3528710\r", - " 27% [.............. ] 966656 / 3528710\r", - " 27% [.............. ] 974848 / 3528710\r", - " 27% [............... ] 983040 / 3528710\r", - " 28% [............... ] 991232 / 3528710\r", - " 28% [............... ] 999424 / 3528710\r", - " 28% [............... ] 1007616 / 3528710\r", - " 28% [............... ] 1015808 / 3528710\r", - " 29% [............... ] 1024000 / 3528710\r", - " 29% [............... ] 1032192 / 3528710\r", - " 29% [............... ] 1040384 / 3528710\r", - " 29% [................ ] 1048576 / 3528710\r", - " 29% [................ ] 1056768 / 3528710\r", - " 30% [................ ] 1064960 / 3528710\r", - " 30% [................ ] 1073152 / 3528710\r", - " 30% [................ ] 1081344 / 3528710\r", - " 30% [................ ] 1089536 / 3528710\r", - " 31% [................ ] 1097728 / 3528710\r", - " 31% [................ ] 1105920 / 3528710\r", - " 31% [................. ] 1114112 / 3528710\r", - " 31% [................. ] 1122304 / 3528710\r", - " 32% [................. ] 1130496 / 3528710\r", - " 32% [................. ] 1138688 / 3528710\r", - " 32% [................. ] 1146880 / 3528710\r", - " 32% [................. ] 1155072 / 3528710\r", - " 32% [................. ] 1163264 / 3528710\r", - " 33% [................. ] 1171456 / 3528710\r", - " 33% [.................. ] 1179648 / 3528710\r", - " 33% [.................. ] 1187840 / 3528710\r", - " 33% [.................. ] 1196032 / 3528710\r", - " 34% [.................. ] 1204224 / 3528710\r", - " 34% [.................. ] 1212416 / 3528710\r", - " 34% [.................. ] 1220608 / 3528710\r", - " 34% [.................. ] 1228800 / 3528710\r", - " 35% [.................. ] 1236992 / 3528710\r", - " 35% [................... ] 1245184 / 3528710\r", - " 35% [................... ] 1253376 / 3528710\r", - " 35% [................... ] 1261568 / 3528710\r", - " 35% [................... ] 1269760 / 3528710\r", - " 36% [................... ] 1277952 / 3528710\r", - " 36% [................... ] 1286144 / 3528710\r", - " 36% [................... ] 1294336 / 3528710\r", - " 36% [................... ] 1302528 / 3528710\r", - " 37% [.................... ] 1310720 / 3528710\r", - " 37% [.................... ] 1318912 / 3528710\r", - " 37% [.................... ] 1327104 / 3528710\r", - " 37% [.................... ] 1335296 / 3528710\r", - " 38% [.................... ] 1343488 / 3528710\r", - " 38% [.................... ] 1351680 / 3528710\r", - " 38% [.................... ] 1359872 / 3528710\r", - " 38% [.................... ] 1368064 / 3528710\r", - " 39% [..................... ] 1376256 / 3528710\r", - " 39% [..................... ] 1384448 / 3528710\r", - " 39% [..................... ] 1392640 / 3528710\r", - " 39% [..................... ] 1400832 / 3528710\r", - " 39% [..................... ] 1409024 / 3528710\r", - " 40% [..................... ] 1417216 / 3528710\r", - " 40% [..................... ] 1425408 / 3528710\r", - " 40% [..................... ] 1433600 / 3528710\r", - " 40% [...................... ] 1441792 / 3528710\r", - " 41% [...................... ] 1449984 / 3528710\r", - " 41% [...................... ] 1458176 / 3528710\r", - " 41% [...................... ] 1466368 / 3528710\r", - " 41% [...................... ] 1474560 / 3528710\r", - " 42% [...................... ] 1482752 / 3528710\r", - " 42% [...................... ] 1490944 / 3528710\r", - " 42% [...................... ] 1499136 / 3528710\r", - " 42% [....................... ] 1507328 / 3528710\r", - " 42% [....................... ] 1515520 / 3528710\r", - " 43% [....................... ] 1523712 / 3528710\r", - " 43% [....................... ] 1531904 / 3528710\r", - " 43% [....................... ] 1540096 / 3528710\r", - " 43% [....................... ] 1548288 / 3528710\r", - " 44% [....................... ] 1556480 / 3528710\r", - " 44% [....................... ] 1564672 / 3528710\r", - " 44% [........................ ] 1572864 / 3528710\r", - " 44% [........................ ] 1581056 / 3528710\r", - " 45% [........................ ] 1589248 / 3528710\r", - " 45% [........................ ] 1597440 / 3528710\r", - " 45% [........................ ] 1605632 / 3528710\r", - " 45% [........................ ] 1613824 / 3528710\r", - " 45% [........................ ] 1622016 / 3528710\r", - " 46% [........................ ] 1630208 / 3528710\r", - " 46% [......................... ] 1638400 / 3528710\r", - " 46% [......................... ] 1646592 / 3528710\r", - " 46% [......................... ] 1654784 / 3528710\r", - " 47% [......................... ] 1662976 / 3528710\r", - " 47% [......................... ] 1671168 / 3528710\r", - " 47% [......................... ] 1679360 / 3528710\r", - " 47% [......................... ] 1687552 / 3528710\r", - " 48% [......................... ] 1695744 / 3528710\r", - " 48% [.......................... ] 1703936 / 3528710\r", - " 48% [.......................... ] 1712128 / 3528710\r", - " 48% [.......................... ] 1720320 / 3528710\r", - " 48% [.......................... ] 1728512 / 3528710\r", - " 49% [.......................... ] 1736704 / 3528710\r", - " 49% [.......................... ] 1744896 / 3528710\r", - " 49% [.......................... ] 1753088 / 3528710\r", - " 49% [.......................... ] 1761280 / 3528710\r", - " 50% [........................... ] 1769472 / 3528710\r", - " 50% [........................... ] 1777664 / 3528710\r", - " 50% [........................... ] 1785856 / 3528710\r", - " 50% [........................... ] 1794048 / 3528710\r", - " 51% [........................... ] 1802240 / 3528710\r", - " 51% [........................... ] 1810432 / 3528710\r", - " 51% [........................... ] 1818624 / 3528710\r", - " 51% [........................... ] 1826816 / 3528710\r", - " 52% [............................ ] 1835008 / 3528710\r", - " 52% [............................ ] 1843200 / 3528710\r", - " 52% [............................ ] 1851392 / 3528710\r", - " 52% [............................ ] 1859584 / 3528710\r", - " 52% [............................ ] 1867776 / 3528710\r", - " 53% [............................ ] 1875968 / 3528710\r", - " 53% [............................ ] 1884160 / 3528710\r", - " 53% [............................ ] 1892352 / 3528710\r", - " 53% [............................. ] 1900544 / 3528710\r", - " 54% [............................. ] 1908736 / 3528710\r", - " 54% [............................. ] 1916928 / 3528710\r", - " 54% [............................. ] 1925120 / 3528710\r", - " 54% [............................. ] 1933312 / 3528710\r", - " 55% [............................. ] 1941504 / 3528710\r", - " 55% [............................. ] 1949696 / 3528710\r", - " 55% [............................. ] 1957888 / 3528710\r", - " 55% [.............................. ] 1966080 / 3528710\r", - " 55% [.............................. ] 1974272 / 3528710\r", - " 56% [.............................. ] 1982464 / 3528710\r", - " 56% [.............................. ] 1990656 / 3528710\r", - " 56% [.............................. ] 1998848 / 3528710\r", - " 56% [.............................. ] 2007040 / 3528710\r", - " 57% [.............................. ] 2015232 / 3528710\r", - " 57% [.............................. ] 2023424 / 3528710\r", - " 57% [............................... ] 2031616 / 3528710\r", - " 57% [............................... ] 2039808 / 3528710\r", - " 58% [............................... ] 2048000 / 3528710\r", - " 58% [............................... ] 2056192 / 3528710\r", - " 58% [............................... ] 2064384 / 3528710\r", - " 58% [............................... ] 2072576 / 3528710\r", - " 58% [............................... ] 2080768 / 3528710\r", - " 59% [............................... ] 2088960 / 3528710\r", - " 59% [................................ ] 2097152 / 3528710\r", - " 59% [................................ ] 2105344 / 3528710\r", - " 59% [................................ ] 2113536 / 3528710\r", - " 60% [................................ ] 2121728 / 3528710\r", - " 60% [................................ ] 2129920 / 3528710\r", - " 60% [................................ ] 2138112 / 3528710\r", - " 60% [................................ ] 2146304 / 3528710\r", - " 61% [................................ ] 2154496 / 3528710\r", - " 61% [................................. ] 2162688 / 3528710\r", - " 61% [................................. ] 2170880 / 3528710\r", - " 61% [................................. ] 2179072 / 3528710\r", - " 61% [................................. ] 2187264 / 3528710\r", - " 62% [................................. ] 2195456 / 3528710\r", - " 62% [................................. ] 2203648 / 3528710\r", - " 62% [................................. ] 2211840 / 3528710\r", - " 62% [................................. ] 2220032 / 3528710\r", - " 63% [.................................. ] 2228224 / 3528710\r", - " 63% [.................................. ] 2236416 / 3528710\r", - " 63% [.................................. ] 2244608 / 3528710\r", - " 63% [.................................. ] 2252800 / 3528710\r", - " 64% [.................................. ] 2260992 / 3528710\r", - " 64% [.................................. ] 2269184 / 3528710\r", - " 64% [.................................. ] 2277376 / 3528710\r", - " 64% [.................................. ] 2285568 / 3528710\r", - " 65% [................................... ] 2293760 / 3528710\r", - " 65% [................................... ] 2301952 / 3528710\r", - " 65% [................................... ] 2310144 / 3528710\r", - " 65% [................................... ] 2318336 / 3528710\r", - " 65% [................................... ] 2326528 / 3528710\r", - " 66% [................................... ] 2334720 / 3528710\r", - " 66% [................................... ] 2342912 / 3528710\r", - " 66% [................................... ] 2351104 / 3528710\r", - " 66% [.................................... ] 2359296 / 3528710\r", - " 67% [.................................... ] 2367488 / 3528710\r", - " 67% [.................................... ] 2375680 / 3528710\r", - " 67% [.................................... ] 2383872 / 3528710\r", - " 67% [.................................... ] 2392064 / 3528710\r", - " 68% [.................................... ] 2400256 / 3528710\r", - " 68% [.................................... ] 2408448 / 3528710\r", - " 68% [.................................... ] 2416640 / 3528710\r", - " 68% [..................................... ] 2424832 / 3528710\r", - " 68% [..................................... ] 2433024 / 3528710\r", - " 69% [..................................... ] 2441216 / 3528710\r", - " 69% [..................................... ] 2449408 / 3528710\r", - " 69% [..................................... ] 2457600 / 3528710\r", - " 69% [..................................... ] 2465792 / 3528710\r", - " 70% [..................................... ] 2473984 / 3528710\r", - " 70% [..................................... ] 2482176 / 3528710\r", - " 70% [...................................... ] 2490368 / 3528710\r", - " 70% [...................................... ] 2498560 / 3528710\r", - " 71% [...................................... ] 2506752 / 3528710\r", - " 71% [...................................... ] 2514944 / 3528710\r", - " 71% [...................................... ] 2523136 / 3528710\r", - " 71% [...................................... ] 2531328 / 3528710\r", - " 71% [...................................... ] 2539520 / 3528710\r", - " 72% [...................................... ] 2547712 / 3528710\r", - " 72% [....................................... ] 2555904 / 3528710\r", - " 72% [....................................... ] 2564096 / 3528710\r", - " 72% [....................................... ] 2572288 / 3528710\r", - " 73% [....................................... ] 2580480 / 3528710\r", - " 73% [....................................... ] 2588672 / 3528710\r", - " 73% [....................................... ] 2596864 / 3528710\r", - " 73% [....................................... ] 2605056 / 3528710\r", - " 74% [....................................... ] 2613248 / 3528710\r", - " 74% [........................................ ] 2621440 / 3528710\r", - " 74% [........................................ ] 2629632 / 3528710\r", - " 74% [........................................ ] 2637824 / 3528710\r", - " 74% [........................................ ] 2646016 / 3528710\r", - " 75% [........................................ ] 2654208 / 3528710\r", - " 75% [........................................ ] 2662400 / 3528710\r", - " 75% [........................................ ] 2670592 / 3528710\r", - " 75% [........................................ ] 2678784 / 3528710\r", - " 76% [......................................... ] 2686976 / 3528710\r", - " 76% [......................................... ] 2695168 / 3528710\r", - " 76% [......................................... ] 2703360 / 3528710\r", - " 76% [......................................... ] 2711552 / 3528710\r", - " 77% [......................................... ] 2719744 / 3528710\r", - " 77% [......................................... ] 2727936 / 3528710\r", - " 77% [......................................... ] 2736128 / 3528710\r", - " 77% [......................................... ] 2744320 / 3528710\r", - " 78% [.......................................... ] 2752512 / 3528710\r", - " 78% [.......................................... ] 2760704 / 3528710\r", - " 78% [.......................................... ] 2768896 / 3528710\r", - " 78% [.......................................... ] 2777088 / 3528710\r", - " 78% [.......................................... ] 2785280 / 3528710\r", - " 79% [.......................................... ] 2793472 / 3528710\r", - " 79% [.......................................... ] 2801664 / 3528710\r", - " 79% [.......................................... ] 2809856 / 3528710\r", - " 79% [........................................... ] 2818048 / 3528710\r", - " 80% [........................................... ] 2826240 / 3528710\r", - " 80% [........................................... ] 2834432 / 3528710\r", - " 80% [........................................... ] 2842624 / 3528710\r", - " 80% [........................................... ] 2850816 / 3528710\r", - " 81% [........................................... ] 2859008 / 3528710\r", - " 81% [........................................... ] 2867200 / 3528710\r", - " 81% [............................................ ] 2875392 / 3528710\r", - " 81% [............................................ ] 2883584 / 3528710\r", - " 81% [............................................ ] 2891776 / 3528710\r", - " 82% [............................................ ] 2899968 / 3528710\r", - " 82% [............................................ ] 2908160 / 3528710\r", - " 82% [............................................ ] 2916352 / 3528710\r", - " 82% [............................................ ] 2924544 / 3528710\r", - " 83% [............................................ ] 2932736 / 3528710\r", - " 83% [............................................. ] 2940928 / 3528710\r", - " 83% [............................................. ] 2949120 / 3528710\r", - " 83% [............................................. ] 2957312 / 3528710\r", - " 84% [............................................. ] 2965504 / 3528710\r", - " 84% [............................................. ] 2973696 / 3528710\r", - " 84% [............................................. ] 2981888 / 3528710\r", - " 84% [............................................. ] 2990080 / 3528710\r", - " 84% [............................................. ] 2998272 / 3528710\r", - " 85% [.............................................. ] 3006464 / 3528710\r", - " 85% [.............................................. ] 3014656 / 3528710\r", - " 85% [.............................................. ] 3022848 / 3528710\r", - " 85% [.............................................. ] 3031040 / 3528710\r", - " 86% [.............................................. ] 3039232 / 3528710\r", - " 86% [.............................................. ] 3047424 / 3528710\r", - " 86% [.............................................. ] 3055616 / 3528710\r", - " 86% [.............................................. ] 3063808 / 3528710\r", - " 87% [............................................... ] 3072000 / 3528710\r", - " 87% [............................................... ] 3080192 / 3528710\r", - " 87% [............................................... ] 3088384 / 3528710\r", - " 87% [............................................... ] 3096576 / 3528710\r", - " 87% [............................................... ] 3104768 / 3528710\r", - " 88% [............................................... ] 3112960 / 3528710\r", - " 88% [............................................... ] 3121152 / 3528710\r", - " 88% [............................................... ] 3129344 / 3528710\r", - " 88% [................................................ ] 3137536 / 3528710\r", - " 89% [................................................ ] 3145728 / 3528710\r", - " 89% [................................................ ] 3153920 / 3528710\r", - " 89% [................................................ ] 3162112 / 3528710\r", - " 89% [................................................ ] 3170304 / 3528710\r", - " 90% [................................................ ] 3178496 / 3528710\r", - " 90% [................................................ ] 3186688 / 3528710\r", - " 90% [................................................ ] 3194880 / 3528710\r", - " 90% [................................................. ] 3203072 / 3528710\r", - " 91% [................................................. ] 3211264 / 3528710\r", - " 91% [................................................. ] 3219456 / 3528710\r", - " 91% [................................................. ] 3227648 / 3528710\r", - " 91% [................................................. ] 3235840 / 3528710\r", - " 91% [................................................. ] 3244032 / 3528710\r", - " 92% [................................................. ] 3252224 / 3528710\r", - " 92% [................................................. ] 3260416 / 3528710\r", - " 92% [.................................................. ] 3268608 / 3528710\r", - " 92% [.................................................. ] 3276800 / 3528710\r", - " 93% [.................................................. ] 3284992 / 3528710\r", - " 93% [.................................................. ] 3293184 / 3528710\r", - " 93% [.................................................. ] 3301376 / 3528710\r", - " 93% [.................................................. ] 3309568 / 3528710\r", - " 94% [.................................................. ] 3317760 / 3528710\r", - " 94% [.................................................. ] 3325952 / 3528710\r", - " 94% [................................................... ] 3334144 / 3528710\r", - " 94% [................................................... ] 3342336 / 3528710\r", - " 94% [................................................... ] 3350528 / 3528710\r", - " 95% [................................................... ] 3358720 / 3528710\r", - " 95% [................................................... ] 3366912 / 3528710\r", - " 95% [................................................... ] 3375104 / 3528710\r", - " 95% [................................................... ] 3383296 / 3528710\r", - " 96% [................................................... ] 3391488 / 3528710\r", - " 96% [.................................................... ] 3399680 / 3528710\r", - " 96% [.................................................... ] 3407872 / 3528710\r", - " 96% [.................................................... ] 3416064 / 3528710\r", - " 97% [.................................................... ] 3424256 / 3528710\r", - " 97% [.................................................... ] 3432448 / 3528710\r", - " 97% [.................................................... ] 3440640 / 3528710\r", - " 97% [.................................................... ] 3448832 / 3528710\r", - " 97% [.................................................... ] 3457024 / 3528710\r", - " 98% [..................................................... ] 3465216 / 3528710\r", - " 98% [..................................................... ] 3473408 / 3528710\r", - " 98% [..................................................... ] 3481600 / 3528710\r", - " 98% [..................................................... ] 3489792 / 3528710\r", - " 99% [..................................................... ] 3497984 / 3528710\r", - " 99% [..................................................... ] 3506176 / 3528710\r", - " 99% [..................................................... ] 3514368 / 3528710\r", - " 99% [..................................................... ] 3522560 / 3528710\r", - "100% [......................................................] 3528710 / 3528710" + "-1 / unknown" ] } ], @@ -490,7 +58,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "id": "ec1e70f2", "metadata": {}, "outputs": [ @@ -706,7 +274,7 @@ "[45730 rows x 10 columns]" ] }, - "execution_count": 3, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } @@ -723,13 +291,17 @@ "\n", "\n", "Note that we can but don't need to rescale\n", - "y-values -- xGPR will rescale y-values automatically.\n", + "y-values -- xGPR will rescale y-values automatically unless\n", + "we tell it to do otherwise. The predictions are automatically\n", + "converted back to the original scale. If you want to *stop* xGPR from\n", + "rescaling y-values during training, you can pass `normalize_y=False`\n", + "when constructing a dataset.\n", "\n" ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 5, "id": "649cb19e", "metadata": {}, "outputs": [], @@ -769,7 +341,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 6, "id": "9498c163", "metadata": {}, "outputs": [], @@ -791,7 +363,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 7, "id": "4d3cd484", "metadata": {}, "outputs": [], @@ -842,7 +414,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 8, "id": "6c8a2cc5", "metadata": {}, "outputs": [ @@ -865,19 +437,19 @@ "Additional acquisition 10.\n", "New hparams: [0.3133586]\n", "Additional acquisition 11.\n", - "New hparams: [0.4027546]\n", + "New hparams: [0.5691902]\n", "Additional acquisition 12.\n", - "New hparams: [-0.9321185]\n", + "New hparams: [-0.7985376]\n", "Additional acquisition 13.\n", - "New hparams: [0.8437973]\n", + "New hparams: [-2.5544848]\n", "Additional acquisition 14.\n", - "New hparams: [-3.1247573]\n", + "New hparams: [-4.5317151]\n", "Additional acquisition 15.\n", - "New hparams: [-5.603476]\n", - "Best score achieved: 38889.132\n", - "Best hyperparams: [-0.4226011 0.192718 0.4027546]\n", + "New hparams: [0.5058531]\n", + "Best score achieved: 38826.519\n", + "Best hyperparams: [-0.4226011 -0.2897586 0.5058531]\n", "Tuning complete.\n", - "Wallclock: 31.684582710266113\n" + "Wallclock: 79.04574918746948\n" ] } ], @@ -903,7 +475,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 9, "id": "4defbf34", "metadata": {}, "outputs": [ @@ -926,19 +498,19 @@ "Additional acquisition 10.\n", "New hparams: [0.3133586]\n", "Additional acquisition 11.\n", - "New hparams: [0.4027546]\n", + "New hparams: [0.5691902]\n", "Additional acquisition 12.\n", - "New hparams: [-0.9321185]\n", + "New hparams: [-0.7985376]\n", "Additional acquisition 13.\n", - "New hparams: [0.8437973]\n", + "New hparams: [-2.5544848]\n", "Additional acquisition 14.\n", - "New hparams: [-3.1247573]\n", + "New hparams: [-4.5317151]\n", "Additional acquisition 15.\n", - "New hparams: [-5.603476]\n", - "Best score achieved: 38889.132\n", - "Best hyperparams: [-0.4226011 0.192718 0.4027546]\n", + "New hparams: [0.5058531]\n", + "Best score achieved: 38826.519\n", + "Best hyperparams: [-0.4226011 -0.2897586 0.5058531]\n", "Tuning complete.\n", - "Wallclock: 31.589682579040527\n" + "Wallclock: 78.39839148521423\n" ] } ], @@ -998,7 +570,12 @@ "Evaluating gradient...\n", "Evaluating gradient...\n", "Evaluating gradient...\n", - "Restart 0 completed. Best score is 38886.80149435726.\n", + "Evaluating gradient...\n", + "Evaluating gradient...\n", + "Evaluating gradient...\n", + "Evaluating gradient...\n", + "Evaluating gradient...\n", + "Restart 0 completed. Best score is 38815.89957910789.\n", "Evaluating gradient...\n", "Evaluating gradient...\n", "Evaluating gradient...\n", @@ -1017,9 +594,9 @@ "Evaluating gradient...\n", "Evaluating gradient...\n", "Evaluating gradient...\n", - "Restart 1 completed. Best score is 38886.80149435726.\n", "Evaluating gradient...\n", "Evaluating gradient...\n", + "Restart 1 completed. Best score is 38815.89957910789.\n", "Evaluating gradient...\n", "Evaluating gradient...\n", "Evaluating gradient...\n", @@ -1032,9 +609,11 @@ "Evaluating gradient...\n", "Evaluating gradient...\n", "Evaluating gradient...\n", - "Restart 2 completed. Best score is 38886.80149435726.\n", + "Evaluating gradient...\n", + "Evaluating gradient...\n", + "Restart 2 completed. Best score is 38815.89957910789.\n", "Tuning complete.\n", - "Wallclock: 55.88275384902954\n" + "Wallclock: 183.09004998207092\n" ] } ], @@ -1063,17 +642,17 @@ }, { "cell_type": "code", - "execution_count": 11, - "id": "accec6d6", + "execution_count": 12, + "id": "c2058512", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "array([-0.41722132, 0.20104136, 0.39762496])" + "array([-0.4167661 , -0.21041182, 0.40020105])" ] }, - "execution_count": 11, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } @@ -1096,7 +675,7 @@ "stochastic gradient descent can also be competitive).\n", "\n", "method can be either 'srht' or 'srht_2'. 'srht' requires only\n", - "one pass across the dataset and no matrix multiplications, so it's\n", + "one pass across the dataset, so it's\n", "pretty fast. 'srht_2' requires two passes across the dataset and\n", "involves matrix multiplication, so it's slower but the resulting\n", "preconditioner usually reduces the number of CG iterations required to\n", @@ -1107,7 +686,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 13, "id": "18ae82bf", "metadata": {}, "outputs": [ @@ -1119,7 +698,7 @@ "Chunk 10 complete.\n", "Chunk 0 complete.\n", "Chunk 10 complete.\n", - "Wallclock: 2.0374441146850586\n" + "Wallclock: 5.4230124950408936\n" ] } ], @@ -1133,7 +712,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 14, "id": "be473990", "metadata": {}, "outputs": [ @@ -1141,7 +720,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "13.69709449862266\n" + "12.153107742865895\n" ] } ], @@ -1175,7 +754,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 15, "id": "5e12d8c9", "metadata": {}, "outputs": [ @@ -1191,8 +770,9 @@ "Iteration 20\n", "Iteration 25\n", "Iteration 30\n", + "Now performing variance calculations...\n", "Fitting complete.\n", - "Wallclock: 1.904855728149414\n" + "Wallclock: 5.044420957565308\n" ] } ], @@ -1217,17 +797,17 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 16, "id": "66136407", "metadata": {}, "outputs": [], "source": [ - "test_predictions = uci_model.predict(test_x, get_var = False, chunk_size = 1000)" + "test_predictions, test_var = uci_model.predict(test_x, get_var = True, chunk_size = 1000)" ] }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 17, "id": "e6e1aef6", "metadata": {}, "outputs": [ @@ -1235,7 +815,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "MAE: 2.9641127603063318\n" + "MAE: 2.9686757776170625\n" ] } ], @@ -1278,7 +858,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 18, "id": "65ceb406", "metadata": {}, "outputs": [ @@ -1290,7 +870,7 @@ "Chunk 10 complete.\n", "Chunk 0 complete.\n", "Chunk 10 complete.\n", - "Wallclock: 6.977460622787476\n" + "Wallclock: 20.7839298248291\n" ] } ], @@ -1306,7 +886,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 19, "id": "c0f76c87", "metadata": {}, "outputs": [ @@ -1322,8 +902,9 @@ "Iteration 20\n", "Iteration 25\n", "Iteration 30\n", + "Now performing variance calculations...\n", "Fitting complete.\n", - "Wallclock: 5.380032539367676\n" + "Wallclock: 16.338116884231567\n" ] } ], @@ -1337,7 +918,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 20, "id": "548d75f8", "metadata": {}, "outputs": [ @@ -1345,7 +926,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "MAE: 2.869126043669058\n" + "MAE: 2.875096362754352\n" ] } ], @@ -1374,7 +955,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 21, "id": "b7208786", "metadata": {}, "outputs": [ @@ -1453,8 +1034,11 @@ "Now building preconditioner...\n", "Now fitting...\n", "NMLL evaluation completed.\n", + "Now building preconditioner...\n", + "Now fitting...\n", + "NMLL evaluation completed.\n", "Tuning complete.\n", - "Wallclock: 192.6702961921692\n" + "Wallclock: 654.5522615909576\n" ] } ], @@ -1475,7 +1059,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 22, "id": "afb6c7c0", "metadata": {}, "outputs": [ @@ -1483,7 +1067,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "[-0.49753128 0.27381234 0.62625147]\n" + "[-0.49860283 -0.06108382 0.6371851 ]\n" ] } ], @@ -1508,7 +1092,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 23, "id": "90fbdc7c", "metadata": {}, "outputs": [ @@ -1531,8 +1115,9 @@ "Iteration 35\n", "Iteration 40\n", "Iteration 45\n", + "Now performing variance calculations...\n", "Fitting complete.\n", - "2.8871837478547877\n" + "2.884159125457879\n" ] } ], @@ -1549,7 +1134,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 24, "id": "e3d3fe41", "metadata": {}, "outputs": [], @@ -1561,7 +1146,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 25, "id": "11e757eb", "metadata": {}, "outputs": [], diff --git a/test/complete_pipeline_tests/test_current_kernels.py b/test/complete_pipeline_tests/test_current_kernels.py index 0b94d0e..c60a973 100644 --- a/test/complete_pipeline_tests/test_current_kernels.py +++ b/test/complete_pipeline_tests/test_current_kernels.py @@ -23,7 +23,7 @@ def test_fit_cpu(self): "involving ARD kernels) may take a minute.") for kernel_name, (is_conv, exp_score) in IMPLEMENTED_KERNELS.items(): cg_score, exact_score = test_fit_cpu(kernel_name, is_conv, RANDOM_SEED, - conv_width = 3, get_var = False) + conv_width = 3, get_var = True) self.assertTrue(cg_score > exp_score) self.assertTrue(exact_score > exp_score) @@ -32,7 +32,7 @@ def test_fit_gpu(self): print("Now running GPU tests.") for kernel_name, (is_conv, exp_score) in IMPLEMENTED_KERNELS.items(): cg_score, exact_score = test_fit_gpu(kernel_name, is_conv, RANDOM_SEED, - conv_width = 3, get_var = False) + conv_width = 3, get_var = True) self.assertTrue(cg_score > exp_score) self.assertTrue(exact_score > exp_score) diff --git a/xGPR/__init__.py b/xGPR/__init__.py index 41fbd9e..30d42b6 100644 --- a/xGPR/__init__.py +++ b/xGPR/__init__.py @@ -1,3 +1,3 @@ #Version number. Updated if generating a new release. #Otherwise, do not change. -__version__ = "0.1.2.2" +__version__ = "0.1.2.3" diff --git a/xGPR/regression_baseclass.py b/xGPR/regression_baseclass.py index d444d93..3db6a3e 100644 --- a/xGPR/regression_baseclass.py +++ b/xGPR/regression_baseclass.py @@ -624,7 +624,13 @@ def device(self, value): elif value == "cpu" and not isinstance(self.weights, np.ndarray): self.weights = cp.asnumpy(self.weights) if self.var is not None: - self.var.device = value + if self.exact_var_calculation: + if value == "cpu": + self.var = cp.asnumpy(self.var) + else: + self.var = cp.asarray(self.var) + else: + self.var.device = value if value == "gpu": mempool = cp.get_default_memory_pool() mempool.free_all_blocks()