Commit b733b112 authored by Malsha Rathnasiri's avatar Malsha Rathnasiri

improve train.py

parent eb14f11a
...@@ -97,37 +97,47 @@ def train(): ...@@ -97,37 +97,47 @@ def train():
train_audio_path = r'./backend/data/train/train/audio/' train_audio_path = r'./backend/data/train/train/audio/'
# all_wave = [] all_wave = []
# all_label = [] all_label = []
# for label in labels:
# print(label)
# waves = [f for f in os.listdir(
# train_audio_path + '/' + label) if f.endswith('.wav')]
# for wav in waves:
# samples, sample_rate = librosa.load(
# train_audio_path + '/' + label + '/' + wav, sr=16000)
# samples = librosa.resample(samples, sample_rate, 8000)
# if(len(samples) == 8000):
# all_wave.append(samples)
# all_label.append(label)
# print('3')
f1 = open('all_label.txt', 'rb') f1 = open('all_label.txt', 'rb')
all_label = pickle.load(f1) all_label = pickle.load(f1)
print('loaded labels')
f2 = open('all_waves_file.txt', 'rb') f2 = open('all_waves_file.txt', 'rb')
all_wave = pickle.load(f2) all_wave = pickle.load(f2)
print('loaded waves')
if(all_wave and all_label):
print('loaded labels and waves')
else:
print('Creating labels and waves files')
for label in labels:
print(label)
waves = [f for f in os.listdir(
train_audio_path + '/' + label) if f.endswith('.wav')]
for wav in waves:
samples, sample_rate = librosa.load(
train_audio_path + '/' + label + '/' + wav, sr=16000)
samples = librosa.resample(samples, sample_rate, 8000)
if(len(samples) == 8000):
all_wave.append(samples)
all_label.append(label)
# print('3')
all_labels_file = open('all_label.txt', 'wb+')
pickle.dump(file=all_labels_file, obj=all_label)
all_labels_file.close()
all_waves_file = open('all_waves_file.txt', 'wb+')
pickle.dump(file=all_waves_file, obj=all_wave)
all_waves_file.close()
print('Done: creating labels and waves files')
le = LabelEncoder() le = LabelEncoder()
y = le.fit_transform(all_label) y = le.fit_transform(all_label)
classes = list(le.classes_) classes = list(le.classes_)
print(classes)
print(all_wave)
print('4') print('4')
y = np_utils.to_categorical(y, num_classes=len(labels)) y = np_utils.to_categorical(y, num_classes=len(labels))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment