00001 """
00002 Utilities for plotting ROOT histograms in matplotlib.
00003 """
00004
00005 __license__ = '''\
00006 Copyright (c) 2009-2010 Jeff Klukas <klukas@wisc.edu>
00007
00008 Permission is hereby granted, free of charge, to any person obtaining a copy
00009 of this software and associated documentation files (the "Software"), to deal
00010 in the Software without restriction, including without limitation the rights
00011 to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
00012 copies of the Software, and to permit persons to whom the Software is
00013 furnished to do so, subject to the following conditions:
00014
00015 The above copyright notice and this permission notice shall be included in
00016 all copies or substantial portions of the Software.
00017
00018 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
00019 IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
00020 FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
00021 AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
00022 LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
00023 OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
00024 THE SOFTWARE.
00025 '''
00026
00027
00028
00029 import math
00030 import ROOT
00031 import re
00032 import copy
00033 import array
00034 from rootplot import utilities
00035 import matplotlib as mpl
00036 import matplotlib.pyplot as plt
00037 import matplotlib.transforms as transforms
00038 import numpy as np
00039
00040
00041
00042 _all_whitespace_string = re.compile(r'\s*$')
00043
00044
00045
00046
00047 class Hist2D(utilities.Hist2D):
00048 """A container to hold the parameters from a 2D ROOT histogram."""
00049 def __init__(self, *args, **kwargs):
00050 self.replacements = None
00051 if 'replacements' in kwargs:
00052 self.replacements = kwargs.pop('replacements')
00053 utilities.Hist2D.__init__(self, *args, **kwargs)
00054 def contour(self, **kwargs):
00055 """Draw a contour plot."""
00056 cs = plt.contour(self.x, self.y, self.content, **kwargs)
00057 plt.clabel(cs, inline=1, fontsize=10)
00058 if self.binlabelsx is not None:
00059 plt.xticks(np.arange(self.nbinsx), self.binlabelsx)
00060 if self.binlabelsy is not None:
00061 plt.yticks(np.arange(self.nbinsy), self.binlabelsy)
00062 return cs
00063 def col(self, **kwargs):
00064 """Draw a colored box plot using :func:`matplotlib.pyplot.imshow`."""
00065 if 'cmap' in kwargs:
00066 kwargs['cmap'] = plt.get_cmap(kwargs['cmap'])
00067 plot = plt.imshow(self.content, interpolation='nearest',
00068 extent=[self.xedges[0], self.xedges[-1],
00069 self.yedges[0], self.yedges[-1]],
00070 aspect='auto', origin='lower', **kwargs)
00071 return plot
00072 def colz(self, **kwargs):
00073 """
00074 Draw a colored box plot with a colorbar using
00075 :func:`matplotlib.pyplot.imshow`.
00076 """
00077 plot = self.col(**kwargs)
00078 plt.colorbar(plot)
00079 return plot
00080 def box(self, maxsize=40, **kwargs):
00081 """
00082 Draw a box plot with size indicating content using
00083 :func:`matplotlib.pyplot.scatter`.
00084
00085 The data will be normalized, with the largest box using a marker of
00086 size maxsize (in points).
00087 """
00088 x = np.hstack([self.x for i in range(self.nbinsy)])
00089 y = np.hstack([[yval for i in range(self.nbinsx)] for yval in self.y])
00090 maxvalue = np.max(self.content)
00091 if maxvalue == 0:
00092 maxvalue = 1
00093 sizes = np.array(self.content).flatten() / maxvalue * maxsize
00094 plot = plt.scatter(x, y, sizes, marker='s', **kwargs)
00095 return plot
00096 def TH2F(self, name=""):
00097 """Return a ROOT.TH2F object with contents of this Hist2D."""
00098 th2f = ROOT.TH2F(name, "",
00099 self.nbinsx, array.array('f', self.xedges),
00100 self.nbinsy, array.array('f', self.yedges))
00101 th2f.SetTitle("%s;%s;%s" % (self.title, self.xlabel, self.ylabel))
00102 for ix in range(self.nbinsx):
00103 for iy in range(self.nbinsy):
00104 th2f.SetBinContent(ix + 1, iy + 1, self.content[iy][ix])
00105 return th2f
00106
00107 class Hist(utilities.Hist):
00108 """A container to hold the parameters from a ROOT histogram."""
00109 def __init__(self, *args, **kwargs):
00110 self.replacements = None
00111 if 'replacements' in kwargs:
00112 self.replacements = kwargs.pop('replacements')
00113 utilities.Hist.__init__(self, *args, **kwargs)
00114 def _prepare_xaxis(self, rotation=0, alignment='center'):
00115 """Apply bounds and text labels on x axis."""
00116 if self.binlabels is not None:
00117 binwidth = (self.xedges[-1] - self.xedges[0]) / self.nbins
00118 plt.xticks(self.x, self.binlabels,
00119 rotation=rotation, ha=alignment)
00120 plt.xlim(self.xedges[0], self.xedges[-1])
00121
00122 def _prepare_yaxis(self, rotation=0, alignment='center'):
00123 """Apply bounds and text labels on y axis."""
00124 if self.binlabels is not None:
00125 binwidth = (self.xedges[-1] - self.xedges[0]) / self.nbins
00126 plt.yticks(self.x, self.binlabels,
00127 rotation=rotation, va=alignment)
00128 plt.ylim(self.xedges[0], self.xedges[-1])
00129
00130 def show_titles(self, **kwargs):
00131 """Print the title and axis labels to the current figure."""
00132 replacements = kwargs.get('replacements', None) or self.replacements
00133 plt.title(replace(self.title, replacements))
00134 plt.xlabel(replace(self.xlabel, replacements))
00135 plt.ylabel(replace(self.ylabel, replacements))
00136 def hist(self, label_rotation=0, label_alignment='center', **kwargs):
00137 """
00138 Generate a matplotlib hist figure.
00139
00140 All additional keyword arguments will be passed to
00141 :func:`matplotlib.pyplot.hist`.
00142 """
00143 kwargs.pop('fmt', None)
00144 replacements = kwargs.get('replacements', None) or self.replacements
00145 weights = self.y
00146
00147 if self.y == [0] * self.nbins:
00148 weights = [1.e-10] * self.nbins
00149 plot = plt.hist(self.x, weights=weights, bins=self.xedges,
00150 label=replace(self.label, replacements), **kwargs)
00151 self._prepare_xaxis(label_rotation, label_alignment)
00152 return plot
00153 def errorbar(self, xerr=False, yerr=False, label_rotation=0,
00154 label_alignment='center', **kwargs):
00155 """
00156 Generate a matplotlib errorbar figure.
00157
00158 All additional keyword arguments will be passed to
00159 :func:`matplotlib.pyplot.errorbar`.
00160 """
00161 if xerr:
00162 kwargs['xerr'] = self.xerr
00163 if yerr:
00164 kwargs['yerr'] = self.yerr
00165 replacements = kwargs.get('replacements', None) or self.replacements
00166 errorbar = plt.errorbar(self.x, self.y,
00167 label=replace(self.label, replacements),
00168 **kwargs)
00169 self._prepare_xaxis(label_rotation, label_alignment)
00170 return errorbar
00171 def errorbarh(self, xerr=False, yerr=False, label_rotation=0,
00172 label_alignment='center', **kwargs):
00173 """
00174 Generate a horizontal matplotlib errorbar figure.
00175
00176 All additional keyword arguments will be passed to
00177 :func:`matplotlib.pyplot.errorbar`.
00178 """
00179 if xerr: kwargs['xerr'] = self.yerr
00180 if yerr: kwargs['yerr'] = self.xerr
00181 replacements = kwargs.get('replacements', None) or self.replacements
00182 errorbar = plt.errorbar(self.y, self.x,
00183 label=replace(self.label, replacements),
00184 **kwargs)
00185 self._prepare_yaxis(label_rotation, label_alignment)
00186 return errorbar
00187 def bar(self, xerr=False, yerr=False, xoffset=0., width=0.8,
00188 label_rotation=0, label_alignment='center', **kwargs):
00189 """
00190 Generate a matplotlib bar figure.
00191
00192 All additional keyword arguments will be passed to
00193 :func:`matplotlib.pyplot.bar`.
00194 """
00195 kwargs.pop('fmt', None)
00196 if xerr: kwargs['xerr'] = self.av_xerr()
00197 if yerr: kwargs['yerr'] = self.av_yerr()
00198 replacements = kwargs.get('replacements', None) or self.replacements
00199 ycontent = [self.xedges[i] + self.width[i] * xoffset
00200 for i in range(len(self.xedges) - 1)]
00201 width = [x * width for x in self.width]
00202 bar = plt.bar(ycontent, self.y, width,
00203 label=replace(self.label, replacements), **kwargs)
00204 self._prepare_xaxis(label_rotation, label_alignment)
00205 return bar
00206 def barh(self, xerr=False, yerr=False, yoffset=0., width=0.8,
00207 label_rotation=0, label_alignment='center', **kwargs):
00208 """
00209 Generate a horizontal matplotlib bar figure.
00210
00211 All additional keyword arguments will be passed to
00212 :func:`matplotlib.pyplot.bar`.
00213 """
00214 kwargs.pop('fmt', None)
00215 if xerr: kwargs['xerr'] = self.av_yerr()
00216 if yerr: kwargs['yerr'] = self.av_xerr()
00217 replacements = kwargs.get('replacements', None) or self.replacements
00218 xcontent = [self.xedges[i] + self.width[i] * yoffset
00219 for i in range(len(self.xedges) - 1)]
00220 width = [x * width for x in self.width]
00221 barh = plt.barh(xcontent, self.y, width,
00222 label=replace(self.label, replacements),
00223 **kwargs)
00224 self._prepare_yaxis(label_rotation, label_alignment)
00225 return barh
00226
00227 class HistStack(utilities.HistStack):
00228 """
00229 A container to hold Hist objects for plotting together.
00230
00231 When plotting, the title and the x and y labels of the last Hist added
00232 will be used unless specified otherwise in the constructor.
00233 """
00234 def __init__(self, *args, **kwargs):
00235 if 'replacements' in kwargs:
00236 self.replacements = kwargs.pop('replacements')
00237 utilities.HistStack.__init__(self, *args, **kwargs)
00238 def show_titles(self, **kwargs):
00239 self.hists[-1].show_titles()
00240 def hist(self, label_rotation=0, **kwargs):
00241 """
00242 Make a matplotlib hist plot.
00243
00244 Any additional keyword arguments will be passed to
00245 :func:`matplotlib.pyplot.hist`, which allows a vast array of
00246 possibilities. Particlularly, the *histtype* values such as
00247 ``'barstacked'`` and ``'stepfilled'`` give substantially different
00248 results. You will probably want to include a transparency value
00249 (i.e. *alpha* = 0.5).
00250 """
00251 contents = np.dstack([hist.y for hist in self.hists])
00252 xedges = self.hists[0].xedges
00253 x = np.dstack([hist.x for hist in self.hists])[0]
00254 labels = [hist.label for hist in self.hists]
00255 try:
00256 clist = [item['color'] for item in self.kwargs]
00257 plt.gca().set_color_cycle(clist)
00258
00259 except:
00260 pass
00261 plot = plt.hist(x, weights=contents, bins=xedges,
00262 label=labels, **kwargs)
00263 def bar3d(self, **kwargs):
00264
00265 from mpl_toolkits.mplot3d import Axes3D
00266 fig = plt.figure()
00267 ax = Axes3D(fig)
00268 plots = []
00269 labels = []
00270 for i, hist in enumerate(self.hists):
00271 if self.title is not None: hist.title = self.title
00272 if self.xlabel is not None: hist.xlabel = self.xlabel
00273 if self.ylabel is not None: hist.ylabel = self.ylabel
00274 labels.append(hist.label)
00275 all_kwargs = copy.copy(kwargs)
00276 all_kwargs.update(self.kwargs[i])
00277 bar = ax.bar(hist.x, hist.y, zs=i, zdir='y', width=hist.width,
00278 **all_kwargs)
00279 plots.append(bar)
00280 from matplotlib.ticker import FixedLocator
00281 locator = FixedLocator(range(len(labels)))
00282 ax.w_yaxis.set_major_locator(locator)
00283 ax.w_yaxis.set_ticklabels(labels)
00284 ax.set_ylim3d(-1, len(labels))
00285 return plots
00286 def barstack(self, **kwargs):
00287 """
00288 Make a matplotlib bar plot, with each Hist stacked upon the last.
00289
00290 Any additional keyword arguments will be passed to
00291 :func:`matplotlib.pyplot.bar`.
00292 """
00293 bottom = None
00294 plots = []
00295 for i, hist in enumerate(self.hists):
00296 if self.title is not None: hist.title = self.title
00297 if self.xlabel is not None: hist.xlabel = self.xlabel
00298 if self.ylabel is not None: hist.ylabel = self.ylabel
00299 all_kwargs = copy.copy(kwargs)
00300 all_kwargs.update(self.kwargs[i])
00301 bar = hist.bar(bottom=bottom, **all_kwargs)
00302 plots.append(bar)
00303 if not bottom: bottom = [0. for i in range(self.hists[0].nbins)]
00304 bottom = [sum(pair) for pair in zip(bottom, hist.y)]
00305 return plots
00306 def histstack(self, **kwargs):
00307 """
00308 Make a matplotlib hist plot, with each Hist stacked upon the last.
00309
00310 Any additional keyword arguments will be passed to
00311 :func:`matplotlib.pyplot.hist`.
00312 """
00313 bottom = None
00314 plots = []
00315 cumhist = None
00316 for i, hist in enumerate(self.hists):
00317 if cumhist:
00318 cumhist = hist + cumhist
00319 else:
00320 cumhist = copy.copy(hist)
00321 if self.title is not None: cumhist.title = self.title
00322 if self.xlabel is not None: cumhist.xlabel = self.xlabel
00323 if self.ylabel is not None: cumhist.ylabel = self.ylabel
00324 all_kwargs = copy.copy(kwargs)
00325 all_kwargs.update(self.kwargs[i])
00326 zorder = 0 + float(len(self) - i)/len(self)
00327 plot = cumhist.hist(zorder=zorder, **all_kwargs)
00328 plots.append(plot)
00329 return plots
00330 def barcluster(self, width=0.8, **kwargs):
00331 """
00332 Make a clustered bar plot.
00333
00334 Any additional keyword arguments will be passed to
00335 :func:`matplotlib.pyplot.bar`.
00336 """
00337 plots = []
00338 spacer = (1. - width) / 2
00339 width = width / len(self.hists)
00340 for i, hist in enumerate(self.hists):
00341 if self.title is not None: hist.title = self.title
00342 if self.xlabel is not None: hist.xlabel = self.xlabel
00343 if self.ylabel is not None: hist.ylabel = self.ylabel
00344 all_kwargs = copy.copy(kwargs)
00345 all_kwargs.update(self.kwargs[i])
00346 bar = hist.bar(xoffset=width*i + spacer, width=width, **all_kwargs)
00347 plots.append(bar)
00348 return plots
00349 def barh(self, width=0.8, **kwargs):
00350 """
00351 Make a horizontal clustered matplotlib bar plot.
00352
00353 Any additional keyword arguments will be passed to
00354 :func:`matplotlib.pyplot.bar`.
00355 """
00356 plots = []
00357 spacer = (1. - width) / 2
00358 width = width / len(self.hists)
00359 for i, hist in enumerate(self.hists):
00360 if self.title is not None: hist.title = self.title
00361 if self.xlabel is not None: hist.ylabel = self.xlabel
00362 if self.ylabel is not None: hist.xlabel = self.ylabel
00363 all_kwargs = copy.copy(kwargs)
00364 all_kwargs.update(self.kwargs[i])
00365 bar = hist.barh(yoffset=width*i + spacer, width=width, **all_kwargs)
00366 plots.append(bar)
00367 return plots
00368 def bar(self, **kwargs):
00369 """
00370 Make a bar plot, with all Hists in the stack overlaid.
00371
00372 Any additional keyword arguments will be passed to
00373 :func:`matplotlib.pyplot.bar`. You will probably want to set a
00374 transparency value (i.e. *alpha* = 0.5).
00375 """
00376 plots = []
00377 for i, hist in enumerate(self.hists):
00378 if self.title is not None: hist.title = self.title
00379 if self.xlabel is not None: hist.xlabel = self.xlabel
00380 if self.ylabel is not None: hist.ylabel = self.ylabel
00381 all_kwargs = copy.copy(kwargs)
00382 all_kwargs.update(self.kwargs[i])
00383 bar = hist.bar(**all_kwargs)
00384 plots.append(bar)
00385 return plots
00386 def errorbar(self, offset=False, **kwargs):
00387 """
00388 Make a matplotlib errorbar plot, with all Hists in the stack overlaid.
00389
00390 Passing 'offset=True' will slightly offset each dataset so overlapping
00391 errorbars are still visible. Any additional keyword arguments will
00392 be passed to :func:`matplotlib.pyplot.errorbar`.
00393 """
00394 plots = []
00395 for i, hist in enumerate(self.hists):
00396 if self.title is not None: hist.title = self.title
00397 if self.xlabel is not None: hist.xlabel = self.xlabel
00398 if self.ylabel is not None: hist.ylabel = self.ylabel
00399 all_kwargs = copy.copy(kwargs)
00400 all_kwargs.update(self.kwargs[i])
00401 transform = plt.gca().transData
00402 if offset:
00403 index_offset = (len(self.hists) - 1)/2.
00404 pixel_offset = 1./72 * (i - index_offset)
00405 transform = transforms.ScaledTranslation(
00406 pixel_offset, 0, plt.gcf().dpi_scale_trans)
00407 transform = plt.gca().transData + transform
00408 errorbar = hist.errorbar(transform=transform, **all_kwargs)
00409 plots.append(errorbar)
00410 return plots
00411 def errorbarh(self, **kwargs):
00412 """
00413 Make a horizontal matplotlib errorbar plot, with all Hists in the
00414 stack overlaid.
00415
00416 Any additional keyword arguments will be passed to
00417 :func:`matplotlib.pyplot.errorbar`.
00418 """
00419 plots = []
00420 for i, hist in enumerate(self.hists):
00421 if self.title is not None: hist.title = self.title
00422 if self.xlabel is not None: hist.ylabel = self.xlabel
00423 if self.ylabel is not None: hist.xlabel = self.ylabel
00424 all_kwargs = copy.copy(kwargs)
00425 all_kwargs.update(self.kwargs[i])
00426 errorbar = hist.errorbarh(**all_kwargs)
00427 plots.append(errorbar)
00428 return plots
00429
00430
00431
00432 class RootFile(utilities.RootFile):
00433 """A wrapper for TFiles, allowing easier access to methods."""
00434 def get(self, object_name, path=None):
00435 try:
00436 return utilities.RootFile.get(self, object_name, path,
00437 Hist, Hist2D)
00438 except ReferenceError, e:
00439 raise ReferenceError(e)
00440
00441
00442
00443 def replace(string, replacements):
00444 """
00445 Modify a string based on a list of patterns and substitutions.
00446
00447 replacements should be a list of two-entry tuples, the first entry giving
00448 a string to search for and the second entry giving the string with which
00449 to replace it. If replacements includes a pattern entry containing
00450 'use_regexp', then all patterns will be treated as regular expressions
00451 using re.sub.
00452 """
00453 if not replacements:
00454 return string
00455 if 'use_regexp' in [x for x,y in replacements]:
00456 for pattern, repl in [x for x in replacements
00457 if x[0] != 'use_regexp']:
00458 string = re.sub(pattern, repl, string)
00459 else:
00460 for pattern, repl in replacements:
00461 string = string.replace(pattern, repl)
00462 if re.match(_all_whitespace_string, string):
00463 return ""
00464 return string
00465