# -*- coding: utf-8 -*- from __future__ import absolute_import, division, print_function, unicode_literals import logging import numpy as np from sklearn.decomposition import NMF, LatentDirichletAllocation, TruncatedSVD from sklearn.externals import joblib from textacy import viz import draw LOGGER = logging.getLogger(__name__) def termite_plot(model, doc_term_matrix, id2term, topics=-1, sort_topics_by='index', highlight_topics=None, n_terms=25, rank_terms_by='topic_weight', sort_terms_by='seriation', save=False, pow_x = 0.66, pow_y = 0.8): if highlight_topics is not None: if isinstance(highlight_topics, int): highlight_topics = (highlight_topics,) elif len(highlight_topics) > 6: raise ValueError('no more than 6 topics may be highlighted at once') # get topics indices if topics == -1: topic_inds = tuple(range(model.n_topics)) elif isinstance(topics, int): topic_inds = (topics,) else: topic_inds = tuple(topics) # get topic indices in sorted order if sort_topics_by == 'index': topic_inds = sorted(topic_inds) elif sort_topics_by == 'weight': topic_inds = tuple(topic_ind for topic_ind in np.argsort(model.topic_weights(model.transform(doc_term_matrix)))[::-1] if topic_ind in topic_inds) else: msg = 'invalid sort_topics_by value; must be in {}'.format( {'index', 'weight'}) raise ValueError(msg) # get column index of any topics to highlight in termite plot if highlight_topics is not None: highlight_cols = tuple(i for i in range(len(topic_inds)) if topic_inds[i] in highlight_topics) else: highlight_cols = None # get top term indices if rank_terms_by == 'corpus_weight': term_inds = np.argsort(np.ravel(doc_term_matrix.sum(axis=0)))[:-n_terms - 1:-1] elif rank_terms_by == 'topic_weight': term_inds = np.argsort(model.model.components_.sum(axis=0))[:-n_terms - 1:-1] else: msg = 'invalid rank_terms_by value; must be in {}'.format( {'corpus_weight', 'topic_weight'}) raise ValueError(msg) # get top term indices in sorted order if sort_terms_by == 'weight': pass elif sort_terms_by == 'index': term_inds = sorted(term_inds) elif sort_terms_by == 'alphabetical': term_inds = sorted(term_inds, key=lambda x: id2term[x]) elif sort_terms_by == 'seriation': topic_term_weights_mat = np.array( np.array([model.model.components_[topic_ind][term_inds] for topic_ind in topic_inds])).T # calculate similarity matrix topic_term_weights_sim = np.dot(topic_term_weights_mat, topic_term_weights_mat.T) # substract minimum of sim mat in order to keep sim mat nonnegative topic_term_weights_sim = topic_term_weights_sim - topic_term_weights_sim.min() # compute Laplacian matrice and its 2nd eigenvector L = np.diag(sum(topic_term_weights_sim, 1)) - topic_term_weights_sim D, V = np.linalg.eigh(L) D = D[np.argsort(D)] V = V[:, np.argsort(D)] fiedler = V[:, 1] # get permutation corresponding to sorting the 2nd eigenvector term_inds = [term_inds[i] for i in np.argsort(fiedler)] else: msg = 'invalid sort_terms_by value; must be in {}'.format( {'weight', 'index', 'alphabetical', 'seriation'}) raise ValueError(msg) # get topic and term labels topic_labels = tuple('topic {}'.format(topic_ind) for topic_ind in topic_inds) term_labels = tuple(id2term[term_ind] for term_ind in term_inds) # get topic-term weights to size dots term_topic_weights = np.array([model.model.components_[topic_ind][term_inds] for topic_ind in topic_inds]).T return draw.draw_termite( term_topic_weights, topic_labels, term_labels, highlight_cols=highlight_cols, save=save, pow_x = pow_x, pow_y = pow_y)