{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": true, "pycharm": { "name": "#%%\n" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Python 3.7.9 (default, Aug 31 2020, 12:42:55) \r\n", "[GCC 7.3.0]\r\n" ] } ], "source": [ "!python -VVV" ] }, { "cell_type": "code", "execution_count": 2, "outputs": [], "source": [ "import gzip\n", "import pickle as p\n", "import re\n", "import string\n", "from collections import Counter\n", "\n", "import pandas\n", "import pandas as pd\n", "import numpy as np\n", "import redis\n", "import xgboost\n", "from hypy_utils import write\n", "from hypy_utils.tqdm_utils import pmap\n", "from xgboost import XGBClassifier\n", "\n", "from main import job_start_redis\n", "from pandas import DataFrame\n", "from datetime import datetime, timezone\n", "from dateutil.relativedelta import relativedelta\n", "import os\n", "from pathlib import Path\n", "\n", "from sklearn.feature_extraction.text import TfidfVectorizer\n", "\n", "\n", "def load_zipped_pickle(filename):\n", " with gzip.open(filename, 'rb') as f:\n", " loaded_object = p.load(f)\n", " l = pd.DataFrame(loaded_object)\n", " return l\n", "\n", "\n", "# total changed files\n", "# total commits\n", "# total fix commits\n", "# avaerage changed lines per fix\n", "#\n", "def get_root_nodes(date: str):\n", " result = sorted([])\n", " for patch in os.listdir(f\"../../data.absolute/{date}/patterns\"):\n", " with open(f\"../../data.absolute/{date}/patterns/{patch}\", 'r') as f:\n", " result.append(f.readline())\n", " return sorted(list(set(result)))\n", "\n", "\n", "def label_gen() -> DataFrame:\n", " start_string = \"2001-12-01\"\n", " start_date = datetime.strptime(start_string, '%Y-%m-%d')\n", " interval = relativedelta(months=6)\n", "\n", " csv = []\n", "\n", " data_path = Path('../../data.absolute')\n", " new = get_root_nodes(start_string)\n", " new = sorted(list(set(new)))\n", " added = sorted(list(set(new)))\n", " remove = sorted([])\n", " csv.append((start_date.strftime('%Y-%m-%d').split(' ')[0], len(new), len(added), len(remove),\n", " added, remove, True))\n", "\n", " while True:\n", " # end = start + interval\n", " end_date = start_date + interval\n", " end_string = end_date.strftime('%Y-%m-%d').split(' ')[0]\n", " start_string = start_date.strftime('%Y-%m-%d').split(' ')[0]\n", " if not os.path.isdir(data_path / str(end_string)):\n", " # new = os.listdir(data_path / f\"{end_string}/patterns\")\n", " # added = sorted(list(set(new)))\n", " # remove = sorted([])\n", " # csv.append((end_string, len(new), added, remove))\n", " break\n", "\n", " new = get_root_nodes(end_string)\n", " old = get_root_nodes(start_string)\n", "\n", " added = sorted(list(set(new) - set(old)))\n", " remove = sorted(list(set(old) - set(new)))\n", "\n", " csv.append((end_string, len(new), len(added), len(remove), added, remove, len(added) != 0))\n", "\n", " start_date += interval\n", "\n", " # plt.plot([v[1] for v in csv], [v[2] for v in csv])\n", " # plt.show()\n", "\n", " df = DataFrame(csv, columns=('Time', 'Number of Patches', 'Numbers of added', 'Numbers of removed', 'Patches Added',\n", " 'Patches Removed', 'New Pattern'))\n", " df.to_csv('diff-test-absolute-root-2.csv')\n", " return df\n", "\n", "\n", "def get_commits_sha(start: str, end: str, project_name: str) -> DataFrame:\n", " start_date = datetime.strptime(start, '%Y-%m-%d').replace(tzinfo=timezone.utc)\n", " end_date = datetime.strptime(end, '%Y-%m-%d').replace(tzinfo=timezone.utc)\n", " commits = load_zipped_pickle(f'/workspace/EECS-Research/data.absolute/{end}/commitsDF/{project_name}-fix.pickle.gz')\n", " return commits[commits['commitDate'].between(start_date, end_date, inclusive=False)]['commit']\n" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": 3, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Shutting down redis 6399...\n", "> Shutdown complete.\n", "/workspace/EECS-Research/data.absolute/2022-06-01/redis\n", "Starting redis 6399...\n", "> Redis started.\n" ] } ], "source": [ "labels = label_gen()\n", "\n", "date = '2022-06-01'\n", "job_start_redis(f'/workspace/EECS-Research/data.absolute/{date}/redis', 6399)" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": 4, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Connected to redis!\n" ] } ], "source": [ "# 1. Load AST diffs\n", "r = redis.StrictRedis(host='localhost', port=6399, db=0)\n", "print(\"Connected to redis!\")\n", "ast_diffs = {k.decode(): v.decode() for k, v in r.hgetall('dump').items()}\n" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": 5, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Loading commits: 100%|██████████| 42/42 [00:01<00:00, 22.41it/s]\n" ] } ], "source": [ "# 2. Load commits\n", "base_path = Path(f'/workspace/EECS-Research/data.absolute/{date}/commitsDF/')\n", "commit_pickles = [base_path / str(f) for f in os.listdir(base_path) if f.endswith('-fix.pickle.gz')]\n", "commits = pandas.concat(pmap(load_zipped_pickle, commit_pickles, desc=f'Loading commits'))\n", "\n", "def get_all_commits_sha(start: str, end: str) -> DataFrame:\n", " start_date = datetime.strptime(start, '%Y-%m-%d').replace(tzinfo=timezone.utc)\n", " end_date = datetime.strptime(end, '%Y-%m-%d').replace(tzinfo=timezone.utc)\n", " df = commits[commits['commitDate'].between(start_date, end_date, inclusive=False)]\n", " return df\n" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": 6, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Ignored 0 AST diff entries\n", "Commit sha lengths: Counter({6: 61553, 7: 3748, 8: 237, 9: 14})\n", "Processing time interval from 2001-12-01 to 2002-06-01\n", "Total of 196 commits in the interval\n", "Total of 44 AST diffs in the interval\n", "Processing time interval from 2002-06-01 to 2002-12-01\n", "Total of 207 commits in the interval\n", "Total of 32 AST diffs in the interval\n", "Processing time interval from 2002-12-01 to 2003-06-01\n", "Total of 276 commits in the interval\n", "Total of 27 AST diffs in the interval\n", "Processing time interval from 2003-06-01 to 2003-12-01\n", "Total of 477 commits in the interval\n", "Total of 60 AST diffs in the interval\n", "Processing time interval from 2003-12-01 to 2004-06-01\n", "Total of 920 commits in the interval\n", "Total of 53 AST diffs in the interval\n", "Processing time interval from 2004-06-01 to 2004-12-01\n", "Total of 627 commits in the interval\n", "Total of 75 AST diffs in the interval\n", "Processing time interval from 2004-12-01 to 2005-06-01\n", "Total of 517 commits in the interval\n", "Total of 40 AST diffs in the interval\n", "Processing time interval from 2005-06-01 to 2005-12-01\n", "Total of 484 commits in the interval\n", "Total of 33 AST diffs in the interval\n", "Processing time interval from 2005-12-01 to 2006-06-01\n", "Total of 613 commits in the interval\n", "Total of 69 AST diffs in the interval\n", "Processing time interval from 2006-06-01 to 2006-12-01\n", "Total of 508 commits in the interval\n", "Total of 37 AST diffs in the interval\n", "Processing time interval from 2006-12-01 to 2007-06-01\n", "Total of 1757 commits in the interval\n", "Total of 141 AST diffs in the interval\n", "Processing time interval from 2007-06-01 to 2007-12-01\n", "Total of 2451 commits in the interval\n", "Total of 241 AST diffs in the interval\n", "Processing time interval from 2007-12-01 to 2008-06-01\n", "Total of 3996 commits in the interval\n", "Total of 447 AST diffs in the interval\n", "Processing time interval from 2008-06-01 to 2008-12-01\n", "Total of 3167 commits in the interval\n", "Total of 481 AST diffs in the interval\n", "Processing time interval from 2008-12-01 to 2009-06-01\n", "Total of 3519 commits in the interval\n", "Total of 556 AST diffs in the interval\n", "Processing time interval from 2009-06-01 to 2009-12-01\n", "Total of 3243 commits in the interval\n", "Total of 548 AST diffs in the interval\n", "Processing time interval from 2009-12-01 to 2010-06-01\n", "Total of 3184 commits in the interval\n", "Total of 359 AST diffs in the interval\n", "Processing time interval from 2010-06-01 to 2010-12-01\n", "Total of 5523 commits in the interval\n", "Total of 746 AST diffs in the interval\n", "Processing time interval from 2010-12-01 to 2011-06-01\n", "Total of 10476 commits in the interval\n", "Total of 1331 AST diffs in the interval\n", "Processing time interval from 2011-06-01 to 2011-12-01\n", "Total of 9990 commits in the interval\n", "Total of 1965 AST diffs in the interval\n", "Processing time interval from 2011-12-01 to 2012-06-01\n", "Total of 9092 commits in the interval\n", "Total of 1466 AST diffs in the interval\n", "Processing time interval from 2012-06-01 to 2012-12-01\n", "Total of 6427 commits in the interval\n", "Total of 924 AST diffs in the interval\n", "Processing time interval from 2012-12-01 to 2013-06-01\n", "Total of 6943 commits in the interval\n", "Total of 924 AST diffs in the interval\n", "Processing time interval from 2013-06-01 to 2013-12-01\n", "Total of 8525 commits in the interval\n", "Total of 1208 AST diffs in the interval\n", "Processing time interval from 2013-12-01 to 2014-06-01\n", "Total of 6567 commits in the interval\n", "Total of 955 AST diffs in the interval\n", "Processing time interval from 2014-06-01 to 2014-12-01\n", "Total of 6512 commits in the interval\n", "Total of 871 AST diffs in the interval\n", "Processing time interval from 2014-12-01 to 2015-06-01\n", "Total of 6618 commits in the interval\n", "Total of 719 AST diffs in the interval\n", "Processing time interval from 2015-06-01 to 2015-12-01\n", "Total of 6827 commits in the interval\n", "Total of 589 AST diffs in the interval\n", "Processing time interval from 2015-12-01 to 2016-06-01\n", "Total of 8035 commits in the interval\n", "Total of 839 AST diffs in the interval\n", "Processing time interval from 2016-06-01 to 2016-12-01\n", "Total of 7323 commits in the interval\n", "Total of 841 AST diffs in the interval\n", "Processing time interval from 2016-12-01 to 2017-06-01\n", "Total of 8170 commits in the interval\n", "Total of 1470 AST diffs in the interval\n", "Processing time interval from 2017-06-01 to 2017-12-01\n", "Total of 7063 commits in the interval\n", "Total of 1079 AST diffs in the interval\n", "Processing time interval from 2017-12-01 to 2018-06-01\n", "Total of 6175 commits in the interval\n", "Total of 430 AST diffs in the interval\n", "Processing time interval from 2018-06-01 to 2018-12-01\n", "Total of 5783 commits in the interval\n", "Total of 823 AST diffs in the interval\n", "Processing time interval from 2018-12-01 to 2019-06-01\n", "Total of 6699 commits in the interval\n", "Total of 754 AST diffs in the interval\n", "Processing time interval from 2019-06-01 to 2019-12-01\n", "Total of 6875 commits in the interval\n", "Total of 1787 AST diffs in the interval\n", "Processing time interval from 2019-12-01 to 2020-06-01\n", "Total of 8036 commits in the interval\n", "Total of 788 AST diffs in the interval\n", "Processing time interval from 2020-06-01 to 2020-12-01\n", "Total of 7538 commits in the interval\n", "Total of 1074 AST diffs in the interval\n", "Processing time interval from 2020-12-01 to 2021-06-01\n", "Total of 7237 commits in the interval\n", "Total of 514 AST diffs in the interval\n", "Processing time interval from 2021-06-01 to 2021-12-01\n", "Total of 6388 commits in the interval\n", "Total of 695 AST diffs in the interval\n", "Processing time interval from 2021-12-01 to 2022-06-01\n", "Total of 6552 commits in the interval\n", "Total of 527 AST diffs in the interval\n" ] } ], "source": [ "# 3.1. Find sha in each AST diff\n", "AST_KEY_SHA_RE = re.compile(r'[0-9a-f]{6,}_[0-9a-f]{6,}')\n", "ast_diff_sha = {k: AST_KEY_SHA_RE.findall(k) for k in ast_diffs.keys()}\n", "tmp_bf = len(ast_diff_sha)\n", "ast_diff_sha = {k: v[0].split('_') for k, v in ast_diff_sha.items() if v}\n", "print(f'Ignored {len(ast_diff_sha) - tmp_bf} AST diff entries')\n", "print(f'Commit sha lengths: {Counter([len(sha) for l in ast_diff_sha.values() for sha in l])}')\n", "\n", "ast_diff_date = {}\n", "for date_start, date_end in zip(labels['Time'][:-1], labels['Time'][1:]):\n", " # 3.2. Find all commit sha in the time interval\n", " print(f'Processing time interval from {date_start} to {date_end}')\n", " commits_sha = {c[:6] for c in get_all_commits_sha(date_start, date_end)['commit']}\n", " print(f'Total of {len(commits_sha)} commits in the interval')\n", "\n", " # 3.3. Find AST diff entries with commit sha in this interval\n", " ast_diffs_in_interval = {k for k, s in ast_diff_sha.items() if all(len(sha) == 6 and sha in commits_sha for sha in s)}\n", " print(f'Total of {len(ast_diffs_in_interval)} AST diffs in the interval')\n", "\n", " # 3.4. Combine AST diffs and save as one file\n", " ast_combined = '\\n\\n'.join(ast_diffs[k] for k in ast_diffs_in_interval)\n", " write(f'ML/ast-diffs/{date_start}.txt', ast_combined)\n", " ast_diff_date[date_start] = ast_combined\n" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": 7, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Cleaning AST Diffs\n" ] } ], "source": [ "# 4. Tokenize / Vectorize\n", "\n", "# 4.1. Remove symbols, numbers\n", "def tmp_clean(s: str) -> str:\n", " s = s.replace('@TO@', '').replace('@AT@', '').replace('@LENGTH@', '')\n", " for c in string.punctuation:\n", " s = s.replace(c, ' ')\n", " while ' ' in s:\n", " s = s.replace(' ', ' ')\n", " return s\n", "print('Cleaning AST Diffs')\n", "ast_diff_date = {k: tmp_clean(v) for k, v in ast_diff_date.items()}\n" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": 8, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "TF-IDF Fitting finished, resulting vector shape: (41, 1000)\n" ] } ], "source": [ "# 4.2. Vectorize X\n", "tf = TfidfVectorizer(ngram_range=(1, 4), max_features=1000)\n", "X = [ast_diff_date[t] for t in labels['Time'] if t in ast_diff_date]\n", "X = tf.fit_transform(X)\n", "print('TF-IDF Fitting finished, resulting vector shape:', X.shape)\n" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": 9, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Y vector size: 41\n" ] } ], "source": [ "# 4.3. Vectorize Y\n", "y = np.array([v != [] for v in labels['Patches Added']][1:])\n", "print('Y vector size:', len(y))" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": 10, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Train size: (30, 1000), Test size: (11, 1000)\n" ] } ], "source": [ "from sklearn.model_selection import train_test_split\n", "\n", "# 4.4. Train test split\n", "X_train, X_test, y_train, y_test = train_test_split(X, y)\n", "\n", "print(f'Train size: {X_train.shape}, Test size: {X_test.shape}')" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": 11, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Training Gradient Boosted Regression Tree model...\n", "Done!\n", "y Test Label: [1 1 1 1 1 0 1 1 0 0 1]\n", "y Prediction: [1 1 1 0 1 1 1 1 1 0 1]\n", "Precision: 77.8\n", "Recall: 87.5\n", "F1: 82.4\n" ] } ], "source": [ "from sklearn.metrics import precision_recall_fscore_support\n", "from sklearn.model_selection import KFold\n", "\n", "def print_prec(y_test, y_pred):\n", " print(f'y Test Label: {np.array([int(x) for x in y_test])}')\n", " print(f'y Prediction: {np.array([int(x) for x in y_pred])}')\n", "\n", " prec, rec, f1, _ = precision_recall_fscore_support(y_test, y_pred, average='binary')\n", " print(f'Precision: {prec * 100:0.1f}')\n", " print(f'Recall: {rec * 100:0.1f}')\n", " print(f'F1: {f1 * 100:0.1f}')\n", "\n", "def k_fold(k,model):\n", " kf=KFold(n_splits=k)\n", " for train_index, test_index in kf.split(X):\n", " #print(\"TRAIN:\", train_index, \"TEST:\", test_index)\n", " X_train , X_test = X[train_index],X[test_index]\n", " y_train , y_test = y[train_index] , y[test_index]\n", " model.fit(X_train,y_train)\n", " pred_values = model.predict(X_test)\n", "\n", " print_prec(y_test,pred_values)\n", "\n", "# 5.1. Gradient Boost Regression Tree model\n", "\n", "\n", "def gbrt():\n", " print('Training Gradient Boosted Regression Tree model...')\n", " classifier: XGBClassifier = XGBClassifier(tree_method='gpu_hist', n_estimators=300)\n", " classifier.fit(X_train, y_train)\n", " print('Done!')\n", "\n", " # 5.1. Get internal accuracy\n", " y_test_pred = classifier.predict(X_test)\n", " print_prec(y_test, y_test_pred)\n", "\n", "\n", "gbrt()" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": 13, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Training Random Forest model...\n", "Done!\n", "y Test Label: [1 1 1 1 1 0 1 1 0 0 1]\n", "y Prediction: [1 1 1 0 1 1 1 1 0 0 1]\n", "Precision: 87.5\n", "Recall: 87.5\n", "F1: 87.5\n" ] } ], "source": [ "from sklearn.ensemble import RandomForestClassifier\n", "\n", "\n", "# 5.2. Random Forest Model\n", "def rf():\n", " print('Training Random Forest model...')\n", " classifier = RandomForestClassifier(n_estimators=300)\n", " classifier.fit(X_train, y_train)\n", " print('Done!')\n", "\n", " # 5.1. Get internal accuracy\n", " y_test_pred = classifier.predict(X_test)\n", " print_prec(y_test, y_test_pred)\n", "\n", "\n", "rf()" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": 14, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Training Logistic Regression model...\n", "Done!\n", "y Test Label: [1 1 1 1 1 0 1 1 0 0 1]\n", "y Prediction: [1 1 1 1 1 1 1 1 1 1 1]\n", "Precision: 72.7\n", "Recall: 100.0\n", "F1: 84.2\n" ] } ], "source": [ "from sklearn.linear_model import LogisticRegression\n", "\n", "\n", "# 5.3. Logistic Regression Model\n", "def logistic():\n", " print('Training Logistic Regression model...')\n", " classifier = LogisticRegression()\n", " classifier.fit(X_train, y_train)\n", " print('Done!')\n", "\n", " # 5.1. Get internal accuracy\n", " y_test_pred = classifier.predict(X_test)\n", " print_prec(y_test, y_test_pred)\n", "\n", "\n", "logistic()\n" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": 15, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Percentage of 1s in dataset: 63.4%\n" ] } ], "source": [ "print(f'Percentage of 1s in dataset: {sum(y) / len(y) * 100 :.1f}%')\n" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": 80, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "TRAIN: [ 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30\n", " 31 32 33 34 35 36 37 38 39 40] TEST: [0 1 2 3 4 5 6]\n", "y Test Label: [1 1 1 1 1 1 0]\n", "y Prediction: [1 1 1 1 1 1 1]\n", "Precision: 85.7\n", "Recall: 100.0\n", "F1: 92.3\n", "TRAIN: [ 0 1 2 3 4 5 6 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30\n", " 31 32 33 34 35 36 37 38 39 40] TEST: [ 7 8 9 10 11 12 13]\n", "y Test Label: [0 1 1 1 1 1 0]\n", "y Prediction: [1 1 1 1 1 1 1]\n", "Precision: 71.4\n", "Recall: 100.0\n", "F1: 83.3\n", "TRAIN: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 21 22 23 24 25 26 27 28 29 30\n", " 31 32 33 34 35 36 37 38 39 40] TEST: [14 15 16 17 18 19 20]\n", "y Test Label: [1 1 0 1 1 1 1]\n", "y Prediction: [1 1 1 1 1 1 1]\n", "Precision: 85.7\n", "Recall: 100.0\n", "F1: 92.3\n", "TRAIN: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 28 29 30\n", " 31 32 33 34 35 36 37 38 39 40] TEST: [21 22 23 24 25 26 27]\n", "y Test Label: [1 0 1 1 1 0 1]\n", "y Prediction: [1 1 1 1 1 1 1]\n", "Precision: 71.4\n", "Recall: 100.0\n", "F1: 83.3\n", "TRAIN: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23\n", " 24 25 26 27 35 36 37 38 39 40] TEST: [28 29 30 31 32 33 34]\n", "y Test Label: [1 0 1 1 0 1 0]\n", "y Prediction: [1 1 1 1 1 1 1]\n", "Precision: 57.1\n", "Recall: 100.0\n", "F1: 72.7\n", "TRAIN: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23\n", " 24 25 26 27 28 29 30 31 32 33 34] TEST: [35 36 37 38 39 40]\n", "y Test Label: [0 0 0 0 0 0]\n", "y Prediction: [1 1 1 1 1 1]\n", "Precision: 0.0\n", "Recall: 0.0\n", "F1: 0.0\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/home/azalea/.conda/envs/fixminerEnv/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1221: UndefinedMetricWarning: Recall and F-score are ill-defined and being set to 0.0 due to no true samples. Use `zero_division` parameter to control this behavior.\n", " _warn_prf(average, modifier, msg_start, len(result))\n" ] } ], "source": [ "k_fold(6,LogisticRegression())\n" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": 82, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "TRAIN: [ 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30\n", " 31 32 33 34 35 36 37 38 39 40] TEST: [0 1 2 3 4 5 6]\n", "y Test Label: [1 1 1 1 1 1 0]\n", "y Prediction: [1 1 1 1 1 1 1]\n", "Precision: 85.7\n", "Recall: 100.0\n", "F1: 92.3\n", "TRAIN: [ 0 1 2 3 4 5 6 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30\n", " 31 32 33 34 35 36 37 38 39 40] TEST: [ 7 8 9 10 11 12 13]\n", "y Test Label: [0 1 1 1 1 1 0]\n", "y Prediction: [1 1 1 1 1 1 1]\n", "Precision: 71.4\n", "Recall: 100.0\n", "F1: 83.3\n", "TRAIN: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 21 22 23 24 25 26 27 28 29 30\n", " 31 32 33 34 35 36 37 38 39 40] TEST: [14 15 16 17 18 19 20]\n", "y Test Label: [1 1 0 1 1 1 1]\n", "y Prediction: [1 1 0 1 1 0 1]\n", "Precision: 100.0\n", "Recall: 83.3\n", "F1: 90.9\n", "TRAIN: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 28 29 30\n", " 31 32 33 34 35 36 37 38 39 40] TEST: [21 22 23 24 25 26 27]\n", "y Test Label: [1 0 1 1 1 0 1]\n", "y Prediction: [1 1 1 1 0 0 0]\n", "Precision: 75.0\n", "Recall: 60.0\n", "F1: 66.7\n", "TRAIN: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23\n", " 24 25 26 27 35 36 37 38 39 40] TEST: [28 29 30 31 32 33 34]\n", "y Test Label: [1 0 1 1 0 1 0]\n", "y Prediction: [0 1 1 0 0 1 0]\n", "Precision: 66.7\n", "Recall: 50.0\n", "F1: 57.1\n", "TRAIN: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23\n", " 24 25 26 27 28 29 30 31 32 33 34] TEST: [35 36 37 38 39 40]\n", "y Test Label: [0 0 0 0 0 0]\n", "y Prediction: [1 1 1 1 1 1]\n", "Precision: 0.0\n", "Recall: 0.0\n", "F1: 0.0\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/home/azalea/.conda/envs/fixminerEnv/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1221: UndefinedMetricWarning: Recall and F-score are ill-defined and being set to 0.0 due to no true samples. Use `zero_division` parameter to control this behavior.\n", " _warn_prf(average, modifier, msg_start, len(result))\n" ] } ], "source": [ "k_fold(6,RandomForestClassifier(n_estimators=300))" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": 83, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "TRAIN: [ 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30\n", " 31 32 33 34 35 36 37 38 39 40] TEST: [0 1 2 3 4 5 6]\n", "y Test Label: [1 1 1 1 1 1 0]\n", "y Prediction: [0 1 0 0 1 1 1]\n", "Precision: 75.0\n", "Recall: 50.0\n", "F1: 60.0\n", "TRAIN: [ 0 1 2 3 4 5 6 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30\n", " 31 32 33 34 35 36 37 38 39 40] TEST: [ 7 8 9 10 11 12 13]\n", "y Test Label: [0 1 1 1 1 1 0]\n", "y Prediction: [1 1 1 1 0 1 1]\n", "Precision: 66.7\n", "Recall: 80.0\n", "F1: 72.7\n", "TRAIN: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 21 22 23 24 25 26 27 28 29 30\n", " 31 32 33 34 35 36 37 38 39 40] TEST: [14 15 16 17 18 19 20]\n", "y Test Label: [1 1 0 1 1 1 1]\n", "y Prediction: [1 0 0 1 1 1 1]\n", "Precision: 100.0\n", "Recall: 83.3\n", "F1: 90.9\n", "TRAIN: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 28 29 30\n", " 31 32 33 34 35 36 37 38 39 40] TEST: [21 22 23 24 25 26 27]\n", "y Test Label: [1 0 1 1 1 0 1]\n", "y Prediction: [1 1 1 1 0 0 1]\n", "Precision: 80.0\n", "Recall: 80.0\n", "F1: 80.0\n", "TRAIN: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23\n", " 24 25 26 27 35 36 37 38 39 40] TEST: [28 29 30 31 32 33 34]\n", "y Test Label: [1 0 1 1 0 1 0]\n", "y Prediction: [0 1 1 0 1 0 1]\n", "Precision: 25.0\n", "Recall: 25.0\n", "F1: 25.0\n", "TRAIN: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23\n", " 24 25 26 27 28 29 30 31 32 33 34] TEST: [35 36 37 38 39 40]\n", "y Test Label: [0 0 0 0 0 0]\n", "y Prediction: [1 1 1 1 1 1]\n", "Precision: 0.0\n", "Recall: 0.0\n", "F1: 0.0\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/home/azalea/.conda/envs/fixminerEnv/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1221: UndefinedMetricWarning: Recall and F-score are ill-defined and being set to 0.0 due to no true samples. Use `zero_division` parameter to control this behavior.\n", " _warn_prf(average, modifier, msg_start, len(result))\n" ] } ], "source": [ "k_fold(6,XGBClassifier(tree_method='gpu_hist', n_estimators=300))\n" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.6" } }, "nbformat": 4, "nbformat_minor": 0 }