How to Do Train Test Split in Sklearn

Sharing is caring!

Last Updated on July 14, 2022 by Jay

A common task in machine learning is to split the train and test datasets, we’ll take a look at how to do that using the sklearn library.

What is A Train Test Split

If we use a data point to train a model, then we can’t also use the same data sample to evaluate the model.

A key ability that our model should have is to generalize well on new and unseen data – not the ones we used to train the model. Since the model can simply memorize all the data we used for training, it doesn’t make sense to use any of the training data to predict the results since the model already knows it.

Usually, we would split a given dataset into two parts – a training dataset and a test dataset:

  1. Training dataset – for training the model
  2. Test dataset – to evaluate the model

Since the model hasn’t “seen” any of the data samples from the test set during training, we can use it to evaluate model performance. A common split partition is 75% for the training set and 25% for the test set.


We’ll need a few libraries for this exercise. Bring up a command prompt window and install them:

pip install numpy pandas sklearn


We’ll use a “fruits” dataset created by Dr. Ian Murray from the University of Edingurgh. Dr. Murray bought a few dozens of oranges, lemons and apples of different varieties, and recorded their measurements in a table.

import numpy as np
import pandas as pd

fruits = pd.read_csv('')

(59, 7)

   fruit_label fruit_name fruit_subtype  mass  width  height  color_score
0            1      apple  granny_smith   192    8.4     7.3         0.55
1            1      apple  granny_smith   180    8.0     6.8         0.59
2            1      apple  granny_smith   176    7.4     7.2         0.60
3            2   mandarin      mandarin    86    6.2     4.7         0.80
4            2   mandarin      mandarin    84    6.0     4.6         0.79

As we can see, there are 59 records in the fruits dataset.

Prepare the Input and Output Variables

An input variable, sometimes also known as “feature” is the attribute of our data sample. In the case of fruits, the input/feature variables are mass, width, height and color_score. We usually use X to represent input variables.

An output variable, sometimes referred to as a “label” of the data sample. In this case, either fruit_label or fruit_name is our output. We usually use y to represent output variables.

Train Test Split in Sklearn

Since the train test split is such a common task in machine learning, sklearn library has a function train_test_split to do just that.

train_test_split returns 4 objects in the order of X_train, X_test, y_train and y_test. The first two “X” objects are usually pandas dataframes, and the latter two “y” objects are usually pandas series.

from sklearn.model_selection import train_test_split

X = fruits[['mass','width','height','color_score']]
y = fruits['fruit_name']

X_train, X_test, y_train, y_test = train_test_split(X, y, random_state = 0)

(44, 4)

(15, 4)

Each label in the y_train corresponds to each feature row in the X_train, and each label in the y_test corresponds to each input row in the X_test. See below screenshots for the first several records in the training set – we can verify that the records match between X and y.

X_train dataset
X_train dataset
y_train dataset
y_train dataset

Data Partition

By default, the train_test_split method will split the original dataset as 75% training set and 25% test set. We can verify that using our fruits dataset:

  • Train: 44/59 = 75%
  • Test: 15/59 = 25%

We can customize the data partition by adding either of the following optional argument: test_size, train_size. We only need 1 of them, for example, if we set train_size = 0.8, then the test_size automatically equals 0.2.


To reproduce the output from the train_test_split method, we can use the argument random_state, which is sometimes called a “seed” for a random generator. Setting random_state = 0 will be good enough, but feel free to choose any integer.

Additional Resources

Least Squares Linear Regression With Python Example

Leave a Reply

Your email address will not be published. Required fields are marked *