"""Time-frequency decompositions using wavelets."""
import numpy as np
from scipy.signal import morlet
from neurodsp.utils.data import create_freqs
from neurodsp.utils.checks import check_n_cycles, check_param_options
from neurodsp.utils.decorators import multidim
###################################################################################################
###################################################################################################
[docs]@multidim()
def convolve_wavelet(sig, fs, freq, n_cycles=7, scaling=0.5, wavelet_len=None, norm='sss'):
"""Convolve a signal with a complex wavelet.
Parameters
----------
sig : 1d array
Time series to filter.
fs : float
Sampling rate, in Hz.
freq : float
Center frequency of bandpass filter.
n_cycles : float, optional, default: 7
Length of the filter, as the number of cycles of the oscillation with specified frequency.
scaling : float, optional, default: 0.5
Scaling factor for the morlet wavelet.
wavelet_len : int, optional
Length of the wavelet. If defined, this overrides the freq and n_cycles inputs.
norm : {'sss', 'amp'}, optional
Normalization method:
* 'sss' - divide by the square root of the sum of squares
* 'amp' - divide by the sum of amplitudes
Returns
-------
array
Complex time series.
Notes
-----
* The real part of the returned array is the filtered signal.
* Taking np.abs() of output gives the analytic amplitude.
* Taking np.angle() of output gives the analytic phase.
Examples
--------
Convolve a complex wavelet with a simulated signal:
>>> from neurodsp.sim import sim_combined
>>> sig = sim_combined(n_seconds=10, fs=500,
... components={'sim_powerlaw': {}, 'sim_oscillation' : {'freq': 10}})
>>> cts = convolve_wavelet(sig, fs=500, freq=10)
"""
check_param_options(norm, 'norm', ['sss', 'amp'])
if wavelet_len is None:
wavelet_len = int(n_cycles * fs / freq)
if wavelet_len > sig.shape[-1]:
raise ValueError('The length of the wavelet is greater than the signal. Can not proceed.')
morlet_f = morlet(wavelet_len, w=n_cycles, s=scaling)
if norm == 'sss':
morlet_f = morlet_f / np.sqrt(np.sum(np.abs(morlet_f)**2))
elif norm == 'amp':
morlet_f = morlet_f / np.sum(np.abs(morlet_f))
mwt_real = np.convolve(sig, np.real(morlet_f), mode='same')
mwt_imag = np.convolve(sig, np.imag(morlet_f), mode='same')
return mwt_real + 1j * mwt_imag