from __future__ import annotations

import asyncio
import copy
import io
import json
import logging
import random
import time
from typing import Optional, Tuple, Any

import httpx
import requests
from requests.structures import CaseInsensitiveDict
from requests_toolbelt.multipart.encoder import MultipartEncoder

from unstructured_client._hooks.custom.common import UNSTRUCTURED_CLIENT_LOGGER_NAME
from unstructured_client._hooks.custom.form_utils import (
    PARTITION_FORM_FILES_KEY,
    PARTITION_FORM_SPLIT_PDF_PAGE_KEY,
    PARTITION_FORM_SPLIT_PDF_ALLOW_FAILED_KEY,
    PARTITION_FORM_PAGE_RANGE_KEY,
    PARTITION_FORM_STARTING_PAGE_NUMBER_KEY,
    FormData,
)

logger = logging.getLogger(UNSTRUCTURED_CLIENT_LOGGER_NAME)


def create_request_body(
    form_data: FormData, page_content: io.BytesIO, filename: str, page_number: int
) -> MultipartEncoder:
    payload = prepare_request_payload(form_data)

    payload_fields:  list[tuple[str, Any]] = []
    for key, value in payload.items():
        if isinstance(value, list):
            payload_fields.extend([(key, list_value) for list_value in value])
        else:
            payload_fields.append((key, value))

    payload_fields.append((PARTITION_FORM_FILES_KEY, (
        filename,
        page_content,
        "application/pdf",
    )))

    payload_fields.append((PARTITION_FORM_STARTING_PAGE_NUMBER_KEY, str(page_number)))

    body = MultipartEncoder(
        fields=payload_fields
    )
    return body


def create_httpx_request(
    original_request: requests.Request, body: MultipartEncoder
) -> httpx.Request:
    headers = prepare_request_headers(original_request.headers)
    return httpx.Request(
        method="POST",
        url=original_request.url or "",
        content=body.to_string(),
        headers={**headers, "Content-Type": body.content_type},
    )


def create_request(
    request: requests.PreparedRequest,
    body: MultipartEncoder,
) -> requests.Request:
    headers = prepare_request_headers(request.headers)
    return requests.Request(
        method="POST",
        url=request.url or "",
        data=body,
        headers={**headers, "Content-Type": body.content_type},
    )


async def retry_with_backoff_async(
    request_func,
    page_number,
    initial_interval,
    max_interval,
    exponent,
    max_elapsed_time,
):
    """
    A copy of the autogenerated backoff code adapted for asyncio
    Call func()
    """
    start = round(time.time() * 1000)
    retries = 0

    retry_status_codes = [502, 503, 504]

    while True:
        try:
            response = await request_func()

            if response.status_code not in retry_status_codes:
                return response

            logger.error("Request (page %d) failed with status code %d. Waiting to retry.", page_number, response.status_code)

            # Is it time to get out of the loop?
            now = round(time.time() * 1000)
            if now - start > max_elapsed_time:
                return response
        except Exception as e:
            logger.error("Request (page %d) failed (%s). Waiting to retry.", page_number, repr(e))

            # Is it time to get out of the loop?
            now = round(time.time() * 1000)
            if now - start > max_elapsed_time:
                raise

        # Otherwise go back to sleep
        sleep = (initial_interval / 1000) * exponent**retries + random.uniform(0, 1)
        sleep = min(sleep, max_interval / 1000)
        await asyncio.sleep(sleep)
        retries += 1


async def call_api_async(
    client: httpx.AsyncClient,
    page: Tuple[io.BytesIO, int],
    original_request: requests.Request,
    form_data: FormData,
    filename: str,
    limiter: asyncio.Semaphore,
) -> httpx.Response:
    """
    Issue a httpx POST using a copy of the original requests.Request
    Wrap the call in a retry loop. These values are copied from the API spec,
    and will not be auto updated. Long term solution is to reuse SDK logic.
    We'll need the hook context to have access to the rest of the SDK.
    """
    page_content, page_number = page
    body = create_request_body(form_data, page_content, filename, page_number)
    new_request = create_httpx_request(original_request, body)

    one_second = 1000
    one_minute = 1000 * 60
    retry_values = {
        "initial_interval": one_second * 3,
        "max_interval": one_minute * 12,
        "max_elapsed_time": one_minute * 30,
        "exponent": 1.88,
    }

    async def do_request():
        return await client.send(new_request)

    async with limiter:
        response = await retry_with_backoff_async(do_request, page_number=page_number, **retry_values)

        return response


def call_api(
    client: Optional[requests.Session],
    page: Tuple[io.BytesIO, int],
    request: requests.PreparedRequest,
    form_data: FormData,
    filename: str,
) -> requests.Response:
    if client is None:
        raise RuntimeError("HTTP client not accessible!")
    page_content, page_number = page

    body = create_request_body(form_data, page_content, filename, page_number)
    new_request = create_request(request, body)
    prepared_request = client.prepare_request(new_request)

    try:
        return client.send(prepared_request)
    except Exception:
        logger.error("Failed to send request for page %d", page_number)
        return requests.Response()


def prepare_request_headers(
    headers: CaseInsensitiveDict[str],
) -> CaseInsensitiveDict[str]:
    """Prepare the request headers by removing the 'Content-Type' and 'Content-Length' headers.

    Args:
        headers: The original request headers.

    Returns:
        The modified request headers.
    """
    headers = copy.deepcopy(headers)
    headers.pop("Content-Type", None)
    headers.pop("Content-Length", None)
    return headers


def prepare_request_payload(form_data: FormData) -> FormData:
    """Prepares the request payload by removing unnecessary keys and updating the file.

    Args:
        form_data: The original form data.

    Returns:
        The updated request payload.
    """
    payload = copy.deepcopy(form_data)
    payload.pop(PARTITION_FORM_SPLIT_PDF_PAGE_KEY, None)
    payload.pop(PARTITION_FORM_SPLIT_PDF_ALLOW_FAILED_KEY, None)
    payload.pop(PARTITION_FORM_FILES_KEY, None)
    payload.pop(PARTITION_FORM_PAGE_RANGE_KEY, None)
    payload.pop(PARTITION_FORM_STARTING_PAGE_NUMBER_KEY, None)
    updated_parameters = {
        PARTITION_FORM_SPLIT_PDF_PAGE_KEY: "false",
    }
    payload.update(updated_parameters)
    return payload


def create_failure_response(response: requests.Response) -> requests.Response:
    """
    Convert the status code on the given response to a 500
    This is because the split logic catches and retries 502, 503, etc
    If a failure is passed back to the SDK, we shouldn't trigger
    another layer of retries, we just want to print the error. 500 is
    non retryable up above.

    Args:
        response: The original response object.
        elements: The list of elements to be serialized and added to
        the response.

    Returns:
        The modified response object with updated content.
    """
    response_copy = copy.deepcopy(response)

    response_copy.status_code = 500

    # Some server errors return a lower case content-type
    # The SDK error parsing expects Content-Type
    if content_type := response_copy.headers.get("content-type"):
        response_copy.headers["Content-Type"] = content_type

    return response_copy


def create_response(elements: list) -> requests.Response:
    """
    Creates a requests.Response object with the list of elements.

    Args:
        response: The original response object.
        elements: The list of elements to be serialized and added to
        the response.

    Returns:
        The modified response object with updated content.
    """
    response = requests.Response()

    content = json.dumps(elements).encode()
    content_length = str(len(content))

    response.headers.update({"Content-Length": content_length, "Content-Type": "application/json"})
    response.status_code = 200
    setattr(response, "_content", content)
    return response


def log_after_split_response(status_code: int, split_number: int):
    if status_code == 200:
        logger.info(
            "Successfully partitioned set #%d, elements added to the final result.",
            split_number,
        )
    else:
        logger.warning(
            "Failed to partition set #%d, its elements will be omitted in the final result.",
            split_number,
        )
