Exercise#7 Contextual Compressor

Objective

Learn to use LangChain Contextual Compressor classes.

Long Contextual Compressor

LangChain Contextual Compressor Retriever

LangChain Contextual Compressors

Flow

contextual-compressors-flow

Steps

Create a new notebook and copy/paste the following code step-by-step. Make sure to go through the code to understand what it is doing.

Import the required packages

from langchain_community.document_loaders import DirectoryLoader
from langchain_core.documents import Document
from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings
from langchain_community.vectorstores import Chroma
from langchain_text_splitters import RecursiveCharacterTextSplitter

from langchain_cohere import CohereEmbeddings

from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import LLMChainExtractor
from langchain.retrievers.document_compressors import LLMChainFilter
from langchain.retrievers.document_compressors import EmbeddingsFilter
from langchain_community.document_transformers import EmbeddingsRedundantFilter
from langchain.retrievers.document_compressors import DocumentCompressorPipeline

import warnings 
# Settings the warnings to be ignored 
warnings.filterwarnings('ignore') 

1. Create an LLM/Embedding for use by compressors

You must adjust the location of the key file in code. Sample uses Cohere command & Cohere embedding model but you may use a different LLM.

  • The LLM will be used by the compression strategy classes
from dotenv import load_dotenv
import sys
import json

# Load the file that contains the API keys - OPENAI_API_KEY
load_dotenv('C:\\Users\\raj\\.jupyter\\.env')

# setting path
sys.path.append('../')

from utils.create_chat_llm import create_gpt_chat_llm, create_cohere_chat_llm

# Try with GPT
llm = create_cohere_chat_llm()

llm_embeddings = CohereEmbeddings()

2. Utility function

  • Pretty prints the documents before/after compression
def print_documents(docs):
    for i, doc in enumerate(docs):
        print("#",i)
        print(doc.page_content)

def  dump_before_after_compression(base_retriever, compressor, question) :  #(: #bef, aft):
    results_before = base_retriever.invoke(question)
    results_after = compressor.invoke(question)
    
    print("BEFORE. Doc count = ", len(results_before))
    print("--------------------------------------------------")
    print_documents(results_before)
    print("--------------------------------------------------")
    print("AFTER. Doc count = ", len(results_after))
    print_documents(results_after)

3. Setup base retriever

  • using ChromaDB as a base retriever
# Create the Chroma vector store
embedding_function = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
vector_store = Chroma(collection_name="full_documents", embedding_function=embedding_function) 

# Load sample docs
loader = DirectoryLoader('./util', glob="**/*.txt")
docs = loader.load()

# Chunking
doc_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=20)
chunked_documents = doc_splitter.split_documents(docs)

# Add to vector DB
vector_store.add_documents(chunked_documents)

# Base retrievers
vector_store_retriever = vector_store.as_retriever()

4. LLMChainExtractor

  • Uses an LLM to extract relevant parts of a document.
# Create the compressor
llm_chain_extractor_compressor = LLMChainExtractor.from_llm(llm)

# Create the retriever
llm_chain_extractor_compressor_retriever = ContextualCompressionRetriever(
    base_retriever=vector_store_retriever, 
    base_compressor=llm_chain_extractor_compressor)
Test
  • Apply compression to retrieved results
  • Print the before/after results for comparison
question = "what is rag?"

dump_before_after_compression(vector_store_retriever, llm_chain_extractor_compressor_retriever, question)

5. LLM Chain Filter

Drops documents that are not relevant for the query.

https://api.python.langchain.com/en/latest/retrievers/langchain.retrievers.document_compressors.chain_filter.LLMChainFilter.html

# Create the compressor
llm_chain_filter_compressor = LLMChainFilter.from_llm(llm)

# Create the retriever
llm_chain_filter_compressor_retriever = ContextualCompressionRetriever(
    base_retriever=vector_store_retriever, 
    base_compressor=llm_chain_filter_compressor)
Test
  • Apply compression to retrieved results
  • Print the before/after results for comparison
question = "what is rag?"

dump_before_after_compression(vector_store_retriever, llm_chain_extractor_compressor_retriever, question)

6. Embeddings Filter

Uses embeddings to drop documents unrelated to the query.

https://api.python.langchain.com/en/latest/retrievers/langchain.retrievers.document_compressors.embeddings_filter.EmbeddingsFilter.html

Making an extra LLM call over each retrieved document is expensive and slow. The EmbeddingsFilter provides a cheaper and faster option by embedding the documents and query and only returning those documents which have sufficiently similar embeddings to the query.

# Create the compressor
# Play with the threshold to understand the behavior
similarity_threshold = 0.5
embeddings_filter = EmbeddingsFilter(embeddings=llm_embeddings, similarity_threshold=similarity_threshold)

# Create the retriever
llm_embeddings_filter_compressor_retriever = ContextualCompressionRetriever(
    base_retriever=vector_store_retriever, 
    base_compressor=embeddings_filter)
Test
question = "what is rag?"

dump_before_after_compression(vector_store_retriever, llm_embeddings_filter_compressor_retriever, question)

7. Compressor pipeline

  • Document compressor that uses a pipeline of Transformers.

https://api.python.langchain.com/en/latest/retrievers/langchain.retrievers.document_compressors.base.DocumentCompressorPipeline.html

transformers = [llm_chain_filter_compressor, embeddings_filter]
pipeline_compressor = DocumentCompressorPipeline(transformers=transformers)
# Create the retriever
pipeline_compressor_retriever = ContextualCompressionRetriever(
    base_retriever=vector_store_retriever, 
    base_compressor=pipeline_compressor)
Test
question = "what is rag?"

dump_before_after_compression(vector_store_retriever, pipeline_compressor_retriever, question)

Solution

The solution to the exercise is available in the following notebook.

ex-7-contextual-compressors-solution.png