import numpy as np
from scipy import signal
import matplotlib.pyplot as plt
from ..utils.spectrum import Spectrum
from ..utils.fir import BandPassFilter, LowPassFilter
from ..utils.arma import Arma
from ..utils.validation import check_random_state
def _decimate(x, q):
"""
Downsample the signal after low-pass filtering to avoid aliasing.
An order 16 Chebyshev type I filter is used.
Parameters
----------
x : ndarray
The signal to be downsampled, as an N-dimensional array.
q : int
The downsampling factor.
Returns
-------
y : ndarray
The down-sampled signal.
"""
if not isinstance(q, int):
raise TypeError("q must be an integer")
b, a = signal.filter_design.cheby1(16, 0.025, 0.98 / q)
y = signal.filtfilt(b, a, x, axis=-1)
sl = [slice(None)] * y.ndim
sl[-1] = slice(None, None, q)
return y[sl]
def decimate(sig, fs, decimation_factor):
"""Decimates the signal:
Downsampling after low-pass filtering to avoid aliasing
Parameters
----------
sig : array
Raw input signal
fs : float
Sampling frequency of the input
decimation_factor : int > 0
Ratio of sampling frequencies (old/new)
Returns
-------
sig : array
Decimated signal
fs : float
Sampling frequency of the output
"""
# -------- center the signal
sig = sig - np.mean(sig)
# -------- resample
# decimation could be performed in two steps for better performance
# 0 in the following array means no decimation
dec_1st = [
0, 0, 2, 3, 4, 5, 6, 7, 2, 3, 2, 0, 3, 0, 2, 3, 4, 0, 3, 0, 4, 3, 0, 0,
4, 5, 0, 3, 4, 0, 5
]
dec_2nd = [
0, 0, 0, 0, 0, 0, 0, 0, 4, 3, 5, 0, 4, 0, 7, 5, 4, 0, 6, 0, 5, 7, 0, 0,
6, 5, 0, 9, 7, 0, 6
]
d1 = dec_1st[decimation_factor]
if d1 == 0:
raise ValueError('cannot decimate by %d' % decimation_factor)
sig = _decimate(sig, d1)
sig = sig.astype(np.float32)
d2 = dec_2nd[decimation_factor]
if d2 > 0:
sig = _decimate(sig, d2)
sig = sig.astype(np.float32)
fs = fs / decimation_factor
return sig, fs
def extract_and_fill(sig, fs, fc, n_cycles=None, bandwidth=1.0, fill=0,
draw='', random_noise=None, extract_complex=True,
low_pass=False, random_state=None):
"""Creates a FIR bandpass filter, applies this filter to a signal to obtain
the filtered signal low_sig and its complement high_sig.
Also fills the frequency gap in high_sig.
sig : array
Input signal
fs : float
Sampling frequency
fc : float
Center frequency of the bandpass filter
n_cycles : float
Number of cycles in the bandpass filter
Should be None if bandwidth is not None
bandwidth : float
Bandwidth of the bandpass filter
Should be None if n_cycles is not None
fill : int in {0, 1, 2}
Filling strategy for in high_sig
0 : keep the signal unchanged: high_sig = sig
1 : remove (the bandpass filtered signal): high_sig = sig - low_sig
2 : remove and replace by bandpass filtered Gaussian white noise
draw : string
List of plots to draw
extract_complex : boolean
Use a complex wavelet
low_pass : boolean
Use a lowpass filter at fc instead of a bandpass filter centered at fc
random_state : None, int or np.random.RandomState instance
Seed or random number generator for the surrogate analysis
Returns
-------
low_sig : array
Bandpass filtered signal
high_sig : array
Processed fullband signal
low_sig_imag : array (returned only if extract_complex is True)
Imaginary part of the bandpass filtered signal
"""
rng = check_random_state(random_state)
if random_noise is None:
random_noise = rng.randn(len(sig))
if low_pass:
filt = LowPassFilter(fs=fs, fc=fc, bandwidth=bandwidth)
if extract_complex:
raise NotImplementedError('extract_complex incompatible with '
'low_pass filter.')
else:
filt = BandPassFilter(fs=fs, fc=fc, n_cycles=n_cycles,
bandwidth=bandwidth, zero_mean=True,
extract_complex=extract_complex)
if 'e' in draw or 'z' in draw:
filt.plot(fscale='lin')
if extract_complex:
low_sig, low_sig_imag = filt.transform(sig)
else:
low_sig = filt.transform(sig)
if fill == 0:
# keeping driver in high_sig
high_sig = sig
if 'z' in draw or 'e' in draw:
_plot_multiple_spectrum([sig, low_sig, high_sig], labels=None,
fs=fs, colors='bgr')
plt.legend(['input', 'driver', 'output'], loc=0)
elif fill == 1:
# subtracting driver
high_sig = sig - low_sig
if 'z' in draw or 'e' in draw:
_plot_multiple_spectrum([sig, low_sig, high_sig], labels=None,
fs=fs, colors='bgr')
plt.legend(['input', 'driver', 'output'], loc=0)
elif fill == 2:
# replacing driver by a white noise
high_sig = sig - low_sig
if extract_complex:
fill_sig, _ = filt.transform(random_noise)
else:
fill_sig = filt.transform(random_noise)
fill_sig.shape = sig.shape
# adjust the power of the filling signal and add it to high_sig
if low_pass:
fa = fc / 2.
dfa = fc / 2.
else:
fa = fc
dfa = bandwidth
high_sig = fill_gap(high_sig, fs, fgap=(fa - dfa, fa + dfa), draw=draw,
fill_sig=fill_sig)
if 'z' in draw or 'e' in draw:
_plot_multiple_spectrum(
[sig, low_sig, sig - low_sig, fill_sig,
high_sig], labels=None, fs=fs, colors='bggrr')
plt.legend(['input', 'driver', 'input-driver', 'filling'
'output'], loc=0)
else:
raise ValueError('Invalid fill parameter: %s' % str(fill))
if extract_complex:
return low_sig, high_sig, low_sig_imag
else:
return low_sig, high_sig
def low_pass_and_fill(sig, fs, fc=1.0, draw='', bandwidth=1.,
random_state=None):
low_sig, high_sig = extract_and_fill(
sig, fs, fc, fill=1, low_pass=True, bandwidth=bandwidth,
random_state=random_state, extract_complex=False)
rng = check_random_state(random_state)
random_noise = rng.randn(*sig.shape)
filt = LowPassFilter(fs=fs, fc=fc, bandwidth=bandwidth)
fill_sig = filt.transform(random_noise)
# adjust power of fill_sig and add it to high_sig
filled_sig = fill_gap(sig=high_sig, fs=fs, fgap=(0, fc),
draw=draw, fill_sig=fill_sig)
return filled_sig
def _plot_multiple_spectrum(signals, fs, labels, colors):
"""
plot the signals spectrum
"""
s = Spectrum(block_length=min(2048, signals[0].size), fs=fs,
wfunc=np.blackman)
for sig in signals:
s.periodogram(sig, hold=True)
s.plot(labels=labels, colors=colors, fscale='lin')
def whiten(sig, fs, ordar=8, draw='', enf=50.0, d_enf=1.0, zero_phase=True,
**kwargs):
"""Use an AR model to whiten a signal
The whitening filter is not estimated around multiples of
the electric network frequency (up to d_enf Hz)
sig : input signal
fs : sampling frequency of input signal
ordar : order of AR whitening filter
draw : list of plots
enf : electric network frequency
denf : tolerance on electric network frequency
zero_phase : if True, apply half the whitening for sig(t) and sig(-t)
returns the whitened signal
"""
# -------- create the AR model and its spectrum
ar = Arma(ordar=ordar, ordma=0, fs=fs, block_length=min(1024, sig.size))
ar.periodogram(sig)
# duplicate to see the removal of the electric network frequency
ar.periodogram(sig, hold=True)
fft_length, _ = ar.check_params()
# -------- remove the influence of the electric network frequency
k = 1
# while the harmonic k is included in the spectrum
while k * enf - d_enf < fs / 2.0:
fmin = k * enf - d_enf
fmax = k * enf + d_enf
kmin = max((0, int(fft_length * fmin / fs)))
kmax = min(fft_length // 2, int(fft_length * fmax / fs) + 1)
Amin = ar.psd[-1][0, kmin]
Amax = ar.psd[-1][0, kmax]
# linear interpolation between (kmin, Amin) and (kmax, Amax)
interpol = np.linspace(Amin, Amax, kmax - kmin, endpoint=False)
# remove positive frequencies
ar.psd[-1][0, kmin:kmax] = interpol
k += 1
# -------- change psd for zero phase filtering
if zero_phase:
ar.psd[-1] = np.sqrt(ar.psd[-1])
# -------- estimate the model and apply it
ar.estimate()
# apply the whitening twice (forward and backward) for zero-phase filtering
if zero_phase:
sigout = ar.inverse(sig)
sigout = sigout[::-1]
sigout = ar.inverse(sigout)
sigout = sigout[::-1]
else:
sigout = ar.inverse(sig)
gain = np.std(sig) / np.std(sigout)
sigout *= gain
if 'w' in draw or 'z' in draw:
ar.arma2psd(hold=True)
ar.periodogram(sigout, hold=True)
ar.plot('periodogram before/after whitening', labels=[
'PSD', 'sqrt(PSD) without electric network', 'model AR',
'whitened'
], fscale='lin')
plt.legend(loc='lower left')
return sigout
def fill_gap(sig, fs, fgap, draw='', fill_sig=None,
random_state=None):
"""Fill a frequency gap with white noise.
"""
rng = check_random_state(random_state)
# -------- get the amplitude of the gap
sp = Spectrum(block_length=min(512, sig.size), fs=fs, wfunc=np.blackman)
fft_length, _ = sp.check_params()
sp.periodogram(sig)
fmin, fmax = fgap
kmin = max((0, int(fft_length * fmin / fs)))
kmax = min(fft_length // 2, int(fft_length * fmax / fs) + 1)
Amin = sp.psd[-1][0, kmin]
Amax = sp.psd[-1][0, kmax]
if kmin == 0 and kmax == (fft_length // 2):
# we can't fill the entire spectrum
return sig
if kmin == 0:
# if the gap reach zero, we only consider the right bound
Amin = Amax
if kmax == (fft_length // 2):
# if the gap reach fft_length / 2, we only consider the left bound
Amax = Amin
A_fa = (Amin + Amax) * 0.5
# -------- bandpass filtering of white noise
if fill_sig is None:
white_noise = rng.randn(*sig.shape)
if kmin == 0:
fir = LowPassFilter(fs=fs, fc=fmax,
bandwidth=min(fmax, fs / 2 - fmax),
zero_mean=False)
fill_sig = fir.transform(white_noise)
elif kmax == (fft_length // 2):
fir = LowPassFilter(fs=fs, fc=fmax,
bandwidth=min(fmax, fs / 2 - fmax),
zero_mean=False)
fill_sig = white_noise - fir.transform(white_noise)
else:
fc = (fmin + fmax) / 2.
bandwidth = (fmax - fmin) / 2.
fir = BandPassFilter(fs=fs, fc=fc, n_cycles=None,
bandwidth=bandwidth,
zero_mean=False)
fill_sig = fir.transform(white_noise)
# -------- compute the scale parameter
sp.periodogram(fill_sig, hold=True)
if kmin == 0:
kfa = kmin
elif kmax == (fft_length // 2):
kfa = kmax
else:
kfa = int(fft_length * (fmin + fmax) / 2. / fs)
scale = np.sqrt(A_fa / sp.psd[-1][0, kfa])
fill_sig *= scale
sig += fill_sig
if 'g' in draw or 'z' in draw:
labels = ['signal', 'fill signal', 'gap filled']
sp.periodogram(sig, hold=True)
sp.plot(labels=labels, fscale='lin', title='fill')
return sig
def _show_plot(draw):
if draw:
plt.show()
def multiple_extract_driver(sigs, fs, frequency_range, n_cycles=None,
bandwidth=1.0, fill=2, whitening='after', ordar=10,
normalize=False, extract_complex=True,
random_state=None, draw='', max_low_fq=None,
enf=50.):
"""Extract the driver for several bandpass center frequency.
Parameters
----------
sigs : array, shape (n_epochs, n_points)
Input array to filter
fs : float
Sampling frequency
frequency_range : float, list, or array, shape (n_frequencies, )
List of center frequency of bandpass filters.
bandwidth : float
Bandwidth of the bandpass filters.
Use it to have a constant bandwidth for all filters.
Should be None if n_cycles is not None.
n_cycles : float
Number of cycles of the bandpass filters.
Use it to have a bandwidth proportional to the center frequency.
Should be None if bandwidth is not None.
fill : in {0, 1, 2}
Filling strategy for the full band signal high_sigs:
0 : keep the signal unchanged: high_sigs = sigs
1 : remove the bandpass filtered signal: high_sigs = sigs - low_sigs
2 : remove and replace by bandpass filtered Gaussian white noise
whitening : in {'before', 'after', None}
Define when the whitening is done compared to the filtering.
ordar : int >= 0
Order of the AR model used for whitening
normalize : boolean
Whether to scale the signals to have unit norm high_sigs.
The low_sigs are scaled with the same scales.
extract_complex : boolean
Whether to extract a complex driver (low_sigs and low_sigs_imag)
random_state : None, int or np.random.RandomState instance
Seed or random number generator for the white noise filling strategy.
draw : string
Add a letter to the string to draw the corresponding figures:
- 'e' : extraction of the driver
- 'g' : gap filling
- 'w' : whitening step
- 'z' : all
max_low_fq : float or None
Maximum low_fq over a potential cross-validation scheme.
Returns
-------
low_sigs : array, shape (n_epochs, n_points)
Bandpass filtered signal (aka driver)
high_sigs : array, shape (n_epochs, n_points)
Bandstop filtered signal
low_sigs_imag : array, shape (n_epochs, n_points)
Imaginary part of the bandpass filtered signal
Returned only if extract_complex is True.
Examples
--------
>>> for (low_sig, high_sig, low_sigs_imag) in multiple_extract_driver(
... sigs, fs, [2., 3., 4.]):
... pass
"""
frequency_range = np.atleast_1d(frequency_range)
sigs = np.atleast_2d(sigs)
rng = check_random_state(random_state)
if whitening == 'before':
sigs = [whiten(sig, fs=fs, ordar=ordar, draw=draw, enf=enf)
for sig in sigs]
_show_plot(draw)
# extract the high frequencies independently of the driver
if max_low_fq is None:
max_low_fq = max(frequency_range)
fc_low_pass = (max_low_fq + bandwidth * 2) # arbitrary
low_pass_width = bandwidth
low_and_high = [
extract_and_fill(sig, fs=fs, fc=fc_low_pass, bandwidth=low_pass_width,
fill=fill, random_noise=None,
draw=draw, extract_complex=False, low_pass=True,
random_state=rng.randint(np.iinfo(np.int32).max))
for sig in sigs
]
high_sigs = [both[1] for both in low_and_high]
if whitening == 'after':
high_sigs = [
whiten(high_sig, fs=fs, ordar=ordar, draw=draw, enf=enf)
for high_sig in high_sigs
]
if normalize:
scales = [1.0 / np.std(high_sig) for high_sig in high_sigs]
high_sigs = [high * s for (high, s) in zip(high_sigs, scales)]
# as high_sigs is now fixed, we don't need the following
fill = 0
random_noise = None
# extract_and_fill the driver
for fc in frequency_range:
low_and_high = [
extract_and_fill(
sig, fs=fs, fc=fc, n_cycles=n_cycles, bandwidth=bandwidth,
fill=fill, random_noise=random_noise, draw=draw,
extract_complex=extract_complex, random_state=random_state)
for sig in sigs
]
low_sigs = [both[0] for both in low_and_high]
if extract_complex:
low_sigs_imag = [both[2] for both in low_and_high]
# normalize variances
if normalize:
low_sigs = [low * s for (low, s) in zip(low_sigs, scales)]
if extract_complex:
low_sigs_imag = [
low * s for (low, s) in zip(low_sigs_imag, scales)
]
_show_plot(draw)
low_sigs = np.asarray(low_sigs)
high_sigs = np.asarray(high_sigs)
if extract_complex:
low_sigs_imag = np.array(low_sigs_imag)
yield low_sigs, high_sigs, low_sigs_imag
else:
yield low_sigs, high_sigs