from unittest import TestCase
from numerous_api_client.headers.ValidationInterceptor import ValidationInterceptor
import api.numerous_api.services.tokens as token_manager
from concurrent import futures
import typing
import os
import grpc


class GRPCTestBase(TestCase):
    """Base case for testing the API. Handles setting up server and connection."""
    def setUp(self) -> None:
        """Setup server and start it."""
        self._server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
        self._server_host = os.environ.get("NUMEROUS_API_HOST", "localhost")
        self._server_port = os.environ.get("NUMEROUS_API_PORT", "50051")
        self._server.add_insecure_port(f'[::]:{self._server_port}')
        self._channel, self._access_token = self._initialize_channel_and_access_token()

        self._server.start()

    def tearDown(self) -> None:
        """Stop the server when tests are done."""
        self._server.stop(None)

    def _initialize_channel_and_access_token(self) -> typing.Tuple[grpc.insecure_channel, str]:
        """
        Generates an access token based on refresh token from env, and creates an insecure channel with the
        access token callback in the context, so it can be accessed on the API side.
        """
        _access_token = token_manager.generate_access_token(
            refresh_token=os.environ.get('NUMEROUS_ADMIN_TOKEN'),
            instance_id='', project_id='', scenario_id='', job_id='', execution_id=''
        )

        vi = ValidationInterceptor(token=_access_token, token_callback=self._get_current_token, instance='')
        _channel = grpc.insecure_channel(f"{self._server_host}:{self._server_port}")
        _channel = grpc.intercept_channel(_channel, vi)

        return _channel, _access_token

    def _get_current_token(self) -> str:
        """
        Load in the current access token when requested.
        TODO: Should this be refreshed? Auth issues if tests last more than 10 min.
        """
        return self._access_token
