AndCh
AndCh

Reputation: 339

Implement filtering in RetrievalQA chain

I have been working on implementing the tutorial using RetrievalQA from Langchain with LLM from Azure OpenAI API. I've made progress with my implementation, and below is the code snippet I've been working on:

import os
# env variables
os.environ["OPENAI_API_TYPE"] = "azure"
os.environ["OPENAI_API_VERSION"] = "<YOUR_API_VERSION>"
os.environ["OPENAI_API_KEY"] = "<YOUR_API_KEY>"
os.environ["AZURE_OPENAI_ENDPOINT"] = "https://<SPACE_NAME>.openai.azure.com/"

# libary imports 
import pandas as pd

from langchain.prompts import PromptTemplate
from langchain.chains.router.llm_router import LLMRouterChain,RouterOutputParser
from langchain.embeddings import GPT4AllEmbeddings
from langchain.llms import AzureOpenAI
from langchain.chat_models import AzureChatOpenAI
from langchain.chains import RetrievalQA
from langchain.chains.qa_with_sources.retrieval import RetrievalQAWithSourcesChain
from langchain.document_loaders import DataFrameLoader
from langchain.text_splitter import (RecursiveCharacterTextSplitter, 
                                            CharacterTextSplitter)
from langchain.vectorstores import Chroma
from langchain.vectorstores import utils as chromautils
from langchain.embeddings import (HuggingFaceEmbeddings, OpenAIEmbeddings, 
                                  SentenceTransformerEmbeddings)
from langchain.callbacks import get_openai_callback
# 

# toy = 'Search in the documents and find a toy that teaches about color to kids'
toy = 'Search in the documents and find a toy with cards that has monsters'


all_docs = pd.read_csv(data) # data is the dataset from the tutorial (see above)

print('Model init \u2713')

print('---->  Azure OpenAI \u2713') 
llm_open = AzureChatOpenAI(
                           model="GPT3",
                           max_tokens = 100
                          )
print('Create docs \u2713')

loader = DataFrameLoader(all_docs, 
                         page_content_column='description' # column description in data
                        )
my_docs = loader.load()
print'Create splits \u2713')
text_splitter = CharacterTextSplitter(chunk_size=512, 
                                      chunk_overlap=0
                                      )
all_splits = text_splitter.split_documents(my_docs)
print('Init embeddings \u2713')

chroma_docs = chromautils.filter_complex_metadata(all_splits)
# embeddings = HuggingFaceEmbeddings()
my_model_name = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
embeddings = SentenceTransformerEmbeddings(model_name=my_model_name)

print('Create Chromadb \u2713')
vectorstore = Chroma.from_documents(all_splits, 
                                    embeddings,
                                   # metadatas=[{"source": f"{i}-pl"} for i in \
                                             # range(len(all_splits))]
                                   )
print('Create QA chain \u2713')
qa_chain = RetrievalQA.from_chain_type(
                                       llm=llm_open,
                                       chain_type="stuff",
                                        retriever=vectorstore.as_retriever(search_kwargs={"k": 10}),
                                       verbose=True,)

print('*** YOUR ANSWER: ***')

with get_openai_callback() as cb:
            llm_res = qa_chain.run(toy)
            plpy.notice(f'{llm_res}')
            plpy.notice(f'Total Tokens: {cb.total_tokens}')
            plpy.notice(f'Prompt Tokens: {cb.prompt_tokens}')
            plpy.notice(f'Completion Tokens: {cb.completion_tokens}')
            plpy.notice(f'Total Cost (USD): ${cb.total_cost}')**strong text**

In the tutorial, there's a section that filters products based on minimum and maximum prices using a SQL query. However, I'm unsure how to achieve similar functionality using RetrievalQA in Langchain while also retrieving the sources. The specific section in the tutorial that I'm referring to is:

results = await conn.fetch("""
         WITH vector_matches AS (
                 SELECT product_id, 
                        1 - (embedding <=> $1) AS similarity
                 FROM product_embeddings
                 WHERE 1 - (embedding <=> $1) > $2
                 ORDER BY similarity DESC
                 LIMIT $3
         )
         SELECT product_name, 
                list_price, 
                description 
         FROM products
         WHERE product_id IN (SELECT product_id FROM vector_matches)
               AND list_price >= $4 AND list_price <= $5
         """, 
         qe, similarity_threshold, num_matches, min_price, max_price)

How to implement this filtering functionality using the RetrievalQA chain in Langchain and also retrieve the sources associated with the filtered products?

Upvotes: 0

Views: 260

Answers (1)

AndCh
AndCh

Reputation: 339

I have found a way to replicate the tutorial using a retriever object. First, I had to convert the 'list_price' column to numeric

all_docs['list_price'] = pd.to_numeric(all_docs['list_price'], 
                                   errors='coerce'
                                  )

, then I can filter the metadatas using the retriever

retriever = vectorstore.as_retriever(
search_kwargs={
"k" : 5,
"filter":{'$and': [{'list_price': {'$gt': 4.0}},
                   {'list_price': {'$lt': 5.0}}]
               }
         }
)

Although, this is a solution I would like to use the prompt and filter the documents using a query like

toy = 'Is there a toy in the documents that teaches about color to kids and the price is between 15 and 50 dollars?'

Upvotes: 0

Related Questions