Source code for neurodsp.plts.timefrequency

"""Plotting functions for neurodsp.timefrequency."""

import numpy as np

from neurodsp.plts.style import style_plot
from neurodsp.plts.utils import check_ax, savefig

###################################################################################################
###################################################################################################

[docs]@savefig @style_plot def plot_timefrequency(times, freqs, powers, x_ticks=5, y_ticks=5, ax=None, **kwargs): """Plot a time-frequency representation of data. Parameters ---------- times : 1d array The time dimension for the time-frequency representation. freqs : 1d array The frequency dimension for the time-frequency representation. powers : 2d array Power values to plot. If array is complex, the real component is taken for plotting. x_ticks, y_ticks : int or array_like Defines the tick labels to add to the plot. If int, is the number of evenly sampled labels to add to the plot. If array_like, is a set of labels to add to the plot. ax : matplotlib.Axes, optional Figure axes upon which to plot. **kwargs Keyword arguments for customizing the plot. Examples -------- Plot a Morlet transformation: >>> import numpy as np >>> from neurodsp.sim import sim_bursty_oscillation >>> from neurodsp.timefrequency.wavelets import compute_wavelet_transform >>> fs=1000 >>> sig = sim_bursty_oscillation(n_seconds=10, fs=fs, freq=10) >>> times = np.arange(0, len(sig)/fs, 1/fs) >>> freqs = np.arange(1, 50, 1) >>> mwt = compute_wavelet_transform(sig, fs, freqs) >>> plot_timefrequency(times, freqs, mwt) """ ax = check_ax(ax, figsize=kwargs.pop('figsize', None)) if np.iscomplexobj(powers): powers = abs(powers) ax.imshow(powers, aspect='auto', **kwargs) ax.invert_yaxis() ax.set_xlabel('Time (s)') ax.set_ylabel('Frequency (Hz)') if isinstance(x_ticks, int): x_tick_pos = np.linspace(0, times.size, x_ticks) x_ticks = np.round(np.linspace(times[0], times[-1], x_ticks), 2) else: x_tick_pos = [np.argmin(np.abs(times - val)) for val in x_ticks] ax.set(xticks=x_tick_pos, xticklabels=x_ticks) if isinstance(y_ticks, int): y_ticks_pos = np.linspace(0, freqs.size, y_ticks) y_ticks = np.round(np.linspace(freqs[0], freqs[-1], y_ticks), 2) else: y_ticks_pos = [np.argmin(np.abs(freqs - val)) for val in y_ticks] ax.set(yticks=y_ticks_pos, yticklabels=y_ticks)