summaryrefslogblamecommitdiffstats
path: root/pdf_reader.py
blob: 7f9e2b393eff5a91ce50bc6fae4e22b0b295593a (plain) (tree)


































































































































































                                                                                                                            
import os
import argparse
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import PyMuPDFLoader
from langchain_community.vectorstores import Chroma
from langchain_community.llms import Ollama
from langchain_community.embeddings import OllamaEmbeddings
from langchain.chains import RetrievalQA
from langchain.docstore.document import Document
from PyPDF2 import PdfReader
import pytesseract
from PIL import Image
import io
import fitz  # PyMuPDF

# Make sure Tesseract is installed and accessible
pytesseract.pytesseract.tesseract_cmd = r'/usr/local/bin/tesseract'  # Update this path based on your tesseract installation

def extract_text_from_pdf(pdf_path):
    """Extract text from a PDF file."""
    try:
        # Try to extract text using PyPDF2
        text = ""
        with open(pdf_path, 'rb') as file:
            pdf = PdfReader(file)
            for page in pdf.pages:
                text += page.extract_text() or ""
        print(f"Extracted {len(text)} characters from the PDF using PyPDF2.")
        return text
    except Exception as e:
        print(f"Error extracting text from PDF: {e}")
        return None

def perform_ocr_on_pdf(pdf_path):
    """Perform OCR on a PDF file to extract text."""
    try:
        doc = fitz.open(pdf_path)  # Open the PDF with PyMuPDF
        text = ""
        for page in doc:
            pix = page.get_pixmap()
            img = Image.open(io.BytesIO(pix.tobytes()))
            ocr_text = pytesseract.image_to_string(img)
            text += ocr_text
        print(f"Extracted {len(text)} characters from the PDF using OCR.")
        return text
    except Exception as e:
        print(f"Error performing OCR on PDF: {e}")
        return None

def get_pdf_text(pdf_path):
    """Determine if OCR is necessary and extract text from PDF."""
    text = extract_text_from_pdf(pdf_path)
    if text and text.strip():  # Check if text is not None and contains non-whitespace characters
        print(f"Successfully extracted text from PDF. Total characters: {len(text)}")
        return text
    else:
        print("No text found using PyPDF2, performing OCR...")
        ocr_text = perform_ocr_on_pdf(pdf_path)
        if ocr_text and ocr_text.strip():
            print(f"Successfully extracted text from PDF using OCR. Total characters: {len(ocr_text)}")
        return ocr_text

def process_pdf_for_qa(pdf_path, embeddings):
    """Prepare a PDF for question answering."""
    # Get text from PDF
    text = get_pdf_text(pdf_path)

    if not text:  # Check if text extraction or OCR failed
        print("Failed to extract text from the PDF. Please check the file.")
        return None

    # Split text into manageable chunks
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)
    texts = text_splitter.split_text(text)
    print(f"Split text into {len(texts)} chunks for processing.")

    # Convert texts into Document objects for embedding
    documents = [Document(page_content=chunk) for chunk in texts]
    
    # Create a vector store from the documents using Chroma
    print("Creating a vector store using Chroma...")
    vector_store = Chroma.from_documents(documents, embeddings)

    return vector_store

def create_qa_chain(vector_store, llm):
    """Create a QA chain for answering questions."""
    # Use the vector store retriever directly
    retriever = vector_store.as_retriever()

    # Create a RetrievalQA chain using the new approach
    qa_chain = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever)

    return qa_chain

def format_answer(response):
    """Format the answer to ensure plain text output without any special characters."""

    # Handle different response formats
    if isinstance(response, dict):
        answer = response.get('result', '')
    elif isinstance(response, list):
        answer = "\n\n".join(item.get('query', '') for item in response if 'text' in item)
    else:
        answer = str(response)
    
    # Clean up the text: Remove excess newlines and strip whitespace
    answer = answer.replace("\\n", "\n").replace("\\'", "'").replace('\\"', '"').strip()

    return answer

def main():
    # Parse command line arguments
    parser = argparse.ArgumentParser(description='PDF Question Answering using LangChain and Ollama')
    parser.add_argument('pdf_path', type=str, help='Path to the PDF file')
    args = parser.parse_args()

    # Check if the PDF file exists
    if not os.path.exists(args.pdf_path):
        print(f"Error: The file {args.pdf_path} does not exist.")
        exit(1)

    # Initialize LLM (LLaMA 3.1 model hosted on Ollama)
    llm = Ollama(model="llama3.1")

    # Initialize Ollama embeddings model using nomic-embed-text:latest
    embeddings = OllamaEmbeddings(model="nomic-embed-text:latest")

    # Process the PDF and prepare it for QA
    print("Processing the PDF. Please wait...")
    vector_store = process_pdf_for_qa(args.pdf_path, embeddings)

    if vector_store is None:
        print("Processing failed. Exiting.")
        exit(1)
    
    # Create the QA chain
    qa_chain = create_qa_chain(vector_store, llm)

    # Interactive mode for asking questions
    print("PDF processing complete. You can now ask questions about the content.")
    print("Type 'exit' or 'quit' to end the session.")

    while True:
        question = input("Enter your question: ").strip()
        if question.lower() in ["exit", "quit"]:
            print("Exiting the session. Goodbye!")
            break
        
        # Get the answer
        response = qa_chain.invoke(question)
        
        # Format the answer
        answer = format_answer(response)

        # Check if the answer is empty or only contains newlines
        if answer.strip():
            print(f"Answer:\n\n{answer}\n")
        else:
            print("No relevant information found for your question. Please try asking a different question.")

if __name__ == "__main__":
    main()