topicModelingTickets/draw.py

166 lines
7.4 KiB
Python

# -*- coding: utf-8 -*-
from __future__ import absolute_import, division, print_function, unicode_literals
import numpy as np
try:
import matplotlib.pyplot as plt
except ImportError:
pass
RC_PARAMS = {'axes.axisbelow': True,
'axes.edgecolor': '.8',
'axes.facecolor': 'white',
'axes.grid': False,
'axes.labelcolor': '.15',
'axes.linewidth': 1.0,
'axes.labelpad' : 10.0,
'figure.facecolor': 'white',
'font.family': ['sans-serif'],
'font.sans-serif': ['Arial', 'Liberation Sans', 'sans-serif'],
'grid.color': '.8', 'grid.linestyle': '-',
'image.cmap': 'Greys',
'legend.frameon': False,
'legend.numpoints': 1, 'legend.scatterpoints': 1,
'lines.solid_capstyle': 'round',
'text.color': '1.0',
'xtick.color': '1.0', 'xtick.direction': 'out',
'xtick.major.size': 0.0, 'xtick.minor.size': 0.0,
'xtick.major.pad' : 5,
'ytick.color': '1.0', 'ytick.direction': 'out',
'ytick.major.size': 0.0, 'ytick.minor.size': 0.0,
'axes.ymargin' : 0.9,
'ytick.major.pad': 5}
COLOR_PAIRS = (((0.65098041296005249, 0.80784314870834351, 0.89019608497619629),
(0.12572087695201239, 0.47323337360924367, 0.707327968232772)),
((0.68899655751153521, 0.8681737867056154, 0.54376011946622071),
(0.21171857311445125, 0.63326415104024547, 0.1812226118410335)),
((0.98320646005518297, 0.5980161709820524, 0.59423301088459368),
(0.89059593116535862, 0.10449827132271793, 0.11108035462744099)),
((0.99175701702342312, 0.74648213716698619, 0.43401768935077328),
(0.99990772780250103, 0.50099192647372981, 0.0051211073118098693)),
((0.78329874347238004, 0.68724338552531095, 0.8336793640080622),
(0.42485198495434734, 0.2511495584950722, 0.60386007743723258)),
((0.99760092286502611, 0.99489427150464516, 0.5965244373854468),
(0.69411766529083252, 0.3490196168422699, 0.15686275064945221)))
def draw_termite(values_mat, col_labels, row_labels,
highlight_cols=None, highlight_colors=None,
save=False, pow_x = 0.66, pow_y = 0.8):
"""
Make a "termite" plot, typically used for assessing topic models with a tabular
layout that promotes comparison of terms both within and across topics.
Args:
values_mat (``np.ndarray`` or matrix): matrix of values with shape
(# row labels, # col labels) used to size the dots on the grid
col_labels (seq[str]): labels used to identify x-axis ticks on the grid
row_labels(seq[str]): labels used to identify y-axis ticks on the grid
highlight_cols (int or seq[int], optional): indices for columns
to visually highlight in the plot with contrasting colors
highlight_colors (tuple of 2-tuples): each 2-tuple corresponds to a pair
of (light/dark) matplotlib-friendly colors used to highlight a single
column; if not specified (default), a good set of 6 pairs are used
save (str, optional): give the full /path/to/fname on disk to save figure
Returns:
``matplotlib.axes.Axes.axis``: axis on which termite plot is plotted
Raises:
ValueError: if more columns are selected for highlighting than colors
or if any of the inputs' dimensions don't match
References:
.. Chuang, Jason, Christopher D. Manning, and Jeffrey Heer. "Termite:
Visualization techniques for assessing textual topic models."
Proceedings of the International Working Conference on Advanced
Visual Interfaces. ACM, 2012.
.. seealso:: :func:`TopicModel.termite_plot <textacy.tm.TopicModel.termite_plot>`
"""
try:
plt
except NameError:
raise ImportError(
'matplotlib is not installed, so textacy.viz won\'t work; install it \
individually, or along with textacy via `pip install textacy[viz]`')
n_rows, n_cols = values_mat.shape
max_val = np.max(values_mat)
if n_rows != len(row_labels):
msg = "values_mat and row_labels dimensions don't match: {} vs. {}".format(
n_rows, len(row_labels))
raise ValueError(msg)
if n_cols != len(col_labels):
msg = "values_mat and col_labels dimensions don't match: {} vs. {}".format(
n_cols, len(col_labels))
raise ValueError(msg)
if highlight_colors is None:
highlight_colors = COLOR_PAIRS
if highlight_cols is not None:
if isinstance(highlight_cols, int):
highlight_cols = (highlight_cols,)
elif len(highlight_cols) > len(highlight_colors):
msg = 'no more than {} columns may be highlighted at once'.format(
len(highlight_colors))
raise ValueError(msg)
highlight_colors = {hc: COLOR_PAIRS[i]
for i, hc in enumerate(highlight_cols)}
with plt.rc_context(RC_PARAMS):
fig, ax = plt.subplots(figsize=(pow(n_cols, pow_y), pow(n_rows, pow_x))) #hier fesntergröße
_ = ax.set_yticks(range(n_rows))
yticklabels = ax.set_yticklabels(row_labels,
fontsize=14, color='gray')
if highlight_cols is not None:
for i, ticklabel in enumerate(yticklabels):
max_tick_val = max(values_mat[i, hc] for hc in highlight_cols)
for hc in highlight_cols:
if max_tick_val > 0 and values_mat[i, hc] == max_tick_val:
ticklabel.set_color(highlight_colors[hc][1])
ax.get_xaxis().set_ticks_position('top')
_ = ax.set_xticks(range(n_cols))
xticklabels = ax.set_xticklabels(col_labels,
fontsize=14, color='gray',
rotation=30, ha='left')
if highlight_cols is not None:
gridlines = ax.get_xgridlines()
for i, ticklabel in enumerate(xticklabels):
if i in highlight_cols:
ticklabel.set_color(highlight_colors[i][1])
gridlines[i].set_color(highlight_colors[i][0])
gridlines[i].set_alpha(0.5)
for col_ind in range(n_cols):
if highlight_cols is not None and col_ind in highlight_cols:
ax.scatter([col_ind for _ in range(n_rows)],
[i for i in range(n_rows)],
s=600 * (values_mat[:, col_ind] / max_val),
alpha=0.5, linewidth=1,
color=highlight_colors[col_ind][0],
edgecolor=highlight_colors[col_ind][1])
else:
ax.scatter([col_ind for _ in range(n_rows)],
[i for i in range(n_rows)],
s=600 * (values_mat[:, col_ind] / max_val),
alpha=0.5, linewidth=1,
color='black', edgecolor='gray')
_ = ax.set_xlim(left=-1, right=n_cols)
_ = ax.set_ylim(bottom=-1, top=n_rows)
ax.invert_yaxis() # otherwise, values/labels go from bottom to top
#plt.ylim(ymax=5)
if save:
fig.savefig(save, bbox_inches='tight', dpi=100)
return ax