topicModelingTickets/java_LabledLDA/src/jgibblda/Model.java

670 lines
22 KiB
Java

/*
* Copyright (C) 2007 by
*
* Xuan-Hieu Phan
* hieuxuan@ecei.tohoku.ac.jp or pxhieu@gmail.com
* Graduate School of Information Sciences
* Tohoku University
*
* Cam-Tu Nguyen
* ncamtu@gmail.com
* College of Technology
* Vietnam National University, Hanoi
*
* JGibbsLDA is a free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published
* by the Free Software Foundation; either version 2 of the License,
* or (at your option) any later version.
*
* JGibbsLDA is distributed in the hope that it will be useful, but
* WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with JGibbsLDA; if not, write to the Free Software Foundation,
* Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA.
*/
package jgibblda;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.OutputStreamWriter;
import java.util.ArrayList;
import java.util.Collections;
import java.util.StringTokenizer;
import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream;
import gnu.trove.list.array.TIntArrayList;
import gnu.trove.map.hash.TIntObjectHashMap;
public class Model {
//---------------------------------------------------------------
// Class Variables
//---------------------------------------------------------------
public static String tassignSuffix = ".tassign.gz"; // suffix for topic assignment file
public static String thetaSuffix = ".theta.gz"; // suffix for theta (topic - document distribution) file
public static String phiSuffix = ".phi.gz"; // suffix for phi file (topic - word distribution) file
public static String othersSuffix = ".others.gz"; // suffix for containing other parameters
public static String twordsSuffix = ".twords.gz"; // suffix for file containing words-per-topics
public static String wordMapSuffix = ".wordmap.gz"; // suffix for file containing word to id map
//---------------------------------------------------------------
// Model Parameters and Variables
//---------------------------------------------------------------
public String dir = "./";
public String dfile = "trndocs.dat";
public boolean unlabeled = false;
public String modelName = "model";
public LDADataset data; // link to a dataset
public int M = 0; // dataset size (i.e., number of docs)
public int V = 0; // vocabulary size
public int K = 100; // number of topics
public double alpha; // LDA hyperparameters
public double beta = 0.01; // LDA hyperparameters
public int niters = 1000; // number of Gibbs sampling iteration
public int nburnin = 500; // number of Gibbs sampling burn-in iterations
public int samplingLag = 5;// Gibbs sampling sample lag
public int numSamples = 1; // number of samples taken
public int liter = 0; // the iteration at which the model was saved
public int twords = 20; // print out top words per each topic
// Estimated/Inferenced parameters
public double[][] theta = null; // theta: document - topic distributions, size M x K
public double[][] phi = null; // phi: topic-word distributions, size K x V
// Temp variables while sampling
public TIntArrayList[] z = null; // topic assignments for words, size M x doc.size()
protected int[][] nw = null; // nw[i][j]: number of instances of word/term i assigned to topic j, size V x K
protected int[][] nd = null; // nd[i][j]: number of words in document i assigned to topic j, size M x K
protected int[] nwsum = null; // nwsum[j]: total number of words assigned to topic j, size K
protected int[] ndsum = null; // ndsum[i]: total number of words in document i, size M
protected ArrayList<TIntObjectHashMap<int[]>> nw_inf = null; // nw[m][i][j]: number of instances of word/term i assigned to topic j in doc m, size M x V x K
protected int[][] nwsum_inf = null; // nwsum[m][j]: total number of words assigned to topic j in doc m, size M x K
// temp variables for sampling
protected double[] p = null;
//---------------------------------------------------------------
// Constructors
//---------------------------------------------------------------
public Model(LDACmdOption option) throws FileNotFoundException, IOException
{
this(option, null);
}
public Model(LDACmdOption option, Model trnModel) throws FileNotFoundException, IOException
{
modelName = option.modelName;
K = option.K;
alpha = option.alpha;
if (alpha < 0.0)
alpha = 50.0 / K;
if (option.beta >= 0)
beta = option.beta;
niters = option.niters;
nburnin = option.nburnin;
samplingLag = option.samplingLag;
dir = option.dir;
if (dir.endsWith(File.separator))
dir = dir.substring(0, dir.length() - 1);
dfile = option.dfile;
unlabeled = option.unlabeled;
twords = option.twords;
// initialize dataset
data = new LDADataset();
// process trnModel (if given)
if (trnModel != null) {
data.setDictionary(trnModel.data.localDict);
K = trnModel.K;
// use hyperparameters from model (if not overridden in options)
if (option.alpha < 0.0)
alpha = trnModel.alpha;
if (option.beta < 0.0)
beta = trnModel.beta;
}
// read in data
data.readDataSet(dir + File.separator + dfile, unlabeled);
}
//---------------------------------------------------------------
// Init Methods
//---------------------------------------------------------------
/**
* Init parameters for estimation or inference
*/
public boolean init(boolean random)
{
if (random) {
M = data.M;
V = data.V;
z = new TIntArrayList[M];
} else {
if (!loadModel()) {
System.out.println("Fail to load word-topic assignment file of the model!");
return false;
}
// debug output
System.out.println("Model loaded:");
System.out.println("\talpha:" + alpha);
System.out.println("\tbeta:" + beta);
System.out.println("\tK:" + K);
System.out.println("\tM:" + M);
System.out.println("\tV:" + V);
}
p = new double[K];
initSS();
for (int m = 0; m < data.M; m++){
if (random) {
z[m] = new TIntArrayList();
}
// initilize for z
int N = data.docs.get(m).length;
for (int n = 0; n < N; n++){
int w = data.docs.get(m).words[n];
int topic;
// random init a topic or load existing topic from z[m]
if (random) {
topic = (int)Math.floor(Math.random() * K);
z[m].add(topic);
} else {
topic = z[m].get(n);
}
nw[w][topic]++; // number of instances of word assigned to topic j
nd[m][topic]++; // number of words in document i assigned to topic j
nwsum[topic]++; // total number of words assigned to topic j
}
ndsum[m] = N; // total number of words in document i
}
theta = new double[M][K];
phi = new double[K][V];
return true;
}
public boolean initInf()
{
nw_inf = new ArrayList<TIntObjectHashMap<int[]>>();
nwsum_inf = new int[M][K];
for (int m = 0; m < M; m++) {
for (int k = 0; k < K; k++) {
nwsum_inf[m][k] = 0;
}
}
for (int m = 0; m < data.M; m++){
nw_inf.add(m, new TIntObjectHashMap<int[]>());
// initilize for z
int N = data.docs.get(m).length;
for (int n = 0; n < N; n++){
int w = data.docs.get(m).words[n];
int topic = z[m].get(n);
if (!nw_inf.get(m).containsKey(w)) {
int[] nw_inf_m_w = new int[K];
for (int k = 0; k < K; k++) {
nw_inf_m_w[k] = 0;
}
nw_inf.get(m).put(w, nw_inf_m_w);
}
nw_inf.get(m).get(w)[topic]++; // number of instances of word assigned to topic j in doc m
//nw_inf[m][w][topic]++; // number of instances of word assigned to topic j in doc m
nwsum_inf[m][topic]++; // total number of words assigned to topic j in doc m
}
}
return true;
}
/**
* Init sufficient stats
*/
protected void initSS()
{
nw = new int[V][K];
for (int w = 0; w < V; w++){
for (int k = 0; k < K; k++){
nw[w][k] = 0;
}
}
nd = new int[M][K];
for (int m = 0; m < M; m++){
for (int k = 0; k < K; k++){
nd[m][k] = 0;
}
}
nwsum = new int[K];
for (int k = 0; k < K; k++){
nwsum[k] = 0;
}
ndsum = new int[M];
for (int m = 0; m < M; m++){
ndsum[m] = 0;
}
}
//---------------------------------------------------------------
// Update Methods
//---------------------------------------------------------------
public void updateParams()
{
updateTheta();
updatePhi();
numSamples++;
}
public void updateParams(Model trnModel)
{
updateTheta();
updatePhi(trnModel);
numSamples++;
}
public void updateTheta()
{
double Kalpha = K * alpha;
for (int m = 0; m < M; m++) {
for (int k = 0; k < K; k++) {
if (numSamples > 1) theta[m][k] *= numSamples - 1; // convert from mean to sum
theta[m][k] += (nd[m][k] + alpha) / (ndsum[m] + Kalpha);
if (numSamples > 1) theta[m][k] /= numSamples; // convert from sum to mean
}
}
}
public void updatePhi()
{
double Vbeta = V * beta;
for (int k = 0; k < K; k++) {
for (int w = 0; w < V; w++) {
if (numSamples > 1) phi[k][w] *= numSamples - 1; // convert from mean to sum
phi[k][w] += (nw[w][k] + beta) / (nwsum[k] + Vbeta);
if (numSamples > 1) phi[k][w] /= numSamples; // convert from sum to mean
}
}
}
// for inference
public void updatePhi(Model trnModel)
{
double Vbeta = trnModel.V * beta;
for (int k = 0; k < K; k++) {
for (int _w = 0; _w < V; _w++) {
if (data.lid2gid.containsKey(_w)) {
int id = data.lid2gid.get(_w);
if (numSamples > 1) phi[k][_w] *= numSamples - 1; // convert from mean to sum
phi[k][_w] += (trnModel.nw[id][k] + nw[_w][k] + beta) / (trnModel.nwsum[k] + nwsum[k] + Vbeta);
if (numSamples > 1) phi[k][_w] /= numSamples; // convert from sum to mean
} // else ignore words that don't appear in training
} //end foreach word
} // end foreach topic
}
//---------------------------------------------------------------
// I/O Methods
//---------------------------------------------------------------
/**
* Save model
*/
public boolean saveModel()
{
return saveModel("");
}
public boolean saveModel(String modelPrefix)
{
if (!saveModelTAssign(dir + File.separator + modelPrefix + modelName + tassignSuffix)) {
return false;
}
if (!saveModelOthers(dir + File.separator + modelPrefix + modelName + othersSuffix)) {
return false;
}
if (!saveModelTheta(dir + File.separator + modelPrefix + modelName + thetaSuffix)) {
return false;
}
//if (!saveModelPhi(dir + File.separator + modelPrefix + modelName + phiSuffix)) {
// return false;
//}
if (twords > 0) {
if (!saveModelTwords(dir + File.separator + modelPrefix + modelName + twordsSuffix)) {
return false;
}
}
if (!data.localDict.writeWordMap(dir + File.separator + modelPrefix + modelName + wordMapSuffix)) {
return false;
}
return true;
}
/**
* Save word-topic assignments for this model
*/
public boolean saveModelTAssign(String filename) {
int i, j;
try{
BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(
new GZIPOutputStream(
new FileOutputStream(filename)), "UTF-8"));
//write docs with topic assignments for words
for (i = 0; i < data.M; i++) {
for (j = 0; j < data.docs.get(i).length; ++j) {
writer.write(data.docs.get(i).words[j] + ":" + z[i].get(j) + " ");
}
writer.write("\n");
}
writer.close();
}
catch (Exception e) {
System.out.println("Error while saving model tassign: " + e.getMessage());
e.printStackTrace();
return false;
}
return true;
}
/**
* Save theta (topic distribution) for this model
*/
public boolean saveModelTheta(String filename) {
try{
BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(
new GZIPOutputStream(
new FileOutputStream(filename)), "UTF-8"));
for (int i = 0; i < M; i++) {
for (int j = 0; j < K; j++) {
if (theta[i][j] > 0) {
writer.write(j + ":" + theta[i][j] + " ");
}
}
writer.write("\n");
}
writer.close();
}
catch (Exception e){
System.out.println("Error while saving topic distribution file for this model: " + e.getMessage());
e.printStackTrace();
return false;
}
return true;
}
/**
* Save word-topic distribution
*/
public boolean saveModelPhi(String filename)
{
try {
BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(
new GZIPOutputStream(
new FileOutputStream(filename)), "UTF-8"));
for (int i = 0; i < K; i++) {
for (int j = 0; j < V; j++) {
if (phi[i][j] > 0) {
writer.write(j + ":" + phi[i][j] + " ");
}
}
writer.write("\n");
}
writer.close();
}
catch (Exception e) {
System.out.println("Error while saving word-topic distribution:" + e.getMessage());
e.printStackTrace();
return false;
}
return true;
}
/**
* Save other information of this model
*/
public boolean saveModelOthers(String filename){
try{
BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(
new GZIPOutputStream(
new FileOutputStream(filename)), "UTF-8"));
writer.write("alpha=" + alpha + "\n");
writer.write("beta=" + beta + "\n");
writer.write("ntopics=" + K + "\n");
writer.write("ndocs=" + M + "\n");
writer.write("nwords=" + V + "\n");
writer.write("liters=" + liter + "\n");
writer.close();
}
catch(Exception e){
System.out.println("Error while saving model others:" + e.getMessage());
e.printStackTrace();
return false;
}
return true;
}
/**
* Save model the most likely words for each topic
*/
public boolean saveModelTwords(String filename){
try{
BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(
new GZIPOutputStream(
new FileOutputStream(filename)), "UTF-8"));
if (twords > V){
twords = V;
}
for (int k = 0; k < K; k++){
ArrayList<Pair> wordsProbsList = new ArrayList<Pair>();
for (int w = 0; w < V; w++){
Pair p = new Pair(w, phi[k][w], false);
wordsProbsList.add(p);
}//end foreach word
//print topic
writer.write("Topic " + k + ":\n");
Collections.sort(wordsProbsList);
for (int i = 0; i < twords; i++){
if (data.localDict.contains((Integer)wordsProbsList.get(i).first)){
String word = data.localDict.getWord((Integer)wordsProbsList.get(i).first);
writer.write("\t" + word + "\t" + wordsProbsList.get(i).second + "\n");
}
}
} //end foreach topic
writer.close();
}
catch(Exception e){
System.out.println("Error while saving model twords: " + e.getMessage());
e.printStackTrace();
return false;
}
return true;
}
/**
* Load saved model
*/
public boolean loadModel(){
if (!readOthersFile(dir + File.separator + modelName + othersSuffix))
return false;
if (!readTAssignFile(dir + File.separator + modelName + tassignSuffix))
return false;
// read dictionary
Dictionary dict = new Dictionary();
if (!dict.readWordMap(dir + File.separator + modelName + wordMapSuffix))
return false;
data.localDict = dict;
return true;
}
/**
* Load "others" file to get parameters
*/
protected boolean readOthersFile(String otherFile){
try {
BufferedReader reader = new BufferedReader(new InputStreamReader(
new GZIPInputStream(
new FileInputStream(otherFile)), "UTF-8"));
String line;
while((line = reader.readLine()) != null){
StringTokenizer tknr = new StringTokenizer(line,"= \t\r\n");
int count = tknr.countTokens();
if (count != 2)
continue;
String optstr = tknr.nextToken();
String optval = tknr.nextToken();
if (optstr.equalsIgnoreCase("alpha")){
alpha = Double.parseDouble(optval);
}
else if (optstr.equalsIgnoreCase("beta")){
beta = Double.parseDouble(optval);
}
else if (optstr.equalsIgnoreCase("ntopics")){
K = Integer.parseInt(optval);
}
else if (optstr.equalsIgnoreCase("liter")){
liter = Integer.parseInt(optval);
}
else if (optstr.equalsIgnoreCase("nwords")){
V = Integer.parseInt(optval);
}
else if (optstr.equalsIgnoreCase("ndocs")){
M = Integer.parseInt(optval);
}
else {
// any more?
}
}
reader.close();
}
catch (Exception e){
System.out.println("Error while reading other file:" + e.getMessage());
e.printStackTrace();
return false;
}
return true;
}
/**
* Load word-topic assignments for this model
*/
protected boolean readTAssignFile(String tassignFile)
{
try {
int i,j;
BufferedReader reader = new BufferedReader(new InputStreamReader(
new GZIPInputStream(
new FileInputStream(tassignFile)), "UTF-8"));
String line;
z = new TIntArrayList[M];
data = new LDADataset();
data.setM(M);
data.V = V;
for (i = 0; i < M; i++){
line = reader.readLine();
StringTokenizer tknr = new StringTokenizer(line, " \t\r\n");
int length = tknr.countTokens();
TIntArrayList words = new TIntArrayList();
TIntArrayList topics = new TIntArrayList();
for (j = 0; j < length; j++){
String token = tknr.nextToken();
StringTokenizer tknr2 = new StringTokenizer(token, ":");
if (tknr2.countTokens() != 2){
System.out.println("Invalid word-topic assignment line\n");
return false;
}
words.add(Integer.parseInt(tknr2.nextToken()));
topics.add(Integer.parseInt(tknr2.nextToken()));
}//end for each topic assignment
//allocate and add new document to the corpus
Document doc = new Document(words);
data.setDoc(doc, i);
//assign values for z
z[i] = new TIntArrayList();
for (j = 0; j < topics.size(); j++){
z[i].add(topics.get(j));
}
}//end for each doc
reader.close();
}
catch (Exception e){
System.out.println("Error while loading model: " + e.getMessage());
e.printStackTrace();
return false;
}
return true;
}
}