summaryrefslogblamecommitdiffstats
path: root/pdf_reader.py
blob: af7b3d15f2fe5a48fc5baa73e8103191c4011ce2 (plain) (tree)
1
2
3
4
5
6
7
               
              
         


           
                      








                                                                  



                                                                                                                            
 














                                                                             
 















                                                                          
 












                                                                                                       











































                                                                                                 













                                                                                     







                                                                       












                                                                                            
 

                                                                                       



                                           
                                                                                             

                              
 




                                                                                        
 



                                                                                                     
                                                                                                                     














                                                                      
                                                                              



                                            
 











                                                                                  
 

                                            
 








                                                                                                             
 

                          
import argparse
import hashlib
import io
import json
import os

import fitz  # PyMuPDF
import pytesseract
from PIL import Image
from PyPDF2 import PdfReader
from langchain.chains import RetrievalQA
from langchain.docstore.document import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_chroma import Chroma
from langchain_community.embeddings import OllamaEmbeddings
from langchain_community.llms import Ollama

# Make sure Tesseract is installed and accessible
pytesseract.pytesseract.tesseract_cmd = r'/usr/local/bin/tesseract'  # Update this path based on your tesseract installation


def extract_text_from_pdf(pdf_path):
    """Extract text from a PDF file."""
    try:
        # Try to extract text using PyPDF2
        text = ""
        with open(pdf_path, 'rb') as file:
            pdf = PdfReader(file)
            for page in pdf.pages:
                text += page.extract_text() or ""
        print(f"Extracted {len(text)} characters from the PDF using PyPDF2.")
        return text
    except Exception as e:
        print(f"Error extracting text from PDF: {e}")
        return None


def perform_ocr_on_pdf(pdf_path):
    """Perform OCR on a PDF file to extract text."""
    try:
        doc = fitz.open(pdf_path)  # Open the PDF with PyMuPDF
        text = ""
        for page in doc:
            pix = page.get_pixmap()
            img = Image.open(io.BytesIO(pix.tobytes()))
            ocr_text = pytesseract.image_to_string(img)
            text += ocr_text
        print(f"Extracted {len(text)} characters from the PDF using OCR.")
        return text
    except Exception as e:
        print(f"Error performing OCR on PDF: {e}")
        return None


def get_pdf_text(pdf_path):
    """Determine if OCR is necessary and extract text from PDF."""
    text = extract_text_from_pdf(pdf_path)
    if text and text.strip():  # Check if text is not None and contains non-whitespace characters
        print(f"Successfully extracted text from PDF. Total characters: {len(text)}")
        return text
    else:
        print("No text found using PyPDF2, performing OCR...")
        ocr_text = perform_ocr_on_pdf(pdf_path)
        if ocr_text and ocr_text.strip():
            print(f"Successfully extracted text from PDF using OCR. Total characters: {len(ocr_text)}")
        return ocr_text


def compute_pdf_hash(pdf_path):
    """Compute a unique hash for the PDF file to identify if it's already processed."""
    hasher = hashlib.sha256()
    with open(pdf_path, 'rb') as f:
        buf = f.read()
        hasher.update(buf)
    return hasher.hexdigest()


def load_metadata(persist_directory):
    """Load metadata from a JSON file."""
    metadata_path = os.path.join(persist_directory, 'metadata.json')
    if os.path.exists(metadata_path):
        with open(metadata_path, 'r') as f:
            return json.load(f)
    else:
        return {'processed_pdfs': []}


def save_metadata(persist_directory, metadata):
    """Save metadata to a JSON file."""
    metadata_path = os.path.join(persist_directory, 'metadata.json')
    with open(metadata_path, 'w') as f:
        json.dump(metadata, f)


def process_pdf_for_qa(pdf_path, embeddings, persist_directory):
    """Prepare a PDF for question answering, using Chroma's persistence."""

    pdf_hash = compute_pdf_hash(pdf_path)

    # Load or initialize metadata
    metadata = load_metadata(persist_directory)

    # Check if this PDF has already been processed
    if pdf_hash in metadata['processed_pdfs']:
        print("This PDF has already been processed. Loading embeddings from the database.")
        vector_store = Chroma(persist_directory=persist_directory, embedding_function=embeddings)
        return vector_store

    # Initialize or load Chroma vector store
    vector_store = Chroma(persist_directory=persist_directory, embedding_function=embeddings)

    # Get text from PDF
    text = get_pdf_text(pdf_path)

    if not text:  # Check if text extraction or OCR failed
        print("Failed to extract text from the PDF. Please check the file.")
        return None

    # Split text into manageable chunks
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)
    texts = text_splitter.split_text(text)
    print(f"Split text into {len(texts)} chunks for processing.")

    # Convert texts into Document objects for embedding
    documents = [Document(page_content=chunk) for chunk in texts]

    # Add the new documents to the existing vector store
    print("Adding new documents to the vector store and persisting...")
    vector_store.add_documents(documents)

    # Update the metadata to include this processed PDF hash
    metadata['processed_pdfs'].append(pdf_hash)
    save_metadata(persist_directory, metadata)

    return vector_store

def create_qa_chain(vector_store, llm):
    """Create a QA chain for answering questions."""
    # Use the vector store retriever directly
    retriever = vector_store.as_retriever()

    # Create a RetrievalQA chain using the new approach
    qa_chain = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever)

    return qa_chain


def format_answer(response):
    """Format the answer to ensure plain text output without any special characters."""
    # Handle different response formats
    if isinstance(response, dict):
        answer = response.get('result', '')
    elif isinstance(response, list):
        answer = "\n\n".join(item.get('result', '') for item in response if 'result' in item)
    else:
        answer = str(response)

    # Clean up the text: Remove excess newlines and strip whitespace
    answer = answer.replace("\\n", "\n").replace("\\'", "'").replace('\\"', '"').strip()

    return answer


def main():
    # Parse command line arguments
    parser = argparse.ArgumentParser(description='PDF Question Answering using LangChain and Ollama')
    parser.add_argument('pdf_path', type=str, help='Path to the PDF file')
    parser.add_argument('--persist', type=str, default='db', help='Directory to save or load persisted vector store')
    args = parser.parse_args()

    # Check if the PDF file exists
    if not os.path.exists(args.pdf_path):
        print(f"Error: The file {args.pdf_path} does not exist.")
        exit(1)

    # Initialize LLM (LLaMA 3.1 model hosted on Ollama)
    llm = Ollama(model="llama3.1")

    # Initialize Ollama embeddings model using nomic-embed-text:latest
    embeddings = OllamaEmbeddings(model="nomic-embed-text:latest")

    # Process the PDF and prepare it for QA
    print("Processing the PDF. Please wait...")
    vector_store = process_pdf_for_qa(args.pdf_path, embeddings, args.persist)

    if vector_store is None:
        print("Processing failed. Exiting.")
        exit(1)

    # Create the QA chain
    qa_chain = create_qa_chain(vector_store, llm)

    # Interactive mode for asking questions
    print("PDF processing complete. You can now ask questions about the content.")
    print("Type 'exit' or 'quit' to end the session.")

    while True:
        question = input("Enter your question: ").strip()
        if question.lower() in ["exit", "quit"]:
            print("Exiting the session. Goodbye!")
            break

        # Get the answer
        response = qa_chain.invoke(question)

        # Format the answer
        answer = format_answer(response)

        # Check if the answer is empty or only contains newlines
        if answer.strip():
            print(f"Answer:\n\n{answer}\n")
        else:
            print("No relevant information found for your question. Please try asking a different question.")


if __name__ == "__main__":
    main()