CMS 3D CMS Logo

/data/doxygen/doxygen-1.7.3/gen/CMSSW_4_2_8/src/PhysicsTools/PythonAnalysis/python/rootplot/root2matplotlib.py

Go to the documentation of this file.
00001 """
00002 Utilities for plotting ROOT histograms in matplotlib.
00003 """
00004 
00005 __license__ = '''\
00006 Copyright (c) 2009-2010 Jeff Klukas <klukas@wisc.edu>
00007 
00008 Permission is hereby granted, free of charge, to any person obtaining a copy
00009 of this software and associated documentation files (the "Software"), to deal
00010 in the Software without restriction, including without limitation the rights
00011 to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
00012 copies of the Software, and to permit persons to whom the Software is
00013 furnished to do so, subject to the following conditions:
00014 
00015 The above copyright notice and this permission notice shall be included in
00016 all copies or substantial portions of the Software.
00017 
00018 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
00019 IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
00020 FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
00021 AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
00022 LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
00023 OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
00024 THE SOFTWARE.
00025 '''
00026 
00027 ################ Import python libraries
00028 
00029 import math
00030 import ROOT
00031 import re
00032 import copy
00033 import array
00034 from rootplot import utilities
00035 import matplotlib as mpl
00036 import matplotlib.pyplot as plt
00037 import matplotlib.transforms as transforms
00038 import numpy as np
00039 
00040 ################ Define constants
00041 
00042 _all_whitespace_string = re.compile(r'\s*$')
00043 
00044 
00045 ################ Define classes
00046 
00047 class Hist2D(utilities.Hist2D):
00048     """A container to hold the parameters from a 2D ROOT histogram."""
00049     def __init__(self, *args, **kwargs):
00050         self.replacements = None
00051         if 'replacements' in kwargs:
00052             self.replacements = kwargs.pop('replacements')
00053         utilities.Hist2D.__init__(self, *args, **kwargs)
00054     def contour(self, **kwargs):
00055         """Draw a contour plot."""
00056         cs = plt.contour(self.x, self.y, self.content, **kwargs)
00057         plt.clabel(cs, inline=1, fontsize=10)
00058         if self.binlabelsx is not None:
00059             plt.xticks(np.arange(self.nbinsx), self.binlabelsx)
00060         if self.binlabelsy is not None:
00061             plt.yticks(np.arange(self.nbinsy), self.binlabelsy)
00062         return cs
00063     def col(self, **kwargs):
00064         """Draw a colored box plot using :func:`matplotlib.pyplot.imshow`."""
00065         plot = plt.imshow(self.content, interpolation='nearest',
00066                           extent=[self.xedges[0], self.xedges[-1],
00067                                   self.yedges[0], self.yedges[-1]],
00068                           aspect='auto', origin='lower', **kwargs)
00069         return plot
00070     def colz(self, **kwargs):
00071         """
00072         Draw a colored box plot with a colorbar using
00073         :func:`matplotlib.pyplot.imshow`.
00074         """
00075         plot = self.col(**kwargs)
00076         plt.colorbar(plot)
00077         return plot
00078     def box(self, maxsize=40, **kwargs):
00079         """
00080         Draw a box plot with size indicating content using
00081         :func:`matplotlib.pyplot.scatter`.
00082         
00083         The data will be normalized, with the largest box using a marker of
00084         size maxsize (in points).
00085         """
00086         x = np.hstack([self.x for i in range(self.nbinsy)])
00087         y = np.hstack([[yval for i in range(self.nbinsx)] for yval in self.y])
00088         maxvalue = np.max(self.content)
00089         if maxvalue == 0:
00090             maxvalue = 1
00091         sizes = np.array(self.content).flatten() / maxvalue * maxsize
00092         plot = plt.scatter(x, y, sizes, marker='s', **kwargs)
00093         return plot
00094     def TH2F(self, name=""):
00095         """Return a ROOT.TH2F object with contents of this Hist2D."""
00096         th2f = ROOT.TH2F(name, "",
00097                          self.nbinsx, array.array('f', self.xedges),
00098                          self.nbinsy, array.array('f', self.yedges))
00099         th2f.SetTitle("%s;%s;%s" % (self.title, self.xlabel, self.ylabel))
00100         for ix in range(self.nbinsx):
00101             for iy in range(self.nbinsy):
00102                 th2f.SetBinContent(ix + 1, iy + 1, self.content[iy][ix])
00103         return th2f
00104 
00105 class Hist(utilities.Hist):
00106     """A container to hold the parameters from a ROOT histogram."""
00107     def __init__(self, *args, **kwargs):
00108         self.replacements = None
00109         if 'replacements' in kwargs:
00110             self.replacements = kwargs.pop('replacements')
00111         utilities.Hist.__init__(self, *args, **kwargs)
00112     def _prepare_xaxis(self, rotation=0, alignment='center'):
00113         """Apply bounds and text labels on x axis."""
00114         if self.binlabels is not None:
00115             binwidth = (self.xedges[-1] - self.xedges[0]) / self.nbins
00116             plt.xticks(self.x, self.binlabels,
00117                        rotation=rotation, ha=alignment)
00118         plt.xlim(self.xedges[0], self.xedges[-1])
00119 
00120     def _prepare_yaxis(self, rotation=0, alignment='center'):
00121         """Apply bound and text labels on y axis."""
00122         if self.binlabels is not None:
00123             binwidth = (self.xedges[-1] - self.xedges[0]) / self.nbins
00124             plt.yticks(self.x, self.binlabels,
00125                        rotation=rotation, va=alignment)
00126         plt.ylim(self.xedges[0], self.xedges[-1])
00127 
00128     def show_titles(self, **kwargs):
00129         """Print the title and axis labels to the current figure."""
00130         replacements = kwargs.get('replacements', None) or self.replacements
00131         plt.title(replace(self.title, replacements))
00132         plt.xlabel(replace(self.xlabel, replacements))
00133         plt.ylabel(replace(self.ylabel, replacements))
00134     def hist(self, label_rotation=0, label_alignment='center', **kwargs):
00135         """
00136         Generate a matplotlib hist figure.
00137 
00138         All additional keyword arguments will be passed to
00139         :func:`matplotlib.pyplot.hist`.
00140         """
00141         kwargs.pop('fmt', None)
00142         replacements = kwargs.get('replacements', None) or self.replacements
00143         weights = self.y
00144         # Kludge to avoid mpl bug when plotting all zeros
00145         if self.y == [0] * self.nbins:
00146             weights = [1.e-10] * self.nbins
00147         plot = plt.hist(self.x, weights=weights, bins=self.xedges,
00148                         label=replace(self.label, replacements), **kwargs)
00149         self._prepare_xaxis(label_rotation, label_alignment)
00150         return plot
00151     def errorbar(self, xerr=False, yerr=False, label_rotation=0,
00152                  label_alignment='center', **kwargs):
00153         """
00154         Generate a matplotlib errorbar figure.
00155 
00156         All additional keyword arguments will be passed to
00157         :func:`matplotlib.pyplot.errorbar`.
00158         """
00159         if xerr:
00160             kwargs['xerr'] = self.xerr
00161         if yerr:
00162             kwargs['yerr'] = self.yerr
00163         replacements = kwargs.get('replacements', None) or self.replacements
00164         errorbar = plt.errorbar(self.x, self.y,
00165                                 label=replace(self.label, replacements),
00166                                 **kwargs)
00167         self._prepare_xaxis(label_rotation, label_alignment)
00168         return errorbar
00169     def errorbarh(self, xerr=False, yerr=False, label_rotation=0,
00170                   label_alignment='center', **kwargs):
00171         """
00172         Generate a horizontal matplotlib errorbar figure.
00173 
00174         All additional keyword arguments will be passed to
00175         :func:`matplotlib.pyplot.errorbar`.
00176         """
00177         if xerr: kwargs['xerr'] = self.yerr
00178         if yerr: kwargs['yerr'] = self.xerr
00179         replacements = kwargs.get('replacements', None) or self.replacements
00180         errorbar = plt.errorbar(self.y, self.x,
00181                                 label=replace(self.label, replacements),
00182                                 **kwargs)
00183         self._prepare_yaxis(label_rotation, label_alignment)
00184         return errorbar
00185     def bar(self, xerr=False, yerr=False, xoffset=0., width=0.8, 
00186             label_rotation=0, label_alignment='center', **kwargs):
00187         """
00188         Generate a matplotlib bar figure.
00189 
00190         All additional keyword arguments will be passed to
00191         :func:`matplotlib.pyplot.bar`.
00192         """
00193         kwargs.pop('fmt', None)
00194         if xerr: kwargs['xerr'] = self.av_xerr()
00195         if yerr: kwargs['yerr'] = self.av_yerr()
00196         replacements = kwargs.get('replacements', None) or self.replacements
00197         ycontent = [self.xedges[i] + self.width[i] * xoffset
00198                     for i in range(len(self.xedges) - 1)]
00199         width = [x * width for x in self.width]
00200         bar = plt.bar(ycontent, self.y, width,
00201                       label=replace(self.label, replacements), **kwargs)
00202         self._prepare_xaxis(label_rotation, label_alignment)
00203         return bar
00204     def barh(self, xerr=False, yerr=False, yoffset=0., width=0.8,
00205              label_rotation=0, label_alignment='center', **kwargs):
00206         """
00207         Generate a horizontal matplotlib bar figure.
00208 
00209         All additional keyword arguments will be passed to
00210         :func:`matplotlib.pyplot.bar`.
00211         """
00212         kwargs.pop('fmt', None)
00213         if xerr: kwargs['xerr'] = self.av_yerr()
00214         if yerr: kwargs['yerr'] = self.av_xerr()
00215         replacements = kwargs.get('replacements', None) or self.replacements
00216         xcontent = [self.xedges[i] + self.width[i] * yoffset
00217                     for i in range(len(self.xedges) - 1)]
00218         width = [x * width for x in self.width]
00219         barh = plt.barh(xcontent, self.y, width,
00220                         label=replace(self.label, replacements),
00221                        **kwargs)
00222         self._prepare_yaxis(label_rotation, label_alignment)
00223         return barh
00224 
00225 class HistStack(utilities.HistStack):
00226     """
00227     A container to hold Hist objects for plotting together.
00228 
00229     When plotting, the title and the x and y labels of the last Hist added
00230     will be used unless specified otherwise in the constructor.
00231     """
00232     def __init__(self, *args, **kwargs):
00233         if 'replacements' in kwargs:
00234             self.replacements = kwargs.pop('replacements')
00235         utilities.HistStack.__init__(self, *args, **kwargs)
00236     def show_titles(self, **kwargs):
00237         self.hists[-1].show_titles()
00238     def hist(self, label_rotation=0, **kwargs):
00239         """
00240         Make a matplotlib hist plot.
00241 
00242         Any additional keyword arguments will be passed to
00243         :func:`matplotlib.pyplot.hist`, which allows a vast array of
00244         possibilities.  Particlularly, the *histtype* values such as
00245         ``'barstacked'`` and ``'stepfilled'`` give substantially different
00246         results.  You will probably want to include a transparency value
00247         (i.e. *alpha* = 0.5).
00248         """
00249         contents = np.dstack([hist.y for hist in self.hists])
00250         xedges = self.hists[0].xedges
00251         x = np.dstack([hist.x for hist in self.hists])
00252         labels = [hist.label for hist in self.hists]
00253         try:
00254             clist = [item['color'] for item in self.kwargs]
00255             plt.gca().set_color_cycle(clist)
00256             ## kwargs['color'] = clist # For newer version of matplotlib
00257         except:
00258             pass
00259         plot = plt.hist(x, weights=contents, bins=xedges,
00260                         label=labels, **kwargs)
00261     def bar3d(self, **kwargs):
00262         #### Not yet ready for primetime
00263         from mpl_toolkits.mplot3d import Axes3D
00264         fig = plt.figure()
00265         ax = Axes3D(fig)
00266         plots = []
00267         labels = []
00268         for i, hist in enumerate(self.hists):
00269             if self.title  is not None: hist.title  = self.title
00270             if self.xlabel is not None: hist.xlabel = self.xlabel
00271             if self.ylabel is not None: hist.ylabel = self.ylabel
00272             labels.append(hist.label)
00273             all_kwargs = copy.copy(kwargs)
00274             all_kwargs.update(self.kwargs[i])
00275             bar = ax.bar(hist.x, hist.y, zs=i, zdir='y', width=hist.width,
00276                          **all_kwargs)
00277             plots.append(bar)
00278         from matplotlib.ticker import FixedLocator
00279         locator = FixedLocator(range(len(labels)))
00280         ax.w_yaxis.set_major_locator(locator)
00281         ax.w_yaxis.set_ticklabels(labels)
00282         ax.set_ylim3d(-1, len(labels))
00283         return plots
00284     def barstack(self, **kwargs):
00285         """
00286         Make a matplotlib bar plot, with each Hist stacked upon the last.
00287 
00288         Any additional keyword arguments will be passed to
00289         :func:`matplotlib.pyplot.bar`.
00290         """
00291         bottom = None # if this is set to zeroes, it fails for log y
00292         plots = []
00293         for i, hist in enumerate(self.hists):
00294             if self.title  is not None: hist.title  = self.title
00295             if self.xlabel is not None: hist.xlabel = self.xlabel
00296             if self.ylabel is not None: hist.ylabel = self.ylabel
00297             all_kwargs = copy.copy(kwargs)
00298             all_kwargs.update(self.kwargs[i])
00299             bar = hist.bar(bottom=bottom, **all_kwargs)
00300             plots.append(bar)
00301             if not bottom: bottom = [0. for i in range(self.hists[0].nbins)]
00302             bottom = [sum(pair) for pair in zip(bottom, hist.y)]
00303         return plots
00304     def barcluster(self, width=0.8, **kwargs):
00305         """
00306         Make a clustered bar plot.
00307 
00308         Any additional keyword arguments will be passed to
00309         :func:`matplotlib.pyplot.bar`.
00310         """
00311         plots = []
00312         spacer = (1. - width) / 2
00313         width = width / len(self.hists)
00314         for i, hist in enumerate(self.hists):
00315             if self.title  is not None: hist.title  = self.title
00316             if self.xlabel is not None: hist.xlabel = self.xlabel
00317             if self.ylabel is not None: hist.ylabel = self.ylabel
00318             all_kwargs = copy.copy(kwargs)
00319             all_kwargs.update(self.kwargs[i])
00320             bar = hist.bar(xoffset=width*i + spacer, width=width, **all_kwargs)
00321             plots.append(bar)
00322         return plots
00323     def barh(self, width=0.8, **kwargs):
00324         """
00325         Make a horizontal clustered matplotlib bar plot.
00326 
00327         Any additional keyword arguments will be passed to
00328         :func:`matplotlib.pyplot.bar`.
00329         """
00330         plots = []
00331         spacer = (1. - width) / 2
00332         width = width / len(self.hists)
00333         for i, hist in enumerate(self.hists):
00334             if self.title  is not None: hist.title  = self.title
00335             if self.xlabel is not None: hist.ylabel = self.xlabel
00336             if self.ylabel is not None: hist.xlabel = self.ylabel
00337             all_kwargs = copy.copy(kwargs)
00338             all_kwargs.update(self.kwargs[i])
00339             bar = hist.barh(yoffset=width*i + spacer, width=width, **all_kwargs)
00340             plots.append(bar)
00341         return plots
00342     def bar(self, **kwargs):
00343         """
00344         Make a bar plot, with all Hists in the stack overlaid.
00345 
00346         Any additional keyword arguments will be passed to
00347         :func:`matplotlib.pyplot.bar`.  You will probably want to set a 
00348         transparency value (i.e. *alpha*=0.5).
00349         """
00350         plots = []
00351         for i, hist in enumerate(self.hists):
00352             if self.title  is not None: hist.title  = self.title
00353             if self.xlabel is not None: hist.xlabel = self.xlabel
00354             if self.ylabel is not None: hist.ylabel = self.ylabel
00355             all_kwargs = copy.copy(kwargs)
00356             all_kwargs.update(self.kwargs[i])
00357             bar = hist.bar(**all_kwargs)
00358             plots.append(bar)
00359         return plots
00360     def errorbar(self, offset=False, **kwargs):
00361         """
00362         Make a matplotlib errorbar plot, with all Hists in the stack overlaid.
00363 
00364         Passing 'offset=True' will slightly offset each dataset so overlapping
00365         errorbars are still visible.  Any additional keyword arguments will
00366         be passed to :func:`matplotlib.pyplot.errorbar`.
00367         """
00368         plots = []
00369         for i, hist in enumerate(self.hists):
00370             if self.title  is not None: hist.title  = self.title
00371             if self.xlabel is not None: hist.xlabel = self.xlabel
00372             if self.ylabel is not None: hist.ylabel = self.ylabel
00373             all_kwargs = copy.copy(kwargs)
00374             all_kwargs.update(self.kwargs[i])
00375             transform = plt.gca().transData
00376             if offset:
00377                 index_offset = (len(self.hists) - 1)/2.
00378                 pixel_offset = 1./72 * (i - index_offset)
00379                 transform = transforms.ScaledTranslation(
00380                     pixel_offset, 0, plt.gcf().dpi_scale_trans)
00381                 transform = plt.gca().transData + transform
00382             errorbar = hist.errorbar(transform=transform, **all_kwargs)
00383             plots.append(errorbar)
00384         return plots
00385     def errorbarh(self, **kwargs):
00386         """
00387         Make a horizontal matplotlib errorbar plot, with all Hists in the
00388         stack overlaid.
00389 
00390         Any additional keyword arguments will be passed to
00391         :func:`matplotlib.pyplot.errorbar`.
00392         """
00393         plots = []
00394         for i, hist in enumerate(self.hists):
00395             if self.title  is not None: hist.title  = self.title
00396             if self.xlabel is not None: hist.ylabel = self.xlabel
00397             if self.ylabel is not None: hist.xlabel = self.ylabel
00398             all_kwargs = copy.copy(kwargs)
00399             all_kwargs.update(self.kwargs[i])
00400             errorbar = hist.errorbarh(**all_kwargs)
00401             plots.append(errorbar)
00402         return plots
00403 
00404 ################ Define functions and classes for navigating within ROOT
00405 
00406 class RootFile(utilities.RootFile):
00407     """A wrapper for TFiles, allowing easier access to methods."""
00408     def get(self, object_name, path=None):
00409         try:
00410             return utilities.RootFile.get(self, object_name, path,
00411                                           Hist, Hist2D)
00412         except ReferenceError, e:
00413             raise ReferenceError(e)
00414 
00415 ################ Define additional helping functions
00416 
00417 def replace(string, replacements):
00418     """
00419     Modify a string based on a list of patterns and substitutions.
00420 
00421     replacements should be a list of two-entry tuples, the first entry giving
00422     a string to search for and the second entry giving the string with which
00423     to replace it.  If replacements includes a pattern entry containing
00424     'use_regexp', then all patterns will be treated as regular expressions
00425     using re.sub.
00426     """
00427     if not replacements:
00428         return string
00429     if 'use_regexp' in [x for x,y in replacements]:
00430         for pattern, repl in [x for x in replacements
00431                               if x[0] != 'use_regexp']:
00432             string = re.sub(pattern, repl, string)
00433     else:
00434         for pattern, repl in replacements:
00435             string = string.replace(pattern, repl)
00436     if re.match(_all_whitespace_string, string):
00437         return ""
00438     return string
00439