Source code for arsenal.viz.util

from os import environ
DISPLAY = True
if not environ.get('DISPLAY'):
    import matplotlib
    #print 'Not a display environment.'
    matplotlib.use('Agg')
    DISPLAY = False

import pandas as pd, numpy as np
import matplotlib.pyplot as pl
#from sys import stderr
from collections import defaultdict
from contextlib import contextmanager
from matplotlib.backends.backend_pdf import PdfPages
from mpl_toolkits.mplot3d import Axes3D
from arsenal.terminal import colors
#from arsenal.viz.covariance_ellipse import covariance_ellipse
from arsenal.misc import ddict


#from palettable.colorbrewer import qualitative


simple_palette = ['r','g','b','y','c','m','k']
#default_palette = np.array(qualitative.Set1_6.mpl_colors)


#def name2color(palette = default_palette):
[docs]def name2color(palette = simple_palette): "Create a mapping from names to matplotlib colors." palette = list(palette) i = -1 n = len(palette) def next_color(): nonlocal i i += 1 return palette[i % n] return defaultdict(next_color)
[docs]def save_plots(pdf): "save all plots to pdf" pp = PdfPages(pdf) for i in pl.get_fignums(): pl.figure(i) pl.savefig(pp, format='pdf') pp.close() print(colors.yellow % 'saved plots to "%s"' % pdf)
# global reference to all of the plots
[docs]def newax(): return pl.figure().add_subplot(111)
AX = defaultdict(newax) DATA = defaultdict(list) # TODO: [2018-04-18 Wed] Add support for visualizing constraints. Apparently, # I'm constantly plotting constraints these days. # # - There may be better ways to implement this (better defaults at least) # # - Look into rasbt's decision boundary plotting script for some example usage of all these things. # https://github.com/rasbt/mlxtend/blob/62aea2a9fb6fafdecedfa041a2121c002e47dac9/mlxtend/plotting/decision_regions.py # # - figure out the difference between contoutr/contourf) # # - Look into value interpolation strategies # https://matplotlib.org/gallery/images_contours_and_fields/triinterp_demo.html#sphx-glr-gallery-images-contours-and-fields-triinterp-demo-py. # # - For constraints look at value mask: # https://matplotlib.org/gallery/images_contours_and_fields/contour_corner_mask.html#sphx-glr-gallery-images-contours-and-fields-contour-corner-mask-py # # - For rendering constraints, we might want to use hatching instead of # something opaque. #
[docs]def contour_plot(f, xdomain, ydomain, color='viridis', alpha=0.5, levels=None, ax=None): "Contour plot of a function of two variables." from arsenal import iterview if ax is None: ax = pl.gca() [xmin, xmax, _] = xdomain; [ymin, ymax, _] = ydomain X, Y = np.meshgrid(np.linspace(*xdomain), np.linspace(*ydomain)) Z = np.array([f(np.array([x,y])) for (x,y) in iterview(zip(X.flat, Y.flat), length=len(X.flat))]).reshape(X.shape) contours = ax.contour(X, Y, Z, 20, colors='black', levels=levels) ax.clabel(contours, inline=True, fontsize=8) if color is not None: ax.imshow(Z, extent=[xmin, xmax, ymin, ymax], origin='lower', cmap=color, alpha=alpha) #ax.axis(aspect='scalar') ax.figure.tight_layout() ax.set_xlim(xmin,xmax); ax.set_ylim(ymin,ymax)
# TODO: No need to say "plot" we're already in a module called "viz". The whole # point is reduce clutter when plotting. contour = contour_plot # TODO: Create an alias which case-analyzes and plots 3d vs 2d accordingly? # TODO: also support interactive sliders and animation for when there are more parameters. use the same range notation.
[docs]def plot3d(f, xdomain, ydomain, ax=None): "3d surface plot of a function of two variables." #[xmin, xmax, _] = xdomain; [ymin, ymax, _] = ydomain X, Y = np.meshgrid(np.linspace(*xdomain), np.linspace(*ydomain)) Z = np.array([f(np.array([x,y])) for (x,y) in zip(X.flat, Y.flat)]).reshape(X.shape) ax = pl.figure().gca(projection='3d') if ax is None else ax ax.plot_surface(X, Y, Z, cmap='viridis', linewidth=0, antialiased=True) return ax
[docs]class plot_xsection: def __init__(self, f, a, b, n, ax=None, opts=None): """ Plot a cross section of `f` by interpolating from `x0 to `x1` by `n` evenly space points. """ if opts is None: opts = {} if ax is None: ax = pl.gca() self.n = n self.a = a self.b = b self.ts = np.linspace(0,1,n) self.fs = [f(xt) for xt in self.curve()] ax.plot(self.ts, self.fs, **opts) ax.set_xlabel('interpolation coefficient') self.ax = ax def __call__(self, f, opts=None): return plot_xsection(f=f, a=self.a, b=self.b, n=self.n, ax=self.ax, opts=opts)
[docs] def curve(self): "Sweep a curve in parameter spaces which is convex combination of `a` and `b`." for t in self.ts: yield self.a*(1-t) + self.b*t
[docs]class NumericalDebug: """Incrementally builds a DataFrame, includes plotting and comparison method. The quickest way to use it is >>> from arsenal.viz import DEBUG >>> d = DEBUG['test1'] >>> d.update(want=1, have=1) >>> d.update(want=1, have=1.01) >>> d.update(want=1, have=0.99) want have 0 1 1.00 1 1 1.01 2 1 0.99 To plots and runs numerical comparison tests, >>> d.compare() # doctest: +SKIP """ def __init__(self, name): self.name = name self._data = [] self._df = None self.ax = None self.uptodate = True @property def df(self): "lazily make DataFrame from _data." if not self.uptodate: self._df = pd.DataFrame(self._data) self.uptodate = True return self._df
[docs] def update(self, **kw): "Pass in column values for the row by name as keyword arguments." self._data.append(kw) self.uptodate = False return self
[docs] def compare(self, want='want', have='have', show_regression=1, scatter=1, **kw): from arsenal.maths import compare if self.ax is None: self.ax = pl.figure().add_subplot(111) if self.df.empty: return with update_ax(self.ax): compare(want, have, data=self.df).plot(ax=self.ax, **kw)
# Global references to numerical debugger class. DEBUG = ddict(NumericalDebug)
[docs]@contextmanager def lineplot(name, with_ax=False, halflife=20, xlabel=None, ylabel=None, title=None, **style): with axman(name, xlabel=xlabel, ylabel=ylabel, title=title) as ax: data = DATA[name] if with_ax: yield (data, ax) else: yield data ax.plot(list(range(len(data))), data, alpha=0.5, **style) if halflife: ax.plot(pd.Series(data).ewm(halflife=halflife).mean(), alpha=0.5, c='k', lw=2)
[docs]@contextmanager def axman(name, xlabel=None, ylabel=None, title=None, clear=True): """`axman` is axis manager. Manages clearing, updating and maintaining a global handle to a named plot. """ ax = AX[name] prev_ax = pl.gca() with update_ax(ax, clear=clear): _try_sca(ax) yield ax if xlabel: ax.set_xlabel(xlabel) if ylabel: ax.set_ylabel(ylabel) ax.set_title(title or name) # `title` overrides `name`. #ax.figure.tight_layout() _try_sca(prev_ax)
def _try_sca(ax): try: pl.sca(ax) except ValueError: pass
[docs]@contextmanager def update_ax(ax, clear=True): "Manages clearing and updating a plot." if not hasattr(ax, '_did_show'): ax._did_show = False if clear: ax.clear() yield for _ in range(2): try: ax.figure.canvas.draw_idle() ax.figure.canvas.flush_events() if not ax._did_show: pl.show(block=False) ax._did_show = True except (NotImplementedError, AttributeError): #print >> stderr, 'warning failed to update plot.' pass
[docs]@contextmanager def scatter_manager(name, with_ax=False, xlabel=None, ylabel=None, title=None, **style): with axman(name, xlabel=xlabel, ylabel=ylabel, title=title) as ax: data = DATA[name] if with_ax: yield (data, ax) else: yield data x,y = list(zip(*data)) ax.scatter(x, y, alpha=0.5, lw=0, **style)
[docs]def test(): d = DEBUG['test1'] d.update(want=1, have=1) d.update(want=1, have=1.01) d.update(want=1, have=0.99) print(d.df)
if __name__ == '__main__': test()