Commit 56052c65 authored by Manilka Shalinda's avatar Manilka Shalinda 💻

Upload New File

parent 768e2e06
import librosa
import tensorflow as tf
import numpy as np
import os
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
from fastapi.responses import FileResponse
from typing import Optional
from pydub import AudioSegment
import io
app = FastAPI()
SAVED_MODEL_PATH = "../model.h5"
SAMPLES_TO_CONSIDER = 22050
class KeywordSpottingService:
"""Singleton class for keyword spotting inference with trained models."""
model = None
_mapping = [
"dataset\\backward",
"dataset\\bed",
"dataset\\bird",
"dataset\\cat",
"dataset\\dog",
"dataset\\down",
"dataset\\eight",
"dataset\\five",
"dataset\\follow",
"dataset\\forward",
"dataset\\four",
"dataset\\go",
"dataset\\happy",
"dataset\\house",
"dataset\\learn",
"dataset\\left",
"dataset\\nine",
"dataset\\no",
"dataset\\off",
"dataset\\on",
"dataset\\one",
"dataset\\right",
"dataset\\seven",
"dataset\\six",
"dataset\\stop",
"dataset\\three",
"dataset\\tree",
"dataset\\two",
"dataset\\up",
"dataset\\visual",
"dataset\\wow",
"dataset\\yes",
"dataset\\zero"
]
_instance = None
def __init__(self):
self.model = tf.keras.models.load_model(SAVED_MODEL_PATH)
def predict(self, audio_data: np.ndarray, sample_rate: int) -> str:
"""Predict keyword from audio data."""
MFCCs = self.preprocess(audio_data, sample_rate)
MFCCs = MFCCs[np.newaxis, ..., np.newaxis]
predictions = self.model.predict(MFCCs)
predicted_index = np.argmax(predictions)
predicted_keyword = self._mapping[predicted_index]
# Extracting only the word from the directory path
predicted_word = predicted_keyword.split("\\")[-1]
return predicted_word
def preprocess(self, audio_data: np.ndarray, sample_rate: int, num_mfcc: int = 13, n_fft: int = 2048, hop_length: int = 512) -> np.ndarray:
"""Extract MFCCs from audio data."""
if len(audio_data) >= SAMPLES_TO_CONSIDER:
audio_data = audio_data[:SAMPLES_TO_CONSIDER]
else:
# Pad the signal with zeros or truncate it to the required length
shortage = SAMPLES_TO_CONSIDER - len(audio_data)
audio_data = np.pad(audio_data, (0, shortage), mode='constant')
MFCCs = librosa.feature.mfcc(y=audio_data, sr=sample_rate, n_mfcc=num_mfcc, n_fft=n_fft, hop_length=hop_length)
return MFCCs.T
kss = KeywordSpottingService()
@app.post("/predict/upload/")
async def predict_keyword_upload(file: UploadFile = File(...)):
"""Endpoint to predict keyword from uploaded audio file."""
file_path = "temp_audio.wav"
try:
with open(file_path, "wb") as f:
f.write(await file.read())
audio_data, sample_rate = librosa.load(file_path)
predicted_keyword = kss.predict(audio_data, sample_rate)
return {"predicted_keyword": predicted_keyword}
finally:
os.remove(file_path)
@app.post("/predict/microphone/")
async def predict_keyword_microphone(audio_file: UploadFile = File(...)):
"""Endpoint to predict keyword from audio file sent from microphone."""
try:
audio_data = await audio_file.read()
audio = AudioSegment.from_file(io.BytesIO(audio_data))
audio_data = np.array(audio.get_array_of_samples())
sample_rate = audio.frame_rate
predicted_keyword = kss.predict(audio_data, sample_rate)
return {"predicted_keyword": predicted_keyword}
except Exception as e:
raise HTTPException(status_code=400, detail="Invalid audio file format")
@app.get("/")
async def get_index():
return FileResponse("index.html")
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment