summaryrefslogtreecommitdiffstats
path: root/pdf_reader.py
diff options
context:
space:
mode:
authorMatthew Lemon <y@yulqen.org>2024-08-27 14:08:48 +0100
committerMatthew Lemon <y@yulqen.org>2024-08-27 14:08:48 +0100
commitdb3ee8795824105244ad5b3045da25726aae1cfe (patch)
treec6a94f703624cfdd7de61cd143e7c21675ad77ee /pdf_reader.py
Initial commit - basic working script
Diffstat (limited to 'pdf_reader.py')
-rw-r--r--pdf_reader.py163
1 files changed, 163 insertions, 0 deletions
diff --git a/pdf_reader.py b/pdf_reader.py
new file mode 100644
index 0000000..7f9e2b3
--- /dev/null
+++ b/pdf_reader.py
@@ -0,0 +1,163 @@
+import os
+import argparse
+from langchain.text_splitter import RecursiveCharacterTextSplitter
+from langchain_community.document_loaders import PyMuPDFLoader
+from langchain_community.vectorstores import Chroma
+from langchain_community.llms import Ollama
+from langchain_community.embeddings import OllamaEmbeddings
+from langchain.chains import RetrievalQA
+from langchain.docstore.document import Document
+from PyPDF2 import PdfReader
+import pytesseract
+from PIL import Image
+import io
+import fitz # PyMuPDF
+
+# Make sure Tesseract is installed and accessible
+pytesseract.pytesseract.tesseract_cmd = r'/usr/local/bin/tesseract' # Update this path based on your tesseract installation
+
+def extract_text_from_pdf(pdf_path):
+ """Extract text from a PDF file."""
+ try:
+ # Try to extract text using PyPDF2
+ text = ""
+ with open(pdf_path, 'rb') as file:
+ pdf = PdfReader(file)
+ for page in pdf.pages:
+ text += page.extract_text() or ""
+ print(f"Extracted {len(text)} characters from the PDF using PyPDF2.")
+ return text
+ except Exception as e:
+ print(f"Error extracting text from PDF: {e}")
+ return None
+
+def perform_ocr_on_pdf(pdf_path):
+ """Perform OCR on a PDF file to extract text."""
+ try:
+ doc = fitz.open(pdf_path) # Open the PDF with PyMuPDF
+ text = ""
+ for page in doc:
+ pix = page.get_pixmap()
+ img = Image.open(io.BytesIO(pix.tobytes()))
+ ocr_text = pytesseract.image_to_string(img)
+ text += ocr_text
+ print(f"Extracted {len(text)} characters from the PDF using OCR.")
+ return text
+ except Exception as e:
+ print(f"Error performing OCR on PDF: {e}")
+ return None
+
+def get_pdf_text(pdf_path):
+ """Determine if OCR is necessary and extract text from PDF."""
+ text = extract_text_from_pdf(pdf_path)
+ if text and text.strip(): # Check if text is not None and contains non-whitespace characters
+ print(f"Successfully extracted text from PDF. Total characters: {len(text)}")
+ return text
+ else:
+ print("No text found using PyPDF2, performing OCR...")
+ ocr_text = perform_ocr_on_pdf(pdf_path)
+ if ocr_text and ocr_text.strip():
+ print(f"Successfully extracted text from PDF using OCR. Total characters: {len(ocr_text)}")
+ return ocr_text
+
+def process_pdf_for_qa(pdf_path, embeddings):
+ """Prepare a PDF for question answering."""
+ # Get text from PDF
+ text = get_pdf_text(pdf_path)
+
+ if not text: # Check if text extraction or OCR failed
+ print("Failed to extract text from the PDF. Please check the file.")
+ return None
+
+ # Split text into manageable chunks
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)
+ texts = text_splitter.split_text(text)
+ print(f"Split text into {len(texts)} chunks for processing.")
+
+ # Convert texts into Document objects for embedding
+ documents = [Document(page_content=chunk) for chunk in texts]
+
+ # Create a vector store from the documents using Chroma
+ print("Creating a vector store using Chroma...")
+ vector_store = Chroma.from_documents(documents, embeddings)
+
+ return vector_store
+
+def create_qa_chain(vector_store, llm):
+ """Create a QA chain for answering questions."""
+ # Use the vector store retriever directly
+ retriever = vector_store.as_retriever()
+
+ # Create a RetrievalQA chain using the new approach
+ qa_chain = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever)
+
+ return qa_chain
+
+def format_answer(response):
+ """Format the answer to ensure plain text output without any special characters."""
+
+ # Handle different response formats
+ if isinstance(response, dict):
+ answer = response.get('result', '')
+ elif isinstance(response, list):
+ answer = "\n\n".join(item.get('query', '') for item in response if 'text' in item)
+ else:
+ answer = str(response)
+
+ # Clean up the text: Remove excess newlines and strip whitespace
+ answer = answer.replace("\\n", "\n").replace("\\'", "'").replace('\\"', '"').strip()
+
+ return answer
+
+def main():
+ # Parse command line arguments
+ parser = argparse.ArgumentParser(description='PDF Question Answering using LangChain and Ollama')
+ parser.add_argument('pdf_path', type=str, help='Path to the PDF file')
+ args = parser.parse_args()
+
+ # Check if the PDF file exists
+ if not os.path.exists(args.pdf_path):
+ print(f"Error: The file {args.pdf_path} does not exist.")
+ exit(1)
+
+ # Initialize LLM (LLaMA 3.1 model hosted on Ollama)
+ llm = Ollama(model="llama3.1")
+
+ # Initialize Ollama embeddings model using nomic-embed-text:latest
+ embeddings = OllamaEmbeddings(model="nomic-embed-text:latest")
+
+ # Process the PDF and prepare it for QA
+ print("Processing the PDF. Please wait...")
+ vector_store = process_pdf_for_qa(args.pdf_path, embeddings)
+
+ if vector_store is None:
+ print("Processing failed. Exiting.")
+ exit(1)
+
+ # Create the QA chain
+ qa_chain = create_qa_chain(vector_store, llm)
+
+ # Interactive mode for asking questions
+ print("PDF processing complete. You can now ask questions about the content.")
+ print("Type 'exit' or 'quit' to end the session.")
+
+ while True:
+ question = input("Enter your question: ").strip()
+ if question.lower() in ["exit", "quit"]:
+ print("Exiting the session. Goodbye!")
+ break
+
+ # Get the answer
+ response = qa_chain.invoke(question)
+
+ # Format the answer
+ answer = format_answer(response)
+
+ # Check if the answer is empty or only contains newlines
+ if answer.strip():
+ print(f"Answer:\n\n{answer}\n")
+ else:
+ print("No relevant information found for your question. Please try asking a different question.")
+
+if __name__ == "__main__":
+ main()