from unittest.mock import MagicMock

import grpc
import pytest
from pytest_mock import MockerFixture

import spm_pb2

from numerous_api_client.client.numerous_client import NumerousClient


def test_initial_get_access_token(mocker: MockerFixture):
    mocker.patch("grpc.secure_channel", return_value=None)
    mocker.patch("grpc._interceptor._Channel")
    mock_spm = mocker.patch("spm_pb2_grpc.SPMStub")
    mock_spm.SubscribeForUpdates.return_value = (("channel", "{}") for _ in range(1))
    mock_token_manager = mocker.patch("spm_pb2_grpc.TokenManagerStub")
    mock_token_manager().GetAccessToken.return_value = spm_pb2.Token(val="token")
    mock_spm().GetScenarioCustomMetaData.return_value = spm_pb2.ScenarioCustomMetaData(meta="{}")
    try:
        client = NumerousClient("job_id", "project_id", "scenario_id", "https://server", port=50000, secure=True,
                                refresh_token="refresh")
    finally:
        client.close()

    mock_token_manager().GetAccessToken.assert_called_once()


def test_write_data(mock_spm: MagicMock):
    mock_spm.WriteDataList().return_value = (None for _ in range(1))
    mock_spm().GetScenarioCustomMetaData.return_value = spm_pb2.ScenarioCustomMetaData(meta="{}")

    try:
        client = NumerousClient("job_id", "project_id", "scenario_id", "https://server", port=50000, secure=True,
                                refresh_token="refresh")
        writer = client.new_writer(buffer_size=0)
        writer.write_row({"_index": 1, "val1": 2.5, "val2": 16.31})
        writer.write_row({"_index": 2, "val1": 3.0, "val2": -3.17})
        writer.write_row({"_index": 3, "val1": 5.3, "val2": 103.87})
        writer.write_row({"_index": 4, "val1": 2.0, "val2": 28.49})
    finally:
        client.close()

    assert len(mock_spm().WriteDataList.mock_calls) == 2


def test_set_timeseries_meta_data_correct_defaults(mock_spm: MagicMock):
    mock_spm().GetScenarioCustomMetaData.return_value = spm_pb2.ScenarioCustomMetaData(meta="{}")
    try:
        client = NumerousClient("job_id", "project_id", "scenario_id", "https://server", port=50000, secure=True,
                                refresh_token="refresh")
        client.set_timeseries_meta_data([{"name": "tag"}], scenario="scenario_id", execution="execution_id",
                                        project="project_id")
    finally:
        client.close()

    expected_tag = spm_pb2.Tag(name="tag", displayName="", unit="", description="", type="double", scaling=1, offset=0)
    expected_scenario_metadata = spm_pb2.ScenarioMetaData(project="project_id", scenario="scenario_id",
                                                          execution="execution_id", tags=[expected_tag], aliases=[],
                                                          offset=0, timezone="UTC", epoch_type="s")
    mock_spm().SetScenarioMetaData.assert_called_once_with(expected_scenario_metadata)


def test_set_timeseries_meta_data_name_required(mock_spm: MagicMock):
    mock_spm().GetScenarioCustomMetaData.return_value = spm_pb2.ScenarioCustomMetaData(meta="{}")
    try:
        client = NumerousClient("job_id", "project_id", "scenario_id", "https://server", port=50000, secure=True,
                                refresh_token="refresh")
        with pytest.raises(KeyError):
            client.set_timeseries_meta_data([{"unit": "s"}], scenario="scenario_id", execution="execution_id",
                                            project="project_id")
    finally:
        client.close()
    assert len(mock_spm().SetScenarioMetaData.mock_calls) == 0


def test_raises_grpc_errors_on_data_read_after_retry(mock_spm: MagicMock, monkeypatch):
    mock_spm().GetScenarioCustomMetaData.return_value = spm_pb2.ScenarioCustomMetaData(meta="{}")
    grpc_retry_timeout = 15
    grpc_retry_delay = 0.25
    monkeypatch.setattr("numerous.client.config.GRPC_RETRY_TIMEOUT", grpc_retry_timeout)
    monkeypatch.setattr("numerous.client.config.GRPC_RETRY_DELAY", grpc_retry_delay)

    client = NumerousClient("job_id", "project_id", "scenario_id", "https://server", port=50000, secure=True,
                            refresh_token="refresh")
    try:
        mock_spm().ReadData.side_effect = grpc.RpcError("Error: ReadData exception details")
        with pytest.raises(grpc.RpcError):
            client.data_read_df(tags=['test_tag_1', 'test_tag_2'], execution='test_execution_id')
    finally:
        client.close()

    assert mock_spm().ReadData.call_count > grpc_retry_timeout / grpc_retry_delay * 0.9  # some overhead
