summaryrefslogblamecommitdiffstats
path: root/pdf_reader.py
blob: 5f3b81873e081a9b5b5ef5d032a9de54cb016fa2 (plain) (tree)
1
2
3
4
5
6
7
8
9
               
              
         

           
                             
                      
 
                      


                            


                                                                  

                                                                                   
                                                


                                                           
                                              

                                                 
                                                                                                                      












































                                                                                                       







                                                                                       








                                                                    





                                                                    

                                                                                                         

                                         
                                                                                                               

                                 
                                                    


                                                  
                                                                                                                   


                                                                                                 
                                                                  

                                                                                             













                                                                                     






                                                                       
                                                   


                       
                                                                    
       
                                                                                                                 





                                                                           
                                                                                                          
       
                                             
                                                        
 









































                                                                                      
     
 






                                                                              
 
 











                                                                                             




                                       
 


                                                




                                                                                   
 








                                                                                                              
                                                                                                                        

                           




                                                                                                     
                                                                                                                           














                                                                      
                                                                              



                                            
 

                                                             
 








































                                                                                                                 
 
                          
         
import argparse
import hashlib
import io
import json
import os
from datetime import datetime
from typing import Any

import fitz  # PyMuPDF
import pytesseract
from PIL import Image
from PyPDF2 import PdfReader
from langchain.memory import ConversationBufferMemory
from langchain.docstore.document import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chains import create_history_aware_retriever, create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.prompts import ChatPromptTemplate
from langchain_chroma import Chroma
from langchain_community.embeddings import OllamaEmbeddings
from langchain_community.llms import Ollama
from langchain_core.language_models import LLM

# Make sure Tesseract is installed and accessible
pytesseract.pytesseract.tesseract_cmd = r'/usr/bin/tesseract'  # Update this path based on your tesseract installation

def extract_text_from_pdf(pdf_path):
    """Extract text from a PDF file."""
    try:
        # Try to extract text using PyPDF2
        text = ""
        with open(pdf_path, 'rb') as file:
            pdf = PdfReader(file)
            for page in pdf.pages:
                text += page.extract_text() or ""
        print(f"Extracted {len(text)} characters from the PDF using PyPDF2.")
        return text
    except Exception as e:
        print(f"Error extracting text from PDF: {e}")
        return None

def perform_ocr_on_pdf(pdf_path):
    """Perform OCR on a PDF file to extract text."""
    try:
        doc = fitz.open(pdf_path)  # Open the PDF with PyMuPDF
        text = ""
        for page in doc:
            pix = page.get_pixmap()
            img = Image.open(io.BytesIO(pix.tobytes()))
            ocr_text = pytesseract.image_to_string(img)
            text += ocr_text
        print(f"Extracted {len(text)} characters from the PDF using OCR.")
        return text
    except Exception as e:
        print(f"Error performing OCR on PDF: {e}")
        return None

def get_pdf_text(pdf_path):
    """Determine if OCR is necessary and extract text from PDF."""
    text = extract_text_from_pdf(pdf_path)
    if text and text.strip():  # Check if text is not None and contains non-whitespace characters
        print(f"Successfully extracted text from PDF. Total characters: {len(text)}")
        return text
    else:
        print("No text found using PyPDF2, performing OCR...")
        ocr_text = perform_ocr_on_pdf(pdf_path)
        if ocr_text and ocr_text.strip():
            print(f"Successfully extracted text from PDF using OCR. Total characters: {len(ocr_text)}")
        return ocr_text

def compute_pdf_hash(pdf_path):
    """Compute a unique hash for the PDF file to identify if it's already processed."""
    hasher = hashlib.sha256()
    with open(pdf_path, 'rb') as f:
        buf = f.read()
        hasher.update(buf)
    return hasher.hexdigest()

