diff options
author | Matthew Lemon <y@yulqen.org> | 2024-08-27 16:25:47 +0100 |
---|---|---|
committer | Matthew Lemon <y@yulqen.org> | 2024-08-27 16:25:47 +0100 |
commit | 6cff0842a3d59a66f1cbb0b8b96881473c795549 (patch) | |
tree | 15ea5c6d1a1b9df48d56b97ebef737b24c4dc0d5 | |
parent | db3ee8795824105244ad5b3045da25726aae1cfe (diff) |
Persistent embeddings
-rw-r--r-- | .gitignore | 1 | ||||
-rw-r--r-- | pdf_reader.py | 102 |
2 files changed, 79 insertions, 24 deletions
@@ -1,2 +1,3 @@ chroma_db .idea +db diff --git a/pdf_reader.py b/pdf_reader.py index 7f9e2b3..af7b3d1 100644 --- a/pdf_reader.py +++ b/pdf_reader.py @@ -1,21 +1,24 @@ -import os import argparse -from langchain.text_splitter import RecursiveCharacterTextSplitter -from langchain_community.document_loaders import PyMuPDFLoader -from langchain_community.vectorstores import Chroma -from langchain_community.llms import Ollama -from langchain_community.embeddings import OllamaEmbeddings -from langchain.chains import RetrievalQA -from langchain.docstore.document import Document -from PyPDF2 import PdfReader -import pytesseract -from PIL import Image +import hashlib import io +import json +import os + import fitz # PyMuPDF +import pytesseract +from PIL import Image +from PyPDF2 import PdfReader +from langchain.chains import RetrievalQA +from langchain.docstore.document import Document +from langchain.text_splitter import RecursiveCharacterTextSplitter +from langchain_chroma import Chroma +from langchain_community.embeddings import OllamaEmbeddings +from langchain_community.llms import Ollama # Make sure Tesseract is installed and accessible pytesseract.pytesseract.tesseract_cmd = r'/usr/local/bin/tesseract' # Update this path based on your tesseract installation + def extract_text_from_pdf(pdf_path): """Extract text from a PDF file.""" try: @@ -31,6 +34,7 @@ def extract_text_from_pdf(pdf_path): print(f"Error extracting text from PDF: {e}") return None + def perform_ocr_on_pdf(pdf_path): """Perform OCR on a PDF file to extract text.""" try: @@ -47,6 +51,7 @@ def perform_ocr_on_pdf(pdf_path): print(f"Error performing OCR on PDF: {e}") return None + def get_pdf_text(pdf_path): """Determine if OCR is necessary and extract text from PDF.""" text = extract_text_from_pdf(pdf_path) @@ -60,8 +65,50 @@ def get_pdf_text(pdf_path): print(f"Successfully extracted text from PDF using OCR. Total characters: {len(ocr_text)}") return ocr_text -def process_pdf_for_qa(pdf_path, embeddings): - """Prepare a PDF for question answering.""" + +def compute_pdf_hash(pdf_path): + """Compute a unique hash for the PDF file to identify if it's already processed.""" + hasher = hashlib.sha256() + with open(pdf_path, 'rb') as f: + buf = f.read() + hasher.update(buf) + return hasher.hexdigest() + + +def load_metadata(persist_directory): + """Load metadata from a JSON file.""" + metadata_path = os.path.join(persist_directory, 'metadata.json') + if os.path.exists(metadata_path): + with open(metadata_path, 'r') as f: + return json.load(f) + else: + return {'processed_pdfs': []} + + +def save_metadata(persist_directory, metadata): + """Save metadata to a JSON file.""" + metadata_path = os.path.join(persist_directory, 'metadata.json') + with open(metadata_path, 'w') as f: + json.dump(metadata, f) + + +def process_pdf_for_qa(pdf_path, embeddings, persist_directory): + """Prepare a PDF for question answering, using Chroma's persistence.""" + + pdf_hash = compute_pdf_hash(pdf_path) + + # Load or initialize metadata + metadata = load_metadata(persist_directory) + + # Check if this PDF has already been processed + if pdf_hash in metadata['processed_pdfs']: + print("This PDF has already been processed. Loading embeddings from the database.") + vector_store = Chroma(persist_directory=persist_directory, embedding_function=embeddings) + return vector_store + + # Initialize or load Chroma vector store + vector_store = Chroma(persist_directory=persist_directory, embedding_function=embeddings) + # Get text from PDF text = get_pdf_text(pdf_path) @@ -76,10 +123,14 @@ def process_pdf_for_qa(pdf_path, embeddings): # Convert texts into Document objects for embedding documents = [Document(page_content=chunk) for chunk in texts] - - # Create a vector store from the documents using Chroma - print("Creating a vector store using Chroma...") - vector_store = Chroma.from_documents(documents, embeddings) + + # Add the new documents to the existing vector store + print("Adding new documents to the vector store and persisting...") + vector_store.add_documents(documents) + + # Update the metadata to include this processed PDF hash + metadata['processed_pdfs'].append(pdf_hash) + save_metadata(persist_directory, metadata) return vector_store @@ -93,26 +144,28 @@ def create_qa_chain(vector_store, llm): return qa_chain + def format_answer(response): """Format the answer to ensure plain text output without any special characters.""" - # Handle different response formats if isinstance(response, dict): answer = response.get('result', '') elif isinstance(response, list): - answer = "\n\n".join(item.get('query', '') for item in response if 'text' in item) + answer = "\n\n".join(item.get('result', '') for item in response if 'result' in item) else: answer = str(response) - + # Clean up the text: Remove excess newlines and strip whitespace answer = answer.replace("\\n", "\n").replace("\\'", "'").replace('\\"', '"').strip() return answer + def main(): # Parse command line arguments parser = argparse.ArgumentParser(description='PDF Question Answering using LangChain and Ollama') parser.add_argument('pdf_path', type=str, help='Path to the PDF file') + parser.add_argument('--persist', type=str, default='db', help='Directory to save or load persisted vector store') args = parser.parse_args() # Check if the PDF file exists @@ -128,12 +181,12 @@ def main(): # Process the PDF and prepare it for QA print("Processing the PDF. Please wait...") - vector_store = process_pdf_for_qa(args.pdf_path, embeddings) + vector_store = process_pdf_for_qa(args.pdf_path, embeddings, args.persist) if vector_store is None: print("Processing failed. Exiting.") exit(1) - + # Create the QA chain qa_chain = create_qa_chain(vector_store, llm) @@ -146,10 +199,10 @@ def main(): if question.lower() in ["exit", "quit"]: print("Exiting the session. Goodbye!") break - + # Get the answer response = qa_chain.invoke(question) - + # Format the answer answer = format_answer(response) @@ -159,5 +212,6 @@ def main(): else: print("No relevant information found for your question. Please try asking a different question.") + if __name__ == "__main__": main() |