17. LightGBM

LightGBM is another gradient boosting API.

17.1. Load data

We will use the diabetes data.

from sklearn.datasets import load_diabetes

X, y = load_diabetes(return_X_y=True, as_frame=True)
X.shape, y.shape
((442, 10), (442,))
from sklearn.model_selection import train_test_split

X_tr, X_te, y_tr, y_te = train_test_split(X, y, test_size=0.10, random_state=37)

X_tr.shape, X_te.shape, y_tr.shape, y_te.shape
((397, 10), (45, 10), (397,), (45,))

17.2. Tuning

Now, the tuning begins. Optuna requires an objective function that takes in a trial object and returns a scalar or tuple; when a tuple of scalar values is returned, the tuning is called multiobjective tuning. In this example, we have only one objective which is to minimize the mean absolute erorr MAE.

from sklearn.pipeline import Pipeline
from sklearn.impute import SimpleImputer
from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error
import lightgbm as lgbm
import numpy as np
import optuna


def get_model(imputer_params={}, regressor_params={}):
    model = Pipeline([
        ('imputer', SimpleImputer(**imputer_params)),
        ('regressor', lgbm.LGBMRegressor(**regressor_params))

    return model

def objective(trial):
    i_params = {
        'strategy': trial.suggest_categorical('strategy', ['mean', 'median', 'most_frequent'])

    r_params = {
        'boosting_type': 'gbdt',
        'num_leaves': trial.suggest_int('num_leaves', 30, 50),
        'max_depth': trial.suggest_int('max_depth', 0, 100),
        'n_estimators': trial.suggest_int('n_estimators', 80, 200),
        'class_weight': 'balanced',
        'random_state': 37,
        'n_jobs': -1

    model = get_model(i_params, r_params)
    model.fit(X_tr, y_tr)

    y_pred = model.predict(X_te)

    mae = mean_absolute_error(y_te, y_pred)
    rmse = mean_squared_error(y_te, y_pred, squared=False)
    r2s = r2_score(y_te, y_pred)

    trial.set_user_attr('mae', mae)
    trial.set_user_attr('rmse', rmse)
    trial.set_user_attr('r2s', r2s)

    return mae

After we create an objective function, we can create a study and perform optimization.

study = optuna.create_study(**{
    'study_name': 'lightgbm-study',
    'storage': 'sqlite:///_temp/lightgbm-study.db',
    'load_if_exists': True,
    'direction': 'minimize',
    'sampler': optuna.samplers.TPESampler(seed=37),
    'pruner': optuna.pruners.MedianPruner(n_warmup_steps=10)

    'func': objective,
    'n_trials': 100,
    'n_jobs': 1,
    'show_progress_bar': False

Now we may look at the best hyperparameters, value (the value we are trying to optmize for), and trial.

{'max_depth': 7, 'n_estimators': 81, 'num_leaves': 33, 'strategy': 'median'}
FrozenTrial(number=77, values=[43.272510110290995], datetime_start=datetime.datetime(2023, 6, 1, 22, 54, 40, 530753), datetime_complete=datetime.datetime(2023, 6, 1, 22, 54, 40, 578125), params={'max_depth': 7, 'n_estimators': 81, 'num_leaves': 33, 'strategy': 'median'}, distributions={'max_depth': IntDistribution(high=100, log=False, low=0, step=1), 'n_estimators': IntDistribution(high=200, log=False, low=80, step=1), 'num_leaves': IntDistribution(high=50, log=False, low=30, step=1), 'strategy': CategoricalDistribution(choices=('mean', 'median', 'most_frequent'))}, user_attrs={'mae': 43.272510110290995, 'r2s': 0.5377439510450933, 'rmse': 56.98427363699325}, system_attrs={}, intermediate_values={}, trial_id=78, state=TrialState.COMPLETE, value=None)

17.3. Plotting

There are several plots you may use to understand the hyperparameter optmization results.

from optuna.visualization import plot_optimization_history

    'study': study
from optuna.visualization import plot_parallel_coordinate

    'study': study
from optuna.visualization import plot_param_importances

    'study': study
from optuna.visualization import plot_slice

    'study': study,
    'params': ['num_leaves', 'max_depth', 'n_estimators']
from optuna.visualization import plot_contour

    'study': study,
    'params': ['num_leaves', 'max_depth', 'n_estimators']
from optuna.visualization import plot_edf