def load_metadata(persist_directory):
    """Load metadata from a JSON file."""
    metadata_path = os.path.join(persist_directory, 'metadata.json')
    if os.path.exists(metadata_path):
        with open(metadata_path, 'r') as f:
            return json.load(f)
    else:
        return {'processed_pdfs': []}

def save_metadata(persist_directory, metadata):
    """Save metadata to a JSON file."""
    metadata_path = os.path.join(persist_directory, 'metadata.json')
    with open(metadata_path, 'w') as f:
        json.dump(metadata, f)

def process_pdf_for_qa(pdf_path, embeddings, base_persist_directory):
    """Prepare a PDF for question answering, using a unique Chroma persistence directory for each PDF."""

    pdf_hash = compute_pdf_hash(pdf_path)
    persist_directory = os.path.join(base_persist_directory, pdf_hash)  # Use hash to create a unique directory

    # Load or initialize metadata
    metadata = load_metadata(base_persist_directory)

    # Check if this PDF has already been processed
    if pdf_hash in metadata['processed_pdfs']:
        print(f"This PDF has already been processed. Loading embeddings from the database at {persist_directory}.")
        vector_store = Chroma(persist_directory=persist_directory, embedding_function=embeddings)
        return vector_store

    # Initialize or load Chroma vector store for this specific PDF
    vector_store = Chroma(persist_directory=persist_directory, embedding_function=embeddings)

    # Get text from PDF
    text = get_pdf_text(pdf_path)

    if not text:  # Check if text extraction or OCR failed
        print("Failed to extract text from the PDF. Please check the file.")
        return None

    # Split text into manageable chunks
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)
    texts = text_splitter.split_text(text)
    print(f"Split text into {len(texts)} chunks for processing.")

    # Convert texts into Document objects for embedding
    documents = [Document(page_content=chunk) for chunk in texts]

    # Add the new documents to the existing vector store
    print("Adding new documents to the vector store and persisting...")
    vector_store.add_documents(documents)

    # Update the metadata to include this processed PDF hash
    metadata['processed_pdfs'].append(pdf_hash)
    save_metadata(base_persist_directory, metadata)

    return vector_store

def create_qa_chain_with_memory(vector_store: Any, llm: LLM) -> Any:
    """
    Create a QA chain with conversational memory for answering questions using the provided vector store and LLM.

    Args:
        vector_store: The database of vectors representing text embeddings.
        llm: The Large Language Model instance.

    Returns:
        A retrieval chain object that can be used to answer questions by querying the LLM with user input.
    """
    # Use the vector store retriever directly
    vector_store_retriever = vector_store.as_retriever()

    # Initialize conversation memory
    memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)

    # Condense question prompt template
    condense_question_system_template = (
        "Given a chat history and the latest user question "
        "which might reference context in the chat history, "
        "formulate a standalone question which can be understood "
        "without the chat history. Do NOT answer the question, "
        "just reformulate it if needed and otherwise return it as is."
    )

    condense_question_prompt = ChatPromptTemplate.from_messages(
        [
            ("system", condense_question_system_template),
            ("placeholder", "{chat_history}"),
            ("human", "{input}"),
        ]
    )

    # Create a history-aware retriever
    history_aware_retriever = create_history_aware_retriever(
        llm, vector_store_retriever, condense_question_prompt
    )

    # System prompt for question answering
    system_prompt = (
        "You are an assistant for question-answering tasks. "
        "Use the following pieces of retrieved context to answer "
        "the question. If you don't know the answer, say that you "
        "don't know. Use three sentences maximum and keep the "
        "answer concise."
        "\n\n"
        "{context}"
    )

    qa_prompt = ChatPromptTemplate.from_messages(
        [
            ("system", system_prompt),
            ("placeholder", "{chat_history}"),
            ("human", "{input}"),
        ]
    )

    # Create a stuff documents chain for question answering
    qa_chain = create_stuff_documents_chain(llm, qa_prompt)

    # Create a retrieval chain with the history-aware retriever and QA chain
    convo_qa_chain = create_retrieval_chain(history_aware_retriever, qa_chain)

    return convo_qa_chain


