#!/usr/bin/python3

import os
import re
import sys
import sets
import time
import socket
import logging
import optparse
try:
    import cStringIO
except ImportError:
    import io as cStringIO
import threading
import traceback
import logging.handlers
import xml.sax.expatreader

from xml.sax import saxutils, make_parser
from xml.sax.handler import ContentHandler, feature_external_ges

import gratia.common.Gratia as Gratia
import gratia.services.StorageElement as StorageElement
import gratia.services.StorageElementRecord as StorageElementRecord
from gratia.common.Gratia import DebugPrint

log = None
timestamp = time.time()
SLEEP_TIME = 60
XRD_NAME = 'Xrootd'
XRD_AREA_NAME = '%s Area Tokens' % XRD_NAME
XRD_STATUS = "Production"
XRD_EXPIRE_MINUTES = 60 # Remove hosts that haven't reported in this many mins.

# Author: Chad J. Schroeder
# Copyright: Copyright (C) 2005 Chad J. Schroeder
# This script is one I've found to be very reliable for creating daemons.
# The license is permissible for redistribution.
# I've modified it slightly for my purposes.  -BB
UMASK = 0
WORKDIR = "/"

if (hasattr(os, "devnull")):
   REDIRECT_TO = os.devnull
else:
   REDIRECT_TO = "/dev/null"

def daemonize(pidfile):
   """Detach a process from the controlling terminal and run it in the
   background as a daemon.

   The detached process will return; the process controlling the terminal
   will exit.

   If the fork is unsuccessful, it will raise an exception; DO NOT CAPTURE IT.

   """

   try:
      pid = os.fork()
   except OSError as e:
      raise Exception("%s [%d]" % (e.strerror, e.errno))

   if (pid == 0):        # The first child.
      os.setsid()
      try:
         pid = os.fork()        # Fork a second child.
      except OSError as e:
         raise Exception("%s [%d]" % (e.strerror, e.errno))

      if (pid == 0):    # The second child.
         os.chdir(WORKDIR)
         os.umask(UMASK)
         for i in range(3):
             os.close(i)
         os.open(REDIRECT_TO, os.O_RDWR|os.O_CREAT) # standard input (0)
         os.dup2(0, 1)                        # standard output (1)
         os.dup2(0, 2)                        # standard error (2)
         try:
             fp = open(pidfile, 'w')
             fp.write(str(os.getpid()))
             fp.close()
         except:
             pass
      else:
         os._exit(0)    # Exit parent (the first child) of the second child.
   else:
      os._exit(0)       # Exit parent of the first child.

class XrootdSummaryParser(ContentHandler):

    def __init__(self, dataCallback):
        self.dataCallback = dataCallback

    def startDocument(self):
        self.cur_key = ()
        self.cur_val = None

    def emitData(self, key, val):
        try:
            self.dataCallback(key, val)
        except (SystemExit, KeyboardInterrupt) as _:
            raise
        except:
            pass

    def startElement(self, name, attrs):
        if name.startswith('statistics'):
            return
        if name == 'stats':
            key = str(attrs.get('id', 'stats'))
        else:
            key = str(name)
        self.cur_key += (key, )

    def endElement(self, name):
        if self.cur_key:
            self.cur_key = self.cur_key[:-1]

    def characters(self, ch):
        if self.cur_key and (self.cur_key[-1] in ['host', 'chars', 'lp', 'rp',
                'name']):
            self.emitData(self.cur_key, str(ch))
        else:
            try:
                val = int(str(ch))
                self.emitData(self.cur_key, val)
            except:
                pass

def gratia_log_traceback(lvl=0):
    exceptionType, exceptionValue, exceptionTraceback = sys.exc_info()
    tb = traceback.format_exception(exceptionType, exceptionValue,
        exceptionTraceback)
    tb_str = ''.join(tb)
    if DebugPrint: # Do nothing if DebugPrint is not initialized
        DebugPrint(lvl, "Encountered exception:\n%s" % tb_str)

invalid_xml_re = re.compile('</statistics toe="\d+">')
class XrdDataHandler(object):

    def __init__(self, callback):
        parser = make_parser()
        handler = XrootdSummaryParser(callback)
        parser.setContentHandler(handler)
        parser.setFeature(feature_external_ges, False)
        self.parser = parser

    def handle(self, data):
        data = invalid_xml_re.sub("</statistics>", data)
        fp = cStringIO.StringIO(data)
        self.parser.parse(fp)

