Commit a5f7dbf9 authored by W.D.R.P. Sandeepa's avatar W.D.R.P. Sandeepa

Merge branch 'it18218640' into 'master'

implemented get_data_splits function

See merge request !35
parents 74a60172 4dda1df9
...@@ -12,6 +12,22 @@ BATCH_SIZE = 32 ...@@ -12,6 +12,22 @@ BATCH_SIZE = 32
NUM_KEYWORDS = 10 NUM_KEYWORDS = 10
def get_data_splits(data_path, test_size=0.1, test_validation=0.1):
# load dataset
X, y = load_dataset(data_path)
# create train/validation/test splits
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size)
X_train, X_validation, y_train, y_validation = train_test_split(X_train, y_train, test_size=test_validation)
# convert inputs from 2D TO 3D array
X_train = X_train[..., np.newaxis]
X_validation = X_validation[..., np.newaxis]
X_test = X_test[..., np.newaxis]
return X_train, X_validation, X_test, y_train, y_validation, y_test
def build_model(input_shape, learning_rate, error="sparse_categorical_crossentropy"): def build_model(input_shape, learning_rate, error="sparse_categorical_crossentropy"):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment