import os
from openai import OpenAI
from qdrant_client import QdrantClient, models
from dotenv import load_dotenv
from concurrent.futures import ThreadPoolExecutor
from fastembed.late_interaction import LateInteractionTextEmbedding
from fastembed.sparse.bm25 import Bm25

openai_embeddings_model = os.getenv('EMBEDDING_MODEL')

load_dotenv()
qdrant_client = QdrantClient(
    url=os.getenv("QDRANT_URL"), 
    api_key=os.getenv("QDRANT_API_KEY"),
)

openai_client=OpenAI()
# collection_name=os.getenv("COLLECTION_NAME")



def get_dense_embedding(text):
    return openai_client.embeddings.create(input=[text], model=openai_embeddings_model).data[0].embedding

def get_late_interaction_embedding(text):
    late_interaction_embedding_model = LateInteractionTextEmbedding("colbert-ir/colbertv2.0")
    return list(late_interaction_embedding_model.passage_embed(text))[0][0].tolist()

def get_sparse_embedding(text):
    bm25_model = Bm25('Qdrant/bm25')
    return list(bm25_model.query_embed(text))[0]

def get_all_embeddings(query):
    text = query.replace("\n", " ")
    with ThreadPoolExecutor() as executor:
        dense_future = executor.submit(get_dense_embedding, text)
        late_interaction_future = executor.submit(get_late_interaction_embedding, text)
        sparse_future = executor.submit(get_sparse_embedding, text)
        
        dense_embedding = dense_future.result()
        late_interaction_embedding = late_interaction_future.result()
        sparse_embedding = sparse_future.result()
    
    return dense_embedding, sparse_embedding, late_interaction_embedding

import time
def get_context(query, collection_name):
    t1=time.time()
    d, s, l = get_all_embeddings(query)
    t2=time.time()
    print("Time=", t2-t1)
    prefetch = [
        models.Prefetch(
            query=d,
            using="dense_embeddings",
            limit=20,
        ),
        models.Prefetch(
            query=models.SparseVector(**s.as_object()),
            using="bm25",
            limit=20,
        ),
        models.Prefetch(
            query=l,
            using="late_interactions",
            limit=20,
        ),
    ]

    # Define filter by name for combined search
    # filter_by_name = models.Filter(
    #     must=[{"key": "pdf_id", "match": {"value": pdf_id}}]
    # )

    t3=time.time()
    print("Time=", t3-t2)
    results = qdrant_client.query_points(
        collection_name,
        prefetch=prefetch,
        query=models.FusionQuery(
            fusion=models.Fusion.RRF,
        ),
        with_payload=True,
        limit=5
    )
    t4=time.time()
    print("Time=", t4-t3)
    return "\n".join([point.payload['text'] for point in results.points])

def answer(query,collection_name):
    passage=get_context(query,collection_name)
    prompt=f"""
    You are a English literature professor, tasked with finding answer to the question asked by your student from the below passage:
    Passage:{passage} 
    Question: {query}
    You need to form you answer only unsing the 
    """
    response = openai_client.chat.completions.create(
        model=os.getenv('OPENAI_MODEL'),
        messages=[
            {
            "role": "user",
            "content": [
                    {
                    "type": "text",
                    "text": prompt
                    }
                ]
            },


        ],
    )
    return response.choices[0].message.content