Commit f3a515f3 authored by W.D.R.P. Sandeepa's avatar W.D.R.P. Sandeepa

Merge branch 'it18218640' into 'master'

add graph

See merge request !68
parents a61c6276 85472019
...@@ -2,12 +2,13 @@ import json ...@@ -2,12 +2,13 @@ import json
import numpy as np import numpy as np
import tensorflow.keras as keras import tensorflow.keras as keras
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
DATA_PATH = "data.json" DATA_PATH = "data.json"
SAVE_MODEL_PATH = "model.h5" SAVE_MODEL_PATH = "model.h5"
LEARNING_RATE = 0.0001 LEARNING_RATE = 0.0001
EPOCHS = 40 EPOCHS = 60
BATCH_SIZE = 32 BATCH_SIZE = 32
NUM_KEYWORDS = 10 NUM_KEYWORDS = 10
...@@ -77,6 +78,22 @@ def build_model(input_shape, learning_rate, error="sparse_categorical_crossentro ...@@ -77,6 +78,22 @@ def build_model(input_shape, learning_rate, error="sparse_categorical_crossentro
return model return model
def show_graph(history):
plt.plot(history.history['loss'], 'r', label='train loss')
plt.plot(history.history['val_loss'], 'g', label='validation loss')
plt.legend()
plt.xlabel('epochs')
plt.ylabel('Loss')
plt.show()
plt.subplot(1,2,1)
plt.ylabel('Accuracy')
plt.xlabel('Epochs')
plt.plot(history.history['accuracy'], label='Training Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.legend()
plt.show()
def main(): def main():
...@@ -88,7 +105,7 @@ def main(): ...@@ -88,7 +105,7 @@ def main():
model = build_model(input_shape, LEARNING_RATE) model = build_model(input_shape, LEARNING_RATE)
# train the model # train the model
model.fit(X_train, y_train, epochs=EPOCHS, batch_size=BATCH_SIZE, validation_data=(X_validation, y_validation)) history = model.fit(X_train, y_train, epochs=EPOCHS, batch_size=BATCH_SIZE, validation_data=(X_validation, y_validation))
# evaluate the model # evaluate the model
test_error, test_accuracy = model.evaluate(X_test, y_test) test_error, test_accuracy = model.evaluate(X_test, y_test)
...@@ -97,5 +114,7 @@ def main(): ...@@ -97,5 +114,7 @@ def main():
# save the model # save the model
model.save(SAVE_MODEL_PATH) model.save(SAVE_MODEL_PATH)
show_graph(history)
if __name__ == "__main__": if __name__ == "__main__":
main() main()
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment