Lambda-School-Labs/allay-ds

View on GitHub
exploration/train_ml_models.ipynb

Summary

Maintainability
Test Coverage
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "Collapsed": "false"
   },
   "source": [
    "# Train Hate Speech Recognition Models"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "Collapsed": "false"
   },
   "source": [
    "With the data cleaned and processed, this notebook implements model training on the data sets. The code in this notebook assumes that cleaned data is in the filepath `\"data/combined_deduped.csv\"`. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "Collapsed": "false"
   },
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import spacy\n",
    "import matplotlib.pyplot as plt\n",
    "import wandb\n",
    "from sklearn.linear_model import LogisticRegression, SGDClassifier\n",
    "from sklearn.model_selection import train_test_split, GridSearchCV, RandomizedSearchCV\n",
    "from sklearn.metrics import roc_curve, roc_auc_score, precision_recall_fscore_support, classification_report\n",
    "from sklearn.naive_bayes import MultinomialNB\n",
    "from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer\n",
    "from sklearn.ensemble import GradientBoostingClassifier, RandomForestClassifier\n",
    "from sklearn.pipeline import Pipeline, make_pipeline\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from sklearn.svm import SVC, LinearSVC\n",
    "from tensorflow import keras\n",
    "from wandb.keras import WandbCallback\n",
    "from tensorflow.keras import regularizers\n",
    "from xgboost import XGBClassifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [],
   "source": [
    "nlp = spacy.load(\"en_core_web_md\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "Collapsed": "false"
   },
   "source": [
    "## Helper Functions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "Collapsed": "false"
   },
   "source": [
    "We'll start by fitting basic models with default parameters. As machine learning models and neural networks have many hyperparameters that can be tweaked, these will serve as a baseline: combinations of hyperparameter settings that improve the model from this baseline should be investigated, while those that degrade model performance should be seen as less promising.\n",
    "\n",
    "As modeling is an iterative process, we start by establishing some functions to automate the repetitive aspects of establishing baselines."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [],
   "source": [
    "def reset_data_with_val():\n",
    "    \"\"\"\n",
    "    Loads data, returns training and validation data\"\"\"\n",
    "    tweets = pd.read_csv(\"data/combined_deduped.csv\")\n",
    "    train, test = train_test_split(tweets, test_size=.2, random_state=42)\n",
    "    train, val = train_test_split(train, test_size=.15, random_state=42)\n",
    "    \n",
    "    x_train, y_train, x_val, y_val = train[\"tweet\"], train[\"inappropriate\"], val[\"tweet\"], val[\"inappropriate\"]\n",
    "    \n",
    "    return x_train, y_train, x_val, y_val"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [],
   "source": [
    "def baseline_model(vect, clf, x_train, y_train, x_val, y_val, scaling=False):\n",
    "    \"\"\"Takes a text vectorizer, a classification model, and \n",
    "    data and outputs that model's score on the data.\n",
    "    \n",
    "    Sub-arguments can be passed to the arguments in the function: \n",
    "    for example, passing `vect = CountVectorizer(stop_words='english')`.\n",
    "    However, as this function doesn't return a fitted model, just model\n",
    "    scores, it is primarily intended to quickly test many basic \n",
    "    model archetypes.\n",
    "\n",
    "    Arguments:\n",
    "        vect {vectorizer} -- e.g. CountVectorizer(), TfidfVectorizer()\n",
    "        clf {classifier} -- eg RandomForest(), MultinomialNB()\n",
    "        x_train {array} -- features and values of the training set\n",
    "        y_train {1d array} -- target values for the training set\n",
    "        x_val {array} -- features and values of the validation set\n",
    "        y_val {1d array} -- target values for the validation set\n",
    "    \n",
    "    Keyword Arguments:\n",
    "        scaling {bool} -- scale data if required by model (default: {False})\n",
    "    \n",
    "    Returns:\n",
    "        accuracy, precision, recall, f1, roc_auc -- metrics of model performance\n",
    "    \"\"\"\n",
    "    x_train = vect.fit_transform(x_train)\n",
    "    x_val = vect.transform(x_val)\n",
    "    \n",
    "    if scaling == True:\n",
    "        scaler = StandardScaler(with_mean=False)\n",
    "        x_train = scaler.fit_transform(x_train)\n",
    "        x_val = scaler.fit_transform(x_val)\n",
    "        \n",
    "    clf.fit(x_train, y_train)\n",
    "    y_pred = clf.predict(x_val)\n",
    "    \n",
    "    accuracy = clf.score(x_val, y_val)\n",
    "    precision, recall, f1, other = precision_recall_fscore_support(y_val, \n",
    "                                                                   y_pred, \n",
    "                                                                   average='binary')\n",
    "    roc_auc = roc_auc_score(y_val, y_pred)\n",
    "\n",
    "    return accuracy, precision, recall, f1, roc_auc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [],
   "source": [
    "def report_to_wandb(accuracy, precision, recall, f1, roc_auc):\n",
    "    \"\"\"Reports a dictionary summarizing model performance\n",
    "    to the \"Weights and Biases\" app associated with this project.\n",
    "    \n",
    "    Arguments:\n",
    "        accuracy {float} -- accuracy of tested model\n",
    "        precision {float} -- precision of tested model\n",
    "        recall {float} -- recall of tested model\n",
    "        f1 {float} -- f1-score of tested model\n",
    "        roc_auc {float} -- roc_auc score of tested model\n",
    "    \"\"\"\n",
    "    wandb.log({'accuracy':accuracy, 'recall':recall, \n",
    "               'f1':f1, 'precision':precision, \n",
    "               'roc_auc_score':roc_auc})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [],
   "source": [
    "x_train, y_train, x_val, y_val = reset_data_with_val()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [],
   "source": [
    "def implement_training(vect, clf, \n",
    "    x_train=x_train, y_train=y_train, \n",
    "    x_val=x_val, y_val=y_val):\n",
    "    \"\"\"This function initializes a weights and biases run, then calls the \n",
    "    helper functions in order to perform a full cycle of model training.\n",
    "\n",
    "    Note that attempting to define this function before x_train, y_train, \n",
    "    x_val, and y_val are assigned local variables will generate an error, as \n",
    "    they're called as the default arguments. These have been set as default\n",
    "    arguments for convenience when fitting multiple models. \n",
    "    \n",
    "    Arguments:\n",
    "        vect {Vectorizer} -- sklearn-compatible text vectorizer\n",
    "        clf {Classifier} -- sklearn-compatible classification model\n",
    "    \n",
    "    Keyword Arguments:\n",
    "        x_train {array} -- training features/values (default: {x_train})\n",
    "        y_train {1d array} -- training target values (default: {y_train})\n",
    "        x_val {array} -- validation features/values (default: {x_val})\n",
    "        y_val {1d array} -- validation target values (default: {y_val})\n",
    "    \"\"\"\n",
    "    wandb.init(project=\"allay-ds-23\")\n",
    "    accuracy, precision, recall, f1, roc_auc = baseline_model(vect, clf, \n",
    "                                        x_train, y_train, \n",
    "                                        x_val, y_val)\n",
    "    report_to_wandb(accuracy, precision, recall, f1, roc_auc)\n",
    "\n",
    "    return(\"Accuracy:\", accuracy, \"Precision:\", precision,\n",
    "        \"Recall:\", recall, \"F1:\", f1, \"ROC_AUC:\", roc_auc)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "Collapsed": "false"
   },
   "source": [
    "# Baseline ML Models"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "Collapsed": "false"
   },
   "source": [
    "## Majority Classifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "False    0.582441\n",
       "True     0.417559\n",
       "Name: inappropriate, dtype: float64"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_val.value_counts(normalize=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "              precision    recall  f1-score   support\n",
      "\n",
      "       False     0.5824    1.0000    0.7361     10223\n",
      "        True     0.0000    0.0000    0.0000      7329\n",
      "\n",
      "    accuracy                         0.5824     17552\n",
      "   macro avg     0.2912    0.5000    0.3681     17552\n",
      "weighted avg     0.3392    0.5824    0.4288     17552\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\ajenk\\.virtualenvs\\allay-ds-cRyEcJS9\\lib\\site-packages\\sklearn\\metrics\\_classification.py:1272: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n"
     ]
    }
   ],
   "source": [
    "y_pred = y_val * 0\n",
    "print(classification_report(y_val, y_pred, digits=4))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "Collapsed": "false"
   },
   "source": [
    "## Logistic Regression"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [],
   "source": [
    "vect = CountVectorizer()\n",
    "clf = LogisticRegression()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "                Logging results to <a href=\"https://wandb.com\" target=\"_blank\">Weights & Biases</a> <a href=\"https://docs.wandb.com/integrations/jupyter.html\" target=\"_blank\">(Documentation)</a>.<br/>\n",
       "                Project page: <a href=\"https://app.wandb.ai/alexmjn/allay-ds-23\" target=\"_blank\">https://app.wandb.ai/alexmjn/allay-ds-23</a><br/>\n",
       "                Run page: <a href=\"https://app.wandb.ai/alexmjn/allay-ds-23/runs/3iw8dreb\" target=\"_blank\">https://app.wandb.ai/alexmjn/allay-ds-23/runs/3iw8dreb</a><br/>\n",
       "            "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\ajenk\\.virtualenvs\\allay-ds-cRyEcJS9\\lib\\site-packages\\sklearn\\linear_model\\_logistic.py:940: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
      "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
      "\n",
      "Increase the number of iterations (max_iter) or scale the data as shown in:\n",
      "    https://scikit-learn.org/stable/modules/preprocessing.html\n",
      "Please also refer to the documentation for alternative solver options:\n",
      "    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
      "  extra_warning_msg=_LOGISTIC_SOLVER_CONVERGENCE_MSG)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "('Accuracy:',\n",
       " 0.8650296262534184,\n",
       " 'Precision:',\n",
       " 0.8802514566084023,\n",
       " 'Recall:',\n",
       " 0.7833265111202073,\n",
       " 'F1:',\n",
       " 0.8289654176593747,\n",
       " 'ROC_AUC:',\n",
       " 0.8534650749868863)"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "implement_training(vect, clf)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [],
   "source": [
    "vect = TfidfVectorizer()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0.863320419325433,\n",
       " 0.8849156777014366,\n",
       " 0.7732296356938191,\n",
       " 0.8253112939634457,\n",
       " 0.8505686474468312)"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "baseline_model(vect, clf, x_train, y_train, x_val, y_val)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "                Logging results to <a href=\"https://wandb.com\" target=\"_blank\">Weights & Biases</a> <a href=\"https://docs.wandb.com/integrations/jupyter.html\" target=\"_blank\">(Documentation)</a>.<br/>\n",
       "                Project page: <a href=\"https://app.wandb.ai/alexmjn/allay-ds-23\" target=\"_blank\">https://app.wandb.ai/alexmjn/allay-ds-23</a><br/>\n",
       "                Run page: <a href=\"https://app.wandb.ai/alexmjn/allay-ds-23/runs/3mcw46pj\" target=\"_blank\">https://app.wandb.ai/alexmjn/allay-ds-23/runs/3mcw46pj</a><br/>\n",
       "            "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "('Accuracy:',\n",
       " 0.863320419325433,\n",
       " '/nPrecision:',\n",
       " 0.8849156777014366,\n",
       " '/nRecall:',\n",
       " 0.7732296356938191,\n",
       " '/nF1:',\n",
       " 0.8253112939634457,\n",
       " '/nROC_AUC:',\n",
       " 0.8505686474468312)"
      ]
     },
     "execution_count": 39,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "implement_training(vect, clf)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "Collapsed": "false"
   },
   "source": [
    "## Naive Bayes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [],
   "source": [
    "vect = CountVectorizer()\n",
    "clf = MultinomialNB()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "                Logging results to <a href=\"https://wandb.com\" target=\"_blank\">Weights & Biases</a> <a href=\"https://docs.wandb.com/integrations/jupyter.html\" target=\"_blank\">(Documentation)</a>.<br/>\n",
       "                Project page: <a href=\"https://app.wandb.ai/alexmjn/allay-ds-23\" target=\"_blank\">https://app.wandb.ai/alexmjn/allay-ds-23</a><br/>\n",
       "                Run page: <a href=\"https://app.wandb.ai/alexmjn/allay-ds-23/runs/2zh784do\" target=\"_blank\">https://app.wandb.ai/alexmjn/allay-ds-23/runs/2zh784do</a><br/>\n",
       "            "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "('Accuracy:',\n",
       " 0.8469690063810392,\n",
       " 'Precision:',\n",
       " 0.8710244526130734,\n",
       " 'Recall:',\n",
       " 0.743621230727248,\n",
       " 'F1:',\n",
       " 0.8022964816723098,\n",
       " 'ROC_AUC:',\n",
       " 0.8323407924153701)"
      ]
     },
     "execution_count": 42,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "implement_training(vect, clf)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "                Logging results to <a href=\"https://wandb.com\" target=\"_blank\">Weights & Biases</a> <a href=\"https://docs.wandb.com/integrations/jupyter.html\" target=\"_blank\">(Documentation)</a>.<br/>\n",
       "                Project page: <a href=\"https://app.wandb.ai/alexmjn/allay-ds-23\" target=\"_blank\">https://app.wandb.ai/alexmjn/allay-ds-23</a><br/>\n",
       "                Run page: <a href=\"https://app.wandb.ai/alexmjn/allay-ds-23/runs/37af9rkz\" target=\"_blank\">https://app.wandb.ai/alexmjn/allay-ds-23/runs/37af9rkz</a><br/>\n",
       "            "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "('Accuracy:',\n",
       " 0.8229831358249772,\n",
       " 'Precision:',\n",
       " 0.9211891460494812,\n",
       " 'Recall:',\n",
       " 0.6299631600491199,\n",
       " 'F1:',\n",
       " 0.7482375820435944,\n",
       " 'ROC_AUC:',\n",
       " 0.7956623977884257)"
      ]
     },
     "execution_count": 43,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "vect = TfidfVectorizer()\n",
    "implement_training(vect, clf)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "Collapsed": "false"
   },
   "source": [
    "## SVM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [],
   "source": [
    "vect = CountVectorizer()\n",
    "clf = LinearSVC()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "                Logging results to <a href=\"https://wandb.com\" target=\"_blank\">Weights & Biases</a> <a href=\"https://docs.wandb.com/integrations/jupyter.html\" target=\"_blank\">(Documentation)</a>.<br/>\n",
       "                Project page: <a href=\"https://app.wandb.ai/alexmjn/allay-ds-23\" target=\"_blank\">https://app.wandb.ai/alexmjn/allay-ds-23</a><br/>\n",
       "                Run page: <a href=\"https://app.wandb.ai/alexmjn/allay-ds-23/runs/1myixgfe\" target=\"_blank\">https://app.wandb.ai/alexmjn/allay-ds-23/runs/1myixgfe</a><br/>\n",
       "            "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\ajenk\\.virtualenvs\\allay-ds-cRyEcJS9\\lib\\site-packages\\sklearn\\svm\\_base.py:947: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
      "  \"the number of iterations.\", ConvergenceWarning)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "('Accuracy:',\n",
       " 0.8507292616226071,\n",
       " 'Precision:',\n",
       " 0.8453865336658354,\n",
       " 'Recall:',\n",
       " 0.7863282848956201,\n",
       " 'F1:',\n",
       " 0.8147886328290682,\n",
       " 'ROC_AUC:',\n",
       " 0.8416137169367077)"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "implement_training(vect, clf)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [],
   "source": [
    "vect = TfidfVectorizer()\n",
    "clf = LinearSVC()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "                Logging results to <a href=\"https://wandb.com\" target=\"_blank\">Weights & Biases</a> <a href=\"https://docs.wandb.com/integrations/jupyter.html\" target=\"_blank\">(Documentation)</a>.<br/>\n",
       "                Project page: <a href=\"https://app.wandb.ai/alexmjn/allay-ds-23\" target=\"_blank\">https://app.wandb.ai/alexmjn/allay-ds-23</a><br/>\n",
       "                Run page: <a href=\"https://app.wandb.ai/alexmjn/allay-ds-23/runs/6j5wh8mg\" target=\"_blank\">https://app.wandb.ai/alexmjn/allay-ds-23/runs/6j5wh8mg</a><br/>\n",
       "            "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "('Accuracy:',\n",
       " 0.8605287146763901,\n",
       " 'Precision:',\n",
       " 0.8562253685593344,\n",
       " 'Recall:',\n",
       " 0.8003820439350525,\n",
       " 'F1:',\n",
       " 0.8273624823695346,\n",
       " 'ROC_AUC:',\n",
       " 0.8520153396824827)"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "implement_training(vect, clf)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "Collapsed": "false"
   },
   "source": [
    "## Random Forest"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [],
   "source": [
    "vect = CountVectorizer()\n",
    "clf = RandomForestClassifier()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "                Logging results to <a href=\"https://wandb.com\" target=\"_blank\">Weights & Biases</a> <a href=\"https://docs.wandb.com/integrations/jupyter.html\" target=\"_blank\">(Documentation)</a>.<br/>\n",
       "                Project page: <a href=\"https://app.wandb.ai/alexmjn/allay-ds-23\" target=\"_blank\">https://app.wandb.ai/alexmjn/allay-ds-23</a><br/>\n",
       "                Run page: <a href=\"https://app.wandb.ai/alexmjn/allay-ds-23/runs/3w229j0j\" target=\"_blank\">https://app.wandb.ai/alexmjn/allay-ds-23/runs/3w229j0j</a><br/>\n",
       "            "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "('Accuracy:',\n",
       " 0.8521536007292616,\n",
       " 'Precision:',\n",
       " 0.9043389135633755,\n",
       " 'Recall:',\n",
       " 0.7223359257743212,\n",
       " 'F1:',\n",
       " 0.8031555791549723,\n",
       " 'ROC_AUC:',\n",
       " 0.8337787425017551)"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "implement_training(vect, clf)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [],
   "source": [
    "vect = TfidfVectorizer()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "                Logging results to <a href=\"https://wandb.com\" target=\"_blank\">Weights & Biases</a> <a href=\"https://docs.wandb.com/integrations/jupyter.html\" target=\"_blank\">(Documentation)</a>.<br/>\n",
       "                Project page: <a href=\"https://app.wandb.ai/alexmjn/allay-ds-23\" target=\"_blank\">https://app.wandb.ai/alexmjn/allay-ds-23</a><br/>\n",
       "                Run page: <a href=\"https://app.wandb.ai/alexmjn/allay-ds-23/runs/189vdpvm\" target=\"_blank\">https://app.wandb.ai/alexmjn/allay-ds-23/runs/189vdpvm</a><br/>\n",
       "            "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "('Accuracy:',\n",
       " 0.8327825888787602,\n",
       " 'Precision:',\n",
       " 0.8206363105662581,\n",
       " 'Recall:',\n",
       " 0.7672260881429935,\n",
       " 'F1:',\n",
       " 0.7930329313870671,\n",
       " 'ROC_AUC:',\n",
       " 0.8235034871899551)"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "implement_training(vect, clf)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "Collapsed": "false"
   },
   "source": [
    "## Gradient-Boosted Trees"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [],
   "source": [
    "vect = CountVectorizer()\n",
    "clf = XGBClassifier()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "                Logging results to <a href=\"https://wandb.com\" target=\"_blank\">Weights & Biases</a> <a href=\"https://docs.wandb.com/integrations/jupyter.html\" target=\"_blank\">(Documentation)</a>.<br/>\n",
       "                Project page: <a href=\"https://app.wandb.ai/alexmjn/allay-ds-23\" target=\"_blank\">https://app.wandb.ai/alexmjn/allay-ds-23</a><br/>\n",
       "                Run page: <a href=\"https://app.wandb.ai/alexmjn/allay-ds-23/runs/2niyxkx4\" target=\"_blank\">https://app.wandb.ai/alexmjn/allay-ds-23/runs/2niyxkx4</a><br/>\n",
       "            "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "('Accuracy:',\n",
       " 0.8648587055606198,\n",
       " 'Precision:',\n",
       " 0.9335315725030611,\n",
       " 'Recall:',\n",
       " 0.7282030290626279,\n",
       " 'F1:',\n",
       " 0.8181818181818182,\n",
       " 'ROC_AUC:',\n",
       " 0.8455159721269317)"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "implement_training(vect, clf)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [],
   "source": [
    "vect = TfidfVectorizer()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "                Logging results to <a href=\"https://wandb.com\" target=\"_blank\">Weights & Biases</a> <a href=\"https://docs.wandb.com/integrations/jupyter.html\" target=\"_blank\">(Documentation)</a>.<br/>\n",
       "                Project page: <a href=\"https://app.wandb.ai/alexmjn/allay-ds-23\" target=\"_blank\">https://app.wandb.ai/alexmjn/allay-ds-23</a><br/>\n",
       "                Run page: <a href=\"https://app.wandb.ai/alexmjn/allay-ds-23/runs/3el65qwi\" target=\"_blank\">https://app.wandb.ai/alexmjn/allay-ds-23/runs/3el65qwi</a><br/>\n",
       "            "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "('Accuracy:',\n",
       " 0.85859161349134,\n",
       " 'Precision:',\n",
       " 0.9106931028639214,\n",
       " 'Recall:',\n",
       " 0.7332514667758221,\n",
       " 'F1:',\n",
       " 0.8123960695389266,\n",
       " 'ROC_AUC:',\n",
       " 0.8408505206323598)"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "implement_training(vect, clf)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "Collapsed": "false"
   },
   "source": [
    "# Performance Graph"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "Collapsed": "false"
   },
   "source": [
    "Our models are tracked to the website Weights and Biases (wandb.com). We'll be using that to systematically log model performance and improvement. For our baseline models, I've grabbed the data off Weights and Biases, dropped the models that either didn't finish or aren't part of the baseline, and graphed both by run number and by accuracy rank.\n",
    "\n",
    "These models are then compared to a baseline of .59 (majority-classifier model performance) and a target goal of .9 (what we think is a plausible score to shoot for with model tuning and improvement."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 174,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [],
   "source": [
    "results_table = pd.read_csv(\"wandb//wandb_export_2020-04-23T12_51_21.251-04_00.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 175,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Name</th>\n",
       "      <th>_wandb</th>\n",
       "      <th>recall</th>\n",
       "      <th>accuracy</th>\n",
       "      <th>precision</th>\n",
       "      <th>f1</th>\n",
       "      <th>roc_auc_score</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>playful-glade-20</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0.798200</td>\n",
       "      <td>0.854300</td>\n",
       "      <td>0.844300</td>\n",
       "      <td>0.820600</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>eager-snowflake-19</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0.777700</td>\n",
       "      <td>0.844500</td>\n",
       "      <td>0.838200</td>\n",
       "      <td>0.806900</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>blooming-cosmos-18</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0.783327</td>\n",
       "      <td>0.865030</td>\n",
       "      <td>0.880251</td>\n",
       "      <td>0.828965</td>\n",
       "      <td>0.853465</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>revived-snowball-17</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0.781962</td>\n",
       "      <td>0.866112</td>\n",
       "      <td>0.884004</td>\n",
       "      <td>0.829858</td>\n",
       "      <td>0.854201</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>fancy-dragon-16</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0.733251</td>\n",
       "      <td>0.858592</td>\n",
       "      <td>0.910693</td>\n",
       "      <td>0.812396</td>\n",
       "      <td>0.840851</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                  Name  _wandb    recall  accuracy  precision        f1  \\\n",
       "0     playful-glade-20     NaN  0.798200  0.854300   0.844300  0.820600   \n",
       "1   eager-snowflake-19     NaN  0.777700  0.844500   0.838200  0.806900   \n",
       "2   blooming-cosmos-18     NaN  0.783327  0.865030   0.880251  0.828965   \n",
       "3  revived-snowball-17     NaN  0.781962  0.866112   0.884004  0.829858   \n",
       "4      fancy-dragon-16     NaN  0.733251  0.858592   0.910693  0.812396   \n",
       "\n",
       "   roc_auc_score  \n",
       "0            NaN  \n",
       "1            NaN  \n",
       "2       0.853465  \n",
       "3       0.854201  \n",
       "4       0.840851  "
      ]
     },
     "execution_count": 175,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results_table.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 176,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Name</th>\n",
       "      <th>recall</th>\n",
       "      <th>accuracy</th>\n",
       "      <th>precision</th>\n",
       "      <th>f1</th>\n",
       "      <th>run_number</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>playful-glade-20</td>\n",
       "      <td>0.798200</td>\n",
       "      <td>0.854300</td>\n",
       "      <td>0.844300</td>\n",
       "      <td>0.820600</td>\n",
       "      <td>11</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>eager-snowflake-19</td>\n",
       "      <td>0.777700</td>\n",
       "      <td>0.844500</td>\n",
       "      <td>0.838200</td>\n",
       "      <td>0.806900</td>\n",
       "      <td>10</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>fancy-dragon-16</td>\n",
       "      <td>0.733251</td>\n",
       "      <td>0.858592</td>\n",
       "      <td>0.910693</td>\n",
       "      <td>0.812396</td>\n",
       "      <td>9</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>fast-yogurt-15</td>\n",
       "      <td>0.728203</td>\n",
       "      <td>0.864859</td>\n",
       "      <td>0.933532</td>\n",
       "      <td>0.818182</td>\n",
       "      <td>8</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>stellar-forest-14</td>\n",
       "      <td>0.767226</td>\n",
       "      <td>0.832783</td>\n",
       "      <td>0.820636</td>\n",
       "      <td>0.793033</td>\n",
       "      <td>7</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>bright-frost-13</td>\n",
       "      <td>0.722336</td>\n",
       "      <td>0.852154</td>\n",
       "      <td>0.904339</td>\n",
       "      <td>0.803156</td>\n",
       "      <td>6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>earthy-cherry-12</td>\n",
       "      <td>0.800382</td>\n",
       "      <td>0.860529</td>\n",
       "      <td>0.856225</td>\n",
       "      <td>0.827362</td>\n",
       "      <td>5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>valiant-snow-11</td>\n",
       "      <td>0.786328</td>\n",
       "      <td>0.850729</td>\n",
       "      <td>0.845387</td>\n",
       "      <td>0.814789</td>\n",
       "      <td>4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>radiant-gorge-9</td>\n",
       "      <td>0.629963</td>\n",
       "      <td>0.822983</td>\n",
       "      <td>0.921189</td>\n",
       "      <td>0.748238</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>morning-sun-8</td>\n",
       "      <td>0.743621</td>\n",
       "      <td>0.846969</td>\n",
       "      <td>0.871024</td>\n",
       "      <td>0.802296</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>stellar-dawn-7</td>\n",
       "      <td>0.773230</td>\n",
       "      <td>0.863320</td>\n",
       "      <td>0.884916</td>\n",
       "      <td>0.825311</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>fanciful-durian-6</td>\n",
       "      <td>0.783327</td>\n",
       "      <td>0.865030</td>\n",
       "      <td>0.880251</td>\n",
       "      <td>0.828965</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                  Name    recall  accuracy  precision        f1  run_number\n",
       "0     playful-glade-20  0.798200  0.854300   0.844300  0.820600          11\n",
       "1   eager-snowflake-19  0.777700  0.844500   0.838200  0.806900          10\n",
       "2      fancy-dragon-16  0.733251  0.858592   0.910693  0.812396           9\n",
       "3       fast-yogurt-15  0.728203  0.864859   0.933532  0.818182           8\n",
       "4    stellar-forest-14  0.767226  0.832783   0.820636  0.793033           7\n",
       "5      bright-frost-13  0.722336  0.852154   0.904339  0.803156           6\n",
       "6     earthy-cherry-12  0.800382  0.860529   0.856225  0.827362           5\n",
       "7      valiant-snow-11  0.786328  0.850729   0.845387  0.814789           4\n",
       "8      radiant-gorge-9  0.629963  0.822983   0.921189  0.748238           3\n",
       "9        morning-sun-8  0.743621  0.846969   0.871024  0.802296           2\n",
       "10      stellar-dawn-7  0.773230  0.863320   0.884916  0.825311           1\n",
       "11   fanciful-durian-6  0.783327  0.865030   0.880251  0.828965           0"
      ]
     },
     "execution_count": 176,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results_table = results_table[results_table.accuracy > .5]\n",
    "results_table = results_table.drop([2, 3, 16])\n",
    "results_table = results_table.drop([\"_wandb\", \"roc_auc_score\"], axis=1)\n",
    "results_table = results_table.reset_index().drop([\"index\"], axis=1)\n",
    "results_table = results_table.reset_index()\n",
    "results_table[\"run_number\"] = 11 - results_table[\"index\"]\n",
    "results_table = results_table.drop([\"index\"], axis = 1)\n",
    "\n",
    "results_table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 177,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 1152x432 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "fig, ax = plt.subplots(figsize = (16, 6))\n",
    "plt.style.use(\"fivethirtyeight\")\n",
    "fig.patch.set(facecolor=\"white\")\n",
    "ax.set(facecolor=\"white\")\n",
    "\n",
    "ax.plot(results_table[\"run_number\"], results_table[\"accuracy\"])\n",
    "ax.axhline(y = .582, linestyle='dotted', linewidth=2.5, color='black')\n",
    "ax.axhline(y = .9, linestyle='dotted', linewidth=2.5 ,color='black')\n",
    "ax.set_yticks([.5, .6, .7, .8, .9, 1])\n",
    "\n",
    "ax.set_xlabel(\"Run Number\", fontweight = \"bold\", size = 15)\n",
    "ax.set_ylabel(\"Accuracy\", fontweight = \"bold\", size = 15)\n",
    "ax.set_title(\"Accuracy vs Run Number: Baseline Models\", fontweight = \"bold\", size = 20)\n",
    "ax.annotate('Target Accuracy: 90%', xy=(5, .925))\n",
    "ax.annotate('Majority Classifier: 58.2% Accuracy', xy=(4.6, .615))\n",
    "\n",
    "plt.savefig(\"AccuracyVsRun\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 178,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [],
   "source": [
    "results_table = results_table.sort_values(by = [\"accuracy\"])\n",
    "results_table= results_table.reset_index().reset_index()\n",
    "results_table[\"model_rank\"] = results_table[\"level_0\"]\n",
    "results_table = results_table.drop([\"level_0\", \"index\"], axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 179,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 1152x432 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "fig, ax = plt.subplots(figsize = (16, 6))\n",
    "plt.style.use(\"fivethirtyeight\")\n",
    "fig.patch.set(facecolor=\"white\")\n",
    "ax.set(facecolor=\"white\")\n",
    "\n",
    "ax.plot(results_table[\"model_rank\"], results_table[\"accuracy\"])\n",
    "ax.axhline(y = .582, linestyle='dotted', linewidth=2.5, color='black')\n",
    "ax.axhline(y = .9, linestyle='dotted', linewidth=2.5 ,color='black')\n",
    "ax.set_yticks([.5, .6, .7, .8, .9, 1])\n",
    "\n",
    "ax.set_xlabel(\"Model Rank\", fontweight = \"bold\", size = 15)\n",
    "ax.set_ylabel(\"Accuracy\", fontweight = \"bold\", size = 15)\n",
    "ax.set_title(\"Accuracy vs Model Rank: Baseline Models\", fontweight = \"bold\", size = 20)\n",
    "ax.annotate('Target Accuracy: 90%', xy=(5, .93))\n",
    "ax.annotate('Majority Classifier: 58.2% Accuracy', xy=(4.6, .615))\n",
    "\n",
    "plt.savefig(\"AccuracyVsRank.png\")\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "allay-ds-cRyEcJS9",
   "language": "python",
   "name": "allay-ds-cryecjs9"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}