diff options
author | Matthew Lemon <y@yulqen.org> | 2024-08-28 14:18:44 +0100 |
---|---|---|
committer | Matthew Lemon <y@yulqen.org> | 2024-08-28 14:18:44 +0100 |
commit | 5dbf9ae7052175aa56a774f5b1124dc7170765b0 (patch) | |
tree | cbf296ac53b1ddd7d4930aff70b5b7757b956e8d /pdf_reader.py | |
parent | 9a6b649674dff9f3ec89eeb658a0347addb5eed2 (diff) |
Logs output to a file and limits the context
Diffstat (limited to 'pdf_reader.py')
-rw-r--r-- | pdf_reader.py | 112 |
1 files changed, 65 insertions, 47 deletions
diff --git a/pdf_reader.py b/pdf_reader.py index cbf7112..5f3b818 100644 --- a/pdf_reader.py +++ b/pdf_reader.py @@ -3,18 +3,19 @@ import hashlib import io import json import os +from datetime import datetime from typing import Any import fitz # PyMuPDF import pytesseract from PIL import Image from PyPDF2 import PdfReader +from langchain.memory import ConversationBufferMemory +from langchain.docstore.document import Document +from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.chains import create_history_aware_retriever, create_retrieval_chain from langchain.chains.combine_documents import create_stuff_documents_chain -from langchain.docstore.document import Document -from langchain.memory import ConversationBufferMemory from langchain.prompts import ChatPromptTemplate -from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_chroma import Chroma from langchain_community.embeddings import OllamaEmbeddings from langchain_community.llms import Ollama @@ -23,7 +24,6 @@ from langchain_core.language_models import LLM # Make sure Tesseract is installed and accessible pytesseract.pytesseract.tesseract_cmd = r'/usr/bin/tesseract' # Update this path based on your tesseract installation - def extract_text_from_pdf(pdf_path): """Extract text from a PDF file.""" try: @@ -39,7 +39,6 @@ def extract_text_from_pdf(pdf_path): print(f"Error extracting text from PDF: {e}") return None - def perform_ocr_on_pdf(pdf_path): """Perform OCR on a PDF file to extract text.""" try: @@ -56,7 +55,6 @@ def perform_ocr_on_pdf(pdf_path): print(f"Error performing OCR on PDF: {e}") return None - def get_pdf_text(pdf_path): """Determine if OCR is necessary and extract text from PDF.""" text = extract_text_from_pdf(pdf_path) @@ -70,7 +68,6 @@ def get_pdf_text(pdf_path): print(f"Successfully extracted text from PDF using OCR. Total characters: {len(ocr_text)}") return ocr_text - def compute_pdf_hash(pdf_path): """Compute a unique hash for the PDF file to identify if it's already processed.""" hasher = hashlib.sha256() @@ -79,7 +76,6 @@ def compute_pdf_hash(pdf_path): hasher.update(buf) return hasher.hexdigest() - def load_metadata(persist_directory): """Load metadata from a JSON file.""" metadata_path = os.path.join(persist_directory, 'metadata.json') @@ -89,14 +85,12 @@ def load_metadata(persist_directory): else: return {'processed_pdfs': []} - def save_metadata(persist_directory, metadata): """Save metadata to a JSON file.""" metadata_path = os.path.join(persist_directory, 'metadata.json') with open(metadata_path, 'w') as f: json.dump(metadata, f) - def process_pdf_for_qa(pdf_path, embeddings, base_persist_directory): """Prepare a PDF for question answering, using a unique Chroma persistence directory for each PDF.""" @@ -140,7 +134,6 @@ def process_pdf_for_qa(pdf_path, embeddings, base_persist_directory): return vector_store - def create_qa_chain_with_memory(vector_store: Any, llm: LLM) -> Any: """ Create a QA chain with conversational memory for answering questions using the provided vector store and LLM. @@ -208,8 +201,18 @@ def create_qa_chain_with_memory(vector_store: Any, llm: LLM) -> Any: return convo_qa_chain -def format_answer(response): - """Format the answer and context to ensure plain text output without any special characters.""" +def format_answer(response, max_documents=3): + """ + Format the answer and context to ensure plain text output without any special characters, + and limit the number of context documents shown. + + Args: + response (dict): The response dictionary containing 'answer' and 'context'. + max_documents (int): Maximum number of context documents to display. + + Returns: + str: Formatted output with limited context. + """ # Extract the main answer answer = response.get('answer', '') @@ -219,8 +222,11 @@ def format_answer(response): # Extract and summarize context if available context = response.get('context', []) if context and isinstance(context, list): - # Concatenate page_content from each Document object in the context - context_summary = "\n\n".join(doc.page_content for doc in context) + # Limit the number of documents to max_documents + limited_context = context[:max_documents] + + # Concatenate page_content from each Document object in the limited context + context_summary = "\n\n".join(doc.page_content for doc in limited_context) # Clean up the answer text if isinstance(answer, str): @@ -231,17 +237,15 @@ def format_answer(response): context_summary = context_summary.replace("\\n", "\n").replace("\\'", "'").replace('\\"', '"').strip() # Combine the answer and context summary - formatted_output = f"Answer:\n\n{answer}\n\nContext:\n\n{context_summary}" + formatted_output = f"Answer:\n\n{answer}\n\nContext (showing up to {max_documents} documents):\n\n{context_summary}" return formatted_output - def main(): # Parse command line arguments parser = argparse.ArgumentParser(description='PDF Question Answering using LangChain and Ollama') parser.add_argument('pdf_path', type=str, help='Path to the PDF file') - parser.add_argument('--persist', type=str, default='db', - help='Base directory to save or load persisted vector stores') + parser.add_argument('--persist', type=str, default='db', help='Base directory to save or load persisted vector stores') args = parser.parse_args() # Check if the PDF file exists @@ -266,33 +270,47 @@ def main(): # Create the QA chain with memory qa_chain = create_qa_chain_with_memory(vector_store, llm) - # Interactive mode for asking questions - print("PDF processing complete. You can now ask questions about the content.") - print("Type 'exit' or 'quit' to end the session.") - - while True: - question = input("Enter your question: ").strip() - if question.lower() in ["exit", "quit"]: - print("Exiting the session. Goodbye!") - break - - # Get the answer - response = qa_chain.invoke( - { - "input": question, - "chat_history": [], # Initially, pass an empty chat history - } - ) - - # Format the answer - answer = format_answer(response) - - # Check if the answer is empty or only contains newlines - if answer.strip(): - print(f"Answer:\n\n{answer}\n") - else: - print("No relevant information found for your question. Please try asking a different question.") - + # Prepare the filename for the conversation log + base_filename = os.path.basename(args.pdf_path).replace(' ', '_') + date_str = datetime.now().strftime('%Y-%m-%d') + log_filename = f"{date_str}_{base_filename}.md" + + # Open the file in write mode + with open(log_filename, 'w') as log_file: + # Start the conversation log + log_file.write(f"# Conversation Log for {base_filename}\n") + log_file.write(f"## Date: {date_str}\n\n") + + # Interactive mode for asking questions + print("PDF processing complete. You can now ask questions about the content.") + print("Type 'exit' or 'quit' to end the session.") + + while True: + question = input("Enter your question: ").strip() + if question.lower() in ["exit", "quit"]: + print("Exiting the session. Goodbye!") + break + + # Get the answer + response = qa_chain.invoke( + { + "input": question, + "chat_history": [], # Initially, pass an empty chat history + } + ) + + # Format the answer + answer = format_answer(response, max_documents=1) + + # Log the question and answer to the file + log_file.write(f"### Question:\n{question}\n\n") + log_file.write(f"### Answer:\n{answer}\n\n") + + # Display the answer to the user + if answer.strip(): + print(f"Answer:\n\n{answer}\n") + else: + print("No relevant information found for your question. Please try asking a different question.") if __name__ == "__main__": - main() + main()
\ No newline at end of file |