Commit 3e6549d8 authored by IT20013950 Lakshani N.V.M.'s avatar IT20013950 Lakshani N.V.M.

Merge branch 'IT20013950-Lakshani' into 'master'

Frontend & k-NN model training

See merge request !4
parents 2a7d9ea1 39e09643
{
"metadata": {
"kernelspec": {
"language": "python",
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.10.10",
"mimetype": "text/x-python",
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"pygments_lexer": "ipython3",
"nbconvert_exporter": "python",
"file_extension": ".py"
},
"colab": {
"provenance": []
}
},
"nbformat_minor": 0,
"nbformat": 4,
"cells": [
{
"cell_type": "markdown",
"source": [
"# Initialisation"
],
"metadata": {
"id": "tFwb5L06sPmL"
}
},
{
"cell_type": "markdown",
"source": [
"## Installing and importing packages"
],
"metadata": {
"id": "iLJdHDgfsPmO"
}
},
{
"cell_type": "code",
"source": [
"!pip install numpy==1.20.0\n",
"!pip install ordered_set"
],
"metadata": {
"trusted": true,
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "WamT8ircsPmP",
"outputId": "4879fb24-f899-4675-a243-86cd93213aad"
},
"execution_count": 6,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
"Collecting numpy==1.20.0\n",
" Downloading numpy-1.20.0.zip (8.0 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m8.0/8.0 MB\u001b[0m \u001b[31m45.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25h Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
" Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
" Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
"Building wheels for collected packages: numpy\n",
" \u001b[1;31merror\u001b[0m: \u001b[1msubprocess-exited-with-error\u001b[0m\n",
" \n",
" \u001b[31m×\u001b[0m \u001b[32mBuilding wheel for numpy \u001b[0m\u001b[1;32m(\u001b[0m\u001b[32mpyproject.toml\u001b[0m\u001b[1;32m)\u001b[0m did not run successfully.\n",
" \u001b[31m│\u001b[0m exit code: \u001b[1;36m1\u001b[0m\n",
" \u001b[31m╰─>\u001b[0m See above for output.\n",
" \n",
" \u001b[1;35mnote\u001b[0m: This error originates from a subprocess, and is likely not a problem with pip.\n",
" Building wheel for numpy (pyproject.toml) ... \u001b[?25l\u001b[?25herror\n",
"\u001b[31m ERROR: Failed building wheel for numpy\u001b[0m\u001b[31m\n",
"\u001b[0mFailed to build numpy\n",
"\u001b[31mERROR: Could not build wheels for numpy, which is required to install pyproject.toml-based projects\u001b[0m\u001b[31m\n",
"\u001b[0mLooking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
"Collecting ordered_set\n",
" Downloading ordered_set-4.1.0-py3-none-any.whl (7.6 kB)\n",
"Installing collected packages: ordered_set\n",
"Successfully installed ordered_set-4.1.0\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"from google.colab import drive\n",
"drive.mount('/content/drive')"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "s0ZApxf4Oe2M",
"outputId": "02ceae9e-a309-4c2f-f511-1df025f0e613"
},
"execution_count": 7,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Mounted at /content/drive\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"import numpy as np # linear algebra\n",
"import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)"
],
"metadata": {
"id": "gf2uH7S1Obmw"
},
"execution_count": 8,
"outputs": []
},
{
"cell_type": "code",
"source": [
"import numpy as np # linear algebra\n",
"import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)\n",
"\n",
"import os\n",
"for dirname, _, filenames in os.walk('/content/drive'):\n",
" for filename in filenames:\n",
" print(os.path.join(dirname, filename))\n"
],
"metadata": {
"_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5",
"_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19",
"trusted": true,
"id": "80TWwaECsPmQ"
},
"execution_count": 1,
"outputs": []
},
{
"cell_type": "code",
"source": [
"\n",
"import time\n",
"\n",
"from sklearn.model_selection import RandomizedSearchCV\n",
"from sklearn.model_selection import GridSearchCV\n",
"from sklearn.metrics import roc_auc_score, accuracy_score\n",
"from sklearn.model_selection import ParameterGrid\n",
"from sklearn.svm import SVC\n",
"import sklearn.feature_extraction\n",
"from sklearn.feature_extraction.text import TfidfVectorizer\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn import metrics\n",
"from sklearn.pipeline import Pipeline\n",
"from sklearn.feature_extraction.text import TfidfVectorizer\n",
"from sklearn.feature_extraction.text import TfidfTransformer\n",
"\n",
"import nltk\n",
"from nltk.corpus import stopwords \n",
"from nltk.tokenize import word_tokenize\n",
"from nltk.stem import WordNetLemmatizer\n",
"nltk.download('wordnet')\n",
"nltk.download('stopwords')\n",
"nltk.download('punkt')\n",
"!unzip /usr/share/nltk_data/corpora/wordnet.zip -d /usr/share/nltk_data/corpora/\n",
"\n",
"from bs4 import BeautifulSoup\n",
"import re\n",
"import pickle\n",
"import seaborn as sns\n",
"\n",
"from ordered_set import OrderedSet\n",
"from scipy.sparse import lil_matrix\n",
"from itertools import compress\n"
],
"metadata": {
"trusted": true,
"id": "c0yWpzdksPmQ",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "87aa8bb6-c1cb-4cfb-f48c-79b5b717684d"
},
"execution_count": 9,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"unzip: cannot find or open /usr/share/nltk_data/corpora/wordnet.zip, /usr/share/nltk_data/corpora/wordnet.zip.zip or /usr/share/nltk_data/corpora/wordnet.zip.ZIP.\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"[nltk_data] Downloading package wordnet to /root/nltk_data...\n",
"[nltk_data] Package wordnet is already up-to-date!\n",
"[nltk_data] Downloading package stopwords to /root/nltk_data...\n",
"[nltk_data] Package stopwords is already up-to-date!\n",
"[nltk_data] Downloading package punkt to /root/nltk_data...\n",
"[nltk_data] Package punkt is already up-to-date!\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"# import dataset\n",
"df = pd.read_csv('/content/sample_data/IMDB Dataset.csv',parse_dates=True)\n",
"df.info()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "LyVwrbLFQs4g",
"outputId": "4174abd8-88bb-4eac-a05d-d944833c4bf9"
},
"execution_count": 10,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"<class 'pandas.core.frame.DataFrame'>\n",
"RangeIndex: 50000 entries, 0 to 49999\n",
"Data columns (total 2 columns):\n",
" # Column Non-Null Count Dtype \n",
"--- ------ -------------- ----- \n",
" 0 review 50000 non-null object\n",
" 1 sentiment 50000 non-null object\n",
"dtypes: object(2)\n",
"memory usage: 781.4+ KB\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"loaded_vocab = pickle.load(open('/content/sample_data/vectorizer_imdb.pkl', 'rb'))\n",
"\n",
"stop_words = set(stopwords.words('english'))\n",
"tokenizer = nltk.tokenize.toktok.ToktokTokenizer()\n",
"lemmatizer = WordNetLemmatizer() \n",
"loaded_vectorizer = TfidfVectorizer(min_df=2, vocabulary=loaded_vocab)\n",
"label_binarizer = sklearn.preprocessing.LabelBinarizer()\n",
"feature_names = loaded_vectorizer.get_feature_names_out()"
],
"metadata": {
"id": "KzLCu5aRPcwB"
},
"execution_count": 12,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Support Functions"
],
"metadata": {
"id": "rkeUtGt4sPmR"
}
},
{
"cell_type": "code",
"source": [
"def strip_html(text):\n",
" soup = BeautifulSoup(text, \"html.parser\")\n",
" return soup.get_text()\n",
"\n",
"def remove_special_characters(text, remove_digits=True):\n",
" pattern=r'[^a-zA-z0-9\\s]'\n",
" text=re.sub(pattern,'',text)\n",
" return text\n",
"\n",
"def remove_stopwords(text, is_lower_case=False):\n",
" tokens = tokenizer.tokenize(text)\n",
" tokens = [token.strip() for token in tokens]\n",
" if is_lower_case:\n",
" filtered_tokens = [token for token in tokens if token not in stop_words]\n",
" else:\n",
" filtered_tokens = [token for token in tokens if token.lower() not in stop_words]\n",
" filtered_text = ' '.join(filtered_tokens) \n",
" return filtered_text\n",
"\n",
"def lemmatize_text(text):\n",
" words=word_tokenize(text)\n",
" edited_text = ''\n",
" for word in words:\n",
" lemma_word=lemmatizer.lemmatize(word)\n",
" extra=\" \"+str(lemma_word)\n",
" edited_text+=extra\n",
" return edited_text"
],
"metadata": {
"trusted": true,
"id": "dLQymBfIsPmR"
},
"execution_count": 13,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Import DataSet and Preprocess"
],
"metadata": {
"id": "ewsKXh2XsPmS"
}
},
{
"cell_type": "code",
"source": [
"## Import\n",
"data = pd.read_csv('/content/sample_data/IMDB Dataset.csv')\n",
"data = data.sample(10000)"
],
"metadata": {
"trusted": true,
"id": "bdJixyiPsPmS"
},
"execution_count": 15,
"outputs": []
},
{
"cell_type": "code",
"source": [
"## Preprocess\n",
"data.review = data.review.str.lower()\n",
"data.review = data.review.apply(strip_html)\n",
"data.review = data.review.apply(remove_special_characters)\n",
"data.review = data.review.apply(remove_stopwords)\n",
"data.review = data.review.apply(lemmatize_text)"
],
"metadata": {
"trusted": true,
"id": "RwfbMIV0sPmS",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "4f52511b-8026-44be-970a-3b18f38e2a70"
},
"execution_count": 16,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"<ipython-input-13-7b716edd6682>:2: MarkupResemblesLocatorWarning: The input looks more like a filename than markup. You may want to open this file and pass the filehandle into Beautiful Soup.\n",
" soup = BeautifulSoup(text, \"html.parser\")\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"## Split Data\n",
"x_imdb = data['review']\n",
"y_imdb = data['sentiment']\n",
"\n",
"x_train_i, x_test_i, y_train_i, y_test_i = train_test_split(x_imdb,y_imdb,test_size=0.2)\n",
"x_test, x_val, y_test_i, y_val_i = train_test_split(x_test_i,y_test_i,test_size=0.5)"
],
"metadata": {
"trusted": true,
"id": "RFUFkBROsPmS"
},
"execution_count": 17,
"outputs": []
},
{
"cell_type": "code",
"source": [
"## X data -> transform into vectors\n",
"x_train_imdb = loaded_vectorizer.fit_transform(x_train_i)\n",
"x_test_imdb = loaded_vectorizer.transform(x_test)\n",
"x_val_imdb = loaded_vectorizer.transform(x_val)"
],
"metadata": {
"trusted": true,
"id": "NbYfcNFTsPmT"
},
"execution_count": 20,
"outputs": []
},
{
"cell_type": "code",
"source": [
"x_train_imdb[0].shape"
],
"metadata": {
"trusted": true,
"id": "zTzI4orbsPmT",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "11e313ba-de2d-42c3-930a-cea796f501d6"
},
"execution_count": 21,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(1, 26698)"
]
},
"metadata": {},
"execution_count": 21
}
]
},
{
"cell_type": "code",
"source": [
"# Y data - Positive is 1\n",
"y_train_imdb = label_binarizer.fit_transform(y_train_i)\n",
"y_test_imdb = label_binarizer.fit_transform(y_test_i)\n",
"y_val_imdb = label_binarizer.fit_transform(y_val_i)"
],
"metadata": {
"trusted": true,
"id": "Q84MWx-QsPmT"
},
"execution_count": 22,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Train RF model"
],
"metadata": {
"id": "_fW74XRBsPmT"
}
},
{
"cell_type": "code",
"source": [
"# Number of trees in random forest\n",
"n_estimators = [int(x) for x in np.linspace(start = 10, stop = 100, num = 10)]\n",
"# Maximum number of levels in tree\n",
"max_depth = [int(x) for x in np.linspace(10, 100, num = 5)]\n",
"max_depth.append(None)\n",
"# Minimum number of samples required to split a node\n",
"min_samples_split = [2, 5, 10]\n",
"# Minimum number of samples required at each leaf node\n",
"min_samples_leaf = [1, 2, 4]\n",
"# Method of selecting samples for training each tree\n",
"bootstrap = [True, False]\n",
"# Create the grid\n",
"grid_rf = {'n_estimators': n_estimators,\n",
" 'max_depth': max_depth,\n",
" 'min_samples_split': min_samples_split,\n",
" 'min_samples_leaf': min_samples_leaf,\n",
" 'bootstrap': bootstrap}\n",
"print(grid_rf)"
],
"metadata": {
"execution": {
"iopub.status.busy": "2023-05-24T01:59:14.021625Z",
"iopub.status.idle": "2023-05-24T01:59:14.022487Z",
"shell.execute_reply.started": "2023-05-24T01:59:14.022189Z",
"shell.execute_reply": "2023-05-24T01:59:14.022216Z"
},
"id": "HdpaVmxSsPmT"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"from sklearn.ensemble import RandomForestClassifier"
],
"metadata": {
"execution": {
"iopub.status.busy": "2023-05-24T01:59:14.024015Z",
"iopub.status.idle": "2023-05-24T01:59:14.024474Z",
"shell.execute_reply.started": "2023-05-24T01:59:14.024258Z",
"shell.execute_reply": "2023-05-24T01:59:14.024286Z"
},
"id": "X_WPq8AJsPmU"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"grid_imdb_rf = RandomizedSearchCV(RandomForestClassifier(), param_distributions = grid_rf, n_iter = 200, cv = 3, verbose=2, random_state=42, n_jobs = -1)# Fit the random search model\n",
"# # Fit the random search model\n",
"grid_imdb_rf.fit(x_train_imdb,y_train_imdb.ravel())\n",
"pickle.dump(grid_imdb_rf, open('grid_imdb_rf.pickle', \"wb\"))"
],
"metadata": {
"execution": {
"iopub.status.busy": "2023-05-24T01:59:14.025697Z",
"iopub.status.idle": "2023-05-24T01:59:14.026125Z",
"shell.execute_reply.started": "2023-05-24T01:59:14.025909Z",
"shell.execute_reply": "2023-05-24T01:59:14.025929Z"
},
"id": "9iAEK_NasPmU"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Train SVC Model"
],
"metadata": {
"id": "qgKbzrRfsPmU"
}
},
{
"cell_type": "code",
"source": [
"# Param Optimisation\n",
"param_grid_imdb = {'C': [0.1,1, 10, 100], 'gamma': [1,0.1,0.01,0.001],'kernel': ['rbf']}\n",
"grid_imdb_svc = GridSearchCV(SVC(),param_grid_imdb,refit=True,verbose=2)"
],
"metadata": {
"id": "kBG69w4JsPmU"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"grid_imdb_svc.fit(x_train_imdb,y_train_imdb.ravel())\n",
"pickle.dump(grid_imdb_svc, open('grid_imdb_svc.pickle', \"wb\"))"
],
"metadata": {
"id": "7C7pX6QlsPmU"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Train KNN model"
],
"metadata": {
"id": "OUnEb3ZVsPmV"
}
},
{
"cell_type": "code",
"source": [
"grid_params_imdb_knn = { 'n_neighbors' : [30,40,50,60,70,80,90], 'metric' : ['manhattan', 'minkowski'], 'weights': ['uniform', 'distance']}\n",
"grid_imdb_knn = GridSearchCV(KNeighborsClassifier(), grid_params_imdb_knn, n_jobs=-1,verbose=2)"
],
"metadata": {
"id": "L4wJA2R0sPmV"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"grid_imdb_knn.fit(x_train_imdb,np.ravel(y_train_imdb,order='C'))\n",
"pickle.dump(grid_imdb_knn, open('grid_imdb_knn.pickle', \"wb\"))"
],
"metadata": {
"id": "zBwE0Cz-sPmV"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Train LR Model"
],
"metadata": {
"id": "MQdS2y_2sPmV"
}
},
{
"cell_type": "code",
"source": [
"param_grid_imdb_lr = [ \n",
" {'penalty' : ['l1', 'l2', 'elasticnet'],\n",
" 'C' : np.logspace(-4, 4, 20),\n",
" 'solver' : ['lbfgs','newton-cg','sag'],\n",
" 'max_iter' : [100, 1000, 5000]\n",
" }\n",
"]\n",
"grid_imdb_lr = GridSearchCV(LogisticRegression(), param_grid = param_grid_imdb_lr, cv = 3, verbose=2, n_jobs=-1)"
],
"metadata": {
"id": "xTyPiA4osPmV"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"grid_imdb_lr.fit(x_train_imdb, np.ravel(y_train_imdb,order='C'))\n",
"pickle.dump(grid_imdb_lr, open('grid_imdb_lr.pickle', \"wb\"))"
],
"metadata": {
"id": "vhVNQZITsPmW"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Load Models"
],
"metadata": {
"id": "AcRWpiiesPmW"
}
},
{
"cell_type": "code",
"source": [
"# Load\n",
"loaded_knn_imdb = pickle.load(open('/content/sample_data/grid_imdb_knn.pickle', \"rb\"))"
],
"metadata": {
"trusted": true,
"id": "rgbWUgfosPmW",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "d8110a61-dec0-40e6-a337-e12de205e304"
},
"execution_count": 23,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"<ipython-input-23-6b9db75f90c6>:2: DeprecationWarning: Please use `csr_matrix` from the `scipy.sparse` namespace, the `scipy.sparse.csr` namespace is deprecated.\n",
" loaded_knn_imdb = pickle.load(open('/content/sample_data/grid_imdb_knn.pickle', \"rb\"))\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"print(loaded_knn_imdb.best_params_)"
],
"metadata": {
"trusted": true,
"id": "XpZjBtLgsPmW",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "9e888c63-3f57-464b-a623-ce287fb10d1b"
},
"execution_count": 24,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"{'metric': 'minkowski', 'n_neighbors': 90, 'weights': 'distance'}\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"### Info about models"
],
"metadata": {
"id": "tds-0KE9sPmX"
}
},
{
"cell_type": "code",
"source": [
"p = np.sum(y_train_imdb)/np.size(y_train_imdb)"
],
"metadata": {
"trusted": true,
"id": "XXrGyw5zsPmX"
},
"execution_count": 25,
"outputs": []
},
{
"cell_type": "code",
"source": [
"probs_knn = loaded_knn_imdb.predict(x_test_imdb)"
],
"metadata": {
"trusted": true,
"id": "EpPwD1xMsPmY"
},
"execution_count": 26,
"outputs": []
},
{
"cell_type": "code",
"source": [
"loaded_knn_imdb.predict(x_test_imdb[0])"
],
"metadata": {
"trusted": true,
"id": "M3L6UcqxsPmq",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "119c56e3-a7c7-4b48-daff-6e9797ed0fea"
},
"execution_count": 27,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array([1])"
]
},
"metadata": {},
"execution_count": 27
}
]
},
{
"cell_type": "code",
"source": [
"accuracy_knn = metrics.accuracy_score(y_test_imdb, probs_knn)"
],
"metadata": {
"trusted": true,
"id": "OOuTh7TlsPmq"
},
"execution_count": 28,
"outputs": []
},
{
"cell_type": "code",
"source": [
"print(\"The accuracy of the KNN model on the test data is %f\" %accuracy_knn)"
],
"metadata": {
"trusted": true,
"id": "Gaqh_vIpsPmr",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "be28e7bf-1ae0-4caa-f301-3a0db9d10f5b"
},
"execution_count": 29,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"The accuracy of the KNN model on the test data is 0.819000\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"input_review = input('give me review')"
],
"metadata": {
"id": "Z6qC0v0JsPms",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "ee6f61df-ae10-471a-83fc-535503f68f90"
},
"execution_count": 30,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"give me reviewI like apple\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"vector_input = loaded_vectorizer.transform([input_review]) "
],
"metadata": {
"id": "7ncfW-8lSwww"
},
"execution_count": 31,
"outputs": []
},
{
"cell_type": "code",
"source": [
"loaded_knn_imdb.predict(vector_input)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "cOVi5sl5SzBP",
"outputId": "52f25084-f0ff-4a2e-cca6-1ea2e2f2901e"
},
"execution_count": 32,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array([0])"
]
},
"metadata": {},
"execution_count": 32
}
]
}
]
}
\ No newline at end of file
......@@ -13,6 +13,7 @@ const useStyles = createStyles((theme) => ({
sections: {
display: "flex",
flexDirection: "row",
position: "relative",
}
}))
......@@ -23,7 +24,7 @@ export default function App() {
<Router>
<div className={classes.sections}>
<NavbarSimpleColored />
<div style={{width: window.innerWidth/5*4}}>
<div style={{width: window.innerWidth/5*4, position: "absolute", right: 0}}>
<Routes>
<Route path='/' element={<Home />} />
<Route path='/svm' element={<SVM />} />
......
import { Card, Image, Text } from '@mantine/core';
export function CardCustom(data: {src: string, title: string}) {
return (
<Card
shadow="xl"
padding="xl"
component="a"
// href="https://www.youtube.com/watch?v=dQw4w9WgXcQ"
target="_blank"
>
<Card.Section>
<Image
src={data.src}
height={160}
alt="No way!"
/>
</Card.Section>
<Text weight={500} size="lg" mt="md">
{data.title}
</Text>
{/* <Text mt="xs" color="dimmed" size="sm">
try {data.title} ...
</Text> */}
</Card>
);
}
\ No newline at end of file
import { Grid, Text, Title, Transition, createStyles } from "@mantine/core"
import { CardCustom } from "../Card/CardCustom"
const useStyle = createStyles((theme) => ({
container: {
height: "100%",
width: "100%",
},
transitionBox: {
display: "flex",
flexDirection: "column",
alignItems: "center",
justifyContent: "center",
},
description: {
width: "80%",
}
}))
function Home() {
const { classes } = useStyle()
return (
<div>Home</div>
<div className={classes.container}>
<Transition mounted={true} transition="pop" duration={1000} timingFunction="ease">
{
(styles) => (
<div className={classes.transitionBox}>
<Title order={1} size={60} mt={30} style={styles}>
Get Your Best XAI Solution
</Title>
<Text className={classes.description} mt={50} align="center">
Paragraphs are the building blocks of papers.
Many students define paragraphs in terms of length: a paragraph is a group of
at least five sentences, a paragraph is half a page long, etc. In reality, though, the unity and
coherence of ideas among sentences is what constitutes a paragraph. A paragraph is
defined as “a group of sentences or a single sentence that forms
a unit” (Lunsford and Connors 116). Length
and appearance do not determine whether a section in a paper is a paragraph. For instance, in some
styles of writing, particularly journalistic styles, a paragraph
can be just one sentence long. Ultimately, a paragraph is a
sentence or group of sentences that support one main idea. In this handout,
we will refer to this as the “controlling idea,” because it controls what happens in the
rest of the paragraph.
</Text>
<Grid ml={50} mr={50} mt={30}>
<Grid.Col span={3}>
<CardCustom
src="https://media.licdn.com/dms/image/C5112AQFY4bX3Y7jcHA/article-cover_image-shrink_600_2000/0/1565431496642?e=2147483647&v=beta&t=N2zhB7OhvdTrebqNkY2lxMaPHqPYpA4r3nj88msI-e0"
title="K-NN"
/>
</Grid.Col>
<Grid.Col span={3}>
<CardCustom
src=""
title="SVM"
/>
</Grid.Col>
<Grid.Col span={3}>
<CardCustom
src=""
title="Random Forest"
/>
</Grid.Col>
<Grid.Col span={3}>
<CardCustom
src="https://www.appstate.edu/~whiteheadjc/service/logit/logit.gif"
title="Logistic Regression"
/>
</Grid.Col>
</Grid>
</div>
)
}
</Transition>
</div>
)
}
......
......@@ -13,6 +13,7 @@ import { LOGO } from '../consts';
const useStyles = createStyles((theme) => ({
navbar: {
backgroundColor: theme.fn.variant({ variant: 'filled', color: theme.primaryColor }).background,
position: "fixed",
},
version: {
......@@ -85,10 +86,10 @@ const useStyles = createStyles((theme) => ({
const data = [
{ link: '/', label: 'Home', icon: HOME_IMG },
{ link: '/svm', label: 'SVM', icon: SVM_IMG },
{ link: '/knn', label: 'k-NN', icon: KNN_IMG },
{ link: '/rf', label: 'Random Forest', icon: RF_IMG },
{ link: '/lr', label: 'Logistic Regression', icon: LR_IMG },
{ link: '/svm', label: 'Get Your Rule', icon: SVM_IMG },
// { link: '/knn', label: 'Documentation', icon: KNN_IMG },
// { link: '/rf', label: 'Random Forest', icon: RF_IMG },
// { link: '/lr', label: 'Logistic Regression', icon: LR_IMG },
];
export function NavbarSimpleColored() {
......@@ -124,21 +125,15 @@ export function NavbarSimpleColored() {
{links}
</Navbar.Section>
<Navbar.Section className={classes.footer}>
{/* <Navbar.Section className={classes.footer}>
<a href="#" className={classes.link} onClick={(event) => {
event.preventDefault()
navigate("/settings")
}}>
{/* <IconSwitchHorizontal className={classes.linkIcon} stroke={1.5} /> */}
<img src={SETTING_IMG} className={classes.linkIcon} />
<span>Settings</span>
</a>
{/* <a href="#" className={classes.link} onClick={(event) => event.preventDefault()}>
<img src={HOME_IMG} className={classes.linkIcon} />
<span>Logout</span>
</a> */}
</Navbar.Section>
</Navbar.Section> */}
</Navbar>
);
}
\ No newline at end of file
import { Title, createStyles } from '@mantine/core'
import React from 'react'
const useStyle = createStyles((theme) => ({
table: {
border: "1px solid #ddd",
borderCollapse: "collapse",
},
td: {
border: "1px solid #ddd",
width: window.innerHeight / 4,
padding: "5px 10px",
},
negative: {
border: "1px solid #ddd",
width: window.innerHeight / 4,
padding: "5px 10px",
background: "green",
color: "white",
},
positive: {
border: "1px solid #ddd",
width: window.innerHeight / 4,
padding: "5px 10px",
background: "red",
color: "white",
}
}))
function CounterfactualTable(data: {tableData: any}) {
const { classes } = useStyle();
return (
<div>
<Title order={6} mt={30} mb={10}>Results</Title>
<table className={classes.table}>
{
data.tableData.map((e: any, i: number) => {
return (
<tr key={i}>
<td className={classes.positive}>{e.initial}</td>
{i === 0 ? <td className={classes.td} rowSpan={data.tableData.length}>Changes to</td> : null}
<td className={classes.negative}>{e.changedTo}</td>
</tr>
);
})
}
</table>
</div>
)
}
export default CounterfactualTable
\ No newline at end of file
import { Text, Title } from '@mantine/core';
export function Prediction(data: {prediction: string}) {
return (
<div>
<Title order={6} mt={60}>Prediction</Title>
<Text size={30} mt={10}>
{data.prediction}
</Text>
</div>
);
}
\ No newline at end of file
import { createStyles } from "@mantine/core"
import { TextAreaCustom } from "../_common/TextAreaCustom"
import { DropDownWidget } from "../_common/DropDownWidget";
import { ButtonCustom } from "../_common/ButtonCustom";
import { ScrollButton } from "../_common/ScrollButton";
import { CustomTitle } from "../_common/CustomTitle";
import { useWindowScroll } from "@mantine/hooks";
import { Sentence } from "./Sentence";
import { Prediction } from "./Prediction";
import CounterfactualTable from "./CounterfactualTable";
const useStyle = createStyles((theme) => ({
mainContainer: {
display: "flex",
flexDirection: "column",
alignItems: "center",
}
}))
function SVM() {
const { classes } = useStyle();
const [scroll, scrollTo] = useWindowScroll();
const tableData = [
{initial: "like", changedTo: "do not like"},
{initial: "happy", changedTo: "sad"},
{initial: "good", changedTo: "bad"},
];
return (
<div>SVM</div>
<div className={classes.mainContainer}>
<div id="getData">
<CustomTitle title="Get Your Rule" />
<TextAreaCustom
label="Enter the review"
placeholder="Enter the review"
/>
<DropDownWidget
label="Select the Algorithm"
placeholder="Algorithm"
/>
<ButtonCustom
label="Get the rule"
onClick={() => scrollTo({ y: window.innerHeight })}
/>
<ScrollButton />
</div>
<div id="viewResult" style={{width: window.innerWidth / 5 * 3, height: window.innerHeight}}>
<CustomTitle title="Counterfactual Result" />
<Sentence
sentence="I like apple"
hightLightWords={["like"]}
/>
<Prediction prediction="Positive" />
<CounterfactualTable tableData={tableData} />
</div>
</div>
)
}
......
import { Highlight, Text, Title } from '@mantine/core';
export function Sentence(data: {sentence: string, hightLightWords: string[]}) {
return (
<div>
<Title order={6} mt={30}>Sentence</Title>
<Text size={30} mt={10}>
<Highlight highlight={data.hightLightWords}>
{data.sentence}
</Highlight>
</Text>
</div>
);
}
\ No newline at end of file
import { Button, createStyles } from '@mantine/core';
const useStyles = createStyles((theme) => ({
container: {
width: window.innerWidth / 5 * 3,
display: "flex",
justifyContent: "end",
}
}))
export function ButtonCustom(data: {label: string, onClick: any}) {
const { classes } = useStyles()
return (
<div className={classes.container}>
<Button
variant="gradient"
gradient={{ from: 'teal', to: 'blue', deg: 60 }}
radius="xl"
size="lg"
mt={30}
mb={100}
onClick={data.onClick}
>
{data.label}
</Button>
</div>
);
}
\ No newline at end of file
import { Title } from '@mantine/core';
export function CustomTitle(data: {title: string}) {
return (
<>
<Title order={1} align='left' mt={50}>{data.title}</Title>
</>
);
}
\ No newline at end of file
import { Select } from '@mantine/core';
export function DropDownWidget(data: {label: string, placeholder: string}) {
return (
<Select
mt={30}
label={data.label}
placeholder={data.placeholder}
data={[
{ value: 'SVM', label: 'Support Vector Machine' },
{ value: 'random forest', label: 'Random Forest' },
{ value: 'logistic Regression', label: 'Logistic Regression' },
{ value: 'KNN', label: 'K-Nearest Neighbor' },
]}
style={{width: window.innerWidth / 5 * 3}}
/>
);
}
\ No newline at end of file
import { useWindowScroll } from '@mantine/hooks';
import { Button, createStyles } from '@mantine/core';
const useStyle = createStyles((theme) => ({
button: {
position: "fixed",
bottom: 50,
right: 50,
width: "50px",
height: "50px",
borderRadius: "50%",
}
}))
export function ScrollButton() {
const [scroll, scrollTo] = useWindowScroll();
const { classes } = useStyle()
return (
<Button className={classes.button} onClick={() => scrollTo({ y: 0 })}>^</Button>
);
}
\ No newline at end of file
import { Textarea } from "@mantine/core";
export function TextAreaCustom(data: {placeholder: string, label: string}) {
return (
<Textarea
mt={30}
placeholder={data.placeholder}
label={data.label}
style={{width: window.innerWidth / 5 * 3}}
minRows={10}
autosize={true}
withAsterisk
/>
);
}
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment