from sqlalchemy import asc, case, desc, func
from sqlalchemy.orm import Session
from db.models.lms_platform import LMSPlatform
from db.schemas.lms_platform import LMSPlatformCreate
from fastapi.responses import JSONResponse
from fastapi import  HTTPException, status

from fastapi.encoders import jsonable_encoder

from db.models.user_course_log import UserCourseLog

def get_lms_platforms(
    db: Session,
    pageNo: int = 1,
    recordsPerPage: int = 10,
    search: str = None,
    sort_by: str = "id",
    sort_order: str = "asc",
):
    # Base query for LMS platforms
    query = db.query(LMSPlatform)

    # 🔍 Search filter
    if search:
        search_pattern = f"%{search}%"
        query = query.filter(
            (LMSPlatform.issuer.ilike(search_pattern)) |
            (LMSPlatform.tenant_domain.ilike(search_pattern)) |
            (LMSPlatform.tenant_name.ilike(search_pattern))
        )

    # 🔽 Sorting
    sort_column = getattr(LMSPlatform, sort_by, LMSPlatform.id)
    query = query.order_by(desc(sort_column) if sort_order.lower() == "desc" else asc(sort_column))

    # 📄 Pagination
    total_records = query.count()
    offset = (pageNo - 1) * recordsPerPage
    platforms = query.offset(offset).limit(recordsPerPage).all()

    # ------------------------------
    # 🧮 Fetch totals per platform in one JOIN query (efficient)
    # ------------------------------
    totals_query = (
        db.query(
            UserCourseLog.platform_id,
            func.sum(
                case(
                    (UserCourseLog.type == "Credit", UserCourseLog.number_of_course),
                    else_=0
                )
            ).label("total_credited"),
            func.sum(
                case(
                    (UserCourseLog.type == "Debit", UserCourseLog.number_of_course),
                    else_=0
                )
            ).label("total_debited"),
        )
        .group_by(UserCourseLog.platform_id)
        .all()
    )

    # Convert totals to dictionary → {platform_id: {totals}}
    totals_map = {
        t.platform_id: {
            "total_credited": t.total_credited or 0,
            "total_debited": t.total_debited or 0,
        }
        for t in totals_query
    }

    # ------------------------------
    # 📌 Construct response list
    # ------------------------------
    platform_list = []
    for p in platforms:
        totals = totals_map.get(p.id, {"total_credited": 0, "total_debited": 0})

        platform_list.append({
            "id": p.id,
            "issuer": p.issuer,
            "tenant_domain": p.tenant_domain,
            "tenant_name": p.tenant_name,
            "client_id": p.client_id,
            "jwks_url": p.jwks_url,
            "token_url": p.token_url,
            "total_credited_courses": totals["total_credited"],
            "total_debited_courses": totals["total_debited"],
        })

    # ------------------------------
    # 📤 Return JSON Response
    # ------------------------------
    response_data = jsonable_encoder({
        "pageNo": pageNo,
        "recordsPerPage": recordsPerPage,
        "totalRecords": total_records,
        "totalPages": (total_records + recordsPerPage - 1) // recordsPerPage,
        "data": platform_list,
    })

    return JSONResponse(content=response_data, status_code=status.HTTP_200_OK)


def get_lms_platform(db: Session, platform_id: int):
    return db.query(LMSPlatform).filter(LMSPlatform.id == platform_id).first()

def create_lms_platform(db: Session, platform: LMSPlatformCreate):
    db_platform = LMSPlatform(**platform.dict())
    db.add(db_platform)
    db.commit()
    db.refresh(db_platform)
    return db_platform
