#!/usr/libexec/platform-python

from __future__ import print_function

import os
import sys
import time
import socket
import signal
import xml.sax
import threading
import xml.sax.expatreader

from xml.sax import saxutils
from xml.sax.handler import ContentHandler, feature_external_ges

# There's no reason that this probe should ever run more than a few
# minutes; kill after 10 minutes to prevent pile-up.
signal.alarm(10*60)

possible_probeconf = [
  "/etc/gratia/condor-events/ProbeConfig",
]

import gratia.common.Gratia as Gratia

def checkEnviron():
    """
    Return the location of the event log
    """
    fd = os.popen('condor_version')
    fd.read()
    if fd.close():
        raise Exception("Unable to successfully execute condor_version!")
    fd = os.popen('condor_config_val EVENT_LOG')
    results = fd.read()
    if fd.close():
        raise Exception("Unable to successfully execute condor_config_val " \
            "EVENT_LOG")
    return results.strip()


class ClassAdHandler(ContentHandler):
    """
    Streaming SAX handler for the output of condor_* -xml calls; it's around
    60 times faster than DOM and has a similar reduction in required memory.
    """

    def __init__(self, handler):
        """
        @param idx: The attribute name used to index the classads with.
        @keyword attrlist: A list of attributes to record; if it is empty, then
           parse all attributes.
        """
        self.handler = handler
        self.curCaInfo = {}
        self.prevCaInfo = {}
        self.attrInfo = ''
        self.attrType = None
        self.caCtr = 0

    def startDocument(self):
        """
        Start up a parsing sequence; initialize myself.
        """
        self.attrInfo = ''
        self.curCaInfo = {}
        self.prevCaInfo = {}
        self.attrType = None
        self.caCtr = 0
        self._starttime = time.time()

    def endDocument(self):
        """
        Print out debugging information from this document parsing.
        """
        self._endtime = time.time()
        self._elapsed = self._endtime - self._starttime
        myLen = self.caCtr
        log.info("Processed %i classads in %.2f seconds; %.2f classads/second"\
            % (myLen, self._elapsed, myLen/(self._elapsed+1e-10)))

    def startElement(self, name, attrs):
        """
        Open an XML element - take note if its a 'c', for the start of a new
        classad, or an 'a', the start of a new attribute.
        """
        if name == 'c':
            self.curCaInfo = {}
            self.caCtr += 1
        elif name == 'a': 
            self.attrName = str(attrs.get('n', 'Unknown'))
            self.attrInfo = '';
        elif name == 'b':
            self.attrType = 'b'
            self.attrInfo = attrs.get('v', 'UNKNOWN')
            if self.attrInfo == 't':
                self.attrInfo = True
            elif self.attrInfo == 'f':
                self.attrInfo = False
        elif name in ('s', 'i', 'r'):
            self.attrType = name
        else:
            pass
    
    def endElement(self, name):
        """
        End of an XML element - save everything we learned
        """
        if name == 'c':
            if self.prevCaInfo:
                self.emit(self.curCaInfo, self.prevCaInfo)
            self.prevCaInfo = self.curCaInfo
        elif name == 'a':
            self.curCaInfo[self.attrName] = self.attrInfo
            try:
                if self.attrType == 'i':
                    self.curCaInfo[self.attrName] = int(self.attrInfo)
                elif self.attrType == 'r':
                    self.curCaInfo[self.attrName] = float(self.attrInfo)
            except:
                print("Malformed attribute; type %s; value %s" \
                    "; key %s" % (self.attrType, self.attrInfo, self.attrName), file=sys.stderr)
        elif name == 'classads':
            self.emit(self.prevCaInfo)
        else:
            pass
    
    def emit(self, cur, prev=None):
        # Just one event
        if prev == None:
            if 'GlobalJobId' not in cur:
                try:           
                    hostname = socket.getfqdn()
                except:        
                    hostname = "UNKNOWN"
                cluster = cur.get("Cluster", "-1")
                proc = cur.get("Proc", "-1")
                job = "%s.%s" % (str(cluster), str(proc))
                mytime = "0"
                cur['GlobalJobId'] = "#".join([hostname, job, mytime])
            self.handler(cur['GlobalJobId'], cur)
            return
        if prev == None:
            prev = {}
        if cur.get('MyType', 'UNKNOWN') == 'JobAdInformationEvent':
            if prev.get('MyType', 'UNKNOWN') == 'JobAdInformationEvent':
                # Two info events in a row
                print("Two JobAdInformationEvents in a row!", file=sys.stderr)
                return
            # A real event followed by an info event; combine and emit one.
            for key, val in cur.items():
                prev.setdefault(key, val)
            self.emit(prev)
        else:
            # Two real events in a row; emit the oldest, keep the next one.
            if prev.get('MyType', 'UNKNOWN') != 'JobAdInformationEvent':
                self.emit(prev)
            # We don't print out the first of two info events.

    def characters(self, ch):
        """
        Save up the XML characters found in the attribute.
        """
        try:
            self.attrInfo += str(ch)
        except:
            pass

