Learn to use LangChain Contextual Compressor classes.
LangChain Contextual Compressor Retriever
LangChain Contextual Compressors
Create a new notebook and copy/paste the following code step-by-step. Make sure to go through the code to understand what it is doing.
from langchain_community.document_loaders import DirectoryLoader
from langchain_core.documents import Document
from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings
from langchain_community.vectorstores import Chroma
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_cohere import CohereEmbeddings
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import LLMChainExtractor
from langchain.retrievers.document_compressors import LLMChainFilter
from langchain.retrievers.document_compressors import EmbeddingsFilter
from langchain_community.document_transformers import EmbeddingsRedundantFilter
from langchain.retrievers.document_compressors import DocumentCompressorPipeline
import warnings
# Settings the warnings to be ignored
warnings.filterwarnings('ignore')
You must adjust the location of the key file in code. Sample uses Cohere command & Cohere embedding model but you may use a different LLM.
from dotenv import load_dotenv
import sys
import json
# Load the file that contains the API keys - OPENAI_API_KEY
load_dotenv('C:\\Users\\raj\\.jupyter\\.env')
# setting path
sys.path.append('../')
from utils.create_chat_llm import create_gpt_chat_llm, create_cohere_chat_llm
# Try with GPT
llm = create_cohere_chat_llm()
llm_embeddings = CohereEmbeddings()
def print_documents(docs):
for i, doc in enumerate(docs):
print("#",i)
print(doc.page_content)
def dump_before_after_compression(base_retriever, compressor, question) : #(: #bef, aft):
results_before = base_retriever.invoke(question)
results_after = compressor.invoke(question)
print("BEFORE. Doc count = ", len(results_before))
print("--------------------------------------------------")
print_documents(results_before)
print("--------------------------------------------------")
print("AFTER. Doc count = ", len(results_after))
print_documents(results_after)
# Create the Chroma vector store
embedding_function = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
vector_store = Chroma(collection_name="full_documents", embedding_function=embedding_function)
# Load sample docs
loader = DirectoryLoader('./util', glob="**/*.txt")
docs = loader.load()
# Chunking
doc_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=20)
chunked_documents = doc_splitter.split_documents(docs)
# Add to vector DB
vector_store.add_documents(chunked_documents)
# Base retrievers
vector_store_retriever = vector_store.as_retriever()
# Create the compressor
llm_chain_extractor_compressor = LLMChainExtractor.from_llm(llm)
# Create the retriever
llm_chain_extractor_compressor_retriever = ContextualCompressionRetriever(
base_retriever=vector_store_retriever,
base_compressor=llm_chain_extractor_compressor)
question = "what is rag?"
dump_before_after_compression(vector_store_retriever, llm_chain_extractor_compressor_retriever, question)
Drops documents that are not relevant for the query.
# Create the compressor
llm_chain_filter_compressor = LLMChainFilter.from_llm(llm)
# Create the retriever
llm_chain_filter_compressor_retriever = ContextualCompressionRetriever(
base_retriever=vector_store_retriever,
base_compressor=llm_chain_filter_compressor)
question = "what is rag?"
dump_before_after_compression(vector_store_retriever, llm_chain_extractor_compressor_retriever, question)
Uses embeddings to drop documents unrelated to the query.
Making an extra LLM call over each retrieved document is expensive and slow. The EmbeddingsFilter provides a cheaper and faster option by embedding the documents and query and only returning those documents which have sufficiently similar embeddings to the query.
# Create the compressor
# Play with the threshold to understand the behavior
similarity_threshold = 0.5
embeddings_filter = EmbeddingsFilter(embeddings=llm_embeddings, similarity_threshold=similarity_threshold)
# Create the retriever
llm_embeddings_filter_compressor_retriever = ContextualCompressionRetriever(
base_retriever=vector_store_retriever,
base_compressor=embeddings_filter)
question = "what is rag?"
dump_before_after_compression(vector_store_retriever, llm_embeddings_filter_compressor_retriever, question)
transformers = [llm_chain_filter_compressor, embeddings_filter]
pipeline_compressor = DocumentCompressorPipeline(transformers=transformers)
# Create the retriever
pipeline_compressor_retriever = ContextualCompressionRetriever(
base_retriever=vector_store_retriever,
base_compressor=pipeline_compressor)
question = "what is rag?"
dump_before_after_compression(vector_store_retriever, pipeline_compressor_retriever, question)
The solution to the exercise is available in the following notebook.