Delete Final_1D_CNN_Model.ipynb

parent ab2de65c
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "WXK1h7IS5ua0"
},
"outputs": [],
"source": [
"#Import Libraries\n",
"\n",
"import pandas as pd\n",
"from matplotlib import pyplot as plt\n",
"import numpy as np\n",
"%matplotlib inline\n",
"\n",
"import tensorflow as tf\n",
"from tensorflow.keras.models import Sequential\n",
"from tensorflow.keras.layers import Dense\n",
"from tensorflow.keras.layers import Flatten\n",
"from sklearn.model_selection import train_test_split\n",
"import time\n",
"\n",
"\n",
"#Data Balancing libraries\n",
"from imblearn.under_sampling import NearMiss\n",
"from imblearn.over_sampling import ADASYN\n",
"\n",
"from imblearn.over_sampling import SMOTE\n",
"from imblearn.under_sampling import RandomUnderSampler\n",
"from imblearn.combine import SMOTEENN #Hybrid method\n",
"\n",
"from sklearn.metrics import accuracy_score, precision_score, recall_score, confusion_matrix, roc_curve, auc\n",
"\n",
"\n",
"#apply standardization\n",
"from sklearn.preprocessing import StandardScaler\n",
"\n",
"#Visual libraries\n",
"import seaborn as sns\n",
"\n",
"from sklearn.neighbors import KNeighborsClassifier\n",
"from sklearn.metrics import classification_report"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "MihfTY1_l4j0",
"outputId": "0560e2d4-429b-411c-eca1-c57ad70d1fff"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mounted at /content/drive\n"
]
}
],
"source": [
"from google.colab import drive\n",
"drive.mount('/content/drive')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ZCZ6F-vSF187"
},
"outputs": [],
"source": [
"data = pd.read_csv('/content/drive/MyDrive/ML Model Attack/disease_preprocess4.csv')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "fMhfrc507QwL"
},
"outputs": [],
"source": [
"data = pd.read_csv('/content/disease_preprocess4.csv')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 226
},
"id": "JryGjddVC8xL",
"outputId": "49d6a37b-6b92-445f-b115-f4989ec40ef4"
},
"outputs": [
{
"data": {
"application/vnd.google.colaboratory.intrinsic+json": {
"type": "dataframe",
"variable_name": "data"
},
"text/html": [
"\n",
" <div id=\"df-f7ee6c0e-29e6-47d6-bfb8-ab66e14a21cf\" class=\"colab-df-container\">\n",
" <div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>GeneralHealth</th>\n",
" <th>Checkup</th>\n",
" <th>Exercise</th>\n",
" <th>HeartDisease</th>\n",
" <th>Depression</th>\n",
" <th>Diabetes</th>\n",
" <th>Arthritis</th>\n",
" <th>Gender</th>\n",
" <th>AgeCategory</th>\n",
" <th>BMI</th>\n",
" <th>SmokingHistory</th>\n",
" <th>AlcoholConsumption</th>\n",
" <th>FriedPotatoConsumption</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>1</td>\n",
" <td>2</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>10</td>\n",
" <td>-2.159696</td>\n",
" <td>1</td>\n",
" <td>-0.621527</td>\n",
" <td>0.664502</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>5</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>10</td>\n",
" <td>-0.051548</td>\n",
" <td>0</td>\n",
" <td>-0.621527</td>\n",
" <td>-0.267579</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>5</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>8</td>\n",
" <td>0.742649</td>\n",
" <td>0</td>\n",
" <td>-0.133707</td>\n",
" <td>1.130543</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>11</td>\n",
" <td>0.015913</td>\n",
" <td>0</td>\n",
" <td>-0.621527</td>\n",
" <td>0.198462</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>4</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>12</td>\n",
" <td>-0.652562</td>\n",
" <td>1</td>\n",
" <td>-0.621527</td>\n",
" <td>-0.733620</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>\n",
" <div class=\"colab-df-buttons\">\n",
"\n",
" <div class=\"colab-df-container\">\n",
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-f7ee6c0e-29e6-47d6-bfb8-ab66e14a21cf')\"\n",
" title=\"Convert this dataframe to an interactive table.\"\n",
" style=\"display:none;\">\n",
"\n",
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\" viewBox=\"0 -960 960 960\">\n",
" <path d=\"M120-120v-720h720v720H120Zm60-500h600v-160H180v160Zm220 220h160v-160H400v160Zm0 220h160v-160H400v160ZM180-400h160v-160H180v160Zm440 0h160v-160H620v160ZM180-180h160v-160H180v160Zm440 0h160v-160H620v160Z\"/>\n",
" </svg>\n",
" </button>\n",
"\n",
" <style>\n",
" .colab-df-container {\n",
" display:flex;\n",
" gap: 12px;\n",
" }\n",
"\n",
" .colab-df-convert {\n",
" background-color: #E8F0FE;\n",
" border: none;\n",
" border-radius: 50%;\n",
" cursor: pointer;\n",
" display: none;\n",
" fill: #1967D2;\n",
" height: 32px;\n",
" padding: 0 0 0 0;\n",
" width: 32px;\n",
" }\n",
"\n",
" .colab-df-convert:hover {\n",
" background-color: #E2EBFA;\n",
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
" fill: #174EA6;\n",
" }\n",
"\n",
" .colab-df-buttons div {\n",
" margin-bottom: 4px;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert {\n",
" background-color: #3B4455;\n",
" fill: #D2E3FC;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert:hover {\n",
" background-color: #434B5C;\n",
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
" fill: #FFFFFF;\n",
" }\n",
" </style>\n",
"\n",
" <script>\n",
" const buttonEl =\n",
" document.querySelector('#df-f7ee6c0e-29e6-47d6-bfb8-ab66e14a21cf button.colab-df-convert');\n",
" buttonEl.style.display =\n",
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
"\n",
" async function convertToInteractive(key) {\n",
" const element = document.querySelector('#df-f7ee6c0e-29e6-47d6-bfb8-ab66e14a21cf');\n",
" const dataTable =\n",
" await google.colab.kernel.invokeFunction('convertToInteractive',\n",
" [key], {});\n",
" if (!dataTable) return;\n",
"\n",
" const docLinkHtml = 'Like what you see? Visit the ' +\n",
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
" + ' to learn more about interactive tables.';\n",
" element.innerHTML = '';\n",
" dataTable['output_type'] = 'display_data';\n",
" await google.colab.output.renderOutput(dataTable, element);\n",
" const docLink = document.createElement('div');\n",
" docLink.innerHTML = docLinkHtml;\n",
" element.appendChild(docLink);\n",
" }\n",
" </script>\n",
" </div>\n",
"\n",
"\n",
"<div id=\"df-f5949bf6-8493-4bae-90ed-a325b814bc5c\">\n",
" <button class=\"colab-df-quickchart\" onclick=\"quickchart('df-f5949bf6-8493-4bae-90ed-a325b814bc5c')\"\n",
" title=\"Suggest charts\"\n",
" style=\"display:none;\">\n",
"\n",
"<svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
" width=\"24px\">\n",
" <g>\n",
" <path d=\"M19 3H5c-1.1 0-2 .9-2 2v14c0 1.1.9 2 2 2h14c1.1 0 2-.9 2-2V5c0-1.1-.9-2-2-2zM9 17H7v-7h2v7zm4 0h-2V7h2v10zm4 0h-2v-4h2v4z\"/>\n",
" </g>\n",
"</svg>\n",
" </button>\n",
"\n",
"<style>\n",
" .colab-df-quickchart {\n",
" --bg-color: #E8F0FE;\n",
" --fill-color: #1967D2;\n",
" --hover-bg-color: #E2EBFA;\n",
" --hover-fill-color: #174EA6;\n",
" --disabled-fill-color: #AAA;\n",
" --disabled-bg-color: #DDD;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-quickchart {\n",
" --bg-color: #3B4455;\n",
" --fill-color: #D2E3FC;\n",
" --hover-bg-color: #434B5C;\n",
" --hover-fill-color: #FFFFFF;\n",
" --disabled-bg-color: #3B4455;\n",
" --disabled-fill-color: #666;\n",
" }\n",
"\n",
" .colab-df-quickchart {\n",
" background-color: var(--bg-color);\n",
" border: none;\n",
" border-radius: 50%;\n",
" cursor: pointer;\n",
" display: none;\n",
" fill: var(--fill-color);\n",
" height: 32px;\n",
" padding: 0;\n",
" width: 32px;\n",
" }\n",
"\n",
" .colab-df-quickchart:hover {\n",
" background-color: var(--hover-bg-color);\n",
" box-shadow: 0 1px 2px rgba(60, 64, 67, 0.3), 0 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
" fill: var(--button-hover-fill-color);\n",
" }\n",
"\n",
" .colab-df-quickchart-complete:disabled,\n",
" .colab-df-quickchart-complete:disabled:hover {\n",
" background-color: var(--disabled-bg-color);\n",
" fill: var(--disabled-fill-color);\n",
" box-shadow: none;\n",
" }\n",
"\n",
" .colab-df-spinner {\n",
" border: 2px solid var(--fill-color);\n",
" border-color: transparent;\n",
" border-bottom-color: var(--fill-color);\n",
" animation:\n",
" spin 1s steps(1) infinite;\n",
" }\n",
"\n",
" @keyframes spin {\n",
" 0% {\n",
" border-color: transparent;\n",
" border-bottom-color: var(--fill-color);\n",
" border-left-color: var(--fill-color);\n",
" }\n",
" 20% {\n",
" border-color: transparent;\n",
" border-left-color: var(--fill-color);\n",
" border-top-color: var(--fill-color);\n",
" }\n",
" 30% {\n",
" border-color: transparent;\n",
" border-left-color: var(--fill-color);\n",
" border-top-color: var(--fill-color);\n",
" border-right-color: var(--fill-color);\n",
" }\n",
" 40% {\n",
" border-color: transparent;\n",
" border-right-color: var(--fill-color);\n",
" border-top-color: var(--fill-color);\n",
" }\n",
" 60% {\n",
" border-color: transparent;\n",
" border-right-color: var(--fill-color);\n",
" }\n",
" 80% {\n",
" border-color: transparent;\n",
" border-right-color: var(--fill-color);\n",
" border-bottom-color: var(--fill-color);\n",
" }\n",
" 90% {\n",
" border-color: transparent;\n",
" border-bottom-color: var(--fill-color);\n",
" }\n",
" }\n",
"</style>\n",
"\n",
" <script>\n",
" async function quickchart(key) {\n",
" const quickchartButtonEl =\n",
" document.querySelector('#' + key + ' button');\n",
" quickchartButtonEl.disabled = true; // To prevent multiple clicks.\n",
" quickchartButtonEl.classList.add('colab-df-spinner');\n",
" try {\n",
" const charts = await google.colab.kernel.invokeFunction(\n",
" 'suggestCharts', [key], {});\n",
" } catch (error) {\n",
" console.error('Error during call to suggestCharts:', error);\n",
" }\n",
" quickchartButtonEl.classList.remove('colab-df-spinner');\n",
" quickchartButtonEl.classList.add('colab-df-quickchart-complete');\n",
" }\n",
" (() => {\n",
" let quickchartButtonEl =\n",
" document.querySelector('#df-f5949bf6-8493-4bae-90ed-a325b814bc5c button');\n",
" quickchartButtonEl.style.display =\n",
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
" })();\n",
" </script>\n",
"</div>\n",
"\n",
" </div>\n",
" </div>\n"
],
"text/plain": [
" GeneralHealth Checkup Exercise HeartDisease Depression Diabetes \\\n",
"0 1 2 0 0 0 0 \n",
"1 5 1 0 1 0 1 \n",
"2 5 1 1 0 0 1 \n",
"3 1 1 1 1 0 1 \n",
"4 4 1 0 0 0 0 \n",
"\n",
" Arthritis Gender AgeCategory BMI SmokingHistory \\\n",
"0 1 1 10 -2.159696 1 \n",
"1 0 1 10 -0.051548 0 \n",
"2 0 1 8 0.742649 0 \n",
"3 0 0 11 0.015913 0 \n",
"4 0 0 12 -0.652562 1 \n",
"\n",
" AlcoholConsumption FriedPotatoConsumption \n",
"0 -0.621527 0.664502 \n",
"1 -0.621527 -0.267579 \n",
"2 -0.133707 1.130543 \n",
"3 -0.621527 0.198462 \n",
"4 -0.621527 -0.733620 "
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data.head()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "F_PjX618F5l6"
},
"outputs": [],
"source": [
"data.columns"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Jkmi2N0aC8nZ",
"outputId": "1df3449b-b629-4a79-fc10-e54d1b6b27a6"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" GeneralHealth Checkup Exercise Depression Diabetes Arthritis \\\n",
"192601 5 1 1 0 0 0 \n",
"196337 5 1 0 0 0 0 \n",
"99888 5 1 1 0 0 0 \n",
"282228 4 3 1 0 0 0 \n",
"216188 3 1 1 0 0 0 \n",
"\n",
" Gender AgeCategory BMI SmokingHistory AlcoholConsumption \\\n",
"192601 1 5 -0.368920 1 0.841932 \n",
"196337 1 11 -0.333656 1 -0.499572 \n",
"99888 1 8 1.898681 0 -0.377617 \n",
"282228 0 11 0.728850 0 2.427347 \n",
"216188 0 7 0.314887 0 -0.621527 \n",
"\n",
" FriedPotatoConsumption \n",
"192601 1.596584 \n",
"196337 0.198462 \n",
"99888 0.198462 \n",
"282228 0.198462 \n",
"216188 -0.267579 \n",
"192601 0\n",
"196337 0\n",
"99888 0\n",
"282228 0\n",
"216188 0\n",
"Name: HeartDisease, dtype: int64\n"
]
}
],
"source": [
"# define target variable and features\n",
"\n",
"# Defining the features (X) and the target (y)\n",
"\n",
"X = data.drop('HeartDisease', axis=1) # Features\n",
"y = data['HeartDisease'] # Target variable\n",
"\n",
"# Performing the train-test split\n",
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)\n",
"\n",
"X_train.shape, X_test.shape, y_train.shape, y_test.shape\n",
"\n",
"print(X_train.head())\n",
"print(y_train.head())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "E2WufmJdz8g_"
},
"source": [
"##Perform Scaling"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Y3FJTGSrUJIh"
},
"outputs": [],
"source": [
"#apply standardization\n",
"from sklearn.preprocessing import StandardScaler\n",
"\n",
"# Create a StandardScaler instance\n",
"scaler = StandardScaler()\n",
"\n",
"\n",
"# Fit the scaler on the training data and transform it\n",
"X_train_scaled = scaler.fit_transform(X_train)\n",
"\n",
"# Use the same scaler to transform the test data\n",
"X_test_scaled = scaler.transform(X_test)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "HZThz9by_GfH",
"outputId": "160a500a-073e-40ab-ec29-9e981f26a1b9"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"HeartDisease\n",
"0 227109\n",
"1 19974\n",
"Name: count, dtype: int64\n"
]
}
],
"source": [
"# Print the count of each class in the before resample data\n",
"print(y_train.value_counts())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zJtzN8ctUoIo"
},
"source": [
"## SMOTE and Random Combined"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "DX0w3ww2UUFw"
},
"outputs": [],
"source": [
"# Resample the training data\n",
"\n",
"from imblearn.over_sampling import SMOTE\n",
"from imblearn.under_sampling import RandomUnderSampler\n",
"from imblearn.combine import SMOTEENN # A hybrid method\n",
"\n",
"\n",
"# Apply SMOTE to oversample the minority class\n",
"smote=SMOTE(sampling_strategy='auto', random_state=23)\n",
"X_train_smote, y_train_smote = smote.fit_resample(X_train_scaled, y_train)\n",
"\n",
"# Apply undersampling to the majority class\n",
"under_sampler = RandomUnderSampler(sampling_strategy='auto', random_state=23)\n",
"X_train_combined, y_train_combined = under_sampler.fit_resample(X_train_smote, y_train_smote)\n",
"\n",
"# Train and evaluate your machine learning model using X_train_combined and y_train_combined\n",
"# Evaluate the model on X_test_scaled and y_test\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9YO5Ql4l_Bl0"
},
"outputs": [],
"source": [
"# Print the count of each class in the resampled data\n",
"print(y_train.value_counts())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "v9-BQ1bpiN3i"
},
"outputs": [],
"source": [
"# Visualization of target variable after resampling\n",
"\n",
"g = sns.countplot(x= y_train_combined,data=data, palette=\"muted\")\n",
"g.set_ylabel(\"Patients\", fontsize=14)\n",
"g.set_xlabel(\"Heart Disease\", fontsize=14)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "71z03akwbnSu"
},
"source": [
"### Model Training"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "0rkXxtX4bqOH"
},
"outputs": [],
"source": [
"from tensorflow.keras.models import Sequential\n",
"from tensorflow.keras.layers import Conv1D, MaxPooling1D, Flatten, Dense, Dropout, BatchNormalization\n",
"from tensorflow.keras.optimizers import Adam\n",
"from tensorflow.keras.losses import BinaryCrossentropy\n",
"\n",
"input_shape = (X_train_combined.shape[1], 1)\n",
"model = Sequential()\n",
"\n",
"# Add Convolutional and Pooling layers\n",
"model.add(Conv1D(filters=128, kernel_size=3, activation='relu', input_shape=input_shape))\n",
"model.add(BatchNormalization()) # Add batch normalization\n",
"model.add(MaxPooling1D(pool_size=2))\n",
"model.add(Conv1D(filters=256, kernel_size=3, activation='relu'))\n",
"model.add(BatchNormalization()) # Add batch normalization\n",
"model.add(MaxPooling1D(pool_size=2))\n",
"\n",
"model.add(Flatten())\n",
"\n",
"# Add Dense layers\n",
"model.add(Dense(units=512, activation='relu'))\n",
"model.add(BatchNormalization()) # Add batch normalization\n",
"model.add(Dropout(0.5))\n",
"model.add(Dense(units=256, activation='relu'))\n",
"model.add(BatchNormalization()) # Add batch normalization\n",
"model.add(Dropout(0.5))\n",
"model.add(Dense(units=128, activation='relu'))\n",
"model.add(BatchNormalization()) # Add batch normalization\n",
"model.add(Dropout(0.5))\n",
"model.add(Dense(units=64, activation='relu'))\n",
"model.add(BatchNormalization()) # Add batch normalization\n",
"model.add(Dropout(0.5))\n",
"model.add(Dense(units=1, activation='sigmoid'))\n",
"\n",
"# Compile the model\n",
"model.compile(optimizer=Adam(learning_rate=0.0001), loss='binary_crossentropy', metrics=['accuracy'])\n",
"#model.summary()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "qcpsJFNQbyEw"
},
"outputs": [],
"source": [
"start_time = time.time()\n",
"history = model.fit(X_train_combined, y_train_combined, epochs=10, validation_split=0.2, verbose=2)\n",
"end_time = time.time()\n",
"execution_time = end_time - start_time\n",
"print(\"Execution time:\", execution_time, \"seconds\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ZuYctwmBb0Rw"
},
"outputs": [],
"source": [
"original_model_accuracy = model.evaluate(X_test_scaled, y_test)[1]\n",
"print(\"Original Model Accuracy:\", original_model_accuracy)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "FDlrN-l0b1JI"
},
"outputs": [],
"source": [
"from sklearn.metrics import confusion_matrix , classification_report\n",
"y_pred = model.predict(X_test_scaled) > 0.5\n",
"print(confusion_matrix(y_test, y_pred))\n",
"print(classification_report(y_test, y_pred))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "EEdrLiFpvdWz"
},
"outputs": [],
"source": [
"#Import the necessary libraries\n",
"import numpy as np\n",
"from sklearn.metrics import confusion_matrix\n",
"import seaborn as sns\n",
"import matplotlib.pyplot as plt\n",
"\n",
"y_pred = model.predict(X_test_scaled) # Replace 'model' with your trained model\n",
"\n",
"# Convert predicted probabilities to binary labels (0 or 1)\n",
"y_pred_binary = (y_pred > 0.5).astype(int)\n",
"\n",
"#compute the confusion matrix.\n",
"cm = confusion_matrix(y_test,y_pred_binary)\n",
"\n",
"#Plot the confusion matrix.\n",
"sns.heatmap(cm,\n",
" annot=True,\n",
" fmt='g',\n",
" xticklabels=['Class 0','Class 1'],\n",
" yticklabels=['Class 0','Class 1'])\n",
"plt.ylabel('Prediction',fontsize=13)\n",
"plt.xlabel('Actual',fontsize=13)\n",
"plt.title('Confusion Matrix',fontsize=17)\n",
"plt.show()"
]
}
],
"metadata": {
"accelerator": "TPU",
"colab": {
"gpuType": "V28",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment