Source code for pactools.delay_estimator

import numpy as np
import matplotlib.pyplot as plt

from .utils.validation import check_consistent_shape, check_array
from .utils.validation import check_is_fitted

from .dar_model import extract_driver
from .utils.progress_bar import ProgressBar
from .utils.viz import SEABORN_PALETTES


[docs]class DelayEstimator(object): """Estimate the optimal delay between the two components in PAC In phase-amplitude coupling (PAC), the slow oscillation and the fast oscillations may be shifted in time with a constant temporal delay. This estimator compute the optimal delay, based on the maximum likelihood of DAR models. Parameters ---------- fs : float Sampling frequency dar_model : DAR instance DAR model used to fit the signal low_fq : float Filtering frequency (phase signal) low_fq_width : float Bandwidth of the band-pass filter (phase signal) max_delay : float or 'auto' The delay grid will range from -max_delay to max_delay. If 'auto', it uses 0.5 / low_fq. refit : boolean, default True If True, the model will be refitted with the best delay obtained random_state : None, int or np.random.RandomState instance Seed or random number generator for the surrogate analysis. References ---------- Dupre la Tour et al. (2017). Non-linear Auto-Regressive Models for Cross-Frequency Coupling in Neural Time Series. bioRxiv, 159731. """
[docs] def __init__(self, fs, dar_model, low_fq, low_fq_width, max_delay='auto', refit=True, random_state=None): self.fs = fs self.dar_model = dar_model self.low_fq = low_fq self.low_fq_width = low_fq_width self.max_delay = max_delay self.refit = refit self.random_state = random_state
def fit(self, low_sig, high_sig=None, mask=None): """ Compute peak-locked time-averaged and time-frequency representations. Parameters ---------- low_sig : array, shape (n_epochs, n_points) Input data for the phase signal high_sig : array or None, shape (n_epochs, n_points) Input data for the amplitude signal. If None, we use low_sig for both signals mask : array or None, shape (n_epochs, n_points) The model is only fitted where the mask is False. Masking is done after filtering, and is not delayed. Attributes ---------- neg_log_likelihood_ : array, shape (n_delays, ) Negative log-likelihood of dar_model, fitted on a grid of delays self.delays_ms_ delays_ms_ : array, shape (n_delays, ) Temporal delays (in ms), corresponding to self.neg_log_likelihood_ best_delay_ms_ : float Temporal delay corresponding to the minimum negative log-likelihood best_model_ : fitted DAR instance If refit is True, the model is refitted with the best delay """ self.low_sig = check_array(low_sig) self.high_sig = check_array(high_sig, accept_none=True) if self.high_sig is None: self.high_sig = self.low_sig self.mask = check_array(mask, accept_none=True) check_consistent_shape(self.low_sig, self.high_sig, self.mask) model = self.dar_model # window decay of sigdriv for continuity after np.roll n_decay = max(int(0.5 * self.fs / self.low_fq), 5) window = np.blackman(n_decay * 2 - 1)[:n_decay] self.low_sig = self.low_sig.copy() # copy to avoid modifying original self.low_sig[:, :n_decay] *= window self.low_sig[:, -n_decay:] *= window[::-1] sigdriv, sigin, sigdriv_imag = extract_driver( self.low_sig, self.fs, self.low_fq, bandwidth=self.low_fq_width, fill=2, random_state=self.random_state) if self.high_sig is not self.low_sig: _, sigin, _ = extract_driver(self.high_sig, self.fs, self.low_fq, bandwidth=self.low_fq_width, fill=2, random_state=self.random_state) if self.max_delay == 'auto': max_delay_point = int(0.5 / self.low_fq * self.fs) else: max_delay_point = int(self.max_delay * self.fs) # delay in time points delays_point = np.arange(max_delay_point + 1) delays_point = np.r_[-delays_point[:0:-1], delays_point] bar = ProgressBar(title='delays', max_value=len(delays_point)) self.delays_ms_ = delays_point / self.fs * 1000. train_weights = (1. - self.mask) if self.mask is not None else None neg_log_likelihood = np.zeros(len(delays_point)) for i_delay, delay in enumerate(delays_point): # add delay sigdriv_ = np.roll(sigdriv, delay, axis=1) sigdriv_imag_ = np.roll(sigdriv_imag, delay, axis=1) # fit the model direct model.fit(sigin=sigin, sigdriv=sigdriv_, sigdriv_imag=sigdriv_imag_, fs=self.fs, train_weights=train_weights) neg_log_likelihood[i_delay] = model.get_criterion('-logl') # fit the model reverted model.fit(sigin=sigin[..., ::-1], sigdriv=sigdriv_[..., ::-1], sigdriv_imag=sigdriv_imag_[..., ::-1], fs=self.fs, train_weights=train_weights) neg_log_likelihood[i_delay] += model.get_criterion('-logl') bar.update(i_delay + 1) bar.close() self.neg_log_likelihood_ = neg_log_likelihood # compute the best delay i_best = np.nanargmin(neg_log_likelihood) self.best_delay_ms_ = self.delays_ms_[i_best] # refit the model with the best delay if self.refit: best_delay_point = delays_point[i_best] sigdriv_ = np.roll(sigdriv, best_delay_point, axis=1) sigdriv_imag_ = np.roll(sigdriv_imag, best_delay_point, axis=1) model.fit(sigin=sigin, sigdriv=sigdriv_, sigdriv_imag=sigdriv_imag_, fs=self.fs, train_weights=train_weights) self.best_model_ = model return self def plot(self, ax=None, write_tau=True): """ Returns ------- fig : matplotlib.figure.Figure Figure instance containing the plot. """ check_is_fitted(self, 'neg_log_likelihood_') if ax is None: fig = plt.figure() ax = fig.gca() else: fig = ax.figure blue, green, red, purple, yellow, cyan = SEABORN_PALETTES['deep'] i_best = np.nanargmin(self.neg_log_likelihood_) ax.plot(self.delays_ms_, self.neg_log_likelihood_, color=purple) ax.plot(self.delays_ms_[i_best], self.neg_log_likelihood_[i_best], 'D', color=red) ax.set_xlabel('Delay (ms)') ax.set_ylabel('Neg. log likelihood / T') ax.grid('on') if write_tau: ax.text(0.5, 0.80, r'$\mathrm{Estimated}$', horizontalalignment='center', transform=ax.transAxes) ax.text(0.5, 0.66, r'$\tau_0 = %.0f \;\mathrm{ms}$' % (self.delays_ms_[i_best], ), horizontalalignment='center', transform=ax.transAxes) return fig