class MyIncrementalParser(xml.sax.expatreader.ExpatParser):

    def __init__(self, bufsize=2**9):
        self._bufsize = bufsize
        xml.sax.expatreader.ExpatParser.__init__(self)

    def parse(self, source):
        source = saxutils.prepare_input_source(source)

        self._source = source
        self._cont_handler.setDocumentLocator(xml.sax.expatreader.ExpatLocator(\
            self))

        file = source.getByteStream()
        buffer = file.read(self._bufsize)

        while buffer != "":
            self.feed(buffer)
            buffer = file.read(self._bufsize)
        
 
class EventLogXmlClient:

  def __init__(self, eventlog):
    self.started = False
    self.event_source = eventlog
    self.subscriptions = []

  def subscribe(self, callback, blocking=False):
    self.subscriptions.append(callback)
    if blocking:
      if self.started == True:
        while True:
          time.sleep(100)
      self.started = True
      self.checkResults()
    else:
      if self.started == True:
        return True
      self.started = True
      self.thread = threading.Thread(target=self.checkResults)
      self.thread.start()
    return True

  def checkResults(self):
    self.dh = ClassAdHandler(self.resultCallback)
    self.event_fd = open(self.event_source, 'r')
    self.parser = MyIncrementalParser()
    self.parser.setContentHandler(self.dh)
    self.parser.feed('<classads>')
    self.parser.parse(self.event_fd)
    self.parser.feed('</classads>')

  def resultCallback(self, id, info):
    for subs in self.subscriptions:
        subs(id, info)

class CondorUploader(object):

    def __init__(self):
        super(CondorUploader, self).__init__()
        is_init = False
        for probeConf in possible_probeconf:
            probeConf = os.path.expandvars(probeConf)
            if os.path.exists(probeConf):
                Gratia.Initialize(probeConf)
                is_init = True
                break
        if not is_init:
            raise Exception("Unable to find a suitable ProbeConfig; searched "\
                "locations are:\n%s" % "\n".join(possible_probeconf))

    def eventCallback(self, id, info):
        record = Gratia.UsageRecord("Events")
        if 'EventTime' in info and info['EventTime'][-1] != 'Z':
            info['EventTime'] += 'Z'
        record.EndTime(info['EventTime'])
        etype = info.get('EventTypeNumber', -1)
        user_info = (info['GlobalJobId'], 
                info.get('x509userproxysubject', info.get('Owner', 'UNKNOWN')),
                info.get('AccountingGroup', info.get('Owner', 'UNKNOWN')))
        record.GlobalJobId(info['GlobalJobId'])
        localJobId = info['GlobalJobId'].split('#')[1]
        record.LocalJobId(localJobId)
        record.LocalUserId(info.get('Owner', 'UNKNOWN'))
        record.UserKeyInfo(info.get('x509userproxysubject', info.get('Owner', 'UNKNOWN')))
        record.GlobalUsername(info.get('x509userproxysubject', info.get('Owner', 'UNKNOWN')))
        record.Queue(info.get('AccountingGroup', info.get('Owner', 'UNKNOWN')))
        record.Processors(info.get('CurrentHosts', 0))
        record.NodeCount(info.get('CurrentHosts', 0), description="Number of nodes used")
        toSend = True
        if etype == 12:
            record.Status("held")
        elif etype == 0:
            record.Status("queued")
        elif etype == 5:
            record.Status("Completed")
            id = info['GlobalJobId']
            record.StartTime(info['JobCurrentStartDate'])
            record.EndTime(info['EventTime'])
            starttime = info['JobCurrentStartDate']
            endtime = info['EventTime']
            queuetime = info['QDate']
            ncpus = info['CurrentHosts']
        elif etype == 1:
            record.Status("started")
        elif etype == 13:
            record.Status("started")
        elif etype == 7:
            record.Status("Error")
        elif etype == 6 or etype == 28:
            toSend = False
        else:
            record.Status(str(etype))

        if toSend == True:
            print(Gratia.Send(record))

def initializeEnvironment():
    is_init = False
    for probeConf in possible_probeconf:
        probeConf = os.path.expandvars(probeConf)
        if os.path.exists(probeConf):
            Gratia.Initialize(probeConf)
            is_init = True
            print("Using probe config file %s" % probeConf)
            break
    if not is_init:
        raise Exception("Unable to find a suitable ProbeConfig; searched "\
            "locations are:\n%s" % "\n".join(possible_probeconf))

    condor_binaries = Gratia.Config.getConfigAttribute("CondorBinaryDir")
    if condor_binaries and os.path.exists(condor_binaries):
        os.environ['PATH'] = condor_binaries + ':' + os.environ['PATH']
    condor_config = Gratia.Config.getConfigAttribute("CondorConfig")
    if condor_config and os.path.exists(condor_config):
        os.environ['CONDOR_CONFIG'] = condor_config

def main():

    print("Initializing environment...")
    initializeEnvironment()

    print("Checking environment...")
    eventlog = checkEnviron()
    print("Creating information client...")
    client = EventLogXmlClient(eventlog)
    print("Creating uploader...")
    uploader = CondorUploader()
    print("Subscribing uploader to information client.")
    client.subscribe(uploader.eventCallback)

if __name__ == '__main__':
    main()
