summaryrefslogtreecommitdiffstats
path: root/pdf_reader.py
diff options
context:
space:
mode:
authorMatthew Lemon <y@yulqen.org>2024-08-28 14:18:44 +0100
committerMatthew Lemon <y@yulqen.org>2024-08-28 14:18:44 +0100
commit5dbf9ae7052175aa56a774f5b1124dc7170765b0 (patch)
treecbf296ac53b1ddd7d4930aff70b5b7757b956e8d /pdf_reader.py
parent9a6b649674dff9f3ec89eeb658a0347addb5eed2 (diff)
Logs output to a file and limits the context
Diffstat (limited to 'pdf_reader.py')
-rw-r--r--pdf_reader.py112
1 files changed, 65 insertions, 47 deletions
diff --git a/pdf_reader.py b/pdf_reader.py
index cbf7112..5f3b818 100644
--- a/pdf_reader.py
+++ b/pdf_reader.py
@@ -3,18 +3,19 @@ import hashlib
import io
import json
import os
+from datetime import datetime
from typing import Any
import fitz # PyMuPDF
import pytesseract
from PIL import Image
from PyPDF2 import PdfReader
+from langchain.memory import ConversationBufferMemory
+from langchain.docstore.document import Document
+from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chains import create_history_aware_retriever, create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
-from langchain.docstore.document import Document
-from langchain.memory import ConversationBufferMemory
from langchain.prompts import ChatPromptTemplate
-from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_chroma import Chroma
from langchain_community.embeddings import OllamaEmbeddings
from langchain_community.llms import Ollama
@@ -23,7 +24,6 @@ from langchain_core.language_models import LLM
# Make sure Tesseract is installed and accessible
pytesseract.pytesseract.tesseract_cmd = r'/usr/bin/tesseract' # Update this path based on your tesseract installation
-
def extract_text_from_pdf(pdf_path):
"""Extract text from a PDF file."""
try:
@@ -39,7 +39,6 @@ def extract_text_from_pdf(pdf_path):
print(f"Error extracting text from PDF: {e}")
return None
-
def perform_ocr_on_pdf(pdf_path):
"""Perform OCR on a PDF file to extract text."""
try:
@@ -56,7 +55,6 @@ def perform_ocr_on_pdf(pdf_path):
print(f"Error performing OCR on PDF: {e}")
return None
-
def get_pdf_text(pdf_path):
"""Determine if OCR is necessary and extract text from PDF."""
text = extract_text_from_pdf(pdf_path)
@@ -70,7 +68,6 @@ def get_pdf_text(pdf_path):
print(f"Successfully extracted text from PDF using OCR. Total characters: {len(ocr_text)}")
return ocr_text
-
def compute_pdf_hash(pdf_path):
"""Compute a unique hash for the PDF file to identify if it's already processed."""
hasher = hashlib.sha256()
@@ -79,7 +76,6 @@ def compute_pdf_hash(pdf_path):
hasher.update(buf)
return hasher.hexdigest()
-
def load_metadata(persist_directory):
"""Load metadata from a JSON file."""
metadata_path = os.path.join(persist_directory, 'metadata.json')
@@ -89,14 +85,12 @@ def load_metadata(persist_directory):
else:
return {'processed_pdfs': []}
-
def save_metadata(persist_directory, metadata):
"""Save metadata to a JSON file."""
metadata_path = os.path.join(persist_directory, 'metadata.json')
with open(metadata_path, 'w') as f:
json.dump(metadata, f)
-
def process_pdf_for_qa(pdf_path, embeddings, base_persist_directory):
"""Prepare a PDF for question answering, using a unique Chroma persistence directory for each PDF."""
@@ -140,7 +134,6 @@ def process_pdf_for_qa(pdf_path, embeddings, base_persist_directory):
return vector_store
-
def create_qa_chain_with_memory(vector_store: Any, llm: LLM) -> Any:
"""
Create a QA chain with conversational memory for answering questions using the provided vector store and LLM.
@@ -208,8 +201,18 @@ def create_qa_chain_with_memory(vector_store: Any, llm: LLM) -> Any:
return convo_qa_chain
-def format_answer(response):
- """Format the answer and context to ensure plain text output without any special characters."""
+def format_answer(response, max_documents=3):
+ """
+ Format the answer and context to ensure plain text output without any special characters,
+ and limit the number of context documents shown.
+
+ Args:
+ response (dict): The response dictionary containing 'answer' and 'context'.
+ max_documents (int): Maximum number of context documents to display.
+
+ Returns:
+ str: Formatted output with limited context.
+ """
# Extract the main answer
answer = response.get('answer', '')
@@ -219,8 +222,11 @@ def format_answer(response):
# Extract and summarize context if available
context = response.get('context', [])
if context and isinstance(context, list):
- # Concatenate page_content from each Document object in the context
- context_summary = "\n\n".join(doc.page_content for doc in context)
+ # Limit the number of documents to max_documents
+ limited_context = context[:max_documents]
+
+ # Concatenate page_content from each Document object in the limited context
+ context_summary = "\n\n".join(doc.page_content for doc in limited_context)
# Clean up the answer text
if isinstance(answer, str):
@@ -231,17 +237,15 @@ def format_answer(response):
context_summary = context_summary.replace("\\n", "\n").replace("\\'", "'").replace('\\"', '"').strip()
# Combine the answer and context summary
- formatted_output = f"Answer:\n\n{answer}\n\nContext:\n\n{context_summary}"
+ formatted_output = f"Answer:\n\n{answer}\n\nContext (showing up to {max_documents} documents):\n\n{context_summary}"
return formatted_output
-
def main():
# Parse command line arguments
parser = argparse.ArgumentParser(description='PDF Question Answering using LangChain and Ollama')
parser.add_argument('pdf_path', type=str, help='Path to the PDF file')
- parser.add_argument('--persist', type=str, default='db',
- help='Base directory to save or load persisted vector stores')
+ parser.add_argument('--persist', type=str, default='db', help='Base directory to save or load persisted vector stores')
args = parser.parse_args()
# Check if the PDF file exists
@@ -266,33 +270,47 @@ def main():
# Create the QA chain with memory
qa_chain = create_qa_chain_with_memory(vector_store, llm)
- # Interactive mode for asking questions
- print("PDF processing complete. You can now ask questions about the content.")
- print("Type 'exit' or 'quit' to end the session.")
-
- while True:
- question = input("Enter your question: ").strip()
- if question.lower() in ["exit", "quit"]:
- print("Exiting the session. Goodbye!")
- break
-
- # Get the answer
- response = qa_chain.invoke(
- {
- "input": question,
- "chat_history": [], # Initially, pass an empty chat history
- }
- )
-
- # Format the answer
- answer = format_answer(response)
-
- # Check if the answer is empty or only contains newlines
- if answer.strip():
- print(f"Answer:\n\n{answer}\n")
- else:
- print("No relevant information found for your question. Please try asking a different question.")
-
+ # Prepare the filename for the conversation log
+ base_filename = os.path.basename(args.pdf_path).replace(' ', '_')
+ date_str = datetime.now().strftime('%Y-%m-%d')
+ log_filename = f"{date_str}_{base_filename}.md"
+
+ # Open the file in write mode
+ with open(log_filename, 'w') as log_file:
+ # Start the conversation log
+ log_file.write(f"# Conversation Log for {base_filename}\n")
+ log_file.write(f"## Date: {date_str}\n\n")
+
+ # Interactive mode for asking questions
+ print("PDF processing complete. You can now ask questions about the content.")
+ print("Type 'exit' or 'quit' to end the session.")
+
+ while True:
+ question = input("Enter your question: ").strip()
+ if question.lower() in ["exit", "quit"]:
+ print("Exiting the session. Goodbye!")
+ break
+
+ # Get the answer
+ response = qa_chain.invoke(
+ {
+ "input": question,
+ "chat_history": [], # Initially, pass an empty chat history
+ }
+ )
+
+ # Format the answer
+ answer = format_answer(response, max_documents=1)
+
+ # Log the question and answer to the file
+ log_file.write(f"### Question:\n{question}\n\n")
+ log_file.write(f"### Answer:\n{answer}\n\n")
+
+ # Display the answer to the user
+ if answer.strip():
+ print(f"Answer:\n\n{answer}\n")
+ else:
+ print("No relevant information found for your question. Please try asking a different question.")
if __name__ == "__main__":
- main()
+ main() \ No newline at end of file