Commit faca315f authored by W.D.R.P. Sandeepa's avatar W.D.R.P. Sandeepa

Merge branch 'it18218640' into 'master'

implemented main function

See merge request !37
parents 8cf8adb8 e431ccee
......@@ -76,3 +76,26 @@ def build_model(input_shape, learning_rate, error="sparse_categorical_crossentro
model.summary()
return model
def main():
# load train/validation/test data splits
X_train, X_validation, X_test, y_train, y_validation, y_test = get_data_splits(DATA_PATH)
# build the CNN model
input_shape = (X_train.shape[1], X_train.shape[2], X_train.shape[3],)
model = build_model(input_shape, LEARNING_RATE)
# train the model
model.fit(X_train, y_train, epochs=EPOCHS, batch_size=BATCH_SIZE, validation_data=(X_validation, y_validation))
# evaluate the model
test_error, test_accuracy = model.evaluate(X_test, y_test)
print(f"Test error: {test_error}, test accuracy: {test_accuracy}")
# save the model
model.save(SAVE_MODEL_PATH)
if __name__ == "__main__":
main()
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment