import os
import json
import time
from typing import Dict, Optional
import logging
import httpx
from fastapi import FastAPI, APIRouter, Depends, Request, Response, Form, HTTPException
from fastapi.responses import RedirectResponse, HTMLResponse, JSONResponse
from jose import jwt
from jose.utils import base64url_decode
from urllib.parse import quote
from urllib.parse import urlencode
import secrets
from db_config.models import LMSPlatform, Deployment, AccessToken, LoginState
import db_config.database as dbase
from dotenv import load_dotenv

load_dotenv()  # optional: read .env for easy config in dev
router = APIRouter()

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# -------------------------
# CONFIG - change these
# -------------------------
# Canvas SSO base URL (you requested sso.canvaslms.com)
CANVAS_BASE = os.getenv("CANVAS_BASE", "https://sso.canvaslms.com")

# Endpoints (common)
AUTHORIZE_ENDPOINT = "https://sso.canvaslms.com/api/lti/authorize_redirect"
TOKEN_ENDPOINT = f"{CANVAS_BASE}/login/oauth2/token"
JWKS_ENDPOINT = "https://canvas.instructure.com/api/lti/security/jwks"
CANVAS_ISSUER = "https://canvas.instructure.com"
CANVAS_API_BASE = f"{CANVAS_BASE}/api/v1"

# Your tool credentials (set as env vars in production)
CLIENT_ID = os.getenv("CLIENT_ID", "282480000000000004")
CLIENT_SECRET = os.getenv("CLIENT_SECRET", "yVCzMWhwrC6yrLKHRW7eXFyhLN7txRFz6KaKCCrfrKJmFG6G6mzDmQK6vcRGzrrC")
# The deployment id Canvas sends in the id_token (optional for some flows)
DEPLOYMENT_ID = os.getenv("LTI_DEPLOYMENT_ID", "")

# Where Canvas will post the id_token (publicly reachable)
# Make sure this exact URL is configured in Canvas tool config
TARGET_LINK_URI = os.getenv("TARGET_LINK_URI", "https://answerous.contactous.com:9000/lti/launch")

# A random state / nonce map for demo (in-memory). Use DB for production.
# Map state -> data (e.g., {state: {"nonce": "...", "created": ts}})
IN_MEMORY_STATE = {}
IN_MEMORY_LAUNCHES = {}  # store launch info keyed by user or state
IN_MEMORY_ACCESS_TOKENS = {}  # store tokens keyed by user_id or client

# Timeouts
HTTPX_TIMEOUT = 10.0


# -------------------------
# Utilities
# -------------------------
async def fetch_jwks() -> Dict:
    """Fetch JWKS from Canvas platform."""
    async with httpx.AsyncClient(timeout=HTTPX_TIMEOUT) as client:
        r = await client.get(JWKS_ENDPOINT)
        r.raise_for_status()
        return r.json()


def _find_jwk_for_kid(jwks: Dict, kid: str) -> Optional[Dict]:
    for k in jwks.get("keys", []):
        if k.get("kid") == kid:
            return k
    return None


# -------------------------
# Routes
# -------------------------

@router.api_route("/login", methods=["GET", "POST"])
async def login(
    request: Request
):
    """
    Initiate OIDC login with Canvas SSO.
    - login_hint: typically provided by the platform when building the deployment; for testing you can provide a dummy value.
    - target: optional override for the TARGET_LINK_URI
    """
    state = secrets.token_urlsafe(32)
    nonce = secrets.token_urlsafe(32)
    IN_MEMORY_STATE[state] = {"nonce": nonce, "created": time.time()}

    params = dict(request.query_params)

    # If POST, also include form body
    if request.method == "POST":
        form = await request.form()
        params.update(form)
    logger.info(f"Login request params: {params}")
    target = params.get("target_link_uri")
    login_hint = params.get("login_hint")
    if not login_hint:
        login_hint = "syllabuild"  # dummy for testing; normally provided by platform

    target_link_uri = target if target else TARGET_LINK_URI

    # iss is the platform issuer — Canvas often uses the platform base URL as issuer.
    # For sso.canvaslms.com use CANVAS_BASE; if Canvas gives a different iss in your platform settings, set that here.
    iss = CANVAS_ISSUER
    login_hint = params.get("login_hint")
    # target_link_uri = params.get("target_link_uri")
    lti_message_hint = params.get("lti_message_hint")

    params = {
        "iss": iss,
        "client_id": CLIENT_ID,
        "login_hint": login_hint,
        # "lti_message_hint": lti_message_hint,
        "target_link_uri": target_link_uri,
        "state": state,
        "nonce": nonce,
        "response_type": "id_token",
        "response_mode": "form_post",
        "scope": "openid",
        "prompt": "none",
    }

    if lti_message_hint:
        params["lti_message_hint"] = lti_message_hint

    if DEPLOYMENT_ID:
        params["lti_deployment_id"] = DEPLOYMENT_ID
    

    url = f"{AUTHORIZE_ENDPOINT}?{urlencode(params)}"
    
    logger.info(f"Redirecting to Canvas authorize endpoint: {url}")

    return RedirectResponse(url)


@router.api_route("/launch", methods=["GET", "POST"])
async def launch(request: Request):
    """
    Canvas will POST an id_token (JWT) after OIDC login.
    """
    form = await request.form()
    id_token = form.get("id_token")
    state = form.get("state")

    if not id_token or not state:
        raise HTTPException(status_code=400, detail="Missing id_token or state")

    if state not in IN_MEMORY_STATE:
        raise HTTPException(status_code=400, detail="Invalid or expired state")

    expected_nonce = IN_MEMORY_STATE[state]["nonce"]

    # Decode headers
    try:
        headers = jwt.get_unverified_header(id_token)
    except Exception as e:
        raise HTTPException(status_code=400, detail=f"Invalid id_token header: {e}")

    kid = headers.get("kid")
    jwks = await fetch_jwks()
    jwk = _find_jwk_for_kid(jwks, kid)
    if not jwk:
        raise HTTPException(status_code=400, detail=f"No matching JWK for kid={kid}")

    try:
        claims = jwt.decode(
            id_token,
            jwk,
            algorithms=["RS256"],
            audience=CLIENT_ID,
            issuer=CANVAS_ISSUER,
        )
    except Exception as e:
        raise HTTPException(status_code=401, detail=f"Failed to verify id_token: {e}")

    # Verify nonce
    if claims.get("nonce") != expected_nonce:
        raise HTTPException(status_code=401, detail="Invalid nonce in id_token")

    sub = claims.get("sub")
    auth_code = claims.get("code")  # if Canvas sent authorization_code

    # Save launch
    IN_MEMORY_LAUNCHES[sub] = {
        "claims": claims,
        "state": state,
        "auth_code": auth_code,
        "created": time.time(),
    }

    html = f"""
    <h2>LTI Launch Success</h2>
    <p>sub (user id): {sub}</p>
    <p>deployment_id: {claims.get("https://purl.imsglobal.org/spec/lti/claim/deployment_id")}</p>
    <p>message_type: {claims.get("https://purl.imsglobal.org/spec/lti/claim/message_type")}</p>
    <p>lti_version: {claims.get("https://purl.imsglobal.org/spec/lti/claim/version")}</p>
    <p>auth_code: {auth_code}</p>
    <pre>{json.dumps(claims, indent=2)}</pre>
    <p>Next: call <code>/token_exchange?sub={sub}&grant_type=authorization_code</code></p>
    """
    return HTMLResponse(html)


@router.get("/token_exchange")
async def token_exchange(sub: str, grant_type: str = "authorization_code"):
    """
    Exchange for an access token that can be used to call Canvas APIs.
    
    Supports:
    - grant_type=authorization_code (preferred in Canvas LTI 1.3)
    - grant_type=client_credentials (if enabled by your Canvas admin)
    """
    if sub not in IN_MEMORY_LAUNCHES:
        raise HTTPException(status_code=404, detail="Launch not found for sub")

    launch = IN_MEMORY_LAUNCHES[sub]
    data = {}

    if grant_type == "authorization_code":
        # In real Canvas launches, the "code" is sometimes provided as a query param,
        # or as part of the id_token claims under a specific LTI claim.
        code = launch["claims"].get("https://purl.imsglobal.org/spec/lti/claim/code")
        if not code:
            raise HTTPException(status_code=400, detail="No authorization code found in launch claims")

        data = {
            "grant_type": "authorization_code",
            "client_id": CLIENT_ID,
            "client_secret": CLIENT_SECRET,
            "redirect_uri": TARGET_LINK_URI,
            "code": code,
        }

    elif grant_type == "client_credentials":
        data = {
            "grant_type": "client_credentials",
            "client_id": CLIENT_ID,
            "client_secret": CLIENT_SECRET,
            # adjust scope depending on your Canvas API needs
            "scope": "url:GET|/api/v1/*",
        }

    else:
        raise HTTPException(status_code=400, detail="Unsupported grant_type")

    async with httpx.AsyncClient(timeout=HTTPX_TIMEOUT) as client:
        try:
            resp = await client.post(TOKEN_ENDPOINT, data=data, headers={"Accept": "application/json"})
            resp.raise_for_status()
        except httpx.HTTPStatusError as exc:
            text = exc.response.text if exc.response is not None else str(exc)
            raise HTTPException(status_code=502, detail=f"Token endpoint error: {text}")

    token_response = resp.json()

    # Save token for this user
    IN_MEMORY_ACCESS_TOKENS[sub] = {
        "token_response": token_response,
        "obtained_at": time.time()
    }

    return JSONResponse(token_response)


@router.get("/canvas_api")
async def canvas_api(sub: str, endpoint: str = "/users/self"):
    """
    Call a Canvas REST API endpoint using the stored access token.
    - endpoint should be the path under /api/v1 (e.g. /users/self/profile, /courses)
    """
    if sub not in IN_MEMORY_ACCESS_TOKENS:
        raise HTTPException(status_code=404, detail="No access token found for sub. Call /token_exchange first.")

    token = IN_MEMORY_ACCESS_TOKENS[sub]["token_response"].get("access_token")
    if not token:
        raise HTTPException(status_code=400, detail="Stored token response does not include access_token")

    # Build full URL
    # Ensure endpoint starts with slash
    if not endpoint.startswith("/"):
        endpoint = "/" + endpoint
    url = f"{CANVAS_API_BASE}{endpoint}"

    async with httpx.AsyncClient(timeout=HTTPX_TIMEOUT) as client:
        r = await client.get(url, headers={"Authorization": f"Bearer {token}"})
        # forward errors
        if r.status_code >= 400:
            raise HTTPException(status_code=r.status_code, detail=f"Canvas API returned error: {r.text}")

        return JSONResponse(r.json())


# Simple debug endpoints to see in-memory stores (for dev only)
@router.get("/_debug/state")
async def debug_state():
    return IN_MEMORY_STATE


@router.get("/_debug/launches")
async def debug_launches():
    return IN_MEMORY_LAUNCHES


@router.get("/_debug/tokens")
async def debug_tokens():
    return IN_MEMORY_ACCESS_TOKENS
