#!/usr/bin/python3

"""utility to merge kerberos keytab files within the CERN environment"""
import argparse
import os
import re
import sys
import tempfile
import shutil
import pexpect

def print_verbose(msg, verbose):
    """Print verbose messages"""
    if verbose:
        print(msg)

def print_debug(msg, debug):
    """Print debug messages"""
    if debug:
        print(msg)

def print_error(msg):
    """Print error messages"""
    print(f"Error: {msg}")
    sys.exit(1)

def remove_file(file):
    """remove a file, yo"""
    try:
        os.unlink(file)
    except FileNotFoundError:
        pass

def do_ktutil(ink, outk, principals, currentprinc, enctypes, append):
    """run ktutil, the heavy lifting of this script"""
    if append:
        tmp_outfile = outk
    else:
        tmp_outfile = generate_temp_file("ktutil.kt.")
    child = pexpect.spawn('/usr/bin/ktutil')
    child.expect('ktutil:', timeout=3)
    for keytabfile in ink:
        inkeytab = keytabfile[0]
        child.sendline(f"read_kt {inkeytab}")
        child.expect('ktutil:', timeout=3)
    child.sendline('list -e')
    child.expect('ktutil:', timeout=3)
    totalentries = {}
    lines = child.before.decode().split('\n')
    for l in lines:
        if re.search(r"^.*\@CERN\.CH", l):
            slot=int(l.split()[0])
            kvno=int(l.split()[1])
            principal=l.split()[2]
            enc=l.split()[3].replace('(','').replace(')','')
            totalentries[slot] = ({'kvno': kvno, 'principal': principal, 'enc': enc})
    slots_to_delete = []
    if principals is not None:
        for key,value in list(totalentries.items()):
            keep=False
            for p in principals:
                if p == value['principal']:
                    keep=True
            if not keep:
                slots_to_delete.append(int(key))
                totalentries.pop(key, None)
    if enctypes is not None:
        for key,value in list(totalentries.items()):
            keep=False
            for e in enctypes:
                if e == value['enc']:
                    keep=True
            if not keep:
                slots_to_delete.append(int(key))
                totalentries.pop(key, None)
    # clean up function
    if currentprinc is not None:
        # first, let's ensure we don't have duplicate entries in totalentries
        dupelist = {}
        for key,value in totalentries.items():
            keep=False
            for k,v in currentprinc.items():
                if k == f"{value['principal']}:{value['enc']}" and v == value['kvno']:
                    if f"{value['principal']}:{value['enc']}:{value['kvno']}" not in dupelist:
                        dupelist[f"{value['principal']}:{value['enc']}:{value['kvno']}"] = True
                        keep=True
            if not keep:
                slots_to_delete.append(int(key))
    # remove duplicates from the dict
    slots_to_delete = list(dict.fromkeys(slots_to_delete))
    for to_del in sorted(slots_to_delete, reverse=True):
        child.sendline(f"delent {to_del}")
        child.expect('ktutil:', timeout=3)
    child.sendline(f"write_kt {tmp_outfile}")
    child.expect('ktutil:', timeout=3)
    if not append:
        shutil.copy(tmp_outfile, outk)

def do_execute(cmd, verbose):
    """execute a thing"""
    try:
        print_verbose(f"running: {cmd}", verbose)
        ret = pexpect.run(cmd, withexitstatus=True)
    except pexpect.ExceptionPexpect:
        print_verbose(f"Failed to execute {cmd}", verbose)
    if ret[1]:
        print_verbose(f"Failed to execute {cmd}", verbose)
    return ret[0]

def fixselinux(keytab, verbose):
    """Fix selinux"""
    chcon = "/usr/bin/chcon"
    if (
        (
            os.path.exists("/sys/fs/selinux")
            or
            os.path.exists("/selinux/status")
        )
        and (
            not os.path.abspath(keytab).startswith('/afs')
            and
            not os.path.abspath(keytab).startswith('/eos')
        )
        and
        os.access(chcon, os.R_OK)
    ):
        do_execute(
            f"{chcon} system_u:object_r:krb5_keytab_t:s0 {keytab}", verbose
        )

def format_enctypes(enctypes):
    """Confirm that user input matches a valid  encryption type"""
    retenctypes = []
    allowedenctypes = {
        "ARCFOUR_HMAC",
        "AES128_CTS_HMAC_SHA1_96",
        "AES256_CTS_HMAC_SHA1_96",
    }
    for e in enctypes:
        good = False
        if e[0].replace('-','_').upper() in allowedenctypes:
            good = True
            retenctypes.append(e[0].replace('_','-').lower())
        if not good:
            print_error(
                    f"Error: Unknown enctype specified. Allowed ones are: "
                    f"{'|'.join(allowedenctypes)}"
            )
    return retenctypes

def format_principals(principals):
    """Ensure that user input has the full principal"""
    # This ensures that input to the perl port continues to work
    # in the python port
    retprincipals = []
    for p in principals:
        if p[0][-8:] == '@CERN.CH':
            retprincipals.append(f"{p[0]}")
        else:
            retprincipals.append(f"{p[0]}@CERN.CH")
    return retprincipals

def generate_temp_file(prefix):
    """Generate a temp file"""
    try:
        # pylint: disable=consider-using-with
        tfh = tempfile.NamedTemporaryFile(mode="w", prefix=prefix, dir="/tmp")
        filename = tfh.name
    except PermissionError:
        print_error(f"cannot create temporary {prefix} config file")
    return filename

def main():
    """main"""

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-i",
        "--inkeytab",
        required=True,
        type=str,
        action="append",
        nargs='+',
        help="Input keytab file name(s). May be specified multiple times.",
    )
    parser.add_argument(
        "-o",
        "--outkeytab",
        required=True,
        action="store",
        help="Output keytab file name.",
    )
    parser.add_argument(
        "-a",
        "--append",
        required=False,
        action="store_true",
        help="Append input keytab(s) keys to output keytab rather than overwriting it.",
    )
    parser.add_argument(
        "-c",
        "--cleanup",
        required=False,
        action="store_true",
        help="Remove all but newest (highest kvno) input keytab(s) keys before writing "
        "to the output keytab",
    )
    parser.add_argument(
        "-p",
        "--principal",
        required=False,
        action="append",
        nargs='+',
        help="Only act on input keytab(s) keys matching given Kerberos principal. "
        "May be specified multiple times. "
        "(principals in CERN keytabs are defined as: 'hostname.cern.ch' , "
        "'hostname$' or 'SRVC/hostname.cern.ch')",
    )
    parser.add_argument(
        "-e",
        "--enctypes",
        required=False,
        type=str,
        action="append",
        nargs='+',
        help="Only act on input keytab(s) keys matching given encryption type(s)."
        "May be specified multiple times. Allowed encryption types are:"
        "ARCFOUR_HMAC,AES128_CTS_HMAC_SHA1_96,AES256_CTS_HMAC_SHA1_96"
    )
    parser.add_argument(
        "-v",
        "--verbose",
        required=False,
        action="store_true",
        help="Display verbose information",
    )
    parser.add_argument(
        "-d",
        "--debug",
        required=False,
        action="store_true",
        help="Display debug information",
    )

    # This logic ensures that single dash shortname arguments work
    new_argv = []
    for arg in sys.argv:
        if arg.startswith("-") and not arg.startswith("--") and len(arg) > 2:
            arg = "-" + str(arg)
        new_argv.append(str(arg))

    sys.argv = new_argv
    args = parser.parse_args()

    if args.enctypes:
        args.enctypes = format_enctypes(args.enctypes)

    if args.principal:
        args.principal = format_principals(args.principal)

    print_verbose("Initializing Kerberos client", args.verbose)
    for keytabfile in args.inkeytab:
        inkeytab = keytabfile[0]
        print_verbose(f"Reading input keytab: {inkeytab}", args.verbose)
        print_verbose(f" using keytab file name: {inkeytab}", args.verbose)
        if not os.access(inkeytab, os.R_OK):
            print_error(f"keytab file not readable: {inkeytab}")
        print_verbose(f" resolving keytab file: {inkeytab}", args.verbose)
        # get current principals into a dict for easier parsing
        # note, this is actually needed for the cleanup function
        # and here we also mimic the same output as the perl port
        klist = do_execute(
            f"/usr/bin/klist -ke {inkeytab}",
            args.verbose,
        )
        principals = {}
        for line in klist.decode().splitlines():
            if re.search(r"^.*\@CERN\.CH", line):
                kvno=int(line.split()[0])
                principal=line.split()[1]
                enc=line.split()[2].replace('(','').replace(')','').replace('DEPRECATED:','')
                if f"{principal}:{enc}" in principals:
                    if kvno > principals[f"{principal}:{enc}"]:
                        principals.update({f"{principal}:{enc}": kvno})
                else:
                    principals.update({f"{principal}:{enc}": kvno})
                print_debug(f" found keytab entry for principal: {principal}", args.debug)
    if args.cleanup:
        # Not actually done here, done in the do_ktutil function
        # maintaining the output to mimic the perl port
        print_verbose('Performing key entries cleanup', args.verbose)
    else:
        principals = None
    if not args.append:
        # We don't actually remove, as this logic is now contained in do_ktutil
        # maintaining the output to mimic the perl port
        print_verbose(f"Removing existing keytab file: {args.outkeytab}", args.verbose)
    do_ktutil(args.inkeytab, args.outkeytab, args.principal, principals, args.enctypes, args.append)
    print_verbose(f"Writing output keytab: {args.outkeytab}", args.verbose)
    print_verbose(f" using keytab file name: {args.outkeytab}", args.verbose)
    print_verbose(f" resolving keytab file: {args.outkeytab}", args.verbose)
    print_verbose("Fixing SELinux context", args.verbose)
    fixselinux(args.outkeytab, args.verbose)

if __name__ == "__main__":
    main()
