diff options
author | Matthew Lemon <y@yulqen.org> | 2024-08-27 14:08:48 +0100 |
---|---|---|
committer | Matthew Lemon <y@yulqen.org> | 2024-08-27 14:08:48 +0100 |
commit | db3ee8795824105244ad5b3045da25726aae1cfe (patch) | |
tree | c6a94f703624cfdd7de61cd143e7c21675ad77ee /pdf_reader.py |
Initial commit - basic working script
Diffstat (limited to 'pdf_reader.py')
-rw-r--r-- | pdf_reader.py | 163 |
1 files changed, 163 insertions, 0 deletions
diff --git a/pdf_reader.py b/pdf_reader.py new file mode 100644 index 0000000..7f9e2b3 --- /dev/null +++ b/pdf_reader.py @@ -0,0 +1,163 @@ +import os +import argparse +from langchain.text_splitter import RecursiveCharacterTextSplitter +from langchain_community.document_loaders import PyMuPDFLoader +from langchain_community.vectorstores import Chroma +from langchain_community.llms import Ollama +from langchain_community.embeddings import OllamaEmbeddings +from langchain.chains import RetrievalQA +from langchain.docstore.document import Document +from PyPDF2 import PdfReader +import pytesseract +from PIL import Image +import io +import fitz # PyMuPDF + +# Make sure Tesseract is installed and accessible +pytesseract.pytesseract.tesseract_cmd = r'/usr/local/bin/tesseract' # Update this path based on your tesseract installation + +def extract_text_from_pdf(pdf_path): + """Extract text from a PDF file.""" + try: + # Try to extract text using PyPDF2 + text = "" + with open(pdf_path, 'rb') as file: + pdf = PdfReader(file) + for page in pdf.pages: + text += page.extract_text() or "" + print(f"Extracted {len(text)} characters from the PDF using PyPDF2.") + return text + except Exception as e: + print(f"Error extracting text from PDF: {e}") + return None + +def perform_ocr_on_pdf(pdf_path): + """Perform OCR on a PDF file to extract text.""" + try: + doc = fitz.open(pdf_path) # Open the PDF with PyMuPDF + text = "" + for page in doc: + pix = page.get_pixmap() + img = Image.open(io.BytesIO(pix.tobytes())) + ocr_text = pytesseract.image_to_string(img) + text += ocr_text + print(f"Extracted {len(text)} characters from the PDF using OCR.") + return text + except Exception as e: + print(f"Error performing OCR on PDF: {e}") + return None + +def get_pdf_text(pdf_path): + """Determine if OCR is necessary and extract text from PDF.""" + text = extract_text_from_pdf(pdf_path) + if text and text.strip(): # Check if text is not None and contains non-whitespace characters + print(f"Successfully extracted text from PDF. Total characters: {len(text)}") + return text + else: + print("No text found using PyPDF2, performing OCR...") + ocr_text = perform_ocr_on_pdf(pdf_path) + if ocr_text and ocr_text.strip(): + print(f"Successfully extracted text from PDF using OCR. Total characters: {len(ocr_text)}") + return ocr_text + +def process_pdf_for_qa(pdf_path, embeddings): + """Prepare a PDF for question answering.""" + # Get text from PDF + text = get_pdf_text(pdf_path) + + if not text: # Check if text extraction or OCR failed + print("Failed to extract text from the PDF. Please check the file.") + return None + + # Split text into manageable chunks + text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100) + texts = text_splitter.split_text(text) + print(f"Split text into {len(texts)} chunks for processing.") + + # Convert texts into Document objects for embedding + documents = [Document(page_content=chunk) for chunk in texts] + + # Create a vector store from the documents using Chroma + print("Creating a vector store using Chroma...") + vector_store = Chroma.from_documents(documents, embeddings) + + return vector_store + +def create_qa_chain(vector_store, llm): + """Create a QA chain for answering questions.""" + # Use the vector store retriever directly + retriever = vector_store.as_retriever() + + # Create a RetrievalQA chain using the new approach + qa_chain = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever) + + return qa_chain + +def format_answer(response): + """Format the answer to ensure plain text output without any special characters.""" + + # Handle different response formats + if isinstance(response, dict): + answer = response.get('result', '') + elif isinstance(response, list): + answer = "\n\n".join(item.get('query', '') for item in response if 'text' in item) + else: + answer = str(response) + + # Clean up the text: Remove excess newlines and strip whitespace + answer = answer.replace("\\n", "\n").replace("\\'", "'").replace('\\"', '"').strip() + + return answer + +def main(): + # Parse command line arguments + parser = argparse.ArgumentParser(description='PDF Question Answering using LangChain and Ollama') + parser.add_argument('pdf_path', type=str, help='Path to the PDF file') + args = parser.parse_args() + + # Check if the PDF file exists + if not os.path.exists(args.pdf_path): + print(f"Error: The file {args.pdf_path} does not exist.") + exit(1) + + # Initialize LLM (LLaMA 3.1 model hosted on Ollama) + llm = Ollama(model="llama3.1") + + # Initialize Ollama embeddings model using nomic-embed-text:latest + embeddings = OllamaEmbeddings(model="nomic-embed-text:latest") + + # Process the PDF and prepare it for QA + print("Processing the PDF. Please wait...") + vector_store = process_pdf_for_qa(args.pdf_path, embeddings) + + if vector_store is None: + print("Processing failed. Exiting.") + exit(1) + + # Create the QA chain + qa_chain = create_qa_chain(vector_store, llm) + + # Interactive mode for asking questions + print("PDF processing complete. You can now ask questions about the content.") + print("Type 'exit' or 'quit' to end the session.") + + while True: + question = input("Enter your question: ").strip() + if question.lower() in ["exit", "quit"]: + print("Exiting the session. Goodbye!") + break + + # Get the answer + response = qa_chain.invoke(question) + + # Format the answer + answer = format_answer(response) + + # Check if the answer is empty or only contains newlines + if answer.strip(): + print(f"Answer:\n\n{answer}\n") + else: + print("No relevant information found for your question. Please try asking a different question.") + +if __name__ == "__main__": + main() |