def format_answer(response, max_documents=3):
    """
    Format the answer and context to ensure plain text output without any special characters,
    and limit the number of context documents shown.

    Args:
        response (dict): The response dictionary containing 'answer' and 'context'.
        max_documents (int): Maximum number of context documents to display.

    Returns:
        str: Formatted output with limited context.
    """
    # Extract the main answer
    answer = response.get('answer', '')

    # Initialize context summary
    context_summary = ""

    # Extract and summarize context if available
    context = response.get('context', [])
    if context and isinstance(context, list):
        # Limit the number of documents to max_documents
        limited_context = context[:max_documents]

        # Concatenate page_content from each Document object in the limited context
        context_summary = "\n\n".join(doc.page_content for doc in limited_context)

    # Clean up the answer text
    if isinstance(answer, str):
        answer = answer.replace("\\n", "\n").replace("\\'", "'").replace('\\"', '"').strip()

    # Clean up the context summary text
    if isinstance(context_summary, str):
        context_summary = context_summary.replace("\\n", "\n").replace("\\'", "'").replace('\\"', '"').strip()

    # Combine the answer and context summary
    formatted_output = f"Answer:\n\n{answer}\n\nContext (showing up to {max_documents} documents):\n\n{context_summary}"

    return formatted_output

def main():
    # Parse command line arguments
    parser = argparse.ArgumentParser(description='PDF Question Answering using LangChain and Ollama')
    parser.add_argument('pdf_path', type=str, help='Path to the PDF file')
    parser.add_argument('--persist', type=str, default='db', help='Base directory to save or load persisted vector stores')
    args = parser.parse_args()

    # Check if the PDF file exists
    if not os.path.exists(args.pdf_path):
        print(f"Error: The file {args.pdf_path} does not exist.")
        exit(1)

    # Initialize LLM (LLaMA 3.1 model hosted on Ollama)
    llm = Ollama(model="llama3.1")

    # Initialize Ollama embeddings model using nomic-embed-text:latest
    embeddings = OllamaEmbeddings(model="nomic-embed-text:latest")

    # Process the PDF and prepare it for QA
    print("Processing the PDF. Please wait...")
    vector_store = process_pdf_for_qa(args.pdf_path, embeddings, args.persist)

    if vector_store is None:
        print("Processing failed. Exiting.")
        exit(1)

    # Create the QA chain with memory
    qa_chain = create_qa_chain_with_memory(vector_store, llm)

    # Prepare the filename for the conversation log
    base_filename = os.path.basename(args.pdf_path).replace(' ', '_')
    date_str = datetime.now().strftime('%Y-%m-%d')
    log_filename = f"{date_str}_{base_filename}.md"

    # Open the file in write mode
    with open(log_filename, 'w') as log_file:
        # Start the conversation log
        log_file.write(f"# Conversation Log for {base_filename}\n")
        log_file.write(f"## Date: {date_str}\n\n")

        # Interactive mode for asking questions
        print("PDF processing complete. You can now ask questions about the content.")
        print("Type 'exit' or 'quit' to end the session.")

        while True:
            question = input("Enter your question: ").strip()
            if question.lower() in ["exit", "quit"]:
                print("Exiting the session. Goodbye!")
                break

            # Get the answer
            response = qa_chain.invoke(
                {
                    "input": question,
                    "chat_history": [],  # Initially, pass an empty chat history
                }
            )

            # Format the answer
            answer = format_answer(response, max_documents=1)

            # Log the question and answer to the file
            log_file.write(f"### Question:\n{question}\n\n")
            log_file.write(f"### Answer:\n{answer}\n\n")

            # Display the answer to the user
            if answer.strip():
                print(f"Answer:\n\n{answer}\n")
            else:
                print("No relevant information found for your question. Please try asking a different question.")

if __name__ == "__main__":
    main()