summaryrefslogtreecommitdiffstats
path: root/pdf_reader.py
diff options
context:
space:
mode:
Diffstat (limited to 'pdf_reader.py')
-rw-r--r--pdf_reader.py25
1 files changed, 20 insertions, 5 deletions
diff --git a/pdf_reader.py b/pdf_reader.py
index e70a88b..62d79fb 100644
--- a/pdf_reader.py
+++ b/pdf_reader.py
@@ -3,6 +3,7 @@ import hashlib
import io
import json
import os
+from typing import Any
import fitz # PyMuPDF
import pytesseract
@@ -14,9 +15,10 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_chroma import Chroma
from langchain_community.embeddings import OllamaEmbeddings
from langchain_community.llms import Ollama
+from langchain_core.language_models import LLM
# Make sure Tesseract is installed and accessible
-pytesseract.pytesseract.tesseract_cmd = r'/usr/local/bin/tesseract' # Update this path based on your tesseract installation
+pytesseract.pytesseract.tesseract_cmd = r'/usr/bin/tesseract' # Update this path based on your tesseract installation
def extract_text_from_pdf(pdf_path):
@@ -135,13 +137,26 @@ def process_pdf_for_qa(pdf_path, embeddings, base_persist_directory):
return vector_store
-def create_qa_chain(vector_store, llm):
- """Create a QA chain for answering questions."""
+def create_qa_chain(vector_store: Any, llm: LLM) -> RetrievalQA:
+ """
+ Create a QA chain for answering questions using the provided vector store and Large Language Model (LLM).
+
+ Args:
+ vector_store: The database of vectors representing text embeddings.
+ llm: The Large Language Model instance.
+
+ Returns:
+ A RetrievalQA object that can be used to answer questions by querying the LLM with user input.
+ """
# Use the vector store retriever directly
- retriever = vector_store.as_retriever()
+ vector_store_retriever = vector_store.as_retriever()
# Create a RetrievalQA chain using the new approach
- qa_chain = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever)
+ qa_chain = RetrievalQA.from_chain_type(
+ llm=llm,
+ chain_type="stuff",
+ retriever=vector_store_retriever # Use a more specific variable name
+ )
return qa_chain