Source code for pactools.utils.spectrum

import numpy as np
import scipy as sp
from scipy.linalg import hankel
from scipy.signal import hilbert
import matplotlib.pyplot as plt

from .maths import square, is_power2, prime_factors, compute_n_fft, next_power2
from .viz import compute_vmin_vmax


[docs]class Spectrum(object): """Spectral estimator following Welch's method Parameters ---------- block_length : int Length of each signal block, on which we estimate the spectrum fft_length : int or None Length of FFT, should be greater or equal to block_length. If None, it is set to block_length step : int or None Step between successive blocks If None, it is set to half the block length (i.e. 0.5 overlap) wfunc : function Function used to compute the weitghting window on each block. Examples: np.ones, np.hamming, np.bartlett, np.blackman, ... fs : float Sampling frequency donorm : boolean If True, the amplitude is normalized """
[docs] def __init__(self, block_length=1024, fft_length=None, step=None, wfunc=np.hamming, fs=1., donorm=True): self.block_length = block_length self.fft_length = fft_length self.step = step self.wfunc = wfunc self.fs = fs self.donorm = donorm self.psd = []
def check_params(self): # block_length if self.block_length <= 0: raise ValueError('Block length is negative: %s' % (self.block_length, )) self.block_length = int(self.block_length) # fft_length if self.fft_length is None: fft_length = next_power2(self.block_length) else: fft_length = int(self.fft_length) if not is_power2(fft_length): raise ValueError('FFT length should be a power of 2') if fft_length < self.block_length: raise ValueError('Block length is greater than FFT length') # step if self.step is None: step = max(int(self.block_length // 2), 1) else: step = int(self.step) if step <= 0 or step > self.block_length: raise ValueError('Invalid step between blocks: %s' % (step, )) return fft_length, step def periodogram(self, signals, hold=False, mean_psd=False): """ Computes the estimation (in dB) for each epoch in a signal Parameters ---------- signals : array, shape (n_epochs, n_points) Signals from which one computes the power spectrum hold : boolean, default = False If True, the estimation is appended to the list of previous estimations, else, the list is emptied and only the current estimation is stored. mean_psd : boolean, default = False If True, the PSD is the mean PSD over all epochs. Returns ------- psd : array, shape (n_epochs, n_freq) or (1, n_freq) if mean_psd Power spectrum estimated with a Welsh method on each epoch n_freq = fft_length // 2 + 1 """ fft_length, step = self.check_params() signals = np.atleast_2d(signals) n_epochs, n_points = signals.shape block_length = min(self.block_length, n_points) window = self.wfunc(block_length) n_epochs, tmax = signals.shape n_freq = fft_length // 2 + 1 psd = np.zeros((n_epochs, n_freq)) for i, sig in enumerate(signals): block = np.arange(block_length) # iterate on blocks count = 0 while block[-1] < sig.size: psd[i] += np.abs(sp.fft.fft(window * sig[block], fft_length, 0))[:n_freq] ** 2 count = count + 1 block = block + step if count == 0: raise IndexError( 'spectrum: first block has %d samples but sig has %d ' 'samples' % (block[-1] + 1, sig.size)) # normalize if self.donorm: scale = 1.0 / (count * (np.sum(window) ** 2)) else: scale = 1.0 / count psd[i] *= scale if mean_psd: psd = np.mean(psd, axis=0)[None, :] if not hold: self.psd = [] self.psd.append(psd) return psd def plot(self, title='', fscale='lin', labels=None, fig=None, axes=None, replicate=None, colors=None): """ plots the power spectral density warning: the plot will only appear after plt.show() Parameters ---------- title : str Title of the plot (and the window). fscale : str Kind of frequency scale ('lin' or 'log'). labels : list of str List of labels for the plots. fig : matplotlib.Figure Specific figure to plot on. axes : matplotlib.Axes Specific axes to draw on. Overrides `fig`. replicate : int Number of replication of the spectrum across frequencies Returns ------- fig : matplotlib.Figure Figure instance used in plotting. """ fft_length, _ = self.check_params() if labels is None: plot_legend = False labels = [''] * len(self.psd) else: plot_legend = True if isinstance(labels, str): labels = [labels] if replicate is None: replicate = 0 if colors is None: colors = [None] * len(self.psd) if axes is None: if fig is None: fig = plt.figure(title) axes = fig.gca() else: if not isinstance(fig, plt.Figure): raise TypeError('fig must be matplotlib Figure, got {}' ' instead.'.format(type(fig))) axes = fig.gca() else: # validate if axes is correct if not isinstance(axes, plt.Axes): raise TypeError('axes must be matplotlib Axes, got {}' ' instead.'.format(type(axes))) fig = axes.figure self.fscale = fscale if self.fscale == 'log': axes.set_xscale('log') else: axes.set_xscale('linear') fmax = self.fs / 2 freq = np.linspace(0, fmax, fft_length // 2 + 1) for label_, color, psd in zip(labels, colors, self.psd): psd = 10.0 * np.log10(np.maximum(psd, 1.0e-16)) color_ = color for i in range(replicate + 1): label = label_ if i == 0 else '' lines = axes.plot(freq + i * fmax, psd.T[::(-1) ** i], label=label, color=color_) color_ = lines[-1].get_color() axes.grid(True) axes.set_title(title) axes.set_xlabel('Frequency (Hz)') axes.set_ylabel('Amplitude (dB)') if plot_legend: axes.legend(loc=0) return fig def main_frequency(self): """Extract the frequency of the maximum in the spectrum""" fft_length, _ = self.check_params() n_freq = fft_length // 2 + 1 freq = np.linspace(0, self.fs / 2, n_freq) psd = self.psd[-1][0, :] return freq[np.argmax(psd[1:]) + 1]
[docs]class Coherence(Spectrum): """Coherence estimator Parameters ---------- block_length : int Length of each signal block, on which we estimate the spectrum fft_length : int or None Length of FFT, should be greater or equal to block_length. If None, it is set to block_length step : int or None Step between successive blocks If None, it is set to half the block length (i.e. 0.5 overlap) wfunc : function Function used to compute the weitghting window on each block. Examples: np.ones, np.hamming, np.bartlett, np.blackman, ... fs : float Sampling frequency donorm : boolean If True, the amplitude is normalized """
[docs] def __init__(self, block_length=1024, fft_length=None, step=None, wfunc=np.hamming, fs=1.): super(Coherence, self).__init__(block_length=block_length, fft_length=fft_length, step=step, wfunc=wfunc, fs=fs) self.coherence = None
def fit(self, sigs_a, sigs_b): """ Computes the coherence for two signals. It is symmetrical, and slightly faster if n_signals_a < n_signals_b. Parameters ---------- sigs_a : array, shape (n_signals_a, n_epochs, n_points) Signal from which one computes the coherence sigs_b : array, shape (n_signals_b, n_epochs, n_points) Signal from which one computes the coherence Returns ------- coherence : array, shape (n_signals_a, n_signals_b, n_freqs) Complex coherence of sigs_a and sigs_b over all epochs. n_freqs = fft_length // 2 + 1 """ fft_length, step = self.check_params() n_signals_a, n_epochs, n_points = sigs_a.shape n_signals_b, n_epochs, n_points = sigs_b.shape if sigs_a.shape[1:] != sigs_b.shape[1:]: raise ValueError('Incompatible shapes: %s and %s' % (sigs_a.shape[1:], sigs_b.shape[1:])) block_length = min(self.block_length, n_points) window = self.wfunc(block_length) n_freq = fft_length // 2 + 1 coherence = np.zeros((n_signals_a, n_signals_b, n_freq), dtype=np.complex128) norm_a = np.zeros((n_signals_a, n_freq), dtype=np.float64) norm_b = np.zeros((n_signals_b, n_freq), dtype=np.float64) # iterate on blocks count = 0 for i_epoch in range(n_epochs): block = np.arange(block_length) while block[-1] < n_points: for i_a in range(n_signals_a): F_a = sp.fft.fft(window * sigs_a[i_a, i_epoch, block], fft_length, 0)[:n_freq] norm_a[i_a] += square(F_a) for i_b in range(n_signals_b): F_b = sp.fft.fft(window * sigs_b[i_b, i_epoch, block], fft_length, 0)[:n_freq] # compute only once if i_a == 0: norm_b[i_b] += square(F_b) coherence[i_a, i_b] += F_a * np.conjugate(F_b) count = count + 1 block = block + step normalization = np.sqrt(norm_a[:, None, :] * norm_b[None, :, :]) coherence /= normalization if count == 0: raise IndexError( 'bicoherence: first block needs %d samples but sigs has shape ' '%s' % (block_length, sigs_a.shape)) self.coherence = coherence return self.coherence def plot(self, fig=None, ax=None): """Not Implemented""" return fig def main_frequency(self): pass
[docs]class Bicoherence(Spectrum): """Bicoherence estimator Parameters ---------- block_length : int Length of each signal block, on which we estimate the spectrum fft_length : int or None Length of FFT, should be greater or equal to block_length. If None, it is set to block_length step : int or None Step between successive blocks If None, it is set to half the block length (i.e. 0.5 overlap) wfunc : function Function used to compute the weitghting window on each block. Examples: np.ones, np.hamming, np.bartlett, np.blackman, ... fs : float Sampling frequency donorm : boolean If True, the amplitude is normalized """
[docs] def __init__(self, block_length=1024, fft_length=None, step=None, wfunc=np.hamming, fs=1.): super(Bicoherence, self).__init__(block_length=block_length, fft_length=fft_length, step=step, wfunc=wfunc, fs=fs)
def fit(self, sigs, method='hagihira'): """ Computes the bicoherence for one signal Parameters ---------- sigs : array, shape (n_epochs, n_points) Signal from which one computes the bicoherence method : string in ('hagihira', 'sigl', 'nagashima', 'bispectrum') Normalization used for the bicoherence Returns ------- bicoherence : array, shape (n_freq, n_freq) Complex bicoherence computed on the input signal. n_freq = fft_length // 2 + 1 """ fft_length, step = self.check_params() self.method = method sigs = np.atleast_2d(sigs) n_epochs, n_points = sigs.shape block_length = min(self.block_length, n_points) window = self.wfunc(block_length) n_freq = fft_length // 2 + 1 bicoherence = np.zeros((n_freq, n_freq), dtype=np.complex128) normalization = np.zeros((n_freq, n_freq), dtype=np.float64) # iterate on blocks count = 0 for i_epoch in range(n_epochs): block = np.arange(block_length) while block[-1] < n_points: F = sp.fft.fft(window * sigs[i_epoch, block], fft_length, 0)[:n_freq] F1 = F[None, :] F2 = F1.T mask = hankel(np.arange(n_freq)) F12 = np.conjugate(F)[mask] product = F1 * F2 * F12 bicoherence += product if method == 'sigl': normalization += square(F1) * square(F2) * square(F12) elif method == 'nagashima': normalization += square(F1 * F2) * square(F12) elif method == 'hagihira': normalization += np.abs(product) elif method == 'bispectrum': pass else: raise (ValueError("Method '%s' unkown." % method)) count = count + 1 block = block + step bicoherence = np.real(np.abs(bicoherence)) if method in ['sigl', 'nagashima']: normalization = np.sqrt(normalization) if method != 'bispectrum': bicoherence /= normalization else: bicoherence = np.log(bicoherence) if count == 0: raise IndexError( 'bicoherence: first block needs %d samples but sigs has shape ' '%s' % (self.block_length, sigs.shape)) self.bicoherence = bicoherence return bicoherence def plot(self, fig=None, ax=None): if fig is None: fig = plt.figure() if ax is None: ax = fig.gca() fmax = self.fs / 2.0 bicoherence = np.copy(self.bicoherence) n_freq = bicoherence.shape[0] np.flipud(bicoherence)[np.triu_indices(n_freq, 1)] = 0 bicoherence[np.triu_indices(n_freq, 1)] = 0 bicoherence = bicoherence[:, :n_freq // 2 + 1] vmin, vmax = compute_vmin_vmax(bicoherence, tick=1e-15, percentile=1) ax.imshow(bicoherence, cmap=plt.cm.viridis, aspect='auto', vmin=vmin, vmax=vmax, origin='lower', extent=(0, fmax // 2, 0, fmax), interpolation='none') ax.set_title('Bicoherence (%s)' % self.method) ax.set_xlabel('Frequency (Hz)') ax.set_ylabel('Frequency (Hz)') # add_colorbar(fig, cax, vmin, vmax, unit='', ax=ax) return ax def main_frequency(self): pass
def phase_amplitude(signals, phase=True, amplitude=True): """Extract instantaneous phase and amplitude with Hilbert transform""" # one dimension array if signals.ndim == 1: signals = signals[None, :] one_dim = True elif signals.ndim == 2: one_dim = False else: raise ValueError('Impossible to compute phase_amplitude with ndim =' ' %s.' % (signals.ndim, )) n_epochs, n_points = signals.shape n_fft = compute_n_fft(signals) sig_phase = np.empty(signals.shape) if phase else None sig_amplitude = np.empty(signals.shape) if amplitude else None for i, sig in enumerate(signals): sig_complex = hilbert(sig, n_fft)[:n_points] if phase: sig_phase[i] = np.angle(sig_complex) if amplitude: sig_amplitude[i] = np.abs(sig_complex) # one dimension array if one_dim: if phase: sig_phase = sig_phase[0] if amplitude: sig_amplitude = sig_amplitude[0] return sig_phase, sig_amplitude def crop_for_fast_hilbert(signals): """Crop the signal to have a good prime decomposition, for hilbert filter. """ if signals.ndim < 2: tmax = signals.shape[0] while prime_factors(tmax)[-1] > 20: tmax -= 1 return signals[:tmax] else: tmax = signals.shape[1] while prime_factors(tmax)[-1] > 20: tmax -= 1 return signals[:, :tmax]