#!/usr/bin/python

import os
import re
import sys
import sets
import time
import socket
import logging
import optparse
import cStringIO
import threading
import traceback
import logging.handlers
import xml.sax.expatreader

from xml.sax import saxutils, make_parser
from xml.sax.handler import ContentHandler, feature_external_ges

import gratia.common.Gratia as Gratia
import gratia.services.StorageElement as StorageElement
import gratia.services.StorageElementRecord as StorageElementRecord
from gratia.common.Gratia import DebugPrint

log = None
timestamp = time.time()
SLEEP_TIME = 60
XRD_NAME = 'Xrootd'
XRD_AREA_NAME = '%s Area Tokens' % XRD_NAME
XRD_STATUS = "Production"
XRD_EXPIRE_MINUTES = 60 # Remove hosts that haven't reported in this many mins.

# Author: Chad J. Schroeder
# Copyright: Copyright (C) 2005 Chad J. Schroeder
# This script is one I've found to be very reliable for creating daemons.
# The license is permissible for redistribution.
# I've modified it slightly for my purposes.  -BB
UMASK = 0
WORKDIR = "/"

if (hasattr(os, "devnull")):
   REDIRECT_TO = os.devnull
else:
   REDIRECT_TO = "/dev/null"

def daemonize(pidfile):
   """Detach a process from the controlling terminal and run it in the
   background as a daemon.

   The detached process will return; the process controlling the terminal
   will exit.

   If the fork is unsuccessful, it will raise an exception; DO NOT CAPTURE IT.

   """

   try:
      pid = os.fork()
   except OSError, e:
      raise Exception("%s [%d]" % (e.strerror, e.errno))

   if (pid == 0):	# The first child.
      os.setsid()
      try:
         pid = os.fork()	# Fork a second child.
      except OSError, e:
         raise Exception("%s [%d]" % (e.strerror, e.errno))

      if (pid == 0):	# The second child.
         os.chdir(WORKDIR)
         os.umask(UMASK)
         for i in range(3):
             os.close(i)
         os.open(REDIRECT_TO, os.O_RDWR|os.O_CREAT) # standard input (0)
         os.dup2(0, 1)                        # standard output (1)
         os.dup2(0, 2)                        # standard error (2)
         try:
             fp = open(pidfile, 'w')
             fp.write(str(os.getpid()))
             fp.close()
         except:
             pass
      else:
         os._exit(0)	# Exit parent (the first child) of the second child.
   else:
      os._exit(0)	# Exit parent of the first child.

class XrootdSummaryParser(ContentHandler):

    def __init__(self, dataCallback):
        self.dataCallback = dataCallback

    def startDocument(self):
        self.cur_key = ()
        self.cur_val = None

    def emitData(self, key, val):
        try:
            self.dataCallback(key, val)
        except SystemExit, KeyboardInterrupt:
            raise
        except:
            pass

    def startElement(self, name, attrs):
        if name.startswith('statistics'):
            return
        if name == 'stats':
            key = str(attrs.get('id', 'stats'))
        else:
            key = str(name)
        self.cur_key += (key, )

    def endElement(self, name):
        if self.cur_key:
            self.cur_key = self.cur_key[:-1]

    def characters(self, ch):
        if self.cur_key and (self.cur_key[-1] in ['host', 'chars', 'lp', 'rp',
                'name']):
            self.emitData(self.cur_key, str(ch))
        else:
            try:
                val = int(str(ch))
                self.emitData(self.cur_key, val)
            except:
                pass

def gratia_log_traceback(lvl=0):
    exceptionType, exceptionValue, exceptionTraceback = sys.exc_info()
    tb = traceback.format_exception(exceptionType, exceptionValue,
        exceptionTraceback)
    tb_str = ''.join(tb)
    if DebugPrint: # Do nothing if DebugPrint is not initialized
        DebugPrint(lvl, "Encountered exception:\n%s" % tb_str)

invalid_xml_re = re.compile('</statistics toe="\d+">')
class XrdDataHandler(object):

    def __init__(self, callback):
        parser = make_parser()
        handler = XrootdSummaryParser(callback)
        parser.setContentHandler(handler)
        parser.setFeature(feature_external_ges, False)
        self.parser = parser

    def handle(self, data):
        data = invalid_xml_re.sub("</statistics>", data)
        fp = cStringIO.StringIO(data)
        self.parser.parse(fp)

def print_handler(key, val):
    DebugPrint(-1, "%s: %s" % (".".join(key), str(val)))

def udp_server(data_handler, port=3333, bind="0.0.0.0"):
    server_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
    server_socket.bind((bind, port))
    try:
        buf = 4096
        while 1:
            data, addr = server_socket.recvfrom(buf)
            data_handler.handle(data)
    finally:
        server_socket.close()

def test_udp_socket(port=3333, bind="0.0.0.0"):
    """
    Test the UDP socket to see if we can bind to it.
    Will throw a "Address already in use" exception if it's not usable.
    """
    server_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
    server_socket.bind((bind, port))
    server_socket.close()

class XrootdXmlParser(xml.sax.expatreader.ExpatParser):

    def __init__(self, bufsize=2**9):
        self._bufsize = bufsize
        xml.sax.expatreader.ExpatParser.__init__(self)

    def parse(self, source):
        source = saxutils.prepare_input_source(source)

        self._source = source
        self._cont_handler.setDocumentLocator(xml.sax.expatreader.ExpatLocator(\
            self))

        file = source.getByteStream()
        buffer = file.read(self._bufsize)

        while buffer != "":
            self.feed(buffer)
            buffer = file.read(self._bufsize)

    def feed(self, data, **kw):
        data = invalid_xml_re.sub("</statistics>", data)
        xml.sax.expatreader.ExpatParser.feed(self, data, **kw)

def file_handler(data_handler, input_file):
    handler = XrootdSummaryParser(data_handler)
    parser = XrootdXmlParser()
    parser.setContentHandler(handler)
    parser.setFeature(feature_external_ges, False)
    fp = open(input_file, 'r')
    parser.feed("<statisticstop>")
    parser.parse(fp)
    parser.feed("</statisticstop>")
    handler.endDocument()

class HostInfo(object):

    def __init__(self, hostname):
        self.hostname = hostname
        self.id_to_area = {}
        self.area_info = {}
        self.path_info = {}
        self.timestamp = time.time()

    def add_path(self, id):
        self.path_info[id] = {}

    def add_path_stat(self, id, stat, val):
        self.path_info[id][stat] = val

    def add_area(self, id, name):
        self.id_to_area[id] = name
        self.area_info[name] = {}

    def add_area_stat(self, id, stat, val):
        name = self.id_to_area[id]
        self.area_info[name][stat] = val

    def get_total_used_kb(self):
        return self.get_total_kb() - self.get_total_free_kb()

    def get_total_free_kb(self):
        return sum([int(i.get("free", 0)) for i in self.path_info.values()])

    def get_total_kb(self):
        return sum([int(i.get("tot", 0)) for i in self.path_info.values()])

    def get_area_info(self, area_name):
        return self.area_info.get(area_name, {})

    def get_area_names(self):
        return self.id_to_area.values()

    def get_hostname(self):
        return self.hostname

    def update_time(self):
        self.timestamp = time.time()

    def get_timestamp(self):
        return self.timestamp

    def __str__(self):
        out  = '\tHost: %s\n' % self.hostname
        out += '\t\tPaths:\n'
        for i in range(len(self.path_info)):
            for key, val in self.path_info[str(i)].items():
                out += '\t\t\t%i %s: %s\n' % (i, key, val)
        out += '\t\tAreas:\n'
        for area, info in self.area_info.items():
            out += '\t\t\tArea %s:\n' % area
            for key, val in info.items():
                 out += '\t\t\t\t%s: %s\n' % (key, val)
        return out

def get_se():
    config_se = Gratia.Config.getConfigAttribute("SiteName")
    if not config_se:
        raise Exception("SiteName attribute not found in ProbeConfig")
    return config_se

def get_rpm_version():
    cmd = "rpm -q gratia-probe-xrootd-storage --queryformat '%{VERSION}-%{RELEASE}'"
    fd = os.popen(cmd)
    output = fd.read()
    if fd.close():
        raise Exception("Unable to successfully query rpm for %s "
            "version information." % XRD_NAME)
    return output

version_cache = None
def get_version():
    global version_cache
    if version_cache != None:
        return version_cache
    config_version = Gratia.Config.getConfigAttribute("%sVersion" % XRD_NAME)
    if not config_version:
        config_version = get_rpm_version()
        #raise Exception("XrootdVersion attribute not found in ProbeConfig")
    version_cache = config_version
    return version_cache

