#!/usr/bin/env python
"""Plotting methods.
This module contains various plotting methods. There are two types of
plotting methods: the Primitives and the Composites. The Primitives are
the most basic and offer simple, single-plot representations. The
Composites are composed of several primitives and offer more complex
representations.
The primitive plots are those whose name begin with ``ax_``, (e.g.
``ax_scalp``).
In order to get more reasonable defaults for colors etc. you can call
the modules :func:`beautify` method::
from wyrm import plot
plot.beautify()
.. warning::
This module needs heavy reworking! We have yet to find a consistent
way to handle primitive and composite plots, deal with the fact that
some plots just manipulate axes, while others operate on figures and
have to decide on which layer of matplotlib we want to deal with
(i.e. pyplot, artist or even pylab).
The API of this module will change and you should not rely on any
method here.
"""
from __future__ import division
import math
import numpy as np
from scipy import interpolate
import matplotlib as mpl
from matplotlib import axes
from matplotlib.colorbar import ColorbarBase
from matplotlib.colors import Normalize
from matplotlib import pyplot as plt
from matplotlib import ticker
from matplotlib.patches import Rectangle
from wyrm import processing as proc
from wyrm.processing import CHANNEL_10_20
from wyrm.types import Data
# ############# OLD FUNCTIONS ############################################
[docs]def plot_channels(dat, ncols=8, chanaxis=-1, otheraxis=-2):
"""Plot all channels for a continuous or epo.
In case of an epoched Data object, the classwise average is
calculated, and for each channel the respective classes are plotted.
Parameters
----------
dat : Data
continous or epoched Data object
ncols : int, optional
the number of colums in the grid. The number of rows is
calculated depending on ``ncols`` and the number of channels
"""
# test if epo
is_epo = False
if dat.data.ndim == 3:
is_epo = True
dat = proc.calculate_classwise_average(dat)
ax = []
n_channels = dat.data.shape[chanaxis]
nrows = int(np.ceil(n_channels / ncols))
f, ax = plt.subplots(nrows=nrows, ncols=ncols, sharex=True, sharey=True);
for i, chan in enumerate(dat.axes[chanaxis]):
a = ax[i // ncols, i % ncols]
dat.axes[otheraxis], dat.data.take([i], chanaxis)
if is_epo:
for j, name in enumerate(dat.class_names):
cnt = proc.select_classes(dat, [j])
a.plot(cnt.axes[otheraxis], cnt.data.take([i], chanaxis).squeeze(), label=name)
else:
a.plot(dat.axes[otheraxis], dat.data.take([i], chanaxis).squeeze())
a.set_title(chan)
a.axvline(x=0, color='black')
a.axhline(y=0, color='black')
plt.legend()
[docs]def plot_spatio_temporal_r2_values(dat):
"""Calculate the signed r^2 values and plot them in a heatmap.
Parameters
----------
dat : Data
epoched data
"""
r2 = proc.calculate_signed_r_square(dat)
r2 *= -1
max = np.max(np.abs(r2))
plt.imshow(r2.T, aspect='auto', interpolation='None', vmin=-max, vmax=max, cmap='RdBu')
ax = plt.gca()
# TODO: sort front-back, left-right
# use the locators to fine-tune the ticks
mask = [True if chan.endswith('z') else False for chan in dat.axes[-1]]
ax.yaxis.set_major_locator(ticker.FixedLocator(np.nonzero(mask)[0]))
ax.yaxis.set_major_formatter(ticker.IndexFormatter(dat.axes[-1]))
ax.xaxis.set_major_locator(ticker.MultipleLocator( np.max(dat.axes[-2]) // 100))
ax.xaxis.set_major_formatter(ticker.IndexFormatter(['%.1f' % i for i in dat.axes[-2]]))
plt.xlabel('%s [%s]' % (dat.names[-2], dat.units[-2]))
plt.ylabel('%s [%s]' % (dat.names[-1], dat.units[-1]))
plt.tight_layout(True)
plt.colorbar()
plt.grid(True)
[docs]def plot_spectrogram(spectrogram, freqs):
extent = 0, len(spectrogram), freqs[0], freqs[-1]
plt.imshow(spectrogram.transpose(),
aspect='auto',
origin='lower',
extent=extent,
interpolation='none')
plt.colorbar()
plt.ylabel('Frequency [Hz]')
plt.xlabel('Time')
# ############# COMPOSITE PLOTS ##########################################
[docs]def plot_timeinterval(data, r_square=None, highlights=None, hcolors=None,
legend=True, reg_chans=None, position=None):
"""Plots a simple time interval.
Plots all channels of either continuous data or the mean of epoched
data into a single timeinterval plot.
Parameters
----------
data : wyrm.types.Data
Data object containing the data to plot.
r_square : [values], optional
List containing r_squared values to be plotted beneath the main
plot (default: None).
highlights : [[int, int)]
List of tuples containing the start point (included) and end
point (excluded) of each area to be highlighted (default: None).
hcolors : [colors], optional
A list of colors to use for the highlights areas (default:
None).
legend : Boolean, optional
Flag to switch plotting of the legend on or off (default: True).
reg_chans : [regular expression], optional
A list of regular expressions. The plot will be limited to those
channels matching the regular expressions. (default: None).
position : [x, y, width, height], optional
A Rectangle that limits the plot to its boundaries (default:
None).
Returns
-------
Matplotlib.Axes or (Matplotlib.Axes, Matplotlib.Axes)
The Matplotlib.Axes corresponding to the plotted timeinterval
and, if provided, the Axes corresponding to r_squared values.
Examples
--------
Plots all channels contained in data with a legend.
>>> plot_timeinterval(data)
Same as above, but without the legend.
>>> plot_timeinterval(data, legend=False)
Adds r-square values to the plot.
>>> plot_timeinterval(data, r_square=[values])
Adds a highlighted area to the plot.
>>> plot_timeinterval(data, highlights=[[200, 400]])
To specify the colors of the highlighted areas use 'hcolors'.
>>> plot_timeinterval(data, highlights=[[200, 400]], hcolors=['red'])
"""
dcopy = data.copy()
rect_ti_solo = [.07, .07, .9, .9]
rect_ti_r2 = [.07, .12, .9, .85]
rect_r2 = [.07, .07, .9, .05]
if position is None:
plt.figure()
if r_square is None:
pos_ti = rect_ti_solo
else:
pos_ti = rect_ti_r2
pos_r2 = rect_r2
else:
if r_square is None:
pos_ti = _transform_rect(position, rect_ti_solo)
else:
pos_ti = _transform_rect(position, rect_ti_r2)
pos_r2 = _transform_rect(position, rect_r2)
if reg_chans is not None:
dcopy = proc.select_channels(dcopy, reg_chans)
# process epoched data into continuous data using the mean
if len(data.data.shape) > 2:
dcopy = Data(np.mean(dcopy.data, axis=0), [dcopy.axes[-2], dcopy.axes[-1]],
[dcopy.names[-2], dcopy.names[-1]], [dcopy.units[-2], dcopy.units[-1]])
ax1 = None
# plotting of the data
ax0 = _subplot_timeinterval(dcopy, position=pos_ti, epoch=-1, highlights=highlights,
hcolors=hcolors, legend=legend)
ax0.xaxis.labelpad = 0
if r_square is not None:
ax1 = _subplot_r_square(r_square, position=pos_r2)
ax0.tick_params(axis='x', direction='in', pad=30 * pos_ti[3])
plt.grid(True)
if r_square is None:
return ax0
else:
return ax0, ax1
[docs]def plot_tenten(data, highlights=None, hcolors=None, legend=False, scale=True,
reg_chans=None):
"""Plots channels on a grid system.
Iterates over every channel in the data structure. If the
channelname matches a channel in the tenten-system it will be
plotted in a grid of rectangles. The grid is structured like the
tenten-system itself, but in a simplified manner. The rows, in which
channels appear, are predetermined, the channels are ordered
automatically within their respective row. Areas to highlight can be
specified, those areas will be marked with colors in every
timeinterval plot.
Parameters
----------
data : wyrm.types.Data
Data object containing the data to plot.
highlights : [[int, int)]
List of tuples containing the start point (included) and end
point (excluded) of each area to be highlighted (default: None).
hcolors : [colors], optional
A list of colors to use for the highlight areas (default: None).
legend : Boolean, optional
Flag to switch plotting of the legend on or off (default: True).
scale : Boolean, optional
Flag to switch plotting of a scale in the top right corner of
the grid (default: True)
reg_chans : [regular expressions]
A list of regular expressions. The plot will be limited to those
channels matching the regular expressions.
Returns
-------
[Matplotlib.Axes], Matplotlib.Axes
Returns the plotted timeinterval axes as a list of
Matplotlib.Axes and the plotted scale as a single
Matplotlib.Axes.
Examples
--------
Plotting of all channels within a Data object
>>> plot_tenten(data)
Plotting of all channels with a highlighted area
>>> plot_tenten(data, highlights=[[200, 400]])
Plotting of all channels beginning with 'A'
>>> plot_tenten(data, reg_chans=['A.*'])
"""
dcopy = data.copy()
# this dictionary determines which y-position corresponds with which row in the grid
ordering = {4.0: 0,
3.5: 0,
3.0: 1,
2.5: 2,
2.0: 3,
1.5: 4,
1.0: 5,
0.5: 6,
0.0: 7,
-0.5: 8,
-1.0: 9,
-1.5: 10,
-2.0: 11,
-2.5: 12,
-2.6: 12,
-3.0: 13,
-3.5: 14,
-4.0: 15,
-4.5: 15,
-5.0: 16}
# all the channels with their x- and y-position
system = dict(CHANNEL_10_20)
# create list with 17 empty lists. one for every potential row of channels.
channel_lists = []
for i in range(18):
channel_lists.append([])
if reg_chans is not None:
dcopy = proc.select_channels(dcopy, reg_chans)
# distribute the channels to the lists by their y-position
count = 0
for c in dcopy.axes[-1]:
if c in system:
# entries in channel_lists: [<channel_name>, <x-position>, <position in Data>]
channel_lists[ordering[system[c][1]]].append((c, system[c][0], count))
count += 1
# sort the lists of channels by their x-position
for l in channel_lists:
l.sort(key=lambda c_list: c_list[1])
# calculate the needed dimensions of the grid
columns = list(map(len, channel_lists))
columns = [value for value in columns if value != 0]
# add another axes to the first row for the scale
columns[0] += 1
plt.figure()
grid = calc_centered_grid(columns, hpad=.01, vpad=.01)
# axis used for sharing axes between channels
masterax = None
ax = []
row = 0
k = 0
scale_ax = 0
for l in channel_lists:
if len(l) > 0:
for i in range(len(l)):
ax.append(_subplot_timeinterval(dcopy, grid[k], epoch=-1, highlights=highlights, hcolors=hcolors, labels=False,
legend=legend, channel=l[i][2], shareaxis=masterax))
if masterax is None and len(ax) > 0:
masterax = ax[0]
# hide the axeslabeling
plt.tick_params(axis='both', which='both', labelbottom='off', labeltop='off', labelleft='off',
labelright='off', top='off', right='off')
# at this moment just to show what's what
plt.gca().annotate(l[i][0], (0.05, 0.05), xycoords='axes fraction')
k += 1
if row == 0 and i == len(l)-1:
# this is the last axes in the first row
scale_ax = k
k += 1
row += 1
# plot the scale axes
xtext = dcopy.axes[0][len(dcopy.axes[0])-1]
sc = _subplot_scale(str(xtext) + ' ms', "$\mu$V", position=grid[scale_ax])
return ax, sc
[docs]def plot_scalp(v, channels, levels=25, colormap=None, norm=None, ticks=None,
annotate=True, position=None):
"""Plots the values 'v' for channels 'channels' on a scalp.
Calculates the interpolation of the values v for the corresponding
channels 'channels' and plots it as a contour plot on a scalp. The
degree of gradients as well as the the appearance of the color bar
can be adjusted.
Parameters
----------
v : [value]
List containing the values of the channels.
channels : [String]
List containing the channel names.
levels : int, optional
The number of automatically created levels in the contour plot
(default: 25).
colormap : matplotlib.colors.colormap, optional
A colormap to define the color transitions (default: a
blue-white-red colormap).
norm : matplotlib.colors.norm, optional
A norm to define the min and max values (default: 'None', values
from -10 to 10 are assumed).
ticks : array([ints]), optional
An array with values to define the ticks on the colorbar
(default: 'None', 3 ticks at -10, 0 and 10 are displayed).
annotate : Boolean, optional
Flag to switch channel annotations on or off (default: True).
position : [x, y, width, height], optional
A Rectangle that limits the plot to its boundaries (default:
None).
Returns
-------
(Matplotlib.Axes, Matplotlib.Axes)
Returns a pair of Matplotlib.Axes. The first contains the
plotted scalp, the second the corresponding colorbar.
Examples
--------
Plots the values v for channels 'channels' on a scalp
>>> plot_scalp(v, channels)
This plot has finer gradients through increasing the levels to 50.
>>> plot_scalp(v, channels, levels=50)
This plot has a norm and ticks from 0 to 10
>>> n = matplotlib.colors.Normalize(vmin=0, vmax=10, clip=False)
>>> t = np.linspace(0.0, 10.0, 3, endpoint=True)
>>> plot_scalp(v, channels, norm=n, ticks=t)
"""
rect_scalp = [.05, .05, .8, .9]
rect_colorbar = [.9, .05, .05, .9]
fig = plt.gcf()
if position is None:
pos_scalp = rect_scalp
pos_colorbar = rect_colorbar
else:
pos_scalp = _transform_rect(position, rect_scalp)
pos_colorbar = _transform_rect(position, rect_colorbar)
if norm is None:
vmax = np.abs(v).max()
vmin = -vmax
norm = Normalize(vmin, vmax, clip=False)
if ticks is None:
ticks = np.linspace(norm.vmin, norm.vmax, 3)
a = fig.add_axes(pos_scalp)
ax0 = ax_scalp(v, channels, ax=a, annotate=annotate, vmin=norm.vmin, vmax=norm.vmax,
colormap=colormap)
a = fig.add_axes(pos_colorbar)
ax1 = ax_colorbar(norm.vmin, norm.vmax, ax=a, ticks=ticks, colormap=colormap,
label='')
return ax0, ax1
[docs]def plot_scalp_ti(v, channels, data, interval, scale_ti=.1, levels=25, colormap=None,
norm=None, ticks=None, annotate=True, position=None):
"""Plots a scalp with channels on top
Plots the values v for channels 'channels' on a scalp as a contour
plot. Additionaly plots the channels in channels_ti as a
timeinterval on top of the scalp plot. The individual channels are
placed over their position on the scalp.
Parameters
----------
v : [value]
List containing the values of the channels.
channels : [String]
List containing the channel names.
data : wyrm.types.Data
Data object containing the continuous data for the overlaying
timeinterval plots.
interval : [begin, end)
Tuple of ints to specify the range of the overlaying
timeinterval plots.
scale_ti : float, optional
The percentage to scale the overlaying timeinterval plots
(default: 0.1).
levels : int, optional
The number of automatically created levels in the contour plot
(default: 25).
colormap : matplotlib.colors.colormap, optional
A colormap to define the color transitions (default: a
blue-white-red colormap).
norm : matplotlib.colors.norm, optional
A norm to define the min and max values. If 'None', values from
-10 to 10 are assumed (default: None).
ticks : array([ints]), optional
An array with values to define the ticks on the colorbar
(default: None, 3 ticks at -10, 0 and 10 are displayed).
annotate : Boolean, optional
Flag to switch channel annotations on or off (default: True).
position : [x, y, width, height], optional
A Rectangle that limits the plot to its boundaries (default:
None).
Returns
-------
((Matplotlib.Axes, Matplotlib.Axes), [Matplotlib.Axes])
Returns a tuple of first a tuple with the plotted scalp and its
colorbar, then a list of all on top plotted timeintervals.
"""
rect_scalp = [.05, .05, .8, .9]
rect_colorbar = [.9, .05, .05, .9]
fig = plt.gcf()
if position is None:
pos_scalp = rect_scalp
pos_colorbar = rect_colorbar
else:
pos_scalp = _transform_rect(position, rect_scalp)
pos_colorbar = _transform_rect(position, rect_colorbar)
if colormap is None:
colormap = 'RdBu'
if ticks is None:
ticks = np.linspace(-10.0, 10.0, 3, endpoint=True)
a = fig.add_axes(pos_scalp)
ax0 = ax_scalp(v, channels, ax=a, annotate=annotate)
a = fig.add_axes(pos_colorbar)
ax1 = ax_colorbar(-10, 10, ax=a, ticks=ticks)
# modification of internally used data if a specific intervals is specified
cdat = data.copy()
if interval is not None:
startindex = np.where(cdat.axes[0] == interval[0])[0][0]
endindex = np.where(cdat.axes[0] == interval[1])[0][0]
cdat.axes[0] = cdat.axes[0][startindex:endindex]
cdat.data = cdat.data[startindex:endindex, :]
tis = []
for c in cdat.axes[1]:
points = get_channelpos(c)
if points is not None:
channelindex = np.where(cdat.axes[1] == c)[0][0]
# dirty: these are the x and y limits of the scalp axes
minx = -1.15
maxx = 1.15
miny = -1.10
maxy = 1.15
# transformation of karth. to relative coordinates
xy = (points[0] + (np.abs(minx))) * (1 / (np.abs(minx) + maxx)), \
(points[1] + (np.abs(miny))) * (1 / (np.abs(miny) + maxy))
pos_c = [xy[0] - (scale_ti / 2), xy[1] - (scale_ti / 2), scale_ti, scale_ti]
# transformation to fit into the scalp part of the plot
pos_c = _transform_rect(pos_scalp, pos_c)
tis.append(_subplot_timeinterval(cdat, position=pos_c, epoch=-1, highlights=None, legend=False,
channel=channelindex, shareaxis=None))
else:
print('The channel "' + c + '" was not found in the tenten-system.')
return (ax0, ax1), tis
# ############# TOOLS ####################################################
[docs]def set_highlights(highlights, hcolors=None, set_axes=None):
"""Sets highlights in form of vertical boxes to axes.
Parameters
----------
highlights : [(start, end)]
List of tuples containing the start point (included) and end
point (excluded) of each area to be highlighted.
hcolors : [colors], optional
A list of colors to use for the highlight areas (e.g. 'b',
'#eeefff' or [R, G, B] for R, G, B = [0..1]. If left as None the
colors blue, gree, red, cyan, magenta and yellow are used.
set_axes : [matplotlib.axes.Axes], optional
List of axes to highlights (default: None, all axes of the
current figure will be highlighted).
Examples
---------
To create two highlighted areas in all axes of the currently active
figure. The first area from 200ms - 300ms in blue and the second
area from 500ms - 600ms in green.
>>> set_highlights([[200, 300], [500, 600]])
"""
if highlights is not None:
if set_axes is None:
set_axes = plt.gcf().axes
def highlight(start, end, axis, color, alpha):
axis.axvspan(start, end, edgecolor='w', facecolor=color, alpha=alpha)
# the edges of the box are at the moment white. transparent edges
# would be better.
# create a standard variety of colors, if nothing is specified
if hcolors is None:
hcolors = ['b', 'g', 'r', 'c', 'm', 'y']
# create a colormask containing #spans colors iterating over specified
# colors or a standard variety
colormask = []
for index, span in enumerate(highlights):
colormask.append(hcolors[index % len(hcolors)])
# check if highlights is an instance of the Highlight class
for p in set_axes:
for idx, span in enumerate(highlights):
highlight(span[0], span[1], p, colormask[idx], .5)
[docs]def calc_centered_grid(cols_list, hpad=.05, vpad=.05):
"""Calculates a centered grid of Rectangles and their positions.
Parameters
----------
cols_list : [int]
List of ints. Every entry represents a row with as many channels
as the value.
hpad : float, optional
The amount of horizontal padding (default: 0.05).
vpad : float, optional
The amount of vertical padding (default: 0.05).
Returns
-------
[[float, float, float, float]]
A list of all rectangle positions in the form of [xi, xy, width,
height] sorted from top left to bottom right.
Examples
--------
Calculates a centered grid with 3 rows of 4, 3 and 2 columns
>>> calc_centered_grid([4, 3, 2])
Calculates a centered grid with more padding
>>> calc_centered_grid([5, 4], hpad=.1, vpad=.75)
"""
h = (1 - ((len(cols_list) + 1) * vpad)) / len(cols_list)
w = (1 - ((max(cols_list) + 1) * hpad)) / max(cols_list)
grid = []
row = 1
for l in cols_list:
yi = 1 - ((row * vpad) + (row * h))
for i in range(l):
# calculate margin on both sides
m = .5 - (((l * w) + ((l - 1) * hpad)) / 2)
xi = m + (i * hpad) + (i * w)
grid.append([xi, yi, w, h])
row += 1
return grid
# ############# PRIMITIVE PLOTS ##########################################
def _subplot_timeinterval(data, position, epoch, highlights=None, hcolors=None,
labels=True, legend=True, channel=None, shareaxis=None):
"""Creates a new axes with a timeinterval plot.
Creates a matplotlib.axes.Axes within the rectangle specified by
'position' and fills it with a timeinterval plot defined by the
channels and values contained in 'data'.
Parameters
----------
data : wyrm.types.Data
Data object containing the data to plot.
position : Rectangle
The rectangle (x, y, width, height) where the axes will be
created.
epoch : int
The epoch to be plotted. If there are no epochs this has to be
'-1'.
highlights : [[int, int)]
List of tuples containing the start point (included) and end
point (excluded) of each area to be highlighted (default: None).
hcolors : [colors], optional
A list of colors to use for the highlights areas (default:
None).
labels : Boolean, optional
Flag to switch plotting of the usual labels on or off (default:
True)
legend : Boolean, optional
Flag to switch plotting of the legend on or off (default: True).
channel : int, optional
This can be used to plot only a single channel. 'channel' has to
be the index of the desired channel in data.axes[-1] (default:
None)
shareaxis : matplotlib.axes.Axes, optional
An axes to share x- and y-axis with the new axes (default:
None).
Returns
-------
matplotlib.axes.Axes
"""
fig = plt.gcf()
if shareaxis is None:
ax = fig.add_axes(position)
else:
ax = axes.Axes(fig, position, sharex=shareaxis, sharey=shareaxis)
fig.add_axes(ax)
# epoch is -1 when there are no epochs
if epoch == -1:
if channel is None:
ax.plot(data.axes[0], data.data)
else:
ax.plot(data.axes[0], data.data[:, channel])
else:
if channel is None:
ax.plot(data.axes[len(data.axes) - 2], data.data[epoch])
else:
ax.plot(data.axes[len(data.axes) - 2], data.data[epoch, channel])
# plotting of highlights
if highlights is not None:
set_highlights(highlights, hcolors=hcolors, set_axes=[ax])
# labeling of axes
if labels:
ax.set_xlabel(data.units[0])
ax.set_ylabel("$\mu$V")
# labeling of channels
if legend:
if channel is None:
ax.legend(data.axes[len(data.axes) - 1])
else:
ax.legend([data.axes[len(data.axes) - 1][channel]])
ax.grid(True)
return ax
def _subplot_r_square(data, position):
"""Creates a new axes with colored r-sqaure values.
Parameters
----------
data : [float]
A list of floats that will be evenly distributed as colored
tiles.
position : Rectangle
The rectangle (x, y, width, height) where the axes will be
created.
Returns
-------
matplotlib.axes.Axes
"""
fig = plt.gcf()
ax = fig.add_axes(position)
data = np.tile(data, (1, 1))
ax.imshow(data, aspect='auto', interpolation='none')
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
return ax
def _subplot_scale(xvalue, yvalue, position):
"""Creates a new axes with a simple scale.
Parameters
----------
xvalue : String
The text to be presented beneath the x-axis.
yvalue : String
The text to be presented next to the y-axis.
position : Rectangle
The rectangle (x, y, width, height) where the axes will be created.
Returns
-------
matplotlib.axes.Axes
"""
fig = plt.gcf()
ax = fig.add_axes(position)
for item in [fig, ax]:
item.patch.set_visible(False)
ax.axis('off')
ax.add_patch(Rectangle((1, 1), 3, .2, color='black'))
ax.add_patch(Rectangle((1, 1), .1, 2, color='black'))
plt.text(1.5, 2, yvalue)
plt.text(1.5, .25, xvalue)
ax.set_ylim([0, 4])
ax.set_xlim([0, 5])
return ax
def _transform_rect(rect, template):
"""Calculates the position of a relative notated rectangle within
another rectangle.
Parameters
----------
rect : Rectangle
The container rectangle to contain the other reactangle.
template : Rectangle
the rectangle to be contained in the other rectangle.
"""
assert len(rect) == len(template) == 4, "Wrong inputs : [x, y, width, height]"
x = rect[0] + (template[0] * rect[2])
y = rect[1] + (template[1] * rect[3])
w = rect[2] * template[2]
h = rect[3] * template[3]
return [x, y, w, h]
###############################################################################
# Primitives
###############################################################################
[docs]def ax_scalp(v, channels, ax=None, annotate=False, vmin=None, vmax=None, colormap=None):
"""Draw a scalp plot.
Draws a scalp plot on an existing axes. The method takes an array of
values and an array of the corresponding channel names. It matches
the channel names with an internal list of known channels and their
positions to project them correctly on the scalp.
.. warning:: The behaviour for unkown channels is undefined.
Parameters
----------
v : 1d-array of floats
The values for the channels
channels : 1d array of strings
The corresponding channel names for the values in ``v``
ax : Axes, optional
The axes to draw the scalp plot on. If not provided, the
currently activated axes (i.e. ``gca()``) will be taken
annotate : Boolean, optional
Draw the channel names next to the channel markers.
vmin, vmax : float, optional
The display limits for the values in ``v``. If the data in ``v``
contains values between -3..3 and ``vmin`` and ``vmax`` are set
to -1 and 1, all values smaller than -1 and bigger than 1 will
appear the same as -1 and 1. If not set, the maximum absolute
value in ``v`` is taken to calculate both values.
colormap : matplotlib.colors.colormap, optional
A colormap to define the color transitions.
Returns
-------
ax : Axes
the axes on which the plot was drawn
See Also
--------
ax_colorbar
"""
if ax is None:
ax = plt.gca()
# what if we have an unknown channel?
points = [get_channelpos(c) for c in channels]
# calculate the interpolation
x = [i[0] for i in points]
y = [i[1] for i in points]
z = v
# interplolate the in-between values
xx = np.linspace(min(x), max(x), 500)
yy = np.linspace(min(y), max(y), 500)
xx, yy = np.meshgrid(xx, yy)
f = interpolate.LinearNDInterpolator(list(zip(x, y)), z)
zz = f(xx, yy)
# draw the contour map
ctr = ax.contourf(xx, yy, zz, 20, vmin=vmin, vmax=vmax, cmap=colormap)
ax.contour(xx, yy, zz, 5, colors="k", vmin=vmin, vmax=vmax, linewidths=.1)
# paint the head
ax.add_artist(plt.Circle((0, 0), 1, linestyle='solid', linewidth=2, fill=False))
# add a nose
ax.plot([-0.1, 0, 0.1], [1, 1.1, 1], 'k-')
# add markers at channels positions
ax.plot(x, y, 'k+')
# set the axes limits, so the figure is centered on the scalp
ax.set_ylim([-1.05, 1.15])
ax.set_xlim([-1.15, 1.15])
# hide the frame and axes
# hiding the axes might be too much, as this will also hide the x/y
# labels :/
ax.set_frame_on(False)
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
# draw the channel names
if annotate:
for i in zip(channels, list(zip(x, y))):
ax.annotate(" " + i[0], i[1])
ax.set_aspect(1)
plt.sci(ctr)
return ax
[docs]def ax_colorbar(vmin, vmax, ax=None, label=None, ticks=None, colormap=None):
"""Draw a color bar
Draws a color bar on an existing axes. The range of the colors is
defined by ``vmin`` and ``vmax``.
.. note::
Unlike the colorbar method from matplotlib, this method does not
automatically create a new axis for the colorbar. It will paint
in the currently active axis instead, overwriting any existing
plots in that axis. Make sure to create a new axis for the
colorbar.
Parameters
----------
vmin, vmax : float
The minimum and maximum values for the colorbar.
ax : Axes, optional
The axes to draw the scalp plot on. If not provided, the
currently activated axes (i.e. ``gca()``) will be taken
label : string, optional
The label for the colorbar
ticks : list, optional
The tick positions
colormap : matplotlib.colors.colormap, optional
A colormap to define the color transitions.
Returns
-------
ax : Axes
the axes on which the plot was drawn
"""
if ax is None:
ax = plt.gca()
ColorbarBase(ax, norm=Normalize(vmin, vmax), label=label, ticks=ticks, cmap=colormap)
###############################################################################
# Utility Functions
###############################################################################
[docs]def get_channelpos(channame):
"""Return the x/y position of a channel.
This method calculates the stereographic projection of a channel
from ``CHANNEL_10_20``, suitable for a scalp plot.
Parameters
----------
channame : str
Name of the channel, the search is case insensitive.
Returns
-------
x, y : float or None
The projected point on the plane if the point is known,
otherwise ``None``
Examples
--------
>>> plot.get_channelpos('C2')
(0.1720792096741632, 0.0)
>>> # the channels are case insensitive
>>> plot.get_channelpos('c2')
(0.1720792096741632, 0.0)
>>> # lookup for an invalid channel
>>> plot.get_channelpos('foo')
None
"""
channame = channame.lower()
for i in CHANNEL_10_20:
if i[0].lower() == channame:
# convert the 90/4th angular position into x, y, z
p = i[1]
ea, eb = p[0] * (90 / 4), p[1] * (90 / 4)
ea = ea * math.pi / 180
eb = eb * math.pi / 180
x = math.sin(ea) * math.cos(eb)
y = math.sin(eb)
z = math.cos(ea) * math.cos(eb)
# Calculate the stereographic projection.
# Given a unit sphere with radius ``r = 1`` and center at
# the origin. Project the point ``p = (x, y, z)`` from the
# sphere's South pole (0, 0, -1) on a plane on the sphere's
# North pole (0, 0, 1).
#
# The formula is:
#
# P' = P * (2r / (r + z))
#
# We changed the values to move the point of projection
# further below the south pole
mu = 1 / (1.3 + z)
x *= mu
y *= mu
return x, y
return None
[docs]def beautify():
"""Set reasonable defaults matplotlib.
This method replaces matplotlib's default rgb/cmyk colors with the
colarized colors. It also does:
* re-orders the default color cycle
* sets the default linewidth
* replaces the defaault 'RdBu' cmap
* sets the default cmap to 'RdBu'
Examples
--------
You can safely call ``beautify`` right after you've imported the
``plot`` module.
>>> from wyrm import plot
>>> plot.beautify()
"""
def to_mpl_format(r, g, b):
"""Convert 0..255 t0 0..1."""
return r / 256, g / 256, b / 256
# The solarized color palette
base03 = to_mpl_format( 0, 43, 54)
base02 = to_mpl_format( 7, 54, 66)
base01 = to_mpl_format( 88, 110, 117)
base00 = to_mpl_format(101, 123, 131)
base0 = to_mpl_format(131, 148, 150)
base1 = to_mpl_format(147, 161, 161)
base2 = to_mpl_format(238, 232, 213)
base3 = to_mpl_format(253, 246, 227)
yellow = to_mpl_format(181, 137, 0)
orange = to_mpl_format(203, 75, 22)
red = to_mpl_format(220, 50, 47)
magenta = to_mpl_format(211, 54, 130)
violet = to_mpl_format(108, 113, 196)
blue = to_mpl_format( 38, 139, 210)
cyan = to_mpl_format( 42, 161, 152)
green = to_mpl_format(133, 153, 0)
white = (1, 1, 1)#base3
black = base03
# Tverwrite the default color values with our new ones. Those
# single-letter colors are used all over the place in matplotlib, so
# this setting has a huge effect.
mpl.colors.ColorConverter.colors = {
'b': blue,
'c': cyan,
'g': green,
'k': black,
'm': magenta,
'r': red,
'w': white,
'y': yellow
}
# Redefine the existing 'RdBu' (Red-Blue) colormap, with our new
# colors for red and blue
cdict = {
'red' : ((0., blue[0], blue[0]), (0.5, white[0], white[0]), (1., magenta[0], magenta[0])),
'green': ((0., blue[1], blue[1]), (0.5, white[1], white[1]), (1., magenta[1], magenta[1])),
'blue' : ((0., blue[2], blue[2]), (0.5, white[2], white[2]), (1., magenta[2], magenta[2]))
}
mpl.cm.register_cmap('RdBu', data=cdict)
# Reorder the default color cycle
mpl.rcParams['axes.color_cycle'] = ['b', 'm', 'g', 'r', 'c', 'y', 'k']
# Set linewidth in plots to 2
mpl.rcParams['lines.linewidth'] = 2
# Set default cmap
mpl.rcParams['image.cmap'] = 'RdBu'