Commit c04a9215 authored by Shehan R.H.A's avatar Shehan R.H.A

shuffle commit

parent 0704f34f
{"cells":[{"attachments":{},"cell_type":"markdown","metadata":{"id":"sHGlTPw4o6zg"},"source":["**Mount the google drive to use data in the google Drive**"]},{"cell_type":"code","execution_count":69,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":5049,"status":"ok","timestamp":1683732979861,"user":{"displayName":"Research ADHD","userId":"15453477032449770552"},"user_tz":-330},"id":"QWEznzron_JP","outputId":"13d4620f-4598-4c2f-d511-d6d4ae2bcf83"},"outputs":[{"name":"stdout","output_type":"stream","text":["Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount(\"/content/drive\", force_remount=True).\n"]}],"source":["from google.colab import drive\n","drive.mount('/content/drive')"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"T9Ny4dmwODdZ"},"source":["## Import neccesary imports"]},{"cell_type":"code","execution_count":70,"metadata":{"executionInfo":{"elapsed":12,"status":"ok","timestamp":1683732979862,"user":{"displayName":"Research ADHD","userId":"15453477032449770552"},"user_tz":-330},"id":"MqrtOs_Cle0d"},"outputs":[],"source":["import os\n","import pandas as pd\n","import numpy as np\n","import csv \n","import cv2\n","import tensorflow as tf"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"SgUTJ3b5o9J3"},"source":["### Avoid OOM Errors by setting GPU memory consumption growth"]},{"cell_type":"code","execution_count":71,"metadata":{"executionInfo":{"elapsed":12,"status":"ok","timestamp":1683732979863,"user":{"displayName":"Research ADHD","userId":"15453477032449770552"},"user_tz":-330},"id":"xa9fLuq6o75y"},"outputs":[],"source":["gpus = tf.config.experimental.list_physical_devices('GPU')\n","for gpu in gpus: \n"," tf.config.experimental.set_memory_growth(gpu, True)"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"OLA8w9rmto7z"},"source":["### Take video inputs"]},{"cell_type":"code","execution_count":72,"metadata":{"executionInfo":{"elapsed":13,"status":"ok","timestamp":1683732979864,"user":{"displayName":"Research ADHD","userId":"15453477032449770552"},"user_tz":-330},"id":"KjE0vEk8ttr6"},"outputs":[],"source":["def getVideo(path):\n"," # Load the video file\n"," cap = cv2.VideoCapture(path)\n","\n"," # Define the shape of the input frames\n"," input_shape = (224, 224, 3)\n","\n"," # Initialize an empty list to store the preprocessed frames\n"," frames = []\n","\n"," for i in range(10):\n"," ret, frame = cap.read() \n"," # Resize the frame to match the input shape of the CNN model\n"," resized_frame = cv2.resize(frame, input_shape[:2])\n"," \n"," # Preprocess the frame by subtracting the mean pixel value and scaling the pixel values\n"," preprocessed_frame = tf.keras.applications.mobilenet_v2.preprocess_input(resized_frame)\n"," \n"," # Add the preprocessed frame to the list of frames\n"," frames.append(preprocessed_frame)\n"," \n"," # Convert the list of frames to a numpy array\n"," return np.array(frames)"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"3MWM9sOv4GBo"},"source":[]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"H_1vw0Q_tlj5"},"source":["### Take csv inputs"]},{"cell_type":"code","execution_count":73,"metadata":{"executionInfo":{"elapsed":12,"status":"ok","timestamp":1683732979864,"user":{"displayName":"Research ADHD","userId":"15453477032449770552"},"user_tz":-330},"id":"OVop4NQrkeBx"},"outputs":[],"source":["# Open the CSV file\n","def getDf(path):\n"," with open(path, 'r') as csv_file:\n","\n"," # Create a CSV reader object\n"," csv_reader = csv.reader(csv_file)\n","\n"," # Define the start and end rows of the segment you want to extract\n"," start_row = 2\n"," end_row = 1001\n","\n"," # Create an empty list to store the rows in the desired segment\n"," rows = []\n","\n"," # Use a loop to iterate over each row in the CSV file\n"," for row_num, row in enumerate(csv_reader):\n"," \n"," # Check if the row number is within the desired segment\n"," if row_num >= start_row and row_num <= end_row:\n"," \n"," # Append the row to the list\n"," rows.append(row)\n","\n"," # Create a pandas DataFrame from the list of rows\n"," df = pd.DataFrame(rows)\n"," return df.iloc[:, 4:9]"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"_Qhml8_apyPd"},"source":["### Take ADHD data"]},{"cell_type":"code","execution_count":74,"metadata":{"executionInfo":{"elapsed":13,"status":"ok","timestamp":1683732979865,"user":{"displayName":"Research ADHD","userId":"15453477032449770552"},"user_tz":-330},"id":"N8W4YyJ6p2cm"},"outputs":[],"source":["files_adhd_video = os.listdir(\"/content/drive/MyDrive/Colab Notebooks/MathsGame/Maths_VIDEO/WithADHD\")\n","files_adhd_csv = os.listdir(\"/content/drive/MyDrive/Colab Notebooks/MathsGame/Maths_EEG/WithADHD\")\n","ADHD = 1"]},{"cell_type":"code","execution_count":75,"metadata":{"executionInfo":{"elapsed":13,"status":"ok","timestamp":1683732979865,"user":{"displayName":"Research ADHD","userId":"15453477032449770552"},"user_tz":-330},"id":"znJKi46ODXjY"},"outputs":[],"source":["# Define the shape of the empty array\n","X_empty_shape_csv = (0, 1000, 5)\n","X_empty_shape_video = (0, 224, 224, 3)"]},{"cell_type":"code","execution_count":76,"metadata":{"executionInfo":{"elapsed":16070,"status":"ok","timestamp":1683732995922,"user":{"displayName":"Research ADHD","userId":"15453477032449770552"},"user_tz":-330},"id":"AzFxx_grqT_C"},"outputs":[],"source":["# Create empty lists\n","arr_adhd_csv = []\n","arr_adhd_video = []\n","\n","# Iterate over the files\n","for file_name in files_adhd_csv:\n"," # Load the CSV data\n"," data_csv = getDf(os.path.join(\"/content/drive/MyDrive/Colab Notebooks/MathsGame/Maths_EEG/WithADHD\",file_name))\n"," # Load the video data\n"," data_video = getVideo(os.path.join(\"/content/drive/MyDrive/Colab Notebooks/MathsGame/Maths_VIDEO/WithADHD\",file_name.split(\".\")[0]+\".mp4\"))\n"," # Append the data to the lists\n"," arr_adhd_csv.append(data_csv)\n"," arr_adhd_video.append(data_video)\n","\n","# Convert the lists to NumPy arrays\n","X_for_aug_adhd_csv = np.array(arr_adhd_csv)\n","X_for_aug_adhd_video = np.array(arr_adhd_video)"]},{"cell_type":"code","execution_count":77,"metadata":{"executionInfo":{"elapsed":31,"status":"ok","timestamp":1683732995923,"user":{"displayName":"Research ADHD","userId":"15453477032449770552"},"user_tz":-330},"id":"tRgw7Zol0pB7"},"outputs":[],"source":["X_video_shape = (0,10, 224, 224, 3)\n","X_csv = np.empty(X_empty_shape_csv)\n","X_video = np.empty(X_video_shape)\n","X_adhd_all_csv = np.empty(X_empty_shape_csv)\n","X_adhd_all_video = np.empty(X_video_shape)"]},{"cell_type":"code","execution_count":78,"metadata":{"executionInfo":{"elapsed":1601,"status":"ok","timestamp":1683732997494,"user":{"displayName":"Research ADHD","userId":"15453477032449770552"},"user_tz":-330},"id":"w1k6gVoxqXc6"},"outputs":[],"source":["X_csv = X_csv.astype(float)\n","X_for_aug_adhd_csv = X_for_aug_adhd_csv.astype(float)\n","X_adhd_all_csv = X_adhd_all_csv.astype(float)\n","#video\n","X_video = X_video.astype(float)\n","X_for_aug_adhd_video = X_for_aug_adhd_video.astype(float)\n","X_adhd_all_video = X_adhd_all_video.astype(float)"]},{"cell_type":"code","execution_count":79,"metadata":{"executionInfo":{"elapsed":38,"status":"ok","timestamp":1683732997495,"user":{"displayName":"Research ADHD","userId":"15453477032449770552"},"user_tz":-330},"id":"OMjyhOLOqawv"},"outputs":[],"source":["arr_adhd = [] #for y values having ADHD\n","X_adhd_all_csv = np.concatenate((X_adhd_all_csv, X_for_aug_adhd_csv), axis=0)\n","X_adhd_all_video = np.concatenate((X_adhd_all_video, X_for_aug_adhd_video), axis=0)\n","for i in range(X_for_aug_adhd_csv.shape[0]):\n"," arr_adhd.append(ADHD)"]},{"cell_type":"code","execution_count":80,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":38,"status":"ok","timestamp":1683732997496,"user":{"displayName":"Research ADHD","userId":"15453477032449770552"},"user_tz":-330},"id":"4EZVTGz3qdTe","outputId":"21ed981f-4668-42fd-c67e-d535e3c76fc9"},"outputs":[{"data":{"text/plain":["(99, 1000, 5)"]},"execution_count":80,"metadata":{},"output_type":"execute_result"}],"source":["X_for_aug_adhd_csv.shape"]},{"cell_type":"code","execution_count":81,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":38,"status":"ok","timestamp":1683732997498,"user":{"displayName":"Research ADHD","userId":"15453477032449770552"},"user_tz":-330},"id":"9YvyeHPM9MsE","outputId":"b9666437-d2c3-4e59-c29d-b5afba01d256"},"outputs":[{"data":{"text/plain":["99"]},"execution_count":81,"metadata":{},"output_type":"execute_result"}],"source":["len(arr_adhd)"]},{"cell_type":"code","execution_count":82,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":35,"status":"ok","timestamp":1683732997500,"user":{"displayName":"Research ADHD","userId":"15453477032449770552"},"user_tz":-330},"id":"1R3dLxz61Evw","outputId":"f1792dd8-a7b2-4f15-bdc4-460359b2abce"},"outputs":[{"data":{"text/plain":["(99, 10, 224, 224, 3)"]},"execution_count":82,"metadata":{},"output_type":"execute_result"}],"source":["X_for_aug_adhd_video.shape"]},{"cell_type":"code","execution_count":83,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":34,"status":"ok","timestamp":1683732997501,"user":{"displayName":"Research ADHD","userId":"15453477032449770552"},"user_tz":-330},"id":"0o4kv_wNqtmS","outputId":"72c657d7-1a76-456e-e544-9c6b5422551b"},"outputs":[{"data":{"text/plain":["(99, 1000, 5)"]},"execution_count":83,"metadata":{},"output_type":"execute_result"}],"source":["X_adhd_all_csv.shape"]},{"cell_type":"code","execution_count":84,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":33,"status":"ok","timestamp":1683732997503,"user":{"displayName":"Research ADHD","userId":"15453477032449770552"},"user_tz":-330},"id":"LsFpAxp29IEB","outputId":"206e8762-9b8a-4339-c85a-3b7b340fb73c"},"outputs":[{"data":{"text/plain":["(99, 10, 224, 224, 3)"]},"execution_count":84,"metadata":{},"output_type":"execute_result"}],"source":["X_adhd_all_video.shape"]},{"cell_type":"code","execution_count":85,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":27,"status":"ok","timestamp":1683732997504,"user":{"displayName":"Research ADHD","userId":"15453477032449770552"},"user_tz":-330},"id":"3NDVhNk6qvZb","outputId":"73f6942b-89d7-4e45-a12e-aca1eb20e747"},"outputs":[{"data":{"text/plain":["99"]},"execution_count":85,"metadata":{},"output_type":"execute_result"}],"source":["len(arr_adhd)"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"JaOGaqxyjJeI"},"source":["### Take non ADHD data"]},{"cell_type":"code","execution_count":86,"metadata":{"executionInfo":{"elapsed":23,"status":"ok","timestamp":1683732997505,"user":{"displayName":"Research ADHD","userId":"15453477032449770552"},"user_tz":-330},"id":"rtsPvpeXjQcs"},"outputs":[],"source":["files_not_adhd_video = os.listdir(\"/content/drive/MyDrive/Colab Notebooks/MathsGame/Maths_VIDEO/NotADHD\")\n","files_not_adhd_csv = os.listdir(\"/content/drive/MyDrive/Colab Notebooks/MathsGame/Maths_EEG/NotADHD\")\n","NOT_ADHD = 0"]},{"cell_type":"code","execution_count":87,"metadata":{"executionInfo":{"elapsed":15172,"status":"ok","timestamp":1683733012655,"user":{"displayName":"Research ADHD","userId":"15453477032449770552"},"user_tz":-330},"id":"99ylgrTJD3Fj"},"outputs":[],"source":["# Create empty list\n","arr_non_adhd_csv = []\n","arr_non_adhd_video = []\n","\n","# Iterate over the files\n","for file_name in files_not_adhd_csv:\n"," # Load the CSV data\n"," data_csv = getDf(os.path.join(\"/content/drive/MyDrive/Colab Notebooks/MathsGame/Maths_EEG/NotADHD\",file_name))\n"," # Load the video data\n"," data_video = getVideo(os.path.join(\"/content/drive/MyDrive/Colab Notebooks/MathsGame/Maths_VIDEO/NotADHD\",file_name.split(\".\")[0]+\".mp4\"))\n"," # Append the data to the lists\n"," arr_non_adhd_csv.append(data_csv)\n"," arr_non_adhd_video.append(data_video)\n","\n","# Convert the lists to NumPy arrays\n","X_for_aug_non_adhd = np.array(arr_adhd_csv)\n","X_for_aug_non_adhd_video = np.array(arr_adhd_video)"]},{"cell_type":"code","execution_count":88,"metadata":{"executionInfo":{"elapsed":18,"status":"ok","timestamp":1683733012656,"user":{"displayName":"Research ADHD","userId":"15453477032449770552"},"user_tz":-330},"id":"N2Yyq3zflXSi"},"outputs":[],"source":["arr_no_adhd = [] #for y values not having ADHD\n","X_non_adhd_all_csv = np.empty(X_empty_shape_csv)\n","X_for_aug_non_adhd_csv = np.empty(X_empty_shape_csv)\n","X_non_adhd_all_video = np.empty(X_video_shape)\n","X_for_aug_non_adhd_video = np.empty(X_video_shape)"]},{"cell_type":"code","execution_count":89,"metadata":{"executionInfo":{"elapsed":18,"status":"ok","timestamp":1683733012657,"user":{"displayName":"Research ADHD","userId":"15453477032449770552"},"user_tz":-330},"id":"ovHHn7mbRVj2"},"outputs":[],"source":["X_for_aug_non_adhd_csv = X_for_aug_non_adhd_csv.astype(float)\n","X_non_adhd_all_csv = X_non_adhd_all_csv.astype(float)\n","X_non_adhd_all_video = X_non_adhd_all_video.astype(float)\n","X_for_aug_non_adhd_video = X_for_aug_non_adhd_video.astype(float)"]},{"cell_type":"code","execution_count":90,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":18,"status":"ok","timestamp":1683733012657,"user":{"displayName":"Research ADHD","userId":"15453477032449770552"},"user_tz":-330},"id":"93x-MSsoRX76","outputId":"e8f6191a-26ee-4e35-ad78-6adfe94167c1"},"outputs":[{"data":{"text/plain":["(0, 1000, 5)"]},"execution_count":90,"metadata":{},"output_type":"execute_result"}],"source":["X_non_adhd_all_csv = np.concatenate((X_non_adhd_all_csv, X_for_aug_non_adhd_csv), axis=0)\n","for i in range(X_for_aug_non_adhd_csv.shape[0]):\n"," arr_no_adhd.append(NOT_ADHD)\n","X_non_adhd_all_csv.shape"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"obV-t2jsemYs"},"source":["## Combine Data"]},{"cell_type":"code","execution_count":91,"metadata":{"executionInfo":{"elapsed":14,"status":"ok","timestamp":1683733012657,"user":{"displayName":"Research ADHD","userId":"15453477032449770552"},"user_tz":-330},"id":"EQkniTZiRkUh"},"outputs":[],"source":["Y = arr_adhd + arr_no_adhd \n","X_csv = np.concatenate((X_adhd_all_csv,X_non_adhd_all_csv), axis=0)\n","X_video = np.concatenate((X_adhd_all_video,X_non_adhd_all_video), axis=0) \n","Y = np.array(Y)\n","Y = Y.astype(int)\n"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"rWKr2Y7yVl_j"},"source":["## Pre-process the data set"]},{"cell_type":"code","execution_count":92,"metadata":{"executionInfo":{"elapsed":14,"status":"ok","timestamp":1683733012658,"user":{"displayName":"Research ADHD","userId":"15453477032449770552"},"user_tz":-330},"id":"pFU06_tlQzZL"},"outputs":[],"source":["from keras.utils import to_categorical\n","\n","# Convert labels to one-hot encoded vectors\n","Y = to_categorical(Y, num_classes=2)"]},{"cell_type":"code","execution_count":93,"metadata":{"executionInfo":{"elapsed":13,"status":"ok","timestamp":1683733012658,"user":{"displayName":"Research ADHD","userId":"15453477032449770552"},"user_tz":-330},"id":"siaUTmPRVqJG"},"outputs":[],"source":["# Get a random permutation of the indices\n","perm = np.random.permutation(len(Y))\n","\n","# Shuffle the arrays using the same permutation\n","Y = Y[perm]\n","X_video = X_video[perm]\n","X_csv = X_csv[perm]"]},{"cell_type":"code","execution_count":94,"metadata":{"executionInfo":{"elapsed":13,"status":"ok","timestamp":1683733012658,"user":{"displayName":"Research ADHD","userId":"15453477032449770552"},"user_tz":-330},"id":"nSJYTUPVE9Ty"},"outputs":[],"source":["n = len(Y)\n","di = n//9\n","Y_train, X_csv_train , X_video_train = Y[:di], X_csv[:di] , X_video[:di]\n","Y_test , X_csv_test , X_video_test = Y[di:], X_csv[di:], X_video[di:]"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"HW6KYHVCPAJj"},"source":["### Model architecture"]},{"cell_type":"code","execution_count":95,"metadata":{"executionInfo":{"elapsed":14,"status":"ok","timestamp":1683733012659,"user":{"displayName":"Research ADHD","userId":"15453477032449770552"},"user_tz":-330},"id":"N5sOhxSrouBm"},"outputs":[],"source":["cnn_model = tf.keras.Sequential([\n"," tf.keras.layers.Conv3D(32, (3, 3, 3), activation='relu', input_shape=(10, 224, 224, 3)),\n"," tf.keras.layers.MaxPooling3D((2, 2, 2)),\n"," tf.keras.layers.Conv3D(64, (3, 3, 3), activation='relu'),\n"," tf.keras.layers.MaxPooling3D((2, 2, 2)),\n"," tf.keras.layers.Flatten(),\n"," tf.keras.layers.Dense(128, activation='relu'),\n"," tf.keras.layers.Dense(2, activation='softmax')\n","])\n"]},{"cell_type":"code","execution_count":96,"metadata":{"executionInfo":{"elapsed":13,"status":"ok","timestamp":1683733012659,"user":{"displayName":"Research ADHD","userId":"15453477032449770552"},"user_tz":-330},"id":"WnnCUVvtPGQP"},"outputs":[],"source":["cnn_model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])"]},{"cell_type":"code","execution_count":97,"metadata":{"executionInfo":{"elapsed":12,"status":"ok","timestamp":1683733012659,"user":{"displayName":"Research ADHD","userId":"15453477032449770552"},"user_tz":-330},"id":"R6v0eGz8XF2E"},"outputs":[],"source":["# Define model architecture\n","model = tf.keras.Sequential([\n"," tf.keras.layers.Flatten(input_shape=(1000, 5)),\n"," tf.keras.layers.Dense(64, activation='relu'),\n"," tf.keras.layers.Dense(32, activation='relu'),\n"," tf.keras.layers.Dense(1, activation='sigmoid')\n","])\n","\n","# Compile model\n","model.compile(optimizer='adam',\n"," loss='binary_crossentropy',\n"," metrics=['accuracy'])\n"]},{"cell_type":"code","execution_count":98,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":4886,"status":"ok","timestamp":1683733017534,"user":{"displayName":"Research ADHD","userId":"15453477032449770552"},"user_tz":-330},"id":"YJ0At5ijvBf4","outputId":"66ab74b5-233a-4e09-8340-f0dc163d337e"},"outputs":[{"name":"stdout","output_type":"stream","text":["Epoch 1/10\n","1/1 [==============================] - 2s 2s/step - loss: 0.6351 - accuracy: 1.0000 - val_loss: 2.0911e-22 - val_accuracy: 1.0000\n","Epoch 2/10\n","1/1 [==============================] - 0s 211ms/step - loss: 1.4724e-22 - accuracy: 1.0000 - val_loss: 0.0000e+00 - val_accuracy: 1.0000\n","Epoch 3/10\n","1/1 [==============================] - 0s 186ms/step - loss: 0.0000e+00 - accuracy: 1.0000 - val_loss: 0.0000e+00 - val_accuracy: 1.0000\n","Epoch 4/10\n","1/1 [==============================] - 0s 174ms/step - loss: 0.0000e+00 - accuracy: 1.0000 - val_loss: 0.0000e+00 - val_accuracy: 1.0000\n","Epoch 5/10\n","1/1 [==============================] - 0s 171ms/step - loss: 0.0000e+00 - accuracy: 1.0000 - val_loss: 0.0000e+00 - val_accuracy: 1.0000\n","Epoch 6/10\n","1/1 [==============================] - 0s 164ms/step - loss: 0.0000e+00 - accuracy: 1.0000 - val_loss: 0.0000e+00 - val_accuracy: 1.0000\n","Epoch 7/10\n","1/1 [==============================] - 0s 169ms/step - loss: 0.0000e+00 - accuracy: 1.0000 - val_loss: 0.0000e+00 - val_accuracy: 1.0000\n","Epoch 8/10\n","1/1 [==============================] - 0s 167ms/step - loss: 0.0000e+00 - accuracy: 1.0000 - val_loss: 0.0000e+00 - val_accuracy: 1.0000\n","Epoch 9/10\n","1/1 [==============================] - 0s 166ms/step - loss: 0.0000e+00 - accuracy: 1.0000 - val_loss: 0.0000e+00 - val_accuracy: 1.0000\n","Epoch 10/10\n","1/1 [==============================] - 0s 158ms/step - loss: 0.0000e+00 - accuracy: 1.0000 - val_loss: 0.0000e+00 - val_accuracy: 1.0000\n"]}],"source":["# Train the model on the dataset\n","history_cnn = cnn_model.fit(X_video_train, Y, epochs=10, validation_split=0.2)"]},{"cell_type":"code","execution_count":99,"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":744},"executionInfo":{"elapsed":29,"status":"error","timestamp":1683733017534,"user":{"displayName":"Research ADHD","userId":"15453477032449770552"},"user_tz":-330},"id":"3KO5RnJfXYYN","outputId":"173edb44-d0a9-4e46-9c73-a895adf92e85"},"outputs":[{"name":"stdout","output_type":"stream","text":["Epoch 1/10\n"]},{"ename":"ValueError","evalue":"ignored","output_type":"error","traceback":["\u001b[0;31m---------------------------------------------------------------------------\u001b[0m","\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)","\u001b[0;32m<ipython-input-99-86b6726ab2b7>\u001b[0m in \u001b[0;36m<cell line: 2>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# Train the model on the dataset\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mhistory\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX_csv_train\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mY\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mepochs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m10\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalidation_split\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0.2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m","\u001b[0;32m/usr/local/lib/python3.10/dist-packages/keras/utils/traceback_utils.py\u001b[0m in \u001b[0;36merror_handler\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 68\u001b[0m \u001b[0;31m# To get the full stack trace, call:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 69\u001b[0m \u001b[0;31m# `tf.debugging.disable_traceback_filtering()`\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 70\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwith_traceback\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfiltered_tb\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 71\u001b[0m \u001b[0;32mfinally\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 72\u001b[0m \u001b[0;32mdel\u001b[0m \u001b[0mfiltered_tb\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.10/dist-packages/keras/engine/training.py\u001b[0m in \u001b[0;36mtf__train_function\u001b[0;34m(iterator)\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[0mdo_return\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 15\u001b[0;31m \u001b[0mretval_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mag__\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconverted_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mag__\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mld\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstep_function\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mag__\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mld\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mag__\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mld\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0miterator\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfscope\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 16\u001b[0m \u001b[0;32mexcept\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[0mdo_return\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;31mValueError\u001b[0m: in user code:\n\n File \"/usr/local/lib/python3.10/dist-packages/keras/engine/training.py\", line 1284, in train_function *\n return step_function(self, iterator)\n File \"/usr/local/lib/python3.10/dist-packages/keras/engine/training.py\", line 1268, in step_function **\n outputs = model.distribute_strategy.run(run_step, args=(data,))\n File \"/usr/local/lib/python3.10/dist-packages/keras/engine/training.py\", line 1249, in run_step **\n outputs = model.train_step(data)\n File \"/usr/local/lib/python3.10/dist-packages/keras/engine/training.py\", line 1051, in train_step\n loss = self.compute_loss(x, y, y_pred, sample_weight)\n File \"/usr/local/lib/python3.10/dist-packages/keras/engine/training.py\", line 1109, in compute_loss\n return self.compiled_loss(\n File \"/usr/local/lib/python3.10/dist-packages/keras/engine/compile_utils.py\", line 265, in __call__\n loss_value = loss_obj(y_t, y_p, sample_weight=sw)\n File \"/usr/local/lib/python3.10/dist-packages/keras/losses.py\", line 142, in __call__\n losses = call_fn(y_true, y_pred)\n File \"/usr/local/lib/python3.10/dist-packages/keras/losses.py\", line 268, in call **\n return ag_fn(y_true, y_pred, **self._fn_kwargs)\n File \"/usr/local/lib/python3.10/dist-packages/keras/losses.py\", line 2156, in binary_crossentropy\n backend.binary_crossentropy(y_true, y_pred, from_logits=from_logits),\n File \"/usr/local/lib/python3.10/dist-packages/keras/backend.py\", line 5707, in binary_crossentropy\n return tf.nn.sigmoid_cross_entropy_with_logits(\n\n ValueError: `logits` and `labels` must have the same shape, received ((None, 1) vs (None, 2)).\n"]}],"source":["# Train the model on the dataset\n","history = model.fit(X_csv_train, Y, epochs=10, validation_split=0.2)"]},{"cell_type":"code","execution_count":null,"metadata":{"executionInfo":{"elapsed":27,"status":"aborted","timestamp":1683733017535,"user":{"displayName":"Research ADHD","userId":"15453477032449770552"},"user_tz":-330},"id":"0EdafLjBX3uE"},"outputs":[],"source":["# Freeze the layers of the CNN so they are not retrained during the joint training\n","for layer in cnn_model.layers:\n"," layer.trainable = False\n","\n","# Get the output of the last convolutional layer\n","cnn_output = cnn_model.layers[-2].output\n","\n","# Flatten the output into a vector\n","flattened_output = tf.keras.layers.Flatten()(cnn_output)\n","\n","# Feed the flattened output as input to the ANN\n","combined_model = tf.keras.models.Sequential([\n"," flattened_output,\n"," model\n","])\n","\n","# Compile the combined model\n","combined_model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])\n","\n","# Train the combined model on new data\n","combined_model.fit([X_video_train, X_csv_train], Y, epochs=10, validation_split=0.2)"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"c0VZ1vksPGJF"},"source":["## evaluate the model"]},{"cell_type":"code","execution_count":null,"metadata":{"executionInfo":{"elapsed":27,"status":"aborted","timestamp":1683733017535,"user":{"displayName":"Research ADHD","userId":"15453477032449770552"},"user_tz":-330},"id":"LP3YrQk_A0im"},"outputs":[],"source":["import matplotlib.pyplot as plt"]},{"cell_type":"code","execution_count":null,"metadata":{"executionInfo":{"elapsed":28,"status":"aborted","timestamp":1683733017536,"user":{"displayName":"Research ADHD","userId":"15453477032449770552"},"user_tz":-330},"id":"c43zv_NfDIe9"},"outputs":[],"source":["plt.plot(history.history['loss'])\n","plt.plot(history.history['val_loss'])\n","plt.title('Model loss')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch')\n","plt.legend(['Train', 'Validation'], loc='upper right')\n","plt.show()\n","\n","plt.plot(history.history['accuracy'])\n","plt.plot(history.history['val_accuracy'])\n","plt.title('Model accuracy')\n","plt.ylabel('Accuracy')\n","plt.xlabel('Epoch')\n","plt.legend(['Train', 'Validation'], loc='lower right')\n","plt.show()"]},{"cell_type":"code","execution_count":null,"metadata":{"executionInfo":{"elapsed":27,"status":"aborted","timestamp":1683733017536,"user":{"displayName":"Research ADHD","userId":"15453477032449770552"},"user_tz":-330},"id":"G5XUZO-FDO30"},"outputs":[],"source":["# Evaluate model on test set\n","test_loss, test_acc = combined_model.evaluate([X_video_test,X_csv_test], Y_test)"]},{"cell_type":"code","execution_count":null,"metadata":{"executionInfo":{"elapsed":27,"status":"aborted","timestamp":1683733017536,"user":{"displayName":"Research ADHD","userId":"15453477032449770552"},"user_tz":-330},"id":"rIKbxSeVDlC4"},"outputs":[],"source":["from sklearn.metrics import precision_score, recall_score, f1_score, ConfusionMatrixDisplay,confusion_matrix"]},{"cell_type":"code","execution_count":null,"metadata":{"executionInfo":{"elapsed":28,"status":"aborted","timestamp":1683733017537,"user":{"displayName":"Research ADHD","userId":"15453477032449770552"},"user_tz":-330},"id":"aEN9miVvDaxL"},"outputs":[],"source":["# Make predictions on test data\n","y_pred = combined_model.predict([X_video_test,X_csv_test])\n","\n","# Convert predictions to binary labels\n","y_pred = (y_pred > 0.5).astype(int)\n","\n","# Calculate precision, recall, and F1 scores\n","precision = precision_score(Y_test, y_pred)\n","recall = recall_score(Y_test, y_pred)\n","f1 = f1_score(Y_test, y_pred)\n","\n","print('Precision:', precision)\n","print('Recall:', recall)\n","print('F1 score:', f1)"]},{"cell_type":"code","execution_count":null,"metadata":{"executionInfo":{"elapsed":30,"status":"aborted","timestamp":1683733017539,"user":{"displayName":"Research ADHD","userId":"15453477032449770552"},"user_tz":-330},"id":"olorqG-KDxj5"},"outputs":[],"source":["data = {'Test_loss':test_loss*100, 'Test_acc':test_acc*100, 'Precision': precision*100, 'Recall': recall*100,'F1 score': f1*100}\n","courses = list(data.keys())\n","values = list(data.values())\n","fig = plt.figure(figsize = (10, 5))\n","colr = ['red', 'green', 'black', 'blue', 'orange']\n"," # creating the bar plot\n","plt.bar(courses, values, color = colr,\n"," width = 0.4)\n"," \n","plt.xlabel(\"Evaluation Criteria\")\n","plt.ylabel(\"Out of 100%\")\n","plt.title(\"CNN evaluation graph\")\n","plt.show()"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"clardaFZtKJ0"},"source":["## Predict for new values"]},{"cell_type":"code","execution_count":null,"metadata":{"executionInfo":{"elapsed":30,"status":"aborted","timestamp":1683733017539,"user":{"displayName":"Research ADHD","userId":"15453477032449770552"},"user_tz":-330},"id":"Fp2c_3xXtTwC"},"outputs":[],"source":["path_csv = \"/content/drive/MyDrive/Colab Notebooks/MathsGame/TestData/eeg/P135.csv\"\n","path_video = \"/content/drive/MyDrive/Colab Notebooks/MathsGame/TestData/video/P35.mp4\"\n","X_new_csv = np.array(getDf(path_csv))\n","X_new_video = np.array(getVideo(path_video))\n","X_new_csv = X_new_csv.astype(float)\n","X_new_video = X_new_video.astype(float)"]},{"cell_type":"code","execution_count":null,"metadata":{"executionInfo":{"elapsed":30,"status":"aborted","timestamp":1683733017539,"user":{"displayName":"Research ADHD","userId":"15453477032449770552"},"user_tz":-330},"id":"IdeGNoFLtZjI"},"outputs":[],"source":["# Make predictions on new data\n","y_pred = cnn_model.predict(np.expand_dims(X_new_video, 0))\n","y_pred"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"CIDzh_9Fspmo"},"source":["## export and load the model"]},{"cell_type":"code","execution_count":null,"metadata":{"executionInfo":{"elapsed":31,"status":"aborted","timestamp":1683733017540,"user":{"displayName":"Research ADHD","userId":"15453477032449770552"},"user_tz":-330},"id":"_-_fdsvZFBfE"},"outputs":[],"source":["from tensorflow.keras.models import load_model"]},{"cell_type":"code","execution_count":null,"metadata":{"executionInfo":{"elapsed":31,"status":"aborted","timestamp":1683733017540,"user":{"displayName":"Research ADHD","userId":"15453477032449770552"},"user_tz":-330},"id":"t1crwULws2JV"},"outputs":[],"source":["combined_model.save(os.path.join('models','combined_model_model_math.h5'))"]}],"metadata":{"accelerator":"GPU","colab":{"authorship_tag":"ABX9TyO7AlHQ4YcLLKitBHaqXpyn","machine_shape":"hm","mount_file_id":"1ElxyHvhNwgqUVGNkG8CUuIgk8U-xgj8T","provenance":[]},"gpuClass":"standard","kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"name":"python"}},"nbformat":4,"nbformat_minor":0}
{"cells":[{"attachments":{},"cell_type":"markdown","metadata":{"id":"sHGlTPw4o6zg"},"source":["**Mount the google drive to use data in the google Drive**"]},{"cell_type":"code","execution_count":69,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":5049,"status":"ok","timestamp":1683732979861,"user":{"displayName":"Research ADHD","userId":"15453477032449770552"},"user_tz":-330},"id":"QWEznzron_JP","outputId":"13d4620f-4598-4c2f-d511-d6d4ae2bcf83"},"outputs":[{"name":"stdout","output_type":"stream","text":["Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount(\"/content/drive\", force_remount=True).\n"]}],"source":["from google.colab import drive\n","drive.mount('/content/drive')"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"T9Ny4dmwODdZ"},"source":["## Import neccesary imports"]},{"cell_type":"code","execution_count":70,"metadata":{"executionInfo":{"elapsed":12,"status":"ok","timestamp":1683732979862,"user":{"displayName":"Research ADHD","userId":"15453477032449770552"},"user_tz":-330},"id":"MqrtOs_Cle0d"},"outputs":[],"source":["import os\n","import pandas as pd\n","import numpy as np\n","import csv \n","import cv2\n","import tensorflow as tf"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"SgUTJ3b5o9J3"},"source":["### Avoid OOM Errors by setting GPU memory consumption growth"]},{"cell_type":"code","execution_count":71,"metadata":{"executionInfo":{"elapsed":12,"status":"ok","timestamp":1683732979863,"user":{"displayName":"Research ADHD","userId":"15453477032449770552"},"user_tz":-330},"id":"xa9fLuq6o75y"},"outputs":[],"source":["gpus = tf.config.experimental.list_physical_devices('GPU')\n","for gpu in gpus: \n"," tf.config.experimental.set_memory_growth(gpu, True)"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"OLA8w9rmto7z"},"source":["### Take video inputs"]},{"cell_type":"code","execution_count":72,"metadata":{"executionInfo":{"elapsed":13,"status":"ok","timestamp":1683732979864,"user":{"displayName":"Research ADHD","userId":"15453477032449770552"},"user_tz":-330},"id":"KjE0vEk8ttr6"},"outputs":[],"source":["def getVideo(path):\n"," # Load the video file\n"," cap = cv2.VideoCapture(path)\n","\n"," # Define the shape of the input frames\n"," input_shape = (224, 224, 3)\n","\n"," # Initialize an empty list to store the preprocessed frames\n"," frames = []\n","\n"," for i in range(10):\n"," ret, frame = cap.read() \n"," # Resize the frame to match the input shape of the CNN model\n"," resized_frame = cv2.resize(frame, input_shape[:2])\n"," \n"," # Preprocess the frame by subtracting the mean pixel value and scaling the pixel values\n"," preprocessed_frame = tf.keras.applications.mobilenet_v2.preprocess_input(resized_frame)\n"," \n"," # Add the preprocessed frame to the list of frames\n"," frames.append(preprocessed_frame)\n"," \n"," # Convert the list of frames to a numpy array\n"," return np.array(frames)"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"3MWM9sOv4GBo"},"source":[]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"H_1vw0Q_tlj5"},"source":["### Take csv inputs"]},{"cell_type":"code","execution_count":73,"metadata":{"executionInfo":{"elapsed":12,"status":"ok","timestamp":1683732979864,"user":{"displayName":"Research ADHD","userId":"15453477032449770552"},"user_tz":-330},"id":"OVop4NQrkeBx"},"outputs":[],"source":["# Open the CSV file\n","def getDf(path):\n"," with open(path, 'r') as csv_file:\n","\n"," # Create a CSV reader object\n"," csv_reader = csv.reader(csv_file)\n","\n"," # Define the start and end rows of the segment you want to extract\n"," start_row = 2\n"," end_row = 1001\n","\n"," # Create an empty list to store the rows in the desired segment\n"," rows = []\n","\n"," # Use a loop to iterate over each row in the CSV file\n"," for row_num, row in enumerate(csv_reader):\n"," \n"," # Check if the row number is within the desired segment\n"," if row_num >= start_row and row_num <= end_row:\n"," \n"," # Append the row to the list\n"," rows.append(row)\n","\n"," # Create a pandas DataFrame from the list of rows\n"," df = pd.DataFrame(rows)\n"," return df.iloc[:, 4:9]"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"_Qhml8_apyPd"},"source":["### Take ADHD data"]},{"cell_type":"code","execution_count":74,"metadata":{"executionInfo":{"elapsed":13,"status":"ok","timestamp":1683732979865,"user":{"displayName":"Research ADHD","userId":"15453477032449770552"},"user_tz":-330},"id":"N8W4YyJ6p2cm"},"outputs":[],"source":["files_adhd_video = os.listdir(\"/content/drive/MyDrive/Colab Notebooks/MathsGame/Maths_VIDEO/WithADHD\")\n","files_adhd_csv = os.listdir(\"/content/drive/MyDrive/Colab Notebooks/MathsGame/Maths_EEG/WithADHD\")\n","ADHD = 1"]},{"cell_type":"code","execution_count":75,"metadata":{"executionInfo":{"elapsed":13,"status":"ok","timestamp":1683732979865,"user":{"displayName":"Research ADHD","userId":"15453477032449770552"},"user_tz":-330},"id":"znJKi46ODXjY"},"outputs":[],"source":["# Define the shape of the empty array\n","X_empty_shape_csv = (0, 1000, 5)\n","X_empty_shape_video = (0, 224, 224, 3)"]},{"cell_type":"code","execution_count":76,"metadata":{"executionInfo":{"elapsed":16070,"status":"ok","timestamp":1683732995922,"user":{"displayName":"Research ADHD","userId":"15453477032449770552"},"user_tz":-330},"id":"AzFxx_grqT_C"},"outputs":[],"source":["# Create empty lists\n","arr_adhd_csv = []\n","arr_adhd_video = []\n","\n","# Iterate over the files\n","for file_name in files_adhd_csv:\n"," # Load the CSV data\n"," data_csv = getDf(os.path.join(\"/content/drive/MyDrive/Colab Notebooks/MathsGame/Maths_EEG/WithADHD\",file_name))\n"," # Load the video data\n"," data_video = getVideo(os.path.join(\"/content/drive/MyDrive/Colab Notebooks/MathsGame/Maths_VIDEO/WithADHD\",file_name.split(\".\")[0]+\".mp4\"))\n"," # Append the data to the lists\n"," arr_adhd_csv.append(data_csv)\n"," arr_adhd_video.append(data_video)\n","\n","# Convert the lists to NumPy arrays\n","X_for_aug_adhd_csv = np.array(arr_adhd_csv)\n","X_for_aug_adhd_video = np.array(arr_adhd_video)"]},{"cell_type":"code","execution_count":77,"metadata":{"executionInfo":{"elapsed":31,"status":"ok","timestamp":1683732995923,"user":{"displayName":"Research ADHD","userId":"15453477032449770552"},"user_tz":-330},"id":"tRgw7Zol0pB7"},"outputs":[],"source":["X_video_shape = (0,10, 224, 224, 3)\n","X_csv = np.empty(X_empty_shape_csv)\n","X_video = np.empty(X_video_shape)\n","X_adhd_all_csv = np.empty(X_empty_shape_csv)\n","X_adhd_all_video = np.empty(X_video_shape)"]},{"cell_type":"code","execution_count":78,"metadata":{"executionInfo":{"elapsed":1601,"status":"ok","timestamp":1683732997494,"user":{"displayName":"Research ADHD","userId":"15453477032449770552"},"user_tz":-330},"id":"w1k6gVoxqXc6"},"outputs":[],"source":["X_csv = X_csv.astype(float)\n","X_for_aug_adhd_csv = X_for_aug_adhd_csv.astype(float)\n","X_adhd_all_csv = X_adhd_all_csv.astype(float)\n","#video\n","X_video = X_video.astype(float)\n","X_for_aug_adhd_video = X_for_aug_adhd_video.astype(float)\n","X_adhd_all_video = X_adhd_all_video.astype(float)"]},{"cell_type":"code","execution_count":79,"metadata":{"executionInfo":{"elapsed":38,"status":"ok","timestamp":1683732997495,"user":{"displayName":"Research ADHD","userId":"15453477032449770552"},"user_tz":-330},"id":"OMjyhOLOqawv"},"outputs":[],"source":["arr_adhd = [] #for y values having ADHD \n","X_adhd_all_csv = np.concatenate((X_adhd_all_csv, X_for_aug_adhd_csv), axis=0)\n","X_adhd_all_video = np.concatenate((X_adhd_all_video, X_for_aug_adhd_video), axis=0)\n","for i in range(X_for_aug_adhd_csv.shape[0]):\n"," arr_adhd.append(ADHD)"]},{"cell_type":"code","execution_count":80,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":38,"status":"ok","timestamp":1683732997496,"user":{"displayName":"Research ADHD","userId":"15453477032449770552"},"user_tz":-330},"id":"4EZVTGz3qdTe","outputId":"21ed981f-4668-42fd-c67e-d535e3c76fc9"},"outputs":[{"data":{"text/plain":["(99, 1000, 5)"]},"execution_count":80,"metadata":{},"output_type":"execute_result"}],"source":["X_for_aug_adhd_csv.shape"]},{"cell_type":"code","execution_count":81,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":38,"status":"ok","timestamp":1683732997498,"user":{"displayName":"Research ADHD","userId":"15453477032449770552"},"user_tz":-330},"id":"9YvyeHPM9MsE","outputId":"b9666437-d2c3-4e59-c29d-b5afba01d256"},"outputs":[{"data":{"text/plain":["99"]},"execution_count":81,"metadata":{},"output_type":"execute_result"}],"source":["len(arr_adhd)"]},{"cell_type":"code","execution_count":82,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":35,"status":"ok","timestamp":1683732997500,"user":{"displayName":"Research ADHD","userId":"15453477032449770552"},"user_tz":-330},"id":"1R3dLxz61Evw","outputId":"f1792dd8-a7b2-4f15-bdc4-460359b2abce"},"outputs":[{"data":{"text/plain":["(99, 10, 224, 224, 3)"]},"execution_count":82,"metadata":{},"output_type":"execute_result"}],"source":["X_for_aug_adhd_video.shape"]},{"cell_type":"code","execution_count":83,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":34,"status":"ok","timestamp":1683732997501,"user":{"displayName":"Research ADHD","userId":"15453477032449770552"},"user_tz":-330},"id":"0o4kv_wNqtmS","outputId":"72c657d7-1a76-456e-e544-9c6b5422551b"},"outputs":[{"data":{"text/plain":["(99, 1000, 5)"]},"execution_count":83,"metadata":{},"output_type":"execute_result"}],"source":["X_adhd_all_csv.shape"]},{"cell_type":"code","execution_count":84,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":33,"status":"ok","timestamp":1683732997503,"user":{"displayName":"Research ADHD","userId":"15453477032449770552"},"user_tz":-330},"id":"LsFpAxp29IEB","outputId":"206e8762-9b8a-4339-c85a-3b7b340fb73c"},"outputs":[{"data":{"text/plain":["(99, 10, 224, 224, 3)"]},"execution_count":84,"metadata":{},"output_type":"execute_result"}],"source":["X_adhd_all_video.shape"]},{"cell_type":"code","execution_count":1,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":27,"status":"ok","timestamp":1683732997504,"user":{"displayName":"Research ADHD","userId":"15453477032449770552"},"user_tz":-330},"id":"3NDVhNk6qvZb","outputId":"73f6942b-89d7-4e45-a12e-aca1eb20e747"},"outputs":[{"ename":"NameError","evalue":"name 'arr_adhd' is not defined","output_type":"error","traceback":["\u001b[1;31m---------------------------------------------------------------------------\u001b[0m","\u001b[1;31mNameError\u001b[0m Traceback (most recent call last)","Cell \u001b[1;32mIn[1], line 1\u001b[0m\n\u001b[1;32m----> 1\u001b[0m \u001b[39mlen\u001b[39m(arr_adhd)\n","\u001b[1;31mNameError\u001b[0m: name 'arr_adhd' is not defined"]}],"source":["len(arr_adhd)"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"JaOGaqxyjJeI"},"source":["### Take non ADHD data"]},{"cell_type":"code","execution_count":86,"metadata":{"executionInfo":{"elapsed":23,"status":"ok","timestamp":1683732997505,"user":{"displayName":"Research ADHD","userId":"15453477032449770552"},"user_tz":-330},"id":"rtsPvpeXjQcs"},"outputs":[],"source":["files_not_adhd_video = os.listdir(\"/content/drive/MyDrive/Colab Notebooks/MathsGame/Maths_VIDEO/NotADHD\")\n","files_not_adhd_csv = os.listdir(\"/content/drive/MyDrive/Colab Notebooks/MathsGame/Maths_EEG/NotADHD\")\n","NOT_ADHD = 0"]},{"cell_type":"code","execution_count":87,"metadata":{"executionInfo":{"elapsed":15172,"status":"ok","timestamp":1683733012655,"user":{"displayName":"Research ADHD","userId":"15453477032449770552"},"user_tz":-330},"id":"99ylgrTJD3Fj"},"outputs":[],"source":["# Create empty list\n","arr_non_adhd_csv = []\n","arr_non_adhd_video = []\n","\n","# Iterate over the file\n","for file_name in files_not_adhd_csv:\n"," # Load the CSV data\n"," data_csv = getDf(os.path.join(\"/content/drive/MyDrive/Colab Notebooks/MathsGame/Maths_EEG/NotADHD\",file_name))\n"," # Load the video data\n"," data_video = getVideo(os.path.join(\"/content/drive/MyDrive/Colab Notebooks/MathsGame/Maths_VIDEO/NotADHD\",file_name.split(\".\")[0]+\".mp4\"))\n"," # Append the data to the lists\n"," arr_non_adhd_csv.append(data_csv)\n"," arr_non_adhd_video.append(data_video)\n","\n","# Convert the lists to NumPy arrays\n","X_for_aug_non_adhd = np.array(arr_adhd_csv)\n","X_for_aug_non_adhd_video = np.array(arr_adhd_video)"]},{"cell_type":"code","execution_count":88,"metadata":{"executionInfo":{"elapsed":18,"status":"ok","timestamp":1683733012656,"user":{"displayName":"Research ADHD","userId":"15453477032449770552"},"user_tz":-330},"id":"N2Yyq3zflXSi"},"outputs":[],"source":["arr_no_adhd = [] #for y values not having ADHD\n","X_non_adhd_all_csv = np.empty(X_empty_shape_csv)\n","X_for_aug_non_adhd_csv = np.empty(X_empty_shape_csv)\n","X_non_adhd_all_video = np.empty(X_video_shape)\n","X_for_aug_non_adhd_video = np.empty(X_video_shape)"]},{"cell_type":"code","execution_count":89,"metadata":{"executionInfo":{"elapsed":18,"status":"ok","timestamp":1683733012657,"user":{"displayName":"Research ADHD","userId":"15453477032449770552"},"user_tz":-330},"id":"ovHHn7mbRVj2"},"outputs":[],"source":["X_for_aug_non_adhd_csv = X_for_aug_non_adhd_csv.astype(float)\n","X_non_adhd_all_csv = X_non_adhd_all_csv.astype(float)\n","X_non_adhd_all_video = X_non_adhd_all_video.astype(float)\n","X_for_aug_non_adhd_video = X_for_aug_non_adhd_video.astype(float)"]},{"cell_type":"code","execution_count":90,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":18,"status":"ok","timestamp":1683733012657,"user":{"displayName":"Research ADHD","userId":"15453477032449770552"},"user_tz":-330},"id":"93x-MSsoRX76","outputId":"e8f6191a-26ee-4e35-ad78-6adfe94167c1"},"outputs":[{"data":{"text/plain":["(0, 1000, 5)"]},"execution_count":90,"metadata":{},"output_type":"execute_result"}],"source":["X_non_adhd_all_csv = np.concatenate((X_non_adhd_all_csv, X_for_aug_non_adhd_csv), axis=0)\n","for i in range(X_for_aug_non_adhd_csv.shape[0]):\n"," arr_no_adhd.append(NOT_ADHD)\n","X_non_adhd_all_csv.shape"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"obV-t2jsemYs"},"source":["## Combine Data"]},{"cell_type":"code","execution_count":91,"metadata":{"executionInfo":{"elapsed":14,"status":"ok","timestamp":1683733012657,"user":{"displayName":"Research ADHD","userId":"15453477032449770552"},"user_tz":-330},"id":"EQkniTZiRkUh"},"outputs":[],"source":["Y = arr_adhd + arr_no_adhd \n","X_csv = np.concatenate((X_adhd_all_csv,X_non_adhd_all_csv), axis=0)\n","X_video = np.concatenate((X_adhd_all_video,X_non_adhd_all_video), axis=0) \n","Y = np.array(Y)\n","Y = Y.astype(int)\n"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"rWKr2Y7yVl_j"},"source":["## Pre-process the data set"]},{"cell_type":"code","execution_count":92,"metadata":{"executionInfo":{"elapsed":14,"status":"ok","timestamp":1683733012658,"user":{"displayName":"Research ADHD","userId":"15453477032449770552"},"user_tz":-330},"id":"pFU06_tlQzZL"},"outputs":[],"source":["from keras.utils import to_categorical\n","\n","# Convert labels to one-hot encoded vectors\n","Y = to_categorical(Y, num_classes=2)"]},{"cell_type":"code","execution_count":93,"metadata":{"executionInfo":{"elapsed":13,"status":"ok","timestamp":1683733012658,"user":{"displayName":"Research ADHD","userId":"15453477032449770552"},"user_tz":-330},"id":"siaUTmPRVqJG"},"outputs":[],"source":["# Get a random permutation of the indices\n","perm = np.random.permutation(len(Y))\n","\n","# Shuffle the arrays using the same permutation\n","Y = Y[perm]\n","X_video = X_video[perm]\n","X_csv = X_csv[perm]"]},{"cell_type":"code","execution_count":94,"metadata":{"executionInfo":{"elapsed":13,"status":"ok","timestamp":1683733012658,"user":{"displayName":"Research ADHD","userId":"15453477032449770552"},"user_tz":-330},"id":"nSJYTUPVE9Ty"},"outputs":[],"source":["n = len(Y)\n","di = n//9\n","Y_train, X_csv_train , X_video_train = Y[:di], X_csv[:di] , X_video[:di]\n","Y_test , X_csv_test , X_video_test = Y[di:], X_csv[di:], X_video[di:]"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"HW6KYHVCPAJj"},"source":["### Model architecture"]},{"cell_type":"code","execution_count":95,"metadata":{"executionInfo":{"elapsed":14,"status":"ok","timestamp":1683733012659,"user":{"displayName":"Research ADHD","userId":"15453477032449770552"},"user_tz":-330},"id":"N5sOhxSrouBm"},"outputs":[],"source":["cnn_model = tf.keras.Sequential([\n"," tf.keras.layers.Conv3D(32, (3, 3, 3), activation='relu', input_shape=(10, 224, 224, 3)),\n"," tf.keras.layers.MaxPooling3D((2, 2, 2)),\n"," tf.keras.layers.Conv3D(64, (3, 3, 3), activation='relu'),\n"," tf.keras.layers.MaxPooling3D((2, 2, 2)),\n"," tf.keras.layers.Flatten(),\n"," tf.keras.layers.Dense(128, activation='relu'),\n"," tf.keras.layers.Dense(2, activation='softmax')\n","])\n"]},{"cell_type":"code","execution_count":96,"metadata":{"executionInfo":{"elapsed":13,"status":"ok","timestamp":1683733012659,"user":{"displayName":"Research ADHD","userId":"15453477032449770552"},"user_tz":-330},"id":"WnnCUVvtPGQP"},"outputs":[],"source":["cnn_model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])"]},{"cell_type":"code","execution_count":97,"metadata":{"executionInfo":{"elapsed":12,"status":"ok","timestamp":1683733012659,"user":{"displayName":"Research ADHD","userId":"15453477032449770552"},"user_tz":-330},"id":"R6v0eGz8XF2E"},"outputs":[],"source":["# Define model architecture\n","model = tf.keras.Sequential([\n"," tf.keras.layers.Flatten(input_shape=(1000, 5)),\n"," tf.keras.layers.Dense(64, activation='relu'),\n"," tf.keras.layers.Dense(32, activation='relu'),\n"," tf.keras.layers.Dense(1, activation='sigmoid')\n","])\n","\n","# Compile model\n","model.compile(optimizer='adam',\n"," loss='binary_crossentropy',\n"," metrics=['accuracy'])\n"]},{"cell_type":"code","execution_count":98,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":4886,"status":"ok","timestamp":1683733017534,"user":{"displayName":"Research ADHD","userId":"15453477032449770552"},"user_tz":-330},"id":"YJ0At5ijvBf4","outputId":"66ab74b5-233a-4e09-8340-f0dc163d337e"},"outputs":[{"name":"stdout","output_type":"stream","text":["Epoch 1/10\n","1/1 [==============================] - 2s 2s/step - loss: 0.6351 - accuracy: 1.0000 - val_loss: 2.0911e-22 - val_accuracy: 1.0000\n","Epoch 2/10\n","1/1 [==============================] - 0s 211ms/step - loss: 1.4724e-22 - accuracy: 1.0000 - val_loss: 0.0000e+00 - val_accuracy: 1.0000\n","Epoch 3/10\n","1/1 [==============================] - 0s 186ms/step - loss: 0.0000e+00 - accuracy: 1.0000 - val_loss: 0.0000e+00 - val_accuracy: 1.0000\n","Epoch 4/10\n","1/1 [==============================] - 0s 174ms/step - loss: 0.0000e+00 - accuracy: 1.0000 - val_loss: 0.0000e+00 - val_accuracy: 1.0000\n","Epoch 5/10\n","1/1 [==============================] - 0s 171ms/step - loss: 0.0000e+00 - accuracy: 1.0000 - val_loss: 0.0000e+00 - val_accuracy: 1.0000\n","Epoch 6/10\n","1/1 [==============================] - 0s 164ms/step - loss: 0.0000e+00 - accuracy: 1.0000 - val_loss: 0.0000e+00 - val_accuracy: 1.0000\n","Epoch 7/10\n","1/1 [==============================] - 0s 169ms/step - loss: 0.0000e+00 - accuracy: 1.0000 - val_loss: 0.0000e+00 - val_accuracy: 1.0000\n","Epoch 8/10\n","1/1 [==============================] - 0s 167ms/step - loss: 0.0000e+00 - accuracy: 1.0000 - val_loss: 0.0000e+00 - val_accuracy: 1.0000\n","Epoch 9/10\n","1/1 [==============================] - 0s 166ms/step - loss: 0.0000e+00 - accuracy: 1.0000 - val_loss: 0.0000e+00 - val_accuracy: 1.0000\n","Epoch 10/10\n","1/1 [==============================] - 0s 158ms/step - loss: 0.0000e+00 - accuracy: 1.0000 - val_loss: 0.0000e+00 - val_accuracy: 1.0000\n"]}],"source":["# Train the model on the dataset\n","history_cnn = cnn_model.fit(X_video_train, Y, epochs=10, validation_split=0.2)"]},{"cell_type":"code","execution_count":99,"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":744},"executionInfo":{"elapsed":29,"status":"error","timestamp":1683733017534,"user":{"displayName":"Research ADHD","userId":"15453477032449770552"},"user_tz":-330},"id":"3KO5RnJfXYYN","outputId":"173edb44-d0a9-4e46-9c73-a895adf92e85"},"outputs":[{"name":"stdout","output_type":"stream","text":["Epoch 1/10\n"]},{"ename":"ValueError","evalue":"ignored","output_type":"error","traceback":["\u001b[0;31m---------------------------------------------------------------------------\u001b[0m","\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)","\u001b[0;32m<ipython-input-99-86b6726ab2b7>\u001b[0m in \u001b[0;36m<cell line: 2>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# Train the model on the dataset\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mhistory\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX_csv_train\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mY\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mepochs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m10\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalidation_split\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0.2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m","\u001b[0;32m/usr/local/lib/python3.10/dist-packages/keras/utils/traceback_utils.py\u001b[0m in \u001b[0;36merror_handler\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 68\u001b[0m \u001b[0;31m# To get the full stack trace, call:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 69\u001b[0m \u001b[0;31m# `tf.debugging.disable_traceback_filtering()`\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 70\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwith_traceback\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfiltered_tb\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 71\u001b[0m \u001b[0;32mfinally\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 72\u001b[0m \u001b[0;32mdel\u001b[0m \u001b[0mfiltered_tb\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.10/dist-packages/keras/engine/training.py\u001b[0m in \u001b[0;36mtf__train_function\u001b[0;34m(iterator)\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[0mdo_return\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 15\u001b[0;31m \u001b[0mretval_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mag__\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconverted_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mag__\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mld\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstep_function\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mag__\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mld\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mag__\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mld\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0miterator\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfscope\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 16\u001b[0m \u001b[0;32mexcept\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[0mdo_return\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;31mValueError\u001b[0m: in user code:\n\n File \"/usr/local/lib/python3.10/dist-packages/keras/engine/training.py\", line 1284, in train_function *\n return step_function(self, iterator)\n File \"/usr/local/lib/python3.10/dist-packages/keras/engine/training.py\", line 1268, in step_function **\n outputs = model.distribute_strategy.run(run_step, args=(data,))\n File \"/usr/local/lib/python3.10/dist-packages/keras/engine/training.py\", line 1249, in run_step **\n outputs = model.train_step(data)\n File \"/usr/local/lib/python3.10/dist-packages/keras/engine/training.py\", line 1051, in train_step\n loss = self.compute_loss(x, y, y_pred, sample_weight)\n File \"/usr/local/lib/python3.10/dist-packages/keras/engine/training.py\", line 1109, in compute_loss\n return self.compiled_loss(\n File \"/usr/local/lib/python3.10/dist-packages/keras/engine/compile_utils.py\", line 265, in __call__\n loss_value = loss_obj(y_t, y_p, sample_weight=sw)\n File \"/usr/local/lib/python3.10/dist-packages/keras/losses.py\", line 142, in __call__\n losses = call_fn(y_true, y_pred)\n File \"/usr/local/lib/python3.10/dist-packages/keras/losses.py\", line 268, in call **\n return ag_fn(y_true, y_pred, **self._fn_kwargs)\n File \"/usr/local/lib/python3.10/dist-packages/keras/losses.py\", line 2156, in binary_crossentropy\n backend.binary_crossentropy(y_true, y_pred, from_logits=from_logits),\n File \"/usr/local/lib/python3.10/dist-packages/keras/backend.py\", line 5707, in binary_crossentropy\n return tf.nn.sigmoid_cross_entropy_with_logits(\n\n ValueError: `logits` and `labels` must have the same shape, received ((None, 1) vs (None, 2)).\n"]}],"source":["# Train the model on the dataset\n","history = model.fit(X_csv_train, Y, epochs=10, validation_split=0.2)"]},{"cell_type":"code","execution_count":null,"metadata":{"executionInfo":{"elapsed":27,"status":"aborted","timestamp":1683733017535,"user":{"displayName":"Research ADHD","userId":"15453477032449770552"},"user_tz":-330},"id":"0EdafLjBX3uE"},"outputs":[],"source":["# Freeze the layers of the CNN so they are not retrained during the joint training\n","for layer in cnn_model.layers:\n"," layer.trainable = False\n","\n","# Get the output of the last convolutional layer\n","cnn_output = cnn_model.layers[-2].output\n","\n","# Flatten the output into a vector\n","flattened_output = tf.keras.layers.Flatten()(cnn_output)\n","\n","# Feed the flattened output as input to the ANN\n","combined_model = tf.keras.models.Sequential([\n"," flattened_output,\n"," model\n","])\n","\n","# Compile the combined model\n","combined_model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])\n","\n","# Train the combined model on new data\n","combined_model.fit([X_video_train, X_csv_train], Y, epochs=10, validation_split=0.2)"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"c0VZ1vksPGJF"},"source":["## evaluate the model"]},{"cell_type":"code","execution_count":null,"metadata":{"executionInfo":{"elapsed":27,"status":"aborted","timestamp":1683733017535,"user":{"displayName":"Research ADHD","userId":"15453477032449770552"},"user_tz":-330},"id":"LP3YrQk_A0im"},"outputs":[],"source":["import matplotlib.pyplot as plt"]},{"cell_type":"code","execution_count":null,"metadata":{"executionInfo":{"elapsed":28,"status":"aborted","timestamp":1683733017536,"user":{"displayName":"Research ADHD","userId":"15453477032449770552"},"user_tz":-330},"id":"c43zv_NfDIe9"},"outputs":[],"source":["plt.plot(history.history['loss'])\n","plt.plot(history.history['val_loss'])\n","plt.title('Model loss')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch')\n","plt.legend(['Train', 'Validation'], loc='upper right')\n","plt.show()\n","\n","plt.plot(history.history['accuracy'])\n","plt.plot(history.history['val_accuracy'])\n","plt.title('Model accuracy')\n","plt.ylabel('Accuracy')\n","plt.xlabel('Epoch')\n","plt.legend(['Train', 'Validation'], loc='lower right')\n","plt.show()"]},{"cell_type":"code","execution_count":null,"metadata":{"executionInfo":{"elapsed":27,"status":"aborted","timestamp":1683733017536,"user":{"displayName":"Research ADHD","userId":"15453477032449770552"},"user_tz":-330},"id":"G5XUZO-FDO30"},"outputs":[],"source":["# Evaluate model on test set\n","test_loss, test_acc = combined_model.evaluate([X_video_test,X_csv_test], Y_test)"]},{"cell_type":"code","execution_count":null,"metadata":{"executionInfo":{"elapsed":27,"status":"aborted","timestamp":1683733017536,"user":{"displayName":"Research ADHD","userId":"15453477032449770552"},"user_tz":-330},"id":"rIKbxSeVDlC4"},"outputs":[],"source":["from sklearn.metrics import precision_score, recall_score, f1_score, ConfusionMatrixDisplay,confusion_matrix"]},{"cell_type":"code","execution_count":null,"metadata":{"executionInfo":{"elapsed":28,"status":"aborted","timestamp":1683733017537,"user":{"displayName":"Research ADHD","userId":"15453477032449770552"},"user_tz":-330},"id":"aEN9miVvDaxL"},"outputs":[],"source":["# Make predictions on test data\n","y_pred = combined_model.predict([X_video_test,X_csv_test])\n","\n","# Convert predictions to binary labels\n","y_pred = (y_pred > 0.5).astype(int)\n","\n","# Calculate precision, recall, and F1 scores\n","precision = precision_score(Y_test, y_pred)\n","recall = recall_score(Y_test, y_pred)\n","f1 = f1_score(Y_test, y_pred)\n","\n","print('Precision:', precision)\n","print('Recall:', recall)\n","print('F1 score:', f1)"]},{"cell_type":"code","execution_count":null,"metadata":{"executionInfo":{"elapsed":30,"status":"aborted","timestamp":1683733017539,"user":{"displayName":"Research ADHD","userId":"15453477032449770552"},"user_tz":-330},"id":"olorqG-KDxj5"},"outputs":[],"source":["data = {'Test_loss':test_loss*100, 'Test_acc':test_acc*100, 'Precision': precision*100, 'Recall': recall*100,'F1 score': f1*100}\n","courses = list(data.keys())\n","values = list(data.values())\n","fig = plt.figure(figsize = (10, 5))\n","colr = ['red', 'green', 'black', 'blue', 'orange']\n"," # creating the bar plot\n","plt.bar(courses, values, color = colr,\n"," width = 0.4)\n"," \n","plt.xlabel(\"Evaluation Criteria\")\n","plt.ylabel(\"Out of 100%\")\n","plt.title(\"CNN evaluation graph\")\n","plt.show()"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"clardaFZtKJ0"},"source":["## Predict for new values"]},{"cell_type":"code","execution_count":null,"metadata":{"executionInfo":{"elapsed":30,"status":"aborted","timestamp":1683733017539,"user":{"displayName":"Research ADHD","userId":"15453477032449770552"},"user_tz":-330},"id":"Fp2c_3xXtTwC"},"outputs":[],"source":["path_csv = \"/content/drive/MyDrive/Colab Notebooks/MathsGame/TestData/eeg/P135.csv\"\n","path_video = \"/content/drive/MyDrive/Colab Notebooks/MathsGame/TestData/video/P35.mp4\"\n","X_new_csv = np.array(getDf(path_csv))\n","X_new_video = np.array(getVideo(path_video))\n","X_new_csv = X_new_csv.astype(float)\n","X_new_video = X_new_video.astype(float)"]},{"cell_type":"code","execution_count":null,"metadata":{"executionInfo":{"elapsed":30,"status":"aborted","timestamp":1683733017539,"user":{"displayName":"Research ADHD","userId":"15453477032449770552"},"user_tz":-330},"id":"IdeGNoFLtZjI"},"outputs":[],"source":["# Make predictions on new data\n","y_pred = cnn_model.predict(np.expand_dims(X_new_video, 0))\n","y_pred"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"CIDzh_9Fspmo"},"source":["## export and load the model"]},{"cell_type":"code","execution_count":null,"metadata":{"executionInfo":{"elapsed":31,"status":"aborted","timestamp":1683733017540,"user":{"displayName":"Research ADHD","userId":"15453477032449770552"},"user_tz":-330},"id":"_-_fdsvZFBfE"},"outputs":[],"source":["from tensorflow.keras.models import load_model"]},{"cell_type":"code","execution_count":null,"metadata":{"executionInfo":{"elapsed":31,"status":"aborted","timestamp":1683733017540,"user":{"displayName":"Research ADHD","userId":"15453477032449770552"},"user_tz":-330},"id":"t1crwULws2JV"},"outputs":[],"source":["combined_model.save(os.path.join('models','combined_model_model_math.h5'))"]}],"metadata":{"accelerator":"GPU","colab":{"authorship_tag":"ABX9TyO7AlHQ4YcLLKitBHaqXpyn","machine_shape":"hm","mount_file_id":"1ElxyHvhNwgqUVGNkG8CUuIgk8U-xgj8T","provenance":[]},"gpuClass":"standard","kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.10.9"}},"nbformat":4,"nbformat_minor":0}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment