import datetime
import json
import logging
import random
from statistics import mean
from typing import Any, List
from unittest.mock import MagicMock

import pytest
from numpy import linspace

import spm_pb2

from numerous_api_client.client.numerous_client import NumerousClient, ScenarioStatus
from numerous_api_client.optimization.configuration import ScenarioSetting
from numerous_api_client.optimization.orchestrator import BaseOptimizationOrchestrator, OptimizationIteration


@pytest.fixture
def fixture_optimizer_scenario_data():
    return {
        "optimizationTargetScenarioID": "target_scenario_id",
        "simComponents": [{
            "uuid": "comp_1_uuid",
            "name": "goal_function",
            "parameters": [{
                "uuid": "param_1_uuid",
                "id": "setting_1",
                "type": "string",
                "value": "sum"
            }, {
                "uuid": "param_2_uuid",
                "id": "setting_2",
                "type": "string",
                "value": "mean"
            }],
            "inputVariables": [{
                "dataSourceType": "scenario",
                "id": "input_1",
                "uuid": "input_1_uuid",
                "offset": 0,
                "scaling": 1,
                "value": 1,
                "tagSource": {
                    "tag": "input_tag"
                }
            }, {
                "dataSourceType": "static",
                "id": "input_2",
                "uuid": "input_2_uuid",
                "offset": 0,
                "scaling": 1,
                "value": 1
            }]
        }, {
            "uuid": "comp_2_uuid",
            "name": "unreleated_component",
            "parameters": [{
                "uuid": "param_3_uuid",
                "id": "unrelated_setting",
                "type": "string",
                "value": "unrelated_value"
            }]
        }, {
            "uuid": "comp_3_uuid",
            "name": "aggregator",
            "parameters": [{
                "uuid": "param_4_uuid",
                "id": "aggregator_setting",
                "type": "string",
                "value": "aggregator_value"
            }]
        }]
    }


@pytest.fixture
def fixture_target_scenario_data():
    return {
        "jobs": {
            "target_job_id": {
                "isMain": True
            }
        },
        "optimizationSettingsPaths": {
            "comp_1_uuid_param_1_uuid_param": ["comp_1_uuid", "param_1_uuid"],
            "comp_1_uuid_input_1_uuid_scaling": ["comp_1_uuid", "input_1_uuid"],
            "comp_1_uuid_input_1_uuid_offset": ["comp_1_uuid", "input_1_uuid"],
            "comp_1_uuid_input_2_uuid_value": ["comp_1_uuid", "input_2_uuid"]
        },
        "optimizationSettingsSimComponents": {
            "comp_1_uuid_param_1_uuid_param": [{
                "parameters": [{
                    "id": "min",
                    "value": 1
                }, {
                    "id": "max",
                    "value": 2
                }]
            }],
            "comp_1_uuid_input_1_uuid_scaling": [{
                "parameters": [{
                    "id": "min",
                    "value": 1
                }, {
                    "id": "max",
                    "value": 2
                }]
            }],
            "comp_1_uuid_input_1_uuid_offset": [{
                "parameters": [{
                    "id": "min",
                    "value": 1
                }, {
                    "id": "max",
                    "value": 2
                }]
            }],
            "comp_1_uuid_input_2_uuid_value": [{
                "parameters": [{
                    "id": "min",
                    "value": 1
                }, {
                    "id": "max",
                    "value": 2
                }]
            }]
        },
        "simComponents": [{
                "id": "comp_1",
                "uuid": "comp_1_uuid",
                "parameters": [
                    {"uuid": "param_1_uuid"},
                ],
                "inputVariables": [
                    {"uuid": "input_1_uuid", "type": "scenario"},
                    {"uuid": "input_2_uuid", "type": "static"},
                ]
            }]
    }


@pytest.fixture
def fixture_get_scenarios(mock_spm: MagicMock, fixture_optimizer_scenario_data, fixture_target_scenario_data):
    scenario_returns = [
        spm_pb2.ScenarioDocument(scenario_document=json.dumps(fixture_optimizer_scenario_data)),
        spm_pb2.ScenarioDocument(scenario_document=json.dumps(fixture_target_scenario_data))
    ]
    mock_spm().GetScenario.side_effect = scenario_returns


def test_goal_function_config(fixture_get_scenarios):  # noqa: F841
    client = NumerousClient("job_id", "project_id", "scenario_id", "https://server", port=50000, secure=True,
                            refresh_token="refresh")
    try:
        assert client.optimization_config.goal_function_config == {
            "params": {"setting_1": "sum", "setting_2": "mean"},
            "components": {}
        }
        assert client.optimization_config.goal_function_inputs == [{"tag": "input_tag", "scale": 1, "offset": 0}]
    finally:
        client.close()


def test_aggregator_config(fixture_get_scenarios):  # noqa: F841
    client = NumerousClient("job_id", "project_id", "scenario_id", "https://server", port=50000, secure=True,
                            refresh_token="refresh")
    try:
        assert client.optimization_config.aggregator_config == {
            "params": {"aggregator_setting": "aggregator_value"},
            "components": {}
        }
    finally:
        client.close()


def test_param_setting(fixture_get_scenarios):  # noqa: F841
    client = NumerousClient("job_id", "project_id", "scenario_id", "https://server", port=50000, secure=True,
                            refresh_token="refresh")
    try:
        setting_tuples = [(s.path, s.type, s.value, s.raw) for s in client.optimization_config.settings]
        expected_tuple = (
            ["comp_1_uuid", "param_1_uuid"],
            "param",
            {"params": {"min": 1, "max": 2}, "components": {}},
            {"uuid": "param_1_uuid"}
        )
        assert expected_tuple in setting_tuples, "Param setting not found"
    finally:
        client.close()


