import argparse
import hashlib
import io
import json
import os
from datetime import datetime
from typing import Any
import fitz # PyMuPDF
import pytesseract
from PIL import Image
from PyPDF2 import PdfReader
from langchain.memory import ConversationBufferMemory
from langchain.docstore.document import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chains import create_history_aware_retriever, create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.prompts import ChatPromptTemplate
from langchain_chroma import Chroma
from langchain_community.embeddings import OllamaEmbeddings
from langchain_community.llms import Ollama
from langchain_core.language_models import LLM
# Make sure Tesseract is installed and accessible
pytesseract.pytesseract.tesseract_cmd = r'/usr/bin/tesseract' # Update this path based on your tesseract installation
def extract_text_from_pdf(pdf_path):
"""Extract text from a PDF file."""
try:
# Try to extract text using PyPDF2
text = ""
with open(pdf_path, 'rb') as file:
pdf = PdfReader(file)
for page in pdf.pages:
text += page.extract_text() or ""
print(f"Extracted {len(text)} characters from the PDF using PyPDF2.")
return text
except Exception as e:
print(f"Error extracting text from PDF: {e}")
return None
def perform_ocr_on_pdf(pdf_path):
"""Perform OCR on a PDF file to extract text."""
try:
doc = fitz.open(pdf_path) # Open the PDF with PyMuPDF
text = ""
for page in doc:
pix = page.get_pixmap()
img = Image.open(io.BytesIO(pix.tobytes()))
ocr_text = pytesseract.image_to_string(img)
text += ocr_text
print(f"Extracted {len(text)} characters from the PDF using OCR.")
return text
except Exception as e:
print(f"Error performing OCR on PDF: {e}")
return None
def get_pdf_text(pdf_path):
"""Determine if OCR is necessary and extract text from PDF."""
text = extract_text_from_pdf(pdf_path)
if text and text.strip(): # Check if text is not None and contains non-whitespace characters
print(f"Successfully extracted text from PDF. Total characters: {len(text)}")
return text
else:
print("No text found using PyPDF2, performing OCR...")
ocr_text = perform_ocr_on_pdf(pdf_path)
if ocr_text and ocr_text.strip():
print(f"Successfully extracted text from PDF using OCR. Total characters: {len(ocr_text)}")
return ocr_text
def compute_pdf_hash(pdf_path):
"""Compute a unique hash for the PDF file to identify if it's already processed."""
hasher = hashlib.sha256()
with open(pdf_path, 'rb') as f:
buf = f.read()
hasher.update(buf)
return hasher.hexdigest()
def load_metadata(persist_directory):
"""Load metadata from a JSON file."""
metadata_path = os.path.join(persist_directory, 'metadata.json')
if os.path.exists(metadata_path):
with open(metadata_path, 'r') as f:
return json.load(f)
else:
return {'processed_pdfs': []}
def save_metadata(persist_directory, metadata):
"""Save metadata to a JSON file."""
metadata_path = os.path.join(persist_directory, 'metadata.json')
with open(metadata_path, 'w') as f:
json.dump(metadata, f)
def process_pdf_for_qa(pdf_path, embeddings, base_persist_directory):
"""Prepare a PDF for question answering, using a unique Chroma persistence directory for each PDF."""
pdf_hash = compute_pdf_hash(pdf_path)
persist_directory = os.path.join(base_persist_directory, pdf_hash) # Use hash to create a unique directory
# Load or initialize metadata
metadata = load_metadata(base_persist_directory)
# Check if this PDF has already been processed
if pdf_hash in metadata['processed_pdfs']:
print(f"This PDF has already been processed. Loading embeddings from the database at {persist_directory}.")
vector_store = Chroma(persist_directory=persist_directory, embedding_function=embeddings)
return vector_store
# Initialize or load Chroma vector store for this specific PDF
vector_store = Chroma(persist_directory=persist_directory, embedding_function=embeddings)
# Get text from PDF
text = get_pdf_text(pdf_path)
if not text: # Check if text extraction or OCR failed
print("Failed to extract text from the PDF. Please check the file.")
return None
# Split text into manageable chunks
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)
texts = text_splitter.split_text(text)
print(f"Split text into {len(texts)} chunks for processing.")
# Convert texts into Document objects for embedding
documents = [Document(page_content=chunk) for chunk in texts]
# Add the new documents to the existing vector store
print("Adding new documents to the vector store and persisting...")
vector_store.add_documents(documents)
# Update the metadata to include this processed PDF hash
metadata['processed_pdfs'].append(pdf_hash)
save_metadata(base_persist_directory, metadata)
return vector_store
def create_qa_chain_with_memory(vector_store: Any, llm: LLM) -> Any:
"""
Create a QA chain with conversational memory for answering questions using the provided vector store and LLM.
Args:
vector_store: The database of vectors representing text embeddings.
llm: The Large Language Model instance.
Returns:
A retrieval chain object that can be used to answer questions by querying the LLM with user input.
"""
# Use the vector store retriever directly
vector_store_retriever = vector_store.as_retriever()
# Initialize conversation memory
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
# Condense question prompt template
condense_question_system_template = (
"Given a chat history and the latest user question "
"which might reference context in the chat history, "
"formulate a standalone question which can be understood "
"without the chat history. Do NOT answer the question, "
"just reformulate it if needed and otherwise return it as is."
)
condense_question_prompt = ChatPromptTemplate.from_messages(
[
("system", condense_question_system_template),
("placeholder", "{chat_history}"),
("human", "{input}"),
]
)
# Create a history-aware retriever
history_aware_retriever = create_history_aware_retriever(
llm, vector_store_retriever, condense_question_prompt
)
# System prompt for question answering
system_prompt = (
"You are an assistant for question-answering tasks. "
"Use the following pieces of retrieved context to answer "
"the question. If you don't know the answer, say that you "
"don't know. Use three sentences maximum and keep the "
"answer concise."
"\n\n"
"{context}"
)
qa_prompt = ChatPromptTemplate.from_messages(
[
("system", system_prompt),
("placeholder", "{chat_history}"),
("human", "{input}"),
]
)
# Create a stuff documents chain for question answering
qa_chain = create_stuff_documents_chain(llm, qa_prompt)
# Create a retrieval chain with the history-aware retriever and QA chain
convo_qa_chain = create_retrieval_chain(history_aware_retriever, qa_chain)
return convo_qa_chain
def format_answer(response, max_documents=3):
"""
Format the answer and context to ensure plain text output without any special characters,
and limit the number of context documents shown.
Args:
response (dict): The response dictionary containing 'answer' and 'context'.
max_documents (int): Maximum number of context documents to display.
Returns:
str: Formatted output with limited context.
"""
# Extract the main answer
answer = response.get('answer', '')
# Initialize context summary
context_summary = ""
# Extract and summarize context if available
context = response.get('context', [])
if context and isinstance(context, list):
# Limit the number of documents to max_documents
limited_context = context[:max_documents]
# Concatenate page_content from each Document object in the limited context
context_summary = "\n\n".join(doc.page_content for doc in limited_context)
# Clean up the answer text
if isinstance(answer, str):
answer = answer.replace("\\n", "\n").replace("\\'", "'").replace('\\"', '"').strip()
# Clean up the context summary text
if isinstance(context_summary, str):
context_summary = context_summary.replace("\\n", "\n").replace("\\'", "'").replace('\\"', '"').strip()
# Combine the answer and context summary
formatted_output = f"Answer:\n\n{answer}\n\nContext (showing up to {max_documents} documents):\n\n{context_summary}"
return formatted_output
def main():
# Parse command line arguments
parser = argparse.ArgumentParser(description='PDF Question Answering using LangChain and Ollama')
parser.add_argument('pdf_path', type=str, help='Path to the PDF file')
parser.add_argument('--persist', type=str, default='db', help='Base directory to save or load persisted vector stores')
args = parser.parse_args()
# Check if the PDF file exists
if not os.path.exists(args.pdf_path):
print(f"Error: The file {args.pdf_path} does not exist.")
exit(1)
# Initialize LLM (LLaMA 3.1 model hosted on Ollama)
llm = Ollama(model="llama3.1")
# Initialize Ollama embeddings model using nomic-embed-text:latest
embeddings = OllamaEmbeddings(model="nomic-embed-text:latest")
# Process the PDF and prepare it for QA
print("Processing the PDF. Please wait...")
vector_store = process_pdf_for_qa(args.pdf_path, embeddings, args.persist)
if vector_store is None:
print("Processing failed. Exiting.")
exit(1)
# Create the QA chain with memory
qa_chain = create_qa_chain_with_memory(vector_store, llm)
# Prepare the filename for the conversation log
base_filename = os.path.basename(args.pdf_path).replace(' ', '_')
date_str = datetime.now().strftime('%Y-%m-%d')
log_filename = f"{date_str}_{base_filename}.md"
# Open the file in write mode
with open(log_filename, 'w') as log_file:
# Start the conversation log
log_file.write(f"# Conversation Log for {base_filename}\n")
log_file.write(f"## Date: {date_str}\n\n")
# Interactive mode for asking questions
print("PDF processing complete. You can now ask questions about the content.")
print("Type 'exit' or 'quit' to end the session.")
while True:
question = input("Enter your question: ").strip()
if question.lower() in ["exit", "quit"]:
print("Exiting the session. Goodbye!")
break
# Get the answer
response = qa_chain.invoke(
{
"input": question,
"chat_history": [], # Initially, pass an empty chat history
}
)
# Format the answer
answer = format_answer(response, max_documents=1)
# Log the question and answer to the file
log_file.write(f"### Question:\n{question}\n\n")
log_file.write(f"### Answer:\n{answer}\n\n")
# Display the answer to the user
if answer.strip():
print(f"Answer:\n\n{answer}\n")
else:
print("No relevant information found for your question. Please try asking a different question.")
if __name__ == "__main__":
main()