topicModelingTickets/draw1.py

106 lines
4.1 KiB
Python

# -*- coding: utf-8 -*-
from __future__ import absolute_import, division, print_function, unicode_literals
import logging
import numpy as np
from sklearn.decomposition import NMF, LatentDirichletAllocation, TruncatedSVD
from sklearn.externals import joblib
from textacy import viz
import draw
LOGGER = logging.getLogger(__name__)
def termite_plot(model, doc_term_matrix, id2term,
topics=-1, sort_topics_by='index', highlight_topics=None,
n_terms=25, rank_terms_by='topic_weight', sort_terms_by='seriation',
save=False, pow_x = 0.66, pow_y = 0.8):
if highlight_topics is not None:
if isinstance(highlight_topics, int):
highlight_topics = (highlight_topics,)
elif len(highlight_topics) > 6:
raise ValueError('no more than 6 topics may be highlighted at once')
# get topics indices
if topics == -1:
topic_inds = tuple(range(model.n_topics))
elif isinstance(topics, int):
topic_inds = (topics,)
else:
topic_inds = tuple(topics)
# get topic indices in sorted order
if sort_topics_by == 'index':
topic_inds = sorted(topic_inds)
elif sort_topics_by == 'weight':
topic_inds = tuple(topic_ind for topic_ind
in np.argsort(model.topic_weights(model.transform(doc_term_matrix)))[::-1]
if topic_ind in topic_inds)
else:
msg = 'invalid sort_topics_by value; must be in {}'.format(
{'index', 'weight'})
raise ValueError(msg)
# get column index of any topics to highlight in termite plot
if highlight_topics is not None:
highlight_cols = tuple(i for i in range(len(topic_inds))
if topic_inds[i] in highlight_topics)
else:
highlight_cols = None
# get top term indices
if rank_terms_by == 'corpus_weight':
term_inds = np.argsort(np.ravel(doc_term_matrix.sum(axis=0)))[:-n_terms - 1:-1]
elif rank_terms_by == 'topic_weight':
term_inds = np.argsort(model.model.components_.sum(axis=0))[:-n_terms - 1:-1]
else:
msg = 'invalid rank_terms_by value; must be in {}'.format(
{'corpus_weight', 'topic_weight'})
raise ValueError(msg)
# get top term indices in sorted order
if sort_terms_by == 'weight':
pass
elif sort_terms_by == 'index':
term_inds = sorted(term_inds)
elif sort_terms_by == 'alphabetical':
term_inds = sorted(term_inds, key=lambda x: id2term[x])
elif sort_terms_by == 'seriation':
topic_term_weights_mat = np.array(
np.array([model.model.components_[topic_ind][term_inds]
for topic_ind in topic_inds])).T
# calculate similarity matrix
topic_term_weights_sim = np.dot(topic_term_weights_mat, topic_term_weights_mat.T)
# substract minimum of sim mat in order to keep sim mat nonnegative
topic_term_weights_sim = topic_term_weights_sim - topic_term_weights_sim.min()
# compute Laplacian matrice and its 2nd eigenvector
L = np.diag(sum(topic_term_weights_sim, 1)) - topic_term_weights_sim
D, V = np.linalg.eigh(L)
D = D[np.argsort(D)]
V = V[:, np.argsort(D)]
fiedler = V[:, 1]
# get permutation corresponding to sorting the 2nd eigenvector
term_inds = [term_inds[i] for i in np.argsort(fiedler)]
else:
msg = 'invalid sort_terms_by value; must be in {}'.format(
{'weight', 'index', 'alphabetical', 'seriation'})
raise ValueError(msg)
# get topic and term labels
topic_labels = tuple('topic {}'.format(topic_ind) for topic_ind in topic_inds)
term_labels = tuple(id2term[term_ind] for term_ind in term_inds)
# get topic-term weights to size dots
term_topic_weights = np.array([model.model.components_[topic_ind][term_inds]
for topic_ind in topic_inds]).T
return draw.draw_termite(
term_topic_weights, topic_labels, term_labels,
highlight_cols=highlight_cols, save=save, pow_x = pow_x, pow_y = pow_y)