summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMatthew Lemon <y@yulqen.org>2024-08-28 13:03:09 +0100
committerMatthew Lemon <y@yulqen.org>2024-08-28 13:03:09 +0100
commit9a6b649674dff9f3ec89eeb658a0347addb5eed2 (patch)
treede8c2e0a2f2d0b53c8392a0e7da2c83343323fe9
parentfcadab179c845f2b644904dc4e38e4b37c4cd0a2 (diff)
First attempt at adding context to conversations
-rw-r--r--pdf_reader.py116
1 files changed, 91 insertions, 25 deletions
diff --git a/pdf_reader.py b/pdf_reader.py
index 62d79fb..cbf7112 100644
--- a/pdf_reader.py
+++ b/pdf_reader.py
@@ -9,8 +9,11 @@ import fitz # PyMuPDF
import pytesseract
from PIL import Image
from PyPDF2 import PdfReader
-from langchain.chains import RetrievalQA
+from langchain.chains import create_history_aware_retriever, create_retrieval_chain
+from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.docstore.document import Document
+from langchain.memory import ConversationBufferMemory
+from langchain.prompts import ChatPromptTemplate
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_chroma import Chroma
from langchain_community.embeddings import OllamaEmbeddings
@@ -137,51 +140,108 @@ def process_pdf_for_qa(pdf_path, embeddings, base_persist_directory):
return vector_store
-def create_qa_chain(vector_store: Any, llm: LLM) -> RetrievalQA:
+
+def create_qa_chain_with_memory(vector_store: Any, llm: LLM) -> Any:
"""
- Create a QA chain for answering questions using the provided vector store and Large Language Model (LLM).
+ Create a QA chain with conversational memory for answering questions using the provided vector store and LLM.
Args:
vector_store: The database of vectors representing text embeddings.
llm: The Large Language Model instance.
Returns:
- A RetrievalQA object that can be used to answer questions by querying the LLM with user input.
+ A retrieval chain object that can be used to answer questions by querying the LLM with user input.
"""
# Use the vector store retriever directly
vector_store_retriever = vector_store.as_retriever()
- # Create a RetrievalQA chain using the new approach
- qa_chain = RetrievalQA.from_chain_type(
- llm=llm,
- chain_type="stuff",
- retriever=vector_store_retriever # Use a more specific variable name
+ # Initialize conversation memory
+ memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
+
+ # Condense question prompt template
+ condense_question_system_template = (
+ "Given a chat history and the latest user question "
+ "which might reference context in the chat history, "
+ "formulate a standalone question which can be understood "
+ "without the chat history. Do NOT answer the question, "
+ "just reformulate it if needed and otherwise return it as is."
+ )
+
+ condense_question_prompt = ChatPromptTemplate.from_messages(
+ [
+ ("system", condense_question_system_template),
+ ("placeholder", "{chat_history}"),
+ ("human", "{input}"),
+ ]
+ )
+
+ # Create a history-aware retriever
+ history_aware_retriever = create_history_aware_retriever(
+ llm, vector_store_retriever, condense_question_prompt
+ )
+
+ # System prompt for question answering
+ system_prompt = (
+ "You are an assistant for question-answering tasks. "
+ "Use the following pieces of retrieved context to answer "
+ "the question. If you don't know the answer, say that you "
+ "don't know. Use three sentences maximum and keep the "
+ "answer concise."
+ "\n\n"
+ "{context}"
+ )
+
+ qa_prompt = ChatPromptTemplate.from_messages(
+ [
+ ("system", system_prompt),
+ ("placeholder", "{chat_history}"),
+ ("human", "{input}"),
+ ]
)
- return qa_chain
+ # Create a stuff documents chain for question answering
+ qa_chain = create_stuff_documents_chain(llm, qa_prompt)
+
+ # Create a retrieval chain with the history-aware retriever and QA chain
+ convo_qa_chain = create_retrieval_chain(history_aware_retriever, qa_chain)
+
+ return convo_qa_chain
def format_answer(response):
- """Format the answer to ensure plain text output without any special characters."""
- # Handle different response formats
- if isinstance(response, dict):
- answer = response.get('result', '')
- elif isinstance(response, list):
- answer = "\n\n".join(item.get('result', '') for item in response if 'result' in item)
- else:
- answer = str(response)
+ """Format the answer and context to ensure plain text output without any special characters."""
+ # Extract the main answer
+ answer = response.get('answer', '')
+
+ # Initialize context summary
+ context_summary = ""
- # Clean up the text: Remove excess newlines and strip whitespace
- answer = answer.replace("\\n", "\n").replace("\\'", "'").replace('\\"', '"').strip()
+ # Extract and summarize context if available
+ context = response.get('context', [])
+ if context and isinstance(context, list):
+ # Concatenate page_content from each Document object in the context
+ context_summary = "\n\n".join(doc.page_content for doc in context)
- return answer
+ # Clean up the answer text
+ if isinstance(answer, str):
+ answer = answer.replace("\\n", "\n").replace("\\'", "'").replace('\\"', '"').strip()
+
+ # Clean up the context summary text
+ if isinstance(context_summary, str):
+ context_summary = context_summary.replace("\\n", "\n").replace("\\'", "'").replace('\\"', '"').strip()
+
+ # Combine the answer and context summary
+ formatted_output = f"Answer:\n\n{answer}\n\nContext:\n\n{context_summary}"
+
+ return formatted_output
def main():
# Parse command line arguments
parser = argparse.ArgumentParser(description='PDF Question Answering using LangChain and Ollama')
parser.add_argument('pdf_path', type=str, help='Path to the PDF file')
- parser.add_argument('--persist', type=str, default='db', help='Base directory to save or load persisted vector stores')
+ parser.add_argument('--persist', type=str, default='db',
+ help='Base directory to save or load persisted vector stores')
args = parser.parse_args()
# Check if the PDF file exists
@@ -203,8 +263,8 @@ def main():
print("Processing failed. Exiting.")
exit(1)
- # Create the QA chain
- qa_chain = create_qa_chain(vector_store, llm)
+ # Create the QA chain with memory
+ qa_chain = create_qa_chain_with_memory(vector_store, llm)
# Interactive mode for asking questions
print("PDF processing complete. You can now ask questions about the content.")
@@ -217,7 +277,12 @@ def main():
break
# Get the answer
- response = qa_chain.invoke(question)
+ response = qa_chain.invoke(
+ {
+ "input": question,
+ "chat_history": [], # Initially, pass an empty chat history
+ }
+ )
# Format the answer
answer = format_answer(response)
@@ -228,5 +293,6 @@ def main():
else:
print("No relevant information found for your question. Please try asking a different question.")
+
if __name__ == "__main__":
main()