import numpy as np
import matplotlib.pyplot as plt
import matplotlib
from .comodulogram import multiple_band_pass
from .utils.peak_finder import peak_finder
from .utils.validation import check_consistent_shape, check_array
from .utils.validation import check_is_fitted
from .utils.viz import add_colorbar, mpl_palette
[docs]class PeakLocking(object):
"""An object to compute time average and time-frequency averaged with
peak-locking, to analyze phase-amplitude coupling.
Parameters
----------
fs : float
Sampling frequency
low_fq : float
Filtering frequency (phase signal)
low_fq_width : float
Bandwidth of the band-pass filter (phase signal)
high_fq_range : array or list, shape (n_high, ), or 'auto'
List of filtering frequencies (amplitude signal)
If 'auto', it uses np.linspace(low_fq, fs / 2, 40).
high_fq_width : float or 'auto'
Bandwidth of the band-pass filter (amplitude signal)
If 'auto', it uses 2 * low_fq.
t_plot : float
Time to plot around the peaks (in second)
filter_method : in {'mne', 'pactools'}
Choose band pass filtering method (in multiple_band_pass)
'mne': with mne.filter.band_pass_filter
'pactools': with pactools.fir.BandPassFilter (default)
peak_or_trough: in {'peak', 'trough'}
Lock to the maximum (peak) of minimum (trough) of the slow
oscillation.
percentiles : list of float or string, shape (n_percentiles, )
Percentile to compute for the time representation.
It can also include 'mean', 'std' or 'ste'
(resp. mean, standard deviation or standard error).
"""
[docs] def __init__(self, fs, low_fq, low_fq_width=1.0, high_fq_range='auto',
high_fq_width='auto', t_plot=1.0, filter_method='pactools',
peak_or_trough='peak', percentiles=['std+', 'mean', 'std-']):
self.fs = fs
self.low_fq = low_fq
self.high_fq_range = high_fq_range
self.low_fq_width = low_fq_width
self.high_fq_width = high_fq_width
self.t_plot = t_plot
self.filter_method = filter_method
self.peak_or_trough = peak_or_trough
self.percentiles = percentiles
def fit(self, low_sig, high_sig=None, mask=None):
"""
Compute peak-locked time-averaged and time-frequency representations.
Parameters
----------
low_sig : array, shape (n_epochs, n_points)
Input data for the phase signal
high_sig : array or None, shape (n_epochs, n_points)
Input data for the amplitude signal.
If None, we use low_sig for both signals
mask : array or None, shape (n_epochs, n_points)
The locking is only evaluated where the mask is False.
Masking is done after filtering.
ax_draw_peaks : boolean or matplotlib.axes.Axes instance
If True, plot the first peaks/troughs in the phase signal
Plot on the matplotlib.axes.Axes instance if given, or on a
new figure.
Attributes
----------
time_frequency_ : array, shape (n_high, n_window)
Time-frequency representation, averaged with peak-locking.
(n_window is the number of point in t_plot seconds)
time_average_ : array, shape (n_percentiles, n_window)
Time representation, averaged with peak-locking.
(n_window is the number of point in t_plot seconds)
"""
self.low_fq = np.atleast_1d(self.low_fq)
if self.high_fq_range == 'auto':
self.high_fq_range = np.linspace(self.low_fq[0], self.fs / 2.0, 40)
if self.high_fq_width == 'auto':
self.high_fq_width = 2 * self.low_fq[0]
self.high_fq_range = np.asarray(self.high_fq_range)
self.low_sig = check_array(low_sig)
self.high_sig = check_array(high_sig, accept_none=True)
if self.high_sig is None:
self.high_sig = self.low_sig
self.mask = check_array(mask, accept_none=True)
check_consistent_shape(self.low_sig, self.high_sig, self.mask)
# compute the slow oscillation
# 1, n_epochs, n_points = filtered_low.shape
filtered_low = multiple_band_pass(low_sig, self.fs, self.low_fq,
self.low_fq_width,
filter_method=self.filter_method)
self.filtered_low_ = filtered_low[0]
filtered_low_real = np.real(self.filtered_low_)
if False:
# find the peak in the filtered_low_real
extrema = 1 if self.peak_or_trough == 'peak' else -1
thresh = (filtered_low_real.max() - filtered_low_real.min()) / 10.
self.peak_loc, self.peak_mag = peak_finder_multi_epochs(
filtered_low_real, fs=self.fs, t_plot=self.t_plot,
mask=self.mask, thresh=thresh, extrema=extrema)
else:
# find the peak in the phase of filtered_low
phase = np.angle(self.filtered_low_)
if self.peak_or_trough == 'peak':
phase = (phase + 2 * np.pi) % (2 * np.pi)
self.peak_loc, _ = peak_finder_multi_epochs(
phase, fs=self.fs, t_plot=self.t_plot, mask=self.mask,
extrema=1)
self.peak_mag = filtered_low_real.ravel()[self.peak_loc]
# extract several signals with band-pass filters
# n_high, n_epochs, n_points = filtered_high.shape
self.filtered_high_ = multiple_band_pass(
self.high_sig, self.fs, self.high_fq_range, self.high_fq_width,
n_cycles=None, filter_method=self.filter_method)
# compute the peak locked time-frequency representation
time_frequency_ = peak_locked_time_frequency(
self.filtered_high_, self.fs, self.high_fq_range,
peak_loc=self.peak_loc, t_plot=self.t_plot, mask=self.mask)
# compute the peak locked time representation
# we don't need the mask here, since only the valid peak locations are
# kept in peak_finder_multi_epochs
time_average_ = peak_locked_percentile(self.low_sig[None, :], self.fs,
self.peak_loc, self.t_plot,
self.percentiles)
time_average_ = time_average_[0, :, :]
self.time_frequency_ = time_frequency_
self.time_average_ = time_average_
return self
def plot_peaks(self, ax=None):
check_is_fitted(self, 'filtered_low_')
# plot the filtered_low_real peaks
if not isinstance(ax, matplotlib.axes.Axes):
fig = plt.figure(figsize=(16, 5))
ax = fig.gca()
n_point_plot = min(3000, self.low_sig.shape[1])
time = np.arange(n_point_plot) / float(self.fs)
filtered = np.real(self.filtered_low_[0, :n_point_plot])
ax.plot(time, self.low_sig[0, :n_point_plot], label='signal')
ax.plot(time, filtered, label='driver')
ax.plot(self.peak_loc[self.peak_loc < n_point_plot] / float(self.fs),
self.peak_mag[self.peak_loc < n_point_plot], 'o',
label='peaks')
ax.set_xlabel('Time (sec)')
ax.set_title("Driver's peak detection")
ax.set_legend(loc=0)
def plot(self, axs=None, vmin=None, vmax=None, ylim=None):
"""
Returns
-------
fig : matplotlib.figure.Figure
Figure instance containing the plot.
"""
check_is_fitted(self, 'time_average_')
if axs is None:
fig, axs = plt.subplots(2, 1, sharex=True, figsize=(8, 8))
axs = axs.ravel()
else:
fig = axs[0].figure
# plot the peak-locked time-frequency
ax = axs[0]
n_high, n_points = self.time_average_.shape
vmax = np.abs(self.time_frequency_).max() if vmax is None else vmax
vmin = -vmax
extent = (
-self.t_plot / 2,
self.t_plot / 2,
self.high_fq_range[0],
self.high_fq_range[-1], )
cax = ax.imshow(self.time_frequency_, cmap=plt.get_cmap('RdBu_r'),
vmin=vmin, vmax=vmax, aspect='auto', origin='lower',
interpolation='none', extent=extent)
# ax.set_xlabel('Time (sec)')
ax.set_ylabel('Frequency (Hz)')
ax.set_title('Driver peak-locked Time-frequency decomposition')
# plot the colorbar
plt.tight_layout()
fig.subplots_adjust(right=0.85)
add_colorbar(fig, cax, vmin, vmax, unit='', ax=None)
# plot the peak-locked time
ax = axs[1]
labels = {
'std+': r'$\mu+\sigma$',
'std-': r'$\mu-\sigma$',
'ste+': r'$\mu+\sigma/\sqrt{n}$',
'ste-': r'$\mu-\sigma/\sqrt{n}$',
'mean': r'$\mu$',
}
colors = mpl_palette('viridis', n_colors=len(self.percentiles))
n_percentiles, n_points = self.time_average_.shape
time = (np.arange(n_points) - n_points // 2) / float(self.fs)
for i, p in enumerate(self.percentiles):
label = ('%d %%' % p) if isinstance(p, int) else labels[p]
ax.plot(time, self.time_average_[i, :], color=colors[i],
label=label)
ax.set_xlabel('Time (sec)')
ax.set_title('Driver peak-locked average of raw signal')
ax.legend(loc='lower center', ncol=5, labelspacing=0.)
ax.grid('on')
# make room for the legend or apply specified ylim
if ylim is None:
ylim = ax.get_ylim()
ylim = (ylim[0] - (ylim[1] - ylim[0]) * 0.2, ylim[1])
ax.set_ylim(ylim)
return fig
def peak_finder_multi_epochs(x0, fs=None, t_plot=None, mask=None, thresh=None,
extrema=1):
"""Call peak_finder for multiple epochs, and fill only one array
as if peak_finder was called with the ravelled array.
Also remove the peaks that are too close to the start or the end
of each epoch, and the peaks that are masked by the mask.
"""
n_epochs, n_points = x0.shape
peak_inds_list = []
peak_mags_list = []
for i_epoch in range(n_epochs):
peak_inds, peak_mags = peak_finder(x0[i_epoch], thresh=thresh,
extrema=extrema)
# remove the peaks too close to the start or the end
if t_plot is not None and fs is not None:
n_half_window = int(fs * t_plot / 2.)
selection = np.logical_and(peak_inds > n_half_window,
peak_inds < n_points - n_half_window)
peak_inds = peak_inds[selection]
peak_mags = peak_mags[selection]
# remove the masked peaks
if mask is not None:
selection = mask[i_epoch, peak_inds] == 0
peak_inds = peak_inds[selection]
peak_mags = peak_mags[selection]
peak_inds_list.extend(peak_inds + i_epoch * n_points)
peak_mags_list.extend(peak_mags)
if peak_inds_list == []:
raise ValueError("No %s detected. The signal might be to short, "
"or the mask to strong. You can also try to reduce "
"the plotted time window `t_plot`." %
["trough", "peak"][(extrema + 1) // 2])
return np.array(peak_inds_list), np.array(peak_mags_list)
def peak_locked_time_frequency(filtered_high, fs, high_fq_range, peak_loc,
t_plot, mask=None):
"""
Compute the peak-locked Time-frequency
"""
# normalize each signal independently
n_high, n_epochs, n_points = filtered_high.shape
# normalization is done everywhere, but mean is computed
# only where mask == 1
if mask is not None:
masked_filtered_high = filtered_high[:, mask == 0]
else:
masked_filtered_high = filtered_high.reshape(n_high, -1)
mean = masked_filtered_high.mean(axis=1)[:, None, None]
std = masked_filtered_high.std(axis=1)[:, None, None]
filtered_high -= mean
filtered_high /= std
# get the power (np.abs(filtered_high) ** 2)
filtered_high *= np.conj(filtered_high)
filtered_high = np.real(filtered_high)
# subtract the mean power.
if mask is not None:
masked_filtered_high = filtered_high[:, mask == 0]
else:
masked_filtered_high = filtered_high.reshape(n_high, -1)
mean = masked_filtered_high.mean(axis=1)[:, None, None]
filtered_high -= mean
# compute the evoked signals (peak-locked mean)
evoked_signals = peak_locked_percentile(filtered_high, fs, peak_loc,
t_plot)
evoked_signals = evoked_signals[:, 0, :]
return evoked_signals
def peak_locked_percentile(signals, fs, peak_loc, t_plot,
percentiles=['mean']):
"""
Compute the mean of each signal in signals, locked to the peaks
Parameters
----------
signals : shape (n_signals, n_epochs, n_points)
fs : sampling frequency
peak_loc : indices of the peak locations
t_plot : in second, time to plot around the peaks
percentiles: list of precentile to compute. It can also include 'mean',
'std' or 'ste' (mean, standard deviation or standard error).
Returns
-------
evoked_signals : array, shape (n_signals, n_percentiles, n_window)
"""
n_signals, n_epochs, n_points = signals.shape
n_window = int(fs * t_plot / 2.) * 2 + 1
n_percentiles = len(percentiles)
# build indices matrix: each line is a index range around peak locations
indices = np.tile(peak_loc, (n_window, 1)).T
indices = indices + np.arange(n_window) - n_window // 2
# ravel the epochs since we now have isolated events
signals = signals.reshape(n_signals, -1)
n_peaks = indices.shape[0]
assert n_peaks > 0
# compute the evoked signals (peak-locked mean)
evoked_signals = np.zeros((n_signals, n_percentiles, n_window))
for i_s in range(n_signals):
for i_p, p in enumerate(percentiles):
if isinstance(p, int):
evoked_signals[i_s, i_p] = np.percentile(signals[i_s][indices],
p, axis=0)
continue
mean = np.mean(signals[i_s][indices], axis=0)
if p == 'mean':
evoked_signals[i_s, i_p] = mean
else:
std = np.std(signals[i_s][indices], axis=0)
if p == 'std+':
evoked_signals[i_s, i_p] = mean + std
elif p == 'std-':
evoked_signals[i_s, i_p] = mean - std
elif p == 'ste+':
evoked_signals[i_s, i_p] = mean + std / np.sqrt(n_peaks)
elif p == 'ste-':
evoked_signals[i_s, i_p] = mean - std / np.sqrt(n_peaks)
else:
raise (ValueError, 'wrong percentile string: %s' % p)
return evoked_signals