An opinionated guide to scikit-learn
- 12. February 2024
- #datascience
scikit-learn is more than simply a collection of machine learning algorithms - it’s a set of building blocks for creating, tuning and evaluating machine learning pipelines.
This notebook demonstrates how to utilize the scikit-learn API to create non-trivial pipelines. It’s assumed that the reader has some prior experience with Python, scikit-learn and machine learning.
Why scikit-learn?
scikit-learn is to machine learning libraries what Python is to programming languages; perhaps not the best at any specific task, but likely the second best in nearly every task.
scikit-learn is not the be-all end-all of machine learning:
- If you want to do deep learning, neural networks or just need automatic differentiation—-use something like PyTorch, TensorFlow or JAX instead.
- If you need a not-so-common statistical model, you will likely end up using a more specialized Python package. You should consider switching out Python for the R language—-their packages cover much more of statistics.
- Need Bayesian statistics? Use Stan or PyMC, not scikit-learn.
The list goes on and on.
However, if you want to try out a linear model, a tree model, a KNN and neural net—-then scikit-learn is a great choice! The classic machine learning algorithms are all found in scikit-learn, and they are implemented in a unified, clean API. There’s also a lot of utility functions that integrate seamlessly with models: metrics, cross validation splitters, hyperparameter tuning, etc.
Scikit-learn mastery and mistakes
There are roughly three levels of scikit-learn mastery:
- Level 1: Being able to use the machine learning models.
- Level 2: Using building blocks that surround the models, e.g. metrics, transformers, cross validators, etc.
- Level 3: Writing custom code that follows the scikit-learn API to solve specific machine learning problems.
Framing it in a different way, we could say that beginners tend to make these mistakes:
- Only using the models, but not the full scikit-learn API, i.e., not taking advantage of built-in functionality.
- Creating additional functionality without conforming to the API. For instance writing a cross validation routine that does not follow the API, and is therefore unusable with meta-estimators such as grid search.
This notebook guides the reader through an example, showcasing how I often write my scikit-learn code.
I hope it’s useful to you! (see also commandement 6)
from watermark import watermark
print(watermark(packages="numpy,scipy,sklearn,pandas,matplotlib", python=True))
Python implementation: CPython Python version : 3.11.7 IPython version : 8.16.1 numpy : 1.26.4 scipy : 1.12.0 sklearn : 1.4.0 pandas : 2.2.0 matplotlib: 3.8.2
import functools
import operator
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.base import BaseEstimator, OneToOneFeatureMixin, TransformerMixin
from sklearn.compose import ColumnTransformer
from sklearn.dummy import DummyRegressor
from sklearn.ensemble import RandomForestRegressor
from sklearn.impute import SimpleImputer
from sklearn.inspection import permutation_importance
from sklearn.linear_model import Ridge
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.metrics import make_scorer
from sklearn.model_selection import (
GridSearchCV,
RandomizedSearchCV,
cross_validate,
TimeSeriesSplit,
)
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
Load and preprocess data
To showcase scikit-learn, we’ll predict the IMDb score of future movies.
It’s kind of a silly example: not something that one would necessarily create a predictive model for, but it suffices to showcase sklearn.
In the code cell below, we:
- Download data from the repo github.com/maazh/IMDB-Movie-Dataset-Analysis.
- Select columns for the regression task.
- Convert dates, adding the century in the process.
- Create missing values where
budget == 0
. One could argue that this preprocessing step should be cross validated over, and therefore should be part of the Pipeline, but we’ll just do it here for simplicity’s sake.
def add_century(release_date):
"""Add century to a release date such as '12/21/66'."""
month, date, year = release_date.split("/")
century = "19" if int(year) > 25 else "20"
return "/".join([month, date, century + year])
# Test the function above
assert add_century("6/9/15") == "6/9/2015"
assert add_century("12/21/66") == "12/21/1966"
# Download data
url = "https://raw.githubusercontent.com/maazh/IMDB-Movie-Dataset-Analysis/master/tmdb-movies.csv"
df = pd.read_csv(url)
# Select columns
COLUMNS = ["original_title", "runtime", "genres", "release_date",
"budget_adj", "keywords", "vote_average"]
df = df[COLUMNS]
# Convert dates
df = df.assign(release_date=pd.to_datetime(df.release_date.apply(add_century)))
# Convert genres and keywords to strings
df = df.assign(genres=lambda df: df["genres"].apply(str))
df = df.assign(keywords=lambda df: df["keywords"].apply(str))
# Map zero budget to missing values
df[["budget_adj"]] = df[["budget_adj"]].replace({0: np.nan})
- I often sample the dataset, so the development experience is fast and smooth.
- Don’t sample too few rows! We need enough rows to do meaningful modeling.
This is an important lesson. Too many people sit and wait five seconds, or half a minute, or even many minutes, for models to run. This is fine at the very end when we want to run through the full pipeline and search for hyperparameters, but initially we want to prioritize fast development speed.
# df = df.sample(1000) # Sample rows
df.sample(3, random_state=3)
original_title | runtime | genres | release_date | budget_adj | keywords | vote_average | |
---|---|---|---|---|---|---|---|
9641 | No Way Out | 114 | Crime|Drama|Action|Thriller|Mystery | 1987-08-14 | 2.878621e+07 | homicide|pentagon|minister|us navy | 7.0 |
9785 | The Towering Inferno | 159 | Action|Thriller | 1974-12-14 | 6.191197e+07 | skyscraper|firemen|rescue|disaster|trapped | 7.2 |
9809 | Barry Lyndon | 184 | Drama|Romance|War | 1975-12-18 | 4.457003e+07 | palace|british army|fencing|epic|debt | 7.2 |
Decide on a cross validation strategy
- Always choose a cross validation strategy that matches how the model will be used.
- Use the same kind of splitter to create (1) cross validation sets and (2) a test set.
- We choose a time-based split, since that’s how the model will be used - predicting the score of future movies.
- If no existing cross-validator in scikit-learn is appropriate, create your own that follows the API.
- If we were modeling primarily to do inference (understanding how the features affect the IMDb score) instead of prediction (predicting future movies), then a shuffled K-fold cross validation strategy would be a good choice.
Some more examples of this principle:
- Time series and weather. If you are trying to predict the weather, don’t do shuffled K-fold cross validation—-you’ll use information from the future to predict the weather of today. This violates the principle of matching the cross validation strategy to how the model will be used. Instead, use a time series split. If you are trying to understand the weather, then you’re justified in using a shuffled K-fold split.
- Medical clinics. Imagine that the training data consists of patients from three clinics A, B and C. If the goal is to predict on the patients of a fourth clinic D in the future, then the appropriate cross validation strategy should be to create datasets
[(A, B), (A, C), (B, C)]
. Removing each clinic in turn and predicting on it mimics how the model will be used. - House property. Assume we want to predict some property of newly built houses. Assume that geographical position, wall isolation efficiency, size, etc. are features. Future houses are typically built on the outskirts of the city, so the model will extrapolate the geographical coordinates when we use the model in the future. Similarly we are likely to extrapolate on isolation levels, because newer houses are built with better isolation. Even though this does not look like a time series problem, we should cross validate using a time series split. We should also use time to create a test set. This most closely mimics how the model will be used.
- Medical clinics revisited. Imagine a data set with information about clinics, clients and visits in time. Should you split by clinics? By clients? Or choose a point in time? Split randomly using shuffled K-fold? It depends on how the model is meant to be used.
# For the TimeSeriesSplit to work correctly,
# the dataset must be sorted by time
df = df.sort_values("release_date")
# We'll use this cross validation strategy for hyperparameter tuning
cv = TimeSeriesSplit(n_splits=10)
# Split into test and train sets using a 80/20 split
# Similar to:
# *_, (train_idx, test_idx) = list(cv.split(df))
indices = np.arange(len(df))
i = int(len(df) * 0.8)
train_idx, test_idx = indices[:i], indices[i:]
df, df_test = df.iloc[train_idx, :], df.iloc[test_idx, :]
Create Transformers
Always use transformers when preprocessing data and generating features
There are two main reason for this:
- Prevent data leakage. Transformations that learn from the data, such as min-max scaling, should always be coded up using a Transformer. This prevents information from leaking from the validation set to the training set.
- Feature engineering is also hyperparameter selection. Transformers allow you to tune hyperparameters. Consider something as simple as applying a logarithm. It’s not a transformation that learns parameters from the data, so no information can leak. However, even the choice of whether or not to apply the logarithm is a modeling choice. Don’t do this: apply the logarithm, note down the RMSE on a piece of paper, go back up in your code, comment out that line of code, re-run the notebook, check if the RMSE went down, etc—-this approach does not scale! Instead, use a Transformer. See
BoxCox
below.
class PassThrough(TransformerMixin, BaseEstimator):
"""Pass a DataFrame through (identity func), storing feature names."""
def fit(self, X, y=None):
return self
def transform(self, X, y=None):
self.feature_names_out_ = list(X.columns)
return X.values
def get_feature_names_out(self, input_features=None):
return self.feature_names_out_
class DateTimeSplitter(TransformerMixin, BaseEstimator):
"""Split datetime into 3 columns: (day, month, year)"""
def fit(self, X, y=None):
return self
def transform(self, X, y=None):
return np.vstack(
(X.release_date.dt.day,
X.release_date.dt.month,
X.release_date.dt.year)
).T
def get_feature_names_out(self, input_features=None):
return ["day", "month", "year"]
class BoxCox(TransformerMixin, BaseEstimator, OneToOneFeatureMixin):
"""Transform 'between' linear and logarithmic.
Gamma = 1 => x, Gamma -> 0 => log(x + 1)"""
# https://en.wikipedia.org/wiki/Power_transform#Box%E2%80%93Cox_transformation
def __init__(self, gamma=1):
self.gamma = gamma
def fit(self, X, y=None):
return self
def transform(self, X, y=None):
self.feature_names_out_ = list(X.columns)
return (np.power(X.values + 1, self.gamma) - 1) / self.gamma
def get_feature_names_out(self, input_features=None):
return self.feature_names_out_
class DenseCountVectorizer(CountVectorizer, TransformerMixin, BaseEstimator):
"""Like a CountVectorizer, but returns a dense Pandas results."""
def transform(self, raw_documents):
index = raw_documents.index
return pd.DataFrame(
super().transform(raw_documents).toarray(),
index=index,
columns=self.get_feature_names_out(),
)
def fit_transform(self, raw_documents, y=None):
index = raw_documents.index
return pd.DataFrame(
super().fit_transform(raw_documents, y).toarray(),
index=index,
columns=self.get_feature_names_out(),
)
Decide on a metric
Think about what the best metric is for your problem! Don’t blindly use RMSE.
- Are errors
(3, 3)
better or worse than errors(5, 0)
for your application?- If
(3, 3)
is better, then RMSE might be appropriate, sinceRMSE(3, 3) = 3 < 3.54 = RMSE(5, 0)
- If
(5, 0)
is better, then MAE might be probably appropriate, sinceMAE(5, 0) = 2.5 < 3 = MAE(3, 3)
- If
- Are errors really symmetric? For instance, is predicting a value that is too low equally bad as predicting a value that is too high?
- Are relative errors appropriate? For instance, if we’re predicting dosage of a drug, then predicting 10 grams too much for a person of 50kg bodyweight might be worse than 10 grams too much for a person of 100kg bodyweight
- Ideally you want the metric and the loss function used in the model to coincide. For instance, if MAE is the most meaningful metric, then you probably want to use a QuantileRegressor, which actually minimizes the MAE, not Ridge which minimizes the RMSE.
def RMSE(y, y_pred):
# Example showing how to create a custom scoring function
return np.sqrt(np.mean(np.power(y - y_pred, 2)))
# We'll use this scoring function during cross validation below.
# This is overkill - we could've used sklearn.metrics.mean_squared_error,
# but the point is to show how one can define custom scoring functions
scoring = make_scorer(RMSE, greater_is_better=False)
# Here we use the mean, since the mean minimizes our metric (RMSE)
dummy_pipeline = Pipeline(steps=[("dummy", DummyRegressor(strategy="mean"))])
results = cross_validate(
dummy_pipeline,
df,
df["vote_average"],
cv=cv, # We use the cross-validation strategy decided on above
scoring=scoring, # We use the metric decided on above
)
# Print out the average RMSE score over all folds
np.mean(-results["test_score"])
0.9133731762599415
Create a Pipeline for a Linear model
- Feature transformations should complement the model, for instance:
- We can use
BoxCox
transformations for a linear model, but it would have no effect on a Random Forest. - We can use a
StandardScaler
for a linear model, which would be unnecessary in a Random Forest. - Why are they not necessary? Random Forests are invariant under monotonic transformations, including scaling and shifting.
- We can use
- Even though the Pipeline is a bit complex, using a Pipeline has many advantages:
- Minimize the chance of data leakage.
- Ability to optimize hyperparameters.
- Easy to modify in the future.
On many problems, spending some time doing sensible feature engineering and using a linear model will be competitive with black box models such as Random Forests, Neural Networks, Boosting etc. This depends very much on the problem and what the goal of the modeling work is, see commandement 5.
linear_pipeline = Pipeline(
steps=[
(
"column-transformer",
ColumnTransformer(
transformers=[
(
"numerical-runtime",
BoxCox(),
[
"runtime",
],
),
("numerical-budget_adj", BoxCox(), ["budget_adj"]),
("date", DateTimeSplitter(), ["release_date"]),
(
"label-genre",
DenseCountVectorizer(
tokenizer=lambda x: x.split("|"),
token_pattern=None,
),
"genres",
),
(
"label-keywords",
DenseCountVectorizer(
max_features=10,
tokenizer=lambda x: x.split("|"),
token_pattern=None,
),
"keywords",
),
],
remainder="drop",
).set_output(transform="pandas"),
),
# Imputation of 'budget', possibly conditioned on 'runtime' and 'year'
(
"imputation",
ColumnTransformer(
transformers=[
(
"numerical-imputation",
# Here we could have used a KNNImputer instead
SimpleImputer(strategy="median"),
[
"numerical-runtime__runtime",
"numerical-budget_adj__budget_adj",
"date__year",
],
),
],
remainder="passthrough",
).set_output(transform="pandas"),
),
("scaler", StandardScaler()),
("ridge", Ridge()),
]
)
linear_pipeline
Pipeline(steps=[('column-transformer', ColumnTransformer(transformers=[('numerical-runtime', BoxCox(), ['runtime']), ('numerical-budget_adj', BoxCox(), ['budget_adj']), ('date', DateTimeSplitter(), ['release_date']), ('label-genre', DenseCountVectorizer(token_pattern=None, tokenizer=<function <lambda> at 0x7f65b0028400>), 'genres'), ('label-keywords', DenseCountVect..._features=10, token_pattern=None, tokenizer=<function <lambda> at 0x7f65b0028040>), 'keywords')])), ('imputation', ColumnTransformer(remainder='passthrough', transformers=[('numerical-imputation', SimpleImputer(strategy='median'), ['numerical-runtime__runtime', 'numerical-budget_adj__budget_adj', 'date__year'])])), ('scaler', StandardScaler()), ('ridge', Ridge())])In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
Pipeline(steps=[('column-transformer', ColumnTransformer(transformers=[('numerical-runtime', BoxCox(), ['runtime']), ('numerical-budget_adj', BoxCox(), ['budget_adj']), ('date', DateTimeSplitter(), ['release_date']), ('label-genre', DenseCountVectorizer(token_pattern=None, tokenizer=<function <lambda> at 0x7f65b0028400>), 'genres'), ('label-keywords', DenseCountVect..._features=10, token_pattern=None, tokenizer=<function <lambda> at 0x7f65b0028040>), 'keywords')])), ('imputation', ColumnTransformer(remainder='passthrough', transformers=[('numerical-imputation', SimpleImputer(strategy='median'), ['numerical-runtime__runtime', 'numerical-budget_adj__budget_adj', 'date__year'])])), ('scaler', StandardScaler()), ('ridge', Ridge())])
ColumnTransformer(transformers=[('numerical-runtime', BoxCox(), ['runtime']), ('numerical-budget_adj', BoxCox(), ['budget_adj']), ('date', DateTimeSplitter(), ['release_date']), ('label-genre', DenseCountVectorizer(token_pattern=None, tokenizer=<function <lambda> at 0x7f65b0028400>), 'genres'), ('label-keywords', DenseCountVectorizer(max_features=10, token_pattern=None, tokenizer=<function <lambda> at 0x7f65b0028040>), 'keywords')])
['runtime']
BoxCox()
['budget_adj']
BoxCox()
['release_date']
DateTimeSplitter()
genres
DenseCountVectorizer(token_pattern=None, tokenizer=<function <lambda> at 0x7f65b0028400>)
keywords
DenseCountVectorizer(max_features=10, token_pattern=None, tokenizer=<function <lambda> at 0x7f65b0028040>)
ColumnTransformer(remainder='passthrough', transformers=[('numerical-imputation', SimpleImputer(strategy='median'), ['numerical-runtime__runtime', 'numerical-budget_adj__budget_adj', 'date__year'])])
['numerical-runtime__runtime', 'numerical-budget_adj__budget_adj', 'date__year']
SimpleImputer(strategy='median')
passthrough
StandardScaler()
Ridge()
Tune hyperparameters for the Linear model
- Hyperparameters should be tuned using the chosen CV strategy and the chosen metric.
- To speed up this part while experimenting, do the following:
- Consider sub-sampling the data set at the very beginning of the notebook.
- Set
n_iter
inRandomizedSearchCV
to a low number. - Set
n_jobs=-1
to run on all available cores on your CPU. - Once you’re happy, do a full run while getting lunch or coffee.
- More advanced strategies than
RandomizedSearchCV
are possible, e.g. Bayesian optimization- However, Bayesian optimization is the kind of fine-tuning that would go in later in a ML project
param_grid = {
# Regularization for the linear model
"ridge__alpha": np.logspace(-2, 2, num=25),
# Logarithm, or identity function?
"column-transformer__numerical-runtime__gamma": np.linspace(0.01, 3, num=25),
"column-transformer__numerical-budget_adj__gamma": np.linspace(0.01, 3, num=25),
# How many labels are useful to include?
"column-transformer__label-keywords__max_features": [2**i for i in range(0, 11)],
}
grid_search = RandomizedSearchCV(
linear_pipeline,
param_grid,
cv=cv,
scoring=scoring,
n_iter=100,
refit=True,
n_jobs=-1,
random_state=42,
)
# Find the best hyper-parameters and print them, along with RMSE
grid_search.fit(df, df["vote_average"])
print(grid_search.best_params_, np.mean(-grid_search.best_score_))
# Retrieve the model with the best hyperparameters (the model is trained on all data)
linear_pipeline = grid_search.best_estimator_
{'ridge__alpha': 3.1622776601683795, 'column-transformer__numerical-runtime__gamma': 0.8820833333333333, 'column-transformer__numerical-budget_adj__gamma': 1.13125, 'column-transformer__label-keywords__max_features': 4} 0.8055133054752345
Inspect the Linear model
- Here we simply print the model coefficients.
- The model coefficients were obtained on scaled inputs by the
StandardScaler
, so we can interpret them as feature importances. - We can use
get_feature_names_out()
on the Pipeline, since each Transformer implementsget_feature_names_out()
# Get pretty variable names
variables_names = linear_pipeline[:1].get_feature_names_out()
# Get coefficients from the Ridge model
ridge_coefficients = linear_pipeline[-1].coef_
# Quick and dirty plot using pd.Series plotting functionality
pd.Series(ridge_coefficients, index=variables_names).sort_values().plot.barh(
figsize=(5, 5)
)
plt.tight_layout()
plt.show()
rf_pipeline = Pipeline(
steps=[
(
"column-transformer",
ColumnTransformer(
transformers=[
(
"numerical",
SimpleImputer(strategy="median"),
["runtime", "budget_adj"],
),
("date", DateTimeSplitter(), ["release_date"]),
(
"label-genre",
DenseCountVectorizer(
tokenizer=lambda x: x.split("|"),
token_pattern=None,
),
"genres",
),
(
"label-keywords",
DenseCountVectorizer(
max_features=10,
tokenizer=lambda x: x.split("|"),
token_pattern=None,
),
"keywords",
),
]
).set_output(transform="pandas"),
),
(
"random-forest",
RandomForestRegressor(
n_estimators=100, criterion="squared_error", random_state=42
),
),
]
)
rf_pipeline
Pipeline(steps=[('column-transformer', ColumnTransformer(transformers=[('numerical', SimpleImputer(strategy='median'), ['runtime', 'budget_adj']), ('date', DateTimeSplitter(), ['release_date']), ('label-genre', DenseCountVectorizer(token_pattern=None, tokenizer=<function <lambda> at 0x7f65ae9fc360>), 'genres'), ('label-keywords', DenseCountVectorizer(max_features=10, token_pattern=None, tokenizer=<function <lambda> at 0x7f65ae9fc400>), 'keywords')])), ('random-forest', RandomForestRegressor(random_state=42))])In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
Pipeline(steps=[('column-transformer', ColumnTransformer(transformers=[('numerical', SimpleImputer(strategy='median'), ['runtime', 'budget_adj']), ('date', DateTimeSplitter(), ['release_date']), ('label-genre', DenseCountVectorizer(token_pattern=None, tokenizer=<function <lambda> at 0x7f65ae9fc360>), 'genres'), ('label-keywords', DenseCountVectorizer(max_features=10, token_pattern=None, tokenizer=<function <lambda> at 0x7f65ae9fc400>), 'keywords')])), ('random-forest', RandomForestRegressor(random_state=42))])
ColumnTransformer(transformers=[('numerical', SimpleImputer(strategy='median'), ['runtime', 'budget_adj']), ('date', DateTimeSplitter(), ['release_date']), ('label-genre', DenseCountVectorizer(token_pattern=None, tokenizer=<function <lambda> at 0x7f65ae9fc360>), 'genres'), ('label-keywords', DenseCountVectorizer(max_features=10, token_pattern=None, tokenizer=<function <lambda> at 0x7f65ae9fc400>), 'keywords')])
['runtime', 'budget_adj']
SimpleImputer(strategy='median')
['release_date']
DateTimeSplitter()
genres
DenseCountVectorizer(token_pattern=None, tokenizer=<function <lambda> at 0x7f65ae9fc360>)
keywords
DenseCountVectorizer(max_features=10, token_pattern=None, tokenizer=<function <lambda> at 0x7f65ae9fc400>)
RandomForestRegressor(random_state=42)
Tune hyperparameters for the Random Forest
- Don’t tune every thinkable parameter, scikit-learn has sensible defaults!
- Instead, choose a few main complexity parameters and tune those (here
min_samples_leaf
). - Start with a large range, then decrease the range depending on the data set and problem as you iterate and experiment.
- Don’t experiment too much prematurely :)
param_grid = {
# Overall model complexity is captured by 'min_samples_leaf'
"random-forest__min_samples_leaf": [2**i for i in range(1, 7)],
# Adding even more labels does help the model, but only marginally
"column-transformer__label-keywords__max_features": [2**i for i in range(0, 6)],
}
grid_search = GridSearchCV(
rf_pipeline,
param_grid,
cv=cv,
scoring=scoring,
refit=True,
n_jobs=-1,
)
# Find the best hyper-parameters and print them, along with RMSE
grid_search.fit(df, df["vote_average"])
print(grid_search.best_params_, np.mean(-grid_search.best_score_))
# Retrieve best model, trained on the full data set
rf_pipeline = grid_search.best_estimator_
{'column-transformer__label-keywords__max_features': 32, 'random-forest__min_samples_leaf': 8} 0.7744583909813831
Inspect the Random Forest
- Model inspection becomes much more difficult, since the RF is a non-additive, black-box model
- We look at feature importances using two methods:
- Gini importances (see description here) - on the features the model sees
- Permutation importances (evaluated on a test set) on the input variables
We could also have computed permutation importances on the features the model sees (transformed features).
There are many ways to inspect black-box models such as a Random Forest, but they are out of scope for this notebook.
# Get pretty variable names
variables_names = [
varname.split("__")[-1].capitalize()
for varname in rf_pipeline[:1].get_feature_names_out()
]
# Compute and output Gini Importances
pd.Series(rf_pipeline[-1].feature_importances_, index=variables_names).loc[
lambda ser: ser > 1e-3
].sort_values().plot.barh(figsize=(5, 4.5))
plt.tight_layout()
plt.show()
# Get variable names to permute, no reason to impute features dropped by the pipeline
variables_names = functools.reduce(
operator.add,
[
(
[input_feature_names]
if isinstance(input_feature_names, str)
else input_feature_names
)
for (_, _, input_feature_names) in rf_pipeline[0].transformers_[:-1]
],
)
# Get indices for train and test
*_, (perm_train_idx, perm_test_idx) = TimeSeriesSplit(n_splits=2).split(df)
# Train model on first half
rf_pipeline.fit(df.iloc[perm_train_idx], df["vote_average"].iloc[perm_train_idx])
results = permutation_importance(
rf_pipeline,
df[variables_names].iloc[perm_train_idx],
df["vote_average"].iloc[perm_train_idx],
scoring=scoring,
n_repeats=10,
n_jobs=-1,
random_state=42,
)
pd.Series(results["importances_mean"], index=variables_names).sort_values().plot.barh(
figsize=(5, 1.5)
)
plt.tight_layout()
plt.show()
Evaluate models on the test set
- Finally, we evaluate the models on the test set.
- Only do this once, when you’re completely satisfied with your model and wish to report on it.
- If you peek at the test set, then change the model because of what you see, you’ve cheated.
- Loop over validation folds in the cross validation and assess performance if you like :)
In summary, neither model is very good. That’s to be expected, predicting IMDb ratings using only this limited information is a hard problem.
plt.figure(figsize=(5, 4))
plt.title("Evaluating models on the test set")
for pipeline in [dummy_pipeline, linear_pipeline, rf_pipeline]:
model_name, _ = pipeline.steps[-1]
pipeline.fit(df, df["vote_average"])
r2 = pipeline.score(df_test, df_test["vote_average"]).round(2)
rmse = -scoring(pipeline, df_test, df_test["vote_average"]).round(2)
plt.scatter(
# Jitter the data a little bit
df_test["vote_average"] + np.random.rand(len(df_test)) * 0.1,
pipeline.predict(df_test),
alpha=0.33,
s=3,
label=f"{model_name} (r2 = {r2}, rmse = {rmse})",
)
plt.plot([2, 9], [2, 9], color="black")
plt.xlabel("Actual score")
plt.ylabel("Predicted score")
plt.legend()
plt.tight_layout()
plt.show()
Summary
- Fast code is important in the development phase. Subsample data, use coarse hyperparam grids, etc.
- Choose a cross validation strategy that matches how the model will be used.
- Use Transformers to capture modeling choices and transformations, avoiding data leakage.
- Decide on a metric that makes sense for your problem - and ideally optimize this metric directly in the loss.
- Feature transformations are model dependent. A logarithm can help a linear model, but does not help RF.
- I like to always test two models: a dummy model (lower bound) and a complex model (upper bound).
- When you tune hyperparameters, focus on a few parameters that capture overall model complexity.
- Model inspection and exploratory data analysis (EDA) is always a good idea. We did not do much EDA here.
- Only use the test set once, at the very end of the analysis when you’re completely happy with the results.
On this regression task, with this dataset, it’s hard to get good performance. There are many reasons why a movie might receive a high or low rating, and most of the variables that capture this information are not in this dataset. This makes sense, it would be amazing if knowing a few keywords, the runtime, etc could reliably predict the IMDb score with high accuracy.
Hope you learned something!
References
Further reading:
- Hands-On Machine Learning with Scikit-Learn and TensorFlow
- API design for machine learning software: experiences from the scikit-learn project
Videos (old but good!):
- Jake Vanderplas, Olivier Grisel: Exploring Machine Learning with Scikit-learn - PyCon 2014
- Jake VanderPlas: Machine Learning with Scikit-Learn - PyCon 2015
- Jake VanderPlas: Machine Learning with Scikit Learn - PyData Seattle 2015
- Andreas Mueller - Machine Learning with Scikit-Learn - PyData Amsterdam 2016