import json
import logging
import os

import uuid
import ssl
from importlib.metadata import version
from paho.mqtt import client as mqtt
from threading import Event
from time import sleep, time

from numerous.image_tools.app import NumerousApp
from numerous.image_tools.job import NumerousBaseJob
from numerous.client import ScenarioStatus

from numerous.client import NumerousClient

logger = logging.getLogger('mqtt_runner')

VERSION = version('numerous.sdk')


class MQTTOutputHandler:
    _insert_response: str
    _insert_request: str
    _delete_response: str
    _delete_request: str
    _timeout: float = 60.0
    connected: bool = False

    def __init__(self, client: NumerousClient, username: str, password: str, hostname: str, port: int,
                 project_id: str, simulation_id: str, delete_simulation: bool):
        self.client = client
        self._username = username
        self._password = password
        self._hostname = hostname
        self._port = port
        self._project_id = project_id
        self._simulation_id = simulation_id

        self.output = self.client.new_writer(buffer_size=0)
        self.mqtt_client = self._get_mqtt_client(project_id)
        self._request_delete_simulation = delete_simulation

        self.cnt = 0
        self._ids = set()

    def _get_mqtt_client(self, project_id):


        mqtt_client = mqtt.Client()

        mqtt_client.tls_set(tls_version=ssl.PROTOCOL_TLSv1_2)
        mqtt_client.on_connect = self.on_connect
        mqtt_client.on_message = self.on_message
        mqtt_client.on_publish = self.on_publish
        mqtt_client.username_pw_set(username=f"simulation:{self._username}", password=self._password)
        mqtt_client.connect(host=self._hostname, port=int(self._port), keepalive=60)

        self._insert_response = f"{project_id}/simulation/insertResponse"
        self._insert_request = f"{project_id}/simulation/insertRequest"
        self._delete_response = f"{project_id}/simulation/deleteResponse"
        self._delete_request = f"{project_id}/simulation/deleteRequest"
        self._cleared = Event()

        mqtt_client.loop_start()

        return mqtt_client

    def on_connect(self, client: mqtt.Client, userdata, flags, rc):
        logger.debug(f"client connected: {rc}")
        client.subscribe(self._insert_response)
        client.subscribe(self._delete_response)

    def on_publish(self, client: mqtt.Client, userdata, msg, **kwargs):
        pass

    def on_message(self, client: mqtt.Client, userdata, msg):
        logger.debug(f"message on topic {msg.topic}: {msg.payload}")
        _msg = json.loads(msg.payload.decode())
        _id = _msg.get('id')
        status = _msg.get('success')
        if msg.topic == self._insert_response and _id in self._ids and status:
            self._ids.remove(_id)
        if msg.topic == self._delete_response and self._delete_id == _id and status:
            self._cleared.set()
            self._delete_id = ""

    def _delete_simulation(self):
        self._delete_id = str(uuid.uuid4())
        body = {
            "id": self._delete_id,
            "path": f"{self._simulation_id}/"
        }
        self.mqtt_client.publish(self._delete_request, json.dumps(body))

    def delete_simulation(self):
        self._delete_simulation()
        t0 = time()

        while not self._cleared.is_set():
            sleep(0.1)
            if time() - t0 > self._timeout:
                logger.warning(f"simulation data not cleared after {self._timeout} seconds")
                return False
        logger.info("simulation data cleared")
        return True

    def wait_for_connect(self):
        if self.connected:
            message = "MQTT client connection lost... attempting to reconnect"
            post_connect = "MQTT client reconnected!"
            self.connected = False
        else:
            message = 'MQTT client connecting...'
            post_connect = "MQTT client connected!"


        t0 = time()
        backofftime = 1
        while not self.mqtt_client.is_connected():
            if time() - t0 > backofftime:
                logger.info("waiting for mqtt to connect")
                backofftime *= 3
                self.client.set_scenario_progress(message, ScenarioStatus.WAITING, force=True)

            if time() - t0 > self._timeout:
                self.client.set_scenario_progress("MQTT client not connected", ScenarioStatus.FAILED, force=True)
                raise RuntimeError("MQTT client not connected")
            sleep(1)
            continue

        self.client.set_scenario_progress(post_connect, ScenarioStatus.WAITING, force=True)
        self.connected = True

    def write_row(self, t, row):
        if not self.mqtt_client.is_connected():
            self.wait_for_connect()
        if not self._simulation_id:
            raise RuntimeError("no simulation id specified (remember to call set_simulation_id)")
        if not self._project_id:
            raise RuntimeError("no project id specified (remember to call set_project_id)")

        if self._request_delete_simulation:
            self.delete_simulation()
            self._request_delete_simulation = False

        self.output.write_row(row)
        formatted_message = self.format_scheme(t, row)
        self._ids.add(str(formatted_message['id']))
        msg = self.mqtt_client.publish(self._insert_request, json.dumps(formatted_message))
        backofftime = 1
        t0 = time()
        while not msg.is_published():
            if time() - t0 > backofftime:
                backofftime *= 3
                logger.info("waiting for publish")
            if time() - t0 > self._timeout:
                raise RuntimeError(f"could not publish message after {self._timeout}")
            sleep(0.1)

        logger.debug(f"message {msg} published")

    def format_scheme(self, t: float, row: dict) -> dict:
        _id = uuid.uuid4()

        outputs = []
        for tag, val in row.items():
            if tag == "_index":
                continue
            outputs.append({
                "path": f"{self._simulation_id}/{tag}",
                "type": "float",
                "unit": "",
                "times": [int(t * 1000)],  # seconds to milliseconds
                "values": [val]
            })

        request = {
            'id': str(_id),
            'data': outputs
        }
        return request

    def close(self):
        self.output.close()
        self.mqtt_client.loop_stop()

class NumerousMQTTApp(NumerousApp):

    client: NumerousClient
    _hostname: str
    _password: str
    _port: int
    _username: str
    _simulation_id: str
    _project_id: str
    output: MQTTOutputHandler

    def __init__(self, appname="defaultnumerousApp",
                 max_restarts=0, numerous_job: NumerousBaseJob = None, model_folder: str = None,
                 working_folder: str = None, reset_job=False, project_id: str = None,
                 simulation_id: str = None,
                 trace: bool = False):
        super(NumerousMQTTApp, self).__init__(appname=appname,
                                              max_restarts=max_restarts, numerous_job=numerous_job,
                                              model_folder=model_folder,
                                              working_folder=working_folder, reset_job=reset_job,
                                              trace=trace)


        self._hostname = self.client.params.get('MQTT_HOSTNAME', os.getenv('MQTT_HOSTNAME'))
        self._password = self.client.params.get('MQTT_PASSWORD', os.getenv('MQTT_PASSWORD'))
        self._port = int(self.client.params.get('MQTT_PORT', os.getenv('MQTT_PORT')))
        self._username = self.client.params.get('MQTT_USERNAME', os.getenv('MQTT_USERNAME'))
        self._project_id = project_id
        self._simulation_id = simulation_id

        logger.debug(self._hostname)
        logger.debug(self._port)
        logger.debug(self._username)


        states = self.client.state.get('states', None)


        self.output = MQTTOutputHandler(self.client,
                                        self._username,
                                        self._password,
                                        self._hostname,
                                        self._port,
                                        project_id,
                                        simulation_id,
                                        reset_job and not states
                                        )


    def set_simulation_id(self, simulation_id: str):
        self.output._simulation_id = simulation_id

    def set_project_id(self, project_id):
        self.output._project_id = project_id
        self.output._get_mqtt_client(project_id)




def run_mqtt_job(numerous_job=None, appname=None, max_restarts=0, model_folder="models", working_folder="tmp",
                 reset_job=None, project_id=None, simulation_id=None, trace=False):
    app = NumerousMQTTApp(appname=appname,
                          max_restarts=max_restarts, numerous_job=numerous_job, model_folder=model_folder,
                          working_folder=working_folder, reset_job=reset_job, project_id=project_id,
                          simulation_id=simulation_id, trace=trace)

    app._run()