def print_handler(key, val):
    DebugPrint(-1, "%s: %s" % (".".join(key), str(val)))

def udp_server(data_handler, port=3333, bind="0.0.0.0"):
    server_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
    server_socket.bind((bind, port))
    try:
        buf = 4096
        while 1:
            data, addr = server_socket.recvfrom(buf)
            data_handler.handle(data)
    finally:
        server_socket.close()

def test_udp_socket(port=3333, bind="0.0.0.0"):
    """
    Test the UDP socket to see if we can bind to it.
    Will throw a "Address already in use" exception if it's not usable.
    """
    server_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
    server_socket.bind((bind, port))
    server_socket.close()

class XrootdXmlParser(xml.sax.expatreader.ExpatParser):

    def __init__(self, bufsize=2**9):
        self._bufsize = bufsize
        xml.sax.expatreader.ExpatParser.__init__(self)

    def parse(self, source):
        source = saxutils.prepare_input_source(source)

        self._source = source
        self._cont_handler.setDocumentLocator(xml.sax.expatreader.ExpatLocator(\
            self))

        file = source.getByteStream()
        buffer = file.read(self._bufsize)

        while buffer != "":
            self.feed(buffer)
            buffer = file.read(self._bufsize)

    def feed(self, data, **kw):
        data = invalid_xml_re.sub("</statistics>", data)
        xml.sax.expatreader.ExpatParser.feed(self, data, **kw)

def file_handler(data_handler, input_file):
    handler = XrootdSummaryParser(data_handler)
    parser = XrootdXmlParser()
    parser.setContentHandler(handler)
    parser.setFeature(feature_external_ges, False)
    fp = open(input_file, 'r')
    parser.feed("<statisticstop>")
    parser.parse(fp)
    parser.feed("</statisticstop>")
    handler.endDocument()

class HostInfo(object):

    def __init__(self, hostname):
        self.hostname = hostname
        self.id_to_area = {}
        self.area_info = {}
        self.path_info = {}
        self.timestamp = time.time()

    def add_path(self, id):
        self.path_info[id] = {}

    def add_path_stat(self, id, stat, val):
        self.path_info[id][stat] = val

    def add_area(self, id, name):
        self.id_to_area[id] = name
        self.area_info[name] = {}

    def add_area_stat(self, id, stat, val):
        name = self.id_to_area[id]
        self.area_info[name][stat] = val

    def get_total_used_kb(self):
        return self.get_total_kb() - self.get_total_free_kb()

    def get_total_free_kb(self):
        return sum([int(i.get("free", 0)) for i in self.path_info.values()])

    def get_total_kb(self):
        return sum([int(i.get("tot", 0)) for i in self.path_info.values()])

    def get_area_info(self, area_name):
        return self.area_info.get(area_name, {})

    def get_area_names(self):
        return list(self.id_to_area.values())

    def get_hostname(self):
        return self.hostname

    def update_time(self):
        self.timestamp = time.time()

    def get_timestamp(self):
        return self.timestamp

    def __str__(self):
        out  = '\tHost: %s\n' % self.hostname
        out += '\t\tPaths:\n'
        for i in range(len(self.path_info)):
            for key, val in self.path_info[str(i)].items():
                out += '\t\t\t%i %s: %s\n' % (i, key, val)
        out += '\t\tAreas:\n'
        for area, info in self.area_info.items():
            out += '\t\t\tArea %s:\n' % area
            for key, val in info.items():
                 out += '\t\t\t\t%s: %s\n' % (key, val)
        return out

def get_se():
    config_se = Gratia.Config.getConfigAttribute("SiteName")
    if not config_se:
        raise Exception("SiteName attribute not found in ProbeConfig")
    return config_se

def get_rpm_version():
    cmd = "rpm -q gratia-probe-xrootd-storage --queryformat '%{VERSION}-%{RELEASE}'"
    fd = os.popen(cmd)
    output = fd.read()
    if fd.close():
        raise Exception("Unable to successfully query rpm for %s "
            "version information." % XRD_NAME)
    return output

version_cache = None
def get_version():
    global version_cache
    if version_cache != None:
        return version_cache
    config_version = Gratia.Config.getConfigAttribute("%sVersion" % XRD_NAME)
    if not config_version:
        config_version = get_rpm_version()
        #raise Exception("XrootdVersion attribute not found in ProbeConfig")
    version_cache = config_version
    return version_cache

class GratiaHandler(object):

    def __init__(self):
        self.host_info = {}
        self.cur_host = None
        self.cur_host_info = None
        self.stop = False
        self.stop_exception = None
        DebugPrint(4, "Finished with the GratiaHandler constructor.")

    def handle(self, key, val):
        DebugPrint(4, "GratiaHandler processed message: %s, %s" % (".".join(key),
            str(val)))
        if self.stop:
            raise self.stop_exception
        # Switch host if necessary.
        if key == ('info', 'host'):
           if val != self.cur_host:
               self.cur_host = val
               self.cur_host_info = self.host_info.get(val, None)
               if not self.cur_host_info:
                   self.cur_host_info = HostInfo(val)
                   self.host_info[val] = self.cur_host_info
               self.cur_host_info.update_time()
        if not self.cur_host:
            return

        # Parse OSS stats only
        if key[0] != 'oss':
            return
        # Parse path information
        elif key[1:] == ('paths',):
            for i in range(val):
                self.cur_host_info.add_path(str(i))
            return
        elif key[1] == 'paths':
            if len(key) < 4:
                return
            id = key[2]
            stat = key[3]
            self.cur_host_info.add_path_stat(id, stat, val)
        # Parse area information
        elif key[1] == 'space' and key[3] == 'name':
            # Format: oss.space.<integer>.name
            self.cur_host_info.add_area(key[2], val)
        elif key[1] == 'space':
            self.cur_host_info.add_area_stat(key[2], key[3], val)

    def summary(self):
        # Update our global timestamp so we don't have to do this in all the
        # Gratia send functions.
        global timestamp
        timestamp = time.time()

        DebugPrint(2, "%s Host info:" % XRD_NAME)
        to_delete = []
        for host, host_info in self.host_info.items():
            if timestamp - host_info.get_timestamp() > XRD_EXPIRE_MINUTES*60:
                to_delete.append(host)
        for host in to_delete:
            del self.host_info[host]
            DebugPrint(1, "** Deleted host %s that stopped reporting." % host)

        # Print out host info
        for host, host_info in self.host_info.items():
            DebugPrint(2, "**", host)
            DebugPrint(2, str(host_info))
            self.send_node_props(host_info)
        self.send_system_props()
        self.send_master_area()
        self.send_area_props()

    def send_node_props(self, host_info):
        """
        Send the storage information to Gratia for a single host.
        """
        se = get_se()
        version = get_version()
        name = host_info.get_hostname()
        unique_id = '%s:Pool:%s' % (se, name)
        parent_id = "%s:SE:%s" % (se, se)

        sa = StorageElement.StorageElement()
        sar = StorageElementRecord.StorageElementRecord()
        sa.UniqueID(unique_id)
        sa.Name(name)
        sa.SE(se)
        sa.SpaceType("Pool")
        sa.Implementation(XRD_NAME)
        sa.Version(version)
        sa.Status(XRD_STATUS)
        sa.ParentID(parent_id)
        sa.Timestamp(timestamp)
        sar.Timestamp(timestamp)
        sar.UniqueID(unique_id)
        sar.MeasurementType("raw")
        sar.StorageType("disk")
        sar.TotalSpace(1024*host_info.get_total_kb())
        sar.FreeSpace(1024*host_info.get_total_free_kb())
        sar.UsedSpace(1024*host_info.get_total_used_kb())
        Gratia.Send(sa)
        Gratia.Send(sar)

    def send_system_props(self):
        # Standard Gratia properties
        se = get_se()
        version = get_version()
        unique_id = "%s:SE:%s" % (se, se)
        parent_id = unique_id

        # Calculate the raw path totals for all hosts
        tot, used, free = 0, 0, 0
        for host_info in self.host_info.values():
            tot += host_info.get_total_kb()
            used += host_info.get_total_used_kb()
            free += host_info.get_total_free_kb()
        # Convert back to bytes
        tot *= 1024
        used *= 1024
        free *= 1024

        # Send out gratia information
        sa = StorageElement.StorageElement()
        sar = StorageElementRecord.StorageElementRecord()
        sa.UniqueID(unique_id)
        sa.Name(se)
        sa.SE(se)
        sa.SpaceType("SE")
        sa.Implementation(XRD_NAME)
        sa.Version(version)
        sa.Status(XRD_STATUS)
        sa.ParentID(parent_id)
        sa.Timestamp(timestamp)
        sar.UniqueID(unique_id)
        sar.MeasurementType("raw")
        sar.StorageType("disk")
        sar.TotalSpace(tot)
        sar.FreeSpace(free)
        sar.UsedSpace(used)
        sar.Timestamp(timestamp)
        Gratia.Send(sa)
        Gratia.Send(sar)

    def send_master_area(self):
        se = get_se()
        name = XRD_AREA_NAME
        space_type = 'Area'
        version = get_version()

        unique_id = '%s:%s:%s' % (se, space_type, name)
        parent_id = '%s:SE:%s' % (se, se)

        sa = StorageElement.StorageElement()
        sa.UniqueID(unique_id)
        sa.Name(name)
        sa.SE(se)
        sa.SpaceType(space_type)
        sa.Implementation(XRD_NAME)
        sa.Version(get_version())
        sa.Status(XRD_STATUS)
        sa.ParentID(parent_id)
        sa.Timestamp(timestamp)
        Gratia.Send(sa)

    def send_area_props(self):

        # Build the set of all area names
        all_areas = sets.Set()
        for host_info in self.host_info.values():
             all_areas.update(host_info.get_area_names())

        se = get_se()
        version = get_version()

        for area_name in all_areas:
            # Collect relevant area statistics
            tot, used, free, quot = 0, 0, 0, 0
            for host_info in self.host_info.values():
                info = host_info.get_area_info(area_name)
                area_tot = info.get('tot', 0)
                tot += area_tot
                free += 1024*info.get('free', 0)
                area_quot = info.get('qta', 0)
                if area_quot > area_tot:
                    quot += area_tot
                else:
                    quot += area_quot
                used += 1024*info.get('usg', 0)

            # Generic stuff for the StorageElement
            if quot:
                space_type = "Quota"
            else:
                space_type = "Directory"
            unique_id = '%s:%s:%s' % (se, space_type, area_name)
            parent_id = '%s:Area:%s' % (se, XRD_AREA_NAME)

            sa = StorageElement.StorageElement()
            sa.Name(area_name)
            sa.SE(se)
            sa.UniqueID(unique_id)
            sa.ParentID(parent_id)
            sa.SpaceType(space_type)
            sa.Implementation(XRD_NAME)
            sa.Version(version)
            sa.Status(XRD_STATUS)
            sa.Timestamp(timestamp)
            sar = StorageElementRecord.StorageElementRecord()
            sar.UniqueID(unique_id)
            sar.MeasurementType("logical")
            sar.StorageType("disk")
            sar.TotalSpace(tot)
            sar.FreeSpace(free)
            sar.UsedSpace(used)
            sar.Timestamp(timestamp)
            Gratia.Send(sa)
            Gratia.Send(sar)

