Commit da1a7ae0 authored by Nagahawatta S.S's avatar Nagahawatta S.S

Answer Validation Model

parent 15c30659
{
"cells": [
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import fasttext, glob\n",
"import tensorflow as tf\n",
"from datasets import Dataset, Audio\n",
"from transformers import WhisperProcessor, WhisperForConditionalGeneration"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Warning : `load_model` does not return WordVectorModel or SupervisedModel any more, but a `FastText` object which is very similar.\n"
]
}
],
"source": [
"s2t_processor = WhisperProcessor.from_pretrained(\"Subhaka/whisper-small-Sinhala-Fine_Tune\")\n",
"s2t_model = WhisperForConditionalGeneration.from_pretrained(\"Subhaka/whisper-small-Sinhala-Fine_Tune\")\n",
"s2t_forced_decoder_ids = s2t_processor.get_decoder_prompt_ids(\n",
" language=\"sinhala\", \n",
" task=\"transcribe\"\n",
" )\n",
"embedding_model = fasttext.load_model(\"models/cc.si.300.bin\")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"def load_audio(audio_file):\n",
"\n",
" audio_data = Dataset.from_dict(\n",
" {\"audio\": [audio_file]}\n",
" ).cast_column(\"audio\", Audio())\n",
" audio_data = audio_data.cast_column(\n",
" \"audio\", \n",
" Audio(sampling_rate=16000)\n",
" )\n",
" audio_data = audio_data[0]['audio']['array']\n",
" return audio_data\n",
"\n",
"def transcribe(audio_file):\n",
" audio_data = load_audio(audio_file)\n",
" input_features = s2t_processor(\n",
" audio_data, \n",
" sampling_rate=16000, \n",
" return_tensors=\"pt\"\n",
" ).input_features\n",
" predicted_ids = s2t_model.generate(\n",
" input_features, \n",
" forced_decoder_ids=s2t_forced_decoder_ids\n",
" )\n",
" \n",
" # transcription = s2t_processor.batch_decode(predicted_ids)\n",
" transcription = s2t_processor.batch_decode(\n",
" predicted_ids, \n",
" skip_special_tokens=True\n",
" )\n",
" return transcription[0]"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'අක්යෙ කාවෙ'"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"transcription = transcribe('data/pronouncing-evaluation/reference/1.wav')\n",
"transcription"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"def load_dataset(ref_audio_dir='data/answering-evaluation/reference'):\n",
" ref_voice_files = glob.glob(f'{ref_audio_dir}/*.wav')\n",
" ref_voice_files = [voice_file.replace('\\\\', '/') for voice_file in ref_voice_files]\n",
" aut_voice_files = [voice_file.replace('/reference/', '/autism/') for voice_file in ref_voice_files]\n",
" non_aut_voice_files = [voice_file.replace('/reference/', '/non-autism/') for voice_file in ref_voice_files]\n",
"\n",
" embeddings_01 = np.zeros((len(ref_voice_files) * 2, 300))\n",
" embeddings_02 = np.zeros((len(ref_voice_files) * 2, 300))\n",
" labels = np.zeros(len(ref_voice_files) * 2)\n",
"\n",
" errorneous_idxs = []\n",
"\n",
" for idx in range(len(ref_voice_files)):\n",
" try:\n",
" transcription_01 = transcribe(ref_voice_files[idx])\n",
" transcription_02 = transcribe(aut_voice_files[idx])\n",
"\n",
" embeddings_01[idx] = embedding_model.get_sentence_vector(transcription_01)\n",
" embeddings_02[idx] = embedding_model.get_sentence_vector(transcription_02)\n",
" labels[idx] = 0\n",
" except:\n",
" errorneous_idxs.append(idx)\n",
" print('Errorneous reference file: ', ref_voice_files[idx])\n",
" print('Errorneous autism file: ', aut_voice_files[idx])\n",
"\n",
" try:\n",
" transcription_01 = transcribe(ref_voice_files[idx])\n",
" transcription_02 = transcribe(non_aut_voice_files[idx]) \n",
"\n",
" embeddings_01[idx + len(ref_voice_files)] = embedding_model.get_sentence_vector(transcription_01)\n",
" embeddings_02[idx + len(ref_voice_files)] = embedding_model.get_sentence_vector(transcription_02)\n",
" labels[idx + len(ref_voice_files)] = 1\n",
" except:\n",
" errorneous_idxs.append(idx)\n",
" print('Errorneous reference file: ', ref_voice_files[idx])\n",
" print('Errorneous non-autism file: ', non_aut_voice_files[idx])\n",
"\n",
" labels = np.array(labels)\n",
" \n",
" embeddings_01 = np.delete(embeddings_01, errorneous_idxs, axis=0)\n",
" embeddings_02 = np.delete(embeddings_02, errorneous_idxs, axis=0)\n",
" labels = np.delete(labels, errorneous_idxs, axis=0)\n",
"\n",
" random_idxs = np.random.permutation(len(labels))\n",
" embeddings_01 = embeddings_01[random_idxs]\n",
" embeddings_02 = embeddings_02[random_idxs]\n",
" labels = labels[random_idxs]\n",
" \n",
" return embeddings_01, embeddings_02, labels"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Embedding 01 shape: (14, 300)\n",
"Embedding 02 shape: (14, 300)\n",
"Labels shape: (14,)\n"
]
}
],
"source": [
"embeddings_01, embeddings_02, labels = load_dataset()\n",
"\n",
"\n",
"print(\"Embedding 01 shape: \", embeddings_01.shape)\n",
"print(\"Embedding 02 shape: \", embeddings_02.shape)\n",
"print(\"Labels shape: \", labels.shape)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"def build_model():\n",
" inputs01 = tf.keras.Input(shape=(300,))\n",
" inputs02 = tf.keras.Input(shape=(300,))\n",
"\n",
" x1 = tf.keras.layers.Dense(300, activation='relu')(inputs01)\n",
" x1 = tf.keras.layers.Dropout(0.2)(x1)\n",
" x1 = tf.keras.layers.Dense(150, activation='relu')(x1)\n",
" x1 = tf.keras.layers.Dropout(0.2)(x1)\n",
" x1 = tf.keras.layers.Dense(30, activation='relu')(x1)\n",
" \n",
" x2 = tf.keras.layers.Dense(300, activation='relu')(inputs02)\n",
" x2 = tf.keras.layers.Dropout(0.2)(x2)\n",
" x2 = tf.keras.layers.Dense(150, activation='relu')(x2)\n",
" x2 = tf.keras.layers.Dropout(0.2)(x2)\n",
" x2 = tf.keras.layers.Dense(30, activation='relu')(x2)\n",
" \n",
" x = tf.keras.layers.concatenate([x1, x2])\n",
" x = tf.keras.layers.Dense(30, activation='relu')(x)\n",
" x = tf.keras.layers.Dropout(0.2)(x)\n",
" outputs = tf.keras.layers.Dense(1, activation='sigmoid')(x)\n",
"\n",
" model = tf.keras.Model(\n",
" inputs=[inputs01, inputs02], \n",
" outputs=outputs\n",
" )\n",
" \n",
" model.compile(\n",
" optimizer='adam',\n",
" loss='binary_crossentropy',\n",
" metrics=[\n",
" tf.keras.metrics.BinaryAccuracy(name='accuracy'),\n",
" tf.keras.metrics.Precision(name='precision'),\n",
" tf.keras.metrics.Recall(name='recall'),\n",
" tf.keras.metrics.AUC(name='auc')\n",
" ]\n",
" )\n",
" return model"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: \"model\"\n",
"__________________________________________________________________________________________________\n",
" Layer (type) Output Shape Param # Connected to \n",
"==================================================================================================\n",
" input_1 (InputLayer) [(None, 300)] 0 [] \n",
" \n",
" input_2 (InputLayer) [(None, 300)] 0 [] \n",
" \n",
" dense (Dense) (None, 300) 90300 ['input_1[0][0]'] \n",
" \n",
" dense_3 (Dense) (None, 300) 90300 ['input_2[0][0]'] \n",
" \n",
" dropout (Dropout) (None, 300) 0 ['dense[0][0]'] \n",
" \n",
" dropout_2 (Dropout) (None, 300) 0 ['dense_3[0][0]'] \n",
" \n",
" dense_1 (Dense) (None, 150) 45150 ['dropout[0][0]'] \n",
" \n",
" dense_4 (Dense) (None, 150) 45150 ['dropout_2[0][0]'] \n",
" \n",
" dropout_1 (Dropout) (None, 150) 0 ['dense_1[0][0]'] \n",
" \n",
" dropout_3 (Dropout) (None, 150) 0 ['dense_4[0][0]'] \n",
" \n",
" dense_2 (Dense) (None, 30) 4530 ['dropout_1[0][0]'] \n",
" \n",
" dense_5 (Dense) (None, 30) 4530 ['dropout_3[0][0]'] \n",
" \n",
" concatenate (Concatenate) (None, 60) 0 ['dense_2[0][0]', \n",
" 'dense_5[0][0]'] \n",
" \n",
" dense_6 (Dense) (None, 30) 1830 ['concatenate[0][0]'] \n",
" \n",
" dropout_4 (Dropout) (None, 30) 0 ['dense_6[0][0]'] \n",
" \n",
" dense_7 (Dense) (None, 1) 31 ['dropout_4[0][0]'] \n",
" \n",
"==================================================================================================\n",
"Total params: 281,821\n",
"Trainable params: 281,821\n",
"Non-trainable params: 0\n",
"__________________________________________________________________________________________________\n"
]
}
],
"source": [
"model = build_model()\n",
"model.summary()"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/100\n",
"7/7 [==============================] - 5s 4ms/step - loss: 0.7084 - accuracy: 0.3571 - precision: 0.0000e+00 - recall: 0.0000e+00 - auc: 0.0918 \n",
"Epoch 2/100\n",
"7/7 [==============================] - 0s 3ms/step - loss: 0.6849 - accuracy: 0.5000 - precision: 0.0000e+00 - recall: 0.0000e+00 - auc: 0.7041\n",
"Epoch 3/100\n",
"7/7 [==============================] - 0s 3ms/step - loss: 0.6786 - accuracy: 0.5714 - precision: 0.6667 - recall: 0.2857 - auc: 0.7449 \n",
"Epoch 4/100\n",
"7/7 [==============================] - 0s 3ms/step - loss: 0.6603 - accuracy: 0.7143 - precision: 0.8000 - recall: 0.5714 - auc: 0.8265\n",
"Epoch 5/100\n",
"7/7 [==============================] - 0s 3ms/step - loss: 0.6189 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 6/100\n",
"7/7 [==============================] - 0s 3ms/step - loss: 0.5713 - accuracy: 0.8571 - precision: 0.8571 - recall: 0.8571 - auc: 0.9490\n",
"Epoch 7/100\n",
"7/7 [==============================] - 0s 3ms/step - loss: 0.4748 - accuracy: 0.9286 - precision: 1.0000 - recall: 0.8571 - auc: 1.0000\n",
"Epoch 8/100\n",
"7/7 [==============================] - 0s 3ms/step - loss: 0.3560 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 9/100\n",
"7/7 [==============================] - 0s 3ms/step - loss: 0.2779 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000 \n",
"Epoch 10/100\n",
"7/7 [==============================] - 0s 3ms/step - loss: 0.1555 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 11/100\n",
"7/7 [==============================] - 0s 3ms/step - loss: 0.0864 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 12/100\n",
"7/7 [==============================] - 0s 3ms/step - loss: 0.0426 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000 \n",
"Epoch 13/100\n",
"7/7 [==============================] - 0s 3ms/step - loss: 0.0163 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 14/100\n",
"7/7 [==============================] - 0s 2ms/step - loss: 0.0202 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 15/100\n",
"7/7 [==============================] - 0s 2ms/step - loss: 0.0082 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 16/100\n",
"7/7 [==============================] - 0s 3ms/step - loss: 0.0080 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 17/100\n",
"7/7 [==============================] - 0s 3ms/step - loss: 0.0053 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 18/100\n",
"7/7 [==============================] - 0s 3ms/step - loss: 0.0030 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 19/100\n",
"7/7 [==============================] - 0s 2ms/step - loss: 0.0036 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 20/100\n",
"7/7 [==============================] - 0s 3ms/step - loss: 0.0016 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000 \n",
"Epoch 21/100\n",
"7/7 [==============================] - 0s 3ms/step - loss: 0.0011 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000 \n",
"Epoch 22/100\n",
"7/7 [==============================] - 0s 2ms/step - loss: 0.0019 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000 \n",
"Epoch 23/100\n",
"7/7 [==============================] - 0s 2ms/step - loss: 0.0012 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000 \n",
"Epoch 24/100\n",
"7/7 [==============================] - 0s 3ms/step - loss: 0.0030 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 25/100\n",
"7/7 [==============================] - 0s 2ms/step - loss: 0.0016 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000 \n",
"Epoch 26/100\n",
"7/7 [==============================] - 0s 3ms/step - loss: 7.1956e-04 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 27/100\n",
"7/7 [==============================] - 0s 2ms/step - loss: 0.0011 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000 \n",
"Epoch 28/100\n",
"7/7 [==============================] - 0s 2ms/step - loss: 5.0398e-04 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 29/100\n",
"7/7 [==============================] - 0s 2ms/step - loss: 0.0099 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 30/100\n",
"7/7 [==============================] - 0s 3ms/step - loss: 9.1052e-04 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 31/100\n",
"7/7 [==============================] - 0s 2ms/step - loss: 0.0025 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000 \n",
"Epoch 32/100\n",
"7/7 [==============================] - 0s 2ms/step - loss: 0.0011 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 33/100\n",
"7/7 [==============================] - 0s 2ms/step - loss: 0.0026 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 34/100\n",
"7/7 [==============================] - 0s 2ms/step - loss: 6.4648e-04 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 35/100\n",
"7/7 [==============================] - 0s 2ms/step - loss: 0.0014 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 36/100\n",
"7/7 [==============================] - 0s 2ms/step - loss: 6.5733e-04 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 37/100\n",
"7/7 [==============================] - 0s 3ms/step - loss: 3.5510e-04 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 38/100\n",
"7/7 [==============================] - 0s 2ms/step - loss: 0.0017 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000 \n",
"Epoch 39/100\n",
"7/7 [==============================] - 0s 3ms/step - loss: 2.6147e-04 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000 \n",
"Epoch 40/100\n",
"7/7 [==============================] - 0s 3ms/step - loss: 2.5398e-04 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 41/100\n",
"7/7 [==============================] - 0s 3ms/step - loss: 1.6020e-04 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000 \n",
"Epoch 42/100\n",
"7/7 [==============================] - 0s 3ms/step - loss: 1.5038e-04 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 43/100\n",
"7/7 [==============================] - 0s 3ms/step - loss: 7.2330e-04 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 44/100\n",
"7/7 [==============================] - 0s 2ms/step - loss: 5.6800e-04 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 45/100\n",
"7/7 [==============================] - 0s 2ms/step - loss: 0.0034 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 46/100\n",
"7/7 [==============================] - 0s 3ms/step - loss: 4.7010e-05 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 47/100\n",
"7/7 [==============================] - 0s 2ms/step - loss: 1.5909e-04 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 48/100\n",
"7/7 [==============================] - 0s 2ms/step - loss: 6.7970e-04 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 49/100\n",
"7/7 [==============================] - 0s 2ms/step - loss: 9.3579e-05 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 50/100\n",
"7/7 [==============================] - 0s 2ms/step - loss: 1.6038e-05 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 51/100\n",
"7/7 [==============================] - 0s 3ms/step - loss: 1.2203e-04 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000 \n",
"Epoch 52/100\n",
"7/7 [==============================] - 0s 2ms/step - loss: 5.4229e-05 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 53/100\n",
"7/7 [==============================] - 0s 3ms/step - loss: 1.0325e-05 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 54/100\n",
"7/7 [==============================] - 0s 3ms/step - loss: 2.4042e-04 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000 \n",
"Epoch 55/100\n",
"7/7 [==============================] - 0s 3ms/step - loss: 4.1154e-05 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000 \n",
"Epoch 56/100\n",
"7/7 [==============================] - 0s 2ms/step - loss: 6.5307e-04 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 57/100\n",
"7/7 [==============================] - 0s 2ms/step - loss: 5.3958e-05 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 58/100\n",
"7/7 [==============================] - 0s 2ms/step - loss: 3.9178e-04 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 59/100\n",
"7/7 [==============================] - 0s 2ms/step - loss: 3.9669e-04 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 60/100\n",
"7/7 [==============================] - 0s 2ms/step - loss: 7.6923e-05 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 61/100\n",
"7/7 [==============================] - 0s 2ms/step - loss: 0.0211 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000 \n",
"Epoch 62/100\n",
"7/7 [==============================] - 0s 2ms/step - loss: 3.5197e-05 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000 \n",
"Epoch 63/100\n",
"7/7 [==============================] - 0s 4ms/step - loss: 1.8985e-05 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000 \n"
]
},
{
"data": {
"text/plain": [
"<keras.callbacks.History at 0x22f8eceb760>"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.fit(\n",
" [embeddings_01, embeddings_02],\n",
" labels,\n",
" epochs=100,\n",
" batch_size=2,\n",
" callbacks=[\n",
" tf.keras.callbacks.EarlyStopping(\n",
" monitor='loss',\n",
" patience=10,\n",
" restore_best_weights=True\n",
" )\n",
" ] \n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"model.save('models/answering-evaluation.h5')"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"def inference_answer_validation(\n",
" audio_file01,\n",
" audio_file02\n",
" ):\n",
" transcription_01 = transcribe(audio_file01)\n",
" transcription_02 = transcribe(audio_file02)\n",
"\n",
" embedding01 = embedding_model.get_sentence_vector(transcription_01)\n",
" embedding02 = embedding_model.get_sentence_vector(transcription_02)\n",
"\n",
" embedding01 = np.expand_dims(embedding01, axis=0)\n",
" embedding02 = np.expand_dims(embedding02, axis=0)\n",
"\n",
" prediction = model.predict([embedding01, embedding02])\n",
" prediction = prediction.squeeze()\n",
"\n",
" # return 'autism' if prediction < 0.5 else 'non-autism' \n",
" print(prediction) "
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1/1 [==============================] - 0s 27ms/step\n",
"0.9999763\n"
]
}
],
"source": [
"response = inference_answer_validation(\n",
" 'data/answering-evaluation/reference/Answer1.wav',\n",
" 'data/answering-evaluation/non-autism/Answer1.wav'\n",
" )\n",
"response"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "tf210",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.17"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment