import argparse
import hashlib
import io
import json
import os
import fitz # PyMuPDF
import pytesseract
from PIL import Image
from PyPDF2 import PdfReader
from langchain.chains import RetrievalQA
from langchain.docstore.document import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_chroma import Chroma
from langchain_community.embeddings import OllamaEmbeddings
from langchain_community.llms import Ollama
# Make sure Tesseract is installed and accessible
pytesseract.pytesseract.tesseract_cmd = r'/usr/local/bin/tesseract' # Update this path based on your tesseract installation
def extract_text_from_pdf(pdf_path):
"""Extract text from a PDF file."""
try:
# Try to extract text using PyPDF2
text = ""
with open(pdf_path, 'rb') as file:
pdf = PdfReader(file)
for page in pdf.pages:
text += page.extract_text() or ""
print(f"Extracted {len(text)} characters from the PDF using PyPDF2.")
return text
except Exception as e:
print(f"Error extracting text from PDF: {e}")
return None
def perform_ocr_on_pdf(pdf_path):
"""Perform OCR on a PDF file to extract text."""
try:
doc = fitz.open(pdf_path) # Open the PDF with PyMuPDF
text = ""
for page in doc:
pix = page.get_pixmap()
img = Image.open(io.BytesIO(pix.tobytes()))
ocr_text = pytesseract.image_to_string(img)
text += ocr_text
print(f"Extracted {len(text)} characters from the PDF using OCR.")
return text
except Exception as e:
print(f"Error performing OCR on PDF: {e}")
return None
def get_pdf_text(pdf_path):
"""Determine if OCR is necessary and extract text from PDF."""
text = extract_text_from_pdf(pdf_path)
if text and text.strip(): # Check if text is not None and contains non-whitespace characters
print(f"Successfully extracted text from PDF. Total characters: {len(text)}")
return text
else:
print("No text found using PyPDF2, performing OCR...")
ocr_text = perform_ocr_on_pdf(pdf_path)
if ocr_text and ocr_text.strip():
print(f"Successfully extracted text from PDF using OCR. Total characters: {len(ocr_text)}")
return ocr_text
def compute_pdf_hash(pdf_path):
"""Compute a unique hash for the PDF file to identify if it's already processed."""
hasher = hashlib.sha256()
with open(pdf_path, 'rb') as f:
buf = f.read()
hasher.update(buf)
return hasher.hexdigest()
def load_metadata(persist_directory):
"""Load metadata from a JSON file."""
metadata_path = os.path.join(persist_directory, 'metadata.json')
if os.path.exists(metadata_path):
with open(metadata_path, 'r') as f:
return json.load(f)
else:
return {'processed_pdfs': []}
def save_metadata(persist_directory, metadata):
"""Save metadata to a JSON file."""
metadata_path = os.path.join(persist_directory, 'metadata.json')
with open(metadata_path, 'w') as f:
json.dump(metadata, f)
def process_pdf_for_qa(pdf_path, embeddings, base_persist_directory):
"""Prepare a PDF for question answering, using a unique Chroma persistence directory for each PDF."""
pdf_hash = compute_pdf_hash(pdf_path)
persist_directory = os.path.join(base_persist_directory, pdf_hash) # Use hash to create a unique directory
# Load or initialize metadata
metadata = load_metadata(base_persist_directory)
# Check if this PDF has already been processed
if pdf_hash in metadata['processed_pdfs']:
print(f"This PDF has already been processed. Loading embeddings from the database at {persist_directory}.")
vector_store = Chroma(persist_directory=persist_directory, embedding_function=embeddings)
return vector_store
# Initialize or load Chroma vector store for this specific PDF
vector_store = Chroma(persist_directory=persist_directory, embedding_function=embeddings)
# Get text from PDF
text = get_pdf_text(pdf_path)
if not text: # Check if text extraction or OCR failed
print("Failed to extract text from the PDF. Please check the file.")
return None
# Split text into manageable chunks
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)
texts = text_splitter.split_text(text)
print(f"Split text into {len(texts)} chunks for processing.")
# Convert texts into Document objects for embedding
documents = [Document(page_content=chunk) for chunk in texts]
# Add the new documents to the existing vector store
print("Adding new documents to the vector store and persisting...")
vector_store.add_documents(documents)
# Update the metadata to include this processed PDF hash
metadata['processed_pdfs'].append(pdf_hash)
save_metadata(base_persist_directory, metadata)
return vector_store
def create_qa_chain(vector_store, llm):
"""Create a QA chain for answering questions."""
# Use the vector store retriever directly
retriever = vector_store.as_retriever()
# Create a RetrievalQA chain using the new approach
qa_chain = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever)
return qa_chain
def format_answer(response):
"""Format the answer to ensure plain text output without any special characters."""
# Handle different response formats
if isinstance(response, dict):
answer = response.get('result', '')
elif isinstance(response, list):
answer = "\n\n".join(item.get('result', '') for item in response if 'result' in item)
else:
answer = str(response)
# Clean up the text: Remove excess newlines and strip whitespace
answer = answer.replace("\\n", "\n").replace("\\'", "'").replace('\\"', '"').strip()
return answer
def main():
# Parse command line arguments
parser = argparse.ArgumentParser(description='PDF Question Answering using LangChain and Ollama')
parser.add_argument('pdf_path', type=str, help='Path to the PDF file')
parser.add_argument('--persist', type=str, default='db', help='Base directory to save or load persisted vector stores')
args = parser.parse_args()
# Check if the PDF file exists
if not os.path.exists(args.pdf_path):
print(f"Error: The file {args.pdf_path} does not exist.")
exit(1)
# Initialize LLM (LLaMA 3.1 model hosted on Ollama)
llm = Ollama(model="llama3.1")
# Initialize Ollama embeddings model using nomic-embed-text:latest
embeddings = OllamaEmbeddings(model="nomic-embed-text:latest")
# Process the PDF and prepare it for QA
print("Processing the PDF. Please wait...")
vector_store = process_pdf_for_qa(args.pdf_path, embeddings, args.persist)
if vector_store is None:
print("Processing failed. Exiting.")
exit(1)
# Create the QA chain
qa_chain = create_qa_chain(vector_store, llm)
# Interactive mode for asking questions
print("PDF processing complete. You can now ask questions about the content.")
print("Type 'exit' or 'quit' to end the session.")
while True:
question = input("Enter your question: ").strip()
if question.lower() in ["exit", "quit"]:
print("Exiting the session. Goodbye!")
break
# Get the answer
response = qa_chain.invoke(question)
# Format the answer
answer = format_answer(response)
# Check if the answer is empty or only contains newlines
if answer.strip():
print(f"Answer:\n\n{answer}\n")
else:
print("No relevant information found for your question. Please try asking a different question.")
if __name__ == "__main__":
main()