Revert "fix: remove unnecessary code"

This reverts commit d3f12845.
parent 4521d87e
{
"metadata": {
"kernelspec": {
"language": "python",
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.10.10",
"mimetype": "text/x-python",
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"pygments_lexer": "ipython3",
"nbconvert_exporter": "python",
"file_extension": ".py"
},
"colab": {
"provenance": []
}
},
"nbformat_minor": 0,
"nbformat": 4,
"cells": [
{
"cell_type": "markdown",
"source": [
"# Initialisation"
],
"metadata": {
"id": "tFwb5L06sPmL"
}
},
{
"cell_type": "markdown",
"source": [
"## Installing and importing packages"
],
"metadata": {
"id": "iLJdHDgfsPmO"
}
},
{
"cell_type": "code",
"source": [
"!pip install numpy==1.20.0\n",
"!pip install ordered_set"
],
"metadata": {
"trusted": true,
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "WamT8ircsPmP",
"outputId": "4879fb24-f899-4675-a243-86cd93213aad"
},
"execution_count": 6,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
"Collecting numpy==1.20.0\n",
" Downloading numpy-1.20.0.zip (8.0 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m8.0/8.0 MB\u001b[0m \u001b[31m45.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25h Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
" Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
" Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
"Building wheels for collected packages: numpy\n",
" \u001b[1;31merror\u001b[0m: \u001b[1msubprocess-exited-with-error\u001b[0m\n",
" \n",
" \u001b[31m×\u001b[0m \u001b[32mBuilding wheel for numpy \u001b[0m\u001b[1;32m(\u001b[0m\u001b[32mpyproject.toml\u001b[0m\u001b[1;32m)\u001b[0m did not run successfully.\n",
" \u001b[31m│\u001b[0m exit code: \u001b[1;36m1\u001b[0m\n",
" \u001b[31m╰─>\u001b[0m See above for output.\n",
" \n",
" \u001b[1;35mnote\u001b[0m: This error originates from a subprocess, and is likely not a problem with pip.\n",
" Building wheel for numpy (pyproject.toml) ... \u001b[?25l\u001b[?25herror\n",
"\u001b[31m ERROR: Failed building wheel for numpy\u001b[0m\u001b[31m\n",
"\u001b[0mFailed to build numpy\n",
"\u001b[31mERROR: Could not build wheels for numpy, which is required to install pyproject.toml-based projects\u001b[0m\u001b[31m\n",
"\u001b[0mLooking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
"Collecting ordered_set\n",
" Downloading ordered_set-4.1.0-py3-none-any.whl (7.6 kB)\n",
"Installing collected packages: ordered_set\n",
"Successfully installed ordered_set-4.1.0\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"from google.colab import drive\n",
"drive.mount('/content/drive')"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "s0ZApxf4Oe2M",
"outputId": "02ceae9e-a309-4c2f-f511-1df025f0e613"
},
"execution_count": 7,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Mounted at /content/drive\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"import numpy as np # linear algebra\n",
"import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)"
],
"metadata": {
"id": "gf2uH7S1Obmw"
},
"execution_count": 8,
"outputs": []
},
{
"cell_type": "code",
"source": [
"import numpy as np # linear algebra\n",
"import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)\n",
"\n",
"import os\n",
"for dirname, _, filenames in os.walk('/content/drive'):\n",
" for filename in filenames:\n",
" print(os.path.join(dirname, filename))\n"
],
"metadata": {
"_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5",
"_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19",
"trusted": true,
"id": "80TWwaECsPmQ"
},
"execution_count": 1,
"outputs": []
},
{
"cell_type": "code",
"source": [
"\n",
"import time\n",
"\n",
"from sklearn.model_selection import RandomizedSearchCV\n",
"from sklearn.model_selection import GridSearchCV\n",
"from sklearn.metrics import roc_auc_score, accuracy_score\n",
"from sklearn.model_selection import ParameterGrid\n",
"from sklearn.svm import SVC\n",
"import sklearn.feature_extraction\n",
"from sklearn.feature_extraction.text import TfidfVectorizer\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn import metrics\n",
"from sklearn.pipeline import Pipeline\n",
"from sklearn.feature_extraction.text import TfidfVectorizer\n",
"from sklearn.feature_extraction.text import TfidfTransformer\n",
"\n",
"import nltk\n",
"from nltk.corpus import stopwords \n",
"from nltk.tokenize import word_tokenize\n",
"from nltk.stem import WordNetLemmatizer\n",
"nltk.download('wordnet')\n",
"nltk.download('stopwords')\n",
"nltk.download('punkt')\n",
"!unzip /usr/share/nltk_data/corpora/wordnet.zip -d /usr/share/nltk_data/corpora/\n",
"\n",
"from bs4 import BeautifulSoup\n",
"import re\n",
"import pickle\n",
"import seaborn as sns\n",
"\n",
"from ordered_set import OrderedSet\n",
"from scipy.sparse import lil_matrix\n",
"from itertools import compress\n"
],
"metadata": {
"trusted": true,
"id": "c0yWpzdksPmQ",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "87aa8bb6-c1cb-4cfb-f48c-79b5b717684d"
},
"execution_count": 9,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"unzip: cannot find or open /usr/share/nltk_data/corpora/wordnet.zip, /usr/share/nltk_data/corpora/wordnet.zip.zip or /usr/share/nltk_data/corpora/wordnet.zip.ZIP.\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"[nltk_data] Downloading package wordnet to /root/nltk_data...\n",
"[nltk_data] Package wordnet is already up-to-date!\n",
"[nltk_data] Downloading package stopwords to /root/nltk_data...\n",
"[nltk_data] Package stopwords is already up-to-date!\n",
"[nltk_data] Downloading package punkt to /root/nltk_data...\n",
"[nltk_data] Package punkt is already up-to-date!\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"# import dataset\n",
"df = pd.read_csv('/content/sample_data/IMDB Dataset.csv',parse_dates=True)\n",
"df.info()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "LyVwrbLFQs4g",
"outputId": "4174abd8-88bb-4eac-a05d-d944833c4bf9"
},
"execution_count": 10,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"<class 'pandas.core.frame.DataFrame'>\n",
"RangeIndex: 50000 entries, 0 to 49999\n",
"Data columns (total 2 columns):\n",
" # Column Non-Null Count Dtype \n",
"--- ------ -------------- ----- \n",
" 0 review 50000 non-null object\n",
" 1 sentiment 50000 non-null object\n",
"dtypes: object(2)\n",
"memory usage: 781.4+ KB\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"loaded_vocab = pickle.load(open('/content/sample_data/vectorizer_imdb.pkl', 'rb'))\n",
"\n",
"stop_words = set(stopwords.words('english'))\n",
"tokenizer = nltk.tokenize.toktok.ToktokTokenizer()\n",
"lemmatizer = WordNetLemmatizer() \n",
"loaded_vectorizer = TfidfVectorizer(min_df=2, vocabulary=loaded_vocab)\n",
"label_binarizer = sklearn.preprocessing.LabelBinarizer()\n",
"feature_names = loaded_vectorizer.get_feature_names_out()"
],
"metadata": {
"id": "KzLCu5aRPcwB"
},
"execution_count": 12,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Support Functions"
],
"metadata": {
"id": "rkeUtGt4sPmR"
}
},
{
"cell_type": "code",
"source": [
"def strip_html(text):\n",
" soup = BeautifulSoup(text, \"html.parser\")\n",
" return soup.get_text()\n",
"\n",
"def remove_special_characters(text, remove_digits=True):\n",
" pattern=r'[^a-zA-z0-9\\s]'\n",
" text=re.sub(pattern,'',text)\n",
" return text\n",
"\n",
"def remove_stopwords(text, is_lower_case=False):\n",
" tokens = tokenizer.tokenize(text)\n",
" tokens = [token.strip() for token in tokens]\n",
" if is_lower_case:\n",
" filtered_tokens = [token for token in tokens if token not in stop_words]\n",
" else:\n",
" filtered_tokens = [token for token in tokens if token.lower() not in stop_words]\n",
" filtered_text = ' '.join(filtered_tokens) \n",
" return filtered_text\n",
"\n",
"def lemmatize_text(text):\n",
" words=word_tokenize(text)\n",
" edited_text = ''\n",
" for word in words:\n",
" lemma_word=lemmatizer.lemmatize(word)\n",
" extra=\" \"+str(lemma_word)\n",
" edited_text+=extra\n",
" return edited_text"
],
"metadata": {
"trusted": true,
"id": "dLQymBfIsPmR"
},
"execution_count": 13,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Import DataSet and Preprocess"
],
"metadata": {
"id": "ewsKXh2XsPmS"
}
},
{
"cell_type": "code",
"source": [
"## Import\n",
"data = pd.read_csv('/content/sample_data/IMDB Dataset.csv')\n",
"data = data.sample(10000)"
],
"metadata": {
"trusted": true,
"id": "bdJixyiPsPmS"
},
"execution_count": 15,
"outputs": []
},
{
"cell_type": "code",
"source": [
"## Preprocess\n",
"data.review = data.review.str.lower()\n",
"data.review = data.review.apply(strip_html)\n",
"data.review = data.review.apply(remove_special_characters)\n",
"data.review = data.review.apply(remove_stopwords)\n",
"data.review = data.review.apply(lemmatize_text)"
],
"metadata": {
"trusted": true,
"id": "RwfbMIV0sPmS",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "4f52511b-8026-44be-970a-3b18f38e2a70"
},
"execution_count": 16,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"<ipython-input-13-7b716edd6682>:2: MarkupResemblesLocatorWarning: The input looks more like a filename than markup. You may want to open this file and pass the filehandle into Beautiful Soup.\n",
" soup = BeautifulSoup(text, \"html.parser\")\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"## Split Data\n",
"x_imdb = data['review']\n",
"y_imdb = data['sentiment']\n",
"\n",
"x_train_i, x_test_i, y_train_i, y_test_i = train_test_split(x_imdb,y_imdb,test_size=0.2)\n",
"x_test, x_val, y_test_i, y_val_i = train_test_split(x_test_i,y_test_i,test_size=0.5)"
],
"metadata": {
"trusted": true,
"id": "RFUFkBROsPmS"
},
"execution_count": 17,
"outputs": []
},
{
"cell_type": "code",
"source": [
"## X data -> transform into vectors\n",
"x_train_imdb = loaded_vectorizer.fit_transform(x_train_i)\n",
"x_test_imdb = loaded_vectorizer.transform(x_test)\n",
"x_val_imdb = loaded_vectorizer.transform(x_val)"
],
"metadata": {
"trusted": true,
"id": "NbYfcNFTsPmT"
},
"execution_count": 20,
"outputs": []
},
{
"cell_type": "code",
"source": [
"x_train_imdb[0].shape"
],
"metadata": {
"trusted": true,
"id": "zTzI4orbsPmT",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "11e313ba-de2d-42c3-930a-cea796f501d6"
},
"execution_count": 21,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(1, 26698)"
]
},
"metadata": {},
"execution_count": 21
}
]
},
{
"cell_type": "code",
"source": [
"# Y data - Positive is 1\n",
"y_train_imdb = label_binarizer.fit_transform(y_train_i)\n",
"y_test_imdb = label_binarizer.fit_transform(y_test_i)\n",
"y_val_imdb = label_binarizer.fit_transform(y_val_i)"
],
"metadata": {
"trusted": true,
"id": "Q84MWx-QsPmT"
},
"execution_count": 22,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Train RF model"
],
"metadata": {
"id": "_fW74XRBsPmT"
}
},
{
"cell_type": "code",
"source": [
"# Number of trees in random forest\n",
"n_estimators = [int(x) for x in np.linspace(start = 10, stop = 100, num = 10)]\n",
"# Maximum number of levels in tree\n",
"max_depth = [int(x) for x in np.linspace(10, 100, num = 5)]\n",
"max_depth.append(None)\n",
"# Minimum number of samples required to split a node\n",
"min_samples_split = [2, 5, 10]\n",
"# Minimum number of samples required at each leaf node\n",
"min_samples_leaf = [1, 2, 4]\n",
"# Method of selecting samples for training each tree\n",
"bootstrap = [True, False]\n",
"# Create the grid\n",
"grid_rf = {'n_estimators': n_estimators,\n",
" 'max_depth': max_depth,\n",
" 'min_samples_split': min_samples_split,\n",
" 'min_samples_leaf': min_samples_leaf,\n",
" 'bootstrap': bootstrap}\n",
"print(grid_rf)"
],
"metadata": {
"execution": {
"iopub.status.busy": "2023-05-24T01:59:14.021625Z",
"iopub.status.idle": "2023-05-24T01:59:14.022487Z",
"shell.execute_reply.started": "2023-05-24T01:59:14.022189Z",
"shell.execute_reply": "2023-05-24T01:59:14.022216Z"
},
"id": "HdpaVmxSsPmT"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"from sklearn.ensemble import RandomForestClassifier"
],
"metadata": {
"execution": {
"iopub.status.busy": "2023-05-24T01:59:14.024015Z",
"iopub.status.idle": "2023-05-24T01:59:14.024474Z",
"shell.execute_reply.started": "2023-05-24T01:59:14.024258Z",
"shell.execute_reply": "2023-05-24T01:59:14.024286Z"
},
"id": "X_WPq8AJsPmU"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"grid_imdb_rf = RandomizedSearchCV(RandomForestClassifier(), param_distributions = grid_rf, n_iter = 200, cv = 3, verbose=2, random_state=42, n_jobs = -1)# Fit the random search model\n",
"# # Fit the random search model\n",
"grid_imdb_rf.fit(x_train_imdb,y_train_imdb.ravel())\n",
"pickle.dump(grid_imdb_rf, open('grid_imdb_rf.pickle', \"wb\"))"
],
"metadata": {
"execution": {
"iopub.status.busy": "2023-05-24T01:59:14.025697Z",
"iopub.status.idle": "2023-05-24T01:59:14.026125Z",
"shell.execute_reply.started": "2023-05-24T01:59:14.025909Z",
"shell.execute_reply": "2023-05-24T01:59:14.025929Z"
},
"id": "9iAEK_NasPmU"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Train SVC Model"
],
"metadata": {
"id": "qgKbzrRfsPmU"
}
},
{
"cell_type": "code",
"source": [
"# Param Optimisation\n",
"param_grid_imdb = {'C': [0.1,1, 10, 100], 'gamma': [1,0.1,0.01,0.001],'kernel': ['rbf']}\n",
"grid_imdb_svc = GridSearchCV(SVC(),param_grid_imdb,refit=True,verbose=2)"
],
"metadata": {
"id": "kBG69w4JsPmU"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"grid_imdb_svc.fit(x_train_imdb,y_train_imdb.ravel())\n",
"pickle.dump(grid_imdb_svc, open('grid_imdb_svc.pickle', \"wb\"))"
],
"metadata": {
"id": "7C7pX6QlsPmU"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Train KNN model"
],
"metadata": {
"id": "OUnEb3ZVsPmV"
}
},
{
"cell_type": "code",
"source": [
"grid_params_imdb_knn = { 'n_neighbors' : [30,40,50,60,70,80,90], 'metric' : ['manhattan', 'minkowski'], 'weights': ['uniform', 'distance']}\n",
"grid_imdb_knn = GridSearchCV(KNeighborsClassifier(), grid_params_imdb_knn, n_jobs=-1,verbose=2)"
],
"metadata": {
"id": "L4wJA2R0sPmV"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"grid_imdb_knn.fit(x_train_imdb,np.ravel(y_train_imdb,order='C'))\n",
"pickle.dump(grid_imdb_knn, open('grid_imdb_knn.pickle', \"wb\"))"
],
"metadata": {
"id": "zBwE0Cz-sPmV"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Train LR Model"
],
"metadata": {
"id": "MQdS2y_2sPmV"
}
},
{
"cell_type": "code",
"source": [
"param_grid_imdb_lr = [ \n",
" {'penalty' : ['l1', 'l2', 'elasticnet'],\n",
" 'C' : np.logspace(-4, 4, 20),\n",
" 'solver' : ['lbfgs','newton-cg','sag'],\n",
" 'max_iter' : [100, 1000, 5000]\n",
" }\n",
"]\n",
"grid_imdb_lr = GridSearchCV(LogisticRegression(), param_grid = param_grid_imdb_lr, cv = 3, verbose=2, n_jobs=-1)"
],
"metadata": {
"id": "xTyPiA4osPmV"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"grid_imdb_lr.fit(x_train_imdb, np.ravel(y_train_imdb,order='C'))\n",
"pickle.dump(grid_imdb_lr, open('grid_imdb_lr.pickle', \"wb\"))"
],
"metadata": {
"id": "vhVNQZITsPmW"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Load Models"
],
"metadata": {
"id": "AcRWpiiesPmW"
}
},
{
"cell_type": "code",
"source": [
"# Load\n",
"loaded_knn_imdb = pickle.load(open('/content/sample_data/grid_imdb_knn.pickle', \"rb\"))"
],
"metadata": {
"trusted": true,
"id": "rgbWUgfosPmW",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "d8110a61-dec0-40e6-a337-e12de205e304"
},
"execution_count": 23,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"<ipython-input-23-6b9db75f90c6>:2: DeprecationWarning: Please use `csr_matrix` from the `scipy.sparse` namespace, the `scipy.sparse.csr` namespace is deprecated.\n",
" loaded_knn_imdb = pickle.load(open('/content/sample_data/grid_imdb_knn.pickle', \"rb\"))\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"print(loaded_knn_imdb.best_params_)"
],
"metadata": {
"trusted": true,
"id": "XpZjBtLgsPmW",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "9e888c63-3f57-464b-a623-ce287fb10d1b"
},
"execution_count": 24,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"{'metric': 'minkowski', 'n_neighbors': 90, 'weights': 'distance'}\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"### Info about models"
],
"metadata": {
"id": "tds-0KE9sPmX"
}
},
{
"cell_type": "code",
"source": [
"p = np.sum(y_train_imdb)/np.size(y_train_imdb)"
],
"metadata": {
"trusted": true,
"id": "XXrGyw5zsPmX"
},
"execution_count": 25,
"outputs": []
},
{
"cell_type": "code",
"source": [
"probs_knn = loaded_knn_imdb.predict(x_test_imdb)"
],
"metadata": {
"trusted": true,
"id": "EpPwD1xMsPmY"
},
"execution_count": 26,
"outputs": []
},
{
"cell_type": "code",
"source": [
"loaded_knn_imdb.predict(x_test_imdb[0])"
],
"metadata": {
"trusted": true,
"id": "M3L6UcqxsPmq",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "119c56e3-a7c7-4b48-daff-6e9797ed0fea"
},
"execution_count": 27,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array([1])"
]
},
"metadata": {},
"execution_count": 27
}
]
},
{
"cell_type": "code",
"source": [
"accuracy_knn = metrics.accuracy_score(y_test_imdb, probs_knn)"
],
"metadata": {
"trusted": true,
"id": "OOuTh7TlsPmq"
},
"execution_count": 28,
"outputs": []
},
{
"cell_type": "code",
"source": [
"print(\"The accuracy of the KNN model on the test data is %f\" %accuracy_knn)"
],
"metadata": {
"trusted": true,
"id": "Gaqh_vIpsPmr",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "be28e7bf-1ae0-4caa-f301-3a0db9d10f5b"
},
"execution_count": 29,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"The accuracy of the KNN model on the test data is 0.819000\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"input_review = input('give me review')"
],
"metadata": {
"id": "Z6qC0v0JsPms",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "ee6f61df-ae10-471a-83fc-535503f68f90"
},
"execution_count": 30,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"give me reviewI like apple\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"vector_input = loaded_vectorizer.transform([input_review]) "
],
"metadata": {
"id": "7ncfW-8lSwww"
},
"execution_count": 31,
"outputs": []
},
{
"cell_type": "code",
"source": [
"loaded_knn_imdb.predict(vector_input)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "cOVi5sl5SzBP",
"outputId": "52f25084-f0ff-4a2e-cca6-1ea2e2f2901e"
},
"execution_count": 32,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array([0])"
]
},
"metadata": {},
"execution_count": 32
}
]
}
]
}
\ No newline at end of file
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "WNOJt2fcFRiH",
"outputId": "f7472149-8f61-46a1-9d95-ad06f7c90c56"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
"Collecting numpy==1.20.0\n",
" Downloading numpy-1.20.0.zip (8.0 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m8.0/8.0 MB\u001b[0m \u001b[31m52.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25h Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
" Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
" Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
"Building wheels for collected packages: numpy\n",
" \u001b[1;31merror\u001b[0m: \u001b[1msubprocess-exited-with-error\u001b[0m\n",
" \n",
" \u001b[31m×\u001b[0m \u001b[32mBuilding wheel for numpy \u001b[0m\u001b[1;32m(\u001b[0m\u001b[32mpyproject.toml\u001b[0m\u001b[1;32m)\u001b[0m did not run successfully.\n",
" \u001b[31m│\u001b[0m exit code: \u001b[1;36m1\u001b[0m\n",
" \u001b[31m╰─>\u001b[0m See above for output.\n",
" \n",
" \u001b[1;35mnote\u001b[0m: This error originates from a subprocess, and is likely not a problem with pip.\n",
" Building wheel for numpy (pyproject.toml) ... \u001b[?25l\u001b[?25herror\n",
"\u001b[31m ERROR: Failed building wheel for numpy\u001b[0m\u001b[31m\n",
"\u001b[0mFailed to build numpy\n",
"\u001b[31mERROR: Could not build wheels for numpy, which is required to install pyproject.toml-based projects\u001b[0m\u001b[31m\n",
"\u001b[0mLooking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
"Collecting ordered_set\n",
" Downloading ordered_set-4.1.0-py3-none-any.whl (7.6 kB)\n",
"Installing collected packages: ordered_set\n",
"Successfully installed ordered_set-4.1.0\n"
]
}
],
"source": [
"!pip install numpy==1.20.0\n",
"!pip install ordered_set"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "IClKGk-9G-M_"
},
"outputs": [],
"source": [
"import numpy as np # linear algebra\n",
"import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "B4HTD76gLyQO",
"outputId": "f196cd73-3993-4d19-aee7-b5a17594b88f"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"unzip: cannot find or open /usr/share/nltk_data/corpora/wordnet.zip, /usr/share/nltk_data/corpora/wordnet.zip.zip or /usr/share/nltk_data/corpora/wordnet.zip.ZIP.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"[nltk_data] Downloading package wordnet to /root/nltk_data...\n",
"[nltk_data] Package wordnet is already up-to-date!\n"
]
}
],
"source": [
"from sklearn.model_selection import RandomizedSearchCV\n",
"from sklearn.model_selection import GridSearchCV\n",
"from sklearn.metrics import roc_auc_score, accuracy_score\n",
"from sklearn.model_selection import ParameterGrid\n",
"from sklearn.svm import SVC\n",
"import sklearn.feature_extraction\n",
"from sklearn.feature_extraction.text import TfidfVectorizer\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn import metrics\n",
"from sklearn.pipeline import Pipeline\n",
"from sklearn.feature_extraction.text import TfidfVectorizer\n",
"from sklearn.feature_extraction.text import TfidfTransformer\n",
"\n",
"import nltk\n",
"from nltk.corpus import stopwords \n",
"from nltk.tokenize import word_tokenize\n",
"from nltk.stem import WordNetLemmatizer\n",
"nltk.download('wordnet')\n",
"!unzip /usr/share/nltk_data/corpora/wordnet.zip -d /usr/share/nltk_data/corpora/\n",
"\n",
"from bs4 import BeautifulSoup\n",
"import re\n",
"import pickle\n",
"import seaborn as sns\n",
"\n",
"from ordered_set import OrderedSet\n",
"from scipy.sparse import lil_matrix\n",
"from itertools import compress\n",
"\n",
"loaded_vocab = pickle.load(open('/content/drive/MyDrive/vectorizer_imdb.pkl', 'rb'))\n",
"\n",
"stop_words = set(stopwords.words('english'))\n",
"tokenizer = nltk.tokenize.toktok.ToktokTokenizer()\n",
"lemmatizer = WordNetLemmatizer() \n",
"loaded_vectorizer = TfidfVectorizer(min_df=2, vocabulary=loaded_vocab)\n",
"label_binarizer = sklearn.preprocessing.LabelBinarizer()\n",
"feature_names = loaded_vectorizer.get_feature_names_out()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "pPO3G_BeNXTZ"
},
"outputs": [],
"source": [
"import nltk"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "3AVqS5t5NnLV",
"outputId": "4a293def-3729-4fcc-8242-2a80a07907c0"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[nltk_data] Downloading package stopwords to /root/nltk_data...\n",
"[nltk_data] Unzipping corpora/stopwords.zip.\n"
]
},
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"nltk.download('stopwords')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9TlvtkGsOEEn"
},
"outputs": [],
"source": [
"def strip_html(text):\n",
" soup = BeautifulSoup(text, \"html.parser\")\n",
" return soup.get_text()\n",
"\n",
"def remove_special_characters(text, remove_digits=True):\n",
" pattern=r'[^a-zA-z0-9\\s]'\n",
" text=re.sub(pattern,'',text)\n",
" return text\n",
"\n",
"def remove_stopwords(text, is_lower_case=False):\n",
" tokens = tokenizer.tokenize(text)\n",
" tokens = [token.strip() for token in tokens]\n",
" if is_lower_case:\n",
" filtered_tokens = [token for token in tokens if token not in stop_words]\n",
" else:\n",
" filtered_tokens = [token for token in tokens if token.lower() not in stop_words]\n",
" filtered_text = ' '.join(filtered_tokens) \n",
" return filtered_text\n",
"\n",
"def lemmatize_text(text):\n",
" words=word_tokenize(text)\n",
" edited_text = ''\n",
" for word in words:\n",
" lemma_word=lemmatizer.lemmatize(word)\n",
" extra=\" \"+str(lemma_word)\n",
" edited_text+=extra\n",
" return edited_text"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-ubc9FDcOZNz"
},
"outputs": [],
"source": [
"## Import\n",
"data = pd.read_csv('/content/drive/MyDrive/IMDB Dataset.csv')\n",
"data = data.sample(10000)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "zzu70vmEOhMt",
"outputId": "dcdb2c79-8356-4dd5-8af1-3d45a59abee7"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"<ipython-input-11-7b716edd6682>:2: MarkupResemblesLocatorWarning: The input looks more like a filename than markup. You may want to open this file and pass the filehandle into Beautiful Soup.\n",
" soup = BeautifulSoup(text, \"html.parser\")\n"
]
}
],
"source": [
"## Preprocess\n",
"data.review = data.review.str.lower()\n",
"data.review = data.review.apply(strip_html)\n",
"data.review = data.review.apply(remove_special_characters)\n",
"data.review = data.review.apply(remove_stopwords)\n",
"data.review = data.review.apply(lemmatize_text)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ms8uVfj0bJdT"
},
"outputs": [],
"source": [
"import nltk"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "MFL0_u3QbUjc",
"outputId": "e01191d0-2a66-4977-ab08-b19e7705ab0f"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[nltk_data] Downloading package punkt to /root/nltk_data...\n",
"[nltk_data] Unzipping tokenizers/punkt.zip.\n"
]
},
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"nltk.download('punkt')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-miptIA5c_mk"
},
"outputs": [],
"source": [
"import nltk"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "4K-RAkQndG88",
"outputId": "e43a70f8-1d73-4f21-808c-8670d6549f73"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[nltk_data] Downloading package stopwords to /root/nltk_data...\n",
"[nltk_data] Unzipping corpora/stopwords.zip.\n"
]
},
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"nltk.download('stopwords')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "GRueBFH-d1qE"
},
"outputs": [],
"source": [
"## Split Data\n",
"x_imdb = data['review']\n",
"y_imdb = data['sentiment']\n",
"\n",
"x_train_i, x_test_i, y_train_i, y_test_i = train_test_split(x_imdb,y_imdb,test_size=0.2)\n",
"x_test, x_val, y_test_i, y_val_i = train_test_split(x_test_i,y_test_i,test_size=0.5)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "pA2pgBhdd4jm"
},
"outputs": [],
"source": [
"## X data\n",
"x_train_imdb = loaded_vectorizer.fit_transform(x_train_i)\n",
"x_test_imdb = loaded_vectorizer.transform(x_test)\n",
"x_val_imdb = loaded_vectorizer.transform(x_val)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "j1uHhqGPeDXc",
"outputId": "894ed3f7-117c-4601-891c-a7100527dfa5"
},
"outputs": [
{
"data": {
"text/plain": [
"(1, 26698)"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x_train_imdb[0].shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "b4RJXbcxeSKc"
},
"outputs": [],
"source": [
"# Y data - Positive is 1\n",
"y_train_imdb = label_binarizer.fit_transform(y_train_i)\n",
"y_test_imdb = label_binarizer.fit_transform(y_test_i)\n",
"y_val_imdb = label_binarizer.fit_transform(y_val_i)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "HdBinqNIea2L"
},
"outputs": [],
"source": [
"from sklearn.linear_model import LogisticRegression\n",
"param_grid_imdb_lr = [ \n",
" {'penalty' : ['l1', 'l2', 'elasticnet'],\n",
" 'C' : np.logspace(-4, 4, 20),\n",
" 'solver' : ['lbfgs','newton-cg','sag'],\n",
" 'max_iter' : [100, 1000, 5000]\n",
" }\n",
"]\n",
"grid_imdb_lr = GridSearchCV(LogisticRegression(), param_grid = param_grid_imdb_lr, cv = 3, verbose=2, n_jobs=-1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "4jbagEFMfBnk",
"outputId": "ebabf88b-035d-42de-beaa-e1387dc5459c"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Fitting 3 folds for each of 540 candidates, totalling 1620 fits\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.10/dist-packages/sklearn/model_selection/_validation.py:378: FitFailedWarning: \n",
"1080 fits failed out of a total of 1620.\n",
"The score on these train-test partitions for these parameters will be set to nan.\n",
"If these failures are not expected, you can try to debug them by setting error_score='raise'.\n",
"\n",
"Below are more details about the failures:\n",
"--------------------------------------------------------------------------------\n",
"180 fits failed with the following error:\n",
"Traceback (most recent call last):\n",
" File \"/usr/local/lib/python3.10/dist-packages/sklearn/model_selection/_validation.py\", line 686, in _fit_and_score\n",
" estimator.fit(X_train, y_train, **fit_params)\n",
" File \"/usr/local/lib/python3.10/dist-packages/sklearn/linear_model/_logistic.py\", line 1162, in fit\n",
" solver = _check_solver(self.solver, self.penalty, self.dual)\n",
" File \"/usr/local/lib/python3.10/dist-packages/sklearn/linear_model/_logistic.py\", line 54, in _check_solver\n",
" raise ValueError(\n",
"ValueError: Solver lbfgs supports only 'l2' or 'none' penalties, got l1 penalty.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"180 fits failed with the following error:\n",
"Traceback (most recent call last):\n",
" File \"/usr/local/lib/python3.10/dist-packages/sklearn/model_selection/_validation.py\", line 686, in _fit_and_score\n",
" estimator.fit(X_train, y_train, **fit_params)\n",
" File \"/usr/local/lib/python3.10/dist-packages/sklearn/linear_model/_logistic.py\", line 1162, in fit\n",
" solver = _check_solver(self.solver, self.penalty, self.dual)\n",
" File \"/usr/local/lib/python3.10/dist-packages/sklearn/linear_model/_logistic.py\", line 54, in _check_solver\n",
" raise ValueError(\n",
"ValueError: Solver newton-cg supports only 'l2' or 'none' penalties, got l1 penalty.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"180 fits failed with the following error:\n",
"Traceback (most recent call last):\n",
" File \"/usr/local/lib/python3.10/dist-packages/sklearn/model_selection/_validation.py\", line 686, in _fit_and_score\n",
" estimator.fit(X_train, y_train, **fit_params)\n",
" File \"/usr/local/lib/python3.10/dist-packages/sklearn/linear_model/_logistic.py\", line 1162, in fit\n",
" solver = _check_solver(self.solver, self.penalty, self.dual)\n",
" File \"/usr/local/lib/python3.10/dist-packages/sklearn/linear_model/_logistic.py\", line 54, in _check_solver\n",
" raise ValueError(\n",
"ValueError: Solver sag supports only 'l2' or 'none' penalties, got l1 penalty.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"180 fits failed with the following error:\n",
"Traceback (most recent call last):\n",
" File \"/usr/local/lib/python3.10/dist-packages/sklearn/model_selection/_validation.py\", line 686, in _fit_and_score\n",
" estimator.fit(X_train, y_train, **fit_params)\n",
" File \"/usr/local/lib/python3.10/dist-packages/sklearn/linear_model/_logistic.py\", line 1162, in fit\n",
" solver = _check_solver(self.solver, self.penalty, self.dual)\n",
" File \"/usr/local/lib/python3.10/dist-packages/sklearn/linear_model/_logistic.py\", line 54, in _check_solver\n",
" raise ValueError(\n",
"ValueError: Solver lbfgs supports only 'l2' or 'none' penalties, got elasticnet penalty.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"180 fits failed with the following error:\n",
"Traceback (most recent call last):\n",
" File \"/usr/local/lib/python3.10/dist-packages/sklearn/model_selection/_validation.py\", line 686, in _fit_and_score\n",
" estimator.fit(X_train, y_train, **fit_params)\n",
" File \"/usr/local/lib/python3.10/dist-packages/sklearn/linear_model/_logistic.py\", line 1162, in fit\n",
" solver = _check_solver(self.solver, self.penalty, self.dual)\n",
" File \"/usr/local/lib/python3.10/dist-packages/sklearn/linear_model/_logistic.py\", line 54, in _check_solver\n",
" raise ValueError(\n",
"ValueError: Solver newton-cg supports only 'l2' or 'none' penalties, got elasticnet penalty.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"180 fits failed with the following error:\n",
"Traceback (most recent call last):\n",
" File \"/usr/local/lib/python3.10/dist-packages/sklearn/model_selection/_validation.py\", line 686, in _fit_and_score\n",
" estimator.fit(X_train, y_train, **fit_params)\n",
" File \"/usr/local/lib/python3.10/dist-packages/sklearn/linear_model/_logistic.py\", line 1162, in fit\n",
" solver = _check_solver(self.solver, self.penalty, self.dual)\n",
" File \"/usr/local/lib/python3.10/dist-packages/sklearn/linear_model/_logistic.py\", line 54, in _check_solver\n",
" raise ValueError(\n",
"ValueError: Solver sag supports only 'l2' or 'none' penalties, got elasticnet penalty.\n",
"\n",
" warnings.warn(some_fits_failed_message, FitFailedWarning)\n",
"/usr/local/lib/python3.10/dist-packages/sklearn/model_selection/_search.py:952: UserWarning: One or more of the test scores are non-finite: [ nan nan nan 0.50124998 0.50124998 0.50124998\n",
" nan nan nan nan nan nan\n",
" 0.50124998 0.50124998 0.50124998 nan nan nan\n",
" nan nan nan 0.50124998 0.50124998 0.50124998\n",
" nan nan nan nan nan nan\n",
" 0.50124998 0.50124998 0.50124998 nan nan nan\n",
" nan nan nan 0.50124998 0.50124998 0.50124998\n",
" nan nan nan nan nan nan\n",
" 0.50124998 0.50124998 0.50124998 nan nan nan\n",
" nan nan nan 0.53499905 0.53499905 0.53312475\n",
" nan nan nan nan nan nan\n",
" 0.53499905 0.53499905 0.53674878 nan nan nan\n",
" nan nan nan 0.53499905 0.53499905 0.53749897\n",
" nan nan nan nan nan nan\n",
" 0.76587341 0.76587341 0.76687324 nan nan nan\n",
" nan nan nan 0.76587341 0.76587341 0.76562312\n",
" nan nan nan nan nan nan\n",
" 0.76587341 0.76587341 0.76574829 nan nan nan\n",
" nan nan nan 0.81412493 0.81412493 0.81412493\n",
" nan nan nan nan nan nan\n",
" 0.81412493 0.81412493 0.81424992 nan nan nan\n",
" nan nan nan 0.81412493 0.81412493 0.81424992\n",
" nan nan nan nan nan nan\n",
" 0.81587471 0.81587471 0.81562465 nan nan nan\n",
" nan nan nan 0.81587471 0.81587471 0.81587471\n",
" nan nan nan nan nan nan\n",
" 0.81587471 0.81587471 0.81562465 nan nan nan\n",
" nan nan nan 0.82037481 0.82037481 0.82037481\n",
" nan nan nan nan nan nan\n",
" 0.82037481 0.82037481 0.82037481 nan nan nan\n",
" nan nan nan 0.82037481 0.82037481 0.82037481\n",
" nan nan nan nan nan nan\n",
" 0.83187445 0.83187445 0.83187445 nan nan nan\n",
" nan nan nan 0.83187445 0.83187445 0.83187445\n",
" nan nan nan nan nan nan\n",
" 0.83187445 0.83187445 0.83187445 nan nan nan\n",
" nan nan nan 0.84899937 0.84899937 0.84899937\n",
" nan nan nan nan nan nan\n",
" 0.84899937 0.84899937 0.84899937 nan nan nan\n",
" nan nan nan 0.84899937 0.84899937 0.84899937\n",
" nan nan nan nan nan nan\n",
" 0.86224954 0.86224954 0.86224954 nan nan nan\n",
" nan nan nan 0.86224954 0.86224954 0.86224954\n",
" nan nan nan nan nan nan\n",
" 0.86224954 0.86224954 0.86224954 nan nan nan\n",
" nan nan nan 0.86849946 0.86849946 0.86849946\n",
" nan nan nan nan nan nan\n",
" 0.86849946 0.86849946 0.86849946 nan nan nan\n",
" nan nan nan 0.86849946 0.86849946 0.86849946\n",
" nan nan nan nan nan nan\n",
" 0.87087431 0.87087431 0.87087431 nan nan nan\n",
" nan nan nan 0.87087431 0.87087431 0.87087431\n",
" nan nan nan nan nan nan\n",
" 0.87087431 0.87087431 0.87087431 nan nan nan\n",
" nan nan nan 0.86899912 0.86899912 0.86899912\n",
" nan nan nan nan nan nan\n",
" 0.86899912 0.86899912 0.86899912 nan nan nan\n",
" nan nan nan 0.86899912 0.86899912 0.86899912\n",
" nan nan nan nan nan nan\n",
" 0.86937403 0.86937403 0.86949906 nan nan nan\n",
" nan nan nan 0.86937403 0.86937403 0.86924904\n",
" nan nan nan nan nan nan\n",
" 0.86937403 0.86937403 0.86937407 nan nan nan\n",
" nan nan nan 0.86562426 0.86574925 0.86512414\n",
" nan nan nan nan nan nan\n",
" 0.86574925 0.86574925 0.86549923 nan nan nan\n",
" nan nan nan 0.86574925 0.86574925 0.86574925\n",
" nan nan nan nan nan nan\n",
" 0.86149912 0.86187403 0.86274906 nan nan nan\n",
" nan nan nan 0.86199901 0.86187403 0.86187407\n",
" nan nan nan nan nan nan\n",
" 0.86199901 0.86187403 0.8616241 nan nan nan\n",
" nan nan nan 0.86012396 0.86049887 0.86337403\n",
" nan nan nan nan nan nan\n",
" 0.86049882 0.86049887 0.8598739 nan nan nan\n",
" nan nan nan 0.86049882 0.86049887 0.86074898\n",
" nan nan nan nan nan nan\n",
" 0.85987409 0.85899906 0.86287409 nan nan nan\n",
" nan nan nan 0.85899906 0.85899906 0.86062409\n",
" nan nan nan nan nan nan\n",
" 0.85899906 0.85899906 0.8601242 nan nan nan\n",
" nan nan nan 0.8589991 0.85899915 0.86387425\n",
" nan nan nan nan nan nan\n",
" 0.85899915 0.85899915 0.86299917 nan nan nan\n",
" nan nan nan 0.85899915 0.85899915 0.85949904\n",
" nan nan nan nan nan nan\n",
" 0.85849903 0.85799909 0.86174885 nan nan nan\n",
" nan nan nan 0.85799909 0.85799909 0.8611244\n",
" nan nan nan nan nan nan\n",
" 0.85799909 0.85799909 0.85887389 nan nan nan]\n",
" warnings.warn(\n"
]
}
],
"source": [
"grid_imdb_lr.fit(x_train_imdb, np.ravel(y_train_imdb,order='C'))\n",
"pickle.dump(grid_imdb_lr, open('grid_imdb_lr.pickle', \"wb\"))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Kz3dUqM0glC6"
},
"outputs": [],
"source": [
"loaded_lr_imdb = pickle.load(open('/content/drive/MyDrive/grid_imdb_lr.pickle', \"rb\"))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "mYl7xkluhEqL",
"outputId": "65c43228-85cc-4591-e9ac-cf387ddbf830"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'C': 4.281332398719396, 'max_iter': 100, 'penalty': 'l2', 'solver': 'lbfgs'}\n"
]
}
],
"source": [
"print(loaded_lr_imdb.best_params_)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "uM5zsNadhUcU"
},
"outputs": [],
"source": [
"p = np.sum(y_train_imdb)/np.size(y_train_imdb)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "nvWe0nKYhaLD"
},
"outputs": [],
"source": [
"probs_lr = loaded_lr_imdb.predict(x_test_imdb)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "-H4swmnehfWV",
"outputId": "68c1766f-eaa3-4c8e-cba7-6a3e87c0d804"
},
"outputs": [
{
"data": {
"text/plain": [
"array([0])"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"loaded_lr_imdb.predict(x_test_imdb[0])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "L45shyPuhnmz"
},
"outputs": [],
"source": [
"accuracy_lr = metrics.accuracy_score(y_test_imdb, probs_lr)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "gSe44ZOihssa",
"outputId": "56131b8a-53f7-489b-a072-4364201bfb75"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The accuracy of the SVC model on the test data is 0.890000\n"
]
}
],
"source": [
"print(\"The accuracy of the SVC model on the test data is %f\" %accuracy_lr)"
]
}
],
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "aAnW5l0dvRT4",
"outputId": "e8853b7d-4ff8-45be-e321-0d902b9a6582"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
"Collecting numpy==1.20.0\n",
" Downloading numpy-1.20.0.zip (8.0 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m8.0/8.0 MB\u001b[0m \u001b[31m36.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25h Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
" Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
" Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
"Building wheels for collected packages: numpy\n",
" \u001b[1;31merror\u001b[0m: \u001b[1msubprocess-exited-with-error\u001b[0m\n",
" \n",
" \u001b[31m×\u001b[0m \u001b[32mBuilding wheel for numpy \u001b[0m\u001b[1;32m(\u001b[0m\u001b[32mpyproject.toml\u001b[0m\u001b[1;32m)\u001b[0m did not run successfully.\n",
" \u001b[31m│\u001b[0m exit code: \u001b[1;36m1\u001b[0m\n",
" \u001b[31m╰─>\u001b[0m See above for output.\n",
" \n",
" \u001b[1;35mnote\u001b[0m: This error originates from a subprocess, and is likely not a problem with pip.\n",
" Building wheel for numpy (pyproject.toml) ... \u001b[?25l\u001b[?25herror\n",
"\u001b[31m ERROR: Failed building wheel for numpy\u001b[0m\u001b[31m\n",
"\u001b[0mFailed to build numpy\n",
"\u001b[31mERROR: Could not build wheels for numpy, which is required to install pyproject.toml-based projects\u001b[0m\u001b[31m\n",
"\u001b[0mLooking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
"Collecting ordered_set\n",
" Downloading ordered_set-4.1.0-py3-none-any.whl (7.6 kB)\n",
"Installing collected packages: ordered_set\n",
"Successfully installed ordered_set-4.1.0\n"
]
}
],
"source": [
"!pip install numpy==1.20.0\n",
"!pip install ordered_set"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "HkTrqt4utY-B",
"outputId": "47471b6c-4455-405a-c4c5-4f6de1fba1bc"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mounted at /content/drive\n"
]
}
],
"source": [
"from google.colab import drive\n",
"drive.mount('/content/drive')"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"id": "F9UFXs5du3v2"
},
"outputs": [],
"source": [
"import numpy as np \n",
"import pandas as pd"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "1qf6NiHlvGJU",
"outputId": "0f1b7fde-8434-4937-d858-1004c5d70b63"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[nltk_data] Downloading package wordnet to /root/nltk_data...\n",
"[nltk_data] Downloading package stopwords to /root/nltk_data...\n",
"[nltk_data] Unzipping corpora/stopwords.zip.\n",
"[nltk_data] Downloading package punkt to /root/nltk_data...\n",
"[nltk_data] Unzipping tokenizers/punkt.zip.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"unzip: cannot find or open /usr/share/nltk_data/corpora/wordnet.zip, /usr/share/nltk_data/corpora/wordnet.zip.zip or /usr/share/nltk_data/corpora/wordnet.zip.ZIP.\n"
]
}
],
"source": [
"import time\n",
"\n",
"from sklearn.model_selection import RandomizedSearchCV\n",
"from sklearn.model_selection import GridSearchCV\n",
"from sklearn.metrics import roc_auc_score, accuracy_score\n",
"from sklearn.model_selection import ParameterGrid\n",
"from sklearn.svm import SVC\n",
"import sklearn.feature_extraction\n",
"from sklearn.feature_extraction.text import TfidfVectorizer\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn import metrics\n",
"from sklearn.pipeline import Pipeline\n",
"from sklearn.feature_extraction.text import TfidfVectorizer\n",
"from sklearn.feature_extraction.text import TfidfTransformer\n",
"\n",
"import nltk\n",
"from nltk.corpus import stopwords \n",
"from nltk.tokenize import word_tokenize\n",
"from nltk.stem import WordNetLemmatizer\n",
"nltk.download('wordnet')\n",
"nltk.download('stopwords')\n",
"nltk.download('punkt')\n",
"!unzip /usr/share/nltk_data/corpora/wordnet.zip -d /usr/share/nltk_data/corpora/\n",
"\n",
"from bs4 import BeautifulSoup\n",
"import re\n",
"import pickle\n",
"import seaborn as sns\n",
"\n",
"from ordered_set import OrderedSet\n",
"from scipy.sparse import lil_matrix\n",
"from itertools import compress"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "FyySbNNYwjuR",
"outputId": "c835e2a7-adca-4073-d5a6-8668daf35b6f"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<class 'pandas.core.frame.DataFrame'>\n",
"RangeIndex: 50000 entries, 0 to 49999\n",
"Data columns (total 2 columns):\n",
" # Column Non-Null Count Dtype \n",
"--- ------ -------------- ----- \n",
" 0 review 50000 non-null object\n",
" 1 sentiment 50000 non-null object\n",
"dtypes: object(2)\n",
"memory usage: 781.4+ KB\n"
]
}
],
"source": [
"# import dataset\n",
"df = pd.read_csv('/content/drive/MyDrive/IMDB Dataset.csv',parse_dates=True)\n",
"df.info()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"id": "nJp922LYwLRL"
},
"outputs": [],
"source": [
"loaded_vocab = pickle.load(open('/content/drive/MyDrive/vectorizer_imdb.pkl', 'rb'))\n",
"\n",
"stop_words = set(stopwords.words('english'))\n",
"tokenizer = nltk.tokenize.toktok.ToktokTokenizer()\n",
"lemmatizer = WordNetLemmatizer() \n",
"loaded_vectorizer = TfidfVectorizer(min_df=2, vocabulary=loaded_vocab)\n",
"label_binarizer = sklearn.preprocessing.LabelBinarizer()\n",
"feature_names = loaded_vectorizer.get_feature_names_out()"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "UEiQ73qMxroZ"
},
"source": [
"# **Support Functions**"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"id": "IzjUkN0Txkbk"
},
"outputs": [],
"source": [
"#Remove HTML tags in review data\n",
"def strip_html(text):\n",
" soup = BeautifulSoup(text, \"html.parser\")\n",
" return soup.get_text()\n",
"\n",
"#Remove special characters like %#( ) keeping only alphanumeric chracters\n",
"def remove_special_characters(text, remove_digits=True):\n",
" pattern=r'[^a-zA-z0-9\\s]'\n",
" text=re.sub(pattern,'',text)\n",
" return text\n",
"\n",
"#Remove stopwords in the review \n",
"def remove_stopwords(text, is_lower_case=False):\n",
" tokens = tokenizer.tokenize(text)\n",
" tokens = [token.strip() for token in tokens]\n",
" if is_lower_case:\n",
" filtered_tokens = [token for token in tokens if token not in stop_words]\n",
" else:\n",
" filtered_tokens = [token for token in tokens if token.lower() not in stop_words]\n",
" filtered_text = ' '.join(filtered_tokens) \n",
" return filtered_text\n",
"\n",
"#Get the root word of words\n",
"def lemmatize_text(text):\n",
" words=word_tokenize(text)\n",
" edited_text = ''\n",
" for word in words:\n",
" lemma_word=lemmatizer.lemmatize(word)\n",
" extra=\" \"+str(lemma_word)\n",
" edited_text+=extra\n",
" return edited_text"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "Bxq5shOnyFCV"
},
"source": [
"# **Import DataSet and Preprocess**"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"id": "8mbindsxyHC9"
},
"outputs": [],
"source": [
"## Import\n",
"data = pd.read_csv('/content/drive/MyDrive/IMDB Dataset.csv')\n",
"##Reduce training time\n",
"data = data.sample(10000)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "zNAmNRgTybFh",
"outputId": "37c5b24f-f370-484b-ff17-724a28fa1b3b"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"<ipython-input-7-5967f8106591>:3: MarkupResemblesLocatorWarning: The input looks more like a filename than markup. You may want to open this file and pass the filehandle into Beautiful Soup.\n",
" soup = BeautifulSoup(text, \"html.parser\")\n"
]
}
],
"source": [
"## Preprocess\n",
"data.review = data.review.str.lower()\n",
"data.review = data.review.apply(strip_html)\n",
"data.review = data.review.apply(remove_special_characters)\n",
"data.review = data.review.apply(remove_stopwords)\n",
"data.review = data.review.apply(lemmatize_text)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"id": "C6NiUM-0yycb"
},
"outputs": [],
"source": [
"## Split Data\n",
"x_imdb = data['review']\n",
"y_imdb = data['sentiment']\n",
"\n",
"x_train_i, x_test_i, y_train_i, y_test_i = train_test_split(x_imdb,y_imdb,test_size=0.2)\n",
"x_test, x_val, y_test_i, y_val_i = train_test_split(x_test_i,y_test_i,test_size=0.5)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"id": "RjYCWtiXy2Rf"
},
"outputs": [],
"source": [
"## X data -> transform into vectors\n",
"x_train_imdb = loaded_vectorizer.fit_transform(x_train_i)\n",
"x_test_imdb = loaded_vectorizer.transform(x_test)\n",
"x_val_imdb = loaded_vectorizer.transform(x_val)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "jUyM0xLOy5yW",
"outputId": "84fbad98-58ea-474f-d950-7bc104bcde19"
},
"outputs": [
{
"data": {
"text/plain": [
"(1, 26698)"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x_train_imdb[0].shape"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"id": "_Ci1-1Jcy9LI"
},
"outputs": [],
"source": [
"# Y data - Positive is 1\n",
"y_train_imdb = label_binarizer.fit_transform(y_train_i)\n",
"y_test_imdb = label_binarizer.fit_transform(y_test_i)\n",
"y_val_imdb = label_binarizer.fit_transform(y_val_i)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "iCgg6XhJzAQq"
},
"source": [
"# **Train RF model**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "ZgF4cqf8zBzU",
"outputId": "31f638bc-b900-44f6-ca74-29f16668b70c"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'n_estimators': [10, 20, 30, 40, 50, 60, 70, 80, 90, 100], 'max_depth': [10, 32, 55, 77, 100, None], 'min_samples_split': [2, 5, 10], 'min_samples_leaf': [1, 2, 4], 'bootstrap': [True, False]}\n"
]
}
],
"source": [
"# Number of trees in random forest\n",
"n_estimators = [int(x) for x in np.linspace(start = 10, stop = 100, num = 10)]\n",
"# Maximum number of levels in tree\n",
"max_depth = [int(x) for x in np.linspace(10, 100, num = 5)]\n",
"max_depth.append(None)\n",
"# Minimum number of samples required to split a node\n",
"min_samples_split = [2, 5, 10]\n",
"# Minimum number of samples required at each leaf node\n",
"min_samples_leaf = [1, 2, 4]\n",
"# Method of selecting samples for training each tree\n",
"bootstrap = [True, False]\n",
"# Create the grid\n",
"grid_rf = {'n_estimators': n_estimators,\n",
" 'max_depth': max_depth,\n",
" 'min_samples_split': min_samples_split,\n",
" 'min_samples_leaf': min_samples_leaf,\n",
" 'bootstrap': bootstrap}\n",
"print(grid_rf)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "iXNaDTrBzIlf"
},
"outputs": [],
"source": [
"from sklearn.ensemble import RandomForestClassifier"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Bd4mLe8HzwJn"
},
"outputs": [],
"source": [
"#Therefore used randomised -> takes less time -> issue - sometimes go to local minima\n",
"grid_imdb_rf = RandomizedSearchCV(RandomForestClassifier(), param_distributions = grid_rf, n_iter = 200, cv = 3, verbose=2, random_state=42, n_jobs = -1)# Fit the random search model\n",
"# # Fit the random search model\n",
"grid_imdb_rf.fit(x_train_imdb,y_train_imdb.ravel())\n",
"#save the trained model to a pickle file to use afterwards\n",
"pickle.dump(grid_imdb_rf, open('grid_imdb_rf.pickle', \"wb\"))"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "gFeISxuuGkIl"
},
"source": [
"# **Load Model**"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"id": "GqJETBX-GjSi"
},
"outputs": [],
"source": [
"loaded_rf_imdb = pickle.load(open('/content/drive/MyDrive/grid_imdb_rf.pickle', \"rb\"))"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "IOUB66VHG2eB",
"outputId": "2cbfa219-b9a3-4ae4-d52a-83b6db5ece57"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'n_estimators': 100, 'min_samples_split': 5, 'min_samples_leaf': 4, 'max_depth': None, 'bootstrap': False}\n"
]
}
],
"source": [
"print(loaded_rf_imdb.best_params_)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "owOb6Gq-HJsO"
},
"source": [
"**Info about models**"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"id": "RyEcQXC9HK8J"
},
"outputs": [],
"source": [
"p = np.sum(y_train_imdb)/np.size(y_train_imdb)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"id": "CPYGvF-JHW0i"
},
"outputs": [],
"source": [
"probs_rf = loaded_rf_imdb.predict(x_test_imdb)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "A1PDKThAAIja",
"outputId": "a74b437b-8f67-4be1-edb4-072e120d0a8f"
},
"outputs": [
{
"data": {
"text/plain": [
"array([0])"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"loaded_rf_imdb.predict(x_test_imdb[0])"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {
"id": "H8KoYli2Higc"
},
"outputs": [],
"source": [
"accuracy_rf = metrics.accuracy_score(y_test_imdb, probs_rf)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "ZUj6LAeKHjQE",
"outputId": "ce507192-34ed-4dde-c4f4-c10fa874578f"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The accuracy of RF the model on the test data is 0.857000\n"
]
}
],
"source": [
"print(\"The accuracy of RF the model on the test data is %f\" %accuracy_rf)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "aXOhr74W_Do_",
"outputId": "2622a8f0-08cd-4889-cbc4-200f103c00af"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"give me reviewgive me reviewit is good movie\n"
]
}
],
"source": [
"input_review = input('give me review')"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {
"id": "rs0wvPMW_PTC"
},
"outputs": [],
"source": [
"vector_input = loaded_vectorizer.transform([input_review]) "
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "U-2KQ0UbAN2i",
"outputId": "5da64d03-b8cd-41f2-a5ed-bbb9f92612b4"
},
"outputs": [
{
"data": {
"text/plain": [
"array([1])"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"loaded_rf_imdb.predict(vector_input)"
]
}
],
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
{
"metadata": {
"kernelspec": {
"language": "python",
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.10.10",
"mimetype": "text/x-python",
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"pygments_lexer": "ipython3",
"nbconvert_exporter": "python",
"file_extension": ".py"
},
"colab": {
"provenance": []
}
},
"nbformat_minor": 0,
"nbformat": 4,
"cells": [
{
"cell_type": "markdown",
"source": [
"# Initialisation"
],
"metadata": {
"id": "DdGNtVpTHIFF"
}
},
{
"cell_type": "markdown",
"source": [
"## Installing and importing packages"
],
"metadata": {
"id": "PkvoTFD1HIFI"
}
},
{
"cell_type": "code",
"source": [
"!pip install numpy==1.20.0\n",
"!pip install ordered_set"
],
"metadata": {
"trusted": true,
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "1C1_8SsOHIFJ",
"outputId": "3cd86637-8171-4f8a-e4b5-3a02ba996509"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
"Collecting numpy==1.20.0\n",
" Downloading numpy-1.20.0.zip (8.0 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m8.0/8.0 MB\u001b[0m \u001b[31m84.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25h Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
" Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
" Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
"Building wheels for collected packages: numpy\n",
" \u001b[1;31merror\u001b[0m: \u001b[1msubprocess-exited-with-error\u001b[0m\n",
" \n",
" \u001b[31m×\u001b[0m \u001b[32mBuilding wheel for numpy \u001b[0m\u001b[1;32m(\u001b[0m\u001b[32mpyproject.toml\u001b[0m\u001b[1;32m)\u001b[0m did not run successfully.\n",
" \u001b[31m│\u001b[0m exit code: \u001b[1;36m1\u001b[0m\n",
" \u001b[31m╰─>\u001b[0m See above for output.\n",
" \n",
" \u001b[1;35mnote\u001b[0m: This error originates from a subprocess, and is likely not a problem with pip.\n",
" Building wheel for numpy (pyproject.toml) ... \u001b[?25l\u001b[?25herror\n",
"\u001b[31m ERROR: Failed building wheel for numpy\u001b[0m\u001b[31m\n",
"\u001b[0mFailed to build numpy\n",
"\u001b[31mERROR: Could not build wheels for numpy, which is required to install pyproject.toml-based projects\u001b[0m\u001b[31m\n",
"\u001b[0mLooking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
"Collecting ordered_set\n",
" Downloading ordered_set-4.1.0-py3-none-any.whl (7.6 kB)\n",
"Installing collected packages: ordered_set\n",
"Successfully installed ordered_set-4.1.0\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"import numpy as np # linear algebra\n",
"import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)\n",
"\n",
"import os\n",
"for dirname, _, filenames in os.walk('/kaggle/input'):\n",
" for filename in filenames:\n",
" print(os.path.join(dirname, filename))\n"
],
"metadata": {
"_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5",
"_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19",
"trusted": true,
"id": "3xvciL8JHIFK"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"\n",
"import time\n",
"\n",
"from sklearn.model_selection import RandomizedSearchCV\n",
"from sklearn.model_selection import GridSearchCV\n",
"from sklearn.metrics import roc_auc_score, accuracy_score\n",
"from sklearn.model_selection import ParameterGrid\n",
"from sklearn.svm import SVC\n",
"import sklearn.feature_extraction\n",
"from sklearn.feature_extraction.text import TfidfVectorizer\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn import metrics\n",
"from sklearn.pipeline import Pipeline\n",
"from sklearn.feature_extraction.text import TfidfVectorizer\n",
"from sklearn.feature_extraction.text import TfidfTransformer\n",
"\n",
"from sklearn.neighbors import KNeighborsClassifier\n",
"from sklearn.linear_model import LogisticRegression\n",
"\n",
"import nltk\n",
"from nltk.corpus import stopwords \n",
"from nltk.tokenize import word_tokenize\n",
"from nltk.stem import WordNetLemmatizer\n",
"nltk.download('wordnet')\n",
"nltk.download('stopwords')\n",
"nltk.download('punkt')\n",
"#!unzip /usr/share/nltk_data/corpora/wordnet.zip -d /usr/share/nltk_data/corpora/\n",
"\n",
"from bs4 import BeautifulSoup\n",
"import re\n",
"import pickle\n",
"import seaborn as sns\n",
"\n",
"from ordered_set import OrderedSet\n",
"from scipy.sparse import lil_matrix\n",
"from itertools import compress\n",
"\n",
"loaded_vocab = pickle.load(open('/content/drive/MyDrive/vectorizer_imdb.pkl', 'rb'))\n",
"\n",
"stop_words = set(stopwords.words('english'))\n",
"tokenizer = nltk.tokenize.toktok.ToktokTokenizer()\n",
"lemmatizer = WordNetLemmatizer() \n",
"loaded_vectorizer = TfidfVectorizer(min_df=2, vocabulary=loaded_vocab)\n",
"label_binarizer = sklearn.preprocessing.LabelBinarizer()\n",
"feature_names = loaded_vectorizer.get_feature_names_out()"
],
"metadata": {
"trusted": true,
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "y4T0HxEMHIFK",
"outputId": "e62a6958-f37d-484d-f763-6d0e8da9f9ee"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"[nltk_data] Downloading package wordnet to /root/nltk_data...\n",
"[nltk_data] Package wordnet is already up-to-date!\n",
"[nltk_data] Downloading package stopwords to /root/nltk_data...\n",
"[nltk_data] Package stopwords is already up-to-date!\n",
"[nltk_data] Downloading package punkt to /root/nltk_data...\n",
"[nltk_data] Package punkt is already up-to-date!\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"from google.colab import drive\n",
"drive.mount('/content/drive')"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "uNJtfK9GSb-_",
"outputId": "70c7acc1-2ab9-4fe0-d736-714049e4a7c3"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Mounted at /content/drive\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"DATA = \"/content/drive\""
],
"metadata": {
"id": "GMutHjwsSxZf"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Support Functions"
],
"metadata": {
"id": "Mpag4FIXHIFL"
}
},
{
"cell_type": "code",
"source": [
"# Removes HTML tags in review data\n",
"def strip_html(text):\n",
" soup = BeautifulSoup(text, \"html.parser\")\n",
" return soup.get_text()\n",
"\n",
"# Remove special characters like #$*( ) keeping only alphanumeric characters\n",
"def remove_special_characters(text, remove_digits=True):\n",
" pattern=r'[^a-zA-z0-9\\s]'\n",
" text=re.sub(pattern,'',text)\n",
" return text\n",
"\n",
"# Remove stopwords in the review\n",
"def remove_stopwords(text, is_lower_case=False):\n",
" tokens = tokenizer.tokenize(text)\n",
" tokens = [token.strip() for token in tokens]\n",
" if is_lower_case:\n",
" filtered_tokens = [token for token in tokens if token not in stop_words]\n",
" else:\n",
" filtered_tokens = [token for token in tokens if token.lower() not in stop_words]\n",
" filtered_text = ' '.join(filtered_tokens) \n",
" return filtered_text\n",
"\n",
"# Get the root word of words -> prevent overfitting, reduce training time\n",
"def lemmatize_text(text):\n",
" words=word_tokenize(text)\n",
" edited_text = ''\n",
" for word in words:\n",
" lemma_word=lemmatizer.lemmatize(word)\n",
" extra=\" \"+str(lemma_word)\n",
" edited_text+=extra\n",
" return edited_text"
],
"metadata": {
"trusted": true,
"id": "JElvoaekHIFL"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Import DataSet and Preprocess"
],
"metadata": {
"id": "76mMXhx1HIFL"
}
},
{
"cell_type": "code",
"source": [
"## Import\n",
"data = pd.read_csv('/content/drive/MyDrive/IMDB Dataset.csv')\n",
"## Reduce training time\n",
"data = data.sample(10000)"
],
"metadata": {
"trusted": true,
"id": "P8sLi1NJHIFM"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"data.shape"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "fPN5y4NfO3Zg",
"outputId": "bedb30b5-e8bd-4daa-800d-915e6be00254"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(10000, 2)"
]
},
"metadata": {},
"execution_count": 9
}
]
},
{
"cell_type": "code",
"source": [
"## Preprocess\n",
"data.review = data.review.str.lower()\n",
"data.review = data.review.apply(strip_html)\n",
"data.review = data.review.apply(remove_special_characters)\n",
"data.review = data.review.apply(remove_stopwords)\n",
"data.review = data.review.apply(lemmatize_text)"
],
"metadata": {
"trusted": true,
"id": "YxcaGylxHIFM",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "6a48219f-28b5-44b0-b593-a21235cda1a1"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"<ipython-input-7-24ac98f943f6>:3: MarkupResemblesLocatorWarning: The input looks more like a filename than markup. You may want to open this file and pass the filehandle into Beautiful Soup.\n",
" soup = BeautifulSoup(text, \"html.parser\")\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"## Split Data\n",
"x_imdb = data['review']\n",
"y_imdb = data['sentiment']\n",
"\n",
"x_train_i, x_test_i, y_train_i, y_test_i = train_test_split(x_imdb,y_imdb,test_size=0.2)\n",
"x_test, x_val, y_test_i, y_val_i = train_test_split(x_test_i,y_test_i,test_size=0.5)"
],
"metadata": {
"trusted": true,
"id": "h59Tc2PgHIFM"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"## X data -> transform into vectors\n",
"x_train_imdb = loaded_vectorizer.fit_transform(x_train_i)\n",
"x_test_imdb = loaded_vectorizer.transform(x_test)\n",
"x_val_imdb = loaded_vectorizer.transform(x_val)"
],
"metadata": {
"trusted": true,
"id": "w_duSrVEHIFN"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"x_train_imdb[0].shape"
],
"metadata": {
"trusted": true,
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "DEzbTGCiHIFN",
"outputId": "0c696ccd-2d2f-4655-8f51-41727967efd8"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(1, 26698)"
]
},
"metadata": {},
"execution_count": 13
}
]
},
{
"cell_type": "code",
"source": [
"# Y data - Positive is 1\n",
"y_train_imdb = label_binarizer.fit_transform(y_train_i)\n",
"y_test_imdb = label_binarizer.fit_transform(y_test_i)\n",
"y_val_imdb = label_binarizer.fit_transform(y_val_i)"
],
"metadata": {
"trusted": true,
"id": "FNU4C7_IHIFN"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"y_train_imdb.shape"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "ksY1bxIcPNlw",
"outputId": "0c48f6e4-384e-497d-c632-95a8ff040155"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(8000, 1)"
]
},
"metadata": {},
"execution_count": 15
}
]
},
{
"cell_type": "markdown",
"source": [
"## Train RF model"
],
"metadata": {
"id": "Wx3SYI82HIFN"
}
},
{
"cell_type": "code",
"source": [
"# Number of trees in random forest\n",
"n_estimators = [int(x) for x in np.linspace(start = 10, stop = 100, num = 10)]\n",
"# Maximum number of levels in tree\n",
"max_depth = [int(x) for x in np.linspace(10, 100, num = 5)]\n",
"max_depth.append(None)\n",
"# Minimum number of samples required to split a node\n",
"min_samples_split = [2, 5, 10]\n",
"# Minimum number of samples required at each leaf node\n",
"min_samples_leaf = [1, 2, 4]\n",
"# Method of selecting samples for training each tree\n",
"bootstrap = [True, False]\n",
"# Create the grid\n",
"grid_rf = {'n_estimators': n_estimators,\n",
" 'max_depth': max_depth,\n",
" 'min_samples_split': min_samples_split,\n",
" 'min_samples_leaf': min_samples_leaf,\n",
" 'bootstrap': bootstrap}\n",
"print(grid_rf)"
],
"metadata": {
"execution": {
"iopub.status.busy": "2023-05-24T01:59:14.021625Z",
"iopub.status.idle": "2023-05-24T01:59:14.022487Z",
"shell.execute_reply.started": "2023-05-24T01:59:14.022189Z",
"shell.execute_reply": "2023-05-24T01:59:14.022216Z"
},
"id": "Ssbu_t5UHIFN"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"from sklearn.ensemble import RandomForestClassifier"
],
"metadata": {
"execution": {
"iopub.status.busy": "2023-05-24T01:59:14.024015Z",
"iopub.status.idle": "2023-05-24T01:59:14.024474Z",
"shell.execute_reply.started": "2023-05-24T01:59:14.024258Z",
"shell.execute_reply": "2023-05-24T01:59:14.024286Z"
},
"id": "Uik1rEaAHIFO"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Grid search -> exhaustive search -> Takes too much time -> Therefore use Randomised -> takes less time -> Issue - Sometimes go to a local minima\n",
"grid_imdb_rf = RandomizedSearchCV(RandomForestClassifier(), param_distributions = grid_rf, n_iter = 200, cv = 3, verbose=2, random_state=42, n_jobs = -1) # Fit the random search model\n",
"# # Fit the random search model\n",
"grid_imdb_rf.fit(x_train_imdb,y_train_imdb.ravel()) # -> can get the best hyperparameters from the input params list\n",
"# Save the trained Model to a pickle file to use afterwards\n",
"pickle.dump(grid_imdb_rf, open('grid_imdb_rf.pickle', \"wb\"))"
],
"metadata": {
"execution": {
"iopub.status.busy": "2023-05-24T01:59:14.025697Z",
"iopub.status.idle": "2023-05-24T01:59:14.026125Z",
"shell.execute_reply.started": "2023-05-24T01:59:14.025909Z",
"shell.execute_reply": "2023-05-24T01:59:14.025929Z"
},
"id": "mNns_pd2HIFO"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Train SVC Model"
],
"metadata": {
"id": "lk6mPEd6HIFO"
}
},
{
"cell_type": "code",
"source": [
"# Param Optimisation\n",
"param_grid_imdb = {'C': [0.1,1, 10, 100], 'gamma': [1,0.1,0.01,0.001],'kernel': ['rbf']}\n",
"grid_imdb_svc = GridSearchCV(SVC(),param_grid_imdb,refit=True,verbose=2)"
],
"metadata": {
"id": "t8Yb5GgRHIFO"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"grid_imdb_svc.fit(x_train_imdb,y_train_imdb.ravel())\n",
"pickle.dump(grid_imdb_svc, open('grid_imdb_svc.pickle', \"wb\"))"
],
"metadata": {
"id": "Lfzr1jIdHIFO",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "9c954b27-6ad8-49d6-ce54-3ffc7f10055c"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Fitting 5 folds for each of 16 candidates, totalling 80 fits\n",
"[CV] END .........................C=0.1, gamma=1, kernel=rbf; total time= 56.5s\n",
"[CV] END .........................C=0.1, gamma=1, kernel=rbf; total time= 1.1min\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"## Train KNN model"
],
"metadata": {
"id": "X-VdilC0HIFP"
}
},
{
"cell_type": "code",
"source": [
"grid_params_imdb_knn = { 'n_neighbors' : [30,40,50,60,70,80,90], 'metric' : ['manhattan', 'minkowski'], 'weights': ['uniform', 'distance']}\n",
"grid_imdb_knn = GridSearchCV(KNeighborsClassifier(), grid_params_imdb_knn, n_jobs=-1,verbose=2)"
],
"metadata": {
"id": "UNzp-0VxHIFP"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"grid_imdb_knn.fit(x_train_imdb,np.ravel(y_train_imdb,order='C'))\n",
"pickle.dump(grid_imdb_knn, open('grid_imdb_knn.pickle', \"wb\"))"
],
"metadata": {
"id": "ru9aCGFMHIFP"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Train LR Model"
],
"metadata": {
"id": "7VE2Buw4HIFP"
}
},
{
"cell_type": "code",
"source": [
"param_grid_imdb_lr = [ \n",
" {'penalty' : ['l1', 'l2', 'elasticnet'],\n",
" 'C' : np.logspace(-4, 4, 20),\n",
" 'solver' : ['lbfgs','newton-cg','sag'],\n",
" 'max_iter' : [100, 1000, 5000]\n",
" }\n",
"]\n",
"grid_imdb_lr = GridSearchCV(LogisticRegression(), param_grid = param_grid_imdb_lr, cv = 3, verbose=2, n_jobs=-1)"
],
"metadata": {
"id": "MXOWXQWYHIFP"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"grid_imdb_lr.fit(x_train_imdb, np.ravel(y_train_imdb,order='C'))\n",
"pickle.dump(grid_imdb_lr, open('grid_imdb_lr.pickle', \"wb\"))"
],
"metadata": {
"id": "6E-cFBvdHIFQ"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Load Models"
],
"metadata": {
"id": "3EMR5PGTHIFQ"
}
},
{
"cell_type": "code",
"source": [
"# Load\n",
"loaded_svc_imdb = pickle.load(open('/content/drive/MyDrive/grid_imdb_svc.pickle/grid_imdb_svc.pickle', \"rb\"))\n",
"#loaded_knn_imdb = pickle.load(open('/kaggle/input/models/grid_imdb_knn.pickle', \"rb\"))\n",
"#loaded_lr_imdb = pickle.load(open('/content/drive/MyDrive/Sentiment/grid_imdb_lr.pickle', \"rb\"))\n",
"#loaded_rf_imdb = pickle.load(open('/content/drive/MyDrive/Sentiment/grid_imdb_rf.pickle', \"rb\"))"
],
"metadata": {
"trusted": true,
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "AIHKvuaNHIFQ",
"outputId": "02f711f8-ea7e-44f4-fd52-8c7f3cb20209"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"<ipython-input-38-b2d9bcf25cb7>:2: DeprecationWarning: Please use `csr_matrix` from the `scipy.sparse` namespace, the `scipy.sparse.csr` namespace is deprecated.\n",
" loaded_svc_imdb = pickle.load(open('/content/drive/MyDrive/grid_imdb_svc.pickle/grid_imdb_svc.pickle', \"rb\"))\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"print(loaded_svc_imdb.best_params_)\n",
"#print(loaded_knn_imdb.best_params_)\n",
"#print(loaded_lr_imdb.best_params_)\n",
"#print(loaded_rf_imdb.best_params_)"
],
"metadata": {
"trusted": true,
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "ICXhxhsfHIFQ",
"outputId": "a8e2ae1d-31b4-4e86-983d-3840611e72e5"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"{'C': 10, 'gamma': 1, 'kernel': 'rbf'}\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"### Info about models"
],
"metadata": {
"id": "uqwjFg8xHIFQ"
}
},
{
"cell_type": "code",
"source": [
"p = np.sum(y_train_imdb)/np.size(y_train_imdb)"
],
"metadata": {
"trusted": true,
"id": "enIOyEwGHIFQ"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Probs\n",
"#probs_svc = loaded_svc_imdb.decision_function(x_test_imdb)\n",
"#probs_lr = loaded_lr_imdb.decision_function(x_test_imdb)\n",
"#probs_rf = loaded_rf_imdb.predict_log_proba(x_test_imdb)\n",
"#probs_knn = loaded_knn_imdb.predict_proba(x_test_imdb)"
],
"metadata": {
"trusted": true,
"id": "IPSJudReHIFQ"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"#threshold_classifier_probs_lr = np.percentile(probs_lr,(100-(p*100)))\n",
"#threshold_classifier_probs_svc = np.percentile(probs_svc,(100-(p*100)))\n",
"#threshold_classifier_probs_rf = np.percentile(probs_rf,(100-(p*100)))\n",
"#threshold_classifier_probs_knn = np.percentile(probs_knn,(100-(p*100)))\n"
],
"metadata": {
"trusted": true,
"id": "tE4u6SdTHIFR"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"#predictions_probs_lr = (probs_lr >= threshold_classifier_probs_lr)\n",
"#predictions_probs_svc = (probs_svc >= threshold_classifier_probs_svc)\n",
"#predictions_probs_rf = (probs_rf >= threshold_classifier_probs_rf)\n",
"#predictions_probs_knn = (probs_knn >= threshold_classifier_probs_knn)\n"
],
"metadata": {
"trusted": true,
"id": "r_zvoZy2HIFR"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"probs_svc = loaded_svc_imdb.predict(x_test_imdb)\n",
"#probs_lr = loaded_lr_imdb.predict(x_test_imdb)\n",
"#probs_rf = loaded_rf_imdb.predict(x_test_imdb)\n",
"#probs_knn = loaded_knn_imdb.predict(x_test_imdb)"
],
"metadata": {
"trusted": true,
"id": "XOIMo-zRHIFR"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"probs_svc"
],
"metadata": {
"id": "j0nnMbDrRM9D",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "2f76b966-11af-40b0-fc1e-67ddecbd5818"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array([0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0,\n",
" 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 1, 0,\n",
" 1, 0, 1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1, 0, 0,\n",
" 0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1,\n",
" 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0,\n",
" 0, 1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0,\n",
" 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1,\n",
" 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0,\n",
" 0, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 0,\n",
" 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1,\n",
" 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1,\n",
" 1, 1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 1, 0,\n",
" 0, 0, 1, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0,\n",
" 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0,\n",
" 1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1,\n",
" 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1,\n",
" 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1,\n",
" 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 1,\n",
" 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0,\n",
" 1, 1, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0,\n",
" 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1,\n",
" 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 0,\n",
" 0, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 1,\n",
" 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 0,\n",
" 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 0, 1, 1,\n",
" 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0,\n",
" 1, 0, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1,\n",
" 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 1, 1,\n",
" 1, 0, 1, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 1, 0, 1, 1, 1,\n",
" 1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1,\n",
" 1, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0,\n",
" 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 1, 0,\n",
" 0, 0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1,\n",
" 0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1,\n",
" 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1,\n",
" 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1,\n",
" 0, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1,\n",
" 1, 0, 0, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 1, 1, 0, 1,\n",
" 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1,\n",
" 1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0,\n",
" 1, 1, 1, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0,\n",
" 0, 0, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0,\n",
" 1, 0, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1,\n",
" 1, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0,\n",
" 1, 0, 0, 1, 1, 0, 1, 0, 1, 0])"
]
},
"metadata": {},
"execution_count": 42
}
]
},
{
"cell_type": "code",
"source": [
"loaded_svc_imdb.predict(x_test_imdb[0])"
],
"metadata": {
"trusted": true,
"id": "IfgApH2JHIFR",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "a0b4d51b-9e10-48ec-de75-0baa60d3c5aa"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array([0])"
]
},
"metadata": {},
"execution_count": 43
}
]
},
{
"cell_type": "code",
"source": [
"loaded_rf_imdb.predict(x_test_imdb[0])"
],
"metadata": {
"trusted": true,
"id": "PXF0D1XRHIFR"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"loaded_knn_imdb.predict(x_test_imdb[0])"
],
"metadata": {
"trusted": true,
"id": "rehc1SbsHIFa"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"accuracy_svc = metrics.accuracy_score(y_test_imdb, probs_svc)\n",
"#accuracy_lr = metrics.accuracy_score(y_test_imdb, probs_lr)\n",
"#accuracy_rf = metrics.accuracy_score(y_test_imdb, probs_rf)\n",
"#accuracy_knn = metrics.accuracy_score(y_test_imdb, probs_knn)"
],
"metadata": {
"trusted": true,
"id": "47SX5b6wHIFb"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"metrics.confusion_matrix(y_test_imdb, probs_svc)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "6RWS2AchT88e",
"outputId": "715a86a3-0aad-42da-8ec0-29f5cfe52f5a"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array([[437, 62],\n",
" [ 44, 457]])"
]
},
"metadata": {},
"execution_count": 25
}
]
},
{
"cell_type": "code",
"source": [
"metrics.f1_score(y_test_imdb, probs_svc)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "4cUSqt6_UTig",
"outputId": "68e958ab-90a2-46db-d3d7-4cf4623e4a4d"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"0.896078431372549"
]
},
"metadata": {},
"execution_count": 26
}
]
},
{
"cell_type": "code",
"source": [
"metrics.roc_curve(y_test_imdb, probs_svc)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "A9cFHtrgVSul",
"outputId": "1b83ae60-872b-4c2f-a954-7700236a392f"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(array([0. , 0.1242485, 1. ]),\n",
" array([0. , 0.91217565, 1. ]),\n",
" array([2, 1, 0]))"
]
},
"metadata": {},
"execution_count": 27
}
]
},
{
"cell_type": "code",
"source": [
"print(\"The accuracy of the SVM model on the test data is %f\" %accuracy_svc)\n",
"#print(\"The accuracy of the SVC model on the test data is %f\" %accuracy_lr)\n",
"#print(\"The accuracy of RF the model on the test data is %f\" %accuracy_rf)\n",
"#print(\"The accuracy of the KNN model on the test data is %f\" %accuracy_knn)\n"
],
"metadata": {
"trusted": true,
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "phQzaoXCHIFb",
"outputId": "ff9385dc-daed-4d03-dd48-f7a7b80ed6c5"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"The accuracy of the SVM model on the test data is 0.894000\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"x_val[0:2]"
],
"metadata": {
"id": "Lqz7xjHXRcx3",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "1e787510-10c8-4293-aba9-e5bc540da9e7"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"23215 pathetic moviei wont waste much time commenti...\n",
"6878 ok gave 3 obviously money make film feel migh...\n",
"Name: review, dtype: object"
]
},
"metadata": {},
"execution_count": 32
}
]
},
{
"cell_type": "code",
"source": [
"input_review = input('give me review')"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "gA34GtvWHIFb",
"outputId": "5e901db6-ce1e-41a6-c642-35ce0bb6e1ae"
},
"execution_count": null,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"give me reviewTo put it simply, I enjoyed this film. The reason for my interest & enjoyment was not related to anything other than the subject matter itself. I had heard tales from my mother and grandmother about how Northern England working class life and attitudes used to be (as experienced by them)and this is an interesting depiction that seems to faithfully represent what they told me. In particular, the paternalistic but overbearing father who \"knows\" what is best for his family along with his stubborness when this paradigm is challenged. (Not much has changed there then!!)<br /><br />People who have seen the play will probably be disappointed with the film because the story does not easily transfer across the different media. In a sense however, the film is an historical document and I personally enjoyed it, if only because of the way it conveyed a social phenomenon.\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"vector_input = loaded_vectorizer.transform([input_review])"
],
"metadata": {
"id": "HqR56VK5R65_"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"loaded_svc_imdb.predict(vector_input)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "A97C_KWfSBfK",
"outputId": "4385e2a8-7397-4bf1-d266-d5eb1210b943"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array([1])"
]
},
"metadata": {},
"execution_count": 49
}
]
},
{
"cell_type": "code",
"source": [],
"metadata": {
"id": "CjobeYxSTtwC"
},
"execution_count": null,
"outputs": []
}
]
}
\ No newline at end of file
# See https://help.github.com/articles/ignoring-files/ for more about ignoring files.
# dependencies
/node_modules
/.pnp
.pnp.js
# testing
/coverage
# production
/build
# misc
.DS_Store
.env.local
.env.development.local
.env.test.local
.env.production.local
npm-debug.log*
yarn-debug.log*
yarn-error.log*
This source diff could not be displayed because it is too large. You can view the blob instead.
{
"name": "frontend",
"version": "0.1.0",
"private": true,
"dependencies": {
"@emotion/react": "^11.10.5",
"@mantine/core": "6.0.0",
"@mantine/hooks": "6.0.0",
"@testing-library/jest-dom": "^5.16.5",
"@testing-library/react": "^13.4.0",
"@testing-library/user-event": "^14.4.3",
"@types/jest": "^29.2.3",
"@types/node": "^18.11.9",
"@types/react": "^18.0.25",
"@types/react-dom": "^18.0.9",
"react": "^18.2.0",
"react-dom": "^18.2.0",
"react-router-dom": "^6.11.2",
"react-scripts": "5.0.1",
"typescript": "^4.9.3",
"web-vitals": "^3.1.0"
},
"scripts": {
"start": "react-scripts start",
"build": "react-scripts build",
"test": "react-scripts test",
"eject": "react-scripts eject",
"typecheck": "tsc --noEmit"
},
"eslintConfig": {
"extends": [
"react-app",
"react-app/jest"
]
},
"browserslist": {
"production": [
">0.2%",
"not dead",
"not op_mini all"
],
"development": [
"last 1 chrome version",
"last 1 firefox version",
"last 1 safari version"
]
}
}
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8" />
<link rel="icon" href="%PUBLIC_URL%/favicon.ico" />
<meta name="viewport" content="width=device-width, initial-scale=1" />
<meta name="theme-color" content="#000000" />
<meta
name="description"
content="Web site created using create-react-app"
/>
<link rel="apple-touch-icon" href="%PUBLIC_URL%/logo192.png" />
<!--
manifest.json provides metadata used when your web app is installed on a
user's mobile device or desktop. See https://developers.google.com/web/fundamentals/web-app-manifest/
-->
<link rel="manifest" href="%PUBLIC_URL%/manifest.json" />
<!--
Notice the use of %PUBLIC_URL% in the tags above.
It will be replaced with the URL of the `public` folder during the build.
Only files inside the `public` folder can be referenced from the HTML.
Unlike "/favicon.ico" or "favicon.ico", "%PUBLIC_URL%/favicon.ico" will
work correctly both with client-side routing and a non-root public URL.
Learn how to configure a non-root public URL by running `npm run build`.
-->
<title>React App</title>
</head>
<body>
<noscript>You need to enable JavaScript to run this app.</noscript>
<div id="root"></div>
<!--
This HTML file is a template.
If you open it directly in the browser, you will see an empty page.
You can add webfonts, meta tags, or analytics to this file.
The build step will place the bundled scripts into the <body> tag.
To begin the development, run `npm start` or `yarn start`.
To create a production bundle, use `npm run build` or `yarn build`.
-->
</body>
</html>
{
"short_name": "React App",
"name": "Create React App Sample",
"icons": [
{
"src": "favicon.ico",
"sizes": "64x64 32x32 24x24 16x16",
"type": "image/x-icon"
},
{
"src": "logo192.png",
"type": "image/png",
"sizes": "192x192"
},
{
"src": "logo512.png",
"type": "image/png",
"sizes": "512x512"
}
],
"start_url": ".",
"display": "standalone",
"theme_color": "#000000",
"background_color": "#ffffff"
}
# https://www.robotstxt.org/robotstxt.html
User-agent: *
Disallow:
import { createStyles } from "@mantine/core";
import { ThemeProvider } from "./ThemeProvider";
import { NavbarSimpleColored } from "./Components/NavBar/NavBar";
import {BrowserRouter as Router, Routes, Route} from 'react-router-dom';
import Knn from "./Components/k-NN/k-NN";
import SVM from "./Components/SVM/SVM";
import RandomForest from "./Components/RandomForest/RandomForest";
import LogisticRegression from "./Components/LogisticRegression/LogisticRegression";
import Settings from "./Components/Settings/Settings";
import Home from "./Components/Home/Home";
const useStyles = createStyles((theme) => ({
sections: {
display: "flex",
flexDirection: "row",
position: "relative",
}
}))
export default function App() {
const { classes } = useStyles()
return (
<ThemeProvider>
<Router>
<div className={classes.sections}>
<NavbarSimpleColored />
<div style={{width: window.innerWidth/5*4, position: "absolute", right: 0}}>
<Routes>
<Route path='/' element={<Home />} />
<Route path='/svm' element={<SVM />} />
<Route path='/knn' element={<Knn />} />
<Route path='/rf' element={<RandomForest />} />
<Route path='/lr' element={<LogisticRegression />} />
<Route path='/settings' element={<Settings />} />
</Routes>
</div>
</div>
</Router>
</ThemeProvider>
);
}
import { Card, Image, Text } from '@mantine/core';
export function CardCustom(data: {src: string, title: string}) {
return (
<Card
shadow="xl"
padding="xl"
component="a"
// href="https://www.youtube.com/watch?v=dQw4w9WgXcQ"
target="_blank"
>
<Card.Section>
<Image
src={data.src}
height={160}
alt="No way!"
/>
</Card.Section>
<Text weight={500} size="lg" mt="md">
{data.title}
</Text>
{/* <Text mt="xs" color="dimmed" size="sm">
try {data.title} ...
</Text> */}
</Card>
);
}
\ No newline at end of file
import { Grid, Text, Title, Transition, createStyles } from "@mantine/core"
import { CardCustom } from "../Card/CardCustom"
const useStyle = createStyles((theme) => ({
container: {
height: "100%",
width: "100%",
},
transitionBox: {
display: "flex",
flexDirection: "column",
alignItems: "center",
justifyContent: "center",
},
description: {
width: "80%",
}
}))
function Home() {
const { classes } = useStyle()
return (
<div className={classes.container}>
<Transition mounted={true} transition="pop" duration={1000} timingFunction="ease">
{
(styles) => (
<div className={classes.transitionBox}>
<Title order={1} size={60} mt={30} style={styles}>
Get Your Best XAI Solution
</Title>
<Text className={classes.description} mt={50} align="center">
Paragraphs are the building blocks of papers.
Many students define paragraphs in terms of length: a paragraph is a group of
at least five sentences, a paragraph is half a page long, etc. In reality, though, the unity and
coherence of ideas among sentences is what constitutes a paragraph. A paragraph is
defined as “a group of sentences or a single sentence that forms
a unit” (Lunsford and Connors 116). Length
and appearance do not determine whether a section in a paper is a paragraph. For instance, in some
styles of writing, particularly journalistic styles, a paragraph
can be just one sentence long. Ultimately, a paragraph is a
sentence or group of sentences that support one main idea. In this handout,
we will refer to this as the “controlling idea,” because it controls what happens in the
rest of the paragraph.
</Text>
<Grid ml={50} mr={50} mt={30}>
<Grid.Col span={3}>
<CardCustom
src="https://media.licdn.com/dms/image/C5112AQFY4bX3Y7jcHA/article-cover_image-shrink_600_2000/0/1565431496642?e=2147483647&v=beta&t=N2zhB7OhvdTrebqNkY2lxMaPHqPYpA4r3nj88msI-e0"
title="K-NN"
/>
</Grid.Col>
<Grid.Col span={3}>
<CardCustom
src=""
title="SVM"
/>
</Grid.Col>
<Grid.Col span={3}>
<CardCustom
src=""
title="Random Forest"
/>
</Grid.Col>
<Grid.Col span={3}>
<CardCustom
src="https://www.appstate.edu/~whiteheadjc/service/logit/logit.gif"
title="Logistic Regression"
/>
</Grid.Col>
</Grid>
</div>
)
}
</Transition>
</div>
)
}
export default Home
\ No newline at end of file
function LogisticRegression() {
return (
<div>LogisticRegression</div>
)
}
export default LogisticRegression
\ No newline at end of file
import { useState } from 'react';
import { createStyles, Navbar, Group, Code, getStylesRef, rem } from '@mantine/core';
import { useNavigate } from "react-router-dom";
import HOME_IMG from '../../assets/home-icon.png';
import SVM_IMG from '../../assets/svm-icon.png'
import KNN_IMG from '../../assets/knn-icon.png'
import LR_IMG from '../../assets/lr-icon.png'
import RF_IMG from '../../assets/rf-icon.png'
import SETTING_IMG from '../../assets/setting-icon.png'
import { LOGO } from '../consts';
const useStyles = createStyles((theme) => ({
navbar: {
backgroundColor: theme.fn.variant({ variant: 'filled', color: theme.primaryColor }).background,
position: "fixed",
},
version: {
backgroundColor: theme.fn.lighten(
theme.fn.variant({ variant: 'filled', color: theme.primaryColor }).background!,
0.1
),
color: theme.white,
fontWeight: 700,
},
header: {
paddingBottom: theme.spacing.md,
marginBottom: `calc(${theme.spacing.md} * 1.5)`,
borderBottom: `${rem(1)} solid ${theme.fn.lighten(
theme.fn.variant({ variant: 'filled', color: theme.primaryColor }).background!,
0.1
)}`,
},
footer: {
paddingTop: theme.spacing.md,
marginTop: theme.spacing.md,
borderTop: `${rem(1)} solid ${theme.fn.lighten(
theme.fn.variant({ variant: 'filled', color: theme.primaryColor }).background!,
0.1
)}`,
},
link: {
...theme.fn.focusStyles(),
display: 'flex',
alignItems: 'center',
textDecoration: 'none',
fontSize: theme.fontSizes.sm,
color: theme.white,
padding: `${theme.spacing.xs} ${theme.spacing.sm}`,
borderRadius: theme.radius.sm,
fontWeight: 500,
'&:hover': {
backgroundColor: theme.fn.lighten(
theme.fn.variant({ variant: 'filled', color: theme.primaryColor }).background!,
0.1
),
},
},
linkIcon: {
ref: getStylesRef('icon'),
color: theme.white,
// opacity: 0.75,
marginRight: theme.spacing.sm,
width: "30px",
height: "30px"
},
linkActive: {
'&, &:hover': {
backgroundColor: theme.fn.lighten(
theme.fn.variant({ variant: 'filled', color: theme.primaryColor }).background!,
0.15
),
[`& .${getStylesRef('icon')}`]: {
opacity: 0.9,
},
},
},
}));
const data = [
{ link: '/', label: 'Home', icon: HOME_IMG },
{ link: '/svm', label: 'Get Your Rule', icon: SVM_IMG },
// { link: '/knn', label: 'Documentation', icon: KNN_IMG },
// { link: '/rf', label: 'Random Forest', icon: RF_IMG },
// { link: '/lr', label: 'Logistic Regression', icon: LR_IMG },
];
export function NavbarSimpleColored() {
const { classes, cx } = useStyles();
const [active, setActive] = useState('Billing');
const navigate = useNavigate();
const links = data.map((item) => (
<a
className={cx(classes.link, { [classes.linkActive]: item.label === active })}
href=""
key={item.label}
onClick={(event) => {
event.preventDefault();
setActive(item.label);
navigate(item.link)
}}
>
{/* <item.icon className={classes.linkIcon} stroke={1.5} /> */}
<img src={item.icon} className={classes.linkIcon} />
<span>{item.label}</span>
</a>
));
return (
<Navbar height={window.innerHeight} width={{ sm: window.innerWidth/5 }} p="md" className={classes.navbar}>
<Navbar.Section grow>
<Group className={classes.header} position="apart">
{/* <MantineLogo size={28} inverted /> */}
<img src={LOGO} className={classes.linkIcon} />
<Code className={classes.version}>v1.0.0</Code>
</Group>
{links}
</Navbar.Section>
{/* <Navbar.Section className={classes.footer}>
<a href="#" className={classes.link} onClick={(event) => {
event.preventDefault()
navigate("/settings")
}}>
<img src={SETTING_IMG} className={classes.linkIcon} />
<span>Settings</span>
</a>
</Navbar.Section> */}
</Navbar>
);
}
\ No newline at end of file
function RandomForest() {
return (
<div>RandomForest</div>
)
}
export default RandomForest
\ No newline at end of file
import { Title, createStyles } from '@mantine/core'
import React from 'react'
const useStyle = createStyles((theme) => ({
table: {
border: "1px solid #ddd",
borderCollapse: "collapse",
},
td: {
border: "1px solid #ddd",
width: window.innerHeight / 4,
padding: "5px 10px",
},
negative: {
border: "1px solid #ddd",
width: window.innerHeight / 4,
padding: "5px 10px",
background: "green",
color: "white",
},
positive: {
border: "1px solid #ddd",
width: window.innerHeight / 4,
padding: "5px 10px",
background: "red",
color: "white",
}
}))
function CounterfactualTable(data: {tableData: any}) {
const { classes } = useStyle();
return (
<div>
<Title order={6} mt={30} mb={10}>Results</Title>
<table className={classes.table}>
{
data.tableData.map((e: any, i: number) => {
return (
<tr key={i}>
<td className={classes.positive}>{e.initial}</td>
{i === 0 ? <td className={classes.td} rowSpan={data.tableData.length}>Changes to</td> : null}
<td className={classes.negative}>{e.changedTo}</td>
</tr>
);
})
}
</table>
</div>
)
}
export default CounterfactualTable
\ No newline at end of file
import { Text, Title } from '@mantine/core';
export function Prediction(data: {prediction: string}) {
return (
<div>
<Title order={6} mt={60}>Prediction</Title>
<Text size={30} mt={10}>
{data.prediction}
</Text>
</div>
);
}
\ No newline at end of file
import { createStyles } from "@mantine/core"
import { TextAreaCustom } from "../_common/TextAreaCustom"
import { DropDownWidget } from "../_common/DropDownWidget";
import { ButtonCustom } from "../_common/ButtonCustom";
import { ScrollButton } from "../_common/ScrollButton";
import { CustomTitle } from "../_common/CustomTitle";
import { useWindowScroll } from "@mantine/hooks";
import { Sentence } from "./Sentence";
import { Prediction } from "./Prediction";
import CounterfactualTable from "./CounterfactualTable";
const useStyle = createStyles((theme) => ({
mainContainer: {
display: "flex",
flexDirection: "column",
alignItems: "center",
}
}))
function SVM() {
const { classes } = useStyle();
const [scroll, scrollTo] = useWindowScroll();
const tableData = [
{initial: "like", changedTo: "do not like"},
{initial: "happy", changedTo: "sad"},
{initial: "good", changedTo: "bad"},
];
return (
<div className={classes.mainContainer}>
<div id="getData">
<CustomTitle title="Get Your Rule" />
<TextAreaCustom
label="Enter the review"
placeholder="Enter the review"
/>
<DropDownWidget
label="Select the Algorithm"
placeholder="Algorithm"
/>
<ButtonCustom
label="Get the rule"
onClick={() => scrollTo({ y: window.innerHeight })}
/>
<ScrollButton />
</div>
<div id="viewResult" style={{width: window.innerWidth / 5 * 3, height: window.innerHeight}}>
<CustomTitle title="Counterfactual Result" />
<Sentence
sentence="I like apple"
hightLightWords={["like"]}
/>
<Prediction prediction="Positive" />
<CounterfactualTable tableData={tableData} />
</div>
</div>
)
}
export default SVM
\ No newline at end of file
import { Highlight, Text, Title } from '@mantine/core';
export function Sentence(data: {sentence: string, hightLightWords: string[]}) {
return (
<div>
<Title order={6} mt={30}>Sentence</Title>
<Text size={30} mt={10}>
<Highlight highlight={data.hightLightWords}>
{data.sentence}
</Highlight>
</Text>
</div>
);
}
\ No newline at end of file
function Settings() {
return (
<div>Settings</div>
)
}
export default Settings
\ No newline at end of file
import { Button, createStyles } from '@mantine/core';
const useStyles = createStyles((theme) => ({
container: {
width: window.innerWidth / 5 * 3,
display: "flex",
justifyContent: "end",
}
}))
export function ButtonCustom(data: {label: string, onClick: any}) {
const { classes } = useStyles()
return (
<div className={classes.container}>
<Button
variant="gradient"
gradient={{ from: 'teal', to: 'blue', deg: 60 }}
radius="xl"
size="lg"
mt={30}
mb={100}
onClick={data.onClick}
>
{data.label}
</Button>
</div>
);
}
\ No newline at end of file
import { Title } from '@mantine/core';
export function CustomTitle(data: {title: string}) {
return (
<>
<Title order={1} align='left' mt={50}>{data.title}</Title>
</>
);
}
\ No newline at end of file
import { Select } from '@mantine/core';
export function DropDownWidget(data: {label: string, placeholder: string}) {
return (
<Select
mt={30}
label={data.label}
placeholder={data.placeholder}
data={[
{ value: 'SVM', label: 'Support Vector Machine' },
{ value: 'random forest', label: 'Random Forest' },
{ value: 'logistic Regression', label: 'Logistic Regression' },
{ value: 'KNN', label: 'K-Nearest Neighbor' },
]}
style={{width: window.innerWidth / 5 * 3}}
/>
);
}
\ No newline at end of file
import { useWindowScroll } from '@mantine/hooks';
import { Button, createStyles } from '@mantine/core';
const useStyle = createStyles((theme) => ({
button: {
position: "fixed",
bottom: 50,
right: 50,
width: "50px",
height: "50px",
borderRadius: "50%",
}
}))
export function ScrollButton() {
const [scroll, scrollTo] = useWindowScroll();
const { classes } = useStyle()
return (
<Button className={classes.button} onClick={() => scrollTo({ y: 0 })}>^</Button>
);
}
\ No newline at end of file
import { Textarea } from "@mantine/core";
export function TextAreaCustom(data: {placeholder: string, label: string}) {
return (
<Textarea
mt={30}
placeholder={data.placeholder}
label={data.label}
style={{width: window.innerWidth / 5 * 3}}
minRows={10}
autosize={true}
withAsterisk
/>
);
}
\ No newline at end of file
import LOGO_IMG from '../assets/logo.png';
export const LOGO = LOGO_IMG;
\ No newline at end of file
function Knn() {
return (
<div>k-NN</div>
)
}
export default Knn
\ No newline at end of file
import { MantineProvider, MantineThemeOverride } from "@mantine/core";
export const theme: MantineThemeOverride = {
colorScheme: "light",
};
interface ThemeProviderProps {
children: React.ReactNode;
}
export function ThemeProvider({ children }: ThemeProviderProps) {
return (
<MantineProvider withGlobalStyles withNormalizeCSS theme={theme}>
{children}
</MantineProvider>
);
}
import { StrictMode } from "react";
import ReactDOM from "react-dom/client";
import App from "./App";
import reportWebVitals from "./reportWebVitals";
const root = ReactDOM.createRoot(
document.getElementById("root") as HTMLElement
);
root.render(
<StrictMode>
<App />
</StrictMode>
);
// If you want to start measuring performance in your app, pass a function
// to log results (for example: reportWebVitals(console.log))
// or send to an analytics endpoint. Learn more: https://bit.ly/CRA-vitals
reportWebVitals();
/// <reference types="react-scripts" />
import { ReportHandler } from 'web-vitals';
const reportWebVitals = (onPerfEntry?: ReportHandler) => {
if (onPerfEntry && onPerfEntry instanceof Function) {
import('web-vitals').then(({ getCLS, getFID, getFCP, getLCP, getTTFB }) => {
getCLS(onPerfEntry);
getFID(onPerfEntry);
getFCP(onPerfEntry);
getLCP(onPerfEntry);
getTTFB(onPerfEntry);
});
}
};
export default reportWebVitals;
// jest-dom adds custom jest matchers for asserting on DOM nodes.
// allows you to do things like:
// expect(element).toHaveTextContent(/react/i)
// learn more: https://github.com/testing-library/jest-dom
import '@testing-library/jest-dom';
{
"compilerOptions": {
"target": "es5",
"lib": [
"dom",
"dom.iterable",
"esnext"
],
"allowJs": true,
"skipLibCheck": true,
"esModuleInterop": true,
"allowSyntheticDefaultImports": true,
"strict": true,
"forceConsistentCasingInFileNames": true,
"noFallthroughCasesInSwitch": true,
"module": "esnext",
"moduleResolution": "node",
"resolveJsonModule": true,
"isolatedModules": true,
"noEmit": true,
"jsx": "react-jsx"
},
"include": [
"src"
]
}
This source diff could not be displayed because it is too large. You can view the blob instead.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment