Commit 16228ea8 authored by Malsha Rathnasiri's avatar Malsha Rathnasiri

minor changes

parent b4c8f105
import { BACKEND_URL } from "./constants" import { BACKEND_URL } from "./constants"
import AsyncStorage from '@react-native-async-storage/async-storage'; import AsyncStorage from '@react-native-async-storage/async-storage';
import { useNavigation } from "@react-navigation/native"; import { useNavigation } from "@react-navigation/native";
import axios from "axios";
//"eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJ0b2tlbl90eXBlIjoiYWNjZXNzIiwiZXhwIjoxNjYxMzU2NDA5LCJpYXQiOjE2NjAxNDY4MDksImp0aSI6ImFhMTljMmY2ZDNkMzRiNDdhZmZmM2FjMzVjNzI4MWJhIiwidXNlcl9pZCI6MX0.IVzibo_Rf2xzoT1J5o1L3zwu3mco6ODcNPC-7imu3Lo" //"eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJ0b2tlbl90eXBlIjoiYWNjZXNzIiwiZXhwIjoxNjYxMzU2NDA5LCJpYXQiOjE2NjAxNDY4MDksImp0aSI6ImFhMTljMmY2ZDNkMzRiNDdhZmZmM2FjMzVjNzI4MWJhIiwidXNlcl9pZCI6MX0.IVzibo_Rf2xzoT1J5o1L3zwu3mco6ODcNPC-7imu3Lo"
...@@ -8,21 +11,41 @@ import { useNavigation } from "@react-navigation/native"; ...@@ -8,21 +11,41 @@ import { useNavigation } from "@react-navigation/native";
const getHeaders = async () => { const getHeaders = async () => {
var token = await AsyncStorage.getItem('access_token') var token = await AsyncStorage.getItem('access_token')
return new Headers({ authorization: `Bearer ${token}`, 'Content-Type': 'application/json' }) return {
Authorization: `JWT ${token}`, 'Content-Type': 'application/json',
}
} }
export const create = async (resource, values) => { export const create = async (resource, values) => {
const headers = await getHeaders() const headers = getHeaders()
console.log({headers}) console.log({ headers }, 'create')
try { try {
return fetch(`${BACKEND_URL}/${resource}/`, { method: 'POST', body: JSON.stringify(values), headers }) return fetch(`${BACKEND_URL}/${resource}/`, { method: 'POST', body: JSON.stringify(values), headers })
// return Axios.post(`${BACKEND_URL}/${resource}/`, values)
} }
catch (e) { catch (e) {
console.log(e) console.log(e)
} }
} }
export const getList = async (resource) => { export const getList = async (resource, params) => {
const url = new URL(`${BACKEND_URL}/${resource}/`)
if (params) {
Object.keys(params).map(key => {
url.searchParams.append(key, params[key])
})
}
const headers = await getHeaders()
console.log(headers, 'getList')
return fetch(url, { method: 'GET', headers: headers })
return axios.get(url.toString(), null, { headers })
}
export const getOne = async (resource, id) => {
const url = new URL(`${BACKEND_URL}/${resource}/${id}/`)
const headers = await getHeaders() const headers = await getHeaders()
return fetch(`${BACKEND_URL}/${resource}/`, { method: 'GET', headers: headers }) console.log(headers, 'getONe')
return fetch(url, { method: "GET", headers: headers })
} }
\ No newline at end of file
// export const BACKEND_URL = "http://192.168.8.103:8000" // export const BACKEND_URL = "http://192.168.8.103:8000"
export const BACKEND_ADDRESS = "dcf1-2402-d000-a500-101a-cd9c-120b-91b8-7774.in.ngrok.io" import { Platform } from 'react-native'
export const BACKEND_ADDRESS = Platform.OS == 'web' ? "127.0.0.1:8000" : "93e8-2401-dd00-10-20-e170-748c-b618-e8f5.in.ngrok.io"
export const BACKEND_URL = `http://${BACKEND_ADDRESS}` export const BACKEND_URL = `http://${BACKEND_ADDRESS}`
...@@ -6,7 +6,7 @@ import _ from 'lodash' ...@@ -6,7 +6,7 @@ import _ from 'lodash'
import EditScreenInfo from '../components/EditScreenInfo'; import EditScreenInfo from '../components/EditScreenInfo';
// import { Text, View } from '../components/Themed'; // import { Text, View } from '../components/Themed';
import { RootTabScreenProps } from '../types'; import { RootTabScreenProps } from '../types';
import { create, getList } from '../api/api'; import { create, getList, getOne } from '../api/api';
import { BACKEND_ADDRESS } from '../api/constants'; import { BACKEND_ADDRESS } from '../api/constants';
import Ionicons from '@expo/vector-icons/Ionicons'; import Ionicons from '@expo/vector-icons/Ionicons';
import { CONVO_DEFAULT_ICON_COLOR, styles } from '../util/styles'; import { CONVO_DEFAULT_ICON_COLOR, styles } from '../util/styles';
...@@ -37,7 +37,8 @@ export default function ChatScreen({ navigation }) { ...@@ -37,7 +37,8 @@ export default function ChatScreen({ navigation }) {
const [detectedText, setDetectedText] = useState("") const [detectedText, setDetectedText] = useState("")
const [playinId, setPlayingId] = useState(3) const [playinId, setPlayingId] = useState(3)
const [chatDetails, setChatDetails] = useState({ from_user: 1, to_user: 2, conversation_id: 1 }) const [chatDetails, setChatDetails] = useState()
const [chats, setChats] = useState([]) const [chats, setChats] = useState([])
const [input, setInput] = useState('test') const [input, setInput] = useState('test')
const [loading, setLoading] = useState(true) const [loading, setLoading] = useState(true)
...@@ -59,12 +60,15 @@ export default function ChatScreen({ navigation }) { ...@@ -59,12 +60,15 @@ export default function ChatScreen({ navigation }) {
console.log({ chats }) console.log({ chats })
} }
useEffect(() => { if (chatDetails) { loadChats(); setLoading(false) } }, [chatDetails])
// const ws = new WebSocket(`ws://${BACKEND_ADDRESS}/chatSocket/`) // const ws = new WebSocket(`ws://${BACKEND_ADDRESS}/chatSocket/`)
useEffect(() => { useEffect(() => {
loadChats() loadChatDetails()
// startWebsocket() // startWebsocket()
// loadSampleChats() // loadSampleChats()
setLoading(false)
}, []) }, [])
// const startWebsocket = () => { // const startWebsocket = () => {
...@@ -87,12 +91,22 @@ export default function ChatScreen({ navigation }) { ...@@ -87,12 +91,22 @@ export default function ChatScreen({ navigation }) {
// } // }
const loadChatDetails = async () => {
await getOne('conversations', 1).then(res => {
return res.json()
}).then(res => {
console.log(res)
setChatDetails(res)
})
}
const loadChats = async () => { const loadChats = async () => {
await getList('chats').then(res => { await getList('chats').then(res => {
return res.json() return res.json()
}).then(res => { }).then(res => {
// console.log(res) console.log("load chats")
const chats = res.results const chats = res.results || []
const sections = [...new Set(chats.map(chat => new Date(chat.timestamp).setHours(0, 0, 0, 0)))]; const sections = [...new Set(chats.map(chat => new Date(chat.timestamp).setHours(0, 0, 0, 0)))];
const sectionChats = sections.map(section => ({ title: section, data: chats.filter(chat => new Date(chat.timestamp).setHours(0, 0, 0, 0) == section) })) const sectionChats = sections.map(section => ({ title: section, data: chats.filter(chat => new Date(chat.timestamp).setHours(0, 0, 0, 0) == section) }))
setChats(sectionChats) setChats(sectionChats)
...@@ -110,8 +124,8 @@ export default function ChatScreen({ navigation }) { ...@@ -110,8 +124,8 @@ export default function ChatScreen({ navigation }) {
setInput('') setInput('')
loadChats() loadChats()
} }
catch(e){ catch (e) {
Toast.show({title: 'Error sending message. try again!', ...ERROR_TOAST_PROPS}) Toast.show({ title: 'Error sending message. try again!', ...ERROR_TOAST_PROPS })
} }
} }
...@@ -130,12 +144,18 @@ export default function ChatScreen({ navigation }) { ...@@ -130,12 +144,18 @@ export default function ChatScreen({ navigation }) {
{loading ? <ActivityIndicator /> : {loading ? <ActivityIndicator /> :
<SectionList <SectionList
refreshing={loading} refreshing={loading}
onRefresh={() => loadChats()} onRefresh={() => {
inverted={true} // loadChatDetails() //remove
loadChats()
}}
inverted
sections={chats} sections={chats}
keyExtractor={(item, index) => item + index} keyExtractor={(item, index) => item + index}
renderItem={({ item }) => { renderItem={({ item }) => {
// console.log({ item }) // console.log({ item })
const timeString = new Date(item.timestamp).toLocaleTimeString()
const time = timeString.slice(-11, -6) + " " + timeString.slice(-2)
return ( return (
<View style={{ margin: 5, padding: 5 }}><Text <View style={{ margin: 5, padding: 5 }}><Text
style={[{ textAlign: chatDetails.from_user == item.from_user ? 'right' : 'left', backgroundColor: chatDetails.from_user == item.from_user ? 'lightgray' : '#FFDE03', borderRadius: 5, padding: 5 }, style={[{ textAlign: chatDetails.from_user == item.from_user ? 'right' : 'left', backgroundColor: chatDetails.from_user == item.from_user ? 'lightgray' : '#FFDE03', borderRadius: 5, padding: 5 },
...@@ -143,8 +163,8 @@ export default function ChatScreen({ navigation }) { ...@@ -143,8 +163,8 @@ export default function ChatScreen({ navigation }) {
]} ]}
key={item.timestamp}>{item.message}</Text> key={item.timestamp}>{item.message}</Text>
<View style={{ flex: 1, flexDirection: chatDetails.from_user == item.from_user ? 'row-reverse' : 'row', textAlign: chatDetails.from_user == item.from_user ? 'right' : 'left' }}> <View style={{ flex: 1, flexDirection: chatDetails.from_user == item.from_user ? 'row-reverse' : 'row', textAlign: chatDetails.from_user == item.from_user ? 'right' : 'left' }}>
<Text style={{ textAlign: chatDetails.from_user == item.from_user ? 'right' : 'left', color: 'gray', fontSize: 11 }}>{new Date(item.timestamp).toLocaleTimeString()}</Text> <Text style={{ textAlign: chatDetails.from_user == item.from_user ? 'right' : 'left', color: 'gray', fontSize: 11 }}>{time}</Text>
{item.is_detected && <Ionicons name="mic" size={15} color={CONVO_DEFAULT_ICON_COLOR}/>} {item.is_detected && <Ionicons name="mic" size={15} color={CONVO_DEFAULT_ICON_COLOR} />}
{/* {chatDetails.to_user == item.from_user && item.id != playinId && <Ionicons name="play" size={15} style={styles.chatIcon} />} {/* {chatDetails.to_user == item.from_user && item.id != playinId && <Ionicons name="play" size={15} style={styles.chatIcon} />}
{chatDetails.to_user == item.from_user && item.id == playinId && <Ionicons name="pause" size={15} style={styles.chatIcon} />} */} {chatDetails.to_user == item.from_user && item.id == playinId && <Ionicons name="pause" size={15} style={styles.chatIcon} />} */}
{chatDetails.to_user == item.from_user && <PlayMessage message={item.message} />} {chatDetails.to_user == item.from_user && <PlayMessage message={item.message} />}
...@@ -200,11 +220,11 @@ export default function ChatScreen({ navigation }) { ...@@ -200,11 +220,11 @@ export default function ChatScreen({ navigation }) {
</View> </View>
<View style={{ flex: 0.075, padding:0, backgroundColor: 'white'}}> <View style={{ flex: 0.075, padding: 0, backgroundColor: 'white' }}>
<View style={{ flexDirection: 'row', display: 'flex', height: '100%' }}> <View style={{ flexDirection: 'row', display: 'flex', height: '100%' }}>
<View style={{ flex: 0.8, height: '100%', flexDirection: 'column-reverse' }}> <View style={{ flex: 0.8, height: '100%', flexDirection: 'column-reverse' }}>
<TextInput <TextInput
style={{ borderWidth: 2, borderColor: 'gray', color: 'black', marginHorizontal: 5,paddingHorizontal: 10, borderRadius: 5 }} style={{ borderWidth: 2, borderColor: 'gray', color: 'black', marginHorizontal: 5, paddingHorizontal: 10, borderRadius: 5 }}
defaultValue={input} defaultValue={input}
onChange={(e) => setInput(e.target.value)}></TextInput> onChange={(e) => setInput(e.target.value)}></TextInput>
</View> </View>
......
import pickle # import pickle
from keras.models import load_model # from keras.models import load_model
import numpy as np # import numpy as np
import IPython.display as ipd # import IPython.display as ipd
import random # import random
from sklearn.model_selection import train_test_split # from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder # from sklearn.preprocessing import LabelEncoder
def predict(samples): # def predict(samples):
model=load_model(r'./best_model_final.hdf5') # model=load_model(r'./best_model_final.hdf5')
f1 = open('all_label.txt', 'rb') # f1 = open('all_label.txt', 'rb')
all_label = pickle.load(f1) # all_label = pickle.load(f1)
print('loaded labels') # print('loaded labels')
# f2 = open('all_waves_file.txt', 'rb') # # f2 = open('all_waves_file.txt', 'rb')
# all_wave = pickle.load(f2) # # all_wave = pickle.load(f2)
# print('loaded waves') # # print('loaded waves')
le = LabelEncoder() # le = LabelEncoder()
y = le.fit_transform(all_label) # y = le.fit_transform(all_label)
classes = list(le.classes_) # classes = list(le.classes_)
# train_data_file = open("train_data_file.txt", 'rb') # # train_data_file = open("train_data_file.txt", 'rb')
# [x_tr, x_val, y_tr, y_val] = np.load(train_data_file, allow_pickle=True) # # [x_tr, x_val, y_tr, y_val] = np.load(train_data_file, allow_pickle=True)
# train_data_file.close() # # train_data_file.close()
def predictSamples(audio): # def predictSamples(audio):
prob=model.predict(audio.reshape(1,8000,1)) # prob=model.predict(audio.reshape(1,8000,1))
index=np.argmax(prob[0]) # index=np.argmax(prob[0])
return classes[index] # return classes[index]
# index=random.randint(0,len(x_val)-1) # # index=random.randint(0,len(x_val)-1)
# samples=x_val[index].ravel() # # samples=x_val[index].ravel()
print(samples) # print(samples)
# print("Audio:",classes[np.argmax(y_val[index])]) # # print("Audio:",classes[np.argmax(y_val[index])])
ipd.Audio(samples, rate=8000) # ipd.Audio(samples, rate=8000)
result = predictSamples(samples) # result = predictSamples(samples)
print("Text:",result) # print("Text:",result)
return result # return result
\ No newline at end of file \ No newline at end of file
import pickle # import pickle
from matplotlib import pyplot # from matplotlib import pyplot
import os # import os
import librosa # import librosa
import IPython.display as ipd # import IPython.display as ipd
import matplotlib.pyplot as plt # import matplotlib.pyplot as plt
import numpy as np # import numpy as np
from scipy.io import wavfile # from scipy.io import wavfile
import warnings # import warnings
from sklearn.preprocessing import LabelEncoder # from sklearn.preprocessing import LabelEncoder
from keras.utils import np_utils # from keras.utils import np_utils
# from sklearn.model_selection import train_test_split
# from keras.layers import Dense, Dropout, Flatten, Conv1D, Input, MaxPooling1D
# from keras.models import Model
# from keras.callbacks import EarlyStopping, ModelCheckpoint
# from keras import backend as K
# K.clear_session()
# warnings.filterwarnings("ignore")
# # os.listdir('../../../data/')
# classes = ['down', 'go', 'left', 'no', 'off',
# 'on', 'right', 'stop', 'up', 'yes']
from sklearn.model_selection import train_test_split
from keras.layers import Dense, Dropout, Flatten, Conv1D, Input, MaxPooling1D
from keras.models import Model
from keras.callbacks import EarlyStopping, ModelCheckpoint
from keras import backend as K
K.clear_session()
warnings.filterwarnings("ignore")
# os.listdir('../../../data/')
classes = ['down', 'go', 'left', 'no', 'off',
'on', 'right', 'stop', 'up', 'yes']
def train():
print('1')
train_audio_path = r'./backend/data/train/train/audio/'
samples, sample_rate = librosa.load(
train_audio_path+'yes/0a7c2a8d_nohash_0.wav', sr=16000)
# fig = plt.figure(figsize=(14, 8))
# ax1 = fig.add_subplot(211)
# ax1.set_title('Raw wave of ' + r'../input/train/audio/yes/0a7c2a8d_nohash_0.wav')
# ax1.set_xlabel('time')
# ax1.set_ylabel('Amplitude')
# ax1.plot(np.linspace(0, sample_rate/len(samples), sample_rate), samples)
ipd.Audio(samples, rate=sample_rate)
print(sample_rate) # def train():
samples = librosa.resample(samples, sample_rate, 8000) # print('1')
ipd.Audio(samples, rate=8000) # train_audio_path = r'./backend/data/train/train/audio/'
# samples, sample_rate = librosa.load(
# train_audio_path+'yes/0a7c2a8d_nohash_0.wav', sr=16000)
# # fig = plt.figure(figsize=(14, 8))
# # ax1 = fig.add_subplot(211)
# # ax1.set_title('Raw wave of ' + r'../input/train/audio/yes/0a7c2a8d_nohash_0.wav')
# # ax1.set_xlabel('time')
# # ax1.set_ylabel('Amplitude')
# # ax1.plot(np.linspace(0, sample_rate/len(samples), sample_rate), samples)
labels = os.listdir(train_audio_path) # ipd.Audio(samples, rate=sample_rate)
# find count of each label and plot bar graph # print(sample_rate)
no_of_recordings = []
for label in labels:
waves = [f for f in os.listdir(
train_audio_path + '/' + label) if f.endswith('.wav')]
no_of_recordings.append(len(waves))
# plot # samples = librosa.resample(samples, sample_rate, 8000)
# plt.figure(figsize=(30,5)) # ipd.Audio(samples, rate=8000)
index = np.arange(len(labels))
# plt.bar(index, no_of_recordings)
# plt.xlabel('Commands', fontsize=12)
# plt.ylabel('No of recordings', fontsize=12)
# plt.xticks(index, labels, fontsize=15, rotation=60)
# plt.title('No. of recordings for each command')
# plt.show()
print('2') # labels = os.listdir(train_audio_path)
labels = ["yes", "no", "up", "down", "left", # # find count of each label and plot bar graph
"right", "on", "off", "stop", "go"] # no_of_recordings = []
# for label in labels:
# waves = [f for f in os.listdir(
# train_audio_path + '/' + label) if f.endswith('.wav')]
# no_of_recordings.append(len(waves))
# labels_file = open('./labels_file.bin', 'wb+') # # plot
# pickle.dump(obj=labels, file=labels_file) # # plt.figure(figsize=(30,5))
# labels_file.close() # index = np.arange(len(labels))
# # plt.bar(index, no_of_recordings)
# # plt.xlabel('Commands', fontsize=12)
# # plt.ylabel('No of recordings', fontsize=12)
# # plt.xticks(index, labels, fontsize=15, rotation=60)
# # plt.title('No. of recordings for each command')
# # plt.show()
# # file = open('./labels_file.bin', 'rb') # print('2')
# # dict = pickle.load(file)
# # print('loaded')
# # print(dict)
# # print('fdnasf')
duration_of_recordings = [] # labels = ["yes", "no", "up", "down", "left",
for label in labels: # "right", "on", "off", "stop", "go"]
print('2.1', label)
waves = [f for f in os.listdir(
train_audio_path + '/' + label) if f.endswith('.wav')]
for wav in waves:
sample_rate, samples = wavfile.read(
train_audio_path + '/' + label + '/' + wav)
duration_of_recordings.append(float(len(samples)/sample_rate))
plt.hist(np.array(duration_of_recordings)) # # labels_file = open('./labels_file.bin', 'wb+')
# # pickle.dump(obj=labels, file=labels_file)
# # labels_file.close()
train_audio_path = r'./backend/data/train/train/audio/' # # # file = open('./labels_file.bin', 'rb')
# # # dict = pickle.load(file)
# # # print('loaded')
# # # print(dict)
# # # print('fdnasf')
all_wave = [] # duration_of_recordings = []
all_label = [] # for label in labels:
# print('2.1', label)
# waves = [f for f in os.listdir(
# train_audio_path + '/' + label) if f.endswith('.wav')]
# for wav in waves:
# sample_rate, samples = wavfile.read(
# train_audio_path + '/' + label + '/' + wav)
# duration_of_recordings.append(float(len(samples)/sample_rate))
f1 = open('all_label.txt', 'rb') # plt.hist(np.array(duration_of_recordings))
all_label = pickle.load(f1)
f2 = open('all_waves_file.txt', 'rb') # train_audio_path = r'./backend/data/train/train/audio/'
all_wave = pickle.load(f2)
if(all_wave and all_label): # all_wave = []
print('loaded labels and waves') # all_label = []
else:
print('Creating labels and waves files')
for label in labels:
print(label)
waves = [f for f in os.listdir(
train_audio_path + '/' + label) if f.endswith('.wav')]
for wav in waves:
samples, sample_rate = librosa.load(
train_audio_path + '/' + label + '/' + wav, sr=16000)
samples = librosa.resample(samples, sample_rate, 8000)
if(len(samples) == 8000):
all_wave.append(samples)
all_label.append(label)
# print('3') # f1 = open('all_label.txt', 'rb')
# all_label = pickle.load(f1)
all_labels_file = open('all_label.txt', 'wb+') # f2 = open('all_waves_file.txt', 'rb')
pickle.dump(file=all_labels_file, obj=all_label) # all_wave = pickle.load(f2)
all_labels_file.close()
return False # if(all_wave and all_label):
# print('loaded labels and waves')
# else:
# print('Creating labels and waves files')
# for label in labels:
# print(label)
# waves = [f for f in os.listdir(
# train_audio_path + '/' + label) if f.endswith('.wav')]
# for wav in waves:
# samples, sample_rate = librosa.load(
# train_audio_path + '/' + label + '/' + wav, sr=16000)
# samples = librosa.resample(samples, sample_rate, 8000)
# if(len(samples) == 8000):
# all_wave.append(samples)
# all_label.append(label)
all_waves_file = open('all_waves_file.txt', 'wb+') # # print('3')
pickle.dump(file=all_waves_file, obj=all_wave)
all_waves_file.close()
print('Done: creating labels and waves files') # all_labels_file = open('all_label.txt', 'wb+')
# pickle.dump(file=all_labels_file, obj=all_label)
# all_labels_file.close()
return False
le = LabelEncoder()
y = le.fit_transform(all_label)
classes = list(le.classes_)
print('4')
y = np_utils.to_categorical(y, num_classes=len(labels)) # all_waves_file = open('all_waves_file.txt', 'wb+')
# pickle.dump(file=all_waves_file, obj=all_wave)
# all_waves_file.close()
all_wave = np.array(all_wave).reshape(-1, 8000, 1) # print('Done: creating labels and waves files')
x_tr, x_val, y_tr, y_val = train_test_split(np.array(all_wave), np.array(
y), stratify=y, test_size=0.2, random_state=777, shuffle=True)
train_data_file = open('train_data_file.txt', 'wb+') # le = LabelEncoder()
np.save(file=train_data_file, arr=np.array([x_tr, x_val, y_tr, y_val])) # y = le.fit_transform(all_label)
train_data_file.close() # classes = list(le.classes_)
inputs = Input(shape=(8000, 1)) # print('4')
# First Conv1D layer # y = np_utils.to_categorical(y, num_classes=len(labels))
conv = Conv1D(8, 13, padding='valid', activation='relu', strides=1)(inputs)
conv = MaxPooling1D(3)(conv)
conv = Dropout(0.3)(conv)
# Second Conv1D layer # all_wave = np.array(all_wave).reshape(-1, 8000, 1)
conv = Conv1D(16, 11, padding='valid', activation='relu', strides=1)(conv)
conv = MaxPooling1D(3)(conv)
conv = Dropout(0.3)(conv)
# Third Conv1D layer # x_tr, x_val, y_tr, y_val = train_test_split(np.array(all_wave), np.array(
conv = Conv1D(32, 9, padding='valid', activation='relu', strides=1)(conv) # y), stratify=y, test_size=0.2, random_state=777, shuffle=True)
conv = MaxPooling1D(3)(conv)
conv = Dropout(0.3)(conv)
# Fourth Conv1D layer # train_data_file = open('train_data_file.txt', 'wb+')
conv = Conv1D(64, 7, padding='valid', activation='relu', strides=1)(conv) # np.save(file=train_data_file, arr=np.array([x_tr, x_val, y_tr, y_val]))
conv = MaxPooling1D(3)(conv) # train_data_file.close()
conv = Dropout(0.3)(conv)
# Flatten layer # inputs = Input(shape=(8000, 1))
conv = Flatten()(conv)
# Dense Layer 1 # # First Conv1D layer
conv = Dense(256, activation='relu')(conv) # conv = Conv1D(8, 13, padding='valid', activation='relu', strides=1)(inputs)
conv = Dropout(0.3)(conv) # conv = MaxPooling1D(3)(conv)
# conv = Dropout(0.3)(conv)
# Dense Layer 2 # # Second Conv1D layer
conv = Dense(128, activation='relu')(conv) # conv = Conv1D(16, 11, padding='valid', activation='relu', strides=1)(conv)
conv = Dropout(0.3)(conv) # conv = MaxPooling1D(3)(conv)
# conv = Dropout(0.3)(conv)
outputs = Dense(len(labels), activation='softmax')(conv) # # Third Conv1D layer
# conv = Conv1D(32, 9, padding='valid', activation='relu', strides=1)(conv)
# conv = MaxPooling1D(3)(conv)
# conv = Dropout(0.3)(conv)
model = Model(inputs, outputs) # # Fourth Conv1D layer
model.summary() # conv = Conv1D(64, 7, padding='valid', activation='relu', strides=1)(conv)
# conv = MaxPooling1D(3)(conv)
# conv = Dropout(0.3)(conv)
model.compile(loss='categorical_crossentropy', # # Flatten layer
optimizer='adam', metrics=['accuracy']) # conv = Flatten()(conv)
es = EarlyStopping(monitor='val_loss', mode='min', # # Dense Layer 1
verbose=1, patience=10, min_delta=0.0001) # conv = Dense(256, activation='relu')(conv)
mc = ModelCheckpoint('best_model.hdf5', monitor='val_accuracy', # conv = Dropout(0.3)(conv)
verbose=1, save_best_only=True, mode='max')
history = model.fit(x_tr, y_tr, epochs=100, callbacks=[ # # Dense Layer 2
es, mc], batch_size=32, validation_data=(x_val, y_val)) # conv = Dense(128, activation='relu')(conv)
# conv = Dropout(0.3)(conv)
# pyplot.plot(history.history['loss'], label='train') # outputs = Dense(len(labels), activation='softmax')(conv)
# pyplot.plot(history.history['val_loss'], label='test')
# pyplot.legend()
# pyplot.show() # model = Model(inputs, outputs)
# model.summary()
return history # model.compile(loss='categorical_crossentropy',
# optimizer='adam', metrics=['accuracy'])
# es = EarlyStopping(monitor='val_loss', mode='min',
# verbose=1, patience=10, min_delta=0.0001)
# mc = ModelCheckpoint('best_model.hdf5', monitor='val_accuracy',
# verbose=1, save_best_only=True, mode='max')
# history = model.fit(x_tr, y_tr, epochs=100, callbacks=[
# es, mc], batch_size=32, validation_data=(x_val, y_val))
# # pyplot.plot(history.history['loss'], label='train')
# # pyplot.plot(history.history['val_loss'], label='test')
# # pyplot.legend()
# # pyplot.show()
# return history
...@@ -46,7 +46,7 @@ class MlModelSerializer(serializers.ModelSerializer): ...@@ -46,7 +46,7 @@ class MlModelSerializer(serializers.ModelSerializer):
'__all__' '__all__'
) )
class ChatSerialier(serializers.ModelSerializer): class ChatSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = Chat model = Chat
fields = ( fields = (
......
from http.client import HTTPResponse
from lib2to3.pytree import convert
from pyexpat import model
from django.contrib.auth.models import User, Group from django.contrib.auth.models import User, Group
from rest_framework import viewsets from rest_framework import viewsets
from rest_framework import permissions from rest_framework import permissions
from backend.cms.serializers import MyTokenObtainPairSerializer from backend.cms.serializers import MyTokenObtainPairSerializer
from backend.cms.serializers import MlModelSerializer, ChatSerialier, ConversationSerializer from backend.cms.serializers import MlModelSerializer, ChatSerializer, ConversationSerializer
from backend.cms.serializers import UserSerializer, GroupSerializer from backend.cms.serializers import UserSerializer, GroupSerializer
from rest_framework.decorators import action from rest_framework.decorators import action
from rest_framework.response import Response from rest_framework.response import Response
import mimetypes
import os import os
from rest_framework.parsers import MultiPartParser from rest_framework.parsers import MultiPartParser
from io import BytesIO
from datetime import datetime, timedelta from datetime import datetime, timedelta
import librosa import librosa
from django.db.models import Q
from django.core.files.storage import FileSystemStorage from django.core.files.storage import FileSystemStorage
from .models import Chat, Conversation, MlModel from .models import Chat, Conversation, MlModel
from rest_framework_simplejwt.views import TokenObtainPairView from rest_framework_simplejwt.views import TokenObtainPairView
from .model.train import train # from .model.train import train
from .model.predict import predict # from .model.predict import predict
from pydub import AudioSegment from pydub import AudioSegment
...@@ -73,7 +71,7 @@ class MlModelViewSet(viewsets.ViewSet): ...@@ -73,7 +71,7 @@ class MlModelViewSet(viewsets.ViewSet):
parser_classes = [MultiPartParser] parser_classes = [MultiPartParser]
@action(detail=False) @action(detail=False)
def runAction(*args, **kwargs): def addChats(*args, **kwargs):
admin = User.objects.get(username='admin') admin = User.objects.get(username='admin')
user2 = User.objects.get(username='user2') user2 = User.objects.get(username='user2')
...@@ -172,16 +170,47 @@ class MlModelViewSet(viewsets.ViewSet): ...@@ -172,16 +170,47 @@ class MlModelViewSet(viewsets.ViewSet):
# track.file.name = mp3_filename # track.file.name = mp3_filename
# track.save() # track.save()
results = predict(samples) results = {} # predict(samples)
print(results) print(results)
return Response({'success': True, 'result': results}) return Response({'success': True, 'result': results})
class ChatViewSet(viewsets.ModelViewSet): class ChatViewSet(viewsets.ModelViewSet):
queryset = Chat.objects.all().order_by('-timestamp') queryset = Chat.objects.all().order_by('-timestamp')
serializer_class = ChatSerialier serializer_class = ChatSerializer
permission_classes = [permissions.IsAuthenticated] permission_classes = [permissions.IsAuthenticated]
@action(methods='POST', detail=True)
def getChats(self, request, *args, **kwargs):
chats = Chat.objects.filter(Q(from_user__user_id=request.user.id) | Q(
to_user__user_id=request.user.id)).order_by('-timestamp').values()
return Response({chats})
def list(self, request, pk=None):
if pk == None:
chats = Chat.objects.filter(Q(from_user_id=request.user.id) | Q(
to_user_id=request.user.id))
else:
chats = Chat.objects.get(id=pk)
page = self.paginate_queryset(chats)
if page is not None:
serializer = self.get_serializer(page, many=True)
return self.get_paginated_response(serializer.data)
serializer = self.get_serializer(page, many=True)
result_set = serializer.data
return Response(result_set)
def get_result_set(self, chats):
result_set = ChatSerializer(chats, many=True).data
return result_set
class ConversationViewSet(viewsets.ModelViewSet): class ConversationViewSet(viewsets.ModelViewSet):
queryset = Conversation.objects.all().order_by('id') queryset = Conversation.objects.all().order_by('id')
......
...@@ -46,14 +46,16 @@ INSTALLED_APPS = [ ...@@ -46,14 +46,16 @@ INSTALLED_APPS = [
] ]
MIDDLEWARE = [ MIDDLEWARE = [
'django.middleware.security.SecurityMiddleware', 'django.middleware.security.SecurityMiddleware',
'django.contrib.sessions.middleware.SessionMiddleware', 'django.contrib.sessions.middleware.SessionMiddleware',
'django.middleware.common.CommonMiddleware', 'django.middleware.common.CommonMiddleware',
'django.middleware.csrf.CsrfViewMiddleware', 'django.middleware.csrf.CsrfViewMiddleware',
'django.contrib.auth.middleware.AuthenticationMiddleware', 'django.contrib.auth.middleware.AuthenticationMiddleware',
'corsheaders.middleware.CorsMiddleware',
'django.contrib.messages.middleware.MessageMiddleware', 'django.contrib.messages.middleware.MessageMiddleware',
'django.middleware.clickjacking.XFrameOptionsMiddleware', 'django.middleware.clickjacking.XFrameOptionsMiddleware',
'corsheaders.middleware.CorsMiddleware',
] ]
ROOT_URLCONF = 'backend.urls' ROOT_URLCONF = 'backend.urls'
...@@ -92,18 +94,18 @@ DATABASES = { ...@@ -92,18 +94,18 @@ DATABASES = {
# https://docs.djangoproject.com/en/4.0/ref/settings/#auth-password-validators # https://docs.djangoproject.com/en/4.0/ref/settings/#auth-password-validators
AUTH_PASSWORD_VALIDATORS = [ AUTH_PASSWORD_VALIDATORS = [
{ # {
'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator', # 'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator',
}, # },
{ # {
'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator', # 'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator',
}, # },
{ # {
'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator', # 'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator',
}, # },
{ # {
'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator', # 'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator',
}, # },
] ]
...@@ -132,10 +134,10 @@ DEFAULT_AUTO_FIELD = 'django.db.models.BigAutoField' ...@@ -132,10 +134,10 @@ DEFAULT_AUTO_FIELD = 'django.db.models.BigAutoField'
REST_FRAMEWORK = { REST_FRAMEWORK = {
'DEFAULT_PAGINATION_CLASS': 'rest_framework.pagination.PageNumberPagination', 'DEFAULT_PAGINATION_CLASS': 'rest_framework.pagination.PageNumberPagination',
'PAGE_SIZE': 10, 'PAGE_SIZE': 10,
'DEFAULT_AUTHENTICATION_CLASSES': ( 'DEFAULT_AUTHENTICATION_CLASSES': [
'rest_framework_simplejwt.authentication.JWTAuthentication', 'rest_framework_simplejwt.authentication.JWTAuthentication',
'rest_framework.authentication.SessionAuthentication', 'rest_framework.authentication.SessionAuthentication',
) ]
} }
SIMPLE_JWT = { SIMPLE_JWT = {
...@@ -148,7 +150,7 @@ SIMPLE_JWT = { ...@@ -148,7 +150,7 @@ SIMPLE_JWT = {
'SIGNING_KEY': SECRET_KEY, 'SIGNING_KEY': SECRET_KEY,
'VERIFYING_KEY': None, 'VERIFYING_KEY': None,
'AUTH_HEADER_TYPES': ('Bearer',), 'AUTH_HEADER_TYPES': ('JWT',),
'USER_ID_FIELD': 'id', 'USER_ID_FIELD': 'id',
'USER_ID_CLAIM': 'user_id', 'USER_ID_CLAIM': 'user_id',
...@@ -156,10 +158,11 @@ SIMPLE_JWT = { ...@@ -156,10 +158,11 @@ SIMPLE_JWT = {
'TOKEN_TYPE_CLAIM': 'token_type', 'TOKEN_TYPE_CLAIM': 'token_type',
'SLIDING_TOKEN_REFRESH_EXP_CLAIM': 'refresh_exp', 'SLIDING_TOKEN_REFRESH_EXP_CLAIM': 'refresh_exp',
'SLIDING_TOKEN_LIFETIME': datetime.timedelta(minutes=5), 'SLIDING_TOKEN_LIFETIME': datetime.timedelta(minutes=60),
'SLIDING_TOKEN_REFRESH_LIFETIME': datetime.timedelta(days=1), 'SLIDING_TOKEN_REFRESH_LIFETIME': datetime.timedelta(days=1),
} }
CORS_ORIGIN_ALLOW_ALL = True CORS_ALLOW_ALL_ORIGINS = True # If this is used then `CORS_ALLOWED_ORIGINS` will not have any effect
CORS_ALLOW_CREDENTIALS = True
ASGI_APPLICATION = "backend.asgi.application" ASGI_APPLICATION = "backend.asgi.application"
\ No newline at end of file
...@@ -5,6 +5,8 @@ from backend.cms import views ...@@ -5,6 +5,8 @@ from backend.cms import views
from rest_framework_simplejwt.views import TokenObtainPairView, TokenRefreshView from rest_framework_simplejwt.views import TokenObtainPairView, TokenRefreshView
from rest_framework.authtoken.views import obtain_auth_token from rest_framework.authtoken.views import obtain_auth_token
from django.contrib import admin
router = routers.DefaultRouter() router = routers.DefaultRouter()
router.register(r'users', views.UserViewSet) router.register(r'users', views.UserViewSet)
router.register(r'groups', views.GroupViewSet) router.register(r'groups', views.GroupViewSet)
...@@ -16,6 +18,7 @@ router.register(r'conversations', views.ConversationViewSet) ...@@ -16,6 +18,7 @@ router.register(r'conversations', views.ConversationViewSet)
# Additionally, we include login URLs for the browsable API. # Additionally, we include login URLs for the browsable API.
urlpatterns = [ urlpatterns = [
path('', include(router.urls)), path('', include(router.urls)),
path('admin/', admin.site.urls),
path('api-auth/', include('rest_framework.urls', namespace='rest_framework')), path('api-auth/', include('rest_framework.urls', namespace='rest_framework')),
re_path(r'^api/auth/token/obtain/$', ObtainTokenPairWithUserView.as_view()), re_path(r'^api/auth/token/obtain/$', ObtainTokenPairWithUserView.as_view()),
re_path(r'^api/auth/token/refresh/$', TokenRefreshView.as_view()), re_path(r'^api/auth/token/refresh/$', TokenRefreshView.as_view()),
......
No preview for this file type
No preview for this file type
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment