"""
FastAPI LTI 1.3 Tool example
- OIDC login -> LTI launch (id_token verification)
- Token exchange (client_assertion JWT)
- Calls to Names & Roles / AGS / Deep Linking
- JWKS hosting for the Tool (demo)
"""

import os
import json
import base64
import time
import uuid
import logging
from datetime import datetime, timedelta
from typing import Dict, Any, Optional
from urllib.parse import urlencode
from fastapi import FastAPI, APIRouter, Depends, Request, Response, Form, HTTPException
from fastapi.responses import RedirectResponse, JSONResponse, HTMLResponse
from fastapi.staticfiles import StaticFiles
from starlette.datastructures import URL
from pydantic import BaseModel
from jose import jwt, jwk
from jose.utils import base64url_decode
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives import serialization, hashes
from cryptography.hazmat.primitives.asymmetric import padding
from cryptography.hazmat.backends import default_backend
import httpx
from urllib.parse import urlparse
from db_config.models import LMSPlatform, Deployment, User, UserCourseLog, AccessToken, LoginState
import db_config.database as dbase
from dotenv import load_dotenv

load_dotenv()  # optional: read .env for easy config in dev
router = APIRouter()
from dependencies.canvas_config import (
    TOOL_CLIENT_ID,
    TOOL_REDIRECT_URI,
    TOOL_BASE_URL,
    TOOL_JWKS_PATH,
    TOOL_JWKS_URL,
    PLATFORM_ISS,
    PLATFORM_OIDC_AUTH,
    PLATFORM_TOKEN_URL,
    PLATFORM_JWKS_URL,
    DEPLOYMENT_ID,
    SCOPES,
)

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
TOOL_KEY_DIR = os.getenv("TOOL_KEY_DIR", "./keys")
os.makedirs(TOOL_KEY_DIR, exist_ok=True)
TOOL_PRIVATE_KEY_PEM = os.getenv("TOOL_PRIVATE_KEY_PEM", "private.pem")
PRIVATE_KEY_PATH = f"{TOOL_KEY_DIR}/{TOOL_PRIVATE_KEY_PEM}"
PRIVATE_KEY_KID = "tool-key-1"
# app.mount("/static", StaticFiles(directory="static"), name="static")
# NOTE: these fields normally come from the platform during the registration process or are discovered dynamically.
# You must set or dynamically discover them before verifying tokens / exchanging tokens.
# Example of values you will need when making requests:
# - token_endpoint
# - jwks_uri
# - names_and_roles_url (service endpoint) / assignments and grades endpoint / deep linking endpoint
# For this demo, we'll fetch the OIDC config dynamically (if PLATFORM_OIDC_CONFIG is reachable).

# ---------------------
# Demo Key generation (for client_assertion and jwks hosting)
# ---------------------
# NOTE: Persist your key pair; DO NOT re-generate each run in production.

def b64url_uint(val: int) -> str:
    """Convert integer to base64url string (no padding)."""
    int_bytes = val.to_bytes((val.bit_length() + 7) // 8, byteorder="big")
    return base64.urlsafe_b64encode(int_bytes).rstrip(b"=").decode("utf-8")

def generate_rsa_keypair(kid: str = "tool-key-1"):
    # Generate RSA private key
    private_key = rsa.generate_private_key(
        public_exponent=65537,
        key_size=2048,
        backend=default_backend()
    )

    # Save private key as PEM file
    with open(PRIVATE_KEY_PATH, "wb") as f:
        f.write(
            private_key.private_bytes(
                encoding=serialization.Encoding.PEM,
                format=serialization.PrivateFormat.PKCS8,
                encryption_algorithm=serialization.NoEncryption(),
            )
        )
    print("✅ Saved private key to private.pem")

    # Extract public key numbers
    public_key = private_key.public_key()
    numbers = public_key.public_numbers()

    jwk = {
        "kty": "RSA",
        "alg": "RS256",
        "use": "sig",
        "kid": PRIVATE_KEY_KID,
        "n": b64url_uint(numbers.n),
        "e": b64url_uint(numbers.e),
    }

    # Save JWK JSON
    with open(f"{TOOL_KEY_DIR}/public_jwk.json", "w") as f:
        json.dump(jwk, f, indent=2)
    print("✅ Saved public JWK to public_jwk.json")

    return jwk

# Load or generate keypair once
with open(PRIVATE_KEY_PATH, "rb") as key_file:
    PRIVATE_KEY = serialization.load_pem_private_key(
        key_file.read(),
        password=None,
        backend=default_backend()
    )

PUBLIC_NUMBERS = PRIVATE_KEY.public_key().public_numbers()

TOOL_JWK = {
    "kty": "RSA",
    "alg": "RS256",
    "use": "sig",
    "kid": PRIVATE_KEY_KID,
    "n": b64url_uint(PUBLIC_NUMBERS.n),
    "e": b64url_uint(PUBLIC_NUMBERS.e),
}
# ---------------------
# In-memory store (demo only)
# ---------------------
# In real world persist this association across requests (DB)
REGISTRATIONS = {
    # Example registration record that you'd normally create at dynamic registration or in your config:
    # platform_iss: { "client_id": "...", "oidc_config": {...}, "deployment_id": "...", ...}
}

# We'll store last nonce state in memory for demo; in production store per user/session
NONCE_STORE: Dict[str, str] = {}

# ---------------------
# Utilities
# ---------------------
def now() -> int:
    return int(time.time())

def create_client_assertion(client_id: str, token_url: str, private_key, kid: str) -> str:
    """
    Create a signed JWT client assertion for Canvas OAuth2 token request.
    """
    iat = int(datetime.utcnow().timestamp())
    exp = iat + 300  # valid for 5 minutes

    payload = {
        "iss": client_id,
        "sub": client_id,
        "aud": token_url,   # must exactly match tenant token URL
        "iat": iat,
        "exp": exp,
        "jti": str(uuid.uuid4()),
    }

    # export private key as PEM
    pem = private_key.private_bytes(
        encoding=serialization.Encoding.PEM,
        format=serialization.PrivateFormat.PKCS8,
        encryption_algorithm=serialization.NoEncryption(),
    )

    headers = {"alg": "RS256", "kid": kid}
    logger.info("Client assertion headers: %s", headers)
    logger.info("Client assertion payload: %s", payload)
    # logger.info("Private Key PEM: %s", pem.decode())

    return jwt.encode(payload, pem, algorithm="RS256", headers=headers)

def get_latest_token(db, email: str) -> Optional[AccessToken]:
    user = db.query(User).filter(User.email == email).first()
    token = (
        db.query(AccessToken).filter(
            AccessToken.user_id == user.id,
            User.email == email
        ).order_by(AccessToken.expires_at.desc()).first()
    )
    if token and token.expires_at > datetime.utcnow():
        return token.token

def get_valid_token(db, platform, deployment, user, scope: str) -> str:
    """
    Return a valid access_token for the given platform/deployment/user/scope.
    If expired or missing, request a new one from Canvas and store it.
    """
    now = datetime.utcnow()

    # 1. Check DB for a valid token
    token = (
        db.query(AccessToken)
        .filter(
            AccessToken.platform_id == platform.id,
            AccessToken.deployment_id == deployment.id,
            AccessToken.user_id == user.id,
            AccessToken.scope == scope,
            AccessToken.expires_at > now
        )
        .order_by(AccessToken.expires_at.desc())
        .first()
    )

    if token:
        return token.token

    # 2. Request a new token
    
    client_assertion = create_client_assertion(
        client_id=platform.client_id,
        token_url=platform.token_url,  # e.g. https://manish.instructure.com/login/oauth2/token
        private_key=PRIVATE_KEY,
        kid=PRIVATE_KEY_KID,
    )

    logger.info("client_assertion: %s", client_assertion)
    payload = {
        "grant_type": "client_credentials",
        "client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
        "client_assertion": client_assertion,
        "scope": scope,
    }
    logger.info("Token request payload: %s", payload)
    response = httpx.post(platform.token_url, data=payload)
    
    logger.info("Token response status: %s", response.text)
    response.raise_for_status()
    token_json = response.json()

    access_token = token_json["access_token"]
    expires_in = token_json.get("expires_in", 3600)
    expires_at = now + timedelta(seconds=expires_in)

    # 3. Save in DB
    new_token = AccessToken(
        platform_id=platform.id,
        deployment_id=deployment.id,
        user_id=user.id,
        scope=scope,
        token=access_token,
        expires_at=expires_at,
    )
    db.add(new_token)
    db.commit()
    db.refresh(new_token)

    return new_token.token


def store_launch(db, decoded_token, tenant_domain: str):
    issuer = decoded_token["iss"]       # global Canvas issuer
    client_id = decoded_token["aud"]

    # Build URLs based on tenant
    jwks_url = f"{tenant_domain}/api/lti/security/jwks"
    token_url = f"{tenant_domain}/login/oauth2/token"

    # resource_link_claim = decoded_token.get("https://purl.imsglobal.org/spec/lti/claim/resource_link", {})
    # resource_id = resource_link_claim.get("id", "")
    # resource_title = resource_link_claim.get("title", "")

    tool_plateform = decoded_token.get("https://purl.imsglobal.org/spec/lti/claim/tool_platform", {})
    platform_name = tool_plateform.get("name", "")
    platform_id = tool_plateform.get("guid", "")
    # product_family_code = tool_plateform.get("product_family_code", "")

    # 1. Platform
    platform = db.query(LMSPlatform).filter_by(
        issuer=issuer,
        client_id=client_id
    ).first()
    if not platform:
        platform = LMSPlatform(
            issuer=issuer,
            tenant_domain=tenant_domain,
            tenant_id=platform_id,
            tenant_name=platform_name,
            client_id=client_id,
            jwks_url=jwks_url,
            token_url=token_url
        )
        db.add(platform)
        db.commit()
        db.refresh(platform)
    else:
        # update platform record if tenant domain or URLs changed
        platform.tenant_domain = tenant_domain
        platform.tenant_id=platform_id
        platform.tenant_name=platform_name
        platform.jwks_url = jwks_url
        platform.token_url = token_url
        db.commit()

    # 2. Deployment
    deployment_id = decoded_token["https://purl.imsglobal.org/spec/lti/claim/deployment_id"]
    deployment = db.query(Deployment).filter_by(
        platform_id=platform.id,
        deployment_id=deployment_id
    ).first()
    if not deployment:
        deployment = Deployment(platform_id=platform.id, deployment_id=deployment_id)
        db.add(deployment)
        db.commit()
        db.refresh(deployment)

    sub = decoded_token["sub"]
    
    roles = decoded_token.get("https://purl.imsglobal.org/spec/lti/claim/roles", [])
    if 'http://purl.imsglobal.org/vocab/lis/v2/institution/person#Administrator' in roles:
        role_id = 3  # collage admin
    else:
        role_id = 2  # student
        
    user = db.query(User).filter_by(email=decoded_token.get("email")).first()
    if not user:    
        # Also create corresponding local User record
        user = User(
            name=decoded_token.get("name"),
            first_name=decoded_token.get("given_name"),
            last_name=decoded_token.get("family_name"),
            email=decoded_token.get("email"),
            platform_id=platform.id,
            sub = sub,
            lti_roles = json.dumps(roles),
            is_active=1,
            role_id=role_id,  # default role (e.g., teacher)
            subscription_status="active",
            is_email_verified=0,
            is_phone_verified=0,
            created_at=datetime.utcnow(),
            updated_at=datetime.utcnow()
        )
        db.add(user)
        db.commit()
        db.refresh(user)

        ## Only for administrator role, give initial course credits
        if role_id == 3:
            UserCourseLogEntry = UserCourseLog(
                user_id=user.id,
                number_of_course=10,
                platform_id=platform.id,
                type="Credit"
            )
            db.add(UserCourseLogEntry)
            db.commit()

    return platform, deployment, user

# ---------------------
# Endpoints
# ---------------------
@router.get("/jwks.json")
async def jwks():
    """Host the tool's JWKS containing the public key (so platform can verify client_assertion if needed)."""
    return JSONResponse({"keys": [TOOL_JWK]})

@router.get("/demo/launch")
async def demo_launch():
    """
    Simulate Canvas launching your tool (for local testing only).
    Builds an OIDC login initiation request and redirects to /oidc/login.
    """
    # Canvas would normally provide these
    iss = "https://canvas.instructure.com"
    login_hint = "demo-login-hint"
    target_link_uri = "http://localhost:8000/canvas/launch"
    lti_message_hint = str(uuid.uuid4())

    params = {
        "iss": iss,
        "login_hint": login_hint,
        "target_link_uri": target_link_uri,
        "lti_message_hint": lti_message_hint,
        "state": str(uuid.uuid4()),  # normally Canvas provides
    }

    # url = httpx.URL("/oidc/login").include_query_params(**params)
    url = f"/canvas/oidc/login?{urlencode(params)}"
    return RedirectResponse(str(url))

@router.api_route("/launch", methods=["GET", "POST"])
async def launch_page(launch_id: str):
    db = next(dbase.get_db())
    reg = REGISTRATIONS.get(launch_id)
    if not reg:
        raise HTTPException(status_code=404, detail="Unknown launch ID")

    decoded = reg["decoded"]
    user_sub = decoded.get("sub")
    roles = decoded.get("https://purl.imsglobal.org/spec/lti/claim/roles", [])

    logger.info("Sub Data: %s", user_sub)
    logger.info("Launch decoded token: %s", decoded)
    tenant_domain = reg.get("tenant_domain")
    if decoded.get("sub") is None:
        raise HTTPException(status_code=400, detail="Missing sub claim in id_token")
    platform, deployment, user = store_launch(db, decoded, tenant_domain)

    scope = SCOPES["nrps"]
    access_token = reg.get("access_token")
    # Example: request multiple scopes at once
    # scope = " ".join([
    #     "https://purl.imsglobal.org/spec/lti-nrps/scope/contextmembership.readonly",
    #     "https://purl.imsglobal.org/spec/lti-ags/scope/lineitem",
    #     "https://purl.imsglobal.org/spec/lti-ags/scope/score"
    # ])

    try:
        access_token = get_valid_token(db, platform, deployment, user, scope)
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Failed to get access token: {str(e)}, {scope}, {dict(platform.__dict__), dict(deployment.__dict__), dict(user.__dict__)}")
    roster_data = {}
    # roster_data = fetch_roster(decoded, access_token)  # optional: fetch and store roster on launch
    lti_user_id = decoded.get("https://purl.imsglobal.org/spec/lti/claim/lti1p1").get("user_id")
    email = decoded.get("email")
    name = decoded.get("name")
    FRONTEND_APP = os.getenv("FRONT_APP_URL", "https://answerous.contactous.com:9001")

    resource_link_claim = decoded.get("https://purl.imsglobal.org/spec/lti/claim/resource_link", {})
    resource_id = resource_link_claim.get("id", "")
    resource_title = resource_link_claim.get("title", "")

    context_id = decoded.get("https://purl.imsglobal.org/spec/lti/claim/context", {}).get("id", "")
    context_title = decoded.get("https://purl.imsglobal.org/spec/lti/claim/context", {}).get("title", "")

    tool_plateform = decoded.get("https://purl.imsglobal.org/spec/lti/claim/tool_platform", {})
    platform_name = tool_plateform.get("name", "")
    platform_id = tool_plateform.get("guid", "")
    product_family_code = tool_plateform.get("product_family_code", "")

    
    roles = ",".join(roles)  # convert list to comma-separated string
    html = f"""
    <html>
    <body>
    <form action="{FRONTEND_APP}/api/canvas" method="POST" style="display: none;">
        <input type="hidden" name="iss" value="{platform.issuer}"/>
        <input type="hidden" name="sub" value="{user_sub}"/>
        <input type="hidden" name="client_id" value="{platform.client_id}"/>
        <input type="hidden" name="launch_id" value="{launch_id}"/>
        <input type="hidden" name="user_id" value="{lti_user_id}"/>
        <input type="hidden" name="lms_user_id" value="{user.id}"/>
        <input type="hidden" name="lis_person_name_full" value="{name}"/>
        <input type="hidden" name="lis_person_contact_email_primary" value="{email}"/>
        <input type="hidden" name="lis_person_name_given" value="{decoded.get('given_name')}"/>
        <input type="hidden" name="lis_person_name_family" value="{decoded.get('family_name')}"/>
        <input type="hidden" name="custom_email" value="{email}"/>
        <input type="hidden" name="custom_canvas_user_login_id" value='{email}'/>
        <input type="hidden" name="access_token" value="{access_token}"/>
        <input type="hidden" name="tenant_domain" value="{tenant_domain}"/>
        <input type="hidden" name="custom_canvas_api_domain" value="{tenant_domain}"/>
        <input type="hidden" name="lti_deployment_id" value="{deployment.deployment_id}"/>
        <input type="hidden" name="target_link_uri" value="{FRONTEND_APP}"/>
        <input type="hidden" name="context_id" value="{context_id}"/>
        <input type="hidden" name="context_title" value="{context_title}"/>
        <input type="hidden" name="resource_link_id" value="{resource_id}"/>
        <input type="hidden" name="resource_link_title" value="{resource_title}"/>
        <input type="hidden" name="tool_consumer_instance_guid" value="{platform_id}"/>
        <input type="hidden" name="tool_consumer_instance_name" value="{platform_name}"/>
        <input type="hidden" name="tool_consumer_instance_product_family_code" value="{product_family_code}"/>
        <input type="hidden" name="lti_version" value='LTI-1p3'/>
        <input type="hidden" name="roles" value='{roles}'/>
        <input type="submit" value="Start Deep Linking Flow"/>
    </form>
    <script type="text/javascript">
        document.forms[0].submit();
    </script>
  </body>
</html>
    """
    return HTMLResponse(html)
    

def fetch_roster(decoded, access_token: str) -> Any:
    nrps_claim = decoded.get("https://purl.imsglobal.org/spec/lti-nrps/claim/namesroleservice")
    if not nrps_claim:
        raise HTTPException(status_code=400, detail="No NRPS endpoint in id_token")

    context_url = nrps_claim["context_memberships_url"]
    scope = SCOPES["nrps"]
    
    try:
        # access_token = get_valid_token(db, platform, deployment, user, scope)
        logger.info("Using access token: %s", access_token)
        headers = {"Authorization": f"Bearer {access_token}"}
        resp = httpx.get(context_url, headers=headers)
        resp.raise_for_status()
        return resp.json()
    except Exception as e:
        logger.error("Failed to fetch roster: %s", str(e))
        return {"error": str(e)}

@router.api_route("/oidc/login", methods=["GET", "POST"])
async def oidc_login(
    request: Request
):    
    """
    OIDC login initiation endpoint.
    Canvas may call this with POST (hidden form) or GET (query params).
    """
    db = next(dbase.get_db())
    # Collect params from query string
    params = dict(request.query_params)

    # If POST, also include form body
    if request.method == "POST":
        form = await request.form()
        params.update(form)

    iss = params.get("iss")
    login_hint = params.get("login_hint")
    target_link_uri = params.get("target_link_uri")
    lti_message_hint = params.get("lti_message_hint")

    # Try to get tenant domain from Referer or Origin header
    referer = request.headers.get("referer") or request.headers.get("origin")
    tenant_domain = None
    if referer:
        parsed = urlparse(referer)
        tenant_domain = f"{parsed.scheme}://{parsed.netloc}"
    
     # Fallback: decode lti_message_hint
    if not tenant_domain and lti_message_hint:
        try:
            decoded_hint = jwt.get_unverified_claims(lti_message_hint)
            canvas_domain = decoded_hint.get("canvas_domain")
            if canvas_domain:
                tenant_domain = f"https://{canvas_domain}"
        except Exception as e:
            print("Failed to decode lti_message_hint:", e)

    logger.info("Detected tenant domain: %s", tenant_domain)

    # if iss != PLATFORM_ISS:
    #     raise HTTPException(status_code=400, detail="Unrecognized issuer")

    # Save tenant_domain in state (or session/DB)
    state = str(uuid.uuid4())

    login_state = LoginState(state=state, tenant_domain=tenant_domain)
    db.add(login_state)
    db.commit()
    
    nonce = str(uuid.uuid4())
    NONCE_STORE[nonce] = "pending"

    redirect_params = {
        "response_type": "id_token",
        "response_mode": "form_post",
        "client_id": TOOL_CLIENT_ID,
        "redirect_uri": TOOL_REDIRECT_URI,
        "scope": "openid",
        "login_hint": login_hint,
        "lti_message_hint": lti_message_hint,
        "nonce": nonce,
        "state": state,
        "prompt": "none",
    }
    # url = httpx.URL(PLATFORM_OIDC_AUTH).include_query_params(**params)
    url = f"{PLATFORM_OIDC_AUTH}?{urlencode(redirect_params)}"
    return RedirectResponse(str(url))


class OIDCLaunchForm(BaseModel):
    id_token: str
    nonce: Optional[str] = None

@router.post("/oidc/callback")
async def oidc_callback(
    request: Request
):
    """
    OIDC callback endpoint.
    Canvas POSTs here with id_token + state.
    After verifying, redirect to target_link_uri (the launch page).
    """
    db = next(dbase.get_db())
    raw = await request.body()
    form = await request.form()
    
    id_token = form.get("id_token")
    state = form.get("state")

    if not id_token:
        return {"error": "Canvas did not send id_token", "form": dict(form), "raw": raw.decode()}
    
    # Retrieve tenant domain saved earlier
    login_state = db.query(LoginState).filter_by(state=state).first()
    if not login_state:
        raise HTTPException(status_code=400, detail="Invalid or expired state")
    tenant_domain = login_state.tenant_domain

    if not tenant_domain:
        raise HTTPException(status_code=400, detail="Missing tenant domain for state")
    logger.info("Using tenant domain from state: %s", tenant_domain)

    # --- Fetch Canvas JWKS ---
    async with httpx.AsyncClient() as client:
        r = await client.get(f"{tenant_domain}/api/lti/security/jwks")
        jwks = r.json()

    # --- Decode without verifying to extract header ---
    header = jwt.get_unverified_header(id_token)
    kid = header.get("kid")
    key = next((k for k in jwks["keys"] if k["kid"] == kid), None)
    if not key:
        raise HTTPException(status_code=400, detail="No matching JWKS key")

    # --- Verify token ---
    try:
        decoded = jwt.decode(
            id_token,
            key,
            algorithms=["RS256"],
            audience=TOOL_CLIENT_ID,
            issuer=PLATFORM_ISS
        )
        
    except Exception as e:
        raise HTTPException(status_code=400, detail=f"Invalid id_token: {str(e)}")

    # --- Validate deployment ---
    deployment_id = decoded.get("https://purl.imsglobal.org/spec/lti/claim/deployment_id")
    # if deployment_id != DEPLOYMENT_ID:
    #     return {"error": "Deployment ID mismatch", "deployment_id": deployment_id}

    # --- Store launch context ---
    launch_id = str(uuid.uuid4())
    REGISTRATIONS[launch_id] = {
        "decoded": decoded,
        "access_token": id_token,
        "tenant_domain": tenant_domain,
        "issued_at": int(time.time())
    }

    # Optionally: cleanup state after use
    db.delete(login_state)
    db.commit()

    # --- Find target_link_uri ---
    target_link_uri = decoded.get(
        "https://purl.imsglobal.org/spec/lti/claim/target_link_uri",
        f"{TOOL_BASE_URL}/canvas/launch"
    )

    # Add launch_id so your launch page can retrieve context if needed
    redirect_url = f"{target_link_uri}?launch_id={launch_id}"

    return RedirectResponse(redirect_url)

@router.post("/lti/canvasdata")
async def canvas_data_webhook(
    canvas_data: str = Form(...)
):
    """
    Canvas Data Webhook endpoint.
    Canvas can POST data exports here.
    canvas_data is a large JSON string.
    """
    db = next(dbase.get_db())
    if not canvas_data:
        return JSONResponse(
            status_code=400,
            content={"error": "Missing canvas_data"}
        )
    payload = json.loads(canvas_data)
    logger.info("Received Canvas Data webhook: %s", payload)
    user_data = {}
    if payload.get('custom_canvas_user_login_id'):
        custom_client_id = f"LTI-1p0-client-{payload.get('custom_canvas_user_login_id')}"
        logger.info("Data for user: %s", payload.get('custom_canvas_user_login_id'))
        user = db.query(User).filter_by(email=payload.get('custom_canvas_user_login_id')).first()
        if user:
            logger.info("User Data: %s", dict(user.__dict__))
            logger.info("Found user in DB: %s", user.email)

            if not user.platform_id:
                platform = db.query(LMSPlatform).filter_by(
                    issuer=payload.get('iss') if payload.get('iss') else PLATFORM_ISS,
                    client_id=payload.get('client_id') if payload.get('client_id') else custom_client_id,
                ).first()
                if not platform:
                    platform = LMSPlatform(
                        issuer=payload.get('iss') if payload.get('iss') else PLATFORM_ISS,
                        tenant_domain=payload.get('custom_canvas_api_domain'),
                        tenant_id=payload.get('tool_consumer_instance_guid'),
                        tenant_name=payload.get('tool_consumer_instance_name'),
                        client_id=payload.get('client_id') if payload.get('client_id') else custom_client_id,
                        jwks_url="",
                        token_url=""
                    )
                    db.add(platform)
                    db.commit()
                    db.refresh(platform)
                else:
                    platform.tenant_domain = payload.get('custom_canvas_api_domain')
                    platform.tenant_id = payload.get('tool_consumer_instance_guid')
                    platform.tenant_name = payload.get('tool_consumer_instance_name')
                    db.commit()

                user.platform_id = platform.id
                db.commit()

            user_data = {
                "user_id": user.id,
                "name": user.name,
                "email": user.email,
                "platform_id": user.platform_id,
                "role_id": user.role_id
            }
        else:
            logger.warning("User not found in DB")
            platform = db.query(LMSPlatform).filter_by(
                issuer=payload.get('iss') if payload.get('iss') else PLATFORM_ISS,
                client_id=payload.get('client_id') if payload.get('client_id') else custom_client_id,
            ).first()
            if not platform:
                platform = LMSPlatform(
                    issuer=payload.get('iss') if payload.get('iss') else PLATFORM_ISS,
                    tenant_domain=payload.get('custom_canvas_api_domain'),
                    tenant_id=payload.get('tool_consumer_instance_guid'),
                    tenant_name=payload.get('tool_consumer_instance_name'),
                    client_id=payload.get('client_id') if payload.get('client_id') else custom_client_id,
                    jwks_url="",
                    token_url=""
                )
                db.add(platform)
                db.commit()
                db.refresh(platform)
            else:
                platform.tenant_domain = payload.get('custom_canvas_api_domain')
                platform.tenant_id = payload.get('tool_consumer_instance_guid')
                platform.tenant_name = payload.get('tool_consumer_instance_name')
                db.commit()

            ext_roles_raw = payload.get("ext_roles", "")

            # Normalize ext_roles to a list (split comma-separated string, trim whitespace, remove empties)
            if isinstance(ext_roles_raw, str):
                ext_roles = [r.strip() for r in ext_roles_raw.split(",") if r.strip()]
            elif isinstance(ext_roles_raw, (list, tuple)):
                ext_roles = [r.strip() for r in ext_roles_raw if isinstance(r, str) and r.strip()]
            else:
                ext_roles = []
            
            if payload.get('roles') == 'urn:lti:instrole:ims/lis/Administrator' or 'http://purl.imsglobal.org/vocab/lis/v2/institution/person#Administrator' in payload.get('roles') or 'http://purl.imsglobal.org/vocab/lis/v2/institution/person#Administrator' in ext_roles:
                role_id = 3  # collage admin
            else:
                role_id = 2  # student

            user = User(
                name=payload.get('lis_person_name_full', 'Unknown'),
                first_name=payload.get('lis_person_name_given', ''),
                last_name=payload.get('lis_person_name_family', ''),
                email=payload.get('custom_canvas_user_login_id'),
                platform_id=platform.id,
                sub=payload.get("sub", str(uuid.uuid4())),
                lti_roles = json.dumps(ext_roles),
                is_active=1,
                role_id=role_id,  # default role
                subscription_status="active",
                is_email_verified=0,
                is_phone_verified=0,
                activation_date=datetime.utcnow(),
                created_at=datetime.utcnow(),
                updated_at=datetime.utcnow()
            )
            db.add(user)
            db.commit()
            db.refresh(user)
            ## Only for administrator role, give initial course credits
            if role_id == 3:
                UserCourseLogEntry = UserCourseLog(
                    user_id=user.id,
                    platform_id=platform.id,
                    number_of_course=10,
                    type="Credit"
                )
                db.add(UserCourseLogEntry)
                db.commit()

            user_data = {
                "user_id": user.id,
                "name": user.name,
                "email": user.email,
                "platform_id": user.platform_id,
                "role_id": user.role_id
            }
            logger.info("Created new user in DB")
    else:
        logger.warning("No user login ID in payload")
    # Process the payload as needed
    return JSONResponse({"status": "received", "data": user_data})

# ---------------------
# Example API calls using the obtained access_token
# ---------------------
@router.get("/call/namesroles/{launch_id}")
async def call_names_roles(launch_id: str):
    """Canvas: Call Names & Roles (NRPS) with pagination support."""
    db = next(dbase.get_db())
    reg = REGISTRATIONS.get(launch_id)
    if not reg:
        raise HTTPException(status_code=404, detail="Launch not found")
    decoded = reg["decoded"]
    # access_token = reg.get("access_token")
    access_token = get_latest_token(db, decoded.get("email"))
    if not access_token:
        raise HTTPException(status_code=400, detail="No access token available")

    nrps_claim = decoded.get("https://purl.imsglobal.org/spec/lti-nrps/claim/namesroleservice", {})
    context_memberships_url = nrps_claim.get("context_memberships_url")
    if not context_memberships_url:
        raise HTTPException(status_code=400, detail="No NRPS endpoint in id_token")

    headers = {
        "Authorization": f"Bearer {access_token}",
        "Accept": "application/vnd.ims.lti-nrps.v2.membershipcontainer+json",
    }

    members = []
    url = context_memberships_url
    async with httpx.AsyncClient() as client:
        while url:
            r = await client.get(url, headers=headers, timeout=10.0)
            try:
                r.raise_for_status()
            except httpx.HTTPStatusError:
                raise HTTPException(status_code=500, detail=f"NRPS call failed: {r.text}")

            data = r.json()
            members.extend(data.get("members", []))

            # Handle Canvas pagination via Link headers
            link = r.headers.get("Link")
            if link and 'rel="next"' in link:
                # Extract next URL from Link header
                import re
                match = re.search(r'<([^>]+)>; rel="next"', link)
                url = match.group(1) if match else None
            else:
                url = None

    return JSONResponse({"members": members, "count": len(members)})

@router.get("/call/ags/{launch_id}")
async def call_ags(launch_id: str):
    """Canvas: Fetch all line items from AGS."""
    db = next(dbase.get_db())
    reg = REGISTRATIONS.get(launch_id)
    if not reg:
        raise HTTPException(status_code=404, detail="Launch not found")
    decoded = reg["decoded"]
    access_token = get_latest_token(db, decoded.get("email"))
    
    if not access_token:
        raise HTTPException(status_code=400, detail="No access token available")

    ags_claim = decoded.get("https://purl.imsglobal.org/spec/lti-ags/claim/endpoint", {})
    lineitems_url = ags_claim.get("lineitems")
    if not lineitems_url:
        raise HTTPException(status_code=400, detail="No AGS lineitems URL in id_token")

    headers = {
        "Authorization": f"Bearer {access_token}",
        "Accept": "application/vnd.ims.lis.v2.lineitemcontainer+json",
    }

    async with httpx.AsyncClient() as client:
        r = await client.get(lineitems_url, headers=headers, timeout=10.0)
        try:
            r.raise_for_status()
        except httpx.HTTPStatusError:
            raise HTTPException(status_code=500, detail=f"AGS call failed: {r.text}")

    return JSONResponse(r.json())


@router.post("/call/ags/{launch_id}/create")
async def create_line_item(launch_id: str, label: str = Form("Demo Assignment"), score_maximum: float = Form(100)):
    """Canvas: Create a new AGS line item."""
    db = next(dbase.get_db())
    reg = REGISTRATIONS.get(launch_id)
    if not reg:
        raise HTTPException(status_code=404, detail="Launch not found")
    decoded = reg["decoded"]
    access_token = get_latest_token(db, decoded.get("email"))
    if not access_token:
        raise HTTPException(status_code=400, detail="No access token available")

    ags_claim = decoded.get("https://purl.imsglobal.org/spec/lti-ags/claim/endpoint", {})
    lineitems_url = ags_claim.get("lineitems")
    if not lineitems_url:
        raise HTTPException(status_code=400, detail="No AGS lineitems URL in id_token")

    headers = {
        "Authorization": f"Bearer {access_token}",
        "Content-Type": "application/vnd.ims.lis.v2.lineitem+json",
    }

    line_item = {
        "scoreMaximum": score_maximum,
        "label": label,
        "resourceId": str(uuid.uuid4()),  # unique id
    }

    async with httpx.AsyncClient() as client:
        r = await client.post(lineitems_url, headers=headers, json=line_item, timeout=10.0)
        try:
            r.raise_for_status()
        except httpx.HTTPStatusError:
            raise HTTPException(status_code=500, detail=f"Create line item failed: {r.text}")

    return JSONResponse(r.json())

@router.post("/call/ags/{launch_id}/score")
async def post_score(launch_id: str, lineitem_url: str = Form(...), user_id: str = Form(...), score_given: float = Form(80), score_maximum: float = Form(100)):
    """Canvas: Post a score to a line item for a user."""
    db = next(dbase.get_db())
    reg = REGISTRATIONS.get(launch_id)
    if not reg:
        raise HTTPException(status_code=404, detail="Launch not found")
    
    decoded = reg["decoded"]
    access_token = get_latest_token(db, decoded.get("email"))
    if not access_token:
        raise HTTPException(status_code=400, detail="No access token available")

    score_url = lineitem_url + "/scores"

    headers = {
        "Authorization": f"Bearer {access_token}",
        "Content-Type": "application/vnd.ims.lis.v1.score+json",
    }

    score_payload = {
        "userId": user_id,
        "timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
        "scoreGiven": score_given,
        "scoreMaximum": score_maximum,
        "activityProgress": "Completed",
        "gradingProgress": "FullyGraded",
    }

    async with httpx.AsyncClient() as client:
        r = await client.post(score_url, headers=headers, json=score_payload, timeout=10.0)
        try:
            r.raise_for_status()
        except httpx.HTTPStatusError:
            raise HTTPException(status_code=500, detail=f"Post score failed: {r.text}")

    return JSONResponse(r.json())

# ---------------------
# Deep Linking example
# ---------------------
@router.get("/deep_linking")
async def deep_linking_start(launch_id: str):
    """Starts a Deep Linking flow from tool side (this endpoint would show a UI to select content)."""
    reg = REGISTRATIONS.get(launch_id)
    if not reg:
        raise HTTPException(status_code=404, detail="Launch not found")
    decoded = reg["decoded"]
    # Deep linking return URL from id_token claim:
    dl_claim = decoded.get("https://purl.imsglobal.org/spec/lti-dl/claim/deep_linking", {})
    deep_link_return_url = dl_claim.get("return_url")
    if not deep_link_return_url:
        # some platforms put it directly under claim 'https://purl.imsglobal.org/spec/lti/claim/deep_linking_settings'
        raise HTTPException(status_code=400, detail="No deep_linking return_url in id_token")

    # Present a very simple HTML UI to select resources
    html = f"""
    <h3>Deep Linking - Select Resource</h3>
    <form action="/canvas/deep_linking/complete" method="post">
      <input type="hidden" name="launch_id" value="{launch_id}"/>
      <label>Title: <input name="title" value="Demo Activity"/></label><br/>
      <label>URL: <input name="url" value="https://tool.example/activity/123"/></label><br/>
      <label>Text: <input name="text" value="A demo deep link item"/></label><br/>
      <input type="submit" value="Send deep_linking_response">
    </form>
    """
    return HTMLResponse(html)

@router.post("/deep_linking/complete")
async def deep_linking_complete(
    launch_id: str = Form(...),
    title: str = Form("Demo Activity"),
    url: str = Form("http://localhost:8000/canvas/launch/123"),
    score_maximum: float = Form(100),
):
    """
    Canvas: Build Deep Linking Response JWT and return to Canvas.
    """
    reg = REGISTRATIONS.get(launch_id)
    if not reg:
        raise HTTPException(status_code=404, detail="Launch not found")
    decoded = reg["decoded"]

    # Deep Linking return URL from Canvas
    dl_claim = decoded.get("https://purl.imsglobal.org/spec/lti-dl/claim/deep_linking", {})
    deep_link_return_url = dl_claim.get("return_url")
    if not deep_link_return_url:
        raise HTTPException(status_code=400, detail="No deep_linking return_url in id_token")

    # Build deep link response JWT
    iat = now()
    exp = iat + 600

    content_item = {
        "type": "ltiResourceLink",
        "title": title,
        "url": url,
        "text": "Launch this activity in the tool",
        "lineItem": {  # optional: pre-create gradebook entry
            "scoreMaximum": score_maximum,
            "label": title,
        },
        "custom": {  # optional: pass back params into the launch
            "tool_setting": "example",
        },
    }

    message = {
        "iss": TOOL_CLIENT_ID,
        "aud": decoded.get("iss"),  # Canvas issuer
        "iat": iat,
        "exp": exp,
        "nonce": str(uuid.uuid4()),
        "https://purl.imsglobal.org/spec/lti/claim/message_type": "LtiDeepLinkingResponse",
        "https://purl.imsglobal.org/spec/lti/claim/version": "1.3.0",
        "https://purl.imsglobal.org/spec/lti-dl/claim/content_items": [content_item],
        "data": dl_claim.get("data"),  # echo back any "data" from Canvas
    }

    pem = PRIVATE_KEY.private_bytes(
        encoding=serialization.Encoding.PEM,
        format=serialization.PrivateFormat.PKCS8,
        encryption_algorithm=serialization.NoEncryption(),
    )
    headers = {"kid": "tool-key-1"}
    deep_linking_jwt = jwt.encode(message, pem, algorithm="RS256", headers=headers)

    # Canvas expects a POST with form field `JWT`
    html = f"""
    <html>
      <body onload="document.forms[0].submit()">
        <form action="{deep_link_return_url}" method="POST">
          <input type="hidden" name="JWT" value="{deep_linking_jwt}"/>
          <noscript>
            <p>JavaScript is disabled. Please click below to continue.</p>
            <button type="submit">Continue</button>
          </noscript>
        </form>
      </body>
    </html>
    """
    return HTMLResponse(html)

# ---------------------
# Simple utility endpoints for debugging
# ---------------------
@router.get("/_debug/registrations")
def debug_regs():
    return JSONResponse(REGISTRATIONS)


@router.get("/_debug/access_token")
def debug_token():
    with open(PRIVATE_KEY_PATH, "rb") as key_file:
        private_key = serialization.load_pem_private_key(
            key_file.read(),
            password=None,
            backend=default_backend()
        )
    numbers = private_key.public_key().public_numbers()

    jwk = {
        "kty": "RSA",
        "alg": "RS256",
        "use": "sig",
        "kid": "tool-key-1",  # must match your Developer Key JSON
        "n": b64url_uint(numbers.n),
        "e": b64url_uint(numbers.e),
    }
    return JSONResponse(jwk)