Commit 16228ea8 authored by Malsha Rathnasiri's avatar Malsha Rathnasiri

minor changes

parent b4c8f105
import { BACKEND_URL } from "./constants"
import AsyncStorage from '@react-native-async-storage/async-storage';
import { useNavigation } from "@react-navigation/native";
import axios from "axios";
//"eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJ0b2tlbl90eXBlIjoiYWNjZXNzIiwiZXhwIjoxNjYxMzU2NDA5LCJpYXQiOjE2NjAxNDY4MDksImp0aSI6ImFhMTljMmY2ZDNkMzRiNDdhZmZmM2FjMzVjNzI4MWJhIiwidXNlcl9pZCI6MX0.IVzibo_Rf2xzoT1J5o1L3zwu3mco6ODcNPC-7imu3Lo"
......@@ -8,21 +11,41 @@ import { useNavigation } from "@react-navigation/native";
const getHeaders = async () => {
var token = await AsyncStorage.getItem('access_token')
return new Headers({ authorization: `Bearer ${token}`, 'Content-Type': 'application/json' })
return {
Authorization: `JWT ${token}`, 'Content-Type': 'application/json',
}
}
export const create = async (resource, values) => {
const headers = await getHeaders()
console.log({headers})
const headers = getHeaders()
console.log({ headers }, 'create')
try {
return fetch(`${BACKEND_URL}/${resource}/`, { method: 'POST', body: JSON.stringify(values), headers })
// return Axios.post(`${BACKEND_URL}/${resource}/`, values)
}
catch (e) {
console.log(e)
}
}
export const getList = async (resource) => {
export const getList = async (resource, params) => {
const url = new URL(`${BACKEND_URL}/${resource}/`)
if (params) {
Object.keys(params).map(key => {
url.searchParams.append(key, params[key])
})
}
const headers = await getHeaders()
console.log(headers, 'getList')
return fetch(url, { method: 'GET', headers: headers })
return axios.get(url.toString(), null, { headers })
}
export const getOne = async (resource, id) => {
const url = new URL(`${BACKEND_URL}/${resource}/${id}/`)
const headers = await getHeaders()
return fetch(`${BACKEND_URL}/${resource}/`, { method: 'GET', headers: headers })
console.log(headers, 'getONe')
return fetch(url, { method: "GET", headers: headers })
}
\ No newline at end of file
// export const BACKEND_URL = "http://192.168.8.103:8000"
export const BACKEND_ADDRESS = "dcf1-2402-d000-a500-101a-cd9c-120b-91b8-7774.in.ngrok.io"
import { Platform } from 'react-native'
export const BACKEND_ADDRESS = Platform.OS == 'web' ? "127.0.0.1:8000" : "93e8-2401-dd00-10-20-e170-748c-b618-e8f5.in.ngrok.io"
export const BACKEND_URL = `http://${BACKEND_ADDRESS}`
......@@ -6,7 +6,7 @@ import _ from 'lodash'
import EditScreenInfo from '../components/EditScreenInfo';
// import { Text, View } from '../components/Themed';
import { RootTabScreenProps } from '../types';
import { create, getList } from '../api/api';
import { create, getList, getOne } from '../api/api';
import { BACKEND_ADDRESS } from '../api/constants';
import Ionicons from '@expo/vector-icons/Ionicons';
import { CONVO_DEFAULT_ICON_COLOR, styles } from '../util/styles';
......@@ -37,7 +37,8 @@ export default function ChatScreen({ navigation }) {
const [detectedText, setDetectedText] = useState("")
const [playinId, setPlayingId] = useState(3)
const [chatDetails, setChatDetails] = useState({ from_user: 1, to_user: 2, conversation_id: 1 })
const [chatDetails, setChatDetails] = useState()
const [chats, setChats] = useState([])
const [input, setInput] = useState('test')
const [loading, setLoading] = useState(true)
......@@ -59,12 +60,15 @@ export default function ChatScreen({ navigation }) {
console.log({ chats })
}
useEffect(() => { if (chatDetails) { loadChats(); setLoading(false) } }, [chatDetails])
// const ws = new WebSocket(`ws://${BACKEND_ADDRESS}/chatSocket/`)
useEffect(() => {
loadChats()
loadChatDetails()
// startWebsocket()
// loadSampleChats()
setLoading(false)
}, [])
// const startWebsocket = () => {
......@@ -87,12 +91,22 @@ export default function ChatScreen({ navigation }) {
// }
const loadChatDetails = async () => {
await getOne('conversations', 1).then(res => {
return res.json()
}).then(res => {
console.log(res)
setChatDetails(res)
})
}
const loadChats = async () => {
await getList('chats').then(res => {
return res.json()
}).then(res => {
// console.log(res)
const chats = res.results
console.log("load chats")
const chats = res.results || []
const sections = [...new Set(chats.map(chat => new Date(chat.timestamp).setHours(0, 0, 0, 0)))];
const sectionChats = sections.map(section => ({ title: section, data: chats.filter(chat => new Date(chat.timestamp).setHours(0, 0, 0, 0) == section) }))
setChats(sectionChats)
......@@ -103,16 +117,16 @@ export default function ChatScreen({ navigation }) {
const onSendPress = () => {
try {
create('chats', { message: input, from_user: chatDetails.from_user, to_user: chatDetails.to_user, conversation: chatDetails.conversation_id }).then(response => {
// console.log(response)
})
setLoading(true)
setInput('')
loadChats()
}
catch(e){
Toast.show({title: 'Error sending message. try again!', ...ERROR_TOAST_PROPS})
}
create('chats', { message: input, from_user: chatDetails.from_user, to_user: chatDetails.to_user, conversation: chatDetails.conversation_id }).then(response => {
// console.log(response)
})
setLoading(true)
setInput('')
loadChats()
}
catch (e) {
Toast.show({ title: 'Error sending message. try again!', ...ERROR_TOAST_PROPS })
}
}
return (
......@@ -130,12 +144,18 @@ export default function ChatScreen({ navigation }) {
{loading ? <ActivityIndicator /> :
<SectionList
refreshing={loading}
onRefresh={() => loadChats()}
inverted={true}
onRefresh={() => {
// loadChatDetails() //remove
loadChats()
}}
inverted
sections={chats}
keyExtractor={(item, index) => item + index}
renderItem={({ item }) => {
// console.log({ item })
const timeString = new Date(item.timestamp).toLocaleTimeString()
const time = timeString.slice(-11, -6) + " " + timeString.slice(-2)
return (
<View style={{ margin: 5, padding: 5 }}><Text
style={[{ textAlign: chatDetails.from_user == item.from_user ? 'right' : 'left', backgroundColor: chatDetails.from_user == item.from_user ? 'lightgray' : '#FFDE03', borderRadius: 5, padding: 5 },
......@@ -143,8 +163,8 @@ export default function ChatScreen({ navigation }) {
]}
key={item.timestamp}>{item.message}</Text>
<View style={{ flex: 1, flexDirection: chatDetails.from_user == item.from_user ? 'row-reverse' : 'row', textAlign: chatDetails.from_user == item.from_user ? 'right' : 'left' }}>
<Text style={{ textAlign: chatDetails.from_user == item.from_user ? 'right' : 'left', color: 'gray', fontSize: 11 }}>{new Date(item.timestamp).toLocaleTimeString()}</Text>
{item.is_detected && <Ionicons name="mic" size={15} color={CONVO_DEFAULT_ICON_COLOR}/>}
<Text style={{ textAlign: chatDetails.from_user == item.from_user ? 'right' : 'left', color: 'gray', fontSize: 11 }}>{time}</Text>
{item.is_detected && <Ionicons name="mic" size={15} color={CONVO_DEFAULT_ICON_COLOR} />}
{/* {chatDetails.to_user == item.from_user && item.id != playinId && <Ionicons name="play" size={15} style={styles.chatIcon} />}
{chatDetails.to_user == item.from_user && item.id == playinId && <Ionicons name="pause" size={15} style={styles.chatIcon} />} */}
{chatDetails.to_user == item.from_user && <PlayMessage message={item.message} />}
......@@ -200,11 +220,11 @@ export default function ChatScreen({ navigation }) {
</View>
<View style={{ flex: 0.075, padding:0, backgroundColor: 'white'}}>
<View style={{ flex: 0.075, padding: 0, backgroundColor: 'white' }}>
<View style={{ flexDirection: 'row', display: 'flex', height: '100%' }}>
<View style={{ flex: 0.8, height: '100%', flexDirection: 'column-reverse' }}>
<TextInput
style={{ borderWidth: 2, borderColor: 'gray', color: 'black', marginHorizontal: 5,paddingHorizontal: 10, borderRadius: 5 }}
style={{ borderWidth: 2, borderColor: 'gray', color: 'black', marginHorizontal: 5, paddingHorizontal: 10, borderRadius: 5 }}
defaultValue={input}
onChange={(e) => setInput(e.target.value)}></TextInput>
</View>
......
import pickle
from keras.models import load_model
# import pickle
# from keras.models import load_model
import numpy as np
# import numpy as np
import IPython.display as ipd
# import IPython.display as ipd
import random
# import random
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
# from sklearn.model_selection import train_test_split
# from sklearn.preprocessing import LabelEncoder
def predict(samples):
model=load_model(r'./best_model_final.hdf5')
# def predict(samples):
# model=load_model(r'./best_model_final.hdf5')
f1 = open('all_label.txt', 'rb')
all_label = pickle.load(f1)
print('loaded labels')
# f1 = open('all_label.txt', 'rb')
# all_label = pickle.load(f1)
# print('loaded labels')
# f2 = open('all_waves_file.txt', 'rb')
# all_wave = pickle.load(f2)
# print('loaded waves')
# # f2 = open('all_waves_file.txt', 'rb')
# # all_wave = pickle.load(f2)
# # print('loaded waves')
le = LabelEncoder()
y = le.fit_transform(all_label)
classes = list(le.classes_)
# le = LabelEncoder()
# y = le.fit_transform(all_label)
# classes = list(le.classes_)
# train_data_file = open("train_data_file.txt", 'rb')
# [x_tr, x_val, y_tr, y_val] = np.load(train_data_file, allow_pickle=True)
# train_data_file.close()
# # train_data_file = open("train_data_file.txt", 'rb')
# # [x_tr, x_val, y_tr, y_val] = np.load(train_data_file, allow_pickle=True)
# # train_data_file.close()
def predictSamples(audio):
prob=model.predict(audio.reshape(1,8000,1))
index=np.argmax(prob[0])
return classes[index]
# def predictSamples(audio):
# prob=model.predict(audio.reshape(1,8000,1))
# index=np.argmax(prob[0])
# return classes[index]
# index=random.randint(0,len(x_val)-1)
# samples=x_val[index].ravel()
print(samples)
# print("Audio:",classes[np.argmax(y_val[index])])
ipd.Audio(samples, rate=8000)
# # index=random.randint(0,len(x_val)-1)
# # samples=x_val[index].ravel()
# print(samples)
# # print("Audio:",classes[np.argmax(y_val[index])])
# ipd.Audio(samples, rate=8000)
result = predictSamples(samples)
# result = predictSamples(samples)
print("Text:",result)
# print("Text:",result)
return result
\ No newline at end of file
# return result
\ No newline at end of file
import pickle
from matplotlib import pyplot
import os
import librosa
import IPython.display as ipd
import matplotlib.pyplot as plt
import numpy as np
from scipy.io import wavfile
import warnings
# import pickle
# from matplotlib import pyplot
# import os
# import librosa
# import IPython.display as ipd
# import matplotlib.pyplot as plt
# import numpy as np
# from scipy.io import wavfile
# import warnings
from sklearn.preprocessing import LabelEncoder
# from sklearn.preprocessing import LabelEncoder
from keras.utils import np_utils
# from keras.utils import np_utils
from sklearn.model_selection import train_test_split
from keras.layers import Dense, Dropout, Flatten, Conv1D, Input, MaxPooling1D
from keras.models import Model
from keras.callbacks import EarlyStopping, ModelCheckpoint
from keras import backend as K
K.clear_session()
warnings.filterwarnings("ignore")
# os.listdir('../../../data/')
classes = ['down', 'go', 'left', 'no', 'off',
'on', 'right', 'stop', 'up', 'yes']
def train():
print('1')
train_audio_path = r'./backend/data/train/train/audio/'
samples, sample_rate = librosa.load(
train_audio_path+'yes/0a7c2a8d_nohash_0.wav', sr=16000)
# fig = plt.figure(figsize=(14, 8))
# ax1 = fig.add_subplot(211)
# ax1.set_title('Raw wave of ' + r'../input/train/audio/yes/0a7c2a8d_nohash_0.wav')
# ax1.set_xlabel('time')
# ax1.set_ylabel('Amplitude')
# ax1.plot(np.linspace(0, sample_rate/len(samples), sample_rate), samples)
ipd.Audio(samples, rate=sample_rate)
# from sklearn.model_selection import train_test_split
# from keras.layers import Dense, Dropout, Flatten, Conv1D, Input, MaxPooling1D
# from keras.models import Model
# from keras.callbacks import EarlyStopping, ModelCheckpoint
# from keras import backend as K
# K.clear_session()
# warnings.filterwarnings("ignore")
# # os.listdir('../../../data/')
# classes = ['down', 'go', 'left', 'no', 'off',
# 'on', 'right', 'stop', 'up', 'yes']
# def train():
# print('1')
# train_audio_path = r'./backend/data/train/train/audio/'
# samples, sample_rate = librosa.load(
# train_audio_path+'yes/0a7c2a8d_nohash_0.wav', sr=16000)
# # fig = plt.figure(figsize=(14, 8))
# # ax1 = fig.add_subplot(211)
# # ax1.set_title('Raw wave of ' + r'../input/train/audio/yes/0a7c2a8d_nohash_0.wav')
# # ax1.set_xlabel('time')
# # ax1.set_ylabel('Amplitude')
# # ax1.plot(np.linspace(0, sample_rate/len(samples), sample_rate), samples)
print(sample_rate)
# ipd.Audio(samples, rate=sample_rate)
samples = librosa.resample(samples, sample_rate, 8000)
ipd.Audio(samples, rate=8000)
# print(sample_rate)
labels = os.listdir(train_audio_path)
# samples = librosa.resample(samples, sample_rate, 8000)
# ipd.Audio(samples, rate=8000)
# find count of each label and plot bar graph
no_of_recordings = []
for label in labels:
waves = [f for f in os.listdir(
train_audio_path + '/' + label) if f.endswith('.wav')]
no_of_recordings.append(len(waves))
# labels = os.listdir(train_audio_path)
# plot
# plt.figure(figsize=(30,5))
index = np.arange(len(labels))
# plt.bar(index, no_of_recordings)
# plt.xlabel('Commands', fontsize=12)
# plt.ylabel('No of recordings', fontsize=12)
# plt.xticks(index, labels, fontsize=15, rotation=60)
# plt.title('No. of recordings for each command')
# plt.show()
# # find count of each label and plot bar graph
# no_of_recordings = []
# for label in labels:
# waves = [f for f in os.listdir(
# train_audio_path + '/' + label) if f.endswith('.wav')]
# no_of_recordings.append(len(waves))
print('2')
# # plot
# # plt.figure(figsize=(30,5))
# index = np.arange(len(labels))
# # plt.bar(index, no_of_recordings)
# # plt.xlabel('Commands', fontsize=12)
# # plt.ylabel('No of recordings', fontsize=12)
# # plt.xticks(index, labels, fontsize=15, rotation=60)
# # plt.title('No. of recordings for each command')
# # plt.show()
labels = ["yes", "no", "up", "down", "left",
"right", "on", "off", "stop", "go"]
# print('2')
# labels_file = open('./labels_file.bin', 'wb+')
# pickle.dump(obj=labels, file=labels_file)
# labels_file.close()
# labels = ["yes", "no", "up", "down", "left",
# "right", "on", "off", "stop", "go"]
# # file = open('./labels_file.bin', 'rb')
# # dict = pickle.load(file)
# # print('loaded')
# # print(dict)
# # print('fdnasf')
# # labels_file = open('./labels_file.bin', 'wb+')
# # pickle.dump(obj=labels, file=labels_file)
# # labels_file.close()
duration_of_recordings = []
for label in labels:
print('2.1', label)
waves = [f for f in os.listdir(
train_audio_path + '/' + label) if f.endswith('.wav')]
for wav in waves:
sample_rate, samples = wavfile.read(
train_audio_path + '/' + label + '/' + wav)
duration_of_recordings.append(float(len(samples)/sample_rate))
# # # file = open('./labels_file.bin', 'rb')
# # # dict = pickle.load(file)
# # # print('loaded')
# # # print(dict)
# # # print('fdnasf')
plt.hist(np.array(duration_of_recordings))
# duration_of_recordings = []
# for label in labels:
# print('2.1', label)
# waves = [f for f in os.listdir(
# train_audio_path + '/' + label) if f.endswith('.wav')]
# for wav in waves:
# sample_rate, samples = wavfile.read(
# train_audio_path + '/' + label + '/' + wav)
# duration_of_recordings.append(float(len(samples)/sample_rate))
train_audio_path = r'./backend/data/train/train/audio/'
# plt.hist(np.array(duration_of_recordings))
all_wave = []
all_label = []
# train_audio_path = r'./backend/data/train/train/audio/'
f1 = open('all_label.txt', 'rb')
all_label = pickle.load(f1)
# all_wave = []
# all_label = []
f2 = open('all_waves_file.txt', 'rb')
all_wave = pickle.load(f2)
# f1 = open('all_label.txt', 'rb')
# all_label = pickle.load(f1)
if(all_wave and all_label):
print('loaded labels and waves')
else:
print('Creating labels and waves files')
for label in labels:
print(label)
waves = [f for f in os.listdir(
train_audio_path + '/' + label) if f.endswith('.wav')]
for wav in waves:
samples, sample_rate = librosa.load(
train_audio_path + '/' + label + '/' + wav, sr=16000)
samples = librosa.resample(samples, sample_rate, 8000)
if(len(samples) == 8000):
all_wave.append(samples)
all_label.append(label)
# f2 = open('all_waves_file.txt', 'rb')
# all_wave = pickle.load(f2)
# print('3')
# if(all_wave and all_label):
# print('loaded labels and waves')
# else:
# print('Creating labels and waves files')
# for label in labels:
# print(label)
# waves = [f for f in os.listdir(
# train_audio_path + '/' + label) if f.endswith('.wav')]
# for wav in waves:
# samples, sample_rate = librosa.load(
# train_audio_path + '/' + label + '/' + wav, sr=16000)
# samples = librosa.resample(samples, sample_rate, 8000)
# if(len(samples) == 8000):
# all_wave.append(samples)
# all_label.append(label)
all_labels_file = open('all_label.txt', 'wb+')
pickle.dump(file=all_labels_file, obj=all_label)
all_labels_file.close()
# # print('3')
return False
# all_labels_file = open('all_label.txt', 'wb+')
# pickle.dump(file=all_labels_file, obj=all_label)
# all_labels_file.close()
all_waves_file = open('all_waves_file.txt', 'wb+')
pickle.dump(file=all_waves_file, obj=all_wave)
all_waves_file.close()
print('Done: creating labels and waves files')
# all_waves_file = open('all_waves_file.txt', 'wb+')
# pickle.dump(file=all_waves_file, obj=all_wave)
# all_waves_file.close()
return False
le = LabelEncoder()
y = le.fit_transform(all_label)
classes = list(le.classes_)
# print('Done: creating labels and waves files')
print('4')
y = np_utils.to_categorical(y, num_classes=len(labels))
# le = LabelEncoder()
# y = le.fit_transform(all_label)
# classes = list(le.classes_)
all_wave = np.array(all_wave).reshape(-1, 8000, 1)
# print('4')
x_tr, x_val, y_tr, y_val = train_test_split(np.array(all_wave), np.array(
y), stratify=y, test_size=0.2, random_state=777, shuffle=True)
# y = np_utils.to_categorical(y, num_classes=len(labels))
train_data_file = open('train_data_file.txt', 'wb+')
np.save(file=train_data_file, arr=np.array([x_tr, x_val, y_tr, y_val]))
train_data_file.close()
# all_wave = np.array(all_wave).reshape(-1, 8000, 1)
inputs = Input(shape=(8000, 1))
# x_tr, x_val, y_tr, y_val = train_test_split(np.array(all_wave), np.array(
# y), stratify=y, test_size=0.2, random_state=777, shuffle=True)
# First Conv1D layer
conv = Conv1D(8, 13, padding='valid', activation='relu', strides=1)(inputs)
conv = MaxPooling1D(3)(conv)
conv = Dropout(0.3)(conv)
# train_data_file = open('train_data_file.txt', 'wb+')
# np.save(file=train_data_file, arr=np.array([x_tr, x_val, y_tr, y_val]))
# train_data_file.close()
# Second Conv1D layer
conv = Conv1D(16, 11, padding='valid', activation='relu', strides=1)(conv)
conv = MaxPooling1D(3)(conv)
conv = Dropout(0.3)(conv)
# inputs = Input(shape=(8000, 1))
# Third Conv1D layer
conv = Conv1D(32, 9, padding='valid', activation='relu', strides=1)(conv)
conv = MaxPooling1D(3)(conv)
conv = Dropout(0.3)(conv)
# # First Conv1D layer
# conv = Conv1D(8, 13, padding='valid', activation='relu', strides=1)(inputs)
# conv = MaxPooling1D(3)(conv)
# conv = Dropout(0.3)(conv)
# Fourth Conv1D layer
conv = Conv1D(64, 7, padding='valid', activation='relu', strides=1)(conv)
conv = MaxPooling1D(3)(conv)
conv = Dropout(0.3)(conv)
# # Second Conv1D layer
# conv = Conv1D(16, 11, padding='valid', activation='relu', strides=1)(conv)
# conv = MaxPooling1D(3)(conv)
# conv = Dropout(0.3)(conv)
# Flatten layer
conv = Flatten()(conv)
# # Third Conv1D layer
# conv = Conv1D(32, 9, padding='valid', activation='relu', strides=1)(conv)
# conv = MaxPooling1D(3)(conv)
# conv = Dropout(0.3)(conv)
# Dense Layer 1
conv = Dense(256, activation='relu')(conv)
conv = Dropout(0.3)(conv)
# # Fourth Conv1D layer
# conv = Conv1D(64, 7, padding='valid', activation='relu', strides=1)(conv)
# conv = MaxPooling1D(3)(conv)
# conv = Dropout(0.3)(conv)
# Dense Layer 2
conv = Dense(128, activation='relu')(conv)
conv = Dropout(0.3)(conv)
# # Flatten layer
# conv = Flatten()(conv)
outputs = Dense(len(labels), activation='softmax')(conv)
# # Dense Layer 1
# conv = Dense(256, activation='relu')(conv)
# conv = Dropout(0.3)(conv)
model = Model(inputs, outputs)
model.summary()
# # Dense Layer 2
# conv = Dense(128, activation='relu')(conv)
# conv = Dropout(0.3)(conv)
model.compile(loss='categorical_crossentropy',
optimizer='adam', metrics=['accuracy'])
# outputs = Dense(len(labels), activation='softmax')(conv)
es = EarlyStopping(monitor='val_loss', mode='min',
verbose=1, patience=10, min_delta=0.0001)
mc = ModelCheckpoint('best_model.hdf5', monitor='val_accuracy',
verbose=1, save_best_only=True, mode='max')
# model = Model(inputs, outputs)
# model.summary()
history = model.fit(x_tr, y_tr, epochs=100, callbacks=[
es, mc], batch_size=32, validation_data=(x_val, y_val))
# model.compile(loss='categorical_crossentropy',
# optimizer='adam', metrics=['accuracy'])
# pyplot.plot(history.history['loss'], label='train')
# pyplot.plot(history.history['val_loss'], label='test')
# pyplot.legend()
# es = EarlyStopping(monitor='val_loss', mode='min',
# verbose=1, patience=10, min_delta=0.0001)
# mc = ModelCheckpoint('best_model.hdf5', monitor='val_accuracy',
# verbose=1, save_best_only=True, mode='max')
# pyplot.show()
# history = model.fit(x_tr, y_tr, epochs=100, callbacks=[
# es, mc], batch_size=32, validation_data=(x_val, y_val))
return history
# # pyplot.plot(history.history['loss'], label='train')
# # pyplot.plot(history.history['val_loss'], label='test')
# # pyplot.legend()
# # pyplot.show()
# return history
......@@ -46,7 +46,7 @@ class MlModelSerializer(serializers.ModelSerializer):
'__all__'
)
class ChatSerialier(serializers.ModelSerializer):
class ChatSerializer(serializers.ModelSerializer):
class Meta:
model = Chat
fields = (
......
from http.client import HTTPResponse
from lib2to3.pytree import convert
from pyexpat import model
from django.contrib.auth.models import User, Group
from rest_framework import viewsets
from rest_framework import permissions
from backend.cms.serializers import MyTokenObtainPairSerializer
from backend.cms.serializers import MlModelSerializer, ChatSerialier, ConversationSerializer
from backend.cms.serializers import MlModelSerializer, ChatSerializer, ConversationSerializer
from backend.cms.serializers import UserSerializer, GroupSerializer
from rest_framework.decorators import action
from rest_framework.response import Response
import mimetypes
import os
from rest_framework.parsers import MultiPartParser
from io import BytesIO
from datetime import datetime, timedelta
import librosa
from django.db.models import Q
from django.core.files.storage import FileSystemStorage
from .models import Chat, Conversation, MlModel
from rest_framework_simplejwt.views import TokenObtainPairView
from .model.train import train
# from .model.train import train
from .model.predict import predict
# from .model.predict import predict
from pydub import AudioSegment
......@@ -73,7 +71,7 @@ class MlModelViewSet(viewsets.ViewSet):
parser_classes = [MultiPartParser]
@action(detail=False)
def runAction(*args, **kwargs):
def addChats(*args, **kwargs):
admin = User.objects.get(username='admin')
user2 = User.objects.get(username='user2')
......@@ -172,16 +170,47 @@ class MlModelViewSet(viewsets.ViewSet):
# track.file.name = mp3_filename
# track.save()
results = predict(samples)
results = {} # predict(samples)
print(results)
return Response({'success': True, 'result': results})
class ChatViewSet(viewsets.ModelViewSet):
queryset = Chat.objects.all().order_by('-timestamp')
serializer_class = ChatSerialier
serializer_class = ChatSerializer
permission_classes = [permissions.IsAuthenticated]
@action(methods='POST', detail=True)
def getChats(self, request, *args, **kwargs):
chats = Chat.objects.filter(Q(from_user__user_id=request.user.id) | Q(
to_user__user_id=request.user.id)).order_by('-timestamp').values()
return Response({chats})
def list(self, request, pk=None):
if pk == None:
chats = Chat.objects.filter(Q(from_user_id=request.user.id) | Q(
to_user_id=request.user.id))
else:
chats = Chat.objects.get(id=pk)
page = self.paginate_queryset(chats)
if page is not None:
serializer = self.get_serializer(page, many=True)
return self.get_paginated_response(serializer.data)
serializer = self.get_serializer(page, many=True)
result_set = serializer.data
return Response(result_set)
def get_result_set(self, chats):
result_set = ChatSerializer(chats, many=True).data
return result_set
class ConversationViewSet(viewsets.ModelViewSet):
queryset = Conversation.objects.all().order_by('id')
......
......@@ -46,14 +46,16 @@ INSTALLED_APPS = [
]
MIDDLEWARE = [
'django.middleware.security.SecurityMiddleware',
'django.contrib.sessions.middleware.SessionMiddleware',
'django.middleware.common.CommonMiddleware',
'django.middleware.csrf.CsrfViewMiddleware',
'django.contrib.auth.middleware.AuthenticationMiddleware',
'corsheaders.middleware.CorsMiddleware',
'django.contrib.messages.middleware.MessageMiddleware',
'django.middleware.clickjacking.XFrameOptionsMiddleware',
'corsheaders.middleware.CorsMiddleware',
]
ROOT_URLCONF = 'backend.urls'
......@@ -92,18 +94,18 @@ DATABASES = {
# https://docs.djangoproject.com/en/4.0/ref/settings/#auth-password-validators
AUTH_PASSWORD_VALIDATORS = [
{
'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator',
},
{
'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator',
},
{
'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator',
},
{
'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator',
},
# {
# 'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator',
# },
# {
# 'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator',
# },
# {
# 'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator',
# },
# {
# 'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator',
# },
]
......@@ -132,10 +134,10 @@ DEFAULT_AUTO_FIELD = 'django.db.models.BigAutoField'
REST_FRAMEWORK = {
'DEFAULT_PAGINATION_CLASS': 'rest_framework.pagination.PageNumberPagination',
'PAGE_SIZE': 10,
'DEFAULT_AUTHENTICATION_CLASSES': (
'DEFAULT_AUTHENTICATION_CLASSES': [
'rest_framework_simplejwt.authentication.JWTAuthentication',
'rest_framework.authentication.SessionAuthentication',
)
]
}
SIMPLE_JWT = {
......@@ -148,7 +150,7 @@ SIMPLE_JWT = {
'SIGNING_KEY': SECRET_KEY,
'VERIFYING_KEY': None,
'AUTH_HEADER_TYPES': ('Bearer',),
'AUTH_HEADER_TYPES': ('JWT',),
'USER_ID_FIELD': 'id',
'USER_ID_CLAIM': 'user_id',
......@@ -156,10 +158,11 @@ SIMPLE_JWT = {
'TOKEN_TYPE_CLAIM': 'token_type',
'SLIDING_TOKEN_REFRESH_EXP_CLAIM': 'refresh_exp',
'SLIDING_TOKEN_LIFETIME': datetime.timedelta(minutes=5),
'SLIDING_TOKEN_LIFETIME': datetime.timedelta(minutes=60),
'SLIDING_TOKEN_REFRESH_LIFETIME': datetime.timedelta(days=1),
}
CORS_ORIGIN_ALLOW_ALL = True
CORS_ALLOW_ALL_ORIGINS = True # If this is used then `CORS_ALLOWED_ORIGINS` will not have any effect
CORS_ALLOW_CREDENTIALS = True
ASGI_APPLICATION = "backend.asgi.application"
\ No newline at end of file
......@@ -5,6 +5,8 @@ from backend.cms import views
from rest_framework_simplejwt.views import TokenObtainPairView, TokenRefreshView
from rest_framework.authtoken.views import obtain_auth_token
from django.contrib import admin
router = routers.DefaultRouter()
router.register(r'users', views.UserViewSet)
router.register(r'groups', views.GroupViewSet)
......@@ -16,6 +18,7 @@ router.register(r'conversations', views.ConversationViewSet)
# Additionally, we include login URLs for the browsable API.
urlpatterns = [
path('', include(router.urls)),
path('admin/', admin.site.urls),
path('api-auth/', include('rest_framework.urls', namespace='rest_framework')),
re_path(r'^api/auth/token/obtain/$', ObtainTokenPairWithUserView.as_view()),
re_path(r'^api/auth/token/refresh/$', TokenRefreshView.as_view()),
......
No preview for this file type
No preview for this file type
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment