Commit b6b4b9f5 authored by Nagahawatta S.S's avatar Nagahawatta S.S

Merge branch 'IT20090562' into 'master'

It20090562

See merge request !5
parents 15c30659 bce3060f
{
"cells": [
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import fasttext, glob\n",
"import tensorflow as tf\n",
"from datasets import Dataset, Audio\n",
"from transformers import WhisperProcessor, WhisperForConditionalGeneration"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Warning : `load_model` does not return WordVectorModel or SupervisedModel any more, but a `FastText` object which is very similar.\n"
]
}
],
"source": [
"s2t_processor = WhisperProcessor.from_pretrained(\"Subhaka/whisper-small-Sinhala-Fine_Tune\")\n",
"s2t_model = WhisperForConditionalGeneration.from_pretrained(\"Subhaka/whisper-small-Sinhala-Fine_Tune\")\n",
"s2t_forced_decoder_ids = s2t_processor.get_decoder_prompt_ids(\n",
" language=\"sinhala\", \n",
" task=\"transcribe\"\n",
" )\n",
"embedding_model = fasttext.load_model(\"models/cc.si.300.bin\")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"def load_audio(audio_file):\n",
"\n",
" audio_data = Dataset.from_dict(\n",
" {\"audio\": [audio_file]}\n",
" ).cast_column(\"audio\", Audio())\n",
" audio_data = audio_data.cast_column(\n",
" \"audio\", \n",
" Audio(sampling_rate=16000)\n",
" )\n",
" audio_data = audio_data[0]['audio']['array']\n",
" return audio_data\n",
"\n",
"def transcribe(audio_file):\n",
" audio_data = load_audio(audio_file)\n",
" input_features = s2t_processor(\n",
" audio_data, \n",
" sampling_rate=16000, \n",
" return_tensors=\"pt\"\n",
" ).input_features\n",
" predicted_ids = s2t_model.generate(\n",
" input_features, \n",
" forced_decoder_ids=s2t_forced_decoder_ids\n",
" )\n",
" \n",
" # transcription = s2t_processor.batch_decode(predicted_ids)\n",
" transcription = s2t_processor.batch_decode(\n",
" predicted_ids, \n",
" skip_special_tokens=True\n",
" )\n",
" return transcription[0]"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'අක්යෙ කාවෙ'"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"transcription = transcribe('data/pronouncing-evaluation/reference/1.wav')\n",
"transcription"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"def load_dataset(ref_audio_dir='data/answering-evaluation/reference'):\n",
" ref_voice_files = glob.glob(f'{ref_audio_dir}/*.wav')\n",
" ref_voice_files = [voice_file.replace('\\\\', '/') for voice_file in ref_voice_files]\n",
" aut_voice_files = [voice_file.replace('/reference/', '/autism/') for voice_file in ref_voice_files]\n",
" non_aut_voice_files = [voice_file.replace('/reference/', '/non-autism/') for voice_file in ref_voice_files]\n",
"\n",
" embeddings_01 = np.zeros((len(ref_voice_files) * 2, 300))\n",
" embeddings_02 = np.zeros((len(ref_voice_files) * 2, 300))\n",
" labels = np.zeros(len(ref_voice_files) * 2)\n",
"\n",
" errorneous_idxs = []\n",
"\n",
" for idx in range(len(ref_voice_files)):\n",
" try:\n",
" transcription_01 = transcribe(ref_voice_files[idx])\n",
" transcription_02 = transcribe(aut_voice_files[idx])\n",
"\n",
" embeddings_01[idx] = embedding_model.get_sentence_vector(transcription_01)\n",
" embeddings_02[idx] = embedding_model.get_sentence_vector(transcription_02)\n",
" labels[idx] = 0\n",
" except:\n",
" errorneous_idxs.append(idx)\n",
" print('Errorneous reference file: ', ref_voice_files[idx])\n",
" print('Errorneous autism file: ', aut_voice_files[idx])\n",
"\n",
" try:\n",
" transcription_01 = transcribe(ref_voice_files[idx])\n",
" transcription_02 = transcribe(non_aut_voice_files[idx]) \n",
"\n",
" embeddings_01[idx + len(ref_voice_files)] = embedding_model.get_sentence_vector(transcription_01)\n",
" embeddings_02[idx + len(ref_voice_files)] = embedding_model.get_sentence_vector(transcription_02)\n",
" labels[idx + len(ref_voice_files)] = 1\n",
" except:\n",
" errorneous_idxs.append(idx)\n",
" print('Errorneous reference file: ', ref_voice_files[idx])\n",
" print('Errorneous non-autism file: ', non_aut_voice_files[idx])\n",
"\n",
" labels = np.array(labels)\n",
" \n",
" embeddings_01 = np.delete(embeddings_01, errorneous_idxs, axis=0)\n",
" embeddings_02 = np.delete(embeddings_02, errorneous_idxs, axis=0)\n",
" labels = np.delete(labels, errorneous_idxs, axis=0)\n",
"\n",
" random_idxs = np.random.permutation(len(labels))\n",
" embeddings_01 = embeddings_01[random_idxs]\n",
" embeddings_02 = embeddings_02[random_idxs]\n",
" labels = labels[random_idxs]\n",
" \n",
" return embeddings_01, embeddings_02, labels"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Embedding 01 shape: (14, 300)\n",
"Embedding 02 shape: (14, 300)\n",
"Labels shape: (14,)\n"
]
}
],
"source": [
"embeddings_01, embeddings_02, labels = load_dataset()\n",
"\n",
"\n",
"print(\"Embedding 01 shape: \", embeddings_01.shape)\n",
"print(\"Embedding 02 shape: \", embeddings_02.shape)\n",
"print(\"Labels shape: \", labels.shape)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"def build_model():\n",
" inputs01 = tf.keras.Input(shape=(300,))\n",
" inputs02 = tf.keras.Input(shape=(300,))\n",
"\n",
" x1 = tf.keras.layers.Dense(300, activation='relu')(inputs01)\n",
" x1 = tf.keras.layers.Dropout(0.2)(x1)\n",
" x1 = tf.keras.layers.Dense(150, activation='relu')(x1)\n",
" x1 = tf.keras.layers.Dropout(0.2)(x1)\n",
" x1 = tf.keras.layers.Dense(30, activation='relu')(x1)\n",
" \n",
" x2 = tf.keras.layers.Dense(300, activation='relu')(inputs02)\n",
" x2 = tf.keras.layers.Dropout(0.2)(x2)\n",
" x2 = tf.keras.layers.Dense(150, activation='relu')(x2)\n",
" x2 = tf.keras.layers.Dropout(0.2)(x2)\n",
" x2 = tf.keras.layers.Dense(30, activation='relu')(x2)\n",
" \n",
" x = tf.keras.layers.concatenate([x1, x2])\n",
" x = tf.keras.layers.Dense(30, activation='relu')(x)\n",
" x = tf.keras.layers.Dropout(0.2)(x)\n",
" outputs = tf.keras.layers.Dense(1, activation='sigmoid')(x)\n",
"\n",
" model = tf.keras.Model(\n",
" inputs=[inputs01, inputs02], \n",
" outputs=outputs\n",
" )\n",
" \n",
" model.compile(\n",
" optimizer='adam',\n",
" loss='binary_crossentropy',\n",
" metrics=[\n",
" tf.keras.metrics.BinaryAccuracy(name='accuracy'),\n",
" tf.keras.metrics.Precision(name='precision'),\n",
" tf.keras.metrics.Recall(name='recall'),\n",
" tf.keras.metrics.AUC(name='auc')\n",
" ]\n",
" )\n",
" return model"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: \"model\"\n",
"__________________________________________________________________________________________________\n",
" Layer (type) Output Shape Param # Connected to \n",
"==================================================================================================\n",
" input_1 (InputLayer) [(None, 300)] 0 [] \n",
" \n",
" input_2 (InputLayer) [(None, 300)] 0 [] \n",
" \n",
" dense (Dense) (None, 300) 90300 ['input_1[0][0]'] \n",
" \n",
" dense_3 (Dense) (None, 300) 90300 ['input_2[0][0]'] \n",
" \n",
" dropout (Dropout) (None, 300) 0 ['dense[0][0]'] \n",
" \n",
" dropout_2 (Dropout) (None, 300) 0 ['dense_3[0][0]'] \n",
" \n",
" dense_1 (Dense) (None, 150) 45150 ['dropout[0][0]'] \n",
" \n",
" dense_4 (Dense) (None, 150) 45150 ['dropout_2[0][0]'] \n",
" \n",
" dropout_1 (Dropout) (None, 150) 0 ['dense_1[0][0]'] \n",
" \n",
" dropout_3 (Dropout) (None, 150) 0 ['dense_4[0][0]'] \n",
" \n",
" dense_2 (Dense) (None, 30) 4530 ['dropout_1[0][0]'] \n",
" \n",
" dense_5 (Dense) (None, 30) 4530 ['dropout_3[0][0]'] \n",
" \n",
" concatenate (Concatenate) (None, 60) 0 ['dense_2[0][0]', \n",
" 'dense_5[0][0]'] \n",
" \n",
" dense_6 (Dense) (None, 30) 1830 ['concatenate[0][0]'] \n",
" \n",
" dropout_4 (Dropout) (None, 30) 0 ['dense_6[0][0]'] \n",
" \n",
" dense_7 (Dense) (None, 1) 31 ['dropout_4[0][0]'] \n",
" \n",
"==================================================================================================\n",
"Total params: 281,821\n",
"Trainable params: 281,821\n",
"Non-trainable params: 0\n",
"__________________________________________________________________________________________________\n"
]
}
],
"source": [
"model = build_model()\n",
"model.summary()"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/100\n",
"7/7 [==============================] - 5s 4ms/step - loss: 0.7084 - accuracy: 0.3571 - precision: 0.0000e+00 - recall: 0.0000e+00 - auc: 0.0918 \n",
"Epoch 2/100\n",
"7/7 [==============================] - 0s 3ms/step - loss: 0.6849 - accuracy: 0.5000 - precision: 0.0000e+00 - recall: 0.0000e+00 - auc: 0.7041\n",
"Epoch 3/100\n",
"7/7 [==============================] - 0s 3ms/step - loss: 0.6786 - accuracy: 0.5714 - precision: 0.6667 - recall: 0.2857 - auc: 0.7449 \n",
"Epoch 4/100\n",
"7/7 [==============================] - 0s 3ms/step - loss: 0.6603 - accuracy: 0.7143 - precision: 0.8000 - recall: 0.5714 - auc: 0.8265\n",
"Epoch 5/100\n",
"7/7 [==============================] - 0s 3ms/step - loss: 0.6189 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 6/100\n",
"7/7 [==============================] - 0s 3ms/step - loss: 0.5713 - accuracy: 0.8571 - precision: 0.8571 - recall: 0.8571 - auc: 0.9490\n",
"Epoch 7/100\n",
"7/7 [==============================] - 0s 3ms/step - loss: 0.4748 - accuracy: 0.9286 - precision: 1.0000 - recall: 0.8571 - auc: 1.0000\n",
"Epoch 8/100\n",
"7/7 [==============================] - 0s 3ms/step - loss: 0.3560 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 9/100\n",
"7/7 [==============================] - 0s 3ms/step - loss: 0.2779 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000 \n",
"Epoch 10/100\n",
"7/7 [==============================] - 0s 3ms/step - loss: 0.1555 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 11/100\n",
"7/7 [==============================] - 0s 3ms/step - loss: 0.0864 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 12/100\n",
"7/7 [==============================] - 0s 3ms/step - loss: 0.0426 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000 \n",
"Epoch 13/100\n",
"7/7 [==============================] - 0s 3ms/step - loss: 0.0163 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 14/100\n",
"7/7 [==============================] - 0s 2ms/step - loss: 0.0202 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 15/100\n",
"7/7 [==============================] - 0s 2ms/step - loss: 0.0082 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 16/100\n",
"7/7 [==============================] - 0s 3ms/step - loss: 0.0080 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 17/100\n",
"7/7 [==============================] - 0s 3ms/step - loss: 0.0053 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 18/100\n",
"7/7 [==============================] - 0s 3ms/step - loss: 0.0030 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 19/100\n",
"7/7 [==============================] - 0s 2ms/step - loss: 0.0036 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 20/100\n",
"7/7 [==============================] - 0s 3ms/step - loss: 0.0016 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000 \n",
"Epoch 21/100\n",
"7/7 [==============================] - 0s 3ms/step - loss: 0.0011 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000 \n",
"Epoch 22/100\n",
"7/7 [==============================] - 0s 2ms/step - loss: 0.0019 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000 \n",
"Epoch 23/100\n",
"7/7 [==============================] - 0s 2ms/step - loss: 0.0012 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000 \n",
"Epoch 24/100\n",
"7/7 [==============================] - 0s 3ms/step - loss: 0.0030 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 25/100\n",
"7/7 [==============================] - 0s 2ms/step - loss: 0.0016 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000 \n",
"Epoch 26/100\n",
"7/7 [==============================] - 0s 3ms/step - loss: 7.1956e-04 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 27/100\n",
"7/7 [==============================] - 0s 2ms/step - loss: 0.0011 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000 \n",
"Epoch 28/100\n",
"7/7 [==============================] - 0s 2ms/step - loss: 5.0398e-04 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 29/100\n",
"7/7 [==============================] - 0s 2ms/step - loss: 0.0099 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 30/100\n",
"7/7 [==============================] - 0s 3ms/step - loss: 9.1052e-04 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 31/100\n",
"7/7 [==============================] - 0s 2ms/step - loss: 0.0025 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000 \n",
"Epoch 32/100\n",
"7/7 [==============================] - 0s 2ms/step - loss: 0.0011 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 33/100\n",
"7/7 [==============================] - 0s 2ms/step - loss: 0.0026 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 34/100\n",
"7/7 [==============================] - 0s 2ms/step - loss: 6.4648e-04 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 35/100\n",
"7/7 [==============================] - 0s 2ms/step - loss: 0.0014 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 36/100\n",
"7/7 [==============================] - 0s 2ms/step - loss: 6.5733e-04 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 37/100\n",
"7/7 [==============================] - 0s 3ms/step - loss: 3.5510e-04 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 38/100\n",
"7/7 [==============================] - 0s 2ms/step - loss: 0.0017 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000 \n",
"Epoch 39/100\n",
"7/7 [==============================] - 0s 3ms/step - loss: 2.6147e-04 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000 \n",
"Epoch 40/100\n",
"7/7 [==============================] - 0s 3ms/step - loss: 2.5398e-04 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 41/100\n",
"7/7 [==============================] - 0s 3ms/step - loss: 1.6020e-04 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000 \n",
"Epoch 42/100\n",
"7/7 [==============================] - 0s 3ms/step - loss: 1.5038e-04 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 43/100\n",
"7/7 [==============================] - 0s 3ms/step - loss: 7.2330e-04 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 44/100\n",
"7/7 [==============================] - 0s 2ms/step - loss: 5.6800e-04 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 45/100\n",
"7/7 [==============================] - 0s 2ms/step - loss: 0.0034 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 46/100\n",
"7/7 [==============================] - 0s 3ms/step - loss: 4.7010e-05 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 47/100\n",
"7/7 [==============================] - 0s 2ms/step - loss: 1.5909e-04 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 48/100\n",
"7/7 [==============================] - 0s 2ms/step - loss: 6.7970e-04 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 49/100\n",
"7/7 [==============================] - 0s 2ms/step - loss: 9.3579e-05 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 50/100\n",
"7/7 [==============================] - 0s 2ms/step - loss: 1.6038e-05 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 51/100\n",
"7/7 [==============================] - 0s 3ms/step - loss: 1.2203e-04 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000 \n",
"Epoch 52/100\n",
"7/7 [==============================] - 0s 2ms/step - loss: 5.4229e-05 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 53/100\n",
"7/7 [==============================] - 0s 3ms/step - loss: 1.0325e-05 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 54/100\n",
"7/7 [==============================] - 0s 3ms/step - loss: 2.4042e-04 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000 \n",
"Epoch 55/100\n",
"7/7 [==============================] - 0s 3ms/step - loss: 4.1154e-05 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000 \n",
"Epoch 56/100\n",
"7/7 [==============================] - 0s 2ms/step - loss: 6.5307e-04 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 57/100\n",
"7/7 [==============================] - 0s 2ms/step - loss: 5.3958e-05 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 58/100\n",
"7/7 [==============================] - 0s 2ms/step - loss: 3.9178e-04 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 59/100\n",
"7/7 [==============================] - 0s 2ms/step - loss: 3.9669e-04 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 60/100\n",
"7/7 [==============================] - 0s 2ms/step - loss: 7.6923e-05 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 61/100\n",
"7/7 [==============================] - 0s 2ms/step - loss: 0.0211 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000 \n",
"Epoch 62/100\n",
"7/7 [==============================] - 0s 2ms/step - loss: 3.5197e-05 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000 \n",
"Epoch 63/100\n",
"7/7 [==============================] - 0s 4ms/step - loss: 1.8985e-05 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000 \n"
]
},
{
"data": {
"text/plain": [
"<keras.callbacks.History at 0x22f8eceb760>"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.fit(\n",
" [embeddings_01, embeddings_02],\n",
" labels,\n",
" epochs=100,\n",
" batch_size=2,\n",
" callbacks=[\n",
" tf.keras.callbacks.EarlyStopping(\n",
" monitor='loss',\n",
" patience=10,\n",
" restore_best_weights=True\n",
" )\n",
" ] \n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"model.save('models/answering-evaluation.h5')"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"def inference_answer_validation(\n",
" audio_file01,\n",
" audio_file02\n",
" ):\n",
" transcription_01 = transcribe(audio_file01)\n",
" transcription_02 = transcribe(audio_file02)\n",
"\n",
" embedding01 = embedding_model.get_sentence_vector(transcription_01)\n",
" embedding02 = embedding_model.get_sentence_vector(transcription_02)\n",
"\n",
" embedding01 = np.expand_dims(embedding01, axis=0)\n",
" embedding02 = np.expand_dims(embedding02, axis=0)\n",
"\n",
" prediction = model.predict([embedding01, embedding02])\n",
" prediction = prediction.squeeze()\n",
"\n",
" # return 'autism' if prediction < 0.5 else 'non-autism' \n",
" print(prediction) "
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1/1 [==============================] - 0s 27ms/step\n",
"0.9999763\n"
]
}
],
"source": [
"response = inference_answer_validation(\n",
" 'data/answering-evaluation/reference/Answer1.wav',\n",
" 'data/answering-evaluation/non-autism/Answer1.wav'\n",
" )\n",
"response"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "tf210",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.17"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
import 'package:Autism/widgets/ButtonXl.dart';
import 'package:flutter/material.dart';
import 'package:Autism/MyStyles.dart' as MyStyles;
class Comp3Page1 extends StatefulWidget {
const Comp3Page1({super.key});
@override
State<Comp3Page1> createState() => _Comp3Page1State();
}
class _Comp3Page1State extends State<Comp3Page1> {
@override
Widget build(BuildContext context) {
void nextPage(String route) {
Navigator.pushNamedAndRemoveUntil(context, route,(r) => false, arguments: {});
}
return Column(
children: [
SizedBox(width: 180,child:
Image.asset('assets/images/Component 3 - img 01.png')
),
SizedBox(height: 30,),
ButtonXL(route: '/Comp3Page2', title: 'ආරම්භ කරන්න', bg: MyStyles.cbtnPrimary),
],
);
}
}
import 'package:Autism/widgets/ButtonXl.dart';
import 'package:Autism/widgets/Instructions.dart';
import 'package:flutter/material.dart';
import 'package:Autism/MyStyles.dart' as MyStyles;
class Comp3Page2 extends StatefulWidget {
const Comp3Page2({super.key});
@override
State<Comp3Page2> createState() => _Comp3Page2State();
}
class _Comp3Page2State extends State<Comp3Page2> {
@override
Widget build(BuildContext context) {
void nextPage(String route) {
Navigator.pushNamedAndRemoveUntil(context, route,(r) => false, arguments: {});
}
return Column(
children: [
SizedBox(height: 10,),
Instructions(title: 'උපදෙස්',body: 'ඊලග පියවරට බොත්තම එබිමෙන් පසුව දී පින්තුර ඇසුරෙන් ලමයාගෙන් ප්‍රශ්න අසන්න. පිලිතුර පටිගත කිරීමට මයික්‍රෆොනය සලකුන ඔබන්න. පටිගත කිරිම අවසන් වු පසු ඊතල සලකුන ඔබා ප්‍රතිපල පිටුවට පිවිසෙන්න.',),
SizedBox(height: 30,),
ButtonXL(route: '/Comp3Page3', title: 'ඊලග පියවරට', bg: MyStyles.cbtnPrimary),
],
);
}
}
import 'package:Autism/widgets/ButtonXl.dart';
import 'package:flutter/material.dart';
import 'package:Autism/MyStyles.dart' as MyStyles;
class Comp3Page3 extends StatefulWidget {
const Comp3Page3({super.key});
@override
State<Comp3Page3> createState() => _Comp3Page3State();
}
class _Comp3Page3State extends State<Comp3Page3> {
@override
Widget build(BuildContext context) {
void nextPage(String route) {
Navigator.pushNamedAndRemoveUntil(context, route,(r) => false, arguments: {});
}
return Column(
children: [
SizedBox(height: 30,),
ButtonXL(route: '/Comp3Page4', title: 'ප්‍රශ්නය 01', bg: MyStyles.cbtnPrimary,
arguments:{ 'text':'පුතා කවුරු එක්කද ආවේ?', 'audio':'assets/comthree/Answer1.wav','image':'assets/comthree/1.png'}
),
SizedBox(height: 30,),
ButtonXL(route: '/Comp3Page4', title: 'ප්‍රශ්නය 02', bg: MyStyles.cbtnPrimary,
arguments:{ 'text':'මේ පින්තුරයේ ඇති බොලයේ පාට මොකක්ද?', 'audio':'assets/comthree/Answer2.wav','image':'assets/comthree/2.jpg'}
),
SizedBox(height: 30,),
ButtonXL(route: '/Comp3Page4', title: 'ප්‍රශ්නය 03', bg: MyStyles.cbtnPrimary,
arguments:{ 'text':'මේ පින්තුරයේ ඇති වාහනය මොකක්ද?', 'audio':'assets/comthree/Answer3.wav','image':'assets/comthree/3.jpg'}
),
SizedBox(height: 30,),
ButtonXL(route: '/Comp3Page4', title: 'ප්‍රශ්නය 04', bg: MyStyles.cbtnPrimary,
arguments:{ 'text':'පුතාට මොනවද බොන්න ඕනෙ?', 'audio':'assets/comthree/Answer4.wav','image':'assets/comthree/4.jpg'}
),
SizedBox(height: 30,),
ButtonXL(route: '/Comp3Page4', title: 'ප්‍රශ්නය 05', bg: MyStyles.cbtnPrimary,
arguments:{ 'text':'මේ පින්තුරේ ඉන්නෙ කවුද?', 'audio':'assets/comthree/Answer5.wav','image':'assets/comthree/5.png'}
),
SizedBox(height: 30,),
ButtonXL(route: '/Results', title: 'අවසාන ප්‍රතිඵලය', bg: MyStyles.cbtnPrimary,
arguments:{ 'text':'', 'audio':'','image':''}
),
SizedBox(height: 30,),
],
);
}
}
import 'dart:io';
import 'package:Autism/widgets/AudioInput.dart';
import 'package:Autism/widgets/ButtonIcon.dart';
import 'package:Autism/widgets/ImageCard.dart';
import 'package:Autism/widgets/Instructions.dart';
import 'package:flutter/material.dart';
import 'package:Autism/MyStyles.dart' as MyStyles;
import 'package:Autism/Api.dart' as Api;
import 'package:dio/dio.dart';
import 'package:flutter/services.dart';
class Comp3Page4 extends StatefulWidget {
const Comp3Page4({super.key});
@override
State<Comp3Page4> createState() => _Comp3Page4State();
}
class _Comp3Page4State extends State<Comp3Page4> {
File? recordedFile;
String text = '';
String image = '';
String audio = '';
String color = '';
Future sendRequest() async {
try {
Response response;
var dio = Dio();
ByteData assetByteData = await rootBundle.load(audio);
List<int> assetBytes = assetByteData.buffer.asUint8List();
FormData formData = FormData();
formData.files.add(
MapEntry(
'files01',
await MultipartFile.fromBytes(assetBytes, filename: 'audio1.wav'),
),
);
formData.files.add(
MapEntry(
'files02',
await MultipartFile.fromFile(recordedFile!.path,
filename: 'audio2.wav'),
),
);
// FormData formData = FormData.fromMap({
// 'audio': await MultipartFile.fromBytes(assetBytes,
// filename: 'audio1.wav'
// ),
// 'audioa': await MultipartFile.fromFile(recordedFile!.path,
// filename: 'audio2.wav'
// ),
// });
response = await dio.post(
Api.Comp3Api,
data: formData,
// onSendProgress: (int sent, int total) {
// //print((100 * sent) / total);
// print(formData.files);
// },
);
if (response.statusCode == 200) {
if (response.data["answer-evaluation"] == "autism") {
setState(() {
color = "රතු පාට";
});
// var color = "Red";
} else {
// color = "Green";
setState(() {
color = "කොළ පාට";
});
}
print(response.data);
print(response.data["answer-evaluation"]);
print(color);
nextPage('/Results');
}
} catch (e) {
// print(e);
}
}
void nextPage(String route) {
Navigator.pushNamed(context, route, arguments: {'color': color});
}
@override
Widget build(BuildContext context) {
final arg = ModalRoute.of(context)!.settings.arguments as Map;
text = arg['text'];
image = arg['image'];
audio = arg['audio'];
return Column(
children: [
Container(
alignment: Alignment.topLeft,
padding: EdgeInsets.all(8.0),
child: IconButton(
icon: Icon(Icons.arrow_back),
onPressed: () {
Navigator.pop(context);
},
),
),
ImageCard(image: image),
SizedBox(
height: 10,
),
Instructions(
title: 'ප්‍රශ්නය',
body: text,
),
SizedBox(
height: 10,
),
Row(
mainAxisAlignment: MainAxisAlignment.spaceEvenly,
children: [
AudioInput(
audio: 'audio',
rtn: (reco) {
setState(() {
recordedFile = reco;
});
print('recorded');
}),
recordedFile != null
? ButtonIcon(
click: () => sendRequest(),
icon: Icons.arrow_forward_ios,
bg: MyStyles.cbtnPrimary,
)
: SizedBox(),
],
)
],
);
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment