Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
A
ASD_Detection
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
2023-161
ASD_Detection
Commits
66c8e0d1
Commit
66c8e0d1
authored
Sep 09, 2023
by
Wijegunarathna K. P. S. G. G.
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add model jupyter source file.
parent
10283f70
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
460 additions
and
0 deletions
+460
-0
Backend/abnomility-sentiment.ipynb
Backend/abnomility-sentiment.ipynb
+460
-0
No files found.
Backend/abnomility-sentiment.ipynb
0 → 100644
View file @
66c8e0d1
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\sulak\\anaconda3\\lib\\site-packages\\paramiko\\transport.py:219: CryptographyDeprecationWarning: Blowfish has been deprecated\n",
" \"class\": algorithms.Blowfish,\n"
]
}
],
"source": [
"import os, glob\n",
"import warnings\n",
"import numpy as np\n",
"import pandas as pd\n",
"import tensorflow as tf\n",
"import matplotlib.pyplot as plt\n",
"from pyannote.audio import Model, Inference\n",
"warnings.filterwarnings(\"ignore\")"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"embedding_model = Model.from_pretrained(\n",
" \"pyannote/embedding\", \n",
" use_auth_token=\"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\"\n",
" )\n",
"embedding_inference = Inference(\n",
" embedding_model, \n",
" window=\"whole\"\n",
" )\n",
"\n",
"class_dict = {\n",
" 'autism': 0,\n",
" 'non-autism': 1\n",
" }"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"def load_dataset(audio_dir='data/abnormality'):\n",
" voice_files = glob.glob(f'{audio_dir}/*/*/*.wav')\n",
" voice_files = [voice_file.replace('\\\\', '/') for voice_file in voice_files]\n",
"\n",
" folder_names = [voice_file.split('/')[-3] for voice_file in voice_files]\n",
" labels = [class_dict[folder_name] for folder_name in folder_names]\n",
"\n",
" embeddings = np.zeros((len(voice_files), 512))\n",
" labels = np.array(labels)\n",
" errorneous_idxs = []\n",
" for i, voice_file in enumerate(voice_files):\n",
" try:\n",
" embeddings[i] = embedding_inference(voice_file)\n",
" except:\n",
" errorneous_idxs.append(i)\n",
" print('Errorneous file: ', voice_file)\n",
"\n",
" embeddings = np.delete(embeddings, errorneous_idxs, axis=0)\n",
" labels = np.delete(labels, errorneous_idxs, axis=0)\n",
" return embeddings, labels"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Embedding shape: (32, 512)\n",
"labels shape: (32,)\n"
]
}
],
"source": [
"embeddings, labels = load_dataset()\n",
"\n",
"print(\"Embedding shape: \", embeddings.shape)\n",
"print(\"labels shape: \", labels.shape)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"def build_model():\n",
" inputs = tf.keras.Input(shape=(512,))\n",
" x = tf.keras.layers.Dense(256, activation='relu')(inputs)\n",
" x = tf.keras.layers.Dense(128)(x)\n",
" x = tf.keras.layers.BatchNormalization()(x)\n",
" x = tf.keras.layers.Activation('relu')(x)\n",
" x = tf.keras.layers.Dropout(0.2)(x)\n",
"\n",
" x = tf.keras.layers.Dense(64)(x)\n",
" x = tf.keras.layers.BatchNormalization()(x)\n",
" x = tf.keras.layers.Activation('relu')(x)\n",
" x = tf.keras.layers.Dropout(0.2)(x)\n",
"\n",
" x = tf.keras.layers.Dense(32)(x)\n",
" x = tf.keras.layers.BatchNormalization()(x)\n",
" x = tf.keras.layers.Activation('relu')(x)\n",
" x = tf.keras.layers.Dropout(0.2)(x)\n",
"\n",
" outputs = tf.keras.layers.Dense(1, activation='sigmoid', name='detection')(x)\n",
"\n",
" model = tf.keras.Model(\n",
" inputs=inputs, \n",
" outputs=outputs\n",
" )\n",
" model.compile(\n",
" optimizer='adam',\n",
" loss='binary_crossentropy',\n",
" metrics=[\n",
" tf.keras.metrics.BinaryAccuracy(name='accuracy'),\n",
" tf.keras.metrics.Precision(name='precision'),\n",
" tf.keras.metrics.Recall(name='recall'),\n",
" tf.keras.metrics.AUC(name='auc')\n",
" ]\n",
" )\n",
" return model"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: \"model\"\n",
"_________________________________________________________________\n",
" Layer (type) Output Shape Param # \n",
"=================================================================\n",
" input_1 (InputLayer) [(None, 512)] 0 \n",
" \n",
" dense (Dense) (None, 256) 131328 \n",
" \n",
" dense_1 (Dense) (None, 128) 32896 \n",
" \n",
" batch_normalization (BatchN (None, 128) 512 \n",
" ormalization) \n",
" \n",
" activation (Activation) (None, 128) 0 \n",
" \n",
" dropout (Dropout) (None, 128) 0 \n",
" \n",
" dense_2 (Dense) (None, 64) 8256 \n",
" \n",
" batch_normalization_1 (Batc (None, 64) 256 \n",
" hNormalization) \n",
" \n",
" activation_1 (Activation) (None, 64) 0 \n",
" \n",
" dropout_1 (Dropout) (None, 64) 0 \n",
" \n",
" dense_3 (Dense) (None, 32) 2080 \n",
" \n",
" batch_normalization_2 (Batc (None, 32) 128 \n",
" hNormalization) \n",
" \n",
" activation_2 (Activation) (None, 32) 0 \n",
" \n",
" dropout_2 (Dropout) (None, 32) 0 \n",
" \n",
" detection (Dense) (None, 1) 33 \n",
" \n",
"=================================================================\n",
"Total params: 175,489\n",
"Trainable params: 175,041\n",
"Non-trainable params: 448\n",
"_________________________________________________________________\n"
]
}
],
"source": [
"model = build_model()\n",
"model.summary()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/100\n",
"4/4 [==============================] - 1s 8ms/step - loss: 0.8720 - accuracy: 0.4062 - precision: 0.3846 - recall: 0.3125 - auc: 0.3613\n",
"Epoch 2/100\n",
"4/4 [==============================] - 0s 4ms/step - loss: 0.6819 - accuracy: 0.5625 - precision: 0.6250 - recall: 0.3125 - auc: 0.6250\n",
"Epoch 3/100\n",
"4/4 [==============================] - 0s 4ms/step - loss: 0.5912 - accuracy: 0.6875 - precision: 0.8000 - recall: 0.5000 - auc: 0.7383 \n",
"Epoch 4/100\n",
"4/4 [==============================] - 0s 4ms/step - loss: 0.4495 - accuracy: 0.8438 - precision: 0.8667 - recall: 0.8125 - auc: 0.8945\n",
"Epoch 5/100\n",
"4/4 [==============================] - 0s 4ms/step - loss: 0.3916 - accuracy: 0.8438 - precision: 0.9231 - recall: 0.7500 - auc: 0.9297\n",
"Epoch 6/100\n",
"4/4 [==============================] - 0s 4ms/step - loss: 0.2661 - accuracy: 0.9062 - precision: 0.9333 - recall: 0.8750 - auc: 0.9883\n",
"Epoch 7/100\n",
"4/4 [==============================] - 0s 3ms/step - loss: 0.3791 - accuracy: 0.8750 - precision: 0.8750 - recall: 0.8750 - auc: 0.9355\n",
"Epoch 8/100\n",
"4/4 [==============================] - 0s 3ms/step - loss: 0.3156 - accuracy: 0.9062 - precision: 0.9333 - recall: 0.8750 - auc: 0.9922\n",
"Epoch 9/100\n",
"4/4 [==============================] - 0s 3ms/step - loss: 0.2801 - accuracy: 0.9375 - precision: 0.9375 - recall: 0.9375 - auc: 0.9805\n",
"Epoch 10/100\n",
"4/4 [==============================] - 0s 4ms/step - loss: 0.2198 - accuracy: 0.9688 - precision: 1.0000 - recall: 0.9375 - auc: 1.0000\n",
"Epoch 11/100\n",
"4/4 [==============================] - 0s 4ms/step - loss: 0.2803 - accuracy: 0.9062 - precision: 1.0000 - recall: 0.8125 - auc: 0.9922\n",
"Epoch 12/100\n",
"4/4 [==============================] - 0s 4ms/step - loss: 0.2149 - accuracy: 0.9688 - precision: 1.0000 - recall: 0.9375 - auc: 1.0000\n",
"Epoch 13/100\n",
"4/4 [==============================] - 0s 4ms/step - loss: 0.1541 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 14/100\n",
"4/4 [==============================] - 0s 4ms/step - loss: 0.1529 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 15/100\n",
"4/4 [==============================] - 0s 4ms/step - loss: 0.2901 - accuracy: 0.8750 - precision: 0.8750 - recall: 0.8750 - auc: 0.9648\n",
"Epoch 16/100\n",
"4/4 [==============================] - 0s 4ms/step - loss: 0.1161 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 17/100\n",
"4/4 [==============================] - 0s 4ms/step - loss: 0.1315 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 18/100\n",
"4/4 [==============================] - 0s 3ms/step - loss: 0.3717 - accuracy: 0.9062 - precision: 1.0000 - recall: 0.8125 - auc: 0.9043\n",
"Epoch 19/100\n",
"4/4 [==============================] - 0s 4ms/step - loss: 0.1533 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 20/100\n",
"4/4 [==============================] - 0s 4ms/step - loss: 0.1370 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 21/100\n",
"4/4 [==============================] - 0s 3ms/step - loss: 0.1268 - accuracy: 0.9688 - precision: 0.9412 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 22/100\n",
"4/4 [==============================] - 0s 4ms/step - loss: 0.1185 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 23/100\n",
"4/4 [==============================] - 0s 4ms/step - loss: 0.1002 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 24/100\n",
"4/4 [==============================] - 0s 5ms/step - loss: 0.1243 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 25/100\n",
"4/4 [==============================] - 0s 5ms/step - loss: 0.1121 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 26/100\n",
"4/4 [==============================] - 0s 4ms/step - loss: 0.1200 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 27/100\n",
"4/4 [==============================] - 0s 3ms/step - loss: 0.1230 - accuracy: 0.9688 - precision: 1.0000 - recall: 0.9375 - auc: 1.0000\n",
"Epoch 28/100\n",
"4/4 [==============================] - 0s 4ms/step - loss: 0.0778 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 29/100\n",
"4/4 [==============================] - 0s 3ms/step - loss: 0.0838 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 30/100\n",
"4/4 [==============================] - 0s 4ms/step - loss: 0.0756 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 31/100\n",
"4/4 [==============================] - 0s 3ms/step - loss: 0.0807 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 32/100\n",
"4/4 [==============================] - 0s 3ms/step - loss: 0.0835 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 33/100\n",
"4/4 [==============================] - 0s 3ms/step - loss: 0.1426 - accuracy: 0.9375 - precision: 0.9375 - recall: 0.9375 - auc: 0.9961\n",
"Epoch 34/100\n",
"4/4 [==============================] - 0s 4ms/step - loss: 0.0730 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 35/100\n",
"4/4 [==============================] - 0s 4ms/step - loss: 0.0690 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 36/100\n",
"4/4 [==============================] - 0s 3ms/step - loss: 0.0762 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 37/100\n",
"4/4 [==============================] - 0s 3ms/step - loss: 0.0935 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 38/100\n",
"4/4 [==============================] - 0s 3ms/step - loss: 0.1438 - accuracy: 0.9688 - precision: 1.0000 - recall: 0.9375 - auc: 1.0000\n",
"Epoch 39/100\n",
"4/4 [==============================] - 0s 4ms/step - loss: 0.0632 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 40/100\n",
"4/4 [==============================] - 0s 4ms/step - loss: 0.0515 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 41/100\n",
"4/4 [==============================] - 0s 3ms/step - loss: 0.0540 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 42/100\n",
"4/4 [==============================] - 0s 4ms/step - loss: 0.0509 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 43/100\n",
"4/4 [==============================] - 0s 3ms/step - loss: 0.0636 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 44/100\n",
"4/4 [==============================] - 0s 3ms/step - loss: 0.0900 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 45/100\n",
"4/4 [==============================] - 0s 3ms/step - loss: 0.0758 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 46/100\n",
"4/4 [==============================] - 0s 4ms/step - loss: 0.0447 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 47/100\n",
"4/4 [==============================] - 0s 3ms/step - loss: 0.1070 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 48/100\n",
"4/4 [==============================] - 0s 3ms/step - loss: 0.1172 - accuracy: 0.9688 - precision: 0.9412 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 49/100\n",
"4/4 [==============================] - 0s 3ms/step - loss: 0.0479 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 50/100\n",
"4/4 [==============================] - 0s 4ms/step - loss: 0.0435 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 51/100\n",
"4/4 [==============================] - 0s 3ms/step - loss: 0.1728 - accuracy: 0.9375 - precision: 1.0000 - recall: 0.8750 - auc: 0.9922\n",
"Epoch 52/100\n",
"4/4 [==============================] - 0s 3ms/step - loss: 0.1023 - accuracy: 0.9688 - precision: 1.0000 - recall: 0.9375 - auc: 1.0000\n",
"Epoch 53/100\n",
"4/4 [==============================] - 0s 3ms/step - loss: 0.0716 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 54/100\n",
"4/4 [==============================] - 0s 4ms/step - loss: 0.0358 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 55/100\n",
"4/4 [==============================] - 0s 4ms/step - loss: 0.0267 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 56/100\n",
"4/4 [==============================] - 0s 3ms/step - loss: 0.0797 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 57/100\n",
"4/4 [==============================] - 0s 3ms/step - loss: 0.0783 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 58/100\n",
"4/4 [==============================] - 0s 5ms/step - loss: 0.0528 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 59/100\n",
"4/4 [==============================] - 0s 4ms/step - loss: 0.0318 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 60/100\n",
"4/4 [==============================] - 0s 4ms/step - loss: 0.0753 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 61/100\n",
"4/4 [==============================] - 0s 3ms/step - loss: 0.3589 - accuracy: 0.8438 - precision: 0.7895 - recall: 0.9375 - auc: 0.9258 \n",
"Epoch 62/100\n",
"4/4 [==============================] - 0s 4ms/step - loss: 0.0808 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 63/100\n",
"4/4 [==============================] - 0s 3ms/step - loss: 0.0684 - accuracy: 0.9688 - precision: 0.9412 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 64/100\n",
"4/4 [==============================] - 0s 3ms/step - loss: 0.1127 - accuracy: 0.9375 - precision: 0.8889 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 65/100\n",
"4/4 [==============================] - 0s 7ms/step - loss: 0.1123 - accuracy: 0.9375 - precision: 0.8889 - recall: 1.0000 - auc: 1.0000\n"
]
},
{
"data": {
"text/plain": [
"<keras.callbacks.History at 0x16a176b08b0>"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.fit(\n",
" embeddings,\n",
" labels,\n",
" epochs=100,\n",
" batch_size=8,\n",
" callbacks=[\n",
" tf.keras.callbacks.EarlyStopping(\n",
" monitor='loss',\n",
" patience=10,\n",
" restore_best_weights=True\n",
" )\n",
" ] \n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"model.save('models/abnomility-sentiment.h5')"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"class_dict_rev = {v: k for k, v in class_dict.items()} "
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"def inference_abnomility_sentiment(audio_file):\n",
" embedding = embedding_inference(audio_file)\n",
" embedding = np.expand_dims(embedding, axis=0)\n",
" sentiment = model.predict(embedding)\n",
" sentiment = sentiment.squeeze()\n",
" sentiment = np.round(sentiment)\n",
" sentiment = int(sentiment)\n",
" return class_dict_rev[sentiment]"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1/1 [==============================] - 0s 153ms/step\n"
]
},
{
"data": {
"text/plain": [
"'autism'"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"response = inference_abnomility_sentiment('data/abnormality/autism/Child 1 - 16/child16_8-තාත්ති එක්ක ආවෙ.wav')\n",
"response"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "tf210",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment