106 lines
4.1 KiB
Python
106 lines
4.1 KiB
Python
|
# -*- coding: utf-8 -*-
|
||
|
from __future__ import absolute_import, division, print_function, unicode_literals
|
||
|
|
||
|
import logging
|
||
|
|
||
|
import numpy as np
|
||
|
from sklearn.decomposition import NMF, LatentDirichletAllocation, TruncatedSVD
|
||
|
from sklearn.externals import joblib
|
||
|
|
||
|
from textacy import viz
|
||
|
import draw
|
||
|
|
||
|
LOGGER = logging.getLogger(__name__)
|
||
|
|
||
|
|
||
|
|
||
|
def termite_plot(model, doc_term_matrix, id2term,
|
||
|
topics=-1, sort_topics_by='index', highlight_topics=None,
|
||
|
n_terms=25, rank_terms_by='topic_weight', sort_terms_by='seriation',
|
||
|
save=False, pow_x = 0.66, pow_y = 0.8):
|
||
|
|
||
|
|
||
|
|
||
|
if highlight_topics is not None:
|
||
|
if isinstance(highlight_topics, int):
|
||
|
highlight_topics = (highlight_topics,)
|
||
|
elif len(highlight_topics) > 6:
|
||
|
raise ValueError('no more than 6 topics may be highlighted at once')
|
||
|
|
||
|
# get topics indices
|
||
|
if topics == -1:
|
||
|
topic_inds = tuple(range(model.n_topics))
|
||
|
elif isinstance(topics, int):
|
||
|
topic_inds = (topics,)
|
||
|
else:
|
||
|
topic_inds = tuple(topics)
|
||
|
|
||
|
# get topic indices in sorted order
|
||
|
if sort_topics_by == 'index':
|
||
|
topic_inds = sorted(topic_inds)
|
||
|
elif sort_topics_by == 'weight':
|
||
|
topic_inds = tuple(topic_ind for topic_ind
|
||
|
in np.argsort(model.topic_weights(model.transform(doc_term_matrix)))[::-1]
|
||
|
if topic_ind in topic_inds)
|
||
|
else:
|
||
|
msg = 'invalid sort_topics_by value; must be in {}'.format(
|
||
|
{'index', 'weight'})
|
||
|
raise ValueError(msg)
|
||
|
|
||
|
# get column index of any topics to highlight in termite plot
|
||
|
if highlight_topics is not None:
|
||
|
highlight_cols = tuple(i for i in range(len(topic_inds))
|
||
|
if topic_inds[i] in highlight_topics)
|
||
|
else:
|
||
|
highlight_cols = None
|
||
|
|
||
|
# get top term indices
|
||
|
if rank_terms_by == 'corpus_weight':
|
||
|
term_inds = np.argsort(np.ravel(doc_term_matrix.sum(axis=0)))[:-n_terms - 1:-1]
|
||
|
elif rank_terms_by == 'topic_weight':
|
||
|
term_inds = np.argsort(model.model.components_.sum(axis=0))[:-n_terms - 1:-1]
|
||
|
else:
|
||
|
msg = 'invalid rank_terms_by value; must be in {}'.format(
|
||
|
{'corpus_weight', 'topic_weight'})
|
||
|
raise ValueError(msg)
|
||
|
|
||
|
# get top term indices in sorted order
|
||
|
if sort_terms_by == 'weight':
|
||
|
pass
|
||
|
elif sort_terms_by == 'index':
|
||
|
term_inds = sorted(term_inds)
|
||
|
elif sort_terms_by == 'alphabetical':
|
||
|
term_inds = sorted(term_inds, key=lambda x: id2term[x])
|
||
|
elif sort_terms_by == 'seriation':
|
||
|
topic_term_weights_mat = np.array(
|
||
|
np.array([model.model.components_[topic_ind][term_inds]
|
||
|
for topic_ind in topic_inds])).T
|
||
|
# calculate similarity matrix
|
||
|
topic_term_weights_sim = np.dot(topic_term_weights_mat, topic_term_weights_mat.T)
|
||
|
# substract minimum of sim mat in order to keep sim mat nonnegative
|
||
|
topic_term_weights_sim = topic_term_weights_sim - topic_term_weights_sim.min()
|
||
|
# compute Laplacian matrice and its 2nd eigenvector
|
||
|
L = np.diag(sum(topic_term_weights_sim, 1)) - topic_term_weights_sim
|
||
|
D, V = np.linalg.eigh(L)
|
||
|
D = D[np.argsort(D)]
|
||
|
V = V[:, np.argsort(D)]
|
||
|
fiedler = V[:, 1]
|
||
|
# get permutation corresponding to sorting the 2nd eigenvector
|
||
|
term_inds = [term_inds[i] for i in np.argsort(fiedler)]
|
||
|
else:
|
||
|
msg = 'invalid sort_terms_by value; must be in {}'.format(
|
||
|
{'weight', 'index', 'alphabetical', 'seriation'})
|
||
|
raise ValueError(msg)
|
||
|
|
||
|
# get topic and term labels
|
||
|
topic_labels = tuple('topic {}'.format(topic_ind) for topic_ind in topic_inds)
|
||
|
term_labels = tuple(id2term[term_ind] for term_ind in term_inds)
|
||
|
|
||
|
# get topic-term weights to size dots
|
||
|
term_topic_weights = np.array([model.model.components_[topic_ind][term_inds]
|
||
|
for topic_ind in topic_inds]).T
|
||
|
|
||
|
return draw.draw_termite(
|
||
|
term_topic_weights, topic_labels, term_labels,
|
||
|
highlight_cols=highlight_cols, save=save, pow_x = pow_x, pow_y = pow_y)
|