Commit 79f827fb authored by RR Nimesha Manchalee's avatar RR Nimesha Manchalee

Disease classifier added

parent dc889df1
{
"cells": [
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [],
"source": [
"data_path = 'data/images/'\n",
"model_weights = 'weights/disease_classification.h5'\n",
"\n",
"batch_size = 32\n",
"valid_size = 16\n",
"color_mode = 'rgb'\n",
"\n",
"width = 299\n",
"height = 299\n",
"\n",
"target_size = (width, height)\n",
"input_shape = (width, height, 3)\n",
"\n",
"zoom_range = 0.3\n",
"shear_range = 0.3\n",
"shift_range = 0.3\n",
"rotation_range = 30\n",
"\n",
"val_split = 0.2\n",
"\n",
"dense_1 = 512\n",
"dense_2 = 256\n",
"dense_3 = 64\n",
"num_classes = 5\n",
"\n",
"epochs = 50\n",
"rate = 0.2\n",
"\n",
"verbose = 1"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'\n",
"\n",
"import cv2 as cv\n",
"import numpy as np\n",
"import pandas as pd\n",
"import tensorflow as tf\n",
"\n",
"from tensorflow.keras.activations import relu\n",
"from tensorflow.keras.models import Model, load_model\n",
"from tensorflow.keras.layers import Dense, BatchNormalization, Dropout\n",
"############################################################################################\n",
"\n",
"\n",
"#set model to train on GPU\n",
"# physical_devices = tf.config.experimental.list_physical_devices('GPU')\n",
"# tf.config.experimental.set_memory_growth(physical_devices[0], True)\n",
"\n",
"def preprocessing_function(img):\n",
" img = tf.keras.applications.xception.preprocess_input(img)\n",
" return img\n",
"\n",
"def image_data_generator():\n",
" train_datagen = tf.keras.preprocessing.image.ImageDataGenerator(\n",
" rotation_range = rotation_range,\n",
" shear_range = shear_range,\n",
" zoom_range = zoom_range,\n",
" width_shift_range=shift_range,\n",
" height_shift_range=shift_range,\n",
" horizontal_flip = True,\n",
" validation_split= val_split,\n",
" preprocessing_function=preprocessing_function\n",
" )\n",
"\n",
" train_generator = train_datagen.flow_from_directory(\n",
" data_path,\n",
" target_size = target_size,\n",
" color_mode = color_mode,\n",
" batch_size = batch_size,\n",
" class_mode = 'categorical',\n",
" subset = 'training',\n",
" shuffle = True\n",
" )\n",
"\n",
" validation_generator = train_datagen.flow_from_directory(\n",
" data_path,\n",
" target_size = target_size,\n",
" color_mode = color_mode,\n",
" batch_size = valid_size,\n",
" class_mode = 'categorical',\n",
" subset = 'validation',\n",
" shuffle = False\n",
" )\n",
"\n",
" return train_generator, validation_generator"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Found 4031 images belonging to 5 classes.\n",
"Found 1005 images belonging to 5 classes.\n"
]
}
],
"source": [
"train_generator, validation_generator = image_data_generator()"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{0: 'CCI_Caterpillars',\n",
" 1: 'CCI_Leaflets',\n",
" 2: 'WCLWD_DryingofLeaflets',\n",
" 3: 'WCLWD_Flaccidity',\n",
" 4: 'WCLWD_Yellowing'}"
]
},
"execution_count": 31,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class_dict = train_generator.class_indices\n",
"class_dict_rev = {v: k for k, v in class_dict.items()}\n",
"class_dict_rev"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"class DiseaseClassification(object):\n",
" def __init__(self):\n",
" self.train_generator = train_generator\n",
" self.validation_generator = validation_generator\n",
" self.train_step = self.train_generator.samples // batch_size\n",
" self.validation_step = self.validation_generator.samples // valid_size\n",
"\n",
" self.accuracy = tf.keras.metrics.CategoricalAccuracy()\n",
" self.recall = tf.keras.metrics.Recall()\n",
" self.precision = tf.keras.metrics.Precision()\n",
" self.auc = tf.keras.metrics.AUC()\n",
"\n",
" self.id2disease = {v:k for k, v in self.train_generator.class_indices.items()}\n",
"\n",
" def classifier(self, x):\n",
" if not self.trainable:\n",
" x = Dense(dense_1, activation='relu')(x)\n",
" x = Dense(dense_1)(x)\n",
" x = BatchNormalization()(x)\n",
" x = relu(x)\n",
" x = Dropout(rate)(x)\n",
"\n",
" x = Dense(dense_2, activation='relu')(x)\n",
" x = Dense(dense_2)(x)\n",
" x = BatchNormalization()(x)\n",
" x = relu(x)\n",
" x = Dropout(rate)(x)\n",
"\n",
" x = Dense(dense_3, activation='relu')(x)\n",
" x = Dense(dense_3)(x)\n",
" x = BatchNormalization()(x)\n",
" x = relu(x)\n",
" x = Dropout(rate)(x)\n",
" return x\n",
"\n",
" def model_conversion(self, trainable):\n",
" functional_model = tf.keras.applications.Xception(weights=\"imagenet\")\n",
" functional_model.trainable = trainable\n",
"\n",
" self.trainable = trainable\n",
"\n",
" inputs = functional_model.input\n",
"\n",
" x = functional_model.layers[-2].output\n",
" x = self.classifier(x)\n",
" outputs = Dense(num_classes, activation='softmax')(x)\n",
"\n",
" model = Model(\n",
" inputs=inputs,\n",
" outputs=outputs\n",
" )\n",
" \n",
" self.model = model\n",
" # self.model.summary()\n",
"\n",
" def train(self):\n",
" callback = tf.keras.callbacks.EarlyStopping(\n",
" monitor='val_loss', \n",
" patience=5\n",
" )\n",
"\n",
" self.model.compile(\n",
" optimizer='Adam',\n",
" loss='categorical_crossentropy',\n",
" metrics=[\n",
" self.accuracy,\n",
" self.recall,\n",
" self.precision,\n",
" self.auc\n",
" ]\n",
" )\n",
" self.model.fit(\n",
" self.train_generator,\n",
" steps_per_epoch= self.train_step,\n",
" validation_data= self.validation_generator,\n",
" validation_steps = self.validation_step,\n",
" epochs=epochs,\n",
" verbose=verbose,\n",
" callbacks=[callback]\n",
" )\n",
"\n",
" def save_model(self):\n",
" self.model.save(model_weights)\n",
"\n",
" def load_model(self):\n",
" self.model = load_model(model_weights)\n",
" self.model.compile(\n",
" optimizer='Adam',\n",
" loss='categorical_crossentropy',\n",
" metrics=[\n",
" self.accuracy,\n",
" self.recall,\n",
" self.precision\n",
" ]\n",
" )\n",
"\n",
" def predict(self, x):\n",
" x = preprocessing_function(x)\n",
" x = np.expand_dims(x, axis=0)\n",
" P = self.model.predict(x)\n",
" disease_id = np.argmax(P)\n",
" disease = self.id2disease[disease_id]\n",
" return disease\n",
" \n",
" def process(self):\n",
" if not os.path.exists(model_weights):\n",
" self.model_conversion(False)\n",
" self.train()\n",
" self.save_model()\n",
" else:\n",
" self.load_model()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/50\n",
"125/125 [==============================] - 148s 1s/step - loss: 0.5661 - categorical_accuracy: 0.7944 - recall: 0.7097 - precision: 0.8789 - auc: 0.9607 - val_loss: 0.5780 - val_categorical_accuracy: 0.7984 - val_recall: 0.7117 - val_precision: 0.8803 - val_auc: 0.9592\n",
"Epoch 2/50\n",
"125/125 [==============================] - 160s 1s/step - loss: 0.2862 - categorical_accuracy: 0.9030 - recall: 0.8837 - precision: 0.9201 - auc: 0.9889 - val_loss: 0.3297 - val_categorical_accuracy: 0.8690 - val_recall: 0.8508 - val_precision: 0.8819 - val_auc: 0.9858\n",
"Epoch 3/50\n",
"125/125 [==============================] - 142s 1s/step - loss: 0.2331 - categorical_accuracy: 0.9200 - recall: 0.9092 - precision: 0.9294 - auc: 0.9917 - val_loss: 0.5819 - val_categorical_accuracy: 0.8246 - val_recall: 0.8216 - val_precision: 0.8257 - val_auc: 0.9713\n",
"Epoch 4/50\n",
"125/125 [==============================] - 137s 1s/step - loss: 0.2084 - categorical_accuracy: 0.9207 - recall: 0.9157 - precision: 0.9294 - auc: 0.9937 - val_loss: 0.1323 - val_categorical_accuracy: 0.9506 - val_recall: 0.9486 - val_precision: 0.9515 - val_auc: 0.9976\n",
"Epoch 5/50\n",
"125/125 [==============================] - 139s 1s/step - loss: 0.1792 - categorical_accuracy: 0.9352 - recall: 0.9280 - precision: 0.9431 - auc: 0.9952 - val_loss: 0.1164 - val_categorical_accuracy: 0.9526 - val_recall: 0.9506 - val_precision: 0.9564 - val_auc: 0.9982\n",
"Epoch 6/50\n",
"125/125 [==============================] - 141s 1s/step - loss: 0.1762 - categorical_accuracy: 0.9397 - recall: 0.9335 - precision: 0.9465 - auc: 0.9952 - val_loss: 0.3771 - val_categorical_accuracy: 0.8659 - val_recall: 0.8629 - val_precision: 0.8708 - val_auc: 0.9849\n",
"Epoch 7/50\n",
"125/125 [==============================] - 136s 1s/step - loss: 0.1890 - categorical_accuracy: 0.9335 - recall: 0.9282 - precision: 0.9407 - auc: 0.9945 - val_loss: 0.1592 - val_categorical_accuracy: 0.9395 - val_recall: 0.9355 - val_precision: 0.9450 - val_auc: 0.9964\n",
"Epoch 8/50\n",
"125/125 [==============================] - 140s 1s/step - loss: 0.1607 - categorical_accuracy: 0.9437 - recall: 0.9375 - precision: 0.9520 - auc: 0.9963 - val_loss: 0.4861 - val_categorical_accuracy: 0.8478 - val_recall: 0.8448 - val_precision: 0.8508 - val_auc: 0.9775\n",
"Epoch 9/50\n",
"125/125 [==============================] - 137s 1s/step - loss: 0.1472 - categorical_accuracy: 0.9510 - recall: 0.9452 - precision: 0.9541 - auc: 0.9966 - val_loss: 0.1752 - val_categorical_accuracy: 0.9405 - val_recall: 0.9375 - val_precision: 0.9432 - val_auc: 0.9953\n",
"Epoch 10/50\n",
"125/125 [==============================] - 139s 1s/step - loss: 0.1374 - categorical_accuracy: 0.9537 - recall: 0.9475 - precision: 0.9592 - auc: 0.9970 - val_loss: 0.1875 - val_categorical_accuracy: 0.9345 - val_recall: 0.9304 - val_precision: 0.9352 - val_auc: 0.9942\n"
]
}
],
"source": [
"model = DiseaseClassification()\n",
"model.process()"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# Evaluation"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"import seaborn as sns\n",
"import matplotlib.pyplot as plt\n",
"from sklearn.metrics import classification_report, confusion_matrix "
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"63/63 [==============================] - 28s 454ms/step\n"
]
}
],
"source": [
"Y = validation_generator.classes\n",
"Y_pred = model.model.predict(validation_generator)\n",
"Y_pred = np.argmax(Y_pred, axis=1)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" precision recall f1-score support\n",
"\n",
" CCI_Caterpillars 1.00 1.00 1.00 198\n",
" CCI_Leaflets 1.00 0.99 1.00 159\n",
"WCLWD_DryingofLeaflets 0.88 0.98 0.93 217\n",
" WCLWD_Flaccidity 0.93 0.88 0.91 215\n",
" WCLWD_Yellowing 0.93 0.88 0.90 216\n",
"\n",
" accuracy 0.94 1005\n",
" macro avg 0.95 0.95 0.95 1005\n",
" weighted avg 0.95 0.94 0.94 1005\n",
"\n"
]
}
],
"source": [
"# Classification Report\n",
"print(classification_report(Y, Y_pred, target_names=class_dict_rev.values()))"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 1000x1000 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Confusion Matrix\n",
"cm = confusion_matrix(Y, Y_pred)\n",
"class_names = class_dict_rev.values()\n",
"df_cm = pd.DataFrame(\n",
" cm, \n",
" index=class_names, \n",
" columns=class_names\n",
" ).astype(int)\n",
"\n",
"plt.figure(figsize=(10, 10))\n",
"sns.heatmap(df_cm, annot=True, fmt='d', cmap='Blues')\n",
"plt.xlabel(\"Predicted\")\n",
"plt.ylabel(\"Actual\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import tensorflow as tf "
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [],
"source": [
"loaded_model = tf.keras.models.load_model(model_weights)\n",
"loaded_model.compile(\n",
" optimizer='Adam',\n",
" loss='categorical_crossentropy',\n",
" metrics=[\n",
" tf.keras.metrics.CategoricalAccuracy(),\n",
" tf.keras.metrics.Recall(),\n",
" tf.keras.metrics.Precision()\n",
" ]\n",
" )\n",
"\n",
"class_dict = {\n",
" 0: 'CCI_Caterpillars',\n",
" 1: 'CCI_Leaflets',\n",
" 2: 'WCLWD_DryingofLeaflets',\n",
" 3: 'WCLWD_Flaccidity',\n",
" 4: 'WCLWD_Yellowing'}\n",
"\n",
"class_dict_rev = {v: k for k, v in class_dict.items()}"
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {},
"outputs": [],
"source": [
"def retrieve_treatments(\n",
" disease_name,\n",
" path = 'data/treatments.xlsx'\n",
" ):\n",
" treatments = pd.read_excel(path)\n",
" # fill nan in Disease name with previous value\n",
" treatments['ClassName'] = treatments['ClassName'].fillna(method='ffill')\n",
" treatments['Disease name'] = treatments['Disease name'].fillna(method='ffill')\n",
" treatments = treatments[treatments['ClassName'] == disease_name]\n",
" del treatments['ClassName']\n",
" return treatments.to_dict('records')\n",
"\n",
"def retrieve_diseases(\n",
" disease_name,\n",
" path = 'data/Diseases.xlsx'\n",
" ): \n",
" diseases = pd.read_excel(path)\n",
" diseases = diseases[diseases['ClassName'] == disease_name]\n",
" del diseases['ClassName']\n",
"\n",
" return diseases.to_dict('records')\n",
"\n",
"def retrieve_disease_data(img_path):\n",
" img = cv.imread(img_path)\n",
" img = cv.resize(img, target_size)\n",
" img = preprocessing_function(img)\n",
" img = np.expand_dims(img, axis=0)\n",
" P = loaded_model.predict(img)\n",
" disease_id = np.argmax(P)\n",
" disease = class_dict[disease_id]\n",
" \n",
" treatments = retrieve_treatments(disease)\n",
" disease_data = retrieve_diseases(disease)\n",
"\n",
" return {\n",
" 'disease': disease,\n",
" 'treatments': treatments,\n",
" 'disease_data': disease_data\n",
" }"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1/1 [==============================] - 0s 35ms/step\n"
]
},
{
"data": {
"text/plain": [
"{'disease': 'CCI_Caterpillars',\n",
" 'treatments': [{'Disease name': 'Caterpillar infestation',\n",
" 'Treatment recommendation': 'Chemical bio catepillar and webworm control',\n",
" 'Treatment description': 'Biological Caterpillar and Webworm Control is an Easy-to-Mix Liquid concentrate. The active ingredient is Bacillus thuringiensis subsp. kurstaki (BTK), a species specific biological larvicide that kills only leaf eating caterpillar pests (Lepidoptera).',\n",
" 'Image URL': 'https://wholesalegrowersdirect.com/wp-content/uploads/2019/04/HGBCWC_1.jpg',\n",
" 'Available on': 'Biological Caterpillar & Webworm Control - Summit® Responsible Solutions (summitchemical.com)'},\n",
" {'Disease name': 'Caterpillar infestation',\n",
" 'Treatment recommendation': 'Bactur catepillar killer',\n",
" 'Treatment description': \"Grosafe's Bactur Caterpillar Killer controls caterpillars on vegetables, fruit and ornamentals using a naturally occurring bacteria, Bacillus thuringiensis. Apply when caterpillars are young before plant damage occurs.\\xa0\",\n",
" 'Image URL': 'https://plantdoctor.co.nz/assets/uploads/2016/05/Kiwicare-Organic-Caterpillar-Control.jpg',\n",
" 'Available on': 'Grosafe Bactur Caterpillar Killer 10g – GardenBarn'},\n",
" {'Disease name': 'Caterpillar infestation',\n",
" 'Treatment recommendation': 'Safer caterpillar killer',\n",
" 'Treatment description': 'Safer Brand Caterpillar Killer Concentrate uses a naturally occurring bacterium to kill and control caterpillars and other leaf-eating worms. The caterpillars and worms that are damaging your plants will stop feeding immediately after ingesting the bacterium, which is known scientifically as Bacillus Thuringiensis var. Kurstak. when Caterpillar Killer is used as directed, it has no effect on birds, earthworms, or beneficial insects such as honeybees and ladybugs',\n",
" 'Image URL': 'https://th.bing.com/th/id/OIP.AnFfBaKtSl8vHIa6s3nIXAHaHa?pid=ImgDet&w=600&h=600&rs=1',\n",
" 'Available on': 'Safer® Brand Caterpillar Killer II Concentrate -16 oz (saferbrand.com)'},\n",
" {'Disease name': 'Caterpillar infestation',\n",
" 'Treatment recommendation': 'Monocrotophos',\n",
" 'Treatment description': 'Monocrotophos is an organophosphate insecticide which is injected into the trunk of the tree to kill the leaf eating caterpillar, Opisina arenosella.',\n",
" 'Image URL': 'https://th.bing.com/th/id/R.191c69ca014d3dabd15538b89feed7ab?rik=S2NXZj%2fMeHWK0Q&pid=ImgRaw&r=0',\n",
" 'Available on': 'Monocrotophos 36 Sl Online | Buy Insecticide Online | Insecticide (agribegri.com)'}],\n",
" 'disease_data': [{'Disease Name ': 'Caterpillar infestation',\n",
" 'Image': 'https://th.bing.com/th/id/OIP.vqfwYctyk-apjCgn5sFqhwHaFj?pid=ImgDet&rs=1',\n",
" 'Disease description': 'Coconut caterpillar infestation is a condition where the black headed caterpillar,\\xa0Opisina arenosella\\xa0 feed on the leaves of trees, it is a serious pest causing significant damage and yield loss.',\n",
" 'Symptoms ': '1. Irregular holes in tree leaves: Infested leaves often exhibit irregularly shaped holes caused by the feeding activity of the caterpillars.\\n2. Skeletonization of leaves: The caterpillars consume the tender tissues of the leaves, leaving behind only the veins or a skeletal structure.\\n3. Defoliation: Severe infestations can result in significant leaf loss, leading to partial or complete defoliation of the tree.\\n4. Stunted growth: As the caterpillars continuously feed on the leaves, the overall growth and development of the tree may be stunted.\\n5. Presence of caterpillars: During the infestation, caterpillars can be observed on the tree leaves. They are usually light-colored or green, with distinct body segments and head capsules.\\n6. Silk threads or webs: In some cases, the caterpillars may leave behind silk threads or webs on the tree leaves as they move or feed.\\n7. Frass or droppings: The caterpillars excrete droppings called frass, which can be found on the leaves, ground, or surrounding areas. The presence of frass can be an indicator of an infestation.'}]}"
]
},
"execution_count": 45,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"img_path = 'data/images/CCI_Caterpillars/CCI_1_26_jpg.rf.d14b2b0861f0b77adfb4947c8259f87a.jpg'\n",
"retrieve_disease_data(img_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.8.13 ('tf26')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "d4b521e29a846470c96e928a1c4aafac58a12234cdaa98f9ca60bc431873fee6"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment