906 lines
30 KiB
Plaintext
906 lines
30 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {
|
|
"collapsed": true,
|
|
"pycharm": {
|
|
"name": "#%%\n"
|
|
}
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Python 3.7.9 (default, Aug 31 2020, 12:42:55) \r\n",
|
|
"[GCC 7.3.0]\r\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"!python -VVV"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"outputs": [],
|
|
"source": [
|
|
"import gzip\n",
|
|
"import pickle as p\n",
|
|
"import re\n",
|
|
"import string\n",
|
|
"from collections import Counter\n",
|
|
"\n",
|
|
"import pandas\n",
|
|
"import pandas as pd\n",
|
|
"import numpy as np\n",
|
|
"import redis\n",
|
|
"import xgboost\n",
|
|
"from hypy_utils import write\n",
|
|
"from hypy_utils.tqdm_utils import pmap\n",
|
|
"from xgboost import XGBClassifier\n",
|
|
"\n",
|
|
"from main import job_start_redis\n",
|
|
"from pandas import DataFrame\n",
|
|
"from datetime import datetime, timezone\n",
|
|
"from dateutil.relativedelta import relativedelta\n",
|
|
"import os\n",
|
|
"from pathlib import Path\n",
|
|
"\n",
|
|
"from sklearn.feature_extraction.text import TfidfVectorizer\n",
|
|
"\n",
|
|
"\n",
|
|
"def load_zipped_pickle(filename):\n",
|
|
" with gzip.open(filename, 'rb') as f:\n",
|
|
" loaded_object = p.load(f)\n",
|
|
" l = pd.DataFrame(loaded_object)\n",
|
|
" return l\n",
|
|
"\n",
|
|
"\n",
|
|
"# total changed files\n",
|
|
"# total commits\n",
|
|
"# total fix commits\n",
|
|
"# avaerage changed lines per fix\n",
|
|
"#\n",
|
|
"def get_root_nodes(date: str):\n",
|
|
" result = sorted([])\n",
|
|
" for patch in os.listdir(f\"../../data.absolute/{date}/patterns\"):\n",
|
|
" with open(f\"../../data.absolute/{date}/patterns/{patch}\", 'r') as f:\n",
|
|
" result.append(f.readline())\n",
|
|
" return sorted(list(set(result)))\n",
|
|
"\n",
|
|
"\n",
|
|
"def label_gen() -> DataFrame:\n",
|
|
" start_string = \"2001-12-01\"\n",
|
|
" start_date = datetime.strptime(start_string, '%Y-%m-%d')\n",
|
|
" interval = relativedelta(months=6)\n",
|
|
"\n",
|
|
" csv = []\n",
|
|
"\n",
|
|
" data_path = Path('../../data.absolute')\n",
|
|
" new = get_root_nodes(start_string)\n",
|
|
" new = sorted(list(set(new)))\n",
|
|
" added = sorted(list(set(new)))\n",
|
|
" remove = sorted([])\n",
|
|
" csv.append((start_date.strftime('%Y-%m-%d').split(' ')[0], len(new), len(added), len(remove),\n",
|
|
" added, remove, True))\n",
|
|
"\n",
|
|
" while True:\n",
|
|
" # end = start + interval\n",
|
|
" end_date = start_date + interval\n",
|
|
" end_string = end_date.strftime('%Y-%m-%d').split(' ')[0]\n",
|
|
" start_string = start_date.strftime('%Y-%m-%d').split(' ')[0]\n",
|
|
" if not os.path.isdir(data_path / str(end_string)):\n",
|
|
" # new = os.listdir(data_path / f\"{end_string}/patterns\")\n",
|
|
" # added = sorted(list(set(new)))\n",
|
|
" # remove = sorted([])\n",
|
|
" # csv.append((end_string, len(new), added, remove))\n",
|
|
" break\n",
|
|
"\n",
|
|
" new = get_root_nodes(end_string)\n",
|
|
" old = get_root_nodes(start_string)\n",
|
|
"\n",
|
|
" added = sorted(list(set(new) - set(old)))\n",
|
|
" remove = sorted(list(set(old) - set(new)))\n",
|
|
"\n",
|
|
" csv.append((end_string, len(new), len(added), len(remove), added, remove, len(added) != 0))\n",
|
|
"\n",
|
|
" start_date += interval\n",
|
|
"\n",
|
|
" # plt.plot([v[1] for v in csv], [v[2] for v in csv])\n",
|
|
" # plt.show()\n",
|
|
"\n",
|
|
" df = DataFrame(csv, columns=('Time', 'Number of Patches', 'Numbers of added', 'Numbers of removed', 'Patches Added',\n",
|
|
" 'Patches Removed', 'New Pattern'))\n",
|
|
" df.to_csv('diff-test-absolute-root-2.csv')\n",
|
|
" return df\n",
|
|
"\n",
|
|
"\n",
|
|
"def get_commits_sha(start: str, end: str, project_name: str) -> DataFrame:\n",
|
|
" start_date = datetime.strptime(start, '%Y-%m-%d').replace(tzinfo=timezone.utc)\n",
|
|
" end_date = datetime.strptime(end, '%Y-%m-%d').replace(tzinfo=timezone.utc)\n",
|
|
" commits = load_zipped_pickle(f'/workspace/EECS-Research/data.absolute/{end}/commitsDF/{project_name}-fix.pickle.gz')\n",
|
|
" return commits[commits['commitDate'].between(start_date, end_date, inclusive=False)]['commit']\n"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"pycharm": {
|
|
"name": "#%%\n"
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Shutting down redis 6399...\n",
|
|
"> Shutdown complete.\n",
|
|
"/workspace/EECS-Research/data.absolute/2022-06-01/redis\n",
|
|
"Starting redis 6399...\n",
|
|
"> Redis started.\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"labels = label_gen()\n",
|
|
"\n",
|
|
"date = '2022-06-01'\n",
|
|
"job_start_redis(f'/workspace/EECS-Research/data.absolute/{date}/redis', 6399)"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"pycharm": {
|
|
"name": "#%%\n"
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Connected to redis!\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# 1. Load AST diffs\n",
|
|
"r = redis.StrictRedis(host='localhost', port=6399, db=0)\n",
|
|
"print(\"Connected to redis!\")\n",
|
|
"ast_diffs = {k.decode(): v.decode() for k, v in r.hgetall('dump').items()}\n"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"pycharm": {
|
|
"name": "#%%\n"
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Loading commits: 100%|██████████| 42/42 [00:01<00:00, 22.41it/s]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# 2. Load commits\n",
|
|
"base_path = Path(f'/workspace/EECS-Research/data.absolute/{date}/commitsDF/')\n",
|
|
"commit_pickles = [base_path / str(f) for f in os.listdir(base_path) if f.endswith('-fix.pickle.gz')]\n",
|
|
"commits = pandas.concat(pmap(load_zipped_pickle, commit_pickles, desc=f'Loading commits'))\n",
|
|
"\n",
|
|
"def get_all_commits_sha(start: str, end: str) -> DataFrame:\n",
|
|
" start_date = datetime.strptime(start, '%Y-%m-%d').replace(tzinfo=timezone.utc)\n",
|
|
" end_date = datetime.strptime(end, '%Y-%m-%d').replace(tzinfo=timezone.utc)\n",
|
|
" df = commits[commits['commitDate'].between(start_date, end_date, inclusive=False)]\n",
|
|
" return df\n"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"pycharm": {
|
|
"name": "#%%\n"
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Ignored 0 AST diff entries\n",
|
|
"Commit sha lengths: Counter({6: 61553, 7: 3748, 8: 237, 9: 14})\n",
|
|
"Processing time interval from 2001-12-01 to 2002-06-01\n",
|
|
"Total of 196 commits in the interval\n",
|
|
"Total of 44 AST diffs in the interval\n",
|
|
"Processing time interval from 2002-06-01 to 2002-12-01\n",
|
|
"Total of 207 commits in the interval\n",
|
|
"Total of 32 AST diffs in the interval\n",
|
|
"Processing time interval from 2002-12-01 to 2003-06-01\n",
|
|
"Total of 276 commits in the interval\n",
|
|
"Total of 27 AST diffs in the interval\n",
|
|
"Processing time interval from 2003-06-01 to 2003-12-01\n",
|
|
"Total of 477 commits in the interval\n",
|
|
"Total of 60 AST diffs in the interval\n",
|
|
"Processing time interval from 2003-12-01 to 2004-06-01\n",
|
|
"Total of 920 commits in the interval\n",
|
|
"Total of 53 AST diffs in the interval\n",
|
|
"Processing time interval from 2004-06-01 to 2004-12-01\n",
|
|
"Total of 627 commits in the interval\n",
|
|
"Total of 75 AST diffs in the interval\n",
|
|
"Processing time interval from 2004-12-01 to 2005-06-01\n",
|
|
"Total of 517 commits in the interval\n",
|
|
"Total of 40 AST diffs in the interval\n",
|
|
"Processing time interval from 2005-06-01 to 2005-12-01\n",
|
|
"Total of 484 commits in the interval\n",
|
|
"Total of 33 AST diffs in the interval\n",
|
|
"Processing time interval from 2005-12-01 to 2006-06-01\n",
|
|
"Total of 613 commits in the interval\n",
|
|
"Total of 69 AST diffs in the interval\n",
|
|
"Processing time interval from 2006-06-01 to 2006-12-01\n",
|
|
"Total of 508 commits in the interval\n",
|
|
"Total of 37 AST diffs in the interval\n",
|
|
"Processing time interval from 2006-12-01 to 2007-06-01\n",
|
|
"Total of 1757 commits in the interval\n",
|
|
"Total of 141 AST diffs in the interval\n",
|
|
"Processing time interval from 2007-06-01 to 2007-12-01\n",
|
|
"Total of 2451 commits in the interval\n",
|
|
"Total of 241 AST diffs in the interval\n",
|
|
"Processing time interval from 2007-12-01 to 2008-06-01\n",
|
|
"Total of 3996 commits in the interval\n",
|
|
"Total of 447 AST diffs in the interval\n",
|
|
"Processing time interval from 2008-06-01 to 2008-12-01\n",
|
|
"Total of 3167 commits in the interval\n",
|
|
"Total of 481 AST diffs in the interval\n",
|
|
"Processing time interval from 2008-12-01 to 2009-06-01\n",
|
|
"Total of 3519 commits in the interval\n",
|
|
"Total of 556 AST diffs in the interval\n",
|
|
"Processing time interval from 2009-06-01 to 2009-12-01\n",
|
|
"Total of 3243 commits in the interval\n",
|
|
"Total of 548 AST diffs in the interval\n",
|
|
"Processing time interval from 2009-12-01 to 2010-06-01\n",
|
|
"Total of 3184 commits in the interval\n",
|
|
"Total of 359 AST diffs in the interval\n",
|
|
"Processing time interval from 2010-06-01 to 2010-12-01\n",
|
|
"Total of 5523 commits in the interval\n",
|
|
"Total of 746 AST diffs in the interval\n",
|
|
"Processing time interval from 2010-12-01 to 2011-06-01\n",
|
|
"Total of 10476 commits in the interval\n",
|
|
"Total of 1331 AST diffs in the interval\n",
|
|
"Processing time interval from 2011-06-01 to 2011-12-01\n",
|
|
"Total of 9990 commits in the interval\n",
|
|
"Total of 1965 AST diffs in the interval\n",
|
|
"Processing time interval from 2011-12-01 to 2012-06-01\n",
|
|
"Total of 9092 commits in the interval\n",
|
|
"Total of 1466 AST diffs in the interval\n",
|
|
"Processing time interval from 2012-06-01 to 2012-12-01\n",
|
|
"Total of 6427 commits in the interval\n",
|
|
"Total of 924 AST diffs in the interval\n",
|
|
"Processing time interval from 2012-12-01 to 2013-06-01\n",
|
|
"Total of 6943 commits in the interval\n",
|
|
"Total of 924 AST diffs in the interval\n",
|
|
"Processing time interval from 2013-06-01 to 2013-12-01\n",
|
|
"Total of 8525 commits in the interval\n",
|
|
"Total of 1208 AST diffs in the interval\n",
|
|
"Processing time interval from 2013-12-01 to 2014-06-01\n",
|
|
"Total of 6567 commits in the interval\n",
|
|
"Total of 955 AST diffs in the interval\n",
|
|
"Processing time interval from 2014-06-01 to 2014-12-01\n",
|
|
"Total of 6512 commits in the interval\n",
|
|
"Total of 871 AST diffs in the interval\n",
|
|
"Processing time interval from 2014-12-01 to 2015-06-01\n",
|
|
"Total of 6618 commits in the interval\n",
|
|
"Total of 719 AST diffs in the interval\n",
|
|
"Processing time interval from 2015-06-01 to 2015-12-01\n",
|
|
"Total of 6827 commits in the interval\n",
|
|
"Total of 589 AST diffs in the interval\n",
|
|
"Processing time interval from 2015-12-01 to 2016-06-01\n",
|
|
"Total of 8035 commits in the interval\n",
|
|
"Total of 839 AST diffs in the interval\n",
|
|
"Processing time interval from 2016-06-01 to 2016-12-01\n",
|
|
"Total of 7323 commits in the interval\n",
|
|
"Total of 841 AST diffs in the interval\n",
|
|
"Processing time interval from 2016-12-01 to 2017-06-01\n",
|
|
"Total of 8170 commits in the interval\n",
|
|
"Total of 1470 AST diffs in the interval\n",
|
|
"Processing time interval from 2017-06-01 to 2017-12-01\n",
|
|
"Total of 7063 commits in the interval\n",
|
|
"Total of 1079 AST diffs in the interval\n",
|
|
"Processing time interval from 2017-12-01 to 2018-06-01\n",
|
|
"Total of 6175 commits in the interval\n",
|
|
"Total of 430 AST diffs in the interval\n",
|
|
"Processing time interval from 2018-06-01 to 2018-12-01\n",
|
|
"Total of 5783 commits in the interval\n",
|
|
"Total of 823 AST diffs in the interval\n",
|
|
"Processing time interval from 2018-12-01 to 2019-06-01\n",
|
|
"Total of 6699 commits in the interval\n",
|
|
"Total of 754 AST diffs in the interval\n",
|
|
"Processing time interval from 2019-06-01 to 2019-12-01\n",
|
|
"Total of 6875 commits in the interval\n",
|
|
"Total of 1787 AST diffs in the interval\n",
|
|
"Processing time interval from 2019-12-01 to 2020-06-01\n",
|
|
"Total of 8036 commits in the interval\n",
|
|
"Total of 788 AST diffs in the interval\n",
|
|
"Processing time interval from 2020-06-01 to 2020-12-01\n",
|
|
"Total of 7538 commits in the interval\n",
|
|
"Total of 1074 AST diffs in the interval\n",
|
|
"Processing time interval from 2020-12-01 to 2021-06-01\n",
|
|
"Total of 7237 commits in the interval\n",
|
|
"Total of 514 AST diffs in the interval\n",
|
|
"Processing time interval from 2021-06-01 to 2021-12-01\n",
|
|
"Total of 6388 commits in the interval\n",
|
|
"Total of 695 AST diffs in the interval\n",
|
|
"Processing time interval from 2021-12-01 to 2022-06-01\n",
|
|
"Total of 6552 commits in the interval\n",
|
|
"Total of 527 AST diffs in the interval\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# 3.1. Find sha in each AST diff\n",
|
|
"AST_KEY_SHA_RE = re.compile(r'[0-9a-f]{6,}_[0-9a-f]{6,}')\n",
|
|
"ast_diff_sha = {k: AST_KEY_SHA_RE.findall(k) for k in ast_diffs.keys()}\n",
|
|
"tmp_bf = len(ast_diff_sha)\n",
|
|
"ast_diff_sha = {k: v[0].split('_') for k, v in ast_diff_sha.items() if v}\n",
|
|
"print(f'Ignored {len(ast_diff_sha) - tmp_bf} AST diff entries')\n",
|
|
"print(f'Commit sha lengths: {Counter([len(sha) for l in ast_diff_sha.values() for sha in l])}')\n",
|
|
"\n",
|
|
"ast_diff_date = {}\n",
|
|
"for date_start, date_end in zip(labels['Time'][:-1], labels['Time'][1:]):\n",
|
|
" # 3.2. Find all commit sha in the time interval\n",
|
|
" print(f'Processing time interval from {date_start} to {date_end}')\n",
|
|
" commits_sha = {c[:6] for c in get_all_commits_sha(date_start, date_end)['commit']}\n",
|
|
" print(f'Total of {len(commits_sha)} commits in the interval')\n",
|
|
"\n",
|
|
" # 3.3. Find AST diff entries with commit sha in this interval\n",
|
|
" ast_diffs_in_interval = {k for k, s in ast_diff_sha.items() if all(len(sha) == 6 and sha in commits_sha for sha in s)}\n",
|
|
" print(f'Total of {len(ast_diffs_in_interval)} AST diffs in the interval')\n",
|
|
"\n",
|
|
" # 3.4. Combine AST diffs and save as one file\n",
|
|
" ast_combined = '\\n\\n'.join(ast_diffs[k] for k in ast_diffs_in_interval)\n",
|
|
" write(f'ML/ast-diffs/{date_start}.txt', ast_combined)\n",
|
|
" ast_diff_date[date_start] = ast_combined\n"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"pycharm": {
|
|
"name": "#%%\n"
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Cleaning AST Diffs\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# 4. Tokenize / Vectorize\n",
|
|
"\n",
|
|
"# 4.1. Remove symbols, numbers\n",
|
|
"def tmp_clean(s: str) -> str:\n",
|
|
" s = s.replace('@TO@', '').replace('@AT@', '').replace('@LENGTH@', '')\n",
|
|
" for c in string.punctuation:\n",
|
|
" s = s.replace(c, ' ')\n",
|
|
" while ' ' in s:\n",
|
|
" s = s.replace(' ', ' ')\n",
|
|
" return s\n",
|
|
"print('Cleaning AST Diffs')\n",
|
|
"ast_diff_date = {k: tmp_clean(v) for k, v in ast_diff_date.items()}\n"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"pycharm": {
|
|
"name": "#%%\n"
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"TF-IDF Fitting finished, resulting vector shape: (41, 1000)\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# 4.2. Vectorize X\n",
|
|
"tf = TfidfVectorizer(ngram_range=(1, 4), max_features=1000)\n",
|
|
"X = [ast_diff_date[t] for t in labels['Time'] if t in ast_diff_date]\n",
|
|
"X = tf.fit_transform(X)\n",
|
|
"print('TF-IDF Fitting finished, resulting vector shape:', X.shape)\n"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"pycharm": {
|
|
"name": "#%%\n"
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Y vector size: 41\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# 4.3. Vectorize Y\n",
|
|
"y = np.array([v != [] for v in labels['Patches Added']][1:])\n",
|
|
"print('Y vector size:', len(y))"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"pycharm": {
|
|
"name": "#%%\n"
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 10,
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Train size: (30, 1000), Test size: (11, 1000)\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"from sklearn.model_selection import train_test_split\n",
|
|
"\n",
|
|
"# 4.4. Train test split\n",
|
|
"X_train, X_test, y_train, y_test = train_test_split(X, y)\n",
|
|
"\n",
|
|
"print(f'Train size: {X_train.shape}, Test size: {X_test.shape}')"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"pycharm": {
|
|
"name": "#%%\n"
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Training Gradient Boosted Regression Tree model...\n",
|
|
"Done!\n",
|
|
"y Test Label: [1 1 1 1 1 0 1 1 0 0 1]\n",
|
|
"y Prediction: [1 1 1 0 1 1 1 1 1 0 1]\n",
|
|
"Precision: 77.8\n",
|
|
"Recall: 87.5\n",
|
|
"F1: 82.4\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"from sklearn.metrics import precision_recall_fscore_support\n",
|
|
"from sklearn.model_selection import KFold\n",
|
|
"\n",
|
|
"def print_prec(y_test, y_pred):\n",
|
|
" print(f'y Test Label: {np.array([int(x) for x in y_test])}')\n",
|
|
" print(f'y Prediction: {np.array([int(x) for x in y_pred])}')\n",
|
|
"\n",
|
|
" prec, rec, f1, _ = precision_recall_fscore_support(y_test, y_pred, average='binary')\n",
|
|
" print(f'Precision: {prec * 100:0.1f}')\n",
|
|
" print(f'Recall: {rec * 100:0.1f}')\n",
|
|
" print(f'F1: {f1 * 100:0.1f}')\n",
|
|
"\n",
|
|
"def k_fold(k,model):\n",
|
|
" kf=KFold(n_splits=k)\n",
|
|
" for train_index, test_index in kf.split(X):\n",
|
|
" #print(\"TRAIN:\", train_index, \"TEST:\", test_index)\n",
|
|
" X_train , X_test = X[train_index],X[test_index]\n",
|
|
" y_train , y_test = y[train_index] , y[test_index]\n",
|
|
" model.fit(X_train,y_train)\n",
|
|
" pred_values = model.predict(X_test)\n",
|
|
"\n",
|
|
" print_prec(y_test,pred_values)\n",
|
|
"\n",
|
|
"# 5.1. Gradient Boost Regression Tree model\n",
|
|
"\n",
|
|
"\n",
|
|
"def gbrt():\n",
|
|
" print('Training Gradient Boosted Regression Tree model...')\n",
|
|
" classifier: XGBClassifier = XGBClassifier(tree_method='gpu_hist', n_estimators=300)\n",
|
|
" classifier.fit(X_train, y_train)\n",
|
|
" print('Done!')\n",
|
|
"\n",
|
|
" # 5.1. Get internal accuracy\n",
|
|
" y_test_pred = classifier.predict(X_test)\n",
|
|
" print_prec(y_test, y_test_pred)\n",
|
|
"\n",
|
|
"\n",
|
|
"gbrt()"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"pycharm": {
|
|
"name": "#%%\n"
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 13,
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Training Random Forest model...\n",
|
|
"Done!\n",
|
|
"y Test Label: [1 1 1 1 1 0 1 1 0 0 1]\n",
|
|
"y Prediction: [1 1 1 0 1 1 1 1 0 0 1]\n",
|
|
"Precision: 87.5\n",
|
|
"Recall: 87.5\n",
|
|
"F1: 87.5\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"from sklearn.ensemble import RandomForestClassifier\n",
|
|
"\n",
|
|
"\n",
|
|
"# 5.2. Random Forest Model\n",
|
|
"def rf():\n",
|
|
" print('Training Random Forest model...')\n",
|
|
" classifier = RandomForestClassifier(n_estimators=300)\n",
|
|
" classifier.fit(X_train, y_train)\n",
|
|
" print('Done!')\n",
|
|
"\n",
|
|
" # 5.1. Get internal accuracy\n",
|
|
" y_test_pred = classifier.predict(X_test)\n",
|
|
" print_prec(y_test, y_test_pred)\n",
|
|
"\n",
|
|
"\n",
|
|
"rf()"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"pycharm": {
|
|
"name": "#%%\n"
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 14,
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Training Logistic Regression model...\n",
|
|
"Done!\n",
|
|
"y Test Label: [1 1 1 1 1 0 1 1 0 0 1]\n",
|
|
"y Prediction: [1 1 1 1 1 1 1 1 1 1 1]\n",
|
|
"Precision: 72.7\n",
|
|
"Recall: 100.0\n",
|
|
"F1: 84.2\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"from sklearn.linear_model import LogisticRegression\n",
|
|
"\n",
|
|
"\n",
|
|
"# 5.3. Logistic Regression Model\n",
|
|
"def logistic():\n",
|
|
" print('Training Logistic Regression model...')\n",
|
|
" classifier = LogisticRegression()\n",
|
|
" classifier.fit(X_train, y_train)\n",
|
|
" print('Done!')\n",
|
|
"\n",
|
|
" # 5.1. Get internal accuracy\n",
|
|
" y_test_pred = classifier.predict(X_test)\n",
|
|
" print_prec(y_test, y_test_pred)\n",
|
|
"\n",
|
|
"\n",
|
|
"logistic()\n"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"pycharm": {
|
|
"name": "#%%\n"
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 15,
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Percentage of 1s in dataset: 63.4%\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(f'Percentage of 1s in dataset: {sum(y) / len(y) * 100 :.1f}%')\n"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"pycharm": {
|
|
"name": "#%%\n"
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 80,
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"TRAIN: [ 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30\n",
|
|
" 31 32 33 34 35 36 37 38 39 40] TEST: [0 1 2 3 4 5 6]\n",
|
|
"y Test Label: [1 1 1 1 1 1 0]\n",
|
|
"y Prediction: [1 1 1 1 1 1 1]\n",
|
|
"Precision: 85.7\n",
|
|
"Recall: 100.0\n",
|
|
"F1: 92.3\n",
|
|
"TRAIN: [ 0 1 2 3 4 5 6 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30\n",
|
|
" 31 32 33 34 35 36 37 38 39 40] TEST: [ 7 8 9 10 11 12 13]\n",
|
|
"y Test Label: [0 1 1 1 1 1 0]\n",
|
|
"y Prediction: [1 1 1 1 1 1 1]\n",
|
|
"Precision: 71.4\n",
|
|
"Recall: 100.0\n",
|
|
"F1: 83.3\n",
|
|
"TRAIN: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 21 22 23 24 25 26 27 28 29 30\n",
|
|
" 31 32 33 34 35 36 37 38 39 40] TEST: [14 15 16 17 18 19 20]\n",
|
|
"y Test Label: [1 1 0 1 1 1 1]\n",
|
|
"y Prediction: [1 1 1 1 1 1 1]\n",
|
|
"Precision: 85.7\n",
|
|
"Recall: 100.0\n",
|
|
"F1: 92.3\n",
|
|
"TRAIN: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 28 29 30\n",
|
|
" 31 32 33 34 35 36 37 38 39 40] TEST: [21 22 23 24 25 26 27]\n",
|
|
"y Test Label: [1 0 1 1 1 0 1]\n",
|
|
"y Prediction: [1 1 1 1 1 1 1]\n",
|
|
"Precision: 71.4\n",
|
|
"Recall: 100.0\n",
|
|
"F1: 83.3\n",
|
|
"TRAIN: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23\n",
|
|
" 24 25 26 27 35 36 37 38 39 40] TEST: [28 29 30 31 32 33 34]\n",
|
|
"y Test Label: [1 0 1 1 0 1 0]\n",
|
|
"y Prediction: [1 1 1 1 1 1 1]\n",
|
|
"Precision: 57.1\n",
|
|
"Recall: 100.0\n",
|
|
"F1: 72.7\n",
|
|
"TRAIN: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23\n",
|
|
" 24 25 26 27 28 29 30 31 32 33 34] TEST: [35 36 37 38 39 40]\n",
|
|
"y Test Label: [0 0 0 0 0 0]\n",
|
|
"y Prediction: [1 1 1 1 1 1]\n",
|
|
"Precision: 0.0\n",
|
|
"Recall: 0.0\n",
|
|
"F1: 0.0\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"/home/azalea/.conda/envs/fixminerEnv/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1221: UndefinedMetricWarning: Recall and F-score are ill-defined and being set to 0.0 due to no true samples. Use `zero_division` parameter to control this behavior.\n",
|
|
" _warn_prf(average, modifier, msg_start, len(result))\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"k_fold(6,LogisticRegression())\n"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"pycharm": {
|
|
"name": "#%%\n"
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 82,
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"TRAIN: [ 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30\n",
|
|
" 31 32 33 34 35 36 37 38 39 40] TEST: [0 1 2 3 4 5 6]\n",
|
|
"y Test Label: [1 1 1 1 1 1 0]\n",
|
|
"y Prediction: [1 1 1 1 1 1 1]\n",
|
|
"Precision: 85.7\n",
|
|
"Recall: 100.0\n",
|
|
"F1: 92.3\n",
|
|
"TRAIN: [ 0 1 2 3 4 5 6 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30\n",
|
|
" 31 32 33 34 35 36 37 38 39 40] TEST: [ 7 8 9 10 11 12 13]\n",
|
|
"y Test Label: [0 1 1 1 1 1 0]\n",
|
|
"y Prediction: [1 1 1 1 1 1 1]\n",
|
|
"Precision: 71.4\n",
|
|
"Recall: 100.0\n",
|
|
"F1: 83.3\n",
|
|
"TRAIN: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 21 22 23 24 25 26 27 28 29 30\n",
|
|
" 31 32 33 34 35 36 37 38 39 40] TEST: [14 15 16 17 18 19 20]\n",
|
|
"y Test Label: [1 1 0 1 1 1 1]\n",
|
|
"y Prediction: [1 1 0 1 1 0 1]\n",
|
|
"Precision: 100.0\n",
|
|
"Recall: 83.3\n",
|
|
"F1: 90.9\n",
|
|
"TRAIN: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 28 29 30\n",
|
|
" 31 32 33 34 35 36 37 38 39 40] TEST: [21 22 23 24 25 26 27]\n",
|
|
"y Test Label: [1 0 1 1 1 0 1]\n",
|
|
"y Prediction: [1 1 1 1 0 0 0]\n",
|
|
"Precision: 75.0\n",
|
|
"Recall: 60.0\n",
|
|
"F1: 66.7\n",
|
|
"TRAIN: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23\n",
|
|
" 24 25 26 27 35 36 37 38 39 40] TEST: [28 29 30 31 32 33 34]\n",
|
|
"y Test Label: [1 0 1 1 0 1 0]\n",
|
|
"y Prediction: [0 1 1 0 0 1 0]\n",
|
|
"Precision: 66.7\n",
|
|
"Recall: 50.0\n",
|
|
"F1: 57.1\n",
|
|
"TRAIN: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23\n",
|
|
" 24 25 26 27 28 29 30 31 32 33 34] TEST: [35 36 37 38 39 40]\n",
|
|
"y Test Label: [0 0 0 0 0 0]\n",
|
|
"y Prediction: [1 1 1 1 1 1]\n",
|
|
"Precision: 0.0\n",
|
|
"Recall: 0.0\n",
|
|
"F1: 0.0\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"/home/azalea/.conda/envs/fixminerEnv/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1221: UndefinedMetricWarning: Recall and F-score are ill-defined and being set to 0.0 due to no true samples. Use `zero_division` parameter to control this behavior.\n",
|
|
" _warn_prf(average, modifier, msg_start, len(result))\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"k_fold(6,RandomForestClassifier(n_estimators=300))"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"pycharm": {
|
|
"name": "#%%\n"
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 83,
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"TRAIN: [ 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30\n",
|
|
" 31 32 33 34 35 36 37 38 39 40] TEST: [0 1 2 3 4 5 6]\n",
|
|
"y Test Label: [1 1 1 1 1 1 0]\n",
|
|
"y Prediction: [0 1 0 0 1 1 1]\n",
|
|
"Precision: 75.0\n",
|
|
"Recall: 50.0\n",
|
|
"F1: 60.0\n",
|
|
"TRAIN: [ 0 1 2 3 4 5 6 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30\n",
|
|
" 31 32 33 34 35 36 37 38 39 40] TEST: [ 7 8 9 10 11 12 13]\n",
|
|
"y Test Label: [0 1 1 1 1 1 0]\n",
|
|
"y Prediction: [1 1 1 1 0 1 1]\n",
|
|
"Precision: 66.7\n",
|
|
"Recall: 80.0\n",
|
|
"F1: 72.7\n",
|
|
"TRAIN: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 21 22 23 24 25 26 27 28 29 30\n",
|
|
" 31 32 33 34 35 36 37 38 39 40] TEST: [14 15 16 17 18 19 20]\n",
|
|
"y Test Label: [1 1 0 1 1 1 1]\n",
|
|
"y Prediction: [1 0 0 1 1 1 1]\n",
|
|
"Precision: 100.0\n",
|
|
"Recall: 83.3\n",
|
|
"F1: 90.9\n",
|
|
"TRAIN: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 28 29 30\n",
|
|
" 31 32 33 34 35 36 37 38 39 40] TEST: [21 22 23 24 25 26 27]\n",
|
|
"y Test Label: [1 0 1 1 1 0 1]\n",
|
|
"y Prediction: [1 1 1 1 0 0 1]\n",
|
|
"Precision: 80.0\n",
|
|
"Recall: 80.0\n",
|
|
"F1: 80.0\n",
|
|
"TRAIN: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23\n",
|
|
" 24 25 26 27 35 36 37 38 39 40] TEST: [28 29 30 31 32 33 34]\n",
|
|
"y Test Label: [1 0 1 1 0 1 0]\n",
|
|
"y Prediction: [0 1 1 0 1 0 1]\n",
|
|
"Precision: 25.0\n",
|
|
"Recall: 25.0\n",
|
|
"F1: 25.0\n",
|
|
"TRAIN: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23\n",
|
|
" 24 25 26 27 28 29 30 31 32 33 34] TEST: [35 36 37 38 39 40]\n",
|
|
"y Test Label: [0 0 0 0 0 0]\n",
|
|
"y Prediction: [1 1 1 1 1 1]\n",
|
|
"Precision: 0.0\n",
|
|
"Recall: 0.0\n",
|
|
"F1: 0.0\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"/home/azalea/.conda/envs/fixminerEnv/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1221: UndefinedMetricWarning: Recall and F-score are ill-defined and being set to 0.0 due to no true samples. Use `zero_division` parameter to control this behavior.\n",
|
|
" _warn_prf(average, modifier, msg_start, len(result))\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"k_fold(6,XGBClassifier(tree_method='gpu_hist', n_estimators=300))\n"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"pycharm": {
|
|
"name": "#%%\n"
|
|
}
|
|
}
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 2
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython2",
|
|
"version": "2.7.6"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 0
|
|
} |