from fastapi import UploadFile, File, Form, APIRouter, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordRequestForm,HTTPBearer, HTTPAuthorizationCredentials
from fastapi.responses import JSONResponse
from qdrant_client import QdrantClient, models
from dotenv import load_dotenv
from PyPDF2 import PdfReader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from qdrant_client.models import VectorParams, Distance, SparseVectorParams, Modifier
from qdrant_client.models import PointStruct
from fastembed.late_interaction import LateInteractionTextEmbedding
from fastembed.sparse.bm25 import Bm25
import uuid
from dependencies.utils import *
import db_config.database as dbase
from db_config.models import Documents

router = APIRouter()
load_dotenv()

qdrant_client = QdrantClient(
    url=os.getenv("QDRANT_URL"), 
    api_key=os.getenv("QDRANT_API_KEY"),
)
# collection_name=os.getenv("COLLECTION_NAME")
collection_name = f"{os.getenv('COLLECTION_NAME')}_{uuid.uuid4().hex}"
openai_client=OpenAI()

@router.post("/uploadpdf")
async def upload_pdf(file: UploadFile = File(...), folderId: int = Form(...)):
    # Save the uploaded file to a temporary location
    file_location = f"uploads/{file.filename}"
    with open(file_location, "wb+") as file_object:
        file_object.write(file.file.read())

    # Extract text from the PDF
    reader = PdfReader(file_location)
    pdf_txt=""
    for page in reader.pages:
        pdf_txt+=' \n'+page.extract_text()

    # Split text into manageable chunks (if necessary)
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=2000,
        chunk_overlap=500,
        length_function=len,
        is_separator_regex=False,
    )
    
    chunks = text_splitter.split_text(pdf_txt)

    openai_embeddings_model = "text-embedding-3-small"

    pdf_id = f"{file.filename}"

    def generate_embeddings(text):    
        text = text.replace("\n", " ")
        return openai_client.embeddings.create(input = [text], model=openai_embeddings_model).data[0].embedding

    dense_embeddings=[]
    for chunk in chunks:
        dense_embeddings.append(generate_embeddings(chunk))
    
    ### Use QDrant vector database 
    qdrant_collection(chunk, dense_embeddings, pdf_id)
    
    if folderId:
        db = dbase.SessionLocal()
        new_document = Documents(
            collection_name=collection_name,
            document_name=file.filename,
            document_type='pdf', 
            folder_id=folderId
        )
        db.add(new_document)
        db.commit()
        db.refresh(new_document)

    return {"status": "PDF uploaded and processed", "filename": file.filename,"collection_name":collection_name}

@router.get("/chat_with_file")
async def chat_with_file(query: str, filename: str, doc_id: str):
    db = dbase.SessionLocal()
    document = db.query(Documents).filter(Documents.id == int(doc_id)).first()
    collection_name = document.collection_name
    response = answer(query, collection_name)
    return {"full_result": response}

def qdrant_collection(chunks, dense_embeddings, pdf_id):
    late_interaction_embedding_model = LateInteractionTextEmbedding("colbert-ir/colbertv2.0")
    late_interaction_embeddings = list(late_interaction_embedding_model.passage_embed(chunks))
    
    bm25_model=Bm25('Qdrant/bm25')
    sparse_embeddings=list(bm25_model.passage_embed(chunks))

    qdrant_client.create_collection(
        collection_name=collection_name,
        vectors_config={
            "dense_embeddings" : VectorParams(
                size=1536,
                distance=Distance.COSINE,
            ),
            "late_interactions" : VectorParams(
                size=len(late_interaction_embeddings[0][0]),
                distance=models.Distance.COSINE,
                multivector_config=models.MultiVectorConfig(
                    comparator=models.MultiVectorComparator.MAX_SIM,
                )
            ),
        },
        sparse_vectors_config={
            "bm25" : SparseVectorParams(
                modifier=Modifier.IDF
            )
        },
        optimizers_config=models.OptimizersConfigDiff(
            indexing_threshold=0,
        )
    )

    points=[]
    for i in range(len(chunks)):
        points.append(PointStruct(
            id=i,
            vector={
                "dense_embeddings": dense_embeddings[i],
                "bm25": sparse_embeddings[i].as_object(),
                "late_interactions": late_interaction_embeddings[i].tolist(),
            },
            payload={
                "_id": i,
                "text": chunks[i], 
                "pdf_id": pdf_id
            }
        ))

    # batch_size=1
    # for i in range(0, len(points), batch_size):
    #     batch = points[i:min(len(points), i + batch_size)]  # Get a batch of points
    #     qdrant_client.upload_points(collection_name, points=batch)
    #     print(f"Uploaded batch {i // batch_size + 1} with {len(batch)} points.")

    qdrant_client.upsert(collection_name=collection_name, points=points)
    print(f"Uploaded document points.")

    qdrant_client.update_collection(
        collection_name=collection_name,
        optimizer_config=models.OptimizersConfigDiff(indexing_threshold=20000),
    )