#!/usr/bin/python3
# This probe imitates the LSF behavior of the pbs-lsf probe: urCollector.pl

import os
import sys, traceback  # for exceptions handling
import re

from gratia.common.Gratia import DebugPrint
#import gratia.common.GratiaWrapper as GratiaWrapper
import gratia.common.Gratia as Gratia

from gratia.common2.meter import GratiaProbe, GratiaMeter

from gratia.common2.filepinput import FileInput
import gratia.common2.timeutil as timeutil

from gratia.lsf.accounting import AcctFile, JobFinishEvent


def DebugPrintLevel(level, *args):
    if level <= 0:
        level_str = "CRITICAL"
    elif level >= 4:
        level_str = "DEBUG"
    else:
        level_str = ["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"][level]
    level_str = "%s - Lsf: " % level_str
    DebugPrint(level, level_str, *args)


class _LsfInputStub:
    """Stub class, needs to be defined before the regular one, to avoid NameError
    """

    # Sample Records: paths and commands have been altered
    # Use a test file for a more realistic test
    value_records = [
        ['JOB_FINISH', '7.02', '1278441653', '4325472', '8723', '33636474', '1', '1278441448', '0', '0', '1278441633', 'glow', 'normal', '', '', '', 'grid1', '', '/dev/null', '/home/32271.1294166239/stdout', '/home/glow/32271.1294166239/stderr', '1294166248.4325472', '0', '1', 'v020', '64', '23.1', '', '#! /bin/sh; echo "missing"',
         '0.129980', '0.197969', '0', '0', '-1', '0', '0', '4009', '0', '0', '0', '0', '-1', '0', '0', '0', '88', '24', '-1', '', 'default', '0', '1', '', '', '0', '7712', '110208', '', '', '', '', '0', '', '0', '', '-1', '/glow', '', '', '', '-1', ''],
        ['JOB_FINISH', '7.02', '1278441653', '4325473', '8723', '33636474', '1', '1278441449', '0', '0', '1278441633', 'glow', 'normal', '', '', '', 'grid1', '', '/dev/null', '/home/32279.1294166239/stdout', '/home/glow/32279.1294166239/stderr', '1294166249.4325473', '0', '1', 'c381', '64', '23.1', '', '#! /bin/sh; echo "missing"',
         '0.133979', '0.203968', '0', '0', '-1', '0', '0', '4009', '0', '0', '0', '0', '-1', '0', '0', '0', '90', '25', '-1', '', 'default', '0', '1', '', '', '0', '7628', '110212', '', '', '', '', '0', '', '0', '', '-1', '/glow', '', '', '', '-1', ''],
        ['JOB_FINISH', '7.02', '1278441653', '4325470', '8723', '33636474', '1', '1278441446', '0', '0', '1278441633', 'glow', 'normal', '', '', '', 'grid1', '', '/dev/null', '/home/32288.1294166239/stdout', '/home/glow/32288.1294166239/stderr', '1294166246.4325470', '0', '1', 'c152', '64', '23.1', '', '#! /bin/sh; echo "missing"',
         '0.121981', '0.203968', '0', '0', '-1', '0', '0', '4009', '0', '0', '0', '0', '-1', '0', '0', '0', '90', '24', '-1', '', 'default', '0', '1', '', '', '0', '7100', '110204', '', '', '', '', '0', '', '0', '', '-1', '/glow', '', '', '', '-1', '']
    ]

    @staticmethod
    def get_records():
        for i in _LsfInputStub.value_records:
            # return the list, will be converted in object in the caller
            yield i


