Add model jupyter source file.

parent 10283f70
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\sulak\\anaconda3\\lib\\site-packages\\paramiko\\transport.py:219: CryptographyDeprecationWarning: Blowfish has been deprecated\n",
" \"class\": algorithms.Blowfish,\n"
]
}
],
"source": [
"import os, glob\n",
"import warnings\n",
"import numpy as np\n",
"import pandas as pd\n",
"import tensorflow as tf\n",
"import matplotlib.pyplot as plt\n",
"from pyannote.audio import Model, Inference\n",
"warnings.filterwarnings(\"ignore\")"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"embedding_model = Model.from_pretrained(\n",
" \"pyannote/embedding\", \n",
" use_auth_token=\"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\"\n",
" )\n",
"embedding_inference = Inference(\n",
" embedding_model, \n",
" window=\"whole\"\n",
" )\n",
"\n",
"class_dict = {\n",
" 'autism': 0,\n",
" 'non-autism': 1\n",
" }"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"def load_dataset(audio_dir='data/abnormality'):\n",
" voice_files = glob.glob(f'{audio_dir}/*/*/*.wav')\n",
" voice_files = [voice_file.replace('\\\\', '/') for voice_file in voice_files]\n",
"\n",
" folder_names = [voice_file.split('/')[-3] for voice_file in voice_files]\n",
" labels = [class_dict[folder_name] for folder_name in folder_names]\n",
"\n",
" embeddings = np.zeros((len(voice_files), 512))\n",
" labels = np.array(labels)\n",
" errorneous_idxs = []\n",
" for i, voice_file in enumerate(voice_files):\n",
" try:\n",
" embeddings[i] = embedding_inference(voice_file)\n",
" except:\n",
" errorneous_idxs.append(i)\n",
" print('Errorneous file: ', voice_file)\n",
"\n",
" embeddings = np.delete(embeddings, errorneous_idxs, axis=0)\n",
" labels = np.delete(labels, errorneous_idxs, axis=0)\n",
" return embeddings, labels"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Embedding shape: (32, 512)\n",
"labels shape: (32,)\n"
]
}
],
"source": [
"embeddings, labels = load_dataset()\n",
"\n",
"print(\"Embedding shape: \", embeddings.shape)\n",
"print(\"labels shape: \", labels.shape)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"def build_model():\n",
" inputs = tf.keras.Input(shape=(512,))\n",
" x = tf.keras.layers.Dense(256, activation='relu')(inputs)\n",
" x = tf.keras.layers.Dense(128)(x)\n",
" x = tf.keras.layers.BatchNormalization()(x)\n",
" x = tf.keras.layers.Activation('relu')(x)\n",
" x = tf.keras.layers.Dropout(0.2)(x)\n",
"\n",
" x = tf.keras.layers.Dense(64)(x)\n",
" x = tf.keras.layers.BatchNormalization()(x)\n",
" x = tf.keras.layers.Activation('relu')(x)\n",
" x = tf.keras.layers.Dropout(0.2)(x)\n",
"\n",
" x = tf.keras.layers.Dense(32)(x)\n",
" x = tf.keras.layers.BatchNormalization()(x)\n",
" x = tf.keras.layers.Activation('relu')(x)\n",
" x = tf.keras.layers.Dropout(0.2)(x)\n",
"\n",
" outputs = tf.keras.layers.Dense(1, activation='sigmoid', name='detection')(x)\n",
"\n",
" model = tf.keras.Model(\n",
" inputs=inputs, \n",
" outputs=outputs\n",
" )\n",
" model.compile(\n",
" optimizer='adam',\n",
" loss='binary_crossentropy',\n",
" metrics=[\n",
" tf.keras.metrics.BinaryAccuracy(name='accuracy'),\n",
" tf.keras.metrics.Precision(name='precision'),\n",
" tf.keras.metrics.Recall(name='recall'),\n",
" tf.keras.metrics.AUC(name='auc')\n",
" ]\n",
" )\n",
" return model"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: \"model\"\n",
"_________________________________________________________________\n",
" Layer (type) Output Shape Param # \n",
"=================================================================\n",
" input_1 (InputLayer) [(None, 512)] 0 \n",
" \n",
" dense (Dense) (None, 256) 131328 \n",
" \n",
" dense_1 (Dense) (None, 128) 32896 \n",
" \n",
" batch_normalization (BatchN (None, 128) 512 \n",
" ormalization) \n",
" \n",
" activation (Activation) (None, 128) 0 \n",
" \n",
" dropout (Dropout) (None, 128) 0 \n",
" \n",
" dense_2 (Dense) (None, 64) 8256 \n",
" \n",
" batch_normalization_1 (Batc (None, 64) 256 \n",
" hNormalization) \n",
" \n",
" activation_1 (Activation) (None, 64) 0 \n",
" \n",
" dropout_1 (Dropout) (None, 64) 0 \n",
" \n",
" dense_3 (Dense) (None, 32) 2080 \n",
" \n",
" batch_normalization_2 (Batc (None, 32) 128 \n",
" hNormalization) \n",
" \n",
" activation_2 (Activation) (None, 32) 0 \n",
" \n",
" dropout_2 (Dropout) (None, 32) 0 \n",
" \n",
" detection (Dense) (None, 1) 33 \n",
" \n",
"=================================================================\n",
"Total params: 175,489\n",
"Trainable params: 175,041\n",
"Non-trainable params: 448\n",
"_________________________________________________________________\n"
]
}
],
"source": [
"model = build_model()\n",
"model.summary()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/100\n",
"4/4 [==============================] - 1s 8ms/step - loss: 0.8720 - accuracy: 0.4062 - precision: 0.3846 - recall: 0.3125 - auc: 0.3613\n",
"Epoch 2/100\n",
"4/4 [==============================] - 0s 4ms/step - loss: 0.6819 - accuracy: 0.5625 - precision: 0.6250 - recall: 0.3125 - auc: 0.6250\n",
"Epoch 3/100\n",
"4/4 [==============================] - 0s 4ms/step - loss: 0.5912 - accuracy: 0.6875 - precision: 0.8000 - recall: 0.5000 - auc: 0.7383 \n",
"Epoch 4/100\n",
"4/4 [==============================] - 0s 4ms/step - loss: 0.4495 - accuracy: 0.8438 - precision: 0.8667 - recall: 0.8125 - auc: 0.8945\n",
"Epoch 5/100\n",
"4/4 [==============================] - 0s 4ms/step - loss: 0.3916 - accuracy: 0.8438 - precision: 0.9231 - recall: 0.7500 - auc: 0.9297\n",
"Epoch 6/100\n",
"4/4 [==============================] - 0s 4ms/step - loss: 0.2661 - accuracy: 0.9062 - precision: 0.9333 - recall: 0.8750 - auc: 0.9883\n",
"Epoch 7/100\n",
"4/4 [==============================] - 0s 3ms/step - loss: 0.3791 - accuracy: 0.8750 - precision: 0.8750 - recall: 0.8750 - auc: 0.9355\n",
"Epoch 8/100\n",
"4/4 [==============================] - 0s 3ms/step - loss: 0.3156 - accuracy: 0.9062 - precision: 0.9333 - recall: 0.8750 - auc: 0.9922\n",
"Epoch 9/100\n",
"4/4 [==============================] - 0s 3ms/step - loss: 0.2801 - accuracy: 0.9375 - precision: 0.9375 - recall: 0.9375 - auc: 0.9805\n",
"Epoch 10/100\n",
"4/4 [==============================] - 0s 4ms/step - loss: 0.2198 - accuracy: 0.9688 - precision: 1.0000 - recall: 0.9375 - auc: 1.0000\n",
"Epoch 11/100\n",
"4/4 [==============================] - 0s 4ms/step - loss: 0.2803 - accuracy: 0.9062 - precision: 1.0000 - recall: 0.8125 - auc: 0.9922\n",
"Epoch 12/100\n",
"4/4 [==============================] - 0s 4ms/step - loss: 0.2149 - accuracy: 0.9688 - precision: 1.0000 - recall: 0.9375 - auc: 1.0000\n",
"Epoch 13/100\n",
"4/4 [==============================] - 0s 4ms/step - loss: 0.1541 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 14/100\n",
"4/4 [==============================] - 0s 4ms/step - loss: 0.1529 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 15/100\n",
"4/4 [==============================] - 0s 4ms/step - loss: 0.2901 - accuracy: 0.8750 - precision: 0.8750 - recall: 0.8750 - auc: 0.9648\n",
"Epoch 16/100\n",
"4/4 [==============================] - 0s 4ms/step - loss: 0.1161 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 17/100\n",
"4/4 [==============================] - 0s 4ms/step - loss: 0.1315 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 18/100\n",
"4/4 [==============================] - 0s 3ms/step - loss: 0.3717 - accuracy: 0.9062 - precision: 1.0000 - recall: 0.8125 - auc: 0.9043\n",
"Epoch 19/100\n",
"4/4 [==============================] - 0s 4ms/step - loss: 0.1533 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 20/100\n",
"4/4 [==============================] - 0s 4ms/step - loss: 0.1370 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 21/100\n",
"4/4 [==============================] - 0s 3ms/step - loss: 0.1268 - accuracy: 0.9688 - precision: 0.9412 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 22/100\n",
"4/4 [==============================] - 0s 4ms/step - loss: 0.1185 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 23/100\n",
"4/4 [==============================] - 0s 4ms/step - loss: 0.1002 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 24/100\n",
"4/4 [==============================] - 0s 5ms/step - loss: 0.1243 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 25/100\n",
"4/4 [==============================] - 0s 5ms/step - loss: 0.1121 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 26/100\n",
"4/4 [==============================] - 0s 4ms/step - loss: 0.1200 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 27/100\n",
"4/4 [==============================] - 0s 3ms/step - loss: 0.1230 - accuracy: 0.9688 - precision: 1.0000 - recall: 0.9375 - auc: 1.0000\n",
"Epoch 28/100\n",
"4/4 [==============================] - 0s 4ms/step - loss: 0.0778 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 29/100\n",
"4/4 [==============================] - 0s 3ms/step - loss: 0.0838 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 30/100\n",
"4/4 [==============================] - 0s 4ms/step - loss: 0.0756 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 31/100\n",
"4/4 [==============================] - 0s 3ms/step - loss: 0.0807 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 32/100\n",
"4/4 [==============================] - 0s 3ms/step - loss: 0.0835 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 33/100\n",
"4/4 [==============================] - 0s 3ms/step - loss: 0.1426 - accuracy: 0.9375 - precision: 0.9375 - recall: 0.9375 - auc: 0.9961\n",
"Epoch 34/100\n",
"4/4 [==============================] - 0s 4ms/step - loss: 0.0730 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 35/100\n",
"4/4 [==============================] - 0s 4ms/step - loss: 0.0690 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 36/100\n",
"4/4 [==============================] - 0s 3ms/step - loss: 0.0762 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 37/100\n",
"4/4 [==============================] - 0s 3ms/step - loss: 0.0935 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 38/100\n",
"4/4 [==============================] - 0s 3ms/step - loss: 0.1438 - accuracy: 0.9688 - precision: 1.0000 - recall: 0.9375 - auc: 1.0000\n",
"Epoch 39/100\n",
"4/4 [==============================] - 0s 4ms/step - loss: 0.0632 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 40/100\n",
"4/4 [==============================] - 0s 4ms/step - loss: 0.0515 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 41/100\n",
"4/4 [==============================] - 0s 3ms/step - loss: 0.0540 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 42/100\n",
"4/4 [==============================] - 0s 4ms/step - loss: 0.0509 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 43/100\n",
"4/4 [==============================] - 0s 3ms/step - loss: 0.0636 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 44/100\n",
"4/4 [==============================] - 0s 3ms/step - loss: 0.0900 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 45/100\n",
"4/4 [==============================] - 0s 3ms/step - loss: 0.0758 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 46/100\n",
"4/4 [==============================] - 0s 4ms/step - loss: 0.0447 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 47/100\n",
"4/4 [==============================] - 0s 3ms/step - loss: 0.1070 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 48/100\n",
"4/4 [==============================] - 0s 3ms/step - loss: 0.1172 - accuracy: 0.9688 - precision: 0.9412 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 49/100\n",
"4/4 [==============================] - 0s 3ms/step - loss: 0.0479 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 50/100\n",
"4/4 [==============================] - 0s 4ms/step - loss: 0.0435 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 51/100\n",
"4/4 [==============================] - 0s 3ms/step - loss: 0.1728 - accuracy: 0.9375 - precision: 1.0000 - recall: 0.8750 - auc: 0.9922\n",
"Epoch 52/100\n",
"4/4 [==============================] - 0s 3ms/step - loss: 0.1023 - accuracy: 0.9688 - precision: 1.0000 - recall: 0.9375 - auc: 1.0000\n",
"Epoch 53/100\n",
"4/4 [==============================] - 0s 3ms/step - loss: 0.0716 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 54/100\n",
"4/4 [==============================] - 0s 4ms/step - loss: 0.0358 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 55/100\n",
"4/4 [==============================] - 0s 4ms/step - loss: 0.0267 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 56/100\n",
"4/4 [==============================] - 0s 3ms/step - loss: 0.0797 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 57/100\n",
"4/4 [==============================] - 0s 3ms/step - loss: 0.0783 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 58/100\n",
"4/4 [==============================] - 0s 5ms/step - loss: 0.0528 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 59/100\n",
"4/4 [==============================] - 0s 4ms/step - loss: 0.0318 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 60/100\n",
"4/4 [==============================] - 0s 4ms/step - loss: 0.0753 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 61/100\n",
"4/4 [==============================] - 0s 3ms/step - loss: 0.3589 - accuracy: 0.8438 - precision: 0.7895 - recall: 0.9375 - auc: 0.9258 \n",
"Epoch 62/100\n",
"4/4 [==============================] - 0s 4ms/step - loss: 0.0808 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 63/100\n",
"4/4 [==============================] - 0s 3ms/step - loss: 0.0684 - accuracy: 0.9688 - precision: 0.9412 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 64/100\n",
"4/4 [==============================] - 0s 3ms/step - loss: 0.1127 - accuracy: 0.9375 - precision: 0.8889 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 65/100\n",
"4/4 [==============================] - 0s 7ms/step - loss: 0.1123 - accuracy: 0.9375 - precision: 0.8889 - recall: 1.0000 - auc: 1.0000\n"
]
},
{
"data": {
"text/plain": [
"<keras.callbacks.History at 0x16a176b08b0>"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.fit(\n",
" embeddings,\n",
" labels,\n",
" epochs=100,\n",
" batch_size=8,\n",
" callbacks=[\n",
" tf.keras.callbacks.EarlyStopping(\n",
" monitor='loss',\n",
" patience=10,\n",
" restore_best_weights=True\n",
" )\n",
" ] \n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"model.save('models/abnomility-sentiment.h5')"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"class_dict_rev = {v: k for k, v in class_dict.items()} "
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"def inference_abnomility_sentiment(audio_file):\n",
" embedding = embedding_inference(audio_file)\n",
" embedding = np.expand_dims(embedding, axis=0)\n",
" sentiment = model.predict(embedding)\n",
" sentiment = sentiment.squeeze()\n",
" sentiment = np.round(sentiment)\n",
" sentiment = int(sentiment)\n",
" return class_dict_rev[sentiment]"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1/1 [==============================] - 0s 153ms/step\n"
]
},
{
"data": {
"text/plain": [
"'autism'"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"response = inference_abnomility_sentiment('data/abnormality/autism/Child 1 - 16/child16_8-තාත්ති එක්ක ආවෙ.wav')\n",
"response"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "tf210",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment