Source code for pactools.comodulogram

import warnings

import matplotlib
import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d, interp2d

from .dar_model.base_dar import BaseDAR
from .dar_model.dar import DAR
from .dar_model.preprocess import multiple_extract_driver
from .utils.parallel import Parallel, delayed
from .utils.progress_bar import ProgressBar
from .utils.spectrum import Bicoherence, Coherence
from .utils.maths import norm, argmax_2d, next_power2
from .utils.validation import check_array, check_random_state
from .utils.validation import check_consistent_shape, check_is_fitted
from .utils.viz import add_colorbar
from .bandpass_filter import multiple_band_pass
from .mne_api import MaskIterator, _check_mne


STANDARD_PAC_METRICS = ['ozkurt', 'canolty', 'tort', 'penny', 'vanwijk']
DAR_BASED_PAC_METRICS = ['duprelatour']
COHERENCE_PAC_METRICS = ['jiang', 'colgin']
BICOHERENCE_PAC_METRICS = ['sigl', 'nagashima', 'hagihira', 'bispectrum']


[docs]class Comodulogram(object): """An object to compute the comodulogram for phase-amplitude coupling. Parameters ---------- fs : float Sampling frequency low_fq_range : array or list List of filtering frequencies (phase signal) high_fq_range : array or list or 'auto' List of filtering frequencies (amplitude signal) If 'auto', it uses np.linspace(max(low_fq_range), fs / 2.0, 40). low_fq_width : float Bandwidth of the band-pass filter (phase signal) high_fq_width : float or 'auto' Bandwidth of the band-pass filter (amplitude signal) If 'auto', it uses 2 * max(low_fq_range). method : string or DAR instance Modulation index method: - String in ('ozkurt', 'canolty', 'tort', 'penny', ), for a PAC estimation based on filtering and using the Hilbert transform. - String in ('vanwijk', ) for a joint AAC and PAC estimation based on filtering and using the Hilbert transform. - String in ('sigl', 'nagashima', 'hagihira', 'bispectrum', ), for a PAC estimation based on the bicoherence. - String in ('colgin', ) for a PAC estimation and in ('jiang', ) for a PAC directionality estimation, based on filtering and computing coherence. - String in ('duprelatour', ) or a DAR instance, for a PAC estimation based on a driven autoregressive model. n_surrogates : int Number of surrogates computed for the z-score If n_surrogates <= 1, the z-score is not computed. vmin, vmax : float or None If not None, it define the min/max value of the plot. progress_bar : boolean If True, a progress bar is shown in stdout. ax_special : matplotlib.axes.Axes or None If not None, a special figure is drawn on it, depending on the PAC method used. minimum_shift : float Minimum time shift (in sec) for the surrogate analysis. random_state : None, int or np.random.RandomState instance Seed or random number generator for the surrogate analysis. coherence_params : dict Parameters for methods base on coherence or bicoherence. May contain: -block_length : int Block length -fft_length : int or None Length of the FFT -step : int or None Step between two blocks If the dictionary is empty, default values will be applied based on fs and low_fq_width, with 0.5 overlap windows and no zero-padding. extract_params : dict Parameters for DAR models driver extraction low_fq_width_2 : float Bandwidth of the band-pass filters centered on low_fq_range, for the amplitude signal. Used only with 'vanwijk' method. n_jobs : int Number of jobs to use in parallel computations. Recquires scikit-learn installed. Examples -------- >>> from pactools.comodulogram import Comodulogram >>> c = Comodulogram(fs=200., low_fq_range=np.arange(2, 4, 0.2), ... low_fq_width=2.) >>> >>> c.plot() >>> comod_array = c.comod_ """
[docs] def __init__(self, fs, low_fq_range, low_fq_width=2., high_fq_range='auto', high_fq_width='auto', method='tort', n_surrogates=0, vmin=None, vmax=None, progress_bar=True, ax_special=None, minimum_shift=1.0, random_state=None, coherence_params=dict(), extract_params=dict(), low_fq_width_2=4.0, n_jobs=1): self.fs = fs self.low_fq_range = low_fq_range self.low_fq_width = low_fq_width self.high_fq_range = high_fq_range self.high_fq_width = high_fq_width self.method = method self.n_surrogates = n_surrogates self.vmin = vmin self.vmax = vmax self.progress_bar = progress_bar self.ax_special = ax_special self.minimum_shift = minimum_shift self.random_state = random_state self.coherence_params = coherence_params self.extract_params = extract_params self.low_fq_width_2 = low_fq_width_2 self.n_jobs = n_jobs
def _check_params(self): high_fq_range = self.high_fq_range if isinstance(high_fq_range, str) and high_fq_range == 'auto': self.high_fq_range = np.linspace( max(self.low_fq_range), self.fs / 2.0, 80) high_fq_width = self.high_fq_width if isinstance(high_fq_width, str) and high_fq_width == 'auto': self.high_fq_width = max(self.low_fq_range) * 2 self.random_state = check_random_state(self.random_state) if isinstance(self.method, str): self.method = self.method.lower() self.fs = float(self.fs) self.low_fq_range = np.atleast_1d(self.low_fq_range) self.high_fq_range = np.atleast_1d(self.high_fq_range) if self.ax_special is not None: assert isinstance(self.ax_special, matplotlib.axes.Axes) if self.low_fq_range.size > 1: raise ValueError("ax_special can only be used if low_fq_range " "contains only one frequency.") def fit(self, low_sig, high_sig=None, mask=None): """Call fit to compute the comodulogram. Parameters ---------- low_sig : array, shape (n_epochs, n_points) Input data for the phase signal high_sig : array or None, shape (n_epochs, n_points) Input data for the amplitude signal. If None, we use low_sig for both signals. mask : array or list of array or None, shape (n_epochs, n_points) The PAC is only evaluated where the mask is False. Masking is done after filtering and Hilbert transform. If the method computes the bicoherence, the mask has to be unidimensional (n_points, ) and the same mask is applied on all epochs. If a list or a MaskIterator is given, the filtering is done only once and the comodulogram is computed on each mask. Attributes ---------- comod_ : array, shape (len(low_fq_range), len(high_fq_range)) Comodulogram for each couple of frequencies. If a list of mask is given, it returns a array of shape (n_masks, len(low_fq_range), len(high_fq_range)) surrogates_ : array, shape (n_surrogates, len(low_fq_range), len(high_fq_range)) or None Comodulograms for each time shift of the surrogate analysis. If a list of mask is given, it returns a array of shape (n_masks, n_surrogates, len(low_fq_range), len(high_fq_range)) """ self._check_params() low_sig = check_array(low_sig) high_sig = check_array(high_sig, accept_none=True) check_consistent_shape(low_sig, high_sig) # check the masks multiple_masks = (isinstance(mask, list) or isinstance(mask, MaskIterator) or (isinstance(mask, np.ndarray) and mask.ndim == 3)) if not multiple_masks: mask = [mask] if not isinstance(mask, MaskIterator): mask = [check_array(m, dtype=bool, accept_none=True) for m in mask] n_masks = len(mask) # pre-compute all the random time shifts n_epochs, n_points = low_sig.shape self.shifts_ = _get_shifts(self.random_state, n_points, self.minimum_shift, self.fs, self.n_surrogates) if self.method in STANDARD_PAC_METRICS: if high_sig is None: high_sig = low_sig if self.progress_bar: self.progress_bar = ProgressBar( 'comodulogram: %s' % self.method, max_value=self.low_fq_range.size * n_masks) # compute a number of band-pass filtered signals filtered_high = multiple_band_pass( high_sig, self.fs, self.high_fq_range, self.high_fq_width) filtered_low = multiple_band_pass( low_sig, self.fs, self.low_fq_range, self.low_fq_width) if self.method == 'vanwijk': filtered_low_2 = multiple_band_pass( low_sig, self.fs, self.low_fq_range, self.low_fq_width_2) else: filtered_low_2 = None all_results = [] for this_mask in mask: results = _comodulogram(self, filtered_low, filtered_high, this_mask, filtered_low_2) all_results.append(results) elif self.method in COHERENCE_PAC_METRICS: if high_sig is None: high_sig = low_sig if self.progress_bar: self.progress_bar = ProgressBar( 'coherence: %s' % self.method, max_value=n_masks * len(self.shifts_)) # compute a number of band-pass filtered signals filtered_high = multiple_band_pass( high_sig, self.fs, self.high_fq_range, self.high_fq_width) all_results = [] for this_mask in mask: results = _coherence(self, low_sig, filtered_high, this_mask) all_results.append(results) # compute PAC with the bispectrum/bicoherence elif self.method in BICOHERENCE_PAC_METRICS: if high_sig is not None: raise ValueError( "Impossible to use a bicoherence method (%s) on two " "signals, please try another method." % self.method) if self.n_surrogates > 1: raise NotImplementedError( "Surrogate analysis with a bicoherence method (%s) " "is not implemented." % self.method) if self.progress_bar: self.progress_bar = ProgressBar( 'bicoherence: %s' % self.method, max_value=n_masks) all_results = [] for this_mask in mask: results = _bicoherence(self, sig=low_sig, mask=this_mask) all_results.append([results]) if self.progress_bar: self.progress_bar.update_with_increment_value(1) elif isinstance(self.method, BaseDAR) or self.method in DAR_BASED_PAC_METRICS: all_results = _driven_comodulogram(self, low_sig=low_sig, high_sig=high_sig, mask=mask) else: raise ValueError('unknown method: %s' % self.method) # remove very small values all_results = np.asarray(all_results) all_results[np.abs(all_results) < 10 * np.finfo(np.float64).eps] = 0 self.comod_ = all_results[:, 0, :, :] self.surrogates_ = all_results[:, 1:, :, :] if not multiple_masks: self.comod_ = self.comod_[0] self.surrogates_ = self.surrogates_[0] return self @property def comod_z_score_(self): """Compute the z-score based on the comodulogram and the surrogates Returns ------- comod_z_score_: array, shape (len(low_fq_range), len(high_fq_range)) If self.comod_ is an array of n_masks comodulograms, so does this array: shape (n_masks, len(low_fq_range), len(high_fq_range)) """ check_is_fitted(self, 'surrogates_') comod_ = self.comod_.copy() surrogates_ = self.surrogates_ ndim = surrogates_.ndim if ndim == 3: surrogates_ = surrogates_[None, ...] comod_ = comod_[None, ...] # n_masks, n_surrogates, n_low, n_high = self.surrogates_.shape if surrogates_.shape[1] > 2: comod_z_score = comod_ comod_z_score -= np.mean(surrogates_, axis=1) comod_z_score /= np.std(surrogates_, axis=1) if ndim == 3: comod_z_score = comod_z_score[0] return comod_z_score else: raise ValueError( "Impossible to compute comod_z_score_ since the surrogate " "comodulograms were not computed. Try to refit the " "estimator with n_surrogates > 1.") @property def surrogate_max_(self): """Compute the distribution of maxima of the surrogates comodulograms Returns ------- surrogate_max_: array, shape (n_surrogates, ) If self.comod_ is an array of n_masks comodulograms, the shape will be: shape (n_masks, n_surrogates) """ check_is_fitted(self, 'surrogates_') surrogates_ = self.surrogates_ ndim = surrogates_.ndim if ndim == 3: surrogates_ = surrogates_[None, ...] # n_masks, n_surrogates, n_low, n_high = self.surrogates_.shape if surrogates_.shape[1] > 2: n_masks, n_surrogates, n_low, n_high = surrogates_.shape tmp = surrogates_.reshape(n_masks, n_surrogates, -1) surrogate_max = np.max(tmp, axis=2) if ndim == 3: surrogate_max = surrogate_max[0] return surrogate_max else: raise ValueError( "Impossible to compute comod_z_score_ since the surrogate " "comodulograms were not computed. Try to refit the " "estimator with n_surrogates > 1.") def plot(self, titles=None, axs=None, cmap=None, vmin=None, vmax=None, unit='', cbar=True, label=True, contour_level=None, contour_method='comod_max', tight_layout=True): """ Plot the comodulograms computed during the fit titles : list of string or None List of titles for each comodulogram axs : list or array of matplotlib.axes.Axes Axes where the comodulograms are drawn. If None, a new figure is created. Typical use is: fig, axs = plt.subplots(3, 4) cmap : colormap or None Colormap used in the plot. If None, it uses 'viridis' colormap. vmin, vmax : float or None If not None, they define the min/max value of the plot, else they are set to (0, comodulograms.max()). unit : string (default: '') Unit of the comodulogram cbar : boolean Display colorbar or not label : boolean Display labels or not contour_level : None or float If not None, contours will be added around values where the significance level is above contour value. The significance level is computed only when n_surrogates >= 2. contour_method : str in ('comod_max', 'z_score'), (default:'comod_max') Select the method used to compute the significance level. - 'comod_max': Compute the maximum of each surrogate comodulogram, then estimate a threshold that correspond to a p-value specified in contour_level. - 'z_score': Compute the z-score for each couple of frequencies This method might suffer from multiple testing issues. tight_layout : boolean Use tight_layout or not Returns ------- fig: matplotlib.figure.Figure The figure object """ check_is_fitted(self, 'comod_') comod_ = self.comod_ # resize for generality if comod_.ndim == 2: comod_ = comod_[None, :, :] n_comod, n_low_fq, n_high_fq = comod_.shape # generate a new figure if no axes is provided if axs is None: n_lines = int(np.sqrt(n_comod)) n_columns = int(np.ceil(n_comod / float(n_lines))) figsize = (4 * n_columns, 3 * n_lines) fig, axs = plt.subplots(n_lines, n_columns, figsize=figsize) else: fig = axs[0].figure axs = np.array(axs).ravel() # Get auto values for vmin, vmax if vmin is None and vmax is None: vmin = min(0, comod_.min()) vmax = max(0, comod_.max()) if vmin < 0 and vmax > 0: vmax = max(vmax, -vmin) vmin = -vmax if cmap is None: cmap = plt.get_cmap('RdBu_r') if cmap is None: cmap = plt.get_cmap('viridis') extent = [ self.low_fq_range[0], self.low_fq_range[-1], self.high_fq_range[0], self.high_fq_range[-1], ] # plot the image for i in range(n_comod): cax = axs[i].imshow(comod_[i].T, cmap=cmap, vmin=vmin, vmax=vmax, aspect='auto', origin='lower', extent=extent, interpolation='none') # optional parameter if titles is not None: for i in range(n_comod): axs[i].set_title(titles[i], fontsize=12) if label: axs[-1].set_xlabel('Driver frequency (Hz)') axs[0].set_ylabel('Signal frequency (Hz)') if tight_layout: fig.tight_layout() if cbar: # plot the colorbar once ax = axs[0] if len(axs) == 1 else None add_colorbar(fig, cax, vmin, vmax, unit=unit, ax=ax) # plot the contours if contour_level is not None: surrogate_max_ = self.surrogate_max_ comod_z_score_ = self.comod_z_score_ # resize for generality if surrogate_max_.ndim == 1: surrogate_max_ = surrogate_max_[None, :] comod_z_score_ = comod_z_score_[None, :, :] for i in range(n_comod): if contour_method == 'comod_max': p_values = np.atleast_1d(contour_level) percentiles = 100. * (1 - p_values) levels = np.atleast_1d( np.percentile(surrogate_max_[i], percentiles)) axs[i].contour(comod_[i].T, levels=levels, colors='w', origin='lower', extent=extent) elif contour_method == 'z_score': levels = np.atleast_1d(contour_level) axs[i].contour(comod_z_score_[i].T, levels=levels, colors='w', origin='lower', extent=extent) else: raise ValueError( "contour_method has to be one of ('comod_max', " "'z_score'), got %s" % (contour_method, )) return fig def get_maximum_pac(self): """Get maximum PAC value in a comodulogram. Returns ------- low_fq : float or array, shape (n_masks, ) Low frequency of maximum PAC high_fq : float or array, shape (n_masks, ) High frequency of maximum PAC pac_value : float or array, shape (n_masks, ) Maximum PAC value """ check_is_fitted(self, 'comod_') # only one comodulogram return_array = True if self.comod_.ndim == 2: self.comod_ = self.comod_[None, :, :] return_array = False # check that the sizes match n_masks, n_low, n_high = self.comod_.shape # compute the maximum of the comodulogram, and get the frequencies max_pac_value = np.zeros(n_masks) low_fq = np.zeros(n_masks) high_fq = np.zeros(n_masks) for k, comodulogram in enumerate(self.comod_): i, j = argmax_2d(comodulogram) max_pac_value[k] = comodulogram[i, j] low_fq[k] = self.low_fq_range[i] high_fq[k] = self.high_fq_range[j] # return arrays or floats if return_array: return low_fq, high_fq, max_pac_value else: return low_fq[0], high_fq[0], max_pac_value[0] def save(self, fname, overwrite=False): """ Save a comodulogram object on disk with h5py. Parameters ---------- fname : string Filename to use. overwrite : boolean If True, overwrite file if it exists. Examples -------- >>> from pactools.comodulogram import Comodulgram, read_comodulogram >>> est = Comodulogram(fs=1000., low_fq_range=[1, 2, 3]) >>> fname = '/my_filename.hdf5' >>> >>> est2 = read_comodulogram(fname) """ mne = _check_mne('') write_hdf5 = mne.externals.h5io.write_hdf5 save_vars = {k: v for k, v in vars(self).items()} if 'random_state' in save_vars: rs = save_vars['random_state'] if type(rs) != save_vars['random_state'] = 'mtrand' save_vars['random_state_state_'] = rs.__getstate__() if self.progress_bar: save_vars['progress_bar'] = True write_hdf5(fname, save_vars, title='comodulogram', overwrite=overwrite, slash='replace')
def _comodulogram(estimator, filtered_low, filtered_high, mask, filtered_low_2): """ Helper function to compute the comodulogram. Used by PAC method in STANDARD_PAC_METRICS. """ # The modulation index is only computed where mask is True if mask is not None: filtered_low = filtered_low[:, ~mask] filtered_high = filtered_high[:, ~mask] if estimator.method == 'vanwijk': filtered_low_2 = filtered_low_2[:, ~mask] else: filtered_low = filtered_low.reshape(filtered_low.shape[0], -1) filtered_high = filtered_high.reshape(filtered_high.shape[0], -1) if estimator.method == 'vanwijk': filtered_low_2 = filtered_low_2.reshape(filtered_low_2.shape[0], -1) n_low, _ = filtered_low.shape n_high, _ = filtered_high.shape n_shifts = estimator.shifts_.size # phase of the low frequency signals for i in range(n_low): filtered_low[i] = np.angle(filtered_low[i]) filtered_low = np.real(filtered_low) # amplitude of the high frequency signals filtered_high = np.real(np.abs(filtered_high)) norm_a = np.zeros(n_high) if estimator.method == 'ozkurt': for j in range(n_high): norm_a[j] = norm(filtered_high[j]) # amplitude of the low frequency signals if estimator.method == 'vanwijk': for i in range(n_low): filtered_low_2[i] = np.abs(filtered_low_2[i]) filtered_low_2 = np.real(filtered_low_2) # Calculate the modulation index for each couple comod_list = np.zeros((n_shifts, n_low, n_high)) for i in range(n_low): # preproces the phase array if estimator.method == 'tort': n_bins = N_BINS_TORT # make sure edge case for right edge is included eps = np.finfo(filtered_low[i].dtype).eps * 2 phase_bins = np.linspace(-np.pi, np.pi + eps, n_bins + 1) # get the indices of the bins to which each value in input belongs phase_preprocessed = np.digitize(filtered_low[i], phase_bins) - 1 elif estimator.method == 'penny': phase_preprocessed = np.c_[np.ones_like(filtered_low[i]), np.cos(filtered_low[i]), np.sin(filtered_low[i])] elif estimator.method == 'vanwijk': phase_preprocessed = np.c_[np.ones_like(filtered_low[i]), np.cos(filtered_low[i]), np.sin(filtered_low[i]), filtered_low_2[i]] # yapf: disable elif estimator.method in ('canolty', 'ozkurt'): phase_preprocessed = np.exp(1j * filtered_low[i]) else: raise ValueError('Unknown method %s.' % estimator.method) delayed_func = delayed(_loop_over_shifts) mi_list = Parallel(n_jobs=estimator.n_jobs)(delayed_func( _one_modulation_index, estimator.shifts_, amplitude=filtered_high[j], phase_preprocessed=phase_preprocessed, norm_a=norm_a[j], method=estimator.method, ax_special=estimator.ax_special) for j in range(n_high)) mi_list = np.array(mi_list).reshape(n_high, n_shifts).T comod_list[:, i, :] = mi_list if estimator.progress_bar: estimator.progress_bar.update_with_increment_value(1) return comod_list def _loop_over_shifts(func, shifts, **kwargs): """Helper to loop over shifts inside a single job""" return [func(shift=sh, **kwargs) for sh in shifts] def _one_modulation_index(amplitude, phase_preprocessed, norm_a, method, shift, ax_special): """ Compute one modulation index. Used by PAC method in STANDARD_PAC_METRICS. """ # shift for the surrogate analysis if shift != 0: phase_preprocessed = np.roll(phase_preprocessed, shift) # Modulation index as in [Ozkurt & al 2011] if method == 'ozkurt': MI = np.abs(np.mean(amplitude * phase_preprocessed)) MI *= np.sqrt(amplitude.size) / norm_a # Generalized linear models as in [Penny & al 2008] or [van Wijk & al 2015] elif method in ('penny', 'vanwijk'): # solve a linear regression problem: # amplitude = * beta PtP =, phase_preprocessed) PtA =, amplitude[:, None]) beta = np.linalg.solve(PtP, PtA) residual = amplitude -, beta).ravel() variance_amplitude = np.var(amplitude) variance_residual = np.var(residual) MI = (variance_amplitude - variance_residual) / variance_amplitude # Modulation index as in [Canolty & al 2006] elif method == 'canolty': z_array = amplitude * phase_preprocessed MI = np.abs(np.mean(z_array)) if ax_special is not None and shift == 0: ax_special.plot(np.real(z_array), np.imag(z_array)) ax_special.set_ylabel('Imaginary part of z(t)') ax_special.set_xlabel('Real part of z(t)') ax_special.set_title("Canolty's modulation index: %.3f" % MI) ax_special.grid('on') # Modulation index as in [Tort & al 2010] elif method == 'tort': # mean amplitude distribution along phase bins n_bins = N_BINS_TORT amplitude_dist = np.ones(n_bins) # default is 1 to avoid log(0) for b in np.unique(phase_preprocessed): selection = amplitude[phase_preprocessed == b] amplitude_dist[b] = np.mean(selection) # Kullback-Leibler divergence of the distribution vs uniform amplitude_dist /= np.sum(amplitude_dist) divergence_kl = np.sum( amplitude_dist * np.log(amplitude_dist * n_bins)) MI = divergence_kl / np.log(n_bins) if ax_special is not None and shift == 0: phase_bins = np.linspace(-np.pi, np.pi, n_bins + 1) phase_bins = 0.5 * (phase_bins[:-1] + phase_bins[1:]) / np.pi * 180 ax_special.plot(phase_bins, amplitude_dist, '.-') ax_special.plot(phase_bins, np.ones(n_bins) / n_bins, '--') ax_special.set_ylim((0, 2. / n_bins)) ax_special.set_xlim((-180, 180)) ax_special.set_ylabel('Normalized mean amplitude') ax_special.set_xlabel('Phase (in degree)') ax_special.set_title("Tort's modulation index: %.3f" % MI) else: raise ValueError("Unknown method: %s" % (method, )) return MI def _same_mask_on_all_epochs(sig, mask, method): """ PAC metrics based on coherence or bicoherence, the same mask is applied on all epochs. """ mask = np.squeeze(mask) if mask.ndim > 1: warnings.warn("For coherence methods (e.g. %s) the mask has " "to be unidimensional, and the same mask is " "applied on all epochs. Got shape %s, so only the " "first row of the mask is used." % (method, mask.shape, ), UserWarning) mask = mask[0, :] sig = sig[..., ~mask] return sig def _bicoherence(estimator, sig, mask): """ Helper function for the comodulogram. Used by PAC method in BICOHERENCE_PAC_METRICS. """ # The modulation index is only computed where mask is True if mask is not None: sig = _same_mask_on_all_epochs(sig, mask, estimator.method) n_epochs, n_points = sig.shape coherence_params = _define_default_coherence_params( estimator.fs, estimator.low_fq_width, estimator.method, **estimator.coherence_params) model = Bicoherence(**coherence_params) bicoh =, method=estimator.method) # remove the redundant part n_freq = bicoh.shape[0] np.flipud(bicoh)[np.triu_indices(n_freq, 1)] = 0 bicoh[np.triu_indices(n_freq, 1)] = 0 frequencies = np.linspace(0, estimator.fs / 2., n_freq) comod = _interpolate(frequencies, frequencies, bicoh, estimator.high_fq_range, estimator.low_fq_range) return comod def _define_default_coherence_params(fs, low_fq_width, method, **user_params): """ Define default values for Coherence and Bicoherence classes, if not defined in user_params dictionary. """ # the FFT length is chosen to have a frequency resolution of low_fq_width fft_length = fs / low_fq_width # but it is faster if it is a power of 2 fft_length = 2 ** int(np.ceil(np.log2(fft_length))) # smoothing for bicoherence methods if method in BICOHERENCE_PAC_METRICS: fft_length /= 4 # not smoothed for because we convolve after if method == 'jiang': fft_length *= 2 # the block length is chosen to avoid zero-padding block_length = fft_length if 'block_length' not in user_params and 'fft_length' not in user_params: user_params['block_length'] = block_length user_params['fft_length'] = fft_length elif 'block_length' in user_params and 'fft_length' not in user_params: user_params['fft_length'] = user_params['block_length'] elif 'block_length' not in user_params and 'fft_length' in user_params: user_params['block_length'] = user_params['fft_length'] if 'fs' not in user_params: user_params['fs'] = fs if 'step' not in user_params: user_params['step'] = None return user_params def _coherence(estimator, low_sig, filtered_high, mask): """ Helper function to compute the comodulogram. Used by PAC method in COHERENCE_PAC_METRICS. """ if mask is not None: low_sig = _same_mask_on_all_epochs(low_sig, mask, estimator.method) filtered_high = _same_mask_on_all_epochs(filtered_high, mask, estimator.method) # amplitude of the high frequency signals filtered_high = np.real(np.abs(filtered_high)) coherence_params = _define_default_coherence_params( estimator.fs, estimator.low_fq_width, estimator.method, **estimator.coherence_params) delayed_func = delayed(_one_coherence_modulation_index) generator = (delayed_func( shift=sh, fs=estimator.fs, low_sig=low_sig, filtered_high=filtered_high, method=estimator.method, low_fq_range=estimator.low_fq_range, coherence_params=coherence_params) for sh in estimator.shifts_) if estimator.progress_bar: generator = estimator.progress_bar(generator) comod_list = Parallel(n_jobs=estimator.n_jobs)(generator) comod_list = np.array(comod_list) return comod_list def _one_coherence_modulation_index(fs, low_sig, filtered_high, method, low_fq_range, coherence_params, shift): """ Compute one modulation index. Used by PAC method in COHERENCE_PAC_METRICS. """ if shift != 0: low_sig = np.roll(low_sig, shift) # the actual frequency resolution is computed here delta_freq = fs / coherence_params['fft_length'] model = Coherence(**coherence_params) coherence =[None, :, :], filtered_high)[0] n_high, n_freq = coherence.shape frequencies = np.linspace(0, fs / 2., n_freq) # Coherence as in [Colgin & al 2009] if method == 'colgin': coherence = np.real(np.abs(coherence)) comod = _interpolate( np.arange(n_high), frequencies, coherence, np.arange(n_high), low_fq_range) # Phase slope index as in [Jiang & al 2015] elif method == 'jiang': product = coherence[:, 1:] * np.conjugate(coherence[:, :-1]) # we use a kernel of (ker * 2) with respect to the product, # i.e. a kernel of (ker * 2 +1) with respect to the coherence. ker = 2 kernel = np.ones(2 * ker) / (2 * ker) phase_slope_index = np.zeros((n_high, n_freq - (2 * ker)), dtype=np.complex128) for i in range(n_high): phase_slope_index[i] = np.convolve(product[i], kernel, 'valid') phase_slope_index = np.imag(phase_slope_index) frequencies = frequencies[ker:-ker] # transform the phase slope index into an approximated delay delay = phase_slope_index / (2. * np.pi * delta_freq) comod = _interpolate( np.arange(n_high), frequencies, delay, np.arange(n_high), low_fq_range) else: raise ValueError('Unknown method %s' % (method, )) return comod def _interpolate(x1, y1, z1, x2, y2): """Helper to interpolate in 1d or 2d We interpolate to get the same shape than with other methods. """ if x1.size > 1 and y1.size > 1: func = interp2d(x1, y1, z1.T, kind='linear', bounds_error=False) z2 = func(x2, y2) elif x1.size == 1 and y1.size > 1: func = interp1d(y1, z1.ravel(), kind='linear', bounds_error=False) z2 = func(y2) elif y1.size == 1 and x1.size > 1: func = interp1d(x1, z1.ravel(), kind='linear', bounds_error=False) z2 = func(x2) else: raise ValueError("Can't interpolate a scalar.") # interp2d is not intuitive and return this shape: z2.shape = (y2.size, x2.size) return z2 def _driven_comodulogram(estimator, low_sig, high_sig, mask): """ Helper function for the comodulogram. Used by PAC method in DAR_BASED_PAC_METRICS. """ model = estimator.method if model == 'duprelatour': model = DAR(ordar=10, ordriv=1) n_epochs = low_sig.shape[0] if high_sig is None: sigs = low_sig else: # hack to call only once extract high_sig = np.atleast_2d(high_sig) sigs = np.r_[high_sig, low_sig] generator = ( delayed(_driven_comodulogram_column)(estimator, filtered_signals, high_sig, mask, n_epochs) for filtered_signals in multiple_extract_driver( sigs=sigs, fs=estimator.fs, bandwidth=estimator.low_fq_width, frequency_range=estimator.low_fq_range, random_state=estimator.random_state, **estimator.extract_params)) if estimator.progress_bar: bar = ProgressBar(max_value=len(estimator.low_fq_range) * len(mask), title='comodulogram: %s' % (model.get_title(name=True), )) generator = bar(generator) results = Parallel(n_jobs=estimator.n_jobs)(generator) if estimator.progress_bar: bar.update(cur_value=bar.max_value) all_results = np.array(results) # From: n_low_fq * n_masks * n_shifts * n_high_fq # Into: n_masks * n_shifts * n_low_fq * n_high_fq all_results = np.swapaxes(all_results, 0, 2) all_results = np.swapaxes(all_results, 0, 1) assert all_results.shape[0] == len(mask) assert all_results.shape[1] == len(estimator.shifts_) assert all_results.shape[2] == estimator.low_fq_range.size return all_results def _driven_comodulogram_column(estimator, filtered_signals, high_sig, mask, n_epochs): extract_complex = estimator.extract_params.get('extract_complex', True) model = estimator.method if model == 'duprelatour': model = DAR(ordar=10, ordriv=1) if extract_complex: filtered_low, filtered_high, filtered_low_imag = filtered_signals else: filtered_low, filtered_high = filtered_signals sigdriv_imag = None if high_sig is None: sigin = np.array(filtered_high) sigdriv = np.array(filtered_low) if extract_complex: sigdriv_imag = np.array(filtered_low_imag) else: sigin = np.array(filtered_high[:n_epochs]) sigdriv = np.array(filtered_low[n_epochs:]) if extract_complex: sigdriv_imag = np.array(filtered_low_imag[n_epochs:]) sigin /= np.std(sigin) results = [] for i_mask, this_mask in enumerate(mask): comod_list = [] for sh in estimator.shifts_: comod = _one_driven_modulation_index( shift=sh, fs=estimator.fs, sigin=sigin, sigdriv=sigdriv, sigdriv_imag=sigdriv_imag, model=model, mask=this_mask, high_fq_range=estimator.high_fq_range, ax_special=estimator.ax_special) comod_list.append(comod) results.append(np.array(comod_list)) return results def _one_driven_modulation_index(fs, sigin, sigdriv, sigdriv_imag, model, mask, high_fq_range, ax_special, shift): """ Compute one modulation index. Used by PAC method in DAR_BASED_PAC_METRICS. """ # shift for the surrogate analysis if shift != 0: sigdriv = np.roll(sigdriv, shift) train_weights = ~mask if mask is not None else None # fit the model DAR on the data, sigin=sigin, sigdriv=sigdriv, sigdriv_imag=sigdriv_imag, train_weights=train_weights) # estimate the length of the padding for the FFT if len(high_fq_range) > 1: delta_f = np.diff(high_fq_range).mean() n_fft = next_power2(fs / delta_f) else: n_fft = 1024 # get PSD difference spec, _, _, _ = model._amplitude_frequency(n_fft=n_fft) # KL divergence for each phase, as in [Tort & al 2010] n_freq, n_phases = spec.shape spec = 10. ** (spec / 20.) spec = spec / np.sum(spec, axis=1)[:, None] spec_diff = np.sum(spec * np.log(spec * n_phases), axis=1) spec_diff /= np.log(n_phases) # crop the spectrum to high_fq_range frequencies = np.linspace(0, fs // 2, spec_diff.size) spec_diff = np.interp(high_fq_range, frequencies, spec_diff) if ax_special is not None and shift == 0: model.plot(frange=[high_fq_range[0], high_fq_range[-1]], ax=ax_special) return spec_diff def _get_shifts(random_state, n_points, minimum_shift, fs, n_surrogates): """Compute the shifts for the surrogate analysis""" n_iterations = max(1, 1 + n_surrogates) n_minimum_shift = max(1, int(fs * minimum_shift)) # shift at least minimum_shift seconds, i.e. n_minimum_shift points if n_iterations > 1: if n_points - n_minimum_shift < n_minimum_shift: raise ValueError("The minimum shift is longer than half the " "visible data.") shifts = random_state.randint( n_minimum_shift, n_points - n_minimum_shift, size=n_iterations) else: shifts = np.array([0]) # the first has no shift since this is for the initial computation shifts[0] = 0 return shifts def read_comodulogram(fname): mne = _check_mne('read_comodulogram') read_hdf5 = mne.externals.h5io.read_hdf5 data = read_hdf5(fname, 'comodulogram', slash='replace') init_params = {k: v for k, v in data.items() if not k.endswith('_')} if 'random_state' in init_params: rs = init_params['random_state'] if rs == 'mtrand': rs_state = data['random_state_state_'] del data['random_state_state_'] rs = np.random.RandomState() rs.set_state(rs_state) init_params['random_state'] = rs attrs = {k: v for k, v in data.items() if k.endswith('_')} inst = Comodulogram(**init_params) for k, v in attrs.items(): setattr(inst, k, v) return inst