.. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_plot_grid_search.py: Grid-search and cross-validation -------------------------------- This examples presents the interface with scikit-learn's GridSearchCV. It creates an artificial signal with phase-amplitude coupling (PAC), fits a DAR model over a grid-search of parameter, using cross_validation. Cross-validation is done over epochs, with any strategy provided in scikit-learn (KFold, ...). Note that the score computed by a DARSklearn model is the log-likelihood from which we subtracted the log-likelihood of an autoregressive model at order 0. This is done to have a reference stable over cross-validation splits. .. code-block:: default import pandas as pd import numpy as np import matplotlib.pyplot as plt from sklearn.pipeline import Pipeline from pactools import simulate_pac from pactools.grid_search import ExtractDriver, AddDriverDelay from pactools.grid_search import DARSklearn, MultipleArray from pactools.grid_search import GridSearchCVProgressBar Let's first create an artificial signal with PAC. .. code-block:: default fs = 200. # Hz high_fq = 50.0 # Hz low_fq = 5.0 # Hz low_fq_width = 1.0 # Hz n_epochs = 3 n_points = 10000 noise_level = 0.4 low_sig = np.array([ simulate_pac(n_points=n_points, fs=fs, high_fq=high_fq, low_fq=low_fq, low_fq_width=low_fq_width, noise_level=noise_level, random_state=i) for i in range(n_epochs) ]) Let's define the model with a scikit-learn's pipeline. In a pipeline, the output of each step is given as input to the next one. Here #we start with `ExtractDriver`, which extracs the driver with a bandpass #filter, and removes it from #the modeled signal with a highpass filter. Then #we follow with `AddDriverDelay`, which adds a delay between the driver and the #modeled signal. Finally, we define the DAR model with `DARSklearn`. .. code-block:: default model = Pipeline(steps=[ ('driver', ExtractDriver(fs=fs, low_fq=4., max_low_fq=7., low_fq_width=low_fq_width, random_state=0)), ('add', AddDriverDelay()), ('dar', DARSklearn(fs=fs, max_ordar=100)), ]) # grid of parameter on which we will loop param_grid = { 'dar__ordar': np.arange(0, 110, 30), 'dar__ordriv': [0, 1, 2], 'add__delay': [0], 'driver__low_fq': [3., 4., 5., 6., 7.], 'driver__low_fq_width': [0.25, 0.5, 1.], } Then we plug the model into GridSearchCV and we fit it. This performs a grid-search with cross-validation: First, multiple train and test sets are defined by the splitting strategy, as defined by the parameter `cv` in GridSearchCV. Then, GridSearchCV will loop over each parameter configuration, fitting the model on one train set and evaluating it on the corresponding test set. .. code-block:: default # Plug the model and the parameter grid into a GridSearchCV estimator # (GridSearchCVProgressBar is identical to GridSearchCV, but it adds a nice # progress bar to monitor progress.) gscv = GridSearchCVProgressBar(model, param_grid=param_grid, cv=3, return_train_score=False, verbose=1) # Fit the grid-search. We use `MultipleArray` to put together low_sig and # high_sig. If high_sig is None, we use low_sig for both the driver and the # modeled signal. X = MultipleArray(low_sig, None) gscv.fit(X) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none Fitting 3 folds for each of 180 candidates, totalling 540 fits [Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers. [Parallel(n_jobs=1)]: Done 540 out of 540 | elapsed: 2.2min finished GridSearchCVProgressBar(cv=3, estimator=Pipeline(steps=[('driver', ExtractDriver(fs=200.0, low_fq=4.0, max_low_fq=7.0, random_state=0)), ('add', AddDriverDelay()), ('dar', DARSklearn(fs=200.0, max_ordar=100))]), param_grid={'add__delay': [0], 'dar__ordar': array([ 0, 30, 60, 90]), 'dar__ordriv': [0, 1, 2], 'driver__low_fq': [3.0, 4.0, 5.0, 6.0, 7.0], 'driver__low_fq_width': [0.25, 0.5, 1.0]}, verbose=1) Print the results of the grid search. .. code-block:: default print("\nBest parameters set found over cross-validation:\n") print(gscv.best_params_) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none Best parameters set found over cross-validation: {'add__delay': 0, 'dar__ordar': 90, 'dar__ordriv': 2, 'driver__low_fq': 5.0, 'driver__low_fq_width': 1.0} Plot the results of the grid search. As we grid-searched over many dimensions, the results are not easy to visualize. That's why we choose two params to plot and the other are set to their best values. .. code-block:: default def plot_results(index='dar__ordar', columns='dar__ordriv'): """Select two hyperparameters from which we plot the fluctuations""" index = 'param_' + index columns = 'param_' + columns # prepare the results into a pandas.DataFrame df = pd.DataFrame(gscv.cv_results_) # Remove the other by selecting their best values (from gscv.best_params_) other = [c for c in df.columns if c[:6] == 'param_'] other.remove(index) other.remove(columns) for col in other: df = df[df[col] == gscv.best_params_[col[6:]]] # Create pivot tables for easy plotting table_mean = df.pivot_table(index=index, columns=columns, values=['mean_test_score']) table_std = df.pivot_table(index=index, columns=columns, values=['std_test_score']) # plot the pivot tables plt.figure() ax = plt.gca() for col_mean, col_std in zip(table_mean.columns, table_std.columns): table_mean[col_mean].plot(ax=ax, yerr=table_std[col_std], marker='o', label=col_mean) plt.title('Grid-search results (higher is better)') plt.ylabel('log-likelihood compared to an AR(0)') plt.legend(title=table_mean.columns.names) plt.show() plot_results(index='dar__ordar', columns='dar__ordriv') plot_results(index='driver__low_fq', columns='driver__low_fq_width') .. rst-class:: sphx-glr-horizontal * .. image:: /auto_examples/images/sphx_glr_plot_grid_search_001.png :class: sphx-glr-multi-img * .. image:: /auto_examples/images/sphx_glr_plot_grid_search_002.png :class: sphx-glr-multi-img .. rst-class:: sphx-glr-script-out Out: .. code-block:: none /home/tom/work/github/pactools/examples/plot_grid_search.py:135: UserWarning: Matplotlib is currently using agg, which is a non-GUI backend, so cannot show the figure. plt.show() .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 2 minutes 14.877 seconds) .. _sphx_glr_download_auto_examples_plot_grid_search.py: .. only :: html .. container:: sphx-glr-footer :class: sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_grid_search.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_grid_search.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_