def parse_opts():
    parser = optparse.OptionParser()
    parser.add_option("-d", "--daemon", help="Run as daemon; automatically " \
        "background the process.", default=False, action="store_true",
        dest="daemon")
    parser.add_option("-l", "--logfile", help="Log file location.  Defaults " \
        "to the Gratia logging infrastructure.", dest="logfile")
    parser.add_option("-i", "--input", help="Input file name; if this option" \
        " is given, the process does not listen for UDP messages", dest="input")
    parser.add_option("-p", "--port", help="UDP Port to listen on for " \
        "messages.  Overridden by Gratia ProbeConfig.", type="int",
        default=3333, dest="port")
    parser.add_option("--gratia_config", help="Location of the Gratia config;" \
        " defaults to /etc/gratia/xrootd-storage/ProbeConfig",
        dest="gratia_config")
    parser.add_option("-b", "--bind", help="Listen for messages on a " \
        "specific address; defaults to 0.0.0.0.  Overridden by Gratia " \
        "ProbeConfig", default="0.0.0.0", dest="bind")
    parser.add_option("-v", "--verbose", help="Enable verbose logging to " \
        "stdout.", default=False, action="store_true", dest="verbose")
    parser.add_option("--print_only", help="Only print data recieved; do not" \
        " send to Gratia.", dest="print_only", action="store_true")
    parser.add_option("-r", "--report_period", help="Time in minutes between" \
        " reports to Gratia.", dest="report_period", type="int")

    opts, args = parser.parse_args()

    # Expand our input paths:
    if opts.input:
        opts.input = os.path.expanduser(opts.input)
    if opts.logfile:
        opts.logfile = os.path.expanduser(opts.logfile)
    if opts.gratia_config:
        opts.gratia_config = os.path.expanduser(opts.gratia_config)

    # Adjust sleep time as necessary
    if opts.report_period:
        global SLEEP_TIME
        SLEEP_TIME = opts.report_period*60

    # Initialize logging
    logfile = "/var/log/gratia/xrootd-storage.log"
    if opts.logfile:
        logfile = opts.logfile
    path, _ = os.path.split(logfile)
    if path and not os.path.exists(path):
        os.makedirs(path)
    try:
        fp = open(logfile, 'w')
    except Exception as e:
        raise Exception("Could not open %s-Storage logfile, %s, for " \
            "write.  Error: %s." % (XRD_NAME, logfile, str(e)))
    global log
    log = logging.getLogger("XrdStorage")
    log.setLevel(logging.DEBUG)
    formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
    handler = logging.handlers.RotatingFileHandler(
        logfile, maxBytes=20*1024*1024, backupCount=5)
    handler.setLevel(logging.DEBUG)
    handler.setFormatter(formatter)
    log.addHandler(handler)
    if opts.verbose:
        handler = logging.StreamHandler()
        handler.setLevel(logging.DEBUG)
        handler.setFormatter(formatter)
        log.addHandler(handler)

    # Initialize Gratia
    gratia_config = None
    if opts.gratia_config and os.path.exists(opts.gratia_config):
        gratia_config = opts.gratia_config
    else:
        tmp = "/etc/gratia/xrootd-storage/ProbeConfig"
        if os.path.exists(tmp):
            gratia_config = tmp
    if not gratia_config:
        raise Exception("Unable to find a suitable ProbeConfig to use!")
    Gratia.Initialize(gratia_config)

    return opts

def main():
    opts = parse_opts()

    if opts.print_only:
        my_handler = print_handler
    else:
        gratia_handler = GratiaHandler()
        my_handler = gratia_handler.handle

    # Do all the daemon-specific tests.
    if not opts.input:
        if opts.daemon:
            # Test the socket first before we daemonize and lose the ability
            # to alert the user of potential errors.
            test_udp_socket(port=opts.port, bind=opts.bind)
            # Test to see if the pidfile is writable.
            pidfile = "/var/run/xrd_storage_probe.pid"
            open(pidfile, 'w').close()
            daemonize(pidfile)
            # Must re-initialize here because we changed processes and lost
            # the previous thread
            if not opts.print_only:
                gratia_handler = GratiaHandler()
                my_handler = gratia_handler.handle

    se = get_se()
    version = get_version()
    DebugPrint(1, "Running %s version %s for SE %s." % (XRD_NAME, version, se))

    handler = XrdDataHandler(my_handler)
    if not opts.input:
        try:
            udpserverthread = threading.Thread(target=udp_server,args=(handler, opts.port, opts.bind))
            udpserverthread.setDaemon(True)
            udpserverthread.setName("Gratia udp_server Thread")
            udpserverthread.start()
        finally:
            if not opts.print_only:
                gratia_handler.summary()
    else:
        input = opts.input
        try:
            fp = open(input, 'r')
        except Exception as e:
            raise Exception("Could not open input file, %s, for read due to " \
                "exception %s." % (input, str(e)))
        try:
            file_handler(my_handler, opts.input)
        finally:
            if not opts.print_only:
                gratia_handler.summary()
    try:
        while 1 == 1:
            try:
                DebugPrint(0, "Will send new Gratia data in %i seconds." % SLEEP_TIME)
                time.sleep(SLEEP_TIME)
            except (KeyboardInterrupt, SystemExit) as _:
                raise
            except Exception as e:
                gratia_log_traceback(lvl=1)

            if opts.print_only:
                continue

            try:
                DebugPrint(1, "Creating a new Xrootd Gratia report.")
                gratia_handler.summary()
            except (KeyboardInterrupt, SystemExit) as _:
                gratia_log_traceback(lvl=0)
                gratia_handler.stop_exception()
                gratia_handler.stop = True
                raise
            except Exception as e:
                gratia_log_traceback(lvl=1)
    finally:
        if not opts.print_only:
            gratia_handler.summary()


if __name__ == '__main__':
    try:
        main()
    except:
        gratia_log_traceback()
        raise