def test_scenario_input_setting(fixture_get_scenarios):  # noqa: F841
    client = NumerousClient("job_id", "project_id", "scenario_id", "https://server", port=50000, secure=True,
                            refresh_token="refresh")
    try:
        setting_tuples = [(s.path, s.type, s.value, s.raw) for s in client.optimization_config.settings]
        expected_offset_tuple = (
            ["comp_1_uuid", "input_1_uuid"],
            "offset",
            {"params": {"min": 1, "max": 2}, "components": {}},
            {"uuid": "input_1_uuid", "type": "scenario"}
        )
        expected_scaling_tuple = (
            ["comp_1_uuid", "input_1_uuid"],
            "scaling",
            {"params": {"min": 1, "max": 2}, "components": {}},
            {"uuid": "input_1_uuid", "type": "scenario"}
        )
        assert expected_offset_tuple in setting_tuples, "Input offset setting not found"
        assert expected_scaling_tuple in setting_tuples, "Input scaling setting not found"
    finally:
        client.close()


def test_static_input_setting(fixture_get_scenarios):  # noqa: F841
    client = NumerousClient("job_id", "project_id", "scenario_id", "https://server", port=50000, secure=True,
                            refresh_token="refresh")
    try:
        setting_tuples = [(s.path, s.type, s.value, s.raw) for s in client.optimization_config.settings]
        expected_tuple = (
            ["comp_1_uuid", "input_2_uuid"],
            "value",
            {"params": {"min": 1, "max": 2}, "components": {}},
            {"uuid": "input_2_uuid", "type": "static"},
        )
        assert expected_tuple in setting_tuples, "Static input setting not found"
    finally:
        client.close()


def test_orchestrator(mock_spm: MagicMock, mock_job_manager: MagicMock, fixture_optimizer_scenario_data,
                      fixture_target_scenario_data):
    iteration_finish_times = {}
    finished_iterations = set()

    def mock_start_job(job: spm_pb2.Job):
        return spm_pb2.ExecutionId(execution_id=f"execution_{job.scenario_id}")

    def mock_get_scenario(scenario: spm_pb2.Scenario):
        scenario_id = scenario.scenario
        if scenario_id == "scenario_id":
            return spm_pb2.ScenarioDocument(scenario_document=json.dumps(fixture_optimizer_scenario_data), files=[])
        elif scenario_id == "target_scenario_id":
            return spm_pb2.ScenarioDocument(scenario_document=json.dumps(fixture_target_scenario_data), files=[])
        elif scenario_id not in iteration_finish_times and scenario_id not in finished_iterations:
            runtime = datetime.timedelta(seconds=random.uniform(0, 3))
            iteration_finish_times[scenario_id] = datetime.datetime.now() + runtime
            status = ScenarioStatus.RUNNING
        elif scenario_id in iteration_finish_times and iteration_finish_times[scenario_id] <= datetime.datetime.now():
            del iteration_finish_times[scenario_id]
            finished_iterations.add(scenario_id)
            status = ScenarioStatus.FINISHED
        elif scenario_id in iteration_finish_times and iteration_finish_times[scenario_id] > datetime.datetime.now():
            status = ScenarioStatus.RUNNING
        doc = {"jobs": {"target_job_id": {"status": {"status": status}}}}
        return spm_pb2.ScenarioDocument(scenario_document=json.dumps(doc), files=[])

    def mock_read_data(read_scenario: spm_pb2.ReadScenario):
        iteration_number = float(read_scenario.scenario.replace("iteration_", ""))
        yield spm_pb2.DataList(data=[spm_pb2.DataBlock(tag="tag1", values=[iteration_number])])
        yield spm_pb2.DataList(data=[spm_pb2.DataBlock(tag="tag2", values=[iteration_number])])

    mock_spm().DuplicateScenario.side_effect = [
        spm_pb2.Scenario(scenario=f"iteration_{i+1}", project="project_id") for i in range(16)
    ]
    mock_job_manager().StartJob.side_effect = mock_start_job
    mock_spm().GetScenario.side_effect = mock_get_scenario
    mock_spm().ReadData.side_effect = mock_read_data

    class TestOrchestrator(BaseOptimizationOrchestrator):
        def __init__(self, client: NumerousClient):
            super().__init__(client, 8)
            self._agg1 = self._agg_func(self.goal_function_config["params"]["setting_1"])  # type: ignore
            self._agg2 = self._agg_func(self.goal_function_config["params"]["setting_2"])  # type: ignore

        def _agg_func(self, name: str):
            if name == "sum":
                return sum
            elif name == "mean":
                return mean

        def aggregate(self, iteration: OptimizationIteration):
            df = self.client.data_read_df(["tag1", "tag2"], scenario=iteration.scenario_id,
                                          execution=iteration.execution_id)
            return self._agg2(self._agg1(df.values))

        def get_iteration_configs(self) -> List[List[ScenarioSetting]]:
            setting_ranges = []
            for setting in self.settings:
                params = setting.value["params"]
                space = linspace(params["min"], params["max"] + 1, 16)
                setting_range = list(setting.scenario_setting(val) for val in space)
                setting_ranges.append(setting_range)
            return [list(settings) for settings in zip(*setting_ranges)]

        def goal_function(self, current: Any, best: Any):
            return min(current, best)

    client = NumerousClient("job_id", "project_id", "scenario_id", "https://server", port=50000, secure=True,
                            refresh_token="refresh", log_level=logging.DEBUG)
    try:
        orchestrator = TestOrchestrator(client)
        result = orchestrator.run(2)
        best_iteration, best_score = result
        assert best_iteration.execution_id == "execution_iteration_1"
        assert best_score == 1.0
    finally:
        client.close()
