Learn to use LangChain Ensemble Retriever class.
Ensemble Retriever LangChain Ensemble Retriever
Create a new notebook and copy/paste the following code step-by-step. Make sure to go through the code to understand what it is doing.
Code use BM25Retriever which requires the rank_bm25 package to be installed (!pip install –upgrade –quiet rank_bm25)
!pip install --upgrade --quiet rank_bm25
from langchain_core.documents import Document
from langchain_community.retrievers import BM25Retriever
from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings
from langchain_community.vectorstores import Chroma
from langchain.retrievers import EnsembleRetriever
## 1. Create a test corpus
corpus = [
"RAG addresses hallucinations",
"Symptoms are hallucinations",
"RAG is easier than fine tuning",
"Use a RAG to clean it",
"Retrieval Augmented Generation"
]
corpus_docs = []
# Add metadata
for i, dat in enumerate(corpus):
document = Document(
page_content= dat,
metadata = {"source": "doc-"+str(i)}
)
corpus_docs.append(document)
# Print corpus
corpus_docs
# Create the BM25 Retriever
bm25_retriever = BM25Retriever.from_documents(corpus_docs, k=3)
# Create instance of ChromaDB and add the
embedding_function = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
vector_store = Chroma(collection_name="full_documents", embedding_function=embedding_function)
vector_store.add_documents(corpus_docs)
# https://api.python.langchain.com/en/latest/vectorstores/langchain_community.vectorstores.chroma.Chroma.html#langchain_community.vectorstores.chroma.Chroma.as_retriever
chromadb_retriever = vector_store.as_retriever(search_kwargs={"k": 3})
retrievers = [bm25_retriever, chromadb_retriever]
retriever_weights = [0.4, 0.6]
ensemble_retriever = EnsembleRetriever(
retrievers = retrievers,
weights = retriever_weights,
id_key = "source"
)
# Utility function to print the list of ranked documents
def dump_doc_source(result_documents):
for doc in result_documents:
print(doc.metadata["source"])
print("\n")
# Test input
input = ["rag is cheaper",
"benefits of rag",
"piece of cloth"
]
# change input index for testing
ndx = 0
print("Input: ", input[ndx],"\n")
# Dump the ranked list for BM25
print("BM25")
print("----")
results_bm25 = bm25_retriever.invoke(input[ndx])
dump_doc_source(results_bm25)
# Dump the ranked list for ChromaDB
print("ChromaDB")
print("--------")
results_chromadb = chromadb_retriever.invoke(input[ndx])
dump_doc_source(results_chromadb)
print("Ensemble Retriever")
print("------------------")
results = ensemble_retriever.invoke(input[ndx])
dump_doc_source(results)
results
The solution to the exercise is available in the following notebook.