#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import shutil
import socket
import subprocess
import sys
from typing import Optional, Tuple
import urllib
import urllib.error
import urllib.request


KNOWN_INSTANCES = [
    "stash-origin",
    "stash-origin-auth",
    "stash-cache",
    "stash-cache-auth",
]


# fmt: off
ENDPOINTS = {
    "Authfile": {
        "stash-origin"      : "/origin/Authfile-public?fqdn={fqdn}",
        "stash-origin-auth" : "/origin/Authfile?fqdn={fqdn}",
        "stash-cache"       : "/cache/Authfile-public?fqdn={fqdn}",
        "stash-cache-auth"  : "/cache/Authfile?fqdn={fqdn}",
    },
    "scitokens.conf": {
        "stash-origin"      : None,
        "stash-origin-auth" : "/origin/scitokens.conf?fqdn={fqdn}",
        "stash-cache"       : None,
        "stash-cache-auth"  : "/cache/scitokens.conf?fqdn={fqdn}",
    },
    "grid-mapfile": {
        "stash-origin"      : None,
        "stash-origin-auth" : "/origin/grid-mapfile?fqdn={fqdn}",
        "stash-cache"       : None,
        "stash-cache-auth"  : "/cache/grid-mapfile?fqdn={fqdn}",
    },
}
# fmt: on


CONFIG_FILES = list(ENDPOINTS.keys())


def complain(*values, **kwargs):
    # print to sys.stderr
    kwargs["file"] = sys.stderr
    return print(*values, **kwargs)


def die_with_usage(prog):
    print(
        f"""
Usage: {prog} <instance>
   or  {prog} --cache
   or  {prog} --origin

where <instance> is one of:
    stash-cache
    stash-cache-auth
    stash-origin
    stash-origin-auth

--cache is equivalent to running it for stash-cache and stash-cache-auth
--origin is equivalent to running it for stash-origin and stash-origin-auth

Environment variables used:
    CACHE_FQDN      FQDN used for cache authfile query (default: the full hostname)
    ORIGIN_FQDN     FQDN used for origin authfile query (default: the full hostname)
    TOPOLOGY        Topology server to get the data from (default: https://topology.opensciencegrid.org)
    DESTDIR         The base directory to write results to (default: /run)
""",
        file=sys.stderr,
    )
    sys.exit(2)


class Download:
    def __init__(self, topology, destdir, instance, config_file, fqdn):
        self.topology = topology
        self.destdir = destdir
        self.instance = instance
        self.config_file = config_file
        self.fqdn = fqdn

        self.full_destdir = f"{self.destdir}/{self.instance}"
        self.dest_file = f"{self.full_destdir}/{self.config_file}"
        self.local_files = [
            f"{self.destdir}/{self.instance}/{self.config_file}.local",
            f"/etc/xrootd/{self.instance}-{self.config_file}.local",
        ]
        self.prepend_local = config_file == "grid-mapfile"
        # ^^ local additions to the grid-mapfile are prepended, not appended
        #    to what's downloaded from topology

    def fetch(self) -> Tuple[Optional[str], bool]:
        """Download the data for this config file from Topology and return
        the content of the download (`text`) and a boolean indicating
        success/failure (based on HTTP return code) (`ok`).

        Returns (None, True) if there is no endpoint for this config file e.g.
        scitokens.conf for an unauthenticated cache.

        """
        endpoint = ENDPOINTS[self.config_file][self.instance]
        if not endpoint:
            return None, True

        url = self.topology + endpoint.format(fqdn=self.fqdn)
        try:
            response = urllib.request.urlopen(url)
            text = response.read()
            if text:
                ok = True
            else:
                ok = False
        except urllib.error.HTTPError as err:
            # An HTTP error might indicate an error with the Topology registration
            # or the query; the contents are useful.
            text = err.read()
            ok = False
        if not isinstance(text, str):
            text = text.decode("utf-8", errors="replace")

        return text, ok

    def combine_with_local_files(self, text: str) -> str:
        """Return the given text with the additions from the local files
        for this config file, if there are any.  Missing files are silently
        skipped; other read errors are reported but are not failures.

        """
        new_text = ""
        for local_file in self.local_files:
            try:
                with open(local_file, "rt", encoding="utf-8", errors="replace") as fh:
                    new_text += (
                        f"## The following lines are from {local_file}:\n"
                        + fh.read().rstrip("\n")
                        + "\n\n"
                    )
            except FileNotFoundError:
                pass
            except OSError as err:
                complain(f"Couldn't read local file {local_file}: {err} (continuing)")

        if new_text:
            if self.prepend_local:
                new_text += "## The following lines are from OSG Topology:\n" + text
            else:
                new_text = text + "\n\n" + new_text
            return new_text.rstrip("\n") + "\n"  # have exactly one final newline
        else:
            return text

    def report_download_error(self, text):
        """Print errors downloading the config file for the instance."""
        complain(f"Error fetching {self.config_file} for {self.instance}")
        if not text:
            complain("No data received")
            return
        complain("Response follows:")
        complain(text)

    def write_dest_file(self, text):
        """Writes the destination file atomically."""
        with open(self.dest_file + ".new", "wt", encoding="utf-8") as new_fh:
            new_fh.write(text)
        shutil.move(self.dest_file + ".new", self.dest_file)
        lines = text.count("\n")
        print(f"{lines} lines written successfully to {self.dest_file}.")


def handle_instance(instance, topology, destdir):
    if "cache" in instance:
        fqdn = os.environ.get("CACHE_FQDN", socket.getfqdn())
    elif "origin" in instance:
        fqdn = os.environ.get("ORIGIN_FQDN", socket.getfqdn())
    else:
        assert False, f"bad instance {instance} should have been caught"

    ret = 0

    for config_file in CONFIG_FILES:
        dl = Download(
            topology=topology,
            destdir=destdir,
            instance=instance,
            config_file=config_file,
            fqdn=fqdn,
        )

        if not os.path.isdir(dl.full_destdir):
            complain(f"Destination directory {dl.full_destdir} doesn't exist")
            return 1  # none of the other downloads will work either

        text, ok = dl.fetch()

        if not ok:
            # some failure happened; inform user of the error but then continue with
            # the next file
            ret = 1
            dl.report_download_error(text)
            continue

        if not text:
            # we didn't download test but that may be ok for this instance
            continue

        # download is successful; now combine the file with any local files
        text = dl.combine_with_local_files(text)

        try:
            dl.write_dest_file(text)
        except OSError as err:
            complain(f"Couldn't write {dl.dest_file}: {err}")
            ret = 1
            continue

    return ret


def main(argv=None):
    if argv is None:
        argv = sys.argv

    topology = os.environ.get("TOPOLOGY", "https://topology.opensciencegrid.org")
    destdir = os.environ.get("DESTDIR", "/run")
    ret = 0

    if len(argv) != 2:
        die_with_usage(argv[0])

    if argv[1] == "--cache":
        instances = ["stash-cache", "stash-cache-auth"]
    elif argv[1] == "--origin":
        instances = ["stash-origin", "stash-origin-auth"]
    else:
        if argv[1] not in KNOWN_INSTANCES:
            complain(f"Unknown instance {argv[1]}")
            die_with_usage(argv[0])
        else:
            instances = [argv[1]]

    for instance in instances:
        ret |= handle_instance(instance, topology=topology, destdir=destdir)

    return ret


if __name__ == "__main__":
    sys.exit(main())
