import argparse import hashlib import io import json import os from typing import Any import fitz # PyMuPDF import pytesseract from PIL import Image from PyPDF2 import PdfReader from langchain.chains import create_history_aware_retriever, create_retrieval_chain from langchain.chains.combine_documents import create_stuff_documents_chain from langchain.docstore.document import Document from langchain.memory import ConversationBufferMemory from langchain.prompts import ChatPromptTemplate from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_chroma import Chroma from langchain_community.embeddings import OllamaEmbeddings from langchain_community.llms import Ollama from langchain_core.language_models import LLM # Make sure Tesseract is installed and accessible pytesseract.pytesseract.tesseract_cmd = r'/usr/bin/tesseract' # Update this path based on your tesseract installation def extract_text_from_pdf(pdf_path): """Extract text from a PDF file.""" try: # Try to extract text using PyPDF2 text = "" with open(pdf_path, 'rb') as file: pdf = PdfReader(file) for page in pdf.pages: text += page.extract_text() or "" print(f"Extracted {len(text)} characters from the PDF using PyPDF2.") return text except Exception as e: print(f"Error extracting text from PDF: {e}") return None def perform_ocr_on_pdf(pdf_path): """Perform OCR on a PDF file to extract text.""" try: doc = fitz.open(pdf_path) # Open the PDF with PyMuPDF text = "" for page in doc: pix = page.get_pixmap() img = Image.open(io.BytesIO(pix.tobytes())) ocr_text = pytesseract.image_to_string(img) text += ocr_text print(f"Extracted {len(text)} characters from the PDF using OCR.") return text except Exception as e: print(f"Error performing OCR on PDF: {e}") return None def get_pdf_text(pdf_path): """Determine if OCR is necessary and extract text from PDF.""" text = extract_text_from_pdf(pdf_path) if text and text.strip(): # Check if text is not None and contains non-whitespace characters print(f"Successfully extracted text from PDF. Total characters: {len(text)}") return text else: print("No text found using PyPDF2, performing OCR...") ocr_text = perform_ocr_on_pdf(pdf_path) if ocr_text and ocr_text.strip(): print(f"Successfully extracted text from PDF using OCR. Total characters: {len(ocr_text)}") return ocr_text def compute_pdf_hash(pdf_path): """Compute a unique hash for the PDF file to identify if it's already processed.""" hasher = hashlib.sha256() with open(pdf_path, 'rb') as f: buf = f.read() hasher.update(buf) return hasher.hexdigest() def load_metadata(persist_directory): """Load metadata from a JSON file.""" metadata_path = os.path.join(persist_directory, 'metadata.json') if os.path.exists(metadata_path): with open(metadata_path, 'r') as f: return json.load(f) else: return {'processed_pdfs': []} def save_metadata(persist_directory, metadata): """Save metadata to a JSON file.""" metadata_path = os.path.join(persist_directory, 'metadata.json') with open(metadata_path, 'w') as f: json.dump(metadata, f) def process_pdf_for_qa(pdf_path, embeddings, base_persist_directory): """Prepare a PDF for question answering, using a unique Chroma persistence directory for each PDF.""" pdf_hash = compute_pdf_hash(pdf_path) persist_directory = os.path.join(base_persist_directory, pdf_hash) # Use hash to create a unique directory # Load or initialize metadata metadata = load_metadata(base_persist_directory) # Check if this PDF has already been processed if pdf_hash in metadata['processed_pdfs']: print(f"This PDF has already been processed. Loading embeddings from the database at {persist_directory}.") vector_store = Chroma(persist_directory=persist_directory, embedding_function=embeddings) return vector_store # Initialize or load Chroma vector store for this specific PDF vector_store = Chroma(persist_directory=persist_directory, embedding_function=embeddings) # Get text from PDF text = get_pdf_text(pdf_path) if not text: # Check if text extraction or OCR failed print("Failed to extract text from the PDF. Please check the file.") return None # Split text into manageable chunks text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100) texts = text_splitter.split_text(text) print(f"Split text into {len(texts)} chunks for processing.") # Convert texts into Document objects for embedding documents = [Document(page_content=chunk) for chunk in texts] # Add the new documents to the existing vector store print("Adding new documents to the vector store and persisting...") vector_store.add_documents(documents) # Update the metadata to include this processed PDF hash metadata['processed_pdfs'].append(pdf_hash) save_metadata(base_persist_directory, metadata) return vector_store def create_qa_chain_with_memory(vector_store: Any, llm: LLM) -> Any: """ Create a QA chain with conversational memory for answering questions using the provided vector store and LLM. Args: vector_store: The database of vectors representing text embeddings. llm: The Large Language Model instance. Returns: A retrieval chain object that can be used to answer questions by querying the LLM with user input. """ # Use the vector store retriever directly vector_store_retriever = vector_store.as_retriever() # Initialize conversation memory memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True) # Condense question prompt template condense_question_system_template = ( "Given a chat history and the latest user question " "which might reference context in the chat history, " "formulate a standalone question which can be understood " "without the chat history. Do NOT answer the question, " "just reformulate it if needed and otherwise return it as is." ) condense_question_prompt = ChatPromptTemplate.from_messages( [ ("system", condense_question_system_template), ("placeholder", "{chat_history}"), ("human", "{input}"), ] ) # Create a history-aware retriever history_aware_retriever = create_history_aware_retriever( llm, vector_store_retriever, condense_question_prompt ) # System prompt for question answering system_prompt = ( "You are an assistant for question-answering tasks. " "Use the following pieces of retrieved context to answer " "the question. If you don't know the answer, say that you " "don't know. Use three sentences maximum and keep the " "answer concise." "\n\n" "{context}" ) qa_prompt = ChatPromptTemplate.from_messages( [ ("system", system_prompt), ("placeholder", "{chat_history}"), ("human", "{input}"), ] ) # Create a stuff documents chain for question answering qa_chain = create_stuff_documents_chain(llm, qa_prompt) # Create a retrieval chain with the history-aware retriever and QA chain convo_qa_chain = create_retrieval_chain(history_aware_retriever, qa_chain) return convo_qa_chain def format_answer(response): """Format the answer and context to ensure plain text output without any special characters.""" # Extract the main answer answer = response.get('answer', '') # Initialize context summary context_summary = "" # Extract and summarize context if available context = response.get('context', []) if context and isinstance(context, list): # Concatenate page_content from each Document object in the context context_summary = "\n\n".join(doc.page_content for doc in context) # Clean up the answer text if isinstance(answer, str): answer = answer.replace("\\n", "\n").replace("\\'", "'").replace('\\"', '"').strip() # Clean up the context summary text if isinstance(context_summary, str): context_summary = context_summary.replace("\\n", "\n").replace("\\'", "'").replace('\\"', '"').strip() # Combine the answer and context summary formatted_output = f"Answer:\n\n{answer}\n\nContext:\n\n{context_summary}" return formatted_output def main(): # Parse command line arguments parser = argparse.ArgumentParser(description='PDF Question Answering using LangChain and Ollama') parser.add_argument('pdf_path', type=str, help='Path to the PDF file') parser.add_argument('--persist', type=str, default='db', help='Base directory to save or load persisted vector stores') args = parser.parse_args() # Check if the PDF file exists if not os.path.exists(args.pdf_path): print(f"Error: The file {args.pdf_path} does not exist.") exit(1) # Initialize LLM (LLaMA 3.1 model hosted on Ollama) llm = Ollama(model="llama3.1") # Initialize Ollama embeddings model using nomic-embed-text:latest embeddings = OllamaEmbeddings(model="nomic-embed-text:latest") # Process the PDF and prepare it for QA print("Processing the PDF. Please wait...") vector_store = process_pdf_for_qa(args.pdf_path, embeddings, args.persist) if vector_store is None: print("Processing failed. Exiting.") exit(1) # Create the QA chain with memory qa_chain = create_qa_chain_with_memory(vector_store, llm) # Interactive mode for asking questions print("PDF processing complete. You can now ask questions about the content.") print("Type 'exit' or 'quit' to end the session.") while True: question = input("Enter your question: ").strip() if question.lower() in ["exit", "quit"]: print("Exiting the session. Goodbye!") break # Get the answer response = qa_chain.invoke( { "input": question, "chat_history": [], # Initially, pass an empty chat history } ) # Format the answer answer = format_answer(response) # Check if the answer is empty or only contains newlines if answer.strip(): print(f"Answer:\n\n{answer}\n") else: print("No relevant information found for your question. Please try asking a different question.") if __name__ == "__main__": main()