diff --git a/model.py b/model.py index 09fb93c..c7644c0 100644 --- a/model.py +++ b/model.py @@ -55,11 +55,12 @@ def mse(y_test, y_pred): # Split data into train and test sets -X_train = X[:40] # first 40 examples (80% of data) -y_train = y[:40] +N = 30 +X_train = X[:N] # first 40 examples (80% of data) +y_train = y[:N] -X_test = X[40:] # last 10 examples (20% of data) -y_test = y[40:] +X_test = X[N:] # last 10 examples (20% of data) +y_test = y[N:] # Take a single example of X