Commit 9ba1e301 authored by W.D.R.P. Sandeepa's avatar W.D.R.P. Sandeepa

implemented get_data_splits function

parent c6817b52
......@@ -12,6 +12,22 @@ BATCH_SIZE = 32
NUM_KEYWORDS = 10
def get_data_splits(data_path, test_size=0.1, test_validation=0.1):
# load dataset
X, y = load_dataset(data_path)
# create train/validation/test splits
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size)
X_train, X_validation, y_train, y_validation = train_test_split(X_train, y_train, test_size=test_validation)
# convert inputs from 2D TO 3D array
X_train = X_train[..., np.newaxis]
X_validation = X_validation[..., np.newaxis]
X_test = X_test[..., np.newaxis]
return X_train, X_validation, X_test, y_train, y_validation, y_test
def build_model(input_shape, learning_rate, error="sparse_categorical_crossentropy"):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment