CMS 3D CMS Logo

 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Pages
root2matplotlib.py
Go to the documentation of this file.
1 """
2 Utilities for plotting ROOT histograms in matplotlib.
3 """
4 
5 __license__ = '''\
6 Copyright (c) 2009-2010 Jeff Klukas <klukas@wisc.edu>
7 
8 Permission is hereby granted, free of charge, to any person obtaining a copy
9 of this software and associated documentation files (the "Software"), to deal
10 in the Software without restriction, including without limitation the rights
11 to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12 copies of the Software, and to permit persons to whom the Software is
13 furnished to do so, subject to the following conditions:
14 
15 The above copyright notice and this permission notice shall be included in
16 all copies or substantial portions of the Software.
17 
18 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19 IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20 FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21 AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22 LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23 OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
24 THE SOFTWARE.
25 '''
26 
27 ################ Import python libraries
28 
29 import math
30 import ROOT
31 import re
32 import copy
33 import array
34 from rootplot import utilities
35 import matplotlib as mpl
36 import matplotlib.pyplot as plt
37 import matplotlib.transforms as transforms
38 import numpy as np
39 
40 ################ Define constants
41 
42 _all_whitespace_string = re.compile(r'\s*$')
43 
44 
45 ################ Define classes
46 
48  """A container to hold the parameters from a 2D ROOT histogram."""
49  def __init__(self, *args, **kwargs):
50  self.replacements = None
51  if 'replacements' in kwargs:
52  self.replacements = kwargs.pop('replacements')
53  utilities.Hist2D.__init__(self, *args, **kwargs)
54  def contour(self, **kwargs):
55  """Draw a contour plot."""
56  cs = plt.contour(self.x, self.y, self.content, **kwargs)
57  plt.clabel(cs, inline=1, fontsize=10)
58  if self.binlabelsx is not None:
59  plt.xticks(np.arange(self.nbinsx), self.binlabelsx)
60  if self.binlabelsy is not None:
61  plt.yticks(np.arange(self.nbinsy), self.binlabelsy)
62  return cs
63  def col(self, **kwargs):
64  """Draw a colored box plot using :func:`matplotlib.pyplot.imshow`."""
65  plot = plt.imshow(self.content, interpolation='nearest',
66  extent=[self.xedges[0], self.xedges[-1],
67  self.yedges[0], self.yedges[-1]],
68  aspect='auto', origin='lower', **kwargs)
69  return plot
70  def colz(self, **kwargs):
71  """
72  Draw a colored box plot with a colorbar using
73  :func:`matplotlib.pyplot.imshow`.
74  """
75  plot = self.col(**kwargs)
76  plt.colorbar(plot)
77  return plot
78  def box(self, maxsize=40, **kwargs):
79  """
80  Draw a box plot with size indicating content using
81  :func:`matplotlib.pyplot.scatter`.
82 
83  The data will be normalized, with the largest box using a marker of
84  size maxsize (in points).
85  """
86  x = np.hstack([self.x for i in range(self.nbinsy)])
87  y = np.hstack([[yval for i in range(self.nbinsx)] for yval in self.y])
88  maxvalue = np.max(self.content)
89  if maxvalue == 0:
90  maxvalue = 1
91  sizes = np.array(self.content).flatten() / maxvalue * maxsize
92  plot = plt.scatter(x, y, sizes, marker='s', **kwargs)
93  return plot
94  def TH2F(self, name=""):
95  """Return a ROOT.TH2F object with contents of this Hist2D."""
96  th2f = ROOT.TH2F(name, "",
97  self.nbinsx, array.array('f', self.xedges),
98  self.nbinsy, array.array('f', self.yedges))
99  th2f.SetTitle("%s;%s;%s" % (self.title, self.xlabel, self.ylabel))
100  for ix in range(self.nbinsx):
101  for iy in range(self.nbinsy):
102  th2f.SetBinContent(ix + 1, iy + 1, self.content[iy][ix])
103  return th2f
104 
106  """A container to hold the parameters from a ROOT histogram."""
107  def __init__(self, *args, **kwargs):
108  self.replacements = None
109  if 'replacements' in kwargs:
110  self.replacements = kwargs.pop('replacements')
111  utilities.Hist.__init__(self, *args, **kwargs)
112  def _prepare_xaxis(self, rotation=0, alignment='center'):
113  """Apply bounds and text labels on x axis."""
114  if self.binlabels is not None:
115  binwidth = (self.xedges[-1] - self.xedges[0]) / self.nbins
116  plt.xticks(self.x, self.binlabels,
117  rotation=rotation, ha=alignment)
118  plt.xlim(self.xedges[0], self.xedges[-1])
119 
120  def _prepare_yaxis(self, rotation=0, alignment='center'):
121  """Apply bound and text labels on y axis."""
122  if self.binlabels is not None:
123  binwidth = (self.xedges[-1] - self.xedges[0]) / self.nbins
124  plt.yticks(self.x, self.binlabels,
125  rotation=rotation, va=alignment)
126  plt.ylim(self.xedges[0], self.xedges[-1])
127 
128  def show_titles(self, **kwargs):
129  """Print the title and axis labels to the current figure."""
130  replacements = kwargs.get('replacements', None) or self.replacements
131  plt.title(replace(self.title, replacements))
132  plt.xlabel(replace(self.xlabel, replacements))
133  plt.ylabel(replace(self.ylabel, replacements))
134  def hist(self, label_rotation=0, label_alignment='center', **kwargs):
135  """
136  Generate a matplotlib hist figure.
137 
138  All additional keyword arguments will be passed to
139  :func:`matplotlib.pyplot.hist`.
140  """
141  kwargs.pop('fmt', None)
142  replacements = kwargs.get('replacements', None) or self.replacements
143  weights = self.y
144  # Kludge to avoid mpl bug when plotting all zeros
145  if self.y == [0] * self.nbins:
146  weights = [1.e-10] * self.nbins
147  plot = plt.hist(self.x, weights=weights, bins=self.xedges,
148  label=replace(self.label, replacements), **kwargs)
149  self._prepare_xaxis(label_rotation, label_alignment)
150  return plot
151  def errorbar(self, xerr=False, yerr=False, label_rotation=0,
152  label_alignment='center', **kwargs):
153  """
154  Generate a matplotlib errorbar figure.
155 
156  All additional keyword arguments will be passed to
157  :func:`matplotlib.pyplot.errorbar`.
158  """
159  if xerr:
160  kwargs['xerr'] = self.xerr
161  if yerr:
162  kwargs['yerr'] = self.yerr
163  replacements = kwargs.get('replacements', None) or self.replacements
164  errorbar = plt.errorbar(self.x, self.y,
165  label=replace(self.label, replacements),
166  **kwargs)
167  self._prepare_xaxis(label_rotation, label_alignment)
168  return errorbar
169  def errorbarh(self, xerr=False, yerr=False, label_rotation=0,
170  label_alignment='center', **kwargs):
171  """
172  Generate a horizontal matplotlib errorbar figure.
173 
174  All additional keyword arguments will be passed to
175  :func:`matplotlib.pyplot.errorbar`.
176  """
177  if xerr: kwargs['xerr'] = self.yerr
178  if yerr: kwargs['yerr'] = self.xerr
179  replacements = kwargs.get('replacements', None) or self.replacements
180  errorbar = plt.errorbar(self.y, self.x,
181  label=replace(self.label, replacements),
182  **kwargs)
183  self._prepare_yaxis(label_rotation, label_alignment)
184  return errorbar
185  def bar(self, xerr=False, yerr=False, xoffset=0., width=0.8,
186  label_rotation=0, label_alignment='center', **kwargs):
187  """
188  Generate a matplotlib bar figure.
189 
190  All additional keyword arguments will be passed to
191  :func:`matplotlib.pyplot.bar`.
192  """
193  kwargs.pop('fmt', None)
194  if xerr: kwargs['xerr'] = self.av_xerr()
195  if yerr: kwargs['yerr'] = self.av_yerr()
196  replacements = kwargs.get('replacements', None) or self.replacements
197  ycontent = [self.xedges[i] + self.width[i] * xoffset
198  for i in range(len(self.xedges) - 1)]
199  width = [x * width for x in self.width]
200  bar = plt.bar(ycontent, self.y, width,
201  label=replace(self.label, replacements), **kwargs)
202  self._prepare_xaxis(label_rotation, label_alignment)
203  return bar
204  def barh(self, xerr=False, yerr=False, yoffset=0., width=0.8,
205  label_rotation=0, label_alignment='center', **kwargs):
206  """
207  Generate a horizontal matplotlib bar figure.
208 
209  All additional keyword arguments will be passed to
210  :func:`matplotlib.pyplot.bar`.
211  """
212  kwargs.pop('fmt', None)
213  if xerr: kwargs['xerr'] = self.av_yerr()
214  if yerr: kwargs['yerr'] = self.av_xerr()
215  replacements = kwargs.get('replacements', None) or self.replacements
216  xcontent = [self.xedges[i] + self.width[i] * yoffset
217  for i in range(len(self.xedges) - 1)]
218  width = [x * width for x in self.width]
219  barh = plt.barh(xcontent, self.y, width,
220  label=replace(self.label, replacements),
221  **kwargs)
222  self._prepare_yaxis(label_rotation, label_alignment)
223  return barh
224 
226  """
227  A container to hold Hist objects for plotting together.
228 
229  When plotting, the title and the x and y labels of the last Hist added
230  will be used unless specified otherwise in the constructor.
231  """
232  def __init__(self, *args, **kwargs):
233  if 'replacements' in kwargs:
234  self.replacements = kwargs.pop('replacements')
235  utilities.HistStack.__init__(self, *args, **kwargs)
236  def show_titles(self, **kwargs):
237  self.hists[-1].show_titles()
238  def hist(self, label_rotation=0, **kwargs):
239  """
240  Make a matplotlib hist plot.
241 
242  Any additional keyword arguments will be passed to
243  :func:`matplotlib.pyplot.hist`, which allows a vast array of
244  possibilities. Particlularly, the *histtype* values such as
245  ``'barstacked'`` and ``'stepfilled'`` give substantially different
246  results. You will probably want to include a transparency value
247  (i.e. *alpha* = 0.5).
248  """
249  contents = np.dstack([hist.y for hist in self.hists])
250  xedges = self.hists[0].xedges
251  x = np.dstack([hist.x for hist in self.hists])
252  labels = [hist.label for hist in self.hists]
253  try:
254  clist = [item['color'] for item in self.kwargs]
255  plt.gca().set_color_cycle(clist)
256  ## kwargs['color'] = clist # For newer version of matplotlib
257  except:
258  pass
259  plot = plt.hist(x, weights=contents, bins=xedges,
260  label=labels, **kwargs)
261  def bar3d(self, **kwargs):
262  #### Not yet ready for primetime
263  from mpl_toolkits.mplot3d import Axes3D
264  fig = plt.figure()
265  ax = Axes3D(fig)
266  plots = []
267  labels = []
268  for i, hist in enumerate(self.hists):
269  if self.title is not None: hist.title = self.title
270  if self.xlabel is not None: hist.xlabel = self.xlabel
271  if self.ylabel is not None: hist.ylabel = self.ylabel
272  labels.append(hist.label)
273  all_kwargs = copy.copy(kwargs)
274  all_kwargs.update(self.kwargs[i])
275  bar = ax.bar(hist.x, hist.y, zs=i, zdir='y', width=hist.width,
276  **all_kwargs)
277  plots.append(bar)
278  from matplotlib.ticker import FixedLocator
279  locator = FixedLocator(range(len(labels)))
280  ax.w_yaxis.set_major_locator(locator)
281  ax.w_yaxis.set_ticklabels(labels)
282  ax.set_ylim3d(-1, len(labels))
283  return plots
284  def barstack(self, **kwargs):
285  """
286  Make a matplotlib bar plot, with each Hist stacked upon the last.
287 
288  Any additional keyword arguments will be passed to
289  :func:`matplotlib.pyplot.bar`.
290  """
291  bottom = None # if this is set to zeroes, it fails for log y
292  plots = []
293  for i, hist in enumerate(self.hists):
294  if self.title is not None: hist.title = self.title
295  if self.xlabel is not None: hist.xlabel = self.xlabel
296  if self.ylabel is not None: hist.ylabel = self.ylabel
297  all_kwargs = copy.copy(kwargs)
298  all_kwargs.update(self.kwargs[i])
299  bar = hist.bar(bottom=bottom, **all_kwargs)
300  plots.append(bar)
301  if not bottom: bottom = [0. for i in range(self.hists[0].nbins)]
302  bottom = [sum(pair) for pair in zip(bottom, hist.y)]
303  return plots
304  def barcluster(self, width=0.8, **kwargs):
305  """
306  Make a clustered bar plot.
307 
308  Any additional keyword arguments will be passed to
309  :func:`matplotlib.pyplot.bar`.
310  """
311  plots = []
312  spacer = (1. - width) / 2
313  width = width / len(self.hists)
314  for i, hist in enumerate(self.hists):
315  if self.title is not None: hist.title = self.title
316  if self.xlabel is not None: hist.xlabel = self.xlabel
317  if self.ylabel is not None: hist.ylabel = self.ylabel
318  all_kwargs = copy.copy(kwargs)
319  all_kwargs.update(self.kwargs[i])
320  bar = hist.bar(xoffset=width*i + spacer, width=width, **all_kwargs)
321  plots.append(bar)
322  return plots
323  def barh(self, width=0.8, **kwargs):
324  """
325  Make a horizontal clustered matplotlib bar plot.
326 
327  Any additional keyword arguments will be passed to
328  :func:`matplotlib.pyplot.bar`.
329  """
330  plots = []
331  spacer = (1. - width) / 2
332  width = width / len(self.hists)
333  for i, hist in enumerate(self.hists):
334  if self.title is not None: hist.title = self.title
335  if self.xlabel is not None: hist.ylabel = self.xlabel
336  if self.ylabel is not None: hist.xlabel = self.ylabel
337  all_kwargs = copy.copy(kwargs)
338  all_kwargs.update(self.kwargs[i])
339  bar = hist.barh(yoffset=width*i + spacer, width=width, **all_kwargs)
340  plots.append(bar)
341  return plots
342  def bar(self, **kwargs):
343  """
344  Make a bar plot, with all Hists in the stack overlaid.
345 
346  Any additional keyword arguments will be passed to
347  :func:`matplotlib.pyplot.bar`. You will probably want to set a
348  transparency value (i.e. *alpha*=0.5).
349  """
350  plots = []
351  for i, hist in enumerate(self.hists):
352  if self.title is not None: hist.title = self.title
353  if self.xlabel is not None: hist.xlabel = self.xlabel
354  if self.ylabel is not None: hist.ylabel = self.ylabel
355  all_kwargs = copy.copy(kwargs)
356  all_kwargs.update(self.kwargs[i])
357  bar = hist.bar(**all_kwargs)
358  plots.append(bar)
359  return plots
360  def errorbar(self, offset=False, **kwargs):
361  """
362  Make a matplotlib errorbar plot, with all Hists in the stack overlaid.
363 
364  Passing 'offset=True' will slightly offset each dataset so overlapping
365  errorbars are still visible. Any additional keyword arguments will
366  be passed to :func:`matplotlib.pyplot.errorbar`.
367  """
368  plots = []
369  for i, hist in enumerate(self.hists):
370  if self.title is not None: hist.title = self.title
371  if self.xlabel is not None: hist.xlabel = self.xlabel
372  if self.ylabel is not None: hist.ylabel = self.ylabel
373  all_kwargs = copy.copy(kwargs)
374  all_kwargs.update(self.kwargs[i])
375  transform = plt.gca().transData
376  if offset:
377  index_offset = (len(self.hists) - 1)/2.
378  pixel_offset = 1./72 * (i - index_offset)
379  transform = transforms.ScaledTranslation(
380  pixel_offset, 0, plt.gcf().dpi_scale_trans)
381  transform = plt.gca().transData + transform
382  errorbar = hist.errorbar(transform=transform, **all_kwargs)
383  plots.append(errorbar)
384  return plots
385  def errorbarh(self, **kwargs):
386  """
387  Make a horizontal matplotlib errorbar plot, with all Hists in the
388  stack overlaid.
389 
390  Any additional keyword arguments will be passed to
391  :func:`matplotlib.pyplot.errorbar`.
392  """
393  plots = []
394  for i, hist in enumerate(self.hists):
395  if self.title is not None: hist.title = self.title
396  if self.xlabel is not None: hist.ylabel = self.xlabel
397  if self.ylabel is not None: hist.xlabel = self.ylabel
398  all_kwargs = copy.copy(kwargs)
399  all_kwargs.update(self.kwargs[i])
400  errorbar = hist.errorbarh(**all_kwargs)
401  plots.append(errorbar)
402  return plots
403 
404 ################ Define functions and classes for navigating within ROOT
405 
407  """A wrapper for TFiles, allowing easier access to methods."""
408  def get(self, object_name, path=None):
409  try:
410  return utilities.RootFile.get(self, object_name, path,
411  Hist, Hist2D)
412  except ReferenceError, e:
413  raise ReferenceError(e)
414 
415 ################ Define additional helping functions
416 
417 def replace(string, replacements):
418  """
419  Modify a string based on a list of patterns and substitutions.
420 
421  replacements should be a list of two-entry tuples, the first entry giving
422  a string to search for and the second entry giving the string with which
423  to replace it. If replacements includes a pattern entry containing
424  'use_regexp', then all patterns will be treated as regular expressions
425  using re.sub.
426  """
427  if not replacements:
428  return string
429  if 'use_regexp' in [x for x,y in replacements]:
430  for pattern, repl in [x for x in replacements
431  if x[0] != 'use_regexp']:
432  string = re.sub(pattern, repl, string)
433  else:
434  for pattern, repl in replacements:
435  string = string.replace(pattern, repl)
436  if re.match(_all_whitespace_string, string):
437  return ""
438  return string
439