Commit 3614c36f authored by RSRGJN Ananda's avatar RSRGJN Ananda

Merge branch 'IT20012724' into 'master'

It20012724

See merge request !6
parents b6b4b9f5 cf2a9e57
{
"cells": [
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"import os, glob\n",
"import warnings\n",
"warnings.filterwarnings(\"ignore\")\n",
"\n",
"import numpy as np\n",
"import pandas as pd\n",
"import tensorflow as tf\n",
"import matplotlib.pyplot as plt\n",
"from pyannote.audio import Model, Inference\n",
"# from speechbrain.pretrained import SepformerSeparation"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"embedding_model = Model.from_pretrained(\n",
" \"pyannote/embedding\", \n",
" use_auth_token=\"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\"\n",
" )\n",
"\n",
"# denoiser = SepformerSeparation.from_hparams(\n",
"# source=\"speechbrain/sepformer-wham-enhancement\", \n",
"# savedir='pretrained_models/sepformer-wham-enhancement'\n",
"# )\n",
"\n",
"embedding_inference = Inference(\n",
" embedding_model, \n",
" window=\"whole\"\n",
" )\n",
"\n",
"class_dict = {\n",
" 'autism': 0,\n",
" 'non-autism': 1\n",
" }"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"# wavFile = 'data/pronouncing-evaluation/reference/1.wav'\n",
"# denoiser.separate_file(path=wavFile) \n",
"# denoised_wavFile = f\"results/denoised/{wavFile.split('/')[-1].split('.')[0]}_denoised.wav\"\n",
"# denoiser.save_file(denoised_wavFile)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"def load_dataset(ref_audio_dir='data/pronouncing-evaluation/reference'):\n",
" ref_voice_files = glob.glob(f'{ref_audio_dir}/*.wav')\n",
" ref_voice_files = [voice_file.replace('\\\\', '/') for voice_file in ref_voice_files]\n",
" aut_voice_files = [voice_file.replace('/reference/', '/autism/') for voice_file in ref_voice_files]\n",
" non_aut_voice_files = [voice_file.replace('/reference/', '/non-autism/') for voice_file in ref_voice_files]\n",
"\n",
" embeddings_01 = np.zeros((len(ref_voice_files) * 2, 512))\n",
" embeddings_02 = np.zeros((len(ref_voice_files) * 2, 512))\n",
" labels = np.zeros(len(ref_voice_files) * 2)\n",
"\n",
" errorneous_idxs = []\n",
"\n",
" for idx in range(len(ref_voice_files)):\n",
" try:\n",
" embeddings_01[idx] = embedding_inference(ref_voice_files[idx])\n",
" embeddings_02[idx] = embedding_inference(aut_voice_files[idx])\n",
" labels[idx] = 0\n",
" except:\n",
" errorneous_idxs.append(idx)\n",
" print('Errorneous reference file: ', ref_voice_files[idx])\n",
" print('Errorneous autism file: ', aut_voice_files[idx])\n",
"\n",
" try:\n",
" embeddings_01[idx + len(ref_voice_files)] = embedding_inference(ref_voice_files[idx])\n",
" embeddings_02[idx + len(ref_voice_files)] = embedding_inference(non_aut_voice_files[idx])\n",
" labels[idx + len(ref_voice_files)] = 1\n",
" except:\n",
" errorneous_idxs.append(idx)\n",
" print('Errorneous reference file: ', ref_voice_files[idx])\n",
" print('Errorneous non-autism file: ', non_aut_voice_files[idx])\n",
"\n",
" labels = np.array(labels)\n",
" \n",
" embeddings_01 = np.delete(embeddings_01, errorneous_idxs, axis=0)\n",
" embeddings_02 = np.delete(embeddings_02, errorneous_idxs, axis=0)\n",
" labels = np.delete(labels, errorneous_idxs, axis=0)\n",
"\n",
" random_idxs = np.random.permutation(len(labels))\n",
" embeddings_01 = embeddings_01[random_idxs]\n",
" embeddings_02 = embeddings_02[random_idxs]\n",
" labels = labels[random_idxs]\n",
" \n",
" return embeddings_01, embeddings_02, labels"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Embedding 01 shape: (14, 512)\n",
"Embedding 02 shape: (14, 512)\n",
"Labels shape: (14,)\n"
]
}
],
"source": [
"embeddings_01, embeddings_02, labels = load_dataset()\n",
"\n",
"\n",
"print(\"Embedding 01 shape: \", embeddings_01.shape)\n",
"print(\"Embedding 02 shape: \", embeddings_02.shape)\n",
"print(\"Labels shape: \", labels.shape)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"def build_model():\n",
" inputs01 = tf.keras.Input(shape=(512,))\n",
" inputs02 = tf.keras.Input(shape=(512,))\n",
"\n",
" x1 = tf.keras.layers.Dense(256, activation='relu')(inputs01)\n",
" x1 = tf.keras.layers.Dropout(0.2)(x1)\n",
" x1 = tf.keras.layers.Dense(128, activation='relu')(x1)\n",
" x1 = tf.keras.layers.Dropout(0.2)(x1)\n",
" x1 = tf.keras.layers.Dense(64, activation='relu')(x1)\n",
" \n",
" x2 = tf.keras.layers.Dense(256, activation='relu')(inputs02)\n",
" x2 = tf.keras.layers.Dropout(0.2)(x2)\n",
" x2 = tf.keras.layers.Dense(128, activation='relu')(x2)\n",
" x2 = tf.keras.layers.Dropout(0.2)(x2)\n",
" x2 = tf.keras.layers.Dense(64, activation='relu')(x2)\n",
" \n",
" x = tf.keras.layers.concatenate([x1, x2])\n",
" x = tf.keras.layers.Dense(32, activation='relu')(x)\n",
" x = tf.keras.layers.Dropout(0.2)(x)\n",
" outputs = tf.keras.layers.Dense(1, activation='sigmoid')(x)\n",
"\n",
" model = tf.keras.Model(\n",
" inputs=[inputs01, inputs02], \n",
" outputs=outputs\n",
" )\n",
" \n",
" model.compile(\n",
" optimizer='adam',\n",
" loss='binary_crossentropy',\n",
" metrics=[\n",
" tf.keras.metrics.BinaryAccuracy(name='accuracy'),\n",
" tf.keras.metrics.Precision(name='precision'),\n",
" tf.keras.metrics.Recall(name='recall'),\n",
" tf.keras.metrics.AUC(name='auc')\n",
" ]\n",
" )\n",
" return model"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: \"model\"\n",
"__________________________________________________________________________________________________\n",
" Layer (type) Output Shape Param # Connected to \n",
"==================================================================================================\n",
" input_1 (InputLayer) [(None, 512)] 0 [] \n",
" \n",
" input_2 (InputLayer) [(None, 512)] 0 [] \n",
" \n",
" dense (Dense) (None, 256) 131328 ['input_1[0][0]'] \n",
" \n",
" dense_3 (Dense) (None, 256) 131328 ['input_2[0][0]'] \n",
" \n",
" dropout (Dropout) (None, 256) 0 ['dense[0][0]'] \n",
" \n",
" dropout_2 (Dropout) (None, 256) 0 ['dense_3[0][0]'] \n",
" \n",
" dense_1 (Dense) (None, 128) 32896 ['dropout[0][0]'] \n",
" \n",
" dense_4 (Dense) (None, 128) 32896 ['dropout_2[0][0]'] \n",
" \n",
" dropout_1 (Dropout) (None, 128) 0 ['dense_1[0][0]'] \n",
" \n",
" dropout_3 (Dropout) (None, 128) 0 ['dense_4[0][0]'] \n",
" \n",
" dense_2 (Dense) (None, 64) 8256 ['dropout_1[0][0]'] \n",
" \n",
" dense_5 (Dense) (None, 64) 8256 ['dropout_3[0][0]'] \n",
" \n",
" concatenate (Concatenate) (None, 128) 0 ['dense_2[0][0]', \n",
" 'dense_5[0][0]'] \n",
" \n",
" dense_6 (Dense) (None, 32) 4128 ['concatenate[0][0]'] \n",
" \n",
" dropout_4 (Dropout) (None, 32) 0 ['dense_6[0][0]'] \n",
" \n",
" dense_7 (Dense) (None, 1) 33 ['dropout_4[0][0]'] \n",
" \n",
"==================================================================================================\n",
"Total params: 349,121\n",
"Trainable params: 349,121\n",
"Non-trainable params: 0\n",
"__________________________________________________________________________________________________\n"
]
}
],
"source": [
"model = build_model()\n",
"model.summary()"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/100\n",
"7/7 [==============================] - 0s 5ms/step - loss: 2.0586e-31 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000 \n",
"Epoch 2/100\n",
"7/7 [==============================] - 0s 5ms/step - loss: 7.7467e-27 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 3/100\n",
"7/7 [==============================] - 0s 3ms/step - loss: 2.0281e-22 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 4/100\n",
"7/7 [==============================] - 0s 4ms/step - loss: 4.1828e-09 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 5/100\n",
"7/7 [==============================] - 0s 5ms/step - loss: 7.9179e-21 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 6/100\n",
"7/7 [==============================] - 0s 5ms/step - loss: 6.5887e-24 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 7/100\n",
"7/7 [==============================] - 0s 4ms/step - loss: 4.3676e-28 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 8/100\n",
"7/7 [==============================] - 0s 5ms/step - loss: 2.5126e-08 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 9/100\n",
"7/7 [==============================] - 0s 2ms/step - loss: 1.5209e-32 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 10/100\n",
"7/7 [==============================] - 0s 3ms/step - loss: 5.7262e-21 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 11/100\n",
"7/7 [==============================] - 0s 3ms/step - loss: 3.8924e-19 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000 \n",
"Epoch 12/100\n",
"7/7 [==============================] - 0s 4ms/step - loss: 0.0000e+00 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 13/100\n",
"7/7 [==============================] - 0s 4ms/step - loss: 8.1846e-21 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000 \n",
"Epoch 14/100\n",
"7/7 [==============================] - 0s 4ms/step - loss: 6.8281e-17 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000 \n",
"Epoch 15/100\n",
"7/7 [==============================] - 0s 3ms/step - loss: 1.1849e-07 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 16/100\n",
"7/7 [==============================] - 0s 5ms/step - loss: 1.4954e-18 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 17/100\n",
"7/7 [==============================] - 0s 4ms/step - loss: 0.0000e+00 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 18/100\n",
"7/7 [==============================] - 0s 4ms/step - loss: 9.5766e-09 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 19/100\n",
"7/7 [==============================] - 0s 4ms/step - loss: 2.3950e-24 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 20/100\n",
"7/7 [==============================] - 0s 3ms/step - loss: 5.0273e-33 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000 \n",
"Epoch 21/100\n",
"7/7 [==============================] - 0s 4ms/step - loss: 0.0000e+00 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000\n",
"Epoch 22/100\n",
"7/7 [==============================] - 0s 5ms/step - loss: 3.8447e-12 - accuracy: 1.0000 - precision: 1.0000 - recall: 1.0000 - auc: 1.0000 \n"
]
},
{
"data": {
"text/plain": [
"<keras.callbacks.History at 0x19c08198730>"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.fit(\n",
" [embeddings_01, embeddings_02],\n",
" labels,\n",
" epochs=100,\n",
" batch_size=2,\n",
" callbacks=[\n",
" tf.keras.callbacks.EarlyStopping(\n",
" monitor='loss',\n",
" patience=10,\n",
" restore_best_weights=True\n",
" )\n",
" ] \n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [],
"source": [
"model.save('models/pronounce-validation.h5')"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"def inference_pronounce_validation(\n",
" audio_file01,\n",
" audio_file02\n",
" ):\n",
" embedding01 = embedding_inference(audio_file01)\n",
" embedding02 = embedding_inference(audio_file02)\n",
"\n",
" embedding01 = np.expand_dims(embedding01, axis=0)\n",
" embedding02 = np.expand_dims(embedding02, axis=0)\n",
"\n",
" prediction = model.predict([embedding01, embedding02], verbose=0)\n",
" prediction = prediction.squeeze()\n",
"\n",
" return 'non-autism' if float(prediction) == 1.0 else 'autism'"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'non-autism'"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"response = inference_pronounce_validation(\n",
" 'data/pronouncing-evaluation/reference/2.wav',\n",
" 'data/pronouncing-evaluation/non-autism/2.wav'\n",
" )\n",
"response"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "tf210",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.17"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
import 'package:Autism/widgets/ButtonXl.dart';
import 'package:flutter/material.dart';
import 'package:Autism/MyStyles.dart' as MyStyles;
class Comp2Page1 extends StatefulWidget {
const Comp2Page1({super.key});
@override
State<Comp2Page1> createState() => _Comp2Page1State();
}
class _Comp2Page1State extends State<Comp2Page1> {
@override
Widget build(BuildContext context) {
void nextPage(String route) {
Navigator.pushNamedAndRemoveUntil(context, route,(r) => false, arguments: {});
}
return Column(
children: [
SizedBox(width: 180,child:
Image.asset('assets/images/Component 2 - img 01.png')
),
SizedBox(height: 30,),
ButtonXL(route: '/Comp2Page2', title: 'ආරම්භ කරන්න', bg: MyStyles.cbtnPrimary),
],
);
}
}
import 'dart:io';
import 'package:Autism/widgets/AudioInput.dart';
import 'package:Autism/widgets/ButtonIcon.dart';
import 'package:Autism/widgets/ImageCard.dart';
import 'package:Autism/widgets/Instructions.dart';
import 'package:flutter/material.dart';
import 'package:Autism/MyStyles.dart' as MyStyles;
import 'package:Autism/Api.dart' as Api;
import 'package:dio/dio.dart';
import 'package:flutter/services.dart';
class Comp2Page3 extends StatefulWidget {
const Comp2Page3({super.key});
@override
State<Comp2Page3> createState() => _Comp2Page3State();
}
class _Comp2Page3State extends State<Comp2Page3> {
File? recordedFile;
String text = '';
String image = '';
String audio = '';
String color = '';
Future sendRequest() async {
try {
Response response;
var dio = Dio();
ByteData assetByteData = await rootBundle.load(audio);
List<int> assetBytes = assetByteData.buffer.asUint8List();
FormData formData = FormData();
formData.files.add(
MapEntry(
'files01',
await MultipartFile.fromBytes(assetBytes, filename: 'audio1.wav'),
),
);
formData.files.add(
MapEntry(
'files02',
await MultipartFile.fromFile(recordedFile!.path,
filename: 'audio2.wav'),
),
);
// FormData formData = FormData.fromMap({
// 'audio': await MultipartFile.fromBytes(assetBytes,
// filename: 'audio1.wav'
// ),
// 'audioa': await MultipartFile.fromFile(recordedFile!.path,
// filename: 'audio2.wav'
// ),
// });
response = await dio.post(
Api.Comp2Api,
data: formData,
// onSendProgress: (int sent, int total) {
// //print((100 * sent) / total);
// print(formData.files.map((e) => print(e.value.filename)));
// },
);
if (response.statusCode == 200) {
if (response.data["pronounce-validation"] == "autism") {
setState(() {
color = "රතු පාට";
});
// var color = "Red";
} else {
// color = "Green";
setState(() {
color = "කොළ පාට";
});
}
print(response.data);
print(response.data["pronounce-validation"]);
print(color);
nextPage('/Results');
}
} catch (e) {
print(e);
}
}
void nextPage(String route) {
Navigator.pushNamed(context, route, arguments: {'color': color});
}
@override
Widget build(BuildContext context) {
final arg = ModalRoute.of(context)!.settings.arguments as Map;
text = arg['text'];
image = arg['image'];
audio = arg['audio'];
return Column(
children: [
ImageCard(image: image),
SizedBox(
height: 10,
),
Instructions(
title: 'උපදෙස්',
body: text,
),
SizedBox(
height: 10,
),
Row(
mainAxisAlignment: MainAxisAlignment.spaceEvenly,
children: [
AudioInput(
audio: 'audio',
rtn: (reco) {
setState(() {
recordedFile = reco;
});
print('recorded');
}),
recordedFile != null
? ButtonIcon(
click: () => sendRequest(),
icon: Icons.arrow_forward_ios,
bg: MyStyles.cbtnPrimary,
)
: SizedBox(),
],
)
],
);
}
}
import 'package:Autism/widgets/ButtonXl.dart';
import 'package:flutter/material.dart';
import 'package:Autism/MyStyles.dart' as MyStyles;
class Comp2Page2 extends StatefulWidget {
const Comp2Page2({super.key});
@override
State<Comp2Page2> createState() => _Comp2Page2State();
}
class _Comp2Page2State extends State<Comp2Page2> {
@override
Widget build(BuildContext context) {
void nextPage(String route) {
Navigator.pushNamedAndRemoveUntil(context, route,(r) => false, arguments: {});
}
return Column(
children: [
SizedBox(height: 30,),
ButtonXL(route: '/Comp2Page3', title: 'පින්තූරය 01', bg: MyStyles.cbtnPrimary,
arguments:{ 'text':'අක්කි එක්ක ආවෙ.', 'audio':'assets/comtwo/child05_01.wav','image':'assets/comtwo/1.jpg'}
),
SizedBox(height: 30,),
ButtonXL(route: '/Comp2Page3', title: 'පින්තූරය 02', bg: MyStyles.cbtnPrimary,
arguments:{ 'text':'අක්කා කෙනෙක් ඉන්නවා', 'audio':'assets/comtwo/child05_02.wav','image':'assets/comtwo/2.jpg'}
),
SizedBox(height: 30,),
ButtonXL(route: '/Comp2Page3', title: 'පින්තූරය 03', bg: MyStyles.cbtnPrimary,
arguments:{ 'text':'මල්ලි කෙනෙක් ඉන්නවා', 'audio':'assets/comtwo/child05_03.wav','image':'assets/comtwo/3.jpg'}
),
SizedBox(height: 30,),
ButtonXL(route: '/Comp2Page3', title: 'පින්තූරය 04', bg: MyStyles.cbtnPrimary,
arguments:{ 'text':'අම්මි එක්ක ආවෙ.', 'audio':'assets/comtwo/child05_04.wav','image':'assets/comtwo/4.jpg'}
),
SizedBox(height: 30,),
ButtonXL(route: '/Comp2Page3', title: 'පින්තූරය 05', bg: MyStyles.cbtnPrimary,
arguments:{ 'text':'තාත්ති එක්ක ආවෙ.', 'audio':'assets/comtwo/child05_05.wav','image':'assets/comtwo/5.png'}
),
SizedBox(height: 30,),
ButtonXL(route: '/Comp2Page3', title: 'පින්තූරය 06', bg: MyStyles.cbtnPrimary,
arguments:{ 'text':'චොකලට් වලට කැමතියි.', 'audio':'assets/comtwo/child05_06.wav','image':'assets/comtwo/6.png'}
),
SizedBox(height: 30,),
ButtonXL(route: '/Comp2Page3', title: 'පින්තූරය 07', bg: MyStyles.cbtnPrimary,
arguments:{ 'text':'තඩි බව්වෙක් ඉන්නවා', 'audio':'assets/comtwo/child05_07.wav','image':'assets/comtwo/7.jpg'}
),
SizedBox(height: 30,),
ButtonXL(route: '/Results', title: 'අවසාන ප්‍රතිඵලය', bg: MyStyles.cbtnPrimary,
arguments:{ 'text':'', 'audio':'','image':''}
),
SizedBox(height: 30,),
],
);
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment