Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
2
240
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
U.D.C.S.WIJESOORIYA
240
Commits
af144043
Commit
af144043
authored
Oct 07, 2022
by
Malsha Rathnasiri
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
format files
parent
6beb123f
Changes
3
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
179 additions
and
256 deletions
+179
-256
backend/backend/cms/model/predict.py
backend/backend/cms/model/predict.py
+27
-39
backend/backend/cms/model/train.py
backend/backend/cms/model/train.py
+142
-164
backend/backend/cms/views.py
backend/backend/cms/views.py
+10
-53
No files found.
backend/backend/cms/model/predict.py
View file @
af144043
#
import pickle
import
pickle
#
from keras.models import load_model
from
keras.models
import
load_model
import
numpy
as
np
# import numpy as np
import
IPython.display
as
ipd
import
random
# import IPython.display as ipd
from
sklearn.model_selection
import
train_test_split
from
sklearn.preprocessing
import
LabelEncoder
# import random
def
predict
(
samples
):
#load the trained model
model
=
load_model
(
r'./best_model_final.hdf5'
)
# from sklearn.model_selection import train_test_split
#load labels
# from sklearn.preprocessing import LabelEncoder
f1
=
open
(
'all_label.txt'
,
'rb'
)
all_label
=
pickle
.
load
(
f1
)
print
(
'loaded labels'
)
#preprocess recorded audio
le
=
LabelEncoder
()
y
=
le
.
fit_transform
(
all_label
)
classes
=
list
(
le
.
classes_
)
# def predict(samples):
def
predictSamples
(
audio
):
# model=load_model(r'./best_model_final.hdf5')
prob
=
model
.
predict
(
audio
.
reshape
(
1
,
8000
,
1
))
index
=
np
.
argmax
(
prob
[
0
])
return
classes
[
index
]
# f1 = open('all_label.txt', 'rb')
ipd
.
Audio
(
samples
,
rate
=
8000
)
# all_label = pickle.load(f1)
# print('loaded labels')
# # f2 = open('all_waves_file.txt', 'rb')
#run the prediction
# # all_wave = pickle.load(f2)
result
=
predictSamples
(
samples
)
# # print('loaded waves')
# le = LabelEncoder()
print
(
"Text:"
,
result
)
# y = le.fit_transform(all_label)
# classes = list(le.classes_)
# # train_data_file = open("train_data_file.txt", 'rb')
return
result
# # [x_tr, x_val, y_tr, y_val] = np.load(train_data_file, allow_pickle=True)
\ No newline at end of file
# # train_data_file.close()
# def predictSamples(audio):
# prob=model.predict(audio.reshape(1,8000,1))
# index=np.argmax(prob[0])
# return classes[index]
# # index=random.randint(0,len(x_val)-1)
# # samples=x_val[index].ravel()
# print(samples)
# # print("Audio:",classes[np.argmax(y_val[index])])
# ipd.Audio(samples, rate=8000)
# result = predictSamples(samples)
# print("Text:",result)
# return result
\ No newline at end of file
backend/backend/cms/model/train.py
View file @
af144043
This diff is collapsed.
Click to expand it.
backend/backend/cms/views.py
View file @
af144043
...
@@ -16,26 +16,23 @@ import librosa
...
@@ -16,26 +16,23 @@ import librosa
from
django.db.models
import
Q
from
django.db.models
import
Q
from
django.core.files.storage
import
FileSystemStorage
from
django.core.files.storage
import
FileSystemStorage
from
.models
import
Chat
,
Conversation
,
MlModel
from
.models
import
Chat
,
Conversation
,
MlModel
from
rest_framework_simplejwt.views
import
TokenObtainPairView
from
rest_framework_simplejwt.views
import
TokenObtainPairView
# from .model.train import train
from
.model.train
import
train
from
.model.predict
import
predict
# from .model.predict import predict
from
pydub
import
AudioSegment
from
pydub
import
AudioSegment
import
numpy
as
np
import
numpy
as
np
#Custom Viewset to get auth token
class
ObtainTokenPairWithUserView
(
TokenObtainPairView
):
class
ObtainTokenPairWithUserView
(
TokenObtainPairView
):
permission_classes
=
(
permissions
.
AllowAny
,)
permission_classes
=
(
permissions
.
AllowAny
,)
serializer_class
=
MyTokenObtainPairSerializer
serializer_class
=
MyTokenObtainPairSerializer
class
UserViewSet
(
viewsets
.
ModelViewSet
):
class
UserViewSet
(
viewsets
.
ModelViewSet
):
"""
"""
API endpoint that allows users to be viewed or edited.
API endpoint that allows users to be viewed or edited.
...
@@ -70,9 +67,9 @@ class MlModelViewSet(viewsets.ViewSet):
...
@@ -70,9 +67,9 @@ class MlModelViewSet(viewsets.ViewSet):
permission_classes
=
[
permissions
.
AllowAny
]
permission_classes
=
[
permissions
.
AllowAny
]
parser_classes
=
[
MultiPartParser
]
parser_classes
=
[
MultiPartParser
]
#Custom api to add sample chats
@
action
(
detail
=
False
)
@
action
(
detail
=
False
)
def
addChats
(
*
args
,
**
kwargs
):
def
addChats
(
*
args
,
**
kwargs
):
admin
=
User
.
objects
.
get
(
username
=
'admin'
)
admin
=
User
.
objects
.
get
(
username
=
'admin'
)
user2
=
User
.
objects
.
get
(
username
=
'user2'
)
user2
=
User
.
objects
.
get
(
username
=
'user2'
)
...
@@ -106,6 +103,7 @@ class MlModelViewSet(viewsets.ViewSet):
...
@@ -106,6 +103,7 @@ class MlModelViewSet(viewsets.ViewSet):
conversation
=
convo
,
from_user_id
=
chat
[
'from'
],
to_user_id
=
chat
[
'to'
],
messsage
=
chat
[
'message'
])
conversation
=
convo
,
from_user_id
=
chat
[
'from'
],
to_user_id
=
chat
[
'to'
],
messsage
=
chat
[
'message'
])
object
.
save
()
object
.
save
()
#Custom api to train the model
@
action
(
detail
=
False
)
@
action
(
detail
=
False
)
def
train
(
*
args
,
**
kwargs
):
def
train
(
*
args
,
**
kwargs
):
print
(
'Function ran'
)
print
(
'Function ran'
)
...
@@ -113,6 +111,7 @@ class MlModelViewSet(viewsets.ViewSet):
...
@@ -113,6 +111,7 @@ class MlModelViewSet(viewsets.ViewSet):
print
(
results
)
print
(
results
)
return
Response
({
'success'
:
True
})
return
Response
({
'success'
:
True
})
#Custom api to convert audio to text
@
action
(
detail
=
False
,
methods
=
[
"POST"
])
@
action
(
detail
=
False
,
methods
=
[
"POST"
])
def
detect
(
self
,
request
,
*
args
,
**
kwargs
):
def
detect
(
self
,
request
,
*
args
,
**
kwargs
):
print
(
'Function ran'
)
print
(
'Function ran'
)
...
@@ -135,7 +134,7 @@ class MlModelViewSet(viewsets.ViewSet):
...
@@ -135,7 +134,7 @@ class MlModelViewSet(viewsets.ViewSet):
print
(
'---------------------------------------------------------'
)
print
(
'---------------------------------------------------------'
)
if
(
samples
.
shape
[
0
]
>
8000
):
if
(
samples
.
shape
[
0
]
>
8000
):
print
(
'grate
e
r -------------------------------------'
)
print
(
'grater -------------------------------------'
)
samples
=
samples
[
-
8000
:]
samples
=
samples
[
-
8000
:]
print
(
samples
.
shape
)
print
(
samples
.
shape
)
else
:
else
:
...
@@ -144,57 +143,17 @@ class MlModelViewSet(viewsets.ViewSet):
...
@@ -144,57 +143,17 @@ class MlModelViewSet(viewsets.ViewSet):
samples
=
np
.
concatenate
((
samples
,
new_arr
))
samples
=
np
.
concatenate
((
samples
,
new_arr
))
print
(
samples
.
shape
)
print
(
samples
.
shape
)
# audio_file = request.data
results
=
predict
(
samples
)
# # Using File storage to save file for future converting
# fs = FileSystemStorage()
# file_name = fs.save(audio_file.name, audio_file)
# audio_file_url = fs.url(file_name)
# # Preparing paths for convertion
# upstream = os.path.dirname(os.path.dirname(os.path.dirname(
# os.path.abspath(__file__))))
# path = os.path.join(upstream, 'media', audio_file.name)
# mp3_filename = '.'.join([audio_file.name.split('.')[0], 'mp3'])
# new_path = os.path.join(
# upstream, 'media', mp3_filename)
# # Converting to mp3 here
# wma_version = AudioSegment.from_file(path, "wav")
# wma_version.export(new_path, format="mp3")
# user_id = self.request.user.id
# # I was trying to create a Track instance, the mp3 get's saved but it is not being saved using location specified in models.
# track = Track.objects.create(user_id=user_id)
# track.file.name = mp3_filename
# track.save()
results
=
{}
# predict(samples)
print
(
results
)
print
(
results
)
return
Response
({
'success'
:
True
,
'result'
:
results
})
return
Response
({
'success'
:
True
,
'result'
:
results
})
class
ChatViewSet
(
viewsets
.
ModelViewSet
):
class
ChatViewSet
(
viewsets
.
ModelViewSet
):
queryset
=
Chat
.
objects
.
all
()
.
order_by
(
'-timestamp'
)
queryset
=
Chat
.
objects
.
all
()
.
order_by
(
'-timestamp'
)
serializer_class
=
ChatSerializer
serializer_class
=
ChatSerializer
permission_classes
=
[
permissions
.
IsAuthenticated
]
#change to permissions.isAutheniticated
permission_classes
=
[
permissions
.
IsAuthenticated
]
@
action
(
methods
=
'POST'
,
detail
=
True
)
def
getChats
(
self
,
request
,
*
args
,
**
kwargs
):
#conversation_id hardcoded as we dont have many conversations at the moment
chats
=
Chat
.
objects
\
.
filter
(
conversation_id
=
1
)
\
.
filter
(
(
Q
(
from_user__user_id
=
request
.
user
.
id
)
|
Q
(
to_user__user_id
=
request
.
user
.
id
))
)
.
order_by
(
'-timestamp'
)
.
values
()
return
Response
({
chats
})
#ovveride defualt list action to get chats of specific user conversation
def
list
(
self
,
request
,
pk
=
None
):
def
list
(
self
,
request
,
pk
=
None
):
if
pk
==
None
:
if
pk
==
None
:
chats
=
Chat
.
objects
\
chats
=
Chat
.
objects
\
.
filter
(
conversation_id
=
1
)
\
.
filter
(
conversation_id
=
1
)
\
...
@@ -210,7 +169,6 @@ class ChatViewSet(viewsets.ModelViewSet):
...
@@ -210,7 +169,6 @@ class ChatViewSet(viewsets.ModelViewSet):
serializer
=
self
.
get_serializer
(
page
,
many
=
True
)
serializer
=
self
.
get_serializer
(
page
,
many
=
True
)
result_set
=
serializer
.
data
result_set
=
serializer
.
data
return
Response
(
result_set
)
return
Response
(
result_set
)
def
get_result_set
(
self
,
chats
):
def
get_result_set
(
self
,
chats
):
...
@@ -218,7 +176,6 @@ class ChatViewSet(viewsets.ModelViewSet):
...
@@ -218,7 +176,6 @@ class ChatViewSet(viewsets.ModelViewSet):
return
result_set
return
result_set
class
ConversationViewSet
(
viewsets
.
ModelViewSet
):
class
ConversationViewSet
(
viewsets
.
ModelViewSet
):
queryset
=
Conversation
.
objects
.
all
()
.
order_by
(
'id'
)
queryset
=
Conversation
.
objects
.
all
()
.
order_by
(
'id'
)
serializer_class
=
ConversationSerializer
serializer_class
=
ConversationSerializer
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment