from unittest.mock import MagicMock

import grpc
from grpc._channel import _InactiveRpcError, _RPCState

import spm_pb2

from numerous.client import NumerousClient


def mock_subscribes(*streams):
    def raise_or_return(value):
        if isinstance(value, Exception):
            raise value
        else:
            return value

    iter_streams = iter(streams)

    def mock_subscribe(request):
        mock = MagicMock()
        if request.channel_patterns == ['ch1', 'ch2']:
            stream = next(iter_streams)
            mock.__iter__.return_value = iter(raise_or_return(update) for update in stream)
        else:
            mock.__iter__.return_value = iter([])
        return mock

    return mock_subscribe


def test_get_scenario_document_retry(mock_spm: MagicMock):
    mock_spm().GetScenarioCustomMetaData.return_value = spm_pb2.ScenarioCustomMetaData(meta='{}')
    mock_spm().GetScenario.side_effect = [
        _InactiveRpcError(_RPCState((), None, None, grpc.StatusCode.UNAVAILABLE, 'details')),
        spm_pb2.ScenarioDocument(scenario_document='{"prop": "val"}')
    ]
    client = NumerousClient("job_id", "project_id", "scenario_id", "https://server", port=50000, secure=True,
                            refresh_token="refresh")
    try:
        scenario, _ = client.get_scenario_document()
        assert scenario == {'prop': 'val'}
    finally:
        client.close()


def test_subscribe_messages_retry(mock_spm: MagicMock):
    mock_spm().GetScenarioCustomMetaData.return_value = spm_pb2.ScenarioCustomMetaData(meta='{}')
    mock_spm().SubscribeForUpdates.side_effect = mock_subscribes(
        [spm_pb2.SubscriptionMessage(channel='ch1', message='{"msg": "hello"}'),
         _InactiveRpcError(_RPCState((), None, None, grpc.StatusCode.UNAVAILABLE, 'details'))],
        [spm_pb2.SubscriptionMessage(channel='ch2', message='{"msg": "hi, bye"}'),
         _InactiveRpcError(_RPCState((), None, None, grpc.StatusCode.UNAVAILABLE, 'details'))],
        [spm_pb2.SubscriptionMessage(channel='ch1', message='{"msg": "see you"}')]
    )
    client = NumerousClient("job_id", "project_id", "scenario_id", "https://server", port=50000, secure=True,
                            refresh_token="refresh")
    try:
        messages = list(client.subscribe_messages(['ch1', 'ch2']))
        assert messages == [('ch1', {'msg': 'hello'}), ('ch2', {'msg': 'hi, bye'}), ('ch1', {'msg': 'see you'})]
    finally:
        client.close()
