Finding Common Topics

How do you find thematic clusters in a large corpus of text documents? The techniques baked into sklearn (e.g. nonnegative matrix factorization, LDA) give you some intuition about common themes. But contemporary NLP has largely moved on from bag-of-words representations. We can do better with some transformer models!

For demonstration purposes, I'll use a few categories from the standard 20-newsgroups dataset. Ideally, we should be able to recover the four categories in the dataset (atheism, computer graphics, space and religion).

import numpy as np
from sklearn.datasets import fetch_20newsgroups
from sklearn.cluster import KMeans
from pydantic import BaseModel, Field
from transformers import AutoTokenizer, AutoModel
from langchain_ollama import ChatOllama
from langchain_core.messages import HumanMessage, SystemMessage
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import pandas as pd
from IPython.display import Markdown
categories = [
    "alt.atheism",
    "talk.religion.misc",
    "comp.graphics",
    "sci.space",
]

dataset = fetch_20newsgroups(
    remove=("headers", "footers", "quotes"),
    subset="all",
    categories=categories,
    shuffle=True,
    random_state=42,
)

Some of the documents in the dataset are only a few words; I only want to deal with documents that are least a couple hundred characters.

X = np.array(list(filter(lambda x: len(x) > 200, (d.strip() for d in dataset.data))))

First, I'll map each document to its embedding using the all-MiniLM BERT variant.

minilm_tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
minilm = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2').to('mps')
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0] #First element contains all embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(
        token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded,
                     1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
def get_embeddings(X):
    loader = DataLoader(X, batch_size=16)
    embeddings = []
    for batch in loader:
        toks = minilm_tokenizer(batch, padding=True, truncation=True,
                                return_tensors='pt')
        with torch.no_grad():
            model_output = minilm(**toks.to('mps'))
            result = F.normalize(mean_pooling(model_output,
                                              toks['attention_mask']), p=2, dim=1)
            embeddings.append(result.cpu())
    return torch.cat(embeddings)

Next, I'll cluster the embeddings with the standard k-means algorithm. There's far more sophisticated clustering techniques in sklearn, but this should be sufficient for the toy problem.

def get_clusters(embeddings):
    neural_kmeans = KMeans(n_clusters=4, n_init=25)
    neural_kmeans.fit(embeddings)
    docs_per_label = pd.DataFrame({'labels': neural_kmeans.labels_}).value_counts()
    return neural_kmeans, docs_per_label

Finally, I'll take a random set of documents closest to the center of each cluster and ask Llama to find a title for the collection.

def top_per_cluster(X, embeddings, kmeans, k=25, m=8):
    return [np.random.choice(
        X[kmeans.labels_ == i][np.argsort(((embeddings[kmeans.labels_ == i]
                                                - c)**2).sum(axis=-1))[:k]], m)
        for i, c in enumerate(kmeans.cluster_centers_)]
llama = ChatOllama(model="llama3", temperature=0)

I'll let the LLM contemplate common themes to itself before deciding on a title. We can require that the results get packaged together in a structured output format.

class SampleAnalysis(BaseModel):
    analysis: str = Field(description='Analysis of the texts.')
    category: str = Field(description='Category of the cluster.')
def llama_summarize(strs):
    prompt = [SystemMessage("""
Your task is to understand why the given documents were assigned to the same cluster.
- First analyze the documents in the cluster for common topics.
- Then, propose a category for the cluster containing these documents based on the analysis.""")]
    prompt.extend(strs)
    return llama.with_structured_output(SampleAnalysis).invoke(prompt)

Let's try it out!

def get_topics(X):
    embeddings = get_embeddings(X)
    neural_kmeans, docs_per_label = get_clusters(embeddings)
    top_embeddings = top_per_cluster(X, embeddings, neural_kmeans)
    results = [llama_summarize([a for a in t]) for t in top_embeddings]
    return pd.DataFrame({
        'category': [r.category for r in results],
        'n_docs': [int(docs_per_label[a]) for a in range(len(docs_per_label))]
    }).sort_values(by='n_docs', ascending=False)
Markdown(get_topics(X).to_markdown())
category n_docs
0 Computer Graphics 706
3 Space Exploration and Development 706
2 Debates about the existence of God and the nature of human reason, with a focus on criticizing Christian beliefs and practices. 632
1 Social Commentary/Philosophy 613

Sounds about right!