"""
A lookup plugin to query the re2o API.

For a detailed example look at https://github.com/ansible/ansible/blob/3dbf89e8aeb80eb2d1484b1cb63458e4bb12795a/lib/ansible/plugins/lookup/aws_ssm.py


The API Client has been adapted from https://gitlab.federez.net/re2o/re2oapi
"""

from ansible.plugins.loader import cache_loader

from pathlib import Path
import datetime
import requests
import stat
import json
import collections
import netaddr
from configparser import ConfigParser

from ansible.module_utils._text import to_native
from ansible.plugins.lookup import LookupBase
from ansible.errors import (AnsibleError,
                            AnsibleFileNotFound,
                            AnsibleLookupError,
                            )
from ansible.utils.display import Display
from ansible.config.manager import ConfigManager

# Ansible Logger to stdout
display = Display()

# Number of seconds before expiration where renewing the token is done
TIME_FOR_RENEW = 120
# Default name of the file to store tokens. Path $HOME/{DEFAUlt_TOKEN_FILENAME}
DEFAULT_TOKEN_FILENAME = '.re2o.token'


class Client:
    """
    Class based client to contact re2o API.
    """
    def __init__(self, hostname, username, password, use_tls=True):
        """
        :arg hostname: The hostname of the Re2o instance to use.
        :arg username: The username to use.
        :arg password: The password to use.
        :arg use_tls: A boolean to specify whether the client should use a
                      a TLS connection. Default is True. Please, keep it.
        """
        self.use_tls = use_tls
        self.hostname = hostname
        self._username = username
        self._password = password

        self.token_file = Path.home() / DEFAULT_TOKEN_FILENAME

        display.v("Connecting to {hostname} as user {user}".format(
            hostname=to_native(self.hostname), user=to_native(self._username)))
        try:
            self.token = self._get_token_from_file()
        except AnsibleFileNotFound:
            display.vv("Force renew the token")
            self._force_renew_token()

    def _get_token_from_file(self):
        display.vv("Trying to fetch token from {}".format(self.token_file))

        # Check if the token file exists
        if not self.token_file.is_file():
            display.vv("Unable to access file {}".format(self.token_file))
            raise AnsibleFileNotFound(file_name=self.token_file)

        try:
            with self.token_file.open() as f:
                data = json.load(f)
        except Exception as e:
            display.vv("File {} not readable".format(self.token_file))
            display.vvv("Original error was {}".format(to_native(e)))
            raise AnsibleFileNotFound(file_name=self.token_file.as_posix() +
                                      ' (Not readable)')

        try:
            token_data = data[self.hostname][self._username]
            ret = {
                'token': token_data['token'],
                'expiration': self._parse_date(token_data["expiration"]),
            }

        except KeyError:
            raise AnsibleLookupError("""Token for {user}@{host} not found
            in token file ({token})""".format(user=self._username,
                                              host=self.hostname,
                                              token=self.token_file,
                                              )
                                     )
        else:
            display.vv("""Token successfully retreived from
            file {token}""".format(token=self.token_file))
            return ret

    def _force_renew_token(self):
        self.token = self._get_token_from_server()
        self._save_token_to_file()

    def _get_token_from_server(self):
        display.vv("Requesting a new token for {user}@{host}".format(
            user=self._username,
            host=self.hostname,
        ))
        # Authentication request
        response = requests.post(
            self.get_url_for('token-auth'),
            data={'username': self._username, 'password': self._password},
        )
        display.vv("Response code: {}".format(response.status_code))
        if response.status_code == requests.codes.bad_request:
            display.vv("Please provide valid credentials")
            raise AnsibleLookupError("Unable to connect to the API for {host}"
                                     .format(host=self.hostname))
        try:
            response.raise_for_status()
        except Exception as e:
            raise AnsibleError("""An error occured while trying to contact
            the API. This was the original exception: {}"""
                               .format(to_native(e)))

        response = response.json()
        ret = {
            'token': response['token'],
            'expiration': self._parse_date(response['expiration']),
        }

        display.vv("Token successfully retreived for {user}@{host}".format(
            user=self._username,
            host=self.hostname,
            )
        )
        return ret

    def _parse_date(self, date, date_format="%Y-%m-%dT%H:%M:%S"):
        return datetime.datetime.strptime(date.split('.')[0], date_format)

    def _save_token_to_file(self):
        display.vv("Saving token to file {}".format(self.token_file))
        try:
            # Read previous data to avoid erasures
            with self.token_file.open() as f:
                data = json.load(f)
        except Exception:
            display.v("""Beware, token file {} was not a valid JSON readable
            file. Considered empty.""".format(self.token_file))
            data = {}

        if self.hostname not in data.keys():
            data[self.hostname] = {}
        data[self.hostname][self._username] = {
            'token': self.token['token'],
            'expiration': self.token['expiration'].isoformat(),
        }

        try:
            with self.token_file.open('w') as f:
                json.dump(data, f)
            self.token_file.chmod(stat.S_IWRITE | stat.S_IREAD)
        except Exception as e:
            display.vv("Token file {} could not be written. Passing."
                       .format(self.token_file))
            display.vvv("Original error was {}".format(to_native(e)))
        else:
            display.vv("Token successfully written to file {}"
                       .format(self.token_file))

    def get_token(self):
        """
        Retrieves the token to use for the current connection.
        Automatically renewed if needed.
        """
        if self.need_renew_token:
            self._force_renew_token()

        return self.token['token']

    @property
    def need_renew_token(self):
        return self.token['expiration'] < \
            datetime.datetime.now() + \
            datetime.timedelta(seconds=TIME_FOR_RENEW)

    def _request(self, method, url, headers={}, params={}, *args, **kwargs):
        display.vv("Building the {method} request to {url}.".format(
            method=method.upper(),
            url=url,
        ))

        # Force the 'Authorization' field with the right token.
        display.vvv("Forcing authentication token.")
        headers.update({
            'Authorization': 'Token {}'.format(self.get_token())
        })

        # Use a json format unless the user already specified something
        if 'format' not in params.keys():
            display.vvv("Forcing JSON format response.")
            params.update({'format': 'json'})

        # Perform the request
        display.v("{} {}".format(method.upper(), url))
        response = getattr(requests, method)(
            url, headers=headers, params=params, *args, **kwargs
        )
        display.vvv("Response code: {}".format(response.status_code))

        if response.status_code == requests.codes.unauthorized:
            # Force re-login to the server (case of a wrong token but valid
            # credentials) and then retry the request without catching errors.
            display.vv("Token refused. Trying to refresh the token.")
            self._force_renew_token()

            headers.update({
                'Authorization': 'Token {}'.format(self.get_token())
            })
            display.vv("Re-performing the request {method} {url}".format(
                method=method.upper(),
                url=url,
            ))
            response = getattr(requests, method)(
                url, headers=headers, params=params, *args, **kwargs
            )
            display.vvv("Response code: ".format(response.status_code))

        if response.status_code == requests.codes.forbidden:
            err = "The {method} request to {url} was denied for {user}".format(
                method=method.upper(),
                url=url,
                user=self._username
            )
            display.vvv(err)
            raise AnsibleLookupError(to_native(err))

        try:
            response.raise_for_status()
        except Exception as e:
            raise AnsibleError("""An error occured while trying to contact
            the API. This was the original exception: {}"""
                               .format(to_native(e)))

        ret = response.json()
        display.vvv("{method} request to {url} successful.".format(
            method=method.upper(),
            url=url
        ))
        return ret

    def get_url_for(self, endpoint):
        """
        Retrieves the complete URL to use for a given endpoint's name.
        """
        return '{proto}://{host}/{namespace}/{endpoint}'.format(
            proto=('https' if self.use_tls else 'http'),
            host=self.hostname,
            namespace='api',
            endpoint=endpoint
        )

    def get(self, *args, **kwargs):
        """
        Perform a GET request to the API
        """
        return self._request('get', *args, **kwargs)

    def list(self, endpoint, max_results=None, params={}):
        """List all objects on the server that corresponds to the given
        endpoint. The endpoint must be valid for listing objects.

        :arg endpoint: The path of the endpoint.
        :kwarg max_results: A limit on the number of result to return
        :kwarg params: See `requests.get` params.
        :returns: The list of all the objects as returned by the API.
        """
        display.v("Starting listing objects under '{}'"
                  .format(endpoint))
        display.vvv("max_results = {}".format(max_results))

        # For optimization, list all results in one page unless the user
        # is forcing a different `page_size`.
        if 'page_size' not in params.keys():
            display.vvv("Forcing 'page_size' parameter to 'all'.")
            params['page_size'] = max_results or 'all'

        # Performs the request for the first page
        response = self.get(
            self.get_url_for(endpoint),
            params=params,
        )

        results = response['results']

        # Get all next pages and append the results
        while response['next'] is not None and \
                (max_results is None or len(results) < max_results):
            response = self.get(response['next'])
            results += response['results']

        # Returns the exact number of results if applicable
        ret = results[:max_results] if max_results else results
        display.vvv("Listing objects under '{}' successful"
                    .format(endpoint))
        return ret


class LookupModule(LookupBase):
    """
    Available terms =
       - dnszones: Queries the re2o API and returns the list of all dns zones
                   nicely formatted to be rendered in a template.

       - dnsreverse: Queries the re2o API and returns the list of all reverse
                     dns zones, formatted to be rendered in a template.

       - get_role, role_name: Works in pair. Fails if role_name not provided.
                              Queries the re2o API and returns the list of
                              all machines whose role_type is role_name.

    If a term is not in the previous list, make a raw query to the API
    with endpoint term.

    It uses arguments api_hostname, api_username, api_password to connect
    to the API. api_hostname can also be defined in ansible configuration file
    (e.g. ansible.cfg) in section re2o. It overrides the values set when the
    plugin is called.

    Usage:

    The following play will use the debug module to output
    all the DNS zone names, querying the API hostname defined in configuration.

    - hosts: sputnik.adm.crans.org
      vars:
        dnszones: "{{ lookup('re2oapi', 'dnszones') }}"
      tasks:
        - debug: var=dnszones
    """

    def _readconfig(self, section="re2o", key=None, boolean=False,
                    integer=False):
        config = self._config
        if not config:
            return None
        else:
            if config.has_option(section, key):
                display.vvv("Found key {} in configuration file".format(key))
                if boolean:
                    return config.getboolean(section, key)
                elif integer:
                    return config.getint(section, key)
                else:
                    return config.get(section, key)

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        config_manager = ConfigManager()
        config_file = config_manager.data.get_setting(name="CONFIG_FILE").value
        self._config = ConfigParser()
        self._config.read(config_file)

        display.vvv("Using {} as configuration file.".format(config_file))

        self._api_hostname = None
        self._api_username = None
        self._api_password = None
        self._use_cpasswords = None
        self._cache_plugin = None
        self._cache = None
        self._timeout = 120

        if self._config.has_section("re2o"):
            display.vvv("Found section re2o in configuration file")

            self._api_hostname = self._readconfig(key="api_hostname")
            self._use_cpasswords = self._readconfig(key="use_cpasswords",
                                                    boolean=True)
            self._cache_plugin = self._readconfig(key="cache")
            self._timeout = self._readconfig(key="timeout", integer=True)

        if self._cache_plugin is not None:
            display.vvv("Using {} as cache plugin".format(self._cache_plugin))

            if self._cache_plugin == 'jsonfile':
                self._cachedir = Path.home() / ".cache/Ansible/re2oapi"
                display.vvv("Cache directory is {}".format(self._cachedir))
                if not self._cachedir.exists():
                    # Creates Ansible cache directory with right permissions
                    # if it doesn't exist yet.
                    display.vvv("Cache directory doesn't exist. Creating it.")
                    try:
                        self._cachedir.mkdir(mode=0o700, parents=True)
                    except Exception as e:
                        raise AnsibleError("""Unable to create {dir}.
                        Original error was : {err}"""
                                           .format(dir=self._cachedir,
                                                   err=to_native(e)))
                self._cache = cache_loader.get('jsonfile',
                                               _uri=self._cachedir,
                                               _timeout=self._timeout,
                                               )
            else:
                raise AnsibleError("Cache plugin {} not supported"
                                   .format(self._cache_plugin))

    def run(self, terms, variables=None, api_hostname=None, api_username=None,
            api_password=None, use_tls=True):

        """
           :arg terms: a list of lookups to run
               e.g. ['dnszones']
           :kwarg variables: ansible variables active at the time of the lookup
           :kwarg api_hostname: The hostname of re2o instance.
           :kwarg api_username: The username to connect to the API.
           :kwarg api_password: The password to use to connect to the API.
           :kwarg use_tls: A boolean to specify whether to use tls. You should!
           :returns: A list of results to the specific queries.
        """

        # Use the hostname specified by the user if it exists.
        if api_hostname is not None:
            display.vvv("Overriding api_hostname with {}".format(api_hostname))
        else:
            api_hostname = self._api_hostname

        if self._api_hostname is None:
            raise AnsibleError(to_native(
                'You must specify a hostname to contact re2oAPI'
            ))

        if (api_username is None and api_password is None
                and self._use_cpasswords):
            display.vvv("Using cpasswords vault to get API credentials.")
            api_username = variables.get('vault_re2o_service_user')
            api_password = variables.get('vault_re2o_service_password')

        if api_username is None:
            raise AnsibleError(to_native(
                'You must specify a valid username to connect to re2oAPI'
            ))

        if api_password is None:
            raise AnsibleError(to_native(
                'You must specify a valid password to connect to re2oAPI'
            ))

        api_client = Client(api_hostname, api_username,
                            api_password, use_tls=True)

        res = []
        dterms = collections.deque(terms)

        display.vvv("Lookup terms are {}".format(terms))
        while dterms:
            term = dterms.popleft()
            display.v("\nLookup for {} \n".format(term))
            if term == 'dnszones':
                res.append(self._getzones(api_client))
            elif term == 'dnsreverse':
                res.append(self._getreverse(api_client))
            elif term == 'get_role':
                try:
                    role_name = dterms.popleft()
                    roles = self._get_role(api_client, role_name)
                    res.append(roles)
                except IndexError:
                    display.v("Error in re2oapi : No role_name provided")
                    raise AnsibleError("role_name not found in arguments.")
            else:
                try:
                    res.append(self._rawquery(api_client, term))
                except Exception as e:
                    raise AnsibleError("""
                    An error occured while running re2oapi
                    lookup plugin. Original message was : {}"""
                                       .format(to_native(e)))
        return res

    def _get_cache(self, key):
        if self._cache:
            return self._cache.get(key)
        else:
            return None

    def _set_cache(self, key, value):
        if self._cache:
            return self._cache.set(key, value)
        else:
            return None

    def _is_cached(self, key):
        if self._cache:
            return self._cache.contains(key)
        else:
            return False

    def _getzones(self, api_client):
        display.v("Getting dns zone names")
        zones, zones_name = None, None

        if self._is_cached('dnszones'):
            zones_name = self._get_cache('dnszones')

        if zones_name is not None:
            display.vvv("Found dnszones in cache.")

        else:
            if self._is_cached('dns_zones'):
                zones = self._get_cache('dns_zones')
            if zones is not None:
                display.vvv("Found dns/zones in cache.")
            else:
                display.vvv("Contacting the API, endpoint dns/zones...")
                zones = api_client.list('dns/zones')
                display.vvv("...Done")
            zones_name = [zone["name"][1:] for zone in zones]
            display.vvv("Storing dnszones in cache.")
            self._set_cache('dnszones', zones_name)

        return zones_name

    def _getreverse(self, api_client):
        display.v("Getting dns reverse zones")

        zones, res = None, None

        if self._is_cached('dnsreverse'):
            res = self._get_cache('dnsreverse')

        if res is not None:
            display.vvv("Found dnsreverse in cache.")

        else:
            if self._is_cached('dns_reverse-zones'):
                zones = self._get_cache('dns_reverse-zones')

            if zones is not None:
                display.vvv("Found dns/reverse-zones in cache.")
            else:
                display.vvv("Contacting the API, endpoint dns/reverse-zones..")
                zones = api_client.list('dns/reverse-zones')
                display.vvv("...Done")

            display.vvv("Trying to format dns reverse in a nice way.")
            res = []
            for zone in zones:
                if zone['ptr_records']:
                    display.vvv('Found PTR records')
                    subnets = []
                    for net in zone['cidrs']:
                        net = netaddr.IPNetwork(net)
                        if net.prefixlen > 24:
                            subnets.extend(net.subnet(32))
                        elif net.prefixlen > 16:
                            subnets.extend(net.subnet(24))
                        elif net.prefixlen > 8:
                            subnets.extend(net.subnet(16))
                        else:
                            subnets.extend(net.subnet(8))

                    for subnet in subnets:
                        _address = netaddr.IPAddress(subnet.first)
                        rev_dns_a = _address.reverse_dns.split('.')[:-1]
                        if subnet.prefixlen == 8:
                            zone_name = '.'.join(rev_dns_a[3:])
                        elif subnet.prefixlen == 16:
                            zone_name = '.'.join(rev_dns_a[2:])
                        elif subnet.prefixlen == 24:
                            zone_name = '.'.join(rev_dns_a[1:])
                        res.append(zone_name)
                        display.vvv("Found reverse zone {}".format(zone_name))

                if zone['ptr_v6_records']:
                    display.vvv("Found PTR v6 record")
                    net = netaddr.IPNetwork(zone['prefix_v6']
                                            + '/'
                                            + str(zone['prefix_v6_length']))
                    net_class = max(((net.prefixlen - 1) // 4) + 1, 1)
                    zone6_name = ".".join(
                        netaddr.IPAddress(net.first)
                        .reverse_dns.split('.')[32 - net_class:])[:-1]
                    res.append(zone6_name)
                    display.vvv("Found reverse zone {}".format(zone6_name))

            display.vvv("Storing dns reverse zones in cache.")
            self._set_cache('dnsreverse', list(set(res)))

        return res

    def _rawquery(self, api_client, endpoint):
        res = None
        if self._is_cached(endpoint.replace('/', '_')):
            res = self._get_cache(endpoint.replace('/', '_'))
        if res is not None:
            display.vvv("Found {} in cache.".format(endpoint))
        else:
            display.v("Making a raw query {host}/api/{endpoint}"
                      .format(host=self.api_hostname, endpoint=endpoint))
            res = api_client.list(endpoint)
            display.vvv("Storing result in cache.")
            self._set_cache(endpoint.replace('/', '_'), res)
        return res

    def _get_role(self, api_client, role_name):
        res, machines_roles = None, None

        if self._is_cached(role_name):
            res = self._get_cache(role_name)

        if res is not None:
            display.vvv("Found {} in cache.".format(role_name))
        else:
            if self._is_cached("machines_role"):
                machines_roles = self._get_cache("machines_role")

            if machines_roles is not None:
                display.vvv("Found machines/roles in cache.")
            else:
                machines_roles = api_client.list("machines/role")
                display.vvv("Storing machines/role in cache.")
                self._set_cache("machines_role", machines_roles)

            res = list(filter(lambda m: m["role_type"] == role_name,
                              machines_roles))
            display.vvv("Storing {} in cache.".format(role_name))
            self._set_cache(role_name, res)

        return res