#!/usr/bin/env python3
'''
Collect XRootD stats and report to HTCondor collector

- StashCache reporter is called periodically from cron or timer
    - It advertises the usage data as a Condor ad to an OSG collector
    - Runs every 5 minutes (since the ad expires after 15 min)
- Reporter maintains a cache of the most recently collected usage data
    - Saved to /tmp as JSON
- If the cached usage data is current (less than two hours old), the reporter
  advertises the cached data and exits
- If the cached usage data is old, the reporter updates the cache, and also
  runs a process to advertise the most recent cached data every ~5 minutes.
  Even if the stats collection takes over 15 minutes, the ad won't expire from
  the collector.
'''

# Viewing stats in collector:
# condor_status -pool collector1.opensciencegrid.org:9619 -any xrootd@hcc-stash.unl.edu -l

import argparse
import json
import logging
import os
import multiprocessing
import random
import sys
import socket
import tempfile
import time

# Forking (multiprocessing) may require XRootD fork handler
# Must set before importing XRootD.client
#os.putenv('XRD_RUNFORKHANDLER', '1')

import classad
import htcondor
import xrootd_cache_stats

# Ad expires from collector after 15 minutes
AD_REFRESH_INTERVAL = 600 + random.randrange(180) # seconds

# URLs of central OSG collectors (used if OSG_COLLECTOR_HOST not defined)
CENTRAL_COLLECTORS = 'collector1.opensciencegrid.org:9619,collector2.opensciencegrid.org:9619'

# StashCache version (set at build time)
STASHCACHE_VERSION = '1.2.0'

class StashCacheReporter(object):
    '''Collect XRootD stats, cache state for performance, and report to collector'''
    def __init__(self, cache_path='/stash', collectors=CENTRAL_COLLECTORS,
                 cache_walk_interval=2*60*60, log_level=30):
        self.logger = logging.getLogger(__name__)
        self.logger.setLevel(log_level)
        self.cache_path = cache_path
        self.state_file = self.get_state_filename(cache_path)
        self.collectors = collectors
        self.cache_walk_interval = cache_walk_interval

        manager = multiprocessing.Manager()
        self.state = manager.dict()

        self.load_state()

    @staticmethod
    def get_state_filename(cache_path):
        '''Generate state filename for cache_path:
        - Replace non-alphanumeric characters in cache_path with dash
        - Prepend temporary directory'''
        state_file = 'stashcache-reporter'
        state_file += ''.join(c if c.isalnum() else '.' for c in cache_path)
        state_file += '.json'
        return os.path.join(tempfile.gettempdir(), state_file)

    def load_state(self):
        '''Load previous run's state from JSON'''
        try:
            with open(self.state_file, 'rb') as fptr:
                self.state.update(json.load(fptr))
                self.logger.debug('Loaded cache state from %s', self.state_file)
                return True
        except IOError as err:
            self.logger.debug('Could not open cache state file %s: %s', self.state_file, err)
        except ValueError as err:
            self.logger.warning('Could not parse cache state file %s: %s', self.state_file, err)

        return False

    def save_state(self):
        '''Save current run's state to JSON'''
        with open(self.state_file, 'wb') as fptr:
            json.dump(dict(self.state), fptr)
            self.logger.debug('Wrote cache state for %s to %s', self.cache_path, self.state_file)

    def stat_collector(self):
        '''Advertise stats, and walk cache to update stats if expired'''
        self.advertise_cache_stats()

        if self.state.get('last_scan', 0) + self.cache_walk_interval < time.time():
            self.walk_cache()
            self.advertise_cache_stats()
        else:
            self.logger.debug('Skipping cache stat collection, state not expired')

    def walk_cache(self):
        '''Walk cache directory to collect stats and update state'''
        xrootd_url = 'root://' + socket.getfqdn()

        self.logger.debug('Collecting cache stats from %s', self.cache_path)
        start_time = time.time()
        cache_ad = xrootd_cache_stats.collect_cache_stats(xrootd_url, self.cache_path)
        cache_ad['STASHCACHE_DaemonVersion'] = STASHCACHE_VERSION
        end_time = time.time()
        self.logger.info('Cache stat collection for %s took %.2f seconds', self.cache_path,
                         end_time - start_time)

        if cache_ad['ping_response_status'] == 'ok':
            self.logger.debug('XRootD server (%s) status: OK', xrootd_url)
            # json can't serialize classads. Need to convert them to/from str.
            self.state['cache_ad'] = str(cache_ad)
            self.state['last_scan'] = end_time
            self.save_state()
        else:
            logging.warning('No heartbeat from XRootD server')

    def advertise_cache_stats(self):
        '''Send cache ad to collector'''
        if 'cache_ad' not in self.state:
            return False

        # classad cannot parse unicode
        cache_ad_txt = self.state['cache_ad'].encode('ascii', 'ignore')
        cache_ad = classad.parseOne(cache_ad_txt)

        coll = htcondor.Collector(self.collectors)

        self.logger.info('Advertising StashCache ads to collectors: %s', self.collectors)
        # Save and restore euid, as advertise() changes it
        old_uid = os.geteuid()
        try:
            coll.advertise([cache_ad], 'UPDATE_STARTD_AD')
        except ValueError as err:
            self.logger.warning('Could not advertise to %s: %s', self.collectors, err)
        os.seteuid(old_uid)

        return True

    def advertiser_loop(self):
        '''If we have a cache ad, loop forever, sending cache ad to collector'''
        while True:
            self.logger.debug('Sleeping %d seconds before ad refresh', AD_REFRESH_INTERVAL)
            time.sleep(AD_REFRESH_INTERVAL)
            self.advertise_cache_stats()

def main():
    '''Main function'''
    args = parse_args()

    # Enable logging
    log_level = max(3 - args.verbose_count, 0) * 10
    log_format = '%(levelname)s: %(message)s'
    logging.basicConfig(level=log_level, format=log_format)

    # Check for existence of host cert/key pair
    for pki in 'cert', 'key':
        pki_path = '/etc/grid-security/host%s.pem' % pki
        if not os.path.exists(pki_path):
            logging.error('Could not find host %s at %s', pki, pki_path)
            sys.exit(1)

    scr = StashCacheReporter(cache_path=args.cache_path, collectors=args.collectors,
                             cache_walk_interval=args.cache_walk_interval,
                             log_level=log_level)

    # Periodically advertise stats while collector runs
    ad_reporter = multiprocessing.Process(target=scr.advertiser_loop)
    ad_reporter.start()

    while True:
        # Start collecting stats and wait for it to finish
        stat_collector = multiprocessing.Process(target=scr.stat_collector)
        stat_collector.start()
        stat_collector.join()

        if args.one_shot:
            break
        else:
            # Sleep before refreshing cache, but keep advertising
            logging.debug('Sleeping %d seconds before cache refresh', args.cache_walk_interval)
            time.sleep(args.cache_walk_interval)

    ad_reporter.terminate()

def parse_args():
    '''Parse CLI options'''
    parser = argparse.ArgumentParser()

    parser.add_argument('--one-shot', default=False, action='store_true',
                        help='Run once, rather than persistently')
    parser.add_argument('--cache-path', default='/stash',
                        help='Path to the local XRootD stashcache directory')
    parser.add_argument('--collectors',
                        default=htcondor.param.get('OSG_COLLECTOR_HOST', CENTRAL_COLLECTORS),
                        help='List of HTCondor collectors to receive ads')
    parser.add_argument('--cache-walk-interval', type=int,
                        default=2*60*60, # 2 hours
                        help='Minimum seconds between walking the cache to refresh stats')
    parser.add_argument('-v', '--verbose', dest='verbose_count',
                        action='count', default=0,
                        help='Increase log verbosity (repeatable)')
    return parser.parse_args()

if __name__ == '__main__':
    main()
