Commit 2a7d9ea1 authored by Srinidee's avatar Srinidee

Logistic regression model development using IMDB data set.

parent 157ac1ff
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "WNOJt2fcFRiH",
"outputId": "f7472149-8f61-46a1-9d95-ad06f7c90c56"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
"Collecting numpy==1.20.0\n",
" Downloading numpy-1.20.0.zip (8.0 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m8.0/8.0 MB\u001b[0m \u001b[31m52.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25h Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
" Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
" Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
"Building wheels for collected packages: numpy\n",
" \u001b[1;31merror\u001b[0m: \u001b[1msubprocess-exited-with-error\u001b[0m\n",
" \n",
" \u001b[31m×\u001b[0m \u001b[32mBuilding wheel for numpy \u001b[0m\u001b[1;32m(\u001b[0m\u001b[32mpyproject.toml\u001b[0m\u001b[1;32m)\u001b[0m did not run successfully.\n",
" \u001b[31m│\u001b[0m exit code: \u001b[1;36m1\u001b[0m\n",
" \u001b[31m╰─>\u001b[0m See above for output.\n",
" \n",
" \u001b[1;35mnote\u001b[0m: This error originates from a subprocess, and is likely not a problem with pip.\n",
" Building wheel for numpy (pyproject.toml) ... \u001b[?25l\u001b[?25herror\n",
"\u001b[31m ERROR: Failed building wheel for numpy\u001b[0m\u001b[31m\n",
"\u001b[0mFailed to build numpy\n",
"\u001b[31mERROR: Could not build wheels for numpy, which is required to install pyproject.toml-based projects\u001b[0m\u001b[31m\n",
"\u001b[0mLooking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
"Collecting ordered_set\n",
" Downloading ordered_set-4.1.0-py3-none-any.whl (7.6 kB)\n",
"Installing collected packages: ordered_set\n",
"Successfully installed ordered_set-4.1.0\n"
]
}
],
"source": [
"!pip install numpy==1.20.0\n",
"!pip install ordered_set"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "IClKGk-9G-M_"
},
"outputs": [],
"source": [
"import numpy as np # linear algebra\n",
"import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "B4HTD76gLyQO",
"outputId": "f196cd73-3993-4d19-aee7-b5a17594b88f"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"unzip: cannot find or open /usr/share/nltk_data/corpora/wordnet.zip, /usr/share/nltk_data/corpora/wordnet.zip.zip or /usr/share/nltk_data/corpora/wordnet.zip.ZIP.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"[nltk_data] Downloading package wordnet to /root/nltk_data...\n",
"[nltk_data] Package wordnet is already up-to-date!\n"
]
}
],
"source": [
"from sklearn.model_selection import RandomizedSearchCV\n",
"from sklearn.model_selection import GridSearchCV\n",
"from sklearn.metrics import roc_auc_score, accuracy_score\n",
"from sklearn.model_selection import ParameterGrid\n",
"from sklearn.svm import SVC\n",
"import sklearn.feature_extraction\n",
"from sklearn.feature_extraction.text import TfidfVectorizer\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn import metrics\n",
"from sklearn.pipeline import Pipeline\n",
"from sklearn.feature_extraction.text import TfidfVectorizer\n",
"from sklearn.feature_extraction.text import TfidfTransformer\n",
"\n",
"import nltk\n",
"from nltk.corpus import stopwords \n",
"from nltk.tokenize import word_tokenize\n",
"from nltk.stem import WordNetLemmatizer\n",
"nltk.download('wordnet')\n",
"!unzip /usr/share/nltk_data/corpora/wordnet.zip -d /usr/share/nltk_data/corpora/\n",
"\n",
"from bs4 import BeautifulSoup\n",
"import re\n",
"import pickle\n",
"import seaborn as sns\n",
"\n",
"from ordered_set import OrderedSet\n",
"from scipy.sparse import lil_matrix\n",
"from itertools import compress\n",
"\n",
"loaded_vocab = pickle.load(open('/content/drive/MyDrive/vectorizer_imdb.pkl', 'rb'))\n",
"\n",
"stop_words = set(stopwords.words('english'))\n",
"tokenizer = nltk.tokenize.toktok.ToktokTokenizer()\n",
"lemmatizer = WordNetLemmatizer() \n",
"loaded_vectorizer = TfidfVectorizer(min_df=2, vocabulary=loaded_vocab)\n",
"label_binarizer = sklearn.preprocessing.LabelBinarizer()\n",
"feature_names = loaded_vectorizer.get_feature_names_out()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "pPO3G_BeNXTZ"
},
"outputs": [],
"source": [
"import nltk"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "3AVqS5t5NnLV",
"outputId": "4a293def-3729-4fcc-8242-2a80a07907c0"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[nltk_data] Downloading package stopwords to /root/nltk_data...\n",
"[nltk_data] Unzipping corpora/stopwords.zip.\n"
]
},
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"nltk.download('stopwords')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9TlvtkGsOEEn"
},
"outputs": [],
"source": [
"def strip_html(text):\n",
" soup = BeautifulSoup(text, \"html.parser\")\n",
" return soup.get_text()\n",
"\n",
"def remove_special_characters(text, remove_digits=True):\n",
" pattern=r'[^a-zA-z0-9\\s]'\n",
" text=re.sub(pattern,'',text)\n",
" return text\n",
"\n",
"def remove_stopwords(text, is_lower_case=False):\n",
" tokens = tokenizer.tokenize(text)\n",
" tokens = [token.strip() for token in tokens]\n",
" if is_lower_case:\n",
" filtered_tokens = [token for token in tokens if token not in stop_words]\n",
" else:\n",
" filtered_tokens = [token for token in tokens if token.lower() not in stop_words]\n",
" filtered_text = ' '.join(filtered_tokens) \n",
" return filtered_text\n",
"\n",
"def lemmatize_text(text):\n",
" words=word_tokenize(text)\n",
" edited_text = ''\n",
" for word in words:\n",
" lemma_word=lemmatizer.lemmatize(word)\n",
" extra=\" \"+str(lemma_word)\n",
" edited_text+=extra\n",
" return edited_text"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-ubc9FDcOZNz"
},
"outputs": [],
"source": [
"## Import\n",
"data = pd.read_csv('/content/drive/MyDrive/IMDB Dataset.csv')\n",
"data = data.sample(10000)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "zzu70vmEOhMt",
"outputId": "dcdb2c79-8356-4dd5-8af1-3d45a59abee7"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"<ipython-input-11-7b716edd6682>:2: MarkupResemblesLocatorWarning: The input looks more like a filename than markup. You may want to open this file and pass the filehandle into Beautiful Soup.\n",
" soup = BeautifulSoup(text, \"html.parser\")\n"
]
}
],
"source": [
"## Preprocess\n",
"data.review = data.review.str.lower()\n",
"data.review = data.review.apply(strip_html)\n",
"data.review = data.review.apply(remove_special_characters)\n",
"data.review = data.review.apply(remove_stopwords)\n",
"data.review = data.review.apply(lemmatize_text)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ms8uVfj0bJdT"
},
"outputs": [],
"source": [
"import nltk"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "MFL0_u3QbUjc",
"outputId": "e01191d0-2a66-4977-ab08-b19e7705ab0f"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[nltk_data] Downloading package punkt to /root/nltk_data...\n",
"[nltk_data] Unzipping tokenizers/punkt.zip.\n"
]
},
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"nltk.download('punkt')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-miptIA5c_mk"
},
"outputs": [],
"source": [
"import nltk"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "4K-RAkQndG88",
"outputId": "e43a70f8-1d73-4f21-808c-8670d6549f73"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[nltk_data] Downloading package stopwords to /root/nltk_data...\n",
"[nltk_data] Unzipping corpora/stopwords.zip.\n"
]
},
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"nltk.download('stopwords')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "GRueBFH-d1qE"
},
"outputs": [],
"source": [
"## Split Data\n",
"x_imdb = data['review']\n",
"y_imdb = data['sentiment']\n",
"\n",
"x_train_i, x_test_i, y_train_i, y_test_i = train_test_split(x_imdb,y_imdb,test_size=0.2)\n",
"x_test, x_val, y_test_i, y_val_i = train_test_split(x_test_i,y_test_i,test_size=0.5)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "pA2pgBhdd4jm"
},
"outputs": [],
"source": [
"## X data\n",
"x_train_imdb = loaded_vectorizer.fit_transform(x_train_i)\n",
"x_test_imdb = loaded_vectorizer.transform(x_test)\n",
"x_val_imdb = loaded_vectorizer.transform(x_val)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "j1uHhqGPeDXc",
"outputId": "894ed3f7-117c-4601-891c-a7100527dfa5"
},
"outputs": [
{
"data": {
"text/plain": [
"(1, 26698)"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x_train_imdb[0].shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "b4RJXbcxeSKc"
},
"outputs": [],
"source": [
"# Y data - Positive is 1\n",
"y_train_imdb = label_binarizer.fit_transform(y_train_i)\n",
"y_test_imdb = label_binarizer.fit_transform(y_test_i)\n",
"y_val_imdb = label_binarizer.fit_transform(y_val_i)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "HdBinqNIea2L"
},
"outputs": [],
"source": [
"from sklearn.linear_model import LogisticRegression\n",
"param_grid_imdb_lr = [ \n",
" {'penalty' : ['l1', 'l2', 'elasticnet'],\n",
" 'C' : np.logspace(-4, 4, 20),\n",
" 'solver' : ['lbfgs','newton-cg','sag'],\n",
" 'max_iter' : [100, 1000, 5000]\n",
" }\n",
"]\n",
"grid_imdb_lr = GridSearchCV(LogisticRegression(), param_grid = param_grid_imdb_lr, cv = 3, verbose=2, n_jobs=-1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "4jbagEFMfBnk",
"outputId": "ebabf88b-035d-42de-beaa-e1387dc5459c"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Fitting 3 folds for each of 540 candidates, totalling 1620 fits\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.10/dist-packages/sklearn/model_selection/_validation.py:378: FitFailedWarning: \n",
"1080 fits failed out of a total of 1620.\n",
"The score on these train-test partitions for these parameters will be set to nan.\n",
"If these failures are not expected, you can try to debug them by setting error_score='raise'.\n",
"\n",
"Below are more details about the failures:\n",
"--------------------------------------------------------------------------------\n",
"180 fits failed with the following error:\n",
"Traceback (most recent call last):\n",
" File \"/usr/local/lib/python3.10/dist-packages/sklearn/model_selection/_validation.py\", line 686, in _fit_and_score\n",
" estimator.fit(X_train, y_train, **fit_params)\n",
" File \"/usr/local/lib/python3.10/dist-packages/sklearn/linear_model/_logistic.py\", line 1162, in fit\n",
" solver = _check_solver(self.solver, self.penalty, self.dual)\n",
" File \"/usr/local/lib/python3.10/dist-packages/sklearn/linear_model/_logistic.py\", line 54, in _check_solver\n",
" raise ValueError(\n",
"ValueError: Solver lbfgs supports only 'l2' or 'none' penalties, got l1 penalty.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"180 fits failed with the following error:\n",
"Traceback (most recent call last):\n",
" File \"/usr/local/lib/python3.10/dist-packages/sklearn/model_selection/_validation.py\", line 686, in _fit_and_score\n",
" estimator.fit(X_train, y_train, **fit_params)\n",
" File \"/usr/local/lib/python3.10/dist-packages/sklearn/linear_model/_logistic.py\", line 1162, in fit\n",
" solver = _check_solver(self.solver, self.penalty, self.dual)\n",
" File \"/usr/local/lib/python3.10/dist-packages/sklearn/linear_model/_logistic.py\", line 54, in _check_solver\n",
" raise ValueError(\n",
"ValueError: Solver newton-cg supports only 'l2' or 'none' penalties, got l1 penalty.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"180 fits failed with the following error:\n",
"Traceback (most recent call last):\n",
" File \"/usr/local/lib/python3.10/dist-packages/sklearn/model_selection/_validation.py\", line 686, in _fit_and_score\n",
" estimator.fit(X_train, y_train, **fit_params)\n",
" File \"/usr/local/lib/python3.10/dist-packages/sklearn/linear_model/_logistic.py\", line 1162, in fit\n",
" solver = _check_solver(self.solver, self.penalty, self.dual)\n",
" File \"/usr/local/lib/python3.10/dist-packages/sklearn/linear_model/_logistic.py\", line 54, in _check_solver\n",
" raise ValueError(\n",
"ValueError: Solver sag supports only 'l2' or 'none' penalties, got l1 penalty.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"180 fits failed with the following error:\n",
"Traceback (most recent call last):\n",
" File \"/usr/local/lib/python3.10/dist-packages/sklearn/model_selection/_validation.py\", line 686, in _fit_and_score\n",
" estimator.fit(X_train, y_train, **fit_params)\n",
" File \"/usr/local/lib/python3.10/dist-packages/sklearn/linear_model/_logistic.py\", line 1162, in fit\n",
" solver = _check_solver(self.solver, self.penalty, self.dual)\n",
" File \"/usr/local/lib/python3.10/dist-packages/sklearn/linear_model/_logistic.py\", line 54, in _check_solver\n",
" raise ValueError(\n",
"ValueError: Solver lbfgs supports only 'l2' or 'none' penalties, got elasticnet penalty.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"180 fits failed with the following error:\n",
"Traceback (most recent call last):\n",
" File \"/usr/local/lib/python3.10/dist-packages/sklearn/model_selection/_validation.py\", line 686, in _fit_and_score\n",
" estimator.fit(X_train, y_train, **fit_params)\n",
" File \"/usr/local/lib/python3.10/dist-packages/sklearn/linear_model/_logistic.py\", line 1162, in fit\n",
" solver = _check_solver(self.solver, self.penalty, self.dual)\n",
" File \"/usr/local/lib/python3.10/dist-packages/sklearn/linear_model/_logistic.py\", line 54, in _check_solver\n",
" raise ValueError(\n",
"ValueError: Solver newton-cg supports only 'l2' or 'none' penalties, got elasticnet penalty.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"180 fits failed with the following error:\n",
"Traceback (most recent call last):\n",
" File \"/usr/local/lib/python3.10/dist-packages/sklearn/model_selection/_validation.py\", line 686, in _fit_and_score\n",
" estimator.fit(X_train, y_train, **fit_params)\n",
" File \"/usr/local/lib/python3.10/dist-packages/sklearn/linear_model/_logistic.py\", line 1162, in fit\n",
" solver = _check_solver(self.solver, self.penalty, self.dual)\n",
" File \"/usr/local/lib/python3.10/dist-packages/sklearn/linear_model/_logistic.py\", line 54, in _check_solver\n",
" raise ValueError(\n",
"ValueError: Solver sag supports only 'l2' or 'none' penalties, got elasticnet penalty.\n",
"\n",
" warnings.warn(some_fits_failed_message, FitFailedWarning)\n",
"/usr/local/lib/python3.10/dist-packages/sklearn/model_selection/_search.py:952: UserWarning: One or more of the test scores are non-finite: [ nan nan nan 0.50124998 0.50124998 0.50124998\n",
" nan nan nan nan nan nan\n",
" 0.50124998 0.50124998 0.50124998 nan nan nan\n",
" nan nan nan 0.50124998 0.50124998 0.50124998\n",
" nan nan nan nan nan nan\n",
" 0.50124998 0.50124998 0.50124998 nan nan nan\n",
" nan nan nan 0.50124998 0.50124998 0.50124998\n",
" nan nan nan nan nan nan\n",
" 0.50124998 0.50124998 0.50124998 nan nan nan\n",
" nan nan nan 0.53499905 0.53499905 0.53312475\n",
" nan nan nan nan nan nan\n",
" 0.53499905 0.53499905 0.53674878 nan nan nan\n",
" nan nan nan 0.53499905 0.53499905 0.53749897\n",
" nan nan nan nan nan nan\n",
" 0.76587341 0.76587341 0.76687324 nan nan nan\n",
" nan nan nan 0.76587341 0.76587341 0.76562312\n",
" nan nan nan nan nan nan\n",
" 0.76587341 0.76587341 0.76574829 nan nan nan\n",
" nan nan nan 0.81412493 0.81412493 0.81412493\n",
" nan nan nan nan nan nan\n",
" 0.81412493 0.81412493 0.81424992 nan nan nan\n",
" nan nan nan 0.81412493 0.81412493 0.81424992\n",
" nan nan nan nan nan nan\n",
" 0.81587471 0.81587471 0.81562465 nan nan nan\n",
" nan nan nan 0.81587471 0.81587471 0.81587471\n",
" nan nan nan nan nan nan\n",
" 0.81587471 0.81587471 0.81562465 nan nan nan\n",
" nan nan nan 0.82037481 0.82037481 0.82037481\n",
" nan nan nan nan nan nan\n",
" 0.82037481 0.82037481 0.82037481 nan nan nan\n",
" nan nan nan 0.82037481 0.82037481 0.82037481\n",
" nan nan nan nan nan nan\n",
" 0.83187445 0.83187445 0.83187445 nan nan nan\n",
" nan nan nan 0.83187445 0.83187445 0.83187445\n",
" nan nan nan nan nan nan\n",
" 0.83187445 0.83187445 0.83187445 nan nan nan\n",
" nan nan nan 0.84899937 0.84899937 0.84899937\n",
" nan nan nan nan nan nan\n",
" 0.84899937 0.84899937 0.84899937 nan nan nan\n",
" nan nan nan 0.84899937 0.84899937 0.84899937\n",
" nan nan nan nan nan nan\n",
" 0.86224954 0.86224954 0.86224954 nan nan nan\n",
" nan nan nan 0.86224954 0.86224954 0.86224954\n",
" nan nan nan nan nan nan\n",
" 0.86224954 0.86224954 0.86224954 nan nan nan\n",
" nan nan nan 0.86849946 0.86849946 0.86849946\n",
" nan nan nan nan nan nan\n",
" 0.86849946 0.86849946 0.86849946 nan nan nan\n",
" nan nan nan 0.86849946 0.86849946 0.86849946\n",
" nan nan nan nan nan nan\n",
" 0.87087431 0.87087431 0.87087431 nan nan nan\n",
" nan nan nan 0.87087431 0.87087431 0.87087431\n",
" nan nan nan nan nan nan\n",
" 0.87087431 0.87087431 0.87087431 nan nan nan\n",
" nan nan nan 0.86899912 0.86899912 0.86899912\n",
" nan nan nan nan nan nan\n",
" 0.86899912 0.86899912 0.86899912 nan nan nan\n",
" nan nan nan 0.86899912 0.86899912 0.86899912\n",
" nan nan nan nan nan nan\n",
" 0.86937403 0.86937403 0.86949906 nan nan nan\n",
" nan nan nan 0.86937403 0.86937403 0.86924904\n",
" nan nan nan nan nan nan\n",
" 0.86937403 0.86937403 0.86937407 nan nan nan\n",
" nan nan nan 0.86562426 0.86574925 0.86512414\n",
" nan nan nan nan nan nan\n",
" 0.86574925 0.86574925 0.86549923 nan nan nan\n",
" nan nan nan 0.86574925 0.86574925 0.86574925\n",
" nan nan nan nan nan nan\n",
" 0.86149912 0.86187403 0.86274906 nan nan nan\n",
" nan nan nan 0.86199901 0.86187403 0.86187407\n",
" nan nan nan nan nan nan\n",
" 0.86199901 0.86187403 0.8616241 nan nan nan\n",
" nan nan nan 0.86012396 0.86049887 0.86337403\n",
" nan nan nan nan nan nan\n",
" 0.86049882 0.86049887 0.8598739 nan nan nan\n",
" nan nan nan 0.86049882 0.86049887 0.86074898\n",
" nan nan nan nan nan nan\n",
" 0.85987409 0.85899906 0.86287409 nan nan nan\n",
" nan nan nan 0.85899906 0.85899906 0.86062409\n",
" nan nan nan nan nan nan\n",
" 0.85899906 0.85899906 0.8601242 nan nan nan\n",
" nan nan nan 0.8589991 0.85899915 0.86387425\n",
" nan nan nan nan nan nan\n",
" 0.85899915 0.85899915 0.86299917 nan nan nan\n",
" nan nan nan 0.85899915 0.85899915 0.85949904\n",
" nan nan nan nan nan nan\n",
" 0.85849903 0.85799909 0.86174885 nan nan nan\n",
" nan nan nan 0.85799909 0.85799909 0.8611244\n",
" nan nan nan nan nan nan\n",
" 0.85799909 0.85799909 0.85887389 nan nan nan]\n",
" warnings.warn(\n"
]
}
],
"source": [
"grid_imdb_lr.fit(x_train_imdb, np.ravel(y_train_imdb,order='C'))\n",
"pickle.dump(grid_imdb_lr, open('grid_imdb_lr.pickle', \"wb\"))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Kz3dUqM0glC6"
},
"outputs": [],
"source": [
"loaded_lr_imdb = pickle.load(open('/content/drive/MyDrive/grid_imdb_lr.pickle', \"rb\"))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "mYl7xkluhEqL",
"outputId": "65c43228-85cc-4591-e9ac-cf387ddbf830"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'C': 4.281332398719396, 'max_iter': 100, 'penalty': 'l2', 'solver': 'lbfgs'}\n"
]
}
],
"source": [
"print(loaded_lr_imdb.best_params_)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "uM5zsNadhUcU"
},
"outputs": [],
"source": [
"p = np.sum(y_train_imdb)/np.size(y_train_imdb)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "nvWe0nKYhaLD"
},
"outputs": [],
"source": [
"probs_lr = loaded_lr_imdb.predict(x_test_imdb)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "-H4swmnehfWV",
"outputId": "68c1766f-eaa3-4c8e-cba7-6a3e87c0d804"
},
"outputs": [
{
"data": {
"text/plain": [
"array([0])"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"loaded_lr_imdb.predict(x_test_imdb[0])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "L45shyPuhnmz"
},
"outputs": [],
"source": [
"accuracy_lr = metrics.accuracy_score(y_test_imdb, probs_lr)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "gSe44ZOihssa",
"outputId": "56131b8a-53f7-489b-a072-4364201bfb75"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The accuracy of the SVC model on the test data is 0.890000\n"
]
}
],
"source": [
"print(\"The accuracy of the SVC model on the test data is %f\" %accuracy_lr)"
]
}
],
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment