Commit 7cfb27ff authored by Malsha Rathnasiri's avatar Malsha Rathnasiri

Merge branch 'IT18094664' of http://gitlab.sliit.lk/2022-240/240 into IT18094664

parents 3ffcde55 af144043
# import pickle
# from keras.models import load_model
import pickle
from keras.models import load_model
import numpy as np
# import numpy as np
import IPython.display as ipd
import random
# import IPython.display as ipd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
# import random
def predict(samples):
#load the trained model
model=load_model(r'./best_model_final.hdf5')
# from sklearn.model_selection import train_test_split
# from sklearn.preprocessing import LabelEncoder
#load labels
f1 = open('all_label.txt', 'rb')
all_label = pickle.load(f1)
print('loaded labels')
#preprocess recorded audio
le = LabelEncoder()
y = le.fit_transform(all_label)
classes = list(le.classes_)
# def predict(samples):
# model=load_model(r'./best_model_final.hdf5')
def predictSamples(audio):
prob=model.predict(audio.reshape(1,8000,1))
index=np.argmax(prob[0])
return classes[index]
# f1 = open('all_label.txt', 'rb')
# all_label = pickle.load(f1)
# print('loaded labels')
ipd.Audio(samples, rate=8000)
# # f2 = open('all_waves_file.txt', 'rb')
# # all_wave = pickle.load(f2)
# # print('loaded waves')
#run the prediction
result = predictSamples(samples)
# le = LabelEncoder()
# y = le.fit_transform(all_label)
# classes = list(le.classes_)
print("Text:",result)
# # train_data_file = open("train_data_file.txt", 'rb')
# # [x_tr, x_val, y_tr, y_val] = np.load(train_data_file, allow_pickle=True)
# # train_data_file.close()
# def predictSamples(audio):
# prob=model.predict(audio.reshape(1,8000,1))
# index=np.argmax(prob[0])
# return classes[index]
# # index=random.randint(0,len(x_val)-1)
# # samples=x_val[index].ravel()
# print(samples)
# # print("Audio:",classes[np.argmax(y_val[index])])
# ipd.Audio(samples, rate=8000)
# result = predictSamples(samples)
# print("Text:",result)
# return result
\ No newline at end of file
return result
\ No newline at end of file
This diff is collapsed.
......@@ -16,26 +16,23 @@ import librosa
from django.db.models import Q
from django.core.files.storage import FileSystemStorage
from .models import Chat, Conversation, MlModel
from rest_framework_simplejwt.views import TokenObtainPairView
# from .model.train import train
# from .model.predict import predict
from .model.train import train
from .model.predict import predict
from pydub import AudioSegment
import numpy as np
#Custom Viewset to get auth token
class ObtainTokenPairWithUserView(TokenObtainPairView):
permission_classes = (permissions.AllowAny,)
serializer_class = MyTokenObtainPairSerializer
class UserViewSet(viewsets.ModelViewSet):
"""
API endpoint that allows users to be viewed or edited.
......@@ -70,9 +67,9 @@ class MlModelViewSet(viewsets.ViewSet):
permission_classes = [permissions.AllowAny]
parser_classes = [MultiPartParser]
#Custom api to add sample chats
@action(detail=False)
def addChats(*args, **kwargs):
admin = User.objects.get(username='admin')
user2 = User.objects.get(username='user2')
......@@ -106,6 +103,7 @@ class MlModelViewSet(viewsets.ViewSet):
conversation=convo, from_user_id=chat['from'], to_user_id=chat['to'], messsage=chat['message'])
object.save()
#Custom api to train the model
@action(detail=False)
def train(*args, **kwargs):
print('Function ran')
......@@ -113,6 +111,7 @@ class MlModelViewSet(viewsets.ViewSet):
print(results)
return Response({'success': True})
#Custom api to convert audio to text
@action(detail=False, methods=["POST"])
def detect(self, request, *args, **kwargs):
print('Function ran')
......@@ -135,7 +134,7 @@ class MlModelViewSet(viewsets.ViewSet):
print('---------------------------------------------------------')
if(samples.shape[0] > 8000):
print('grateer -------------------------------------')
print('grater -------------------------------------')
samples = samples[-8000:]
print(samples.shape)
else:
......@@ -144,57 +143,17 @@ class MlModelViewSet(viewsets.ViewSet):
samples = np.concatenate((samples, new_arr))
print(samples.shape)
# audio_file = request.data
# # Using File storage to save file for future converting
# fs = FileSystemStorage()
# file_name = fs.save(audio_file.name, audio_file)
# audio_file_url = fs.url(file_name)
# # Preparing paths for convertion
# upstream = os.path.dirname(os.path.dirname(os.path.dirname(
# os.path.abspath(__file__))))
# path = os.path.join(upstream, 'media', audio_file.name)
# mp3_filename = '.'.join([audio_file.name.split('.')[0], 'mp3'])
# new_path = os.path.join(
# upstream, 'media', mp3_filename)
# # Converting to mp3 here
# wma_version = AudioSegment.from_file(path, "wav")
# wma_version.export(new_path, format="mp3")
# user_id = self.request.user.id
# # I was trying to create a Track instance, the mp3 get's saved but it is not being saved using location specified in models.
# track = Track.objects.create(user_id=user_id)
# track.file.name = mp3_filename
# track.save()
results = {} # predict(samples)
results = predict(samples)
print(results)
return Response({'success': True, 'result': results})
class ChatViewSet(viewsets.ModelViewSet):
queryset = Chat.objects.all().order_by('-timestamp')
serializer_class = ChatSerializer
permission_classes = [permissions.IsAuthenticated] #change to permissions.isAutheniticated
@action(methods='POST', detail=True)
def getChats(self, request, *args, **kwargs):
#conversation_id hardcoded as we dont have many conversations at the moment
chats = Chat.objects \
.filter(conversation_id=1) \
.filter(
(Q(from_user__user_id=request.user.id)
| Q(to_user__user_id=request.user.id))
).order_by('-timestamp').values()
return Response({chats})
permission_classes = [permissions.IsAuthenticated]
#ovveride defualt list action to get chats of specific user conversation
def list(self, request, pk=None):
if pk == None:
chats = Chat.objects \
.filter(conversation_id=1) \
......@@ -210,7 +169,6 @@ class ChatViewSet(viewsets.ModelViewSet):
serializer = self.get_serializer(page, many=True)
result_set = serializer.data
return Response(result_set)
def get_result_set(self, chats):
......@@ -218,7 +176,6 @@ class ChatViewSet(viewsets.ModelViewSet):
return result_set
class ConversationViewSet(viewsets.ModelViewSet):
queryset = Conversation.objects.all().order_by('id')
serializer_class = ConversationSerializer
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment