user1583016
user1583016

Reputation: 79

Hidden Markov Model for Topical Text Segmentation

I'm attempting to write a function that splits a long document into shorter segments of text, splitting the text into the topics discussed as a step in a data processing pipeline prior to embedding the shorter segments of text for vector search.

I'm attempting to use v1.0 of the pomegranate python library as I get the impression it will be more performant than some of the other common options I looked at. Below is my code.

# import pomegranate library and numpy
import pomegranate
from pomegranate.hmm import DenseHMM
import numpy
import torch

# import LabelEncoder from sklearn
from sklearn.preprocessing import LabelEncoder

def segment_document(document):
  # define the range of possible topics
  min_topics = 2
  max_topics = 20
  
  # initialize a list to store the models and their scores
  models = []
  
  # create an encoder object
  encoder = LabelEncoder()

  # fit the encoder to the vocabulary
  vocab = numpy.unique(numpy.char.split(document))[0]

  encoder.fit(vocab)

  # transform the document into numeric values
  document = encoder.transform(document.split())
  document = document.reshape(1, -1) # reshape into a 2D array with one row
  # add another dimension to document for DenseHMM input
  document = numpy.expand_dims(document, axis=-1)
  # convert document to torch tensor of float type
  document = torch.from_numpy(document).float()

  vocab_size = len(vocab)
  # loop over the number of topics
  for num_topics in range(min_topics, max_topics + 1):
    # initialize state priors and transition probabilities
    state_priors = numpy.random.dirichlet(numpy.ones(num_topics))
    trans_probs = numpy.random.dirichlet(numpy.ones(num_topics), size=num_topics)
    
    # initialize emission factors with random or uniform probabilities
    emission_probs = []
    for i in range(num_topics):
      # create a categorical distribution from random integers and their counts
      ints = numpy.random.randint(0, vocab_size, size=100) # generate 100 random integers in [0, vocab_size)
      counts = numpy.bincount(ints) # count the occurrences of each integer
      probs = counts / counts.sum() # normalize the counts to get probabilities
      print(f"probs before adjustment: {probs}, sum: {numpy.sum(probs)}")
      assert numpy.isclose(numpy.sum(probs), 1) # check if the sum is close to 1 within a tolerance
      print(f"probs after adjustment: {probs}, sum: {numpy.sum(probs)}")
      probs = probs.reshape(1, -1) # reshape into a 2D array with one row
      emission_probs.append(pomegranate.distributions.Categorical(probs)) # create a categorical distribution
    
    # create model object
    model = DenseHMM()
    
    # add distributions to model
    model.add_distributions(emission_probs)
    
    # set starts and ends probabilities for model
    model.starts = state_priors
    model.ends = numpy.zeros(num_topics)
    
    #train model using Baum-Welch algorithm
    model.fit([document])
    
    # calculate the log probability of the document under the model
    logprob = model.log_probability(document)
    
    # calculate the number of parameters of the model
    n_params = num_topics * (num_topics - 1) + num_topics * vocab_size
    
    # calculate the BIC score of the model
    bic = -2 * logprob + n_params * numpy.log(len(document))
    
    # calculate the AIC score of the model
    aic = -2 * logprob + 2 * n_params
    
    # append the model and its scores to the list
    models.append((model, bic, aic))
  
  # sort the models by their BIC scores in ascending order
  models.sort(key=lambda x: x[1])
  
  # get the best model by BIC and its viterbi states
  best_model_bic, best_bic, _ = models[0]
  viterbi_states_bic, viterbi_logprob_bic = best_model_bic.viterbi(document)
  
  # sort the models by their AIC scores in ascending order
  models.sort(key=lambda x: x[2])
  
  # get the best model by AIC and its viterbi states
  best_model_aic, _, best_aic = models[0]
  viterbi_states_aic, viterbi_logprob_aic = best_model_aic.viterbi(document)
  
  # compare the BIC and AIC scores and choose the best one
  if best_bic < best_aic:
    # use BIC as the criterion and segment document based on viterbi states
    segments = []
    current_segment = []
    current_state = viterbi_states_bic[0][1].name
    for i in range(1, len(viterbi_states_bic)):
      word = document[i-1]
      state = viterbi_states_bic[i][1].name
      if state == current_state:
        # add word to current segment
        current_segment.append(word)
      else:
        # start a new segment
        segments.append((current_state, current_segment))
        current_segment = [word]
        current_state = state
    
    # add last segment
    segments.append((current_state, current_segment))
    
    # return segments and BIC score
    return segments, best_bic
  
  else:
    # use AIC as the criterion and segment document based on viterbi states
    segments = []
    current_segment = []
    current_state = viterbi_states_aic[0][1].name
    for i in range(1, len(viterbi_states_aic)):
      word = document[i-1]
      state = viterbi_states_aic[i][1].name
      if state == current_state:
        # add word to current segment
        current_segment.append(word)
      else:
        # start a new segment
        segments.append((current_state, current_segment))
        current_segment = [word]
        current_state = state
    
    # add last segment
    segments.append((current_state, current_segment))
    
    # return segments and AIC score
    return segments, best_aic


if __name__ == "__main__":
  # Get the document as input from the user.
  document = input("Enter a document: ")

  # Segment the document into distinct segments.
  segments = segment_document(document)

  # Print the segments.
  for segment in segments:
    print("Segment:", segment)

And the full traceback on error I'm currently getting is:

---------------------------------------------------------------------------

IndexError                                Traceback (most recent call last)

<ipython-input-70-38334fe21de7> in <cell line: 144>()
    147 
    148   # Segment the document into distinct segments.
--> 149   segments = segment_document(document)
    150 
    151   # Print the segments.

6 frames

<ipython-input-70-38334fe21de7> in segment_document(document)
     63 
     64     #train model using Baum-Welch algorithm
---> 65     model.fit([document])
     66 
     67     # calculate the log probability of the document under the model

/usr/local/lib/python3.10/dist-packages/pomegranate/hmm/_base.py in fit(self, X, sample_weight, priors)
    604                                 p_ = None if priors is None else priors[j]
    605 
--> 606                                 logp += self.summarize(X_, sample_weight=w_, priors=p_).sum()
    607 
    608                         # Calculate and check improvement and optionally print it

/usr/local/lib/python3.10/dist-packages/pomegranate/hmm/dense_hmm.py in summarize(self, X, sample_weight, emissions, priors)
    541         """
    542 
--> 543         X, emissions, sample_weight = super().summarize(X, 
    544             sample_weight=sample_weight, emissions=emissions, priors=priors)
    545 

/usr/local/lib/python3.10/dist-packages/pomegranate/hmm/_base.py in summarize(self, X, sample_weight, emissions, priors)
    681         X = _check_parameter(_cast_as_tensor(X), "X", ndim=3, 
    682             shape=(-1, -1, self.d), check_parameter=self.check_data)
--> 683                 emissions = _check_inputs(self, X, emissions, priors)
    684 
    685                 if sample_weight is None:

/usr/local/lib/python3.10/dist-packages/pomegranate/hmm/_base.py in _check_inputs(model, X, emissions, priors)
     26         ndim=3)
     27         if emissions is None:
---> 28                 emissions = model._emission_matrix(X, priors=priors)
     29 
     30         return emissions

/usr/local/lib/python3.10/dist-packages/pomegranate/hmm/_base.py in _emission_matrix(self, X, priors)
    285 
    286                 for i, node in enumerate(self.distributions):
--> 287                         logp = node.log_probability(X)
    288                         if isinstance(logp, torch.masked.MaskedTensor):
    289                                 logp = logp._masked_data

/usr/local/lib/python3.10/dist-packages/pomegranate/distributions/categorical.py in log_probability(self, X)
    173                 logps = torch.zeros(X.shape[0], dtype=self.probs.dtype)
    174                 for i in range(self.d):
--> 175                         logps += self._log_probs[i][X[:, i]]
    176 
    177                 return logps

IndexError: tensors used as indices must be long, int, byte or bool tensors

Upvotes: 0

Views: 334

Answers (1)

Flob
Flob

Reputation: 1

Try changing this line:

document = document.reshape(1, -1) # reshape into a 2D array with one row

To:

document = numpy.reshape(document, (1, -1))

This should reshape the document into a 2D array with one row, and might fix the issue

Upvotes: 0

Related Questions