diff options
author | Matthew Lemon <y@yulqen.org> | 2024-08-28 13:03:09 +0100 |
---|---|---|
committer | Matthew Lemon <y@yulqen.org> | 2024-08-28 13:03:09 +0100 |
commit | 9a6b649674dff9f3ec89eeb658a0347addb5eed2 (patch) | |
tree | de8c2e0a2f2d0b53c8392a0e7da2c83343323fe9 | |
parent | fcadab179c845f2b644904dc4e38e4b37c4cd0a2 (diff) |
First attempt at adding context to conversations
-rw-r--r-- | pdf_reader.py | 116 |
1 files changed, 91 insertions, 25 deletions
diff --git a/pdf_reader.py b/pdf_reader.py index 62d79fb..cbf7112 100644 --- a/pdf_reader.py +++ b/pdf_reader.py @@ -9,8 +9,11 @@ import fitz # PyMuPDF import pytesseract from PIL import Image from PyPDF2 import PdfReader -from langchain.chains import RetrievalQA +from langchain.chains import create_history_aware_retriever, create_retrieval_chain +from langchain.chains.combine_documents import create_stuff_documents_chain from langchain.docstore.document import Document +from langchain.memory import ConversationBufferMemory +from langchain.prompts import ChatPromptTemplate from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_chroma import Chroma from langchain_community.embeddings import OllamaEmbeddings @@ -137,51 +140,108 @@ def process_pdf_for_qa(pdf_path, embeddings, base_persist_directory): return vector_store -def create_qa_chain(vector_store: Any, llm: LLM) -> RetrievalQA: + +def create_qa_chain_with_memory(vector_store: Any, llm: LLM) -> Any: """ - Create a QA chain for answering questions using the provided vector store and Large Language Model (LLM). + Create a QA chain with conversational memory for answering questions using the provided vector store and LLM. Args: vector_store: The database of vectors representing text embeddings. llm: The Large Language Model instance. Returns: - A RetrievalQA object that can be used to answer questions by querying the LLM with user input. + A retrieval chain object that can be used to answer questions by querying the LLM with user input. """ # Use the vector store retriever directly vector_store_retriever = vector_store.as_retriever() - # Create a RetrievalQA chain using the new approach - qa_chain = RetrievalQA.from_chain_type( - llm=llm, - chain_type="stuff", - retriever=vector_store_retriever # Use a more specific variable name + # Initialize conversation memory + memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True) + + # Condense question prompt template + condense_question_system_template = ( + "Given a chat history and the latest user question " + "which might reference context in the chat history, " + "formulate a standalone question which can be understood " + "without the chat history. Do NOT answer the question, " + "just reformulate it if needed and otherwise return it as is." + ) + + condense_question_prompt = ChatPromptTemplate.from_messages( + [ + ("system", condense_question_system_template), + ("placeholder", "{chat_history}"), + ("human", "{input}"), + ] + ) + + # Create a history-aware retriever + history_aware_retriever = create_history_aware_retriever( + llm, vector_store_retriever, condense_question_prompt + ) + + # System prompt for question answering + system_prompt = ( + "You are an assistant for question-answering tasks. " + "Use the following pieces of retrieved context to answer " + "the question. If you don't know the answer, say that you " + "don't know. Use three sentences maximum and keep the " + "answer concise." + "\n\n" + "{context}" + ) + + qa_prompt = ChatPromptTemplate.from_messages( + [ + ("system", system_prompt), + ("placeholder", "{chat_history}"), + ("human", "{input}"), + ] ) - return qa_chain + # Create a stuff documents chain for question answering + qa_chain = create_stuff_documents_chain(llm, qa_prompt) + + # Create a retrieval chain with the history-aware retriever and QA chain + convo_qa_chain = create_retrieval_chain(history_aware_retriever, qa_chain) + + return convo_qa_chain def format_answer(response): - """Format the answer to ensure plain text output without any special characters.""" - # Handle different response formats - if isinstance(response, dict): - answer = response.get('result', '') - elif isinstance(response, list): - answer = "\n\n".join(item.get('result', '') for item in response if 'result' in item) - else: - answer = str(response) + """Format the answer and context to ensure plain text output without any special characters.""" + # Extract the main answer + answer = response.get('answer', '') + + # Initialize context summary + context_summary = "" - # Clean up the text: Remove excess newlines and strip whitespace - answer = answer.replace("\\n", "\n").replace("\\'", "'").replace('\\"', '"').strip() + # Extract and summarize context if available + context = response.get('context', []) + if context and isinstance(context, list): + # Concatenate page_content from each Document object in the context + context_summary = "\n\n".join(doc.page_content for doc in context) - return answer + # Clean up the answer text + if isinstance(answer, str): + answer = answer.replace("\\n", "\n").replace("\\'", "'").replace('\\"', '"').strip() + + # Clean up the context summary text + if isinstance(context_summary, str): + context_summary = context_summary.replace("\\n", "\n").replace("\\'", "'").replace('\\"', '"').strip() + + # Combine the answer and context summary + formatted_output = f"Answer:\n\n{answer}\n\nContext:\n\n{context_summary}" + + return formatted_output def main(): # Parse command line arguments parser = argparse.ArgumentParser(description='PDF Question Answering using LangChain and Ollama') parser.add_argument('pdf_path', type=str, help='Path to the PDF file') - parser.add_argument('--persist', type=str, default='db', help='Base directory to save or load persisted vector stores') + parser.add_argument('--persist', type=str, default='db', + help='Base directory to save or load persisted vector stores') args = parser.parse_args() # Check if the PDF file exists @@ -203,8 +263,8 @@ def main(): print("Processing failed. Exiting.") exit(1) - # Create the QA chain - qa_chain = create_qa_chain(vector_store, llm) + # Create the QA chain with memory + qa_chain = create_qa_chain_with_memory(vector_store, llm) # Interactive mode for asking questions print("PDF processing complete. You can now ask questions about the content.") @@ -217,7 +277,12 @@ def main(): break # Get the answer - response = qa_chain.invoke(question) + response = qa_chain.invoke( + { + "input": question, + "chat_history": [], # Initially, pass an empty chat history + } + ) # Format the answer answer = format_answer(response) @@ -228,5 +293,6 @@ def main(): else: print("No relevant information found for your question. Please try asking a different question.") + if __name__ == "__main__": main() |