Commit 77738a5f authored by RSRGJN Ananda's avatar RSRGJN Ananda

pronounce-validation model

parent b6b4b9f5
{
"cells": [
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"import os, glob\n",
"import warnings\n",
"warnings.filterwarnings(\"ignore\")\n",
"\n",
"import numpy as np\n",
"import pandas as pd\n",
"import tensorflow as tf\n",
"import matplotlib.pyplot as plt\n",
"from pyannote.audio import Model, Inference\n",
"# from speechbrain.pretrained import SepformerSeparation"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"embedding_model = Model.from_pretrained(\n",
" \"pyannote/embedding\", \n",
" use_auth_token=\"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\"\n",
" )\n",
"\n",
"# denoiser = SepformerSeparation.from_hparams(\n",
"# source=\"speechbrain/sepformer-wham-enhancement\", \n",
"# savedir='pretrained_models/sepformer-wham-enhancement'\n",
"# )\n",
"\n",
"embedding_inference = Inference(\n",
" embedding_model, \n",
" window=\"whole\"\n",
" )\n",
"\n",
"class_dict = {\n",
" 'autism': 0,\n",
" 'non-autism': 1\n",
" }"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"# wavFile = 'data/pronouncing-evaluation/reference/1.wav'\n",
"# denoiser.separate_file(path=wavFile) \n",
"# denoised_wavFile = f\"results/denoised/{wavFile.split('/')[-1].split('.')[0]}_denoised.wav\"\n",
"# denoiser.save_file(denoised_wavFile)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"def load_dataset(ref_audio_dir='data/pronouncing-evaluation/reference'):\n",
" ref_voice_files = glob.glob(f'{ref_audio_dir}/*.wav')\n",
" ref_voice_files = [voice_file.replace('\\\\', '/') for voice_file in ref_voice_files]\n",
" aut_voice_files = [voice_file.replace('/reference/', '/autism/') for voice_file in ref_voice_files]\n",
" non_aut_voice_files = [voice_file.replace('/reference/', '/non-autism/') for voice_file in ref_voice_files]\n",
"\n",
" embeddings_01 = np.zeros((len(ref_voice_files) * 2, 512))\n",
" embeddings_02 = np.zeros((len(ref_voice_files) * 2, 512))\n",
" labels = np.zeros(len(ref_voice_files) * 2)\n",
"\n",
" errorneous_idxs = []\n",
"\n",
" for idx in range(len(ref_voice_files)):\n",
" try:\n",
" embeddings_01[idx] = embedding_inference(ref_voice_files[idx])\n",
" embeddings_02[idx] = embedding_inference(aut_voice_files[idx])\n",
" labels[idx] = 0\n",
" except:\n",
" errorneous_idxs.append(idx)\n",
" print('Errorneous reference file: ', ref_voice_files[idx])\n",
" print('Errorneous autism file: ', aut_voice_files[idx])\n",
"\n",
" try:\n",
" embeddings_01[idx + len(ref_voice_files)] = embedding_inference(ref_voice_files[idx])\n",
" embeddings_02[idx + len(ref_voice_files)] = embedding_inference(non_aut_voice_files[idx])\n",
" labels[idx + len(ref_voice_files)] = 1\n",
" except:\n",
" errorneous_idxs.append(idx)\n",
" print('Errorneous reference file: ', ref_voice_files[idx])\n",
" print('Errorneous non-autism file: ', non_aut_voice_files[idx])\n",
"\n",
" labels = np.array(labels)\n",
" \n",
" embeddings_01 = np.delete(embeddings_01, errorneous_idxs, axis=0)\n",
" embeddings_02 = np.delete(embeddings_02, errorneous_idxs, axis=0)\n",
" labels = np.delete(labels, errorneous_idxs, axis=0)\n",
"\n",
" random_idxs = np.random.permutation(len(labels))\n",
" embeddings_01 = embeddings_01[random_idxs]\n",
" embeddings_02 = embeddings_02[random_idxs]\n",
" labels = labels[random_idxs]\n",
" \n",
" return embeddings_01, embeddings_02, labels"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Embedding 01 shape: (14, 512)\n",
"Embedding 02 shape: (14, 512)\n",
"Labels shape: (14,)\n"
]
}
],
"source": [
"embeddings_01, embeddings_02, labels = load_dataset()\n",
"\n",
"\n",
"print(\"Embedding 01 shape: \", embeddings_01.shape)\n",
"print(\"Embedding 02 shape: \", embeddings_02.shape)\n",
"print(\"Labels shape: \", labels.shape)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"def build_model():\n",
" inputs01 = tf.keras.Input(shape=(512,))\n",
" inputs02 = tf.keras.Input(shape=(512,))\n",
"\n",
" x1 = tf.keras.layers.Dense(256, activation='relu')(inputs01)\n",
" x1 = tf.keras.layers.Dropout(0.2)(x1)\n",
" x1 = tf.keras.layers.Dense(128, activation='relu')(x1)\n",
" x1 = tf.keras.layers.Dropout(0.2)(x1)\n",
" x1 = tf.keras.layers.Dense(64, activation='relu')(x1)\n",
" \n",
" x2 = tf.keras.layers.Dense(256, activation='relu')(inputs02)\n",
" x2 = tf.keras.layers.Dropout(0.2)(x2)\n",
" x2 = tf.keras.layers.Dense(128, activation='relu')(x2)\n",
" x2 = tf.keras.layers.Dropout(0.2)(x2)\n",
" x2 = tf.keras.layers.Dense(64, activation='relu')(x2)\n",
" \n",
" x = tf.keras.layers.concatenate([x1, x2])\n",
" x = tf.keras.layers.Dense(32, activation='relu')(x)\n",
" x = tf.keras.layers.Dropout(0.2)(x)\n",
" outputs = tf.keras.layers.Dense(1, activation='sigmoid')(x)\n",
"\n",
" model = tf.keras.Model(\n",
" inputs=[inputs01, inputs02], \n",
" outputs=outputs\n",
" )\n",
" \n",
" model.compile(\n",
" optimizer='adam',\n",
" loss='binary_crossentropy',\n",
" metrics=[\n",
" tf.keras.metrics.BinaryAccuracy(name='accuracy'),\n",
" tf.keras.metrics.Precision(name='precision'),\n",
" tf.keras.metrics.Recall(name='recall'),\n",
" tf.keras.metrics.AUC(name='auc')\n",
" ]\n",
" )\n",
" return model"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: \"model\"\n",
"__________________________________________________________________________________________________\n",
" Layer (type) Output Shape Param # Connected to \n",
"==================================================================================================\n",
" input_1 (InputLayer) [(None, 512)] 0 [] \n",
" \n",
" input_2 (InputLayer) [(None, 512)] 0 [] \n",
" \n",
" dense (Dense) (None, 256) 131328 ['input_1[0][0]'] \n",
" \n",
" dense_3 (Dense) (None, 256) 131328 ['input_2[0][0]'] \n",
" \n",
" dropout (Dropout) (None, 256) 0 ['dense[0][0]'] \n",
" \n",
" dropout_2 (Dropout) (None, 256) 0 ['dense_3[0][0]'] \n",
" \n",
" dense_1 (Dense) (None, 128) 32896 ['dropout[0][0]'] \n",
" \n",
" dense_4 (Dense) (None, 128) 32896 ['dropout_2[0][0]'] \n",
" \n",
" dropout_1 (Dropout) (None, 128) 0 ['dense_1[0][0]'] \n",
" \n",
" dropout_3 (Dropout) (None, 128) 0 ['dense_4[0][0]'] \n",
" \n",
" dense_2 (Dense) (None, 64) 8256 ['dropout_1[0][0]'] \n",
" \n",
" dense_5 (Dense) (None, 64) 8256 ['dropout_3[0][0]'] \n",
" \n",
" concatenate (Concatenate) (None, 128) 0 ['dense_2[0][0]', \n",
" 'dense_5[0][0]'] \n",
" \n",
" dense_6 (Dense) (None, 32) 4128 ['concatenate[0][0]'] \n",
" \n",
" dropout_4 (Dropout) (None, 32) 0 ['dense_6[0][0]'] \n",
" \n",
" dense_7 (Dense) (None, 1) 33 ['dropout_4[0][0]'] \n",
" \n",
"==================================================================================================\n",
"Total params: 349,121\n",
"Trainable params: 349,121\n",
"Non-trainable params: 0\n",
"__________________________________________________________________________________________________\n"
]
}
],
"source": [
"model = build_model()\n",
"model.summary()"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/100\n",
"7/7 [==============================] - 0s 5ms/step - loss: 2.0586e-31 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000 \n",
"Epoch 2/100\n",
"7/7 [==============================] - 0s 5ms/step - loss: 7.7467e-27 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 3/100\n",
"7/7 [==============================] - 0s 3ms/step - loss: 2.0281e-22 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 4/100\n",
"7/7 [==============================] - 0s 4ms/step - loss: 4.1828e-09 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 5/100\n",
"7/7 [==============================] - 0s 5ms/step - loss: 7.9179e-21 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 6/100\n",
"7/7 [==============================] - 0s 5ms/step - loss: 6.5887e-24 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 7/100\n",
"7/7 [==============================] - 0s 4ms/step - loss: 4.3676e-28 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 8/100\n",
"7/7 [==============================] - 0s 5ms/step - loss: 2.5126e-08 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 9/100\n",
"7/7 [==============================] - 0s 2ms/step - loss: 1.5209e-32 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 10/100\n",
"7/7 [==============================] - 0s 3ms/step - loss: 5.7262e-21 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 11/100\n",
"7/7 [==============================] - 0s 3ms/step - loss: 3.8924e-19 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000 \n",
"Epoch 12/100\n",
"7/7 [==============================] - 0s 4ms/step - loss: 0.0000e+00 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 13/100\n",
"7/7 [==============================] - 0s 4ms/step - loss: 8.1846e-21 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000 \n",
"Epoch 14/100\n",
"7/7 [==============================] - 0s 4ms/step - loss: 6.8281e-17 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000 \n",
"Epoch 15/100\n",
"7/7 [==============================] - 0s 3ms/step - loss: 1.1849e-07 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 16/100\n",
"7/7 [==============================] - 0s 5ms/step - loss: 1.4954e-18 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 17/100\n",
"7/7 [==============================] - 0s 4ms/step - loss: 0.0000e+00 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 18/100\n",
"7/7 [==============================] - 0s 4ms/step - loss: 9.5766e-09 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 19/100\n",
"7/7 [==============================] - 0s 4ms/step - loss: 2.3950e-24 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 20/100\n",
"7/7 [==============================] - 0s 3ms/step - loss: 5.0273e-33 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000 \n",
"Epoch 21/100\n",
"7/7 [==============================] - 0s 4ms/step - loss: 0.0000e+00 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 22/100\n",
"7/7 [==============================] - 0s 5ms/step - loss: 3.8447e-12 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000 \n"
]
},
{
"data": {
"text/plain": [
"<keras.callbacks.History at 0x19c08198730>"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.fit(\n",
" [embeddings_01, embeddings_02],\n",
" labels,\n",
" epochs=100,\n",
" batch_size=2,\n",
" callbacks=[\n",
" tf.keras.callbacks.EarlyStopping(\n",
" monitor='loss',\n",
" patience=10,\n",
" restore_best_weights=True\n",
" )\n",
" ] \n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [],
"source": [
"model.save('models/pronounce-validation.h5')"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"def inference_pronounce_validation(\n",
" audio_file01,\n",
" audio_file02\n",
" ):\n",
" embedding01 = embedding_inference(audio_file01)\n",
" embedding02 = embedding_inference(audio_file02)\n",
"\n",
" embedding01 = np.expand_dims(embedding01, axis=0)\n",
" embedding02 = np.expand_dims(embedding02, axis=0)\n",
"\n",
" prediction = model.predict([embedding01, embedding02], verbose=0)\n",
" prediction = prediction.squeeze()\n",
"\n",
" return 'non-autism' if float(prediction) == 1.0 else 'autism'"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'non-autism'"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"response = inference_pronounce_validation(\n",
" 'data/pronouncing-evaluation/reference/2.wav',\n",
" 'data/pronouncing-evaluation/non-autism/2.wav'\n",
" )\n",
"response"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "tf210",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.17"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment