summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMatthew Lemon <y@yulqen.org>2024-08-27 16:25:47 +0100
committerMatthew Lemon <y@yulqen.org>2024-08-27 16:25:47 +0100
commit6cff0842a3d59a66f1cbb0b8b96881473c795549 (patch)
tree15ea5c6d1a1b9df48d56b97ebef737b24c4dc0d5
parentdb3ee8795824105244ad5b3045da25726aae1cfe (diff)
Persistent embeddings
-rw-r--r--.gitignore1
-rw-r--r--pdf_reader.py102
2 files changed, 79 insertions, 24 deletions
diff --git a/.gitignore b/.gitignore
index 88785d7..7a7b1ba 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,2 +1,3 @@
chroma_db
.idea
+db
diff --git a/pdf_reader.py b/pdf_reader.py
index 7f9e2b3..af7b3d1 100644
--- a/pdf_reader.py
+++ b/pdf_reader.py
@@ -1,21 +1,24 @@
-import os
import argparse
-from langchain.text_splitter import RecursiveCharacterTextSplitter
-from langchain_community.document_loaders import PyMuPDFLoader
-from langchain_community.vectorstores import Chroma
-from langchain_community.llms import Ollama
-from langchain_community.embeddings import OllamaEmbeddings
-from langchain.chains import RetrievalQA
-from langchain.docstore.document import Document
-from PyPDF2 import PdfReader
-import pytesseract
-from PIL import Image
+import hashlib
import io
+import json
+import os
+
import fitz # PyMuPDF
+import pytesseract
+from PIL import Image
+from PyPDF2 import PdfReader
+from langchain.chains import RetrievalQA
+from langchain.docstore.document import Document
+from langchain.text_splitter import RecursiveCharacterTextSplitter
+from langchain_chroma import Chroma
+from langchain_community.embeddings import OllamaEmbeddings
+from langchain_community.llms import Ollama
# Make sure Tesseract is installed and accessible
pytesseract.pytesseract.tesseract_cmd = r'/usr/local/bin/tesseract' # Update this path based on your tesseract installation
+
def extract_text_from_pdf(pdf_path):
"""Extract text from a PDF file."""
try:
@@ -31,6 +34,7 @@ def extract_text_from_pdf(pdf_path):
print(f"Error extracting text from PDF: {e}")
return None
+
def perform_ocr_on_pdf(pdf_path):
"""Perform OCR on a PDF file to extract text."""
try:
@@ -47,6 +51,7 @@ def perform_ocr_on_pdf(pdf_path):
print(f"Error performing OCR on PDF: {e}")
return None
+
def get_pdf_text(pdf_path):
"""Determine if OCR is necessary and extract text from PDF."""
text = extract_text_from_pdf(pdf_path)
@@ -60,8 +65,50 @@ def get_pdf_text(pdf_path):
print(f"Successfully extracted text from PDF using OCR. Total characters: {len(ocr_text)}")
return ocr_text
-def process_pdf_for_qa(pdf_path, embeddings):
- """Prepare a PDF for question answering."""
+
+def compute_pdf_hash(pdf_path):
+ """Compute a unique hash for the PDF file to identify if it's already processed."""
+ hasher = hashlib.sha256()
+ with open(pdf_path, 'rb') as f:
+ buf = f.read()
+ hasher.update(buf)
+ return hasher.hexdigest()
+
+
+def load_metadata(persist_directory):
+ """Load metadata from a JSON file."""
+ metadata_path = os.path.join(persist_directory, 'metadata.json')
+ if os.path.exists(metadata_path):
+ with open(metadata_path, 'r') as f:
+ return json.load(f)
+ else:
+ return {'processed_pdfs': []}
+
+
+def save_metadata(persist_directory, metadata):
+ """Save metadata to a JSON file."""
+ metadata_path = os.path.join(persist_directory, 'metadata.json')
+ with open(metadata_path, 'w') as f:
+ json.dump(metadata, f)
+
+
+def process_pdf_for_qa(pdf_path, embeddings, persist_directory):
+ """Prepare a PDF for question answering, using Chroma's persistence."""
+
+ pdf_hash = compute_pdf_hash(pdf_path)
+
+ # Load or initialize metadata
+ metadata = load_metadata(persist_directory)
+
+ # Check if this PDF has already been processed
+ if pdf_hash in metadata['processed_pdfs']:
+ print("This PDF has already been processed. Loading embeddings from the database.")
+ vector_store = Chroma(persist_directory=persist_directory, embedding_function=embeddings)
+ return vector_store
+
+ # Initialize or load Chroma vector store
+ vector_store = Chroma(persist_directory=persist_directory, embedding_function=embeddings)
+
# Get text from PDF
text = get_pdf_text(pdf_path)
@@ -76,10 +123,14 @@ def process_pdf_for_qa(pdf_path, embeddings):
# Convert texts into Document objects for embedding
documents = [Document(page_content=chunk) for chunk in texts]
-
- # Create a vector store from the documents using Chroma
- print("Creating a vector store using Chroma...")
- vector_store = Chroma.from_documents(documents, embeddings)
+
+ # Add the new documents to the existing vector store
+ print("Adding new documents to the vector store and persisting...")
+ vector_store.add_documents(documents)
+
+ # Update the metadata to include this processed PDF hash
+ metadata['processed_pdfs'].append(pdf_hash)
+ save_metadata(persist_directory, metadata)
return vector_store
@@ -93,26 +144,28 @@ def create_qa_chain(vector_store, llm):
return qa_chain
+
def format_answer(response):
"""Format the answer to ensure plain text output without any special characters."""
-
# Handle different response formats
if isinstance(response, dict):
answer = response.get('result', '')
elif isinstance(response, list):
- answer = "\n\n".join(item.get('query', '') for item in response if 'text' in item)
+ answer = "\n\n".join(item.get('result', '') for item in response if 'result' in item)
else:
answer = str(response)
-
+
# Clean up the text: Remove excess newlines and strip whitespace
answer = answer.replace("\\n", "\n").replace("\\'", "'").replace('\\"', '"').strip()
return answer
+
def main():
# Parse command line arguments
parser = argparse.ArgumentParser(description='PDF Question Answering using LangChain and Ollama')
parser.add_argument('pdf_path', type=str, help='Path to the PDF file')
+ parser.add_argument('--persist', type=str, default='db', help='Directory to save or load persisted vector store')
args = parser.parse_args()
# Check if the PDF file exists
@@ -128,12 +181,12 @@ def main():
# Process the PDF and prepare it for QA
print("Processing the PDF. Please wait...")
- vector_store = process_pdf_for_qa(args.pdf_path, embeddings)
+ vector_store = process_pdf_for_qa(args.pdf_path, embeddings, args.persist)
if vector_store is None:
print("Processing failed. Exiting.")
exit(1)
-
+
# Create the QA chain
qa_chain = create_qa_chain(vector_store, llm)
@@ -146,10 +199,10 @@ def main():
if question.lower() in ["exit", "quit"]:
print("Exiting the session. Goodbye!")
break
-
+
# Get the answer
response = qa_chain.invoke(question)
-
+
# Format the answer
answer = format_answer(response)
@@ -159,5 +212,6 @@ def main():
else:
print("No relevant information found for your question. Please try asking a different question.")
+
if __name__ == "__main__":
main()