class LsfInput(FileInput):
    """Get Lsf usage information from accounting file
    """

    VERSION_ATTRIBUTE = 'LsfVersion'
    PROCESS_LAST_RECORD = 'ProcessLastRecord'

    def get_init_params(self):
        """Return list of parameters to read form the config file"""
        return FileInput.get_init_params(self) + [LsfInput.VERSION_ATTRIBUTE, LsfInput.PROCESS_LAST_RECORD]

    def start(self, static_info):
        """open DB connection and set version form config file"""
        DebugPrint(4, "Lsf start, static info: %s" % static_info)
        # static_info has always the attributes, None if not in the config file
        if static_info[LsfInput.VERSION_ATTRIBUTE]:
            self._set_version_config(static_info[LsfInput.VERSION_ATTRIBUTE])
        FileInput.start(self, static_info)
        self._process_last_record = self.parse_config_boolean(static_info[LsfInput.PROCESS_LAST_RECORD])

    def _start_stub(self, static_info):
        """start replacement for testing: database connection errors are trapped"""
        if LsfInput.VERSION_ATTRIBUTE in static_info:
            self._set_version_config(static_info[LsfInput.VERSION_ATTRIBUTE])
        try:
            DebugPrintLevel(4, "Testing File existence. The probe will not use it")
            FileInput.start(self, static_info)
            if self.status_ok():
                DebugPrintLevel(4, "File OK")
            else:
                DebugPrintLevel(4, "Unable to read file")
            DebugPrintLevel(4, "Closing the file")
            self.stop()
        except:
            DebugPrint(1, "Unable to read file. The test can continue since stubs are used.")
        DebugPrint(4, "Lsf start stub, static info: %s" % static_info)

    def get_version(self):
        # RPM package for lsf
        # lsid -V prints the version on stderr
        # http://www.ccs.miami.edu/hpc/lsf/7.0.6/admin/cluster_ops.html
        return self._get_version(version_command='lsid -V 2>&1')

    def get_records(self, limit=None):
        """Extract the job records
        """
        checkpoint = self.checkpoint
        new_checkpoint = None
        if checkpoint:
            # prepare for checkpoint
            new_checkpoint = (checkpoint.date(), 0)
        else:
            pass

        files = []
        if self.data_file:
            DebugPrint(4, "Parsing new Lsf records in file: %s" % self.data_file)
            files = [self.data_file]
        elif self.data_dir:
            DebugPrint(4, "Parsing new Lsf records in directory: %s" % self.data_dir)
            # Assume that all LSF accounting files are named lsb.acct*
            files = self.iter_directory(self.data_dir, filename_re="lsb.acct*")
        else:
            DebugPrint(3, "No data file or data directory were specified")
        prev_r = None
        this_r = None
        for acct_file in files:
            try:
                fh = open(acct_file)
                more_records = True
                while more_records:
                    # loop used to skip over malformed records
                    # the new invocation of the generator will continue from the following record
                    try:
                        for r in AcctFile(fh):
                            if checkpoint is not None:
                                if r.eventTime < checkpoint.date():
                                    DebugPrint(6, "Skipping record before Checkpoint: %s" % r.jobID)
                                    continue
                            # This allows to know if the record is the last or not
                            prev_r = this_r
                            if prev_r:
                                yield prev_r
                            this_r = r
                            if new_checkpoint is not None and prev_r is not None:
                                if new_checkpoint[0] < prev_r.eventTime:
                                    new_checkpoint = (prev_r.eventTime, prev_r.fileLineNumber)
                        more_records = False
                        if new_checkpoint is not None:
                            # update checkpoint and commit (once per file)
                            # TODO: update checkpoint and commit
                            DebugPrint(4, "Saving New Checkpoint: %s (%s)" % (new_checkpoint, type(checkpoint)))
                            # new_checkpoint should be None or datetime
                            checkpoint.set_date_transaction(new_checkpoint[0], new_checkpoint[1])
                    except ValueError as e:
                        DebugPrint(2, "Error parsing %s: %s" % (acct_file, e))
                # The very last record may be incomplete or corrupted (written partially)
                # If processing also the very last record
                if self._process_last_record:
                    if this_r:
                        # Make sure that some events were processed
                        yield this_r
                        if new_checkpoint is not None:
                            if new_checkpoint[0] < this_r.eventTime:
                                # update checkpoint and commit (once per file)
                                DebugPrint(4, "Saving New Checkpoint: %s (%s)" % (
                                    (this_r.eventTime, this_r.fileLineNumber), type(checkpoint)))
                            # new_checkpoint should be None or datetime
                            checkpoint.set_date_transaction(this_r.eventTime, this_r.fileLineNumber)
            except IOError as e:
                DebugPrint(2, "Unable to open LSF accounting file %s: %s" % (acct_file, e))
            except:
                exc_type, exc_value, exc_traceback = sys.exc_info()
                DebugPrint(2, "Unknown error parsing the LSF accounting file: %s" % acct_file)
                DebugPrint(4, "Exception details \n%s" %
                           '\n'.join(traceback.format_exception(exc_type, exc_value, exc_traceback)))

    def _get_records_stub(self, limit=None):
        """get_records replacement for tests: records are from a pre-filled array
        limit is ignored"""
        counter = 0
        for i in _LsfInputStub.get_records():
            r = JobFinishEvent(i, counter)
            yield r
            counter += 1

    def do_test(self, static_info=None):
        """Test with pre-arranged value sets
        replacing: start, get_records
        """
        # replace DB calls with stubs
        self.start = self._start_stub
        self.get_records = self._get_records_stub


