# -*- coding: utf-8 -*- from __future__ import absolute_import, division, print_function, unicode_literals import numpy as np try: import matplotlib.pyplot as plt except ImportError: pass RC_PARAMS = {'axes.axisbelow': True, 'axes.edgecolor': '.8', 'axes.facecolor': 'white', 'axes.grid': False, 'axes.labelcolor': '.15', 'axes.linewidth': 1.0, 'axes.labelpad' : 10.0, 'figure.facecolor': 'white', 'font.family': ['sans-serif'], 'font.sans-serif': ['Arial', 'Liberation Sans', 'sans-serif'], 'grid.color': '.8', 'grid.linestyle': '-', 'image.cmap': 'Greys', 'legend.frameon': False, 'legend.numpoints': 1, 'legend.scatterpoints': 1, 'lines.solid_capstyle': 'round', 'text.color': '1.0', 'xtick.color': '1.0', 'xtick.direction': 'out', 'xtick.major.size': 0.0, 'xtick.minor.size': 0.0, 'xtick.major.pad' : 5, 'ytick.color': '1.0', 'ytick.direction': 'out', 'ytick.major.size': 0.0, 'ytick.minor.size': 0.0, 'axes.ymargin' : 0.9, 'ytick.major.pad': 5} COLOR_PAIRS = (((0.65098041296005249, 0.80784314870834351, 0.89019608497619629), (0.12572087695201239, 0.47323337360924367, 0.707327968232772)), ((0.68899655751153521, 0.8681737867056154, 0.54376011946622071), (0.21171857311445125, 0.63326415104024547, 0.1812226118410335)), ((0.98320646005518297, 0.5980161709820524, 0.59423301088459368), (0.89059593116535862, 0.10449827132271793, 0.11108035462744099)), ((0.99175701702342312, 0.74648213716698619, 0.43401768935077328), (0.99990772780250103, 0.50099192647372981, 0.0051211073118098693)), ((0.78329874347238004, 0.68724338552531095, 0.8336793640080622), (0.42485198495434734, 0.2511495584950722, 0.60386007743723258)), ((0.99760092286502611, 0.99489427150464516, 0.5965244373854468), (0.69411766529083252, 0.3490196168422699, 0.15686275064945221))) def draw_termite(values_mat, col_labels, row_labels, highlight_cols=None, highlight_colors=None, save=False, pow_x = 0.66, pow_y = 0.8): """ Make a "termite" plot, typically used for assessing topic models with a tabular layout that promotes comparison of terms both within and across topics. Args: values_mat (``np.ndarray`` or matrix): matrix of values with shape (# row labels, # col labels) used to size the dots on the grid col_labels (seq[str]): labels used to identify x-axis ticks on the grid row_labels(seq[str]): labels used to identify y-axis ticks on the grid highlight_cols (int or seq[int], optional): indices for columns to visually highlight in the plot with contrasting colors highlight_colors (tuple of 2-tuples): each 2-tuple corresponds to a pair of (light/dark) matplotlib-friendly colors used to highlight a single column; if not specified (default), a good set of 6 pairs are used save (str, optional): give the full /path/to/fname on disk to save figure Returns: ``matplotlib.axes.Axes.axis``: axis on which termite plot is plotted Raises: ValueError: if more columns are selected for highlighting than colors or if any of the inputs' dimensions don't match References: .. Chuang, Jason, Christopher D. Manning, and Jeffrey Heer. "Termite: Visualization techniques for assessing textual topic models." Proceedings of the International Working Conference on Advanced Visual Interfaces. ACM, 2012. .. seealso:: :func:`TopicModel.termite_plot ` """ try: plt except NameError: raise ImportError( 'matplotlib is not installed, so textacy.viz won\'t work; install it \ individually, or along with textacy via `pip install textacy[viz]`') n_rows, n_cols = values_mat.shape max_val = np.max(values_mat) if n_rows != len(row_labels): msg = "values_mat and row_labels dimensions don't match: {} vs. {}".format( n_rows, len(row_labels)) raise ValueError(msg) if n_cols != len(col_labels): msg = "values_mat and col_labels dimensions don't match: {} vs. {}".format( n_cols, len(col_labels)) raise ValueError(msg) if highlight_colors is None: highlight_colors = COLOR_PAIRS if highlight_cols is not None: if isinstance(highlight_cols, int): highlight_cols = (highlight_cols,) elif len(highlight_cols) > len(highlight_colors): msg = 'no more than {} columns may be highlighted at once'.format( len(highlight_colors)) raise ValueError(msg) highlight_colors = {hc: COLOR_PAIRS[i] for i, hc in enumerate(highlight_cols)} with plt.rc_context(RC_PARAMS): fig, ax = plt.subplots(figsize=(pow(n_cols, pow_y), pow(n_rows, pow_x))) #hier fesntergröße _ = ax.set_yticks(range(n_rows)) yticklabels = ax.set_yticklabels(row_labels, fontsize=14, color='gray') if highlight_cols is not None: for i, ticklabel in enumerate(yticklabels): max_tick_val = max(values_mat[i, hc] for hc in highlight_cols) for hc in highlight_cols: if max_tick_val > 0 and values_mat[i, hc] == max_tick_val: ticklabel.set_color(highlight_colors[hc][1]) ax.get_xaxis().set_ticks_position('top') _ = ax.set_xticks(range(n_cols)) xticklabels = ax.set_xticklabels(col_labels, fontsize=14, color='gray', rotation=30, ha='left') if highlight_cols is not None: gridlines = ax.get_xgridlines() for i, ticklabel in enumerate(xticklabels): if i in highlight_cols: ticklabel.set_color(highlight_colors[i][1]) gridlines[i].set_color(highlight_colors[i][0]) gridlines[i].set_alpha(0.5) for col_ind in range(n_cols): if highlight_cols is not None and col_ind in highlight_cols: ax.scatter([col_ind for _ in range(n_rows)], [i for i in range(n_rows)], s=600 * (values_mat[:, col_ind] / max_val), alpha=0.5, linewidth=1, color=highlight_colors[col_ind][0], edgecolor=highlight_colors[col_ind][1]) else: ax.scatter([col_ind for _ in range(n_rows)], [i for i in range(n_rows)], s=600 * (values_mat[:, col_ind] / max_val), alpha=0.5, linewidth=1, color='black', edgecolor='gray') _ = ax.set_xlim(left=-1, right=n_cols) _ = ax.set_ylim(bottom=-1, top=n_rows) ax.invert_yaxis() # otherwise, values/labels go from bottom to top #plt.ylim(ymax=5) if save: fig.savefig(save, bbox_inches='tight', dpi=100) return ax