import numpy as np
import matplotlib.pyplot as plt
from .utils.validation import check_consistent_shape, check_array
from .utils.validation import check_is_fitted
from .dar_model import extract_driver
from .utils.progress_bar import ProgressBar
from .utils.viz import SEABORN_PALETTES
[docs]class DelayEstimator(object):
"""Estimate the optimal delay between the two components in PAC
In phase-amplitude coupling (PAC), the slow oscillation and the fast
oscillations may be shifted in time with a constant temporal delay. This
estimator compute the optimal delay, based on the maximum likelihood of DAR
models.
Parameters
----------
fs : float
Sampling frequency
dar_model : DAR instance
DAR model used to fit the signal
low_fq : float
Filtering frequency (phase signal)
low_fq_width : float
Bandwidth of the band-pass filter (phase signal)
max_delay : float or 'auto'
The delay grid will range from -max_delay to max_delay.
If 'auto', it uses 0.5 / low_fq.
refit : boolean, default True
If True, the model will be refitted with the best delay obtained
random_state : None, int or np.random.RandomState instance
Seed or random number generator for the surrogate analysis.
References
----------
Dupre la Tour et al. (2017). Non-linear Auto-Regressive Models for
Cross-Frequency Coupling in Neural Time Series. bioRxiv, 159731.
"""
[docs] def __init__(self, fs, dar_model, low_fq, low_fq_width, max_delay='auto',
refit=True, random_state=None):
self.fs = fs
self.dar_model = dar_model
self.low_fq = low_fq
self.low_fq_width = low_fq_width
self.max_delay = max_delay
self.refit = refit
self.random_state = random_state
def fit(self, low_sig, high_sig=None, mask=None):
"""
Compute peak-locked time-averaged and time-frequency representations.
Parameters
----------
low_sig : array, shape (n_epochs, n_points)
Input data for the phase signal
high_sig : array or None, shape (n_epochs, n_points)
Input data for the amplitude signal.
If None, we use low_sig for both signals
mask : array or None, shape (n_epochs, n_points)
The model is only fitted where the mask is False.
Masking is done after filtering, and is not delayed.
Attributes
----------
neg_log_likelihood_ : array, shape (n_delays, )
Negative log-likelihood of dar_model, fitted on a grid of delays
self.delays_ms_
delays_ms_ : array, shape (n_delays, )
Temporal delays (in ms), corresponding to self.neg_log_likelihood_
best_delay_ms_ : float
Temporal delay corresponding to the minimum negative log-likelihood
best_model_ : fitted DAR instance
If refit is True, the model is refitted with the best delay
"""
self.low_sig = check_array(low_sig)
self.high_sig = check_array(high_sig, accept_none=True)
if self.high_sig is None:
self.high_sig = self.low_sig
self.mask = check_array(mask, accept_none=True)
check_consistent_shape(self.low_sig, self.high_sig, self.mask)
model = self.dar_model
# window decay of sigdriv for continuity after np.roll
n_decay = max(int(0.5 * self.fs / self.low_fq), 5)
window = np.blackman(n_decay * 2 - 1)[:n_decay]
self.low_sig = self.low_sig.copy() # copy to avoid modifying original
self.low_sig[:, :n_decay] *= window
self.low_sig[:, -n_decay:] *= window[::-1]
sigdriv, sigin, sigdriv_imag = extract_driver(
self.low_sig, self.fs, self.low_fq, bandwidth=self.low_fq_width,
fill=2, random_state=self.random_state)
if self.high_sig is not self.low_sig:
_, sigin, _ = extract_driver(self.high_sig, self.fs, self.low_fq,
bandwidth=self.low_fq_width, fill=2,
random_state=self.random_state)
if self.max_delay == 'auto':
max_delay_point = int(0.5 / self.low_fq * self.fs)
else:
max_delay_point = int(self.max_delay * self.fs)
# delay in time points
delays_point = np.arange(max_delay_point + 1)
delays_point = np.r_[-delays_point[:0:-1], delays_point]
bar = ProgressBar(title='delays', max_value=len(delays_point))
self.delays_ms_ = delays_point / self.fs * 1000.
train_weights = (1. - self.mask) if self.mask is not None else None
neg_log_likelihood = np.zeros(len(delays_point))
for i_delay, delay in enumerate(delays_point):
# add delay
sigdriv_ = np.roll(sigdriv, delay, axis=1)
sigdriv_imag_ = np.roll(sigdriv_imag, delay, axis=1)
# fit the model direct
model.fit(sigin=sigin, sigdriv=sigdriv_,
sigdriv_imag=sigdriv_imag_, fs=self.fs,
train_weights=train_weights)
neg_log_likelihood[i_delay] = model.get_criterion('-logl')
# fit the model reverted
model.fit(sigin=sigin[..., ::-1], sigdriv=sigdriv_[..., ::-1],
sigdriv_imag=sigdriv_imag_[..., ::-1], fs=self.fs,
train_weights=train_weights)
neg_log_likelihood[i_delay] += model.get_criterion('-logl')
bar.update(i_delay + 1)
bar.close()
self.neg_log_likelihood_ = neg_log_likelihood
# compute the best delay
i_best = np.nanargmin(neg_log_likelihood)
self.best_delay_ms_ = self.delays_ms_[i_best]
# refit the model with the best delay
if self.refit:
best_delay_point = delays_point[i_best]
sigdriv_ = np.roll(sigdriv, best_delay_point, axis=1)
sigdriv_imag_ = np.roll(sigdriv_imag, best_delay_point, axis=1)
model.fit(sigin=sigin, sigdriv=sigdriv_,
sigdriv_imag=sigdriv_imag_, fs=self.fs,
train_weights=train_weights)
self.best_model_ = model
return self
def plot(self, ax=None, write_tau=True):
"""
Returns
-------
fig : matplotlib.figure.Figure
Figure instance containing the plot.
"""
check_is_fitted(self, 'neg_log_likelihood_')
if ax is None:
fig = plt.figure()
ax = fig.gca()
else:
fig = ax.figure
blue, green, red, purple, yellow, cyan = SEABORN_PALETTES['deep']
i_best = np.nanargmin(self.neg_log_likelihood_)
ax.plot(self.delays_ms_, self.neg_log_likelihood_, color=purple)
ax.plot(self.delays_ms_[i_best], self.neg_log_likelihood_[i_best], 'D',
color=red)
ax.set_xlabel('Delay (ms)')
ax.set_ylabel('Neg. log likelihood / T')
ax.grid('on')
if write_tau:
ax.text(0.5, 0.80, r'$\mathrm{Estimated}$',
horizontalalignment='center', transform=ax.transAxes)
ax.text(0.5, 0.66, r'$\tau_0 = %.0f \;\mathrm{ms}$' %
(self.delays_ms_[i_best], ), horizontalalignment='center',
transform=ax.transAxes)
return fig