import numpy as np
from scipy import signal
import matplotlib.pyplot as plt
from .spectrum import Spectrum
class FIR(object):
"""FIR filter
Parameters
----------
fir : array
Finite impulse response (FIR) filter
Examples
--------
>>> from pactools.utils.fir import FIR
>>> f = FIR(fir=[0.2, 0.6, 0.2])
>>> f.plot()
>>> signal_out = f.transform(signal_in)
"""
def __init__(self, fir=np.ones(1), fs=1.0):
self.fir = fir
self.fs = fs
def transform(self, sigin, out=None, out_imag=None):
"""Apply this filter to a signal
Parameters
----------
sigin : array, shape (n_points, ) or (n_signals, n_points)
Input signal
Returns
-------
out : array, shape (n_points, ) or (n_signals, n_points)
Filtered signal
"""
sigin_ndim = sigin.ndim
sigin = np.atleast_2d(sigin)
if out is None:
out = np.empty(sigin.shape)
else:
out = np.atleast_2d(out)
assert out.dtype.kind in 'fc'
for i, sig in enumerate(sigin):
tmp = signal.fftconvolve(sig, self.fir, 'same')
out[i] = tmp
if sigin_ndim == 1:
out = out[0]
else:
out = np.asarray(out)
return out
def plot(self, axs=None, fscale='log', colors=None):
"""
Plots the impulse response and the transfer function of the filter.
"""
# validate figure
fig_passed = axs is not None
if axs is None:
fig, axs = plt.subplots(nrows=2)
else:
axs = np.atleast_1d(axs)
if np.any([not isinstance(ax, plt.Axes) for ax in axs]):
raise TypeError('axs must be a list of matplotlib Axes, got {}'
' instead.'.format(type(axs)))
# test if is figure and has 2 axes
if len(axs) < 2:
raise ValueError('Passed figure must have at least two axes'
', given figure has {}.'.format(len(axs)))
fig = axs[0].figure
# compute periodogram
fft_length = max(int(2 ** np.ceil(np.log2(self.fir.shape[0]))), 2048)
s = Spectrum(fft_length=fft_length, block_length=self.fir.size,
step=None, fs=self.fs, wfunc=np.ones, donorm=False)
s.periodogram(self.fir)
s.plot('Transfer function of FIR filter', fscale=fscale, axes=axs[0],
colors=colors)
# plots
axs[1].plot(self.fir, color=colors[0])
axs[1].set_title('Impulse response of FIR filter')
axs[1].set_xlabel('Samples')
axs[1].set_ylabel('Amplitude')
if not fig_passed:
fig.tight_layout()
return fig
[docs]class BandPassFilter(FIR):
"""Band-pass FIR filter
Designs a band-pass FIR filter centered on frequency fc.
Parameters
----------
fs : float
Sampling frequency
fc : float
Center frequency of the bandpass filter
n_cycles : float or None, (default 7.0)
Number of oscillation in the wavelet. None if bandwidth is used.
bandwidth : float or None, (default None)
Bandwidth of the FIR wavelet filter. None if n_cycles is used.
zero_mean : boolean, (default True)
If True, the mean of the FIR is subtracted, i.e. fir.sum() = 0.
extract_complex : boolean, (default False)
If True, the wavelet filter is complex and ``transform`` returns two
signals, filtered with the real and the imaginary part of the filter.
Examples
--------
>>> from pactools.utils import BandPassFilter
>>> f = BandPassFilter(fs=100., fc=5., bandwidth=1., n_cycles=None)
>>> f.plot()
>>> signal_out = f.transform(signal_in)
"""
[docs] def __init__(self, fs, fc, n_cycles=7.0, bandwidth=None, zero_mean=True,
extract_complex=False):
self.fc = fc
self.fs = fs
self.n_cycles = n_cycles
self.bandwidth = bandwidth
self.zero_mean = zero_mean
self.extract_complex = extract_complex
self._design()
def _design(self):
"""Designs the FIR filter"""
# the length of the filter
order = self._get_order()
half_order = (order - 1) // 2
w = np.blackman(order)
t = np.linspace(-half_order, half_order, order)
phase = (2.0 * np.pi * self.fc / self.fs) * t
car = np.cos(phase)
fir = w * car
# the filter must be symmetric, in order to be zero-phase
assert np.all(np.abs(fir - fir[::-1]) < 1e-15)
# remove the constant component by forcing fir.sum() = 0
if self.zero_mean:
fir -= fir.sum() / order
gain = np.sum(fir * car)
self.fir = fir * (1.0 / gain)
# add the imaginary part to have a complex wavelet
if self.extract_complex:
car_imag = np.sin(phase)
fir_imag = w * car_imag
self.fir_imag = fir_imag * (1.0 / gain)
return self
def _get_order(self):
if self.bandwidth is None and self.n_cycles is not None:
half_order = int(float(self.n_cycles) / self.fc * self.fs / 2)
elif self.bandwidth is not None and self.n_cycles is None:
half_order = int(1.65 * self.fs / self.bandwidth) // 2
else:
raise ValueError('fir.BandPassFilter: n_cycles and bandwidth '
'cannot be both None, or both not None. Got '
'%s and %s' % (self.n_cycles, self.bandwidth, ))
order = half_order * 2 + 1
return order
def transform(self, sigin, out=None, out_imag=None):
"""Apply this filter to a signal
Parameters
----------
sigin : array, shape (n_points, ) or (n_signals, n_points)
Input signal
Returns
-------
filtered : array, shape (n_points, ) or (n_signals, n_points)
Filtered signal
(filtered_imag) : array, shape (n_points, ) or (n_signals, n_points)
Only when extract_complex is true.
Filtered signal with the imaginary part of the filter
"""
filtered = super(BandPassFilter, self).transform(sigin, out=out)
if self.extract_complex:
fir = FIR(fir=self.fir_imag, fs=self.fs)
filtered_imag = fir.transform(sigin, out=out_imag)
return filtered, filtered_imag
else:
return filtered
def plot(self, axs=None, fscale='log', colors=None):
"""
Plots the impulse response and the transfer function of the filter.
"""
fig = super(BandPassFilter, self).plot(axs=axs, fscale=fscale,
colors=colors)
if self.extract_complex:
if axs is None:
axs = fig.axes
fir = FIR(fir=self.fir_imag, fs=self.fs)
fir.plot(axs=axs, fscale=fscale, colors=colors)
return fig
[docs]class LowPassFilter(FIR):
"""Low-pass FIR filter
Designs a FIR filter that is a low-pass filter.
Parameters
----------
fs : float
Sampling frequency
fc : float
Cut-off frequency of the low-pass filter
bandwidth : float
Bandwidth of the FIR wavelet filter
ripple_db : float (default 60.0)
Positive number specifying maximum ripple in passband (dB) and minimum
ripple in stopband, in Kaiser-window low-pass FIR filter.
Examples
--------
>>> from pactools.utils import LowPassFilter
>>> f = LowPassFilter(fs=100., fc=5., bandwidth=1.)
>>> f.plot()
>>> signal_out = f.transform(signal_in)
"""
[docs] def __init__(self, fs, fc, bandwidth, ripple_db=60.0):
self.fs = fs
self.fc = fc
self.bandwidth = bandwidth
self.ripple_db = ripple_db
self._design()
def _design(self):
# Compute the order and Kaiser parameter for the FIR filter.
N, beta = signal.kaiserord(self.ripple_db,
self.bandwidth / self.fs * 2)
# Use firwin with a Kaiser window to create a lowpass FIR filter.
fir = signal.firwin(N, self.fc / self.fs * 2, window=('kaiser', beta))
# the filter must be symmetric, in order to be zero-phase
assert np.all(np.abs(fir - fir[::-1]) < 1e-15)
self.fir = fir / np.sum(fir)
return self