class LsfProbe(GratiaMeter):
    """
    LSF probe
    """

    # Set pre-defined values
    PROBE_NAME = 'lsf'
    # dCache, xrootd, Enstore
    CE_NAME = 'LsfCE'
    # Production
    CE_STATUS = 'Production'
    # Exclusive node use (will account for all CPUs
    RE_EXCLUSIVE_NODE = re.compile('#BSUB -x', re.IGNORECASE)
    # Fixes for invalid records zeroing time out (see GRATIA-119)
    MAX_JOB_TIME = 2000000000

    def __init__(self):
        GratiaMeter.__init__(self, self.PROBE_NAME)
        self._probeinput = LsfInput()
        self._lsf_bindir = ""
        self.job_env = {}

    # Functions to process and send
    def complete_and_send(self, jrecord):
        """
        Complete and send a Gratia record
        :param jrecord: LSF accounting record
        :return: None - numbers to do session accounting
        """
        # TODO: change return value to allow cumulative accounting (# of records sent)
        # TODO: use exceptions to skip discarded records?
        r = self.lsf_to_user_record(jrecord)
        if r is None:
            DebugPrint(4, "Skipping LSF job: %s, %s" % (jrecord.jobID, jrecord.jobName))
            return
        if jrecord.execHosts:
            lshost = "lshost"
            if self._lsf_bindir:
                lshost = os.path.join(self._lsf_bindir, lshost)
            self.add_hostinfo(r, jrecord, lshost)

        DebugPrint(4, "Sending record for LSF job: %s, %s" % (jrecord.jobID, jrecord.jobName))
        DebugPrint(5, "Record being sent: %s" % r)
        Gratia.Send(r)

    def lsf_to_user_record(self, jrecord):
        """Preparing and sending a job record
        Add here some notes about conventions/conversions:
        - assuming that all time stamps in lsb.acct are UTC
        :param jrecord: LsfInput, record coming from lsb.acct file
        :return: values usable for cuulative statistics
        """
        DebugPrint(5, "Processing job record: %s" % jrecord)

        # Empty usage record
        resource_type = "Batch"
        # resource_type = "GridMonitor"
        # resource_type = 'BatchPilot'
        r = Gratia.UsageRecord(resource_type)
        # get Grid type from configuration, default OSG
        r.Grid(self.get_config_attribute("Grid", "OSG"))

        # Adding ID
        r.LocalJobId(str(jrecord.jobID))
        # TODO: global id?
        # TODO: is the ID OK? Unique enough? (jrecord.fromHost, jrecord.jobID,) get_hostname?
        #global_id = "%s-%s" % (jrecord.fromHost, jrecord.jobID )

        global_id = None
        if 'GRID_JOBID' in self.job_env:
            global_id = self.job_env['GRID_JOBID']
            DebugPrint(4, "Using grid job ID %s also as RecordIdentity!" % global_id)
        elif jrecord.fromHost and jrecord.jobID and jrecord.startTimeEpoch:
            #
            global_id = "%s:%s_%d" % (self._normalize_hostname(jrecord.fromHost), jrecord.jobID, jrecord.startTimeEpoch)
            # ":".$urAcctlogInfo{lrmsId}."_".$urAcctlogInfo{start}
            DebugPrint(4, "Composing global job ID %s for RecordIdentity" % global_id)
        if not global_id:
            DebugPrint(2, "ERROR: Cannot determine or construct a unique job ID ... skipping record!")
            DebugPrint(4, "ERROR: One component is not defiled: host:%s, job:%s, time:%s)" %
                       (jrecord.fromHost, jrecord.jobID, jrecord.submitTimeEpoch))
            return None  # return 3
        r.GlobalJobId(global_id)

        # It is actually a user name
        r.LocalUserId(jrecord.userName)

        # r.VOName still missing

        r.EndTime(timeutil.format_datetime(jrecord.eventTime))
        # jrecord.creationTime (submit/creation time)
        # Previous probe setting machine name, not submit host r.SubmitHost()
        r.MachineName(self._normalize_hostname(jrecord.fromHost), "Server")
        # Previous probe was adding description: LRMS type: lsf
        r.Queue(jrecord.queue, "LRMS type: lsf")
        # user above
        # jobID above
        r.Processors(jrecord.numProcessors)  # processors as from the acct record
        #if jrecord.startTimeEpoch:  # used as flag to see if the job ran
        #    walltime = int(jrecord.eventTimeEpoch - jrecord.startTimeEpoch)
        if jrecord.startTimeEpoch >= 1:
            # the job started (and ran)
            # python 2.7 jrecord.runTime.total_seconds
            seconds_walltime = timeutil.datetime_timedelta_to_seconds(jrecord.runTime)
            seconds_system = timeutil.datetime_timedelta_to_seconds(jrecord.stime)
            seconds_user = timeutil.datetime_timedelta_to_seconds(jrecord.utime)

            # Fixes for invalid records (see GRATIA-119), zero out time fields if they
            # have really large numbers, threshold time based on bad values used in
            # condor probe
            if (seconds_system+seconds_user) > self.MAX_JOB_TIME:
                DebugPrint(3, "WARNING: INVALID DATA: Record for %s has invalid cpu time (system: %s, user: %s), replacing values with 0" %(global_id, seconds_system, seconds_user))
                seconds_system = 0
                seconds_user = 0
            if seconds_walltime > self.MAX_JOB_TIME:
                DebugPrint(3, "WARNING: INVALID DATA: Record for %s has invalid walltime time %s, replacing value with 0" % (global_id, seconds_walltime))
                seconds_walltime = 0

            # Gratia durations accept either string or # of seconds
            r.WallDuration(timeutil.format_interval(seconds_walltime))
            # previous probe was reporting total (user+system) - was accounted as user and system was 0
            r.CpuDuration(seconds_system, "system", "Was entered in seconds")
            r.CpuDuration(seconds_user, "user", "Was entered in seconds")

            # mem jrecord.maxRMem
            # swap jrecord.maxRSwap
            r.StartTime(timeutil.format_datetime(jrecord.startTime))
        # end jrecord.eventTime(Epoch)
        r.TimeInstant(timeutil.format_datetime(jrecord.submitTime), 'SubmitTime')  # ctime
        #if jrecord.execHosts:
        #    # last exec host
        #    r.Host(jrecord.execHosts[-1])
        r.JobName(jrecord.jobName)  #jobName
        # job_command = jrecord.command  # command (actual script?)
        r.Status(jrecord.exitStatus, "exit status")  # exitStatus
        return r

    @staticmethod
    def add_hostinfo(r, jrecord, lshost):
        if jrecord.execHosts:
            # last exec host
            r.Host(jrecord.execHosts[-1], description="executing host")
            # Conditions from perl code:
            # command =~ m/#BSUB -x/i
            #  & ( exist execHost & execHost != "" )
            #  & ( ! exist proc | proc ="" | proc = 1 )
            if not LsfProbe.RE_EXCLUSIVE_NODE.match(jrecord.command):
                return

            if not jrecord.numProcessors or jrecord.numProcessors == 1:
                # Execute: lshosts $execHost   Host
                # which returns
                #HOST_NAME      type    model  cpuf ncpus maxmem maxswp server RESOURCES
                #c485         X86_64   PC1133  23.1     8 16046M 16378M    Yes (mpich2)
                """ perl code to imitate:
         # print "Will execute lshost $urAcctlogInfo{execHost}\n";
         my $lshosts = $lsfBinDir."/lshosts";
         open FH, "$lshosts $urAcctlogInfo{execHost} |" or die "Failed to open pipeline to/from lshost";
         my @lines = <FH>;
         close(FH);
         print "Need: @lines\n";
         if ( scalar @lines == 2) {
            my @headers = split(/ +/,$lines[0]);
            my @values = split(/ +/,$lines[1]);
            my $search = "ncpus";
            my $index = first { $headers[$_] eq $search } 0 .. $#headers;
            # print "Found value for $search = $values[$index]\n";
            $urAcctlogInfo{processors} = $values[$index] || 1;
         }
                """

                # execute lshost jrecord.execHosts[-1]
                cmd = "%s %s" % (lshost, jrecord.execHosts[-1])
                out = GratiaProbe.run_command(cmd)
                ncpus = 1
                if out:
                    rows = out.split('\n')
                    if len(rows) == 2:
                        counter = 0
                        for i in rows[0].split():
                            if i == 'ncpus':
                                ncpus = rows[1].split()[counter]
                            counter += 1
                r.Processors(ncpus, metric="max")   # max or average or total (default)

    @staticmethod
    def _is_ok(jrecord):
        return True

    def register_gratia(self):
        super(self.__class__, self).register_gratia()
        #These have been taken care in the parent registerGratia
        #Gratia.RegisterReporter()
        #Gratia.RegisterReporterLibrary()
        #Gratia.RegisterService()
        # This sets a global variable used by certinfo.py to guess the certinfo file name
        Gratia.setProbeBatchManager('lsf')

    def main(self):
        # Initialize the probe an the input
        self.start()
        DebugPrintLevel(4, "LSF probe started - main loop")

        #se = self.get_sitename()
        #name = self.get_probename()
        #hostname = self.get_hostname()

        jobs = {}

        # Loop over job records
        for jrecord in self._probeinput.get_records():
            """
            Job record is an accounting.JobFinishEvent object
version - LSF Version number as a string
eventTimeEpoch - time generated (= end time) in seconds since epoch
eventTime - time generated (= end time) as datetime
jobID - LSF job id
userId - numeric user ID of the user who owned the job
submitTimeEpoch, submitTime - job submitted
beginTimeEpoch, beginTime - scheduled job start time
termTimeEpoch, termTime - Job termination deadline
startTimeEpoch, startTime - start time
userName - user name of the submitter
queue - job queue
fromHost
cwd
inFile
outFile
errFile
jobFile
numAskedHosts - Number of host names to which job dispatching will be limited
askedHosts - List of host names to which job dispatching will be limited
numExHosts - Number of processors used for execution (if LSF_HPC_EXTENSIONS="SHORT_EVENTFILE" in lsf.conf, number of .hosts listed in execHosts)
execHosts - List of execution host names (allocation at job finish time)
jStatus - Job status. 32=EXIT, 64=DONE
hostFactor - CPU factor of the first execution host
jobName - Job name (up to 4094 characters)
command - Complete batch job command
# -1=unavailable, units are sec or KB
utime - User time used in seconds
stime - System time used in seconds
exutime - Exact user time used
mailUser - Name of the user to whom job related mail was sent
projectName - LSF project name
exitStatus - UNIX exit status of the job
maxNumProcessors - Maximum number of processors specified for the job
loginShell - Login shell used for the job
timeEvent
termInfo - TermInfo object
warningAction
warningTimePeriod
chargedSAAP - Share Attribute Account Path (SAAP) that was charged for the job
runTime - tun time (0 if never started, startTime-eventTime) datetime.timedelta
# Job never started
waitTime=eventTime-submitTime
startTime=termTime
# else:
waitTime=startTime-submitTime
pendTime=waitTime - The time the job was pending
            """
            DebugPrint(5, "Preparing job record for: %s" % jrecord)

            if self._is_ok(jrecord):
                # TODO: send only completed jobs?
                # also jobs that did not start?
                self.complete_and_send(jrecord)
                #del mounts[t_volume]
            else:
                # Invalid job!
                # Add also some message why it is not valid
                DebugPrint(3, "Skipping Invalid job %s." % jrecord)


if __name__ == "__main__":
    # Do the work
    LsfProbe().main()

