CMS 3D CMS Logo

/data/refman/pasoursint/CMSSW_5_3_3/src/PhysicsTools/PythonAnalysis/python/rootplot/root2matplotlib.py

Go to the documentation of this file.
00001 """
00002 Utilities for plotting ROOT histograms in matplotlib.
00003 """
00004 
00005 __license__ = '''\
00006 Copyright (c) 2009-2010 Jeff Klukas <klukas@wisc.edu>
00007 
00008 Permission is hereby granted, free of charge, to any person obtaining a copy
00009 of this software and associated documentation files (the "Software"), to deal
00010 in the Software without restriction, including without limitation the rights
00011 to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
00012 copies of the Software, and to permit persons to whom the Software is
00013 furnished to do so, subject to the following conditions:
00014 
00015 The above copyright notice and this permission notice shall be included in
00016 all copies or substantial portions of the Software.
00017 
00018 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
00019 IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
00020 FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
00021 AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
00022 LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
00023 OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
00024 THE SOFTWARE.
00025 '''
00026 
00027 ################ Import python libraries
00028 
00029 import math
00030 import ROOT
00031 import re
00032 import copy
00033 import array
00034 from rootplot import utilities
00035 import matplotlib as mpl
00036 import matplotlib.pyplot as plt
00037 import matplotlib.transforms as transforms
00038 import numpy as np
00039 
00040 ################ Define constants
00041 
00042 _all_whitespace_string = re.compile(r'\s*$')
00043 
00044 
00045 ################ Define classes
00046 
00047 class Hist2D(utilities.Hist2D):
00048     """A container to hold the parameters from a 2D ROOT histogram."""
00049     def __init__(self, *args, **kwargs):
00050         self.replacements = None
00051         if 'replacements' in kwargs:
00052             self.replacements = kwargs.pop('replacements')
00053         utilities.Hist2D.__init__(self, *args, **kwargs)
00054     def contour(self, **kwargs):
00055         """Draw a contour plot."""
00056         cs = plt.contour(self.x, self.y, self.content, **kwargs)
00057         plt.clabel(cs, inline=1, fontsize=10)
00058         if self.binlabelsx is not None:
00059             plt.xticks(np.arange(self.nbinsx), self.binlabelsx)
00060         if self.binlabelsy is not None:
00061             plt.yticks(np.arange(self.nbinsy), self.binlabelsy)
00062         return cs
00063     def col(self, **kwargs):
00064         """Draw a colored box plot using :func:`matplotlib.pyplot.imshow`."""
00065         if 'cmap' in kwargs:
00066             kwargs['cmap'] = plt.get_cmap(kwargs['cmap'])
00067         plot = plt.imshow(self.content, interpolation='nearest',
00068                           extent=[self.xedges[0], self.xedges[-1],
00069                                   self.yedges[0], self.yedges[-1]],
00070                           aspect='auto', origin='lower', **kwargs)
00071         return plot
00072     def colz(self, **kwargs):
00073         """
00074         Draw a colored box plot with a colorbar using
00075         :func:`matplotlib.pyplot.imshow`.
00076         """
00077         plot = self.col(**kwargs)
00078         plt.colorbar(plot)
00079         return plot
00080     def box(self, maxsize=40, **kwargs):
00081         """
00082         Draw a box plot with size indicating content using
00083         :func:`matplotlib.pyplot.scatter`.
00084         
00085         The data will be normalized, with the largest box using a marker of
00086         size maxsize (in points).
00087         """
00088         x = np.hstack([self.x for i in range(self.nbinsy)])
00089         y = np.hstack([[yval for i in range(self.nbinsx)] for yval in self.y])
00090         maxvalue = np.max(self.content)
00091         if maxvalue == 0:
00092             maxvalue = 1
00093         sizes = np.array(self.content).flatten() / maxvalue * maxsize
00094         plot = plt.scatter(x, y, sizes, marker='s', **kwargs)
00095         return plot
00096     def TH2F(self, name=""):
00097         """Return a ROOT.TH2F object with contents of this Hist2D."""
00098         th2f = ROOT.TH2F(name, "",
00099                          self.nbinsx, array.array('f', self.xedges),
00100                          self.nbinsy, array.array('f', self.yedges))
00101         th2f.SetTitle("%s;%s;%s" % (self.title, self.xlabel, self.ylabel))
00102         for ix in range(self.nbinsx):
00103             for iy in range(self.nbinsy):
00104                 th2f.SetBinContent(ix + 1, iy + 1, self.content[iy][ix])
00105         return th2f
00106 
00107 class Hist(utilities.Hist):
00108     """A container to hold the parameters from a ROOT histogram."""
00109     def __init__(self, *args, **kwargs):
00110         self.replacements = None
00111         if 'replacements' in kwargs:
00112             self.replacements = kwargs.pop('replacements')
00113         utilities.Hist.__init__(self, *args, **kwargs)
00114     def _prepare_xaxis(self, rotation=0, alignment='center'):
00115         """Apply bounds and text labels on x axis."""
00116         if self.binlabels is not None:
00117             binwidth = (self.xedges[-1] - self.xedges[0]) / self.nbins
00118             plt.xticks(self.x, self.binlabels,
00119                        rotation=rotation, ha=alignment)
00120         plt.xlim(self.xedges[0], self.xedges[-1])
00121 
00122     def _prepare_yaxis(self, rotation=0, alignment='center'):
00123         """Apply bounds and text labels on y axis."""
00124         if self.binlabels is not None:
00125             binwidth = (self.xedges[-1] - self.xedges[0]) / self.nbins
00126             plt.yticks(self.x, self.binlabels,
00127                        rotation=rotation, va=alignment)
00128         plt.ylim(self.xedges[0], self.xedges[-1])
00129 
00130     def show_titles(self, **kwargs):
00131         """Print the title and axis labels to the current figure."""
00132         replacements = kwargs.get('replacements', None) or self.replacements
00133         plt.title(replace(self.title, replacements))
00134         plt.xlabel(replace(self.xlabel, replacements))
00135         plt.ylabel(replace(self.ylabel, replacements))
00136     def hist(self, label_rotation=0, label_alignment='center', **kwargs):
00137         """
00138         Generate a matplotlib hist figure.
00139 
00140         All additional keyword arguments will be passed to
00141         :func:`matplotlib.pyplot.hist`.
00142         """
00143         kwargs.pop('fmt', None)
00144         replacements = kwargs.get('replacements', None) or self.replacements
00145         weights = self.y
00146         # Kludge to avoid mpl bug when plotting all zeros
00147         if self.y == [0] * self.nbins:
00148             weights = [1.e-10] * self.nbins
00149         plot = plt.hist(self.x, weights=weights, bins=self.xedges,
00150                         label=replace(self.label, replacements), **kwargs)
00151         self._prepare_xaxis(label_rotation, label_alignment)
00152         return plot
00153     def errorbar(self, xerr=False, yerr=False, label_rotation=0,
00154                  label_alignment='center', **kwargs):
00155         """
00156         Generate a matplotlib errorbar figure.
00157 
00158         All additional keyword arguments will be passed to
00159         :func:`matplotlib.pyplot.errorbar`.
00160         """
00161         if xerr:
00162             kwargs['xerr'] = self.xerr
00163         if yerr:
00164             kwargs['yerr'] = self.yerr
00165         replacements = kwargs.get('replacements', None) or self.replacements
00166         errorbar = plt.errorbar(self.x, self.y,
00167                                 label=replace(self.label, replacements),
00168                                 **kwargs)
00169         self._prepare_xaxis(label_rotation, label_alignment)
00170         return errorbar
00171     def errorbarh(self, xerr=False, yerr=False, label_rotation=0,
00172                   label_alignment='center', **kwargs):
00173         """
00174         Generate a horizontal matplotlib errorbar figure.
00175 
00176         All additional keyword arguments will be passed to
00177         :func:`matplotlib.pyplot.errorbar`.
00178         """
00179         if xerr: kwargs['xerr'] = self.yerr
00180         if yerr: kwargs['yerr'] = self.xerr
00181         replacements = kwargs.get('replacements', None) or self.replacements
00182         errorbar = plt.errorbar(self.y, self.x,
00183                                 label=replace(self.label, replacements),
00184                                 **kwargs)
00185         self._prepare_yaxis(label_rotation, label_alignment)
00186         return errorbar
00187     def bar(self, xerr=False, yerr=False, xoffset=0., width=0.8, 
00188             label_rotation=0, label_alignment='center', **kwargs):
00189         """
00190         Generate a matplotlib bar figure.
00191 
00192         All additional keyword arguments will be passed to
00193         :func:`matplotlib.pyplot.bar`.
00194         """
00195         kwargs.pop('fmt', None)
00196         if xerr: kwargs['xerr'] = self.av_xerr()
00197         if yerr: kwargs['yerr'] = self.av_yerr()
00198         replacements = kwargs.get('replacements', None) or self.replacements
00199         ycontent = [self.xedges[i] + self.width[i] * xoffset
00200                     for i in range(len(self.xedges) - 1)]
00201         width = [x * width for x in self.width]
00202         bar = plt.bar(ycontent, self.y, width,
00203                       label=replace(self.label, replacements), **kwargs)
00204         self._prepare_xaxis(label_rotation, label_alignment)
00205         return bar
00206     def barh(self, xerr=False, yerr=False, yoffset=0., width=0.8,
00207              label_rotation=0, label_alignment='center', **kwargs):
00208         """
00209         Generate a horizontal matplotlib bar figure.
00210 
00211         All additional keyword arguments will be passed to
00212         :func:`matplotlib.pyplot.bar`.
00213         """
00214         kwargs.pop('fmt', None)
00215         if xerr: kwargs['xerr'] = self.av_yerr()
00216         if yerr: kwargs['yerr'] = self.av_xerr()
00217         replacements = kwargs.get('replacements', None) or self.replacements
00218         xcontent = [self.xedges[i] + self.width[i] * yoffset
00219                     for i in range(len(self.xedges) - 1)]
00220         width = [x * width for x in self.width]
00221         barh = plt.barh(xcontent, self.y, width,
00222                         label=replace(self.label, replacements),
00223                        **kwargs)
00224         self._prepare_yaxis(label_rotation, label_alignment)
00225         return barh
00226 
00227 class HistStack(utilities.HistStack):
00228     """
00229     A container to hold Hist objects for plotting together.
00230 
00231     When plotting, the title and the x and y labels of the last Hist added
00232     will be used unless specified otherwise in the constructor.
00233     """
00234     def __init__(self, *args, **kwargs):
00235         if 'replacements' in kwargs:
00236             self.replacements = kwargs.pop('replacements')
00237         utilities.HistStack.__init__(self, *args, **kwargs)
00238     def show_titles(self, **kwargs):
00239         self.hists[-1].show_titles()
00240     def hist(self, label_rotation=0, **kwargs):
00241         """
00242         Make a matplotlib hist plot.
00243 
00244         Any additional keyword arguments will be passed to
00245         :func:`matplotlib.pyplot.hist`, which allows a vast array of
00246         possibilities.  Particlularly, the *histtype* values such as
00247         ``'barstacked'`` and ``'stepfilled'`` give substantially different
00248         results.  You will probably want to include a transparency value
00249         (i.e. *alpha* = 0.5).
00250         """
00251         contents = np.dstack([hist.y for hist in self.hists])
00252         xedges = self.hists[0].xedges
00253         x = np.dstack([hist.x for hist in self.hists])[0]
00254         labels = [hist.label for hist in self.hists]
00255         try:
00256             clist = [item['color'] for item in self.kwargs]
00257             plt.gca().set_color_cycle(clist)
00258             ## kwargs['color'] = clist # For newer version of matplotlib
00259         except:
00260             pass
00261         plot = plt.hist(x, weights=contents, bins=xedges,
00262                         label=labels, **kwargs)
00263     def bar3d(self, **kwargs):
00264         #### Not yet ready for primetime
00265         from mpl_toolkits.mplot3d import Axes3D
00266         fig = plt.figure()
00267         ax = Axes3D(fig)
00268         plots = []
00269         labels = []
00270         for i, hist in enumerate(self.hists):
00271             if self.title  is not None: hist.title  = self.title
00272             if self.xlabel is not None: hist.xlabel = self.xlabel
00273             if self.ylabel is not None: hist.ylabel = self.ylabel
00274             labels.append(hist.label)
00275             all_kwargs = copy.copy(kwargs)
00276             all_kwargs.update(self.kwargs[i])
00277             bar = ax.bar(hist.x, hist.y, zs=i, zdir='y', width=hist.width,
00278                          **all_kwargs)
00279             plots.append(bar)
00280         from matplotlib.ticker import FixedLocator
00281         locator = FixedLocator(range(len(labels)))
00282         ax.w_yaxis.set_major_locator(locator)
00283         ax.w_yaxis.set_ticklabels(labels)
00284         ax.set_ylim3d(-1, len(labels))
00285         return plots
00286     def barstack(self, **kwargs):
00287         """
00288         Make a matplotlib bar plot, with each Hist stacked upon the last.
00289 
00290         Any additional keyword arguments will be passed to
00291         :func:`matplotlib.pyplot.bar`.
00292         """
00293         bottom = None # if this is set to zeroes, it fails for log y
00294         plots = []
00295         for i, hist in enumerate(self.hists):
00296             if self.title  is not None: hist.title  = self.title
00297             if self.xlabel is not None: hist.xlabel = self.xlabel
00298             if self.ylabel is not None: hist.ylabel = self.ylabel
00299             all_kwargs = copy.copy(kwargs)
00300             all_kwargs.update(self.kwargs[i])
00301             bar = hist.bar(bottom=bottom, **all_kwargs)
00302             plots.append(bar)
00303             if not bottom: bottom = [0. for i in range(self.hists[0].nbins)]
00304             bottom = [sum(pair) for pair in zip(bottom, hist.y)]
00305         return plots
00306     def histstack(self, **kwargs):
00307         """
00308         Make a matplotlib hist plot, with each Hist stacked upon the last.
00309 
00310         Any additional keyword arguments will be passed to
00311         :func:`matplotlib.pyplot.hist`.
00312         """
00313         bottom = None # if this is set to zeroes, it fails for log y
00314         plots = []
00315         cumhist = None
00316         for i, hist in enumerate(self.hists):
00317             if cumhist:
00318                 cumhist = hist + cumhist
00319             else:
00320                 cumhist = copy.copy(hist)
00321             if self.title  is not None: cumhist.title  = self.title
00322             if self.xlabel is not None: cumhist.xlabel = self.xlabel
00323             if self.ylabel is not None: cumhist.ylabel = self.ylabel
00324             all_kwargs = copy.copy(kwargs)
00325             all_kwargs.update(self.kwargs[i])
00326             zorder = 0 + float(len(self) - i)/len(self) # plot in reverse order
00327             plot = cumhist.hist(zorder=zorder, **all_kwargs)
00328             plots.append(plot)
00329         return plots
00330     def barcluster(self, width=0.8, **kwargs):
00331         """
00332         Make a clustered bar plot.
00333 
00334         Any additional keyword arguments will be passed to
00335         :func:`matplotlib.pyplot.bar`.
00336         """
00337         plots = []
00338         spacer = (1. - width) / 2
00339         width = width / len(self.hists)
00340         for i, hist in enumerate(self.hists):
00341             if self.title  is not None: hist.title  = self.title
00342             if self.xlabel is not None: hist.xlabel = self.xlabel
00343             if self.ylabel is not None: hist.ylabel = self.ylabel
00344             all_kwargs = copy.copy(kwargs)
00345             all_kwargs.update(self.kwargs[i])
00346             bar = hist.bar(xoffset=width*i + spacer, width=width, **all_kwargs)
00347             plots.append(bar)
00348         return plots
00349     def barh(self, width=0.8, **kwargs):
00350         """
00351         Make a horizontal clustered matplotlib bar plot.
00352 
00353         Any additional keyword arguments will be passed to
00354         :func:`matplotlib.pyplot.bar`.
00355         """
00356         plots = []
00357         spacer = (1. - width) / 2
00358         width = width / len(self.hists)
00359         for i, hist in enumerate(self.hists):
00360             if self.title  is not None: hist.title  = self.title
00361             if self.xlabel is not None: hist.ylabel = self.xlabel
00362             if self.ylabel is not None: hist.xlabel = self.ylabel
00363             all_kwargs = copy.copy(kwargs)
00364             all_kwargs.update(self.kwargs[i])
00365             bar = hist.barh(yoffset=width*i + spacer, width=width, **all_kwargs)
00366             plots.append(bar)
00367         return plots
00368     def bar(self, **kwargs):
00369         """
00370         Make a bar plot, with all Hists in the stack overlaid.
00371 
00372         Any additional keyword arguments will be passed to
00373         :func:`matplotlib.pyplot.bar`.  You will probably want to set a 
00374         transparency value (i.e. *alpha* = 0.5).
00375         """
00376         plots = []
00377         for i, hist in enumerate(self.hists):
00378             if self.title  is not None: hist.title  = self.title
00379             if self.xlabel is not None: hist.xlabel = self.xlabel
00380             if self.ylabel is not None: hist.ylabel = self.ylabel
00381             all_kwargs = copy.copy(kwargs)
00382             all_kwargs.update(self.kwargs[i])
00383             bar = hist.bar(**all_kwargs)
00384             plots.append(bar)
00385         return plots
00386     def errorbar(self, offset=False, **kwargs):
00387         """
00388         Make a matplotlib errorbar plot, with all Hists in the stack overlaid.
00389 
00390         Passing 'offset=True' will slightly offset each dataset so overlapping
00391         errorbars are still visible.  Any additional keyword arguments will
00392         be passed to :func:`matplotlib.pyplot.errorbar`.
00393         """
00394         plots = []
00395         for i, hist in enumerate(self.hists):
00396             if self.title  is not None: hist.title  = self.title
00397             if self.xlabel is not None: hist.xlabel = self.xlabel
00398             if self.ylabel is not None: hist.ylabel = self.ylabel
00399             all_kwargs = copy.copy(kwargs)
00400             all_kwargs.update(self.kwargs[i])
00401             transform = plt.gca().transData
00402             if offset:
00403                 index_offset = (len(self.hists) - 1)/2.
00404                 pixel_offset = 1./72 * (i - index_offset)
00405                 transform = transforms.ScaledTranslation(
00406                     pixel_offset, 0, plt.gcf().dpi_scale_trans)
00407                 transform = plt.gca().transData + transform
00408             errorbar = hist.errorbar(transform=transform, **all_kwargs)
00409             plots.append(errorbar)
00410         return plots
00411     def errorbarh(self, **kwargs):
00412         """
00413         Make a horizontal matplotlib errorbar plot, with all Hists in the
00414         stack overlaid.
00415 
00416         Any additional keyword arguments will be passed to
00417         :func:`matplotlib.pyplot.errorbar`.
00418         """
00419         plots = []
00420         for i, hist in enumerate(self.hists):
00421             if self.title  is not None: hist.title  = self.title
00422             if self.xlabel is not None: hist.ylabel = self.xlabel
00423             if self.ylabel is not None: hist.xlabel = self.ylabel
00424             all_kwargs = copy.copy(kwargs)
00425             all_kwargs.update(self.kwargs[i])
00426             errorbar = hist.errorbarh(**all_kwargs)
00427             plots.append(errorbar)
00428         return plots
00429 
00430 ################ Define functions and classes for navigating within ROOT
00431 
00432 class RootFile(utilities.RootFile):
00433     """A wrapper for TFiles, allowing easier access to methods."""
00434     def get(self, object_name, path=None):
00435         try:
00436             return utilities.RootFile.get(self, object_name, path,
00437                                           Hist, Hist2D)
00438         except ReferenceError, e:
00439             raise ReferenceError(e)
00440 
00441 ################ Define additional helping functions
00442 
00443 def replace(string, replacements):
00444     """
00445     Modify a string based on a list of patterns and substitutions.
00446 
00447     replacements should be a list of two-entry tuples, the first entry giving
00448     a string to search for and the second entry giving the string with which
00449     to replace it.  If replacements includes a pattern entry containing
00450     'use_regexp', then all patterns will be treated as regular expressions
00451     using re.sub.
00452     """
00453     if not replacements:
00454         return string
00455     if 'use_regexp' in [x for x,y in replacements]:
00456         for pattern, repl in [x for x in replacements
00457                               if x[0] != 'use_regexp']:
00458             string = re.sub(pattern, repl, string)
00459     else:
00460         for pattern, repl in replacements:
00461             string = string.replace(pattern, repl)
00462     if re.match(_all_whitespace_string, string):
00463         return ""
00464     return string
00465