class GratiaHandler(object):

    def __init__(self):
        self.host_info = {}
        self.cur_host = None
        self.cur_host_info = None
        self.stop = False
        self.stop_exception = None
        DebugPrint(4, "Finished with the GratiaHandler constructor.")

    def handle(self, key, val):
        DebugPrint(4, "GratiaHandler processed message: %s, %s" % (".".join(key),
            str(val)))
        if self.stop:
            raise self.stop_exception
        # Switch host if necessary.
        if key == ('info', 'host'):
           if val != self.cur_host:
               self.cur_host = val
               self.cur_host_info = self.host_info.get(val, None)
               if not self.cur_host_info:
                   self.cur_host_info = HostInfo(val)
                   self.host_info[val] = self.cur_host_info
               self.cur_host_info.update_time()
        if not self.cur_host:
            return

        # Parse OSS stats only
        if key[0] != 'oss':
            return
        # Parse path information
        elif key[1:] == ('paths',):
            for i in range(val):
                self.cur_host_info.add_path(str(i))
            return
        elif key[1] == 'paths':
            if len(key) < 4:
                return
            id = key[2]
            stat = key[3]
            self.cur_host_info.add_path_stat(id, stat, val)
        # Parse area information
        elif key[1] == 'space' and key[3] == 'name':
            # Format: oss.space.<integer>.name
            self.cur_host_info.add_area(key[2], val)
        elif key[1] == 'space':
            self.cur_host_info.add_area_stat(key[2], key[3], val)

    def summary(self):
        # Update our global timestamp so we don't have to do this in all the
        # Gratia send functions.
        global timestamp
        timestamp = time.time()

        DebugPrint(2, "%s Host info:" % XRD_NAME)
        to_delete = []
        for host, host_info in self.host_info.items():
            if timestamp - host_info.get_timestamp() > XRD_EXPIRE_MINUTES*60:
                to_delete.append(host)
        for host in to_delete:
            del self.host_info[host]
            DebugPrint(1, "** Deleted host %s that stopped reporting." % host)

        # Print out host info
        for host, host_info in self.host_info.items():
            DebugPrint(2, "**", host)
            DebugPrint(2, str(host_info))
            self.send_node_props(host_info)
        self.send_system_props()
        self.send_master_area()
        self.send_area_props()

    def send_node_props(self, host_info):
        """
        Send the storage information to Gratia for a single host.
        """
        se = get_se()
        version = get_version()
        name = host_info.get_hostname()
        unique_id = '%s:Pool:%s' % (se, name)
        parent_id = "%s:SE:%s" % (se, se)

        sa = StorageElement.StorageElement()
        sar = StorageElementRecord.StorageElementRecord()
        sa.UniqueID(unique_id)
        sa.Name(name)
        sa.SE(se)
        sa.SpaceType("Pool")
        sa.Implementation(XRD_NAME)
        sa.Version(version)
        sa.Status(XRD_STATUS)
        sa.ParentID(parent_id)
        sa.Timestamp(timestamp)
        sar.Timestamp(timestamp)
        sar.UniqueID(unique_id)
        sar.MeasurementType("raw")
        sar.StorageType("disk")
        sar.TotalSpace(1024*host_info.get_total_kb())
        sar.FreeSpace(1024*host_info.get_total_free_kb())
        sar.UsedSpace(1024*host_info.get_total_used_kb())
        Gratia.Send(sa)
        Gratia.Send(sar)

    def send_system_props(self):
        # Standard Gratia properties
        se = get_se()
        version = get_version()
        unique_id = "%s:SE:%s" % (se, se)
        parent_id = unique_id

        # Calculate the raw path totals for all hosts
        tot, used, free = 0, 0, 0
        for host_info in self.host_info.values():
            tot += host_info.get_total_kb()
            used += host_info.get_total_used_kb()
            free += host_info.get_total_free_kb()
        # Convert back to bytes
        tot *= 1024
        used *= 1024
        free *= 1024

        # Send out gratia information
        sa = StorageElement.StorageElement()
        sar = StorageElementRecord.StorageElementRecord()
        sa.UniqueID(unique_id)
        sa.Name(se)
        sa.SE(se)
        sa.SpaceType("SE")
        sa.Implementation(XRD_NAME)
        sa.Version(version)
        sa.Status(XRD_STATUS)
        sa.ParentID(parent_id)
        sa.Timestamp(timestamp)
        sar.UniqueID(unique_id)
        sar.MeasurementType("raw")
        sar.StorageType("disk")
        sar.TotalSpace(tot)
        sar.FreeSpace(free)
        sar.UsedSpace(used)
        sar.Timestamp(timestamp)
        Gratia.Send(sa)
        Gratia.Send(sar)

    def send_master_area(self):
        se = get_se()
        name = XRD_AREA_NAME
        space_type = 'Area'
        version = get_version()

        unique_id = '%s:%s:%s' % (se, space_type, name)
        parent_id = '%s:SE:%s' % (se, se)

        sa = StorageElement.StorageElement()
        sa.UniqueID(unique_id)
        sa.Name(name)
        sa.SE(se)
        sa.SpaceType(space_type)
        sa.Implementation(XRD_NAME)
        sa.Version(get_version())
        sa.Status(XRD_STATUS)
        sa.ParentID(parent_id)
        sa.Timestamp(timestamp)
        Gratia.Send(sa)

    def send_area_props(self):

        # Build the set of all area names
        all_areas = sets.Set()
        for host_info in self.host_info.values():
             all_areas.update(host_info.get_area_names())

        se = get_se()
        version = get_version()

        for area_name in all_areas:
            # Collect relevant area statistics
            tot, used, free, quot = 0, 0, 0, 0
            for host_info in self.host_info.values():
                info = host_info.get_area_info(area_name)
                area_tot = info.get('tot', 0)
                tot += area_tot
                free += 1024*info.get('free', 0)
                area_quot = info.get('qta', 0)
                if area_quot > area_tot:
                    quot += area_tot
                else:
                    quot += area_quot
                used += 1024*info.get('usg', 0)

            # Generic stuff for the StorageElement
            if quot:
                space_type = "Quota"
            else:
                space_type = "Directory"
            unique_id = '%s:%s:%s' % (se, space_type, area_name)
            parent_id = '%s:Area:%s' % (se, XRD_AREA_NAME)

            sa = StorageElement.StorageElement()
            sa.Name(area_name)
            sa.SE(se)
            sa.UniqueID(unique_id)
            sa.ParentID(parent_id)
            sa.SpaceType(space_type)
            sa.Implementation(XRD_NAME)
            sa.Version(version)
            sa.Status(XRD_STATUS)
            sa.Timestamp(timestamp)
            sar = StorageElementRecord.StorageElementRecord()
            sar.UniqueID(unique_id)
            sar.MeasurementType("logical")
            sar.StorageType("disk")
            sar.TotalSpace(tot)
            sar.FreeSpace(free)
            sar.UsedSpace(used)
            sar.Timestamp(timestamp)
            Gratia.Send(sa)
            Gratia.Send(sar)

def parse_opts():
    parser = optparse.OptionParser()
    parser.add_option("-d", "--daemon", help="Run as daemon; automatically " \
        "background the process.", default=False, action="store_true",
        dest="daemon")
    parser.add_option("-l", "--logfile", help="Log file location.  Defaults " \
        "to the Gratia logging infrastructure.", dest="logfile")
    parser.add_option("-i", "--input", help="Input file name; if this option" \
        " is given, the process does not listen for UDP messages", dest="input")
    parser.add_option("-p", "--port", help="UDP Port to listen on for " \
        "messages.  Overridden by Gratia ProbeConfig.", type="int",
        default=3333, dest="port")
    parser.add_option("--gratia_config", help="Location of the Gratia config;" \
        " defaults to /etc/gratia/xrootd-storage/ProbeConfig",
        dest="gratia_config")
    parser.add_option("-b", "--bind", help="Listen for messages on a " \
        "specific address; defaults to 0.0.0.0.  Overridden by Gratia " \
        "ProbeConfig", default="0.0.0.0", dest="bind")
    parser.add_option("-v", "--verbose", help="Enable verbose logging to " \
        "stdout.", default=False, action="store_true", dest="verbose")
    parser.add_option("--print_only", help="Only print data recieved; do not" \
        " send to Gratia.", dest="print_only", action="store_true")
    parser.add_option("-r", "--report_period", help="Time in minutes between" \
        " reports to Gratia.", dest="report_period", type="int")

    opts, args = parser.parse_args()

    # Expand our input paths:
    if opts.input:
        opts.input = os.path.expanduser(opts.input)
    if opts.logfile:
        opts.logfile = os.path.expanduser(opts.logfile)
    if opts.gratia_config:
        opts.gratia_config = os.path.expanduser(opts.gratia_config)

    # Adjust sleep time as necessary
    if opts.report_period:
        global SLEEP_TIME
        SLEEP_TIME = opts.report_period*60

    # Initialize logging
    logfile = "/var/log/gratia/xrootd-storage.log"
    if opts.logfile:
        logfile = opts.logfile
    path, _ = os.path.split(logfile)
    if path and not os.path.exists(path):
        os.makedirs(path)
    try:
        fp = open(logfile, 'w')
    except Exception, e:
        raise Exception("Could not open %s-Storage logfile, %s, for " \
            "write.  Error: %s." % (XRD_NAME, logfile, str(e)))
    global log
    log = logging.getLogger("XrdStorage")
    log.setLevel(logging.DEBUG)
    formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
    handler = logging.handlers.RotatingFileHandler(
        logfile, maxBytes=20*1024*1024, backupCount=5)
    handler.setLevel(logging.DEBUG)
    handler.setFormatter(formatter)
    log.addHandler(handler)
    if opts.verbose:
        handler = logging.StreamHandler()
        handler.setLevel(logging.DEBUG)
        handler.setFormatter(formatter)
        log.addHandler(handler)

    # Initialize Gratia
    gratia_config = None
    if opts.gratia_config and os.path.exists(opts.gratia_config):
        gratia_config = opts.gratia_config
    else:
        tmp = "/etc/gratia/xrootd-storage/ProbeConfig"
        if os.path.exists(tmp):
            gratia_config = tmp
    if not gratia_config:
        raise Exception("Unable to find a suitable ProbeConfig to use!")
    Gratia.Initialize(gratia_config)

    return opts

def main():
    opts = parse_opts()

    if opts.print_only:
        my_handler = print_handler
    else:
        gratia_handler = GratiaHandler()
        my_handler = gratia_handler.handle

    # Do all the daemon-specific tests.
    if not opts.input:
        if opts.daemon:
            # Test the socket first before we daemonize and lose the ability
            # to alert the user of potential errors.
            test_udp_socket(port=opts.port, bind=opts.bind)
            # Test to see if the pidfile is writable.
            pidfile = "/var/run/xrd_storage_probe.pid"
            open(pidfile, 'w').close()
            daemonize(pidfile)
            # Must re-initialize here because we changed processes and lost
            # the previous thread
            if not opts.print_only:
                gratia_handler = GratiaHandler()
                my_handler = gratia_handler.handle

    se = get_se()
    version = get_version()
    DebugPrint(1, "Running %s version %s for SE %s." % (XRD_NAME, version, se))

    handler = XrdDataHandler(my_handler)
    if not opts.input:
        try:
	    udpserverthread = threading.Thread(target=udp_server,args=(handler, opts.port, opts.bind))
	    udpserverthread.setDaemon(True)
	    udpserverthread.setName("Gratia udp_server Thread")
	    udpserverthread.start()
        finally:
            if not opts.print_only:
                gratia_handler.summary()
    else:
        input = opts.input
        try:
            fp = open(input, 'r')
        except Exception, e:
            raise Exception("Could not open input file, %s, for read due to " \
                "exception %s." % (input, str(e)))
        try:
            file_handler(my_handler, opts.input)
        finally:
            if not opts.print_only:
                gratia_handler.summary()
    try:
	while 1 == 1:
	    try:
	        DebugPrint(0, "Will send new Gratia data in %i seconds." % SLEEP_TIME)
	        time.sleep(SLEEP_TIME)
	    except KeyboardInterrupt, SystemExit:
	        raise
	    except Exception, e:
	        gratia_log_traceback(lvl=1)

            if opts.print_only:
                continue

	    try:
	        DebugPrint(1, "Creating a new Xrootd Gratia report.")
	        gratia_handler.summary()
	    except KeyboardInterrupt, SystemExit:
	        gratia_log_traceback(lvl=0)
	        gratia_handler.stop_exception()
	        gratia_handler.stop = True
	        raise
	    except Exception, e:
	        gratia_log_traceback(lvl=1)
    finally:
        if not opts.print_only:
            gratia_handler.summary()


if __name__ == '__main__':
    try:
        main()
    except:
        gratia_log_traceback()
        raise

