#!/usr/bin/env python3
import os
import sys
import glob
import argparse
import configparser
import logging
import logging.handlers
import uproot
import subprocess
import zlib
import lzma
import xxhash
from multiprocessing import Process, Value, Lock
import sqlite3
from sqlite3 import Error
import time
import re
import pdb

def get_byte_ranges(byte_map_lines, blocksize):
    error = 0
    xcounter = xcounter_aux =pcounter = pcounter_aux =0

    xflag = True
    pflag = False

    xstart = xend = pstart = pend = 0
    range_list= []

    i = 0
    for line in byte_map_lines:
        for c in line:
            if c == "x":
                if xflag == True:
                    xcounter_aux +=1
                else:
                    pcounter += pcounter_aux
                    pend = i - 1
                    xstart = i
                    xcounter_aux = 1
                    xflag = True
                    pflag=False
                    #range_list.append(["p", pstart, pend])
            elif c == ".":
                if pflag == True:
                    pcounter_aux +=1
                else:
                    xcounter += xcounter_aux
                    xend = i - 1
                    pstart = i
                    pcounter_aux = 1
                    pflag = True
                    xflag=False
                    range_list.append([xstart*blocksize, (xend+1)*blocksize])
            else:
                continue

            if c== "x" or c ==".":
                i+=1
    if xflag == True:
        xend = i-1
        range_list.append([xstart*blocksize, (xend+1)*blocksize])

    if not range_list:
        log.error("Cannot get the byte_range or byte range is empty")
        error = 1

    return range_list, error


def parse_cinfo(filename):
    error = 0
    full_file = byte_ranges = None
    # Process the cinfo file wih the command: xrdpfc_print
    lines, error = get_xrdpfc_print_ouput(filename)

    if error == 0:
        # Extract the block size from the output of "xrdpfc_print" command
        blocksize, error = get_block_size(lines)

        if error == 0:
            # Check wheter the file is fully downloaded or its a partial file
            full_file, error = is_file_fully_downloaded(lines)

            if error == 0:
                # Extract the lines containing the byte map
                byte_map, error = get_byte_map(lines)

                if error == 0:
                    # Use the bytemap to calculate byte ranges within the file
                    byte_ranges, error = get_byte_ranges(byte_map, blocksize)

    return full_file, byte_ranges, error

def get_byte_map(lines):
    error = 0
    byte_map = []
    flag_map_start = False

    for i in range(0, len(lines)):
        if "access" in lines[i]:
            break
        if flag_map_start == True:
            new_line = ""
            for c in lines[i]:
                if c == "x" or c ==".":
                    new_line+= c
            byte_map.append(new_line)

        # Find the inmediate line before the byte map starts
        elif "012345" in lines[i]:
            flag_map_start = True
        else:
            continue

    if not byte_map:
        log.error("Cannot get byte_map or byte_map empty")
        error = 1

    return byte_map, error


def get_block_size(lines):
    error = 0
    blocksize = -1
    log.debug("Extracting block size")
    for line in lines:
        if "bufferSize" in line or "buffer_size" in line:
            log.debug("Found a line with 'bufferSize': %s", line)
            line_splitted = line.split()
            for i in range(0, len(line_splitted)):
                if "bufferSize" in  line_splitted[i] or "buffer_size" in  line_splitted[i]:
                    blocksize = int(line_splitted[i+1])
                    if line_splitted[i+2] == "kB,":
                        blocksize = blocksize * 1024
                    log.debug("Found blocksize = %d", blocksize)

    if blocksize == -1:
        log.error("Cannot find blocksize in 'xrdpfc_print' output")
        error = 1

    return blocksize, error

# Chek whether the file is fully downloaded
def is_file_fully_downloaded(lines):
    error = 0
    nBlocks = -1
    nDownloaded = -2
    is_fully_downloaded = False
    log.debug("Checking if file is fully downloaded")
    for line in lines:
        if "bufferSize" in line or "buffer_size" in line:
            line_splitted = line.split()
            for i in range(0, len(line_splitted)):
                if "nBlocks" in line_splitted[i] or "n_blocks" in  line_splitted[i]:
                    if "nBlocks" in line_splitted[i]:
                        nBlocks = int(line_splitted[i+1])
                    else:
                        nBlocks = int(line_splitted[i+1][:-1])
                    log.debug("Found nBlocks = %d", nBlocks)

                elif "nDownloaded" in line_splitted[i] or "n_downloaded" in  line_splitted[i]:
                    if "nDownloaded" in line_splitted[i]:
                        nDownloaded = int(line_splitted[i+1])
                    else:
                        nDownloaded = int(line_splitted[i+1][:-1])
                    log.debug("Found nDownloaded = %d", nDownloaded)

    if nBlocks == -1 or nDownloaded == -2:
        error = 1

    # If nBlocks is equals to nDonwloaded then the file is fully downloaded
    return (nBlocks == nDownloaded), error

def get_xrdpfc_print_ouput(filename):
    error = 0
    lines = []
    try:
        out = subprocess.Popen(['xrdpfc_print', '-v', filename], stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
    except Exception as e:
        if "No such file or directory" in str(e):
            log.error("Cannot find 'xrdpfc_print', exiting...")
        else:
            log.error("An exception was raised when trying to execute 'xrdpfc_print', exiting...")
        sys.exit(1)

    stdout, stderr = out.communicate()
    if "not cinfo file" in stdout:
        log.error("Something went wrong when parsing cinfo file: %s", filename)
        error = 1
    else:
        log.debug("Parsing cinfo file: %s", filename)
        lines = stdout.split("\n")

    return lines, error

# Return a list of files, with a given @extension, found under @path
def list_files_recursively(path, extension):
    filepath_list = []
    for x in os.walk(path):
        for filepath in glob.glob(os.path.join(x[0], "*"+extension)):
            filepath_list.append(filepath)

    return filepath_list


def basket_in_file(range_list, basket_start, basket_length):
    status_in_file = 0
    basket_end = basket_start + basket_length
    for i in range_list:
        if  basket_start >= i[0] and basket_end <= i[1]:
                status_in_file = 1
                break
        elif basket_start >= i[0] and basket_start <= i[1]:
            status_in_file = 2
            break
        elif basket_end >= i[0] and basket_end <= i[1]:
            status_in_file = 3
            break

    return status_in_file


def get_list_of_baskets_in_branch(branch, byte_ranges, is_full_file):
    num_baskets = branch.numbaskets

    list_baskets = []
    for i in range(0, num_baskets):
        key = branch._threadsafe_key(i, None, True)
        if key.__class__.__name__ == '_RecoveredTBasket':
            continue
            #TODO: do something
        else:
            class_name = key.source.__class__.__name__
            if class_name == "MemmapSource":
                continue
                # Basket isn't compressed
            else:
                basket_start = key.source._cursor.index
                basket_length = branch.basket_compressedbytes(i)
                if is_full_file == True or basket_in_file(byte_ranges, basket_start, basket_length) == 1:
                    #TODO:
                    # append also the length of the basket to compare it with
                    # the size read from the header of the basket

                    # 9 is the size of the header
                    list_baskets.append(basket_start+9)
    return list_baskets


def recursive_branch(branch, byte_ranges, is_full_file):
    list_baskets = []
    if branch:
        for sub_branch in branch.values():
            baskets_in_branch = recursive_branch(sub_branch, byte_ranges, is_full_file)
            list_baskets += baskets_in_branch
    else:
        baskets_in_branch = get_list_of_baskets_in_branch(branch, byte_ranges, is_full_file)
        list_baskets += baskets_in_branch

    return list_baskets


def get_list_of_baskets_in_file(filename,  byte_ranges, is_full_file):
    corrupted = False
    f = uproot.open(filename)
    list_baskets = []

    for tree in f.values():
        for branch in tree.values():
            baskets_in_branch = recursive_branch(branch, byte_ranges, is_full_file)
            list_baskets += baskets_in_branch

    # Remove duplicates
    list_baskets = list(dict.fromkeys(list_baskets))

    # Sort
    list_baskets.sort()
    log.info("done sorting. Starting to check baskets")

    return list_baskets


def convert_checksum(checksum_8bytes):
    return sum( ord(checksum_8bytes[j]) * pow(2, i)
                for j, i in enumerate(reversed(range(0, 64, 8))) )


def check_baskets(baskets_list, shrd_basket_index, chunk, num_baskets, shrd_corrupted, lock, filename):
    log.debug("process : "+str(os.getpid())+ " starting check_baskets")
    fd = open(filename)
    finished = False
    corrupted = False
    while(finished == False and corrupted==False):
        lock.acquire()

        # If somebody else found something corrupted or there are no more baskets to analyze
        if shrd_corrupted.value == 1 or shrd_basket_index.value >= num_baskets:
            log.debug("process : "+str(os.getpid())+ " finishing here. shrd_corrupted = "+str(shrd_corrupted.value)+" shrd_basket_index= "+str(shrd_basket_index.value))
            finished = True
            lock.release()
            break

        start = shrd_basket_index.value
        shrd_basket_index.value += chunk
        lock.release()

        stop = start + chunk
        if stop > num_baskets:
            stop = num_baskets
        for seek in baskets_list[start:stop]:
            fd.seek(seek-9)
            header = fd.read(9)
            algo_bytes = header[0:2]
            c1 = int(ord(header[3]))
            c2 = int(ord(header[4]))
            c3 = int(ord(header[5]))
            u1 = int(ord(header[6]))
            u2 = int(ord(header[7]))
            u3 = int(ord(header[8]))
            num_uncompressed_bytes = u1 + (u2 << 8) + (u3 << 16)
            num_compressed_bytes = c1 + (c2 << 8) + (c3 << 16)

            if algo_bytes == "ZL":
                fd.seek(seek)
                try:
                    zlib.decompress(fd.read(num_compressed_bytes))
                except Exception as e:
                    corrupted = True
                    lock.acquire()
                    shrd_corrupted.value = 1
                    log.info("process: " +str(os.getpid()) + ", Corrupted Basket on Seek: "+str(seek)+ ". Error message: "+str(e))
                    lock.release()
                    break

            # lzma
            elif algo_bytes == "XZ":
                fd.seek(seek)
                try:
                    lzma.decompress(fd.read(num_uncompressed_bytes))
                except Exception as e:
                    corrupted = True
                    lock.acquire()
                    shrd_corrupted.value = 1
                    log.info("process: " +str(os.getpid()) + ", Corrupted Basket on Seek: "+str(seek)+ ". Error message: "+str(e))
                    lock.release()
                    break

            # lz4
            elif algo_bytes == "L4":
                # Baskets compressed with this algorithm have an extra 8-byte header containing
                # a checksum of the compressed data, meaning that the length of the basket is
                # actually 8 bytes shorter
                num_compressed_bytes -= 8

                # Read the checksum from the header of the basket
                fd.seek(seek)
                checksum_8bytes = fd.read(8)
                # Convert the 8 separated byte into a single number
                checksum = convert_checksum(checksum_8bytes)

                # Get the compressed bytes of the basket
                fd.seek(seek+8)
                compressed_bytes = fd.read(num_compressed_bytes)
                # Calculate the checksum of the compressed bytes and compare it to the one
                # in the header of the basket
                if xxhash.xxh64(compressed_bytes).intdigest() != checksum:
                    lock.acquire()
                    shrd_corrupted.value = 1
                    log.info("process: " +str(os.getpid()) + ", Corrupted Basket on Seek: "+str(seek))
                    lock.release()
                    break

            elif algo_bytes == "CS":
                log.info("process: " +str(os.getpid()) +", Unsupported very OLD algorithm: "+algo_bytes+" , skipping basket...")

            else:
                ("process: " +str(os.getpid()) + ", Unsupported algorithm. skipping...")

    fd.close()


def create_table(conn, create_table_sql):
    try:
        c = conn.cursor()
        c.execute(create_table_sql)
    except Error as e:
        log.error("Cannot create database table, verify you have write access on database file. Error message: "+str(e))
        log.error(create_table_sql)
        sys.exit(1)


def create_connection(db_file):
    conn = None
    try:
        conn = sqlite3.connect(db_file)
        return conn
    except Error as e:
        log.error("Cannot create database connection, verify you have write access on database file. Error message: "+str(e))
        sys.exit(1)
    return conn


def create_db(filename):
    sql = """ CREATE TABLE IF NOT EXISTS files (
                                        id integer PRIMARY KEY,
                                        filename text,
                                        last_check_ts integer,
                                        last_modification_date integer,
                                        checksum text
                                    ); """

    # create a database connection
    conn = create_connection(filename)

    # create tables
    if conn is not None:
        # create projects table
        create_table(conn, sql)
    else:
        log.error("Cannot create database connection.")

def calculate_checksum(filename):
    fd = open(filename)
    file_bytes = fd.read()
    checksum = str(xxhash.xxh64(file_bytes).intdigest())
    return checksum

def insert_in_db(conn, root_filepath, ts, last_modification_date, file_checksum):
    sql = ''' INSERT INTO files(filename, last_check_ts, last_modification_date, checksum)
              VALUES(?,?,?,?) '''

    record =(root_filepath, ts, last_modification_date, file_checksum)
    with conn:
        cur = conn.cursor()
        cur.execute(sql, record)
    log.debug("inserted record: "+str(cur.lastrowid))

# When a file hasn't changed but a quick checksum comparison has been done, we need to update the 'last_check_ts'
# of the file's record
def update_last_check_ts(conn, root_filepath, ts):
    sql = ''' UPDATE files
                SET last_check_ts = ?
                WHERE filename = ?'''

    record =(ts, root_filepath)
    with conn:
        cur = conn.cursor()
        cur.execute(sql, record)
    log.debug("updated record's last check ts: "+root_filepath)

# When a file has changed (new blocks have been downloaded) its record needs to be updated
def update_db(conn, root_filepath, ts, last_modification_date, file_checksum):
    sql = ''' UPDATE files
                SET last_check_ts = ?,
                    last_modification_date = ?,
                    checksum = ?
                WHERE filename = ?'''

    record =(ts, last_modification_date, file_checksum, root_filepath)
    with conn:
        cur = conn.cursor()
        cur.execute(sql, record)
    log.debug("updated record: "+root_filepath)



def get_file_from_db(conn, root_filepath):
    sql = ''' SELECT id, last_check_ts, last_modification_date, checksum
              FROM files
              WHERE filename =?'''

    cur = conn.cursor()
    cur.execute(sql, (root_filepath,))
    rows = cur.fetchall()
    if len(rows) > 1:
        log.error("Duplicated file in the DB: "+root_filepath)
    if len(rows) == 0:
        ret = (None, -1, -1, -1)
    else:
        ret = (rows[0])
    return ret

# Removes a record from the database, it is used when a file is found to be corrupted and removed from the chache
def remove_from_db(conn, root_filepath):
    sql = ''' DELETE FROM files
                WHERE filename = ?'''

    record =(root_filepath,)
    with conn:
        cur = conn.cursor()
        cur.execute(sql, record)
    log.debug("file removed from db: "+root_filepath)


def safe_remove(filename, allowed_paths, is_symlink):
    def actual_remove(filename):
        try:
            os.remove(filename)
            if is_symlink:
                log.info("Symlink removed: "+filename)
            else:
                log.info("File removed: "+filename)
        except Exception as e:
            log.error("Something went wrong when trying to remove: "+filename+"\n Exception message is:  "+str(e))

    ret = 0
    if allowed_paths:
        # This has the list of allowed paths to remove from
        # only files within these paths can be removed
        list_of_allowed_paths = allowed_paths.split(",")
        flag_matched = False
        for path in list_of_allowed_paths:
            if filename.startswith(path):
                flag_matched = True
                actual_remove(filename)
                break
        if not flag_matched:
            log.warning("Attempt to remove file in forbidden path:"+filename)
            ret = 1
    else:
        # Removing for anywhere is allowed
        actual_remove(filename)

    return ret

# Removes a file from the cache, it is used when a file is found to be corrupted
def remove_file(filename, allowed_paths, is_full_file):
    def remove_single(filename, allowed_paths):
        # Check that the file still exists:
        if os.path.isfile(filename):
            # In case of a chain of symbolic links e.g. a -> b -> c
            # Remove all links
            stop = 0
            while os.path.islink(filename) and stop == 0:
                # Get symlink target
                #aux = os.path.dirname(filename)
                aux = os.path.abspath(os.readlink(filename))
                stop = safe_remove(filename,allowed_paths, True)
                filename = aux
            if stop == 0:
                # Remove the real file
                safe_remove(filename,allowed_paths, False)
        elif os.path.islink(filename):
            # This is the case of a broken link
            log.info("file to remove is actually a broken symlink: "+filename)
            safe_remove(filename, allowed_paths, True)
        else:
            log.error("cannot find file to remove: "+filename)

    remove_single(filename, allowed_paths)
    if not is_full_file:
        cinfo_filepath = filename+".cinfo"
        remove_single(cinfo_filepath, allowed_paths)

def find_xrootd_config_file():
    # First get the configuration file being used by xrootd
    out = subprocess.Popen(['ps', '-u','xrootd', '-o', 'command='], stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
    stdout, stderr = out.communicate()

    xrootd_config_file = ""
    if out.returncode != 0:
        log.error("Something went wrong when executing ps")
    else:
        log.debug("Getting xrootd config file from ps")
        xrootd_p = None
        for p in stdout.split('\n'):
            splitted = p.split()
            if len(splitted) >0 and "xrootd" in splitted[0]:
                xrootd_p = p
                break

        if xrootd_p is not None:
            # Look for the config file passed to xrootd, as either "-c foo.cfg"
            # (two parts) or "-cfoo.cfg" (one part).
            # Return the first match.

            command_chunks = xrootd_p.split()
            dash_c = False
            for chunk in command_chunks:
                if dash_c:
                    xrootd_config_file = chunk
                    break
                if chunk == "-c":
                    dash_c = True
                elif chunk[0:2] == "-c":
                    xrootd_config_file = chunk[2:]
                    break

            if xrootd_config_file:
                log.info("found xrootd configuration file:"+ xrootd_config_file)
            else:
                log.error("Cannot find the config file used by the xrootd running process")
        else:
            log.error("Cannot find a running xrootd process from which obtain the configuration file used")

    return xrootd_config_file


# When the configuration parameter 'path' is set to auto, we will try to set it to xrootd's configuration parameter
# 'oss.localroot'
def get_localroot_from_xrootd_config_file(xrootd_config_file):
    # Use xrootd config file to find the 'oss.localroot' parameter
    out = subprocess.Popen(['cconfig', '-c', xrootd_config_file], stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
    stdout, stderr = out.communicate()

    path=""
    if out.returncode !=0:
        log.error("Something went wrong when executing cconfig -c "+ xrootd_config_file)
    else:
        log.debug("Parsing cconfig output to get 'oss.localroot'")

        for line in stdout.split('\n'):
            if "oss.localroot" in line:
                if len(line.split()) > 0:
                    path= line.split()[1]
                    # no break intentionally, it could be more than one
                    # we want to take the last one.
    return path

def get_pfc_spaces_from_xrootd_config_file(xrootd_config_file):
    # 1.  Use xrootd config file to find the 'pfc.spaces' parameter
    out = subprocess.Popen(['cconfig', '-c', xrootd_config_file], stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
    stdout, stderr = out.communicate()

    pfc_spaces=""
    if out.returncode != 0 :
        log.error("Something went wrong when executing cconfig -c "+ xrootd_config_file)
    else:
        log.debug("Parsing cconfig output to get 'pfc.spaces'")
        for line in stdout.split('\n'):
            if "pfc.spaces" in line:
                if len(line.split()) == 3:
                    data= line.split()[1]
                    meta= line.split()[2]
                    # no break intentionally, it could be more than one
                    # we want to take the last one.
                else:
                    log.error("pfc.spaces do not have the expected format: pfc.spaces data metadata")

        # 2. for both data and meta look for all definitions of oss.space
        set_of_paths=set()
        for line in stdout.split('\n'):
            if "oss.space "+data in line or "oss.space "+meta in line:
                if len(line.split()) == 3:
                    path= line.split()[2]
                    set_of_paths.add(path)
                    # no break intentionally, it could be more than one
                    # we want them all
                else:
                    log.error("oss.space does not have the expected format: oss.space data|metadata path")

        #Transform set to string of comma separated paths
        string_of_paths=""
        for path in set_of_paths:
            # Make sure all paths ends with "/"
            if path[-1] !=  "/":
                path+="/"
            string_of_paths+=path+","

        # Remove last comma
        pfc_spaces = string_of_paths[:-1]

    return pfc_spaces


# Argument parsing
def parseargs():

    # Nothing is "required" because we don't want the args parser to enforce and argument
    # that could be defined in the config file, we need to enforce in our own way, see Step 5 on set_configuration()
    parser = argparse.ArgumentParser()

    parser.add_argument("--config", dest="config_file",
                         help="Path to configuration file")

    parser.add_argument("--logfile", dest="logfile",
                         help="Path to the files to store the logs")

    parser.add_argument("--path", dest="path",
                         help="Path to the files to analyze, for a single file use --rootfile")

    parser.add_argument("--rootfile", dest="rootfile",
                         help=".root filename to be analyzed, if this is a partial file a .cinfo file with the same \
                            name is expected. If this is a full file set --full_file")

    parser.add_argument("--full-file", dest="full_file", action="store_true", default=False,
                         help="Assume that the root file passed in --rootfile is fully downloaded thus does not requires a cinfo file (default: False)")

    parser.add_argument("--debug", dest="debug", action="store_true", default=False,
                         help="Set log to DEBUG mode (default: False)")

    parser.add_argument("--dry-run", dest="dry_run", action="store_true", default=False,
                         help="Just list the files to be analyzed but do not run the check")

    parser.add_argument("--max", dest="max", type=int, default=-1,
                         help="Maximum number of files to analyze (default:analyze all files)")

    parser.add_argument("--num-procs", dest="num_procs", type=int,
                         help="Number of parallel processes used to analyze the file(s) (default: 1)")

    parser.add_argument("--db", dest="db",
                         help="Database used to keep track of analyzed files (default: /var/lib/xcache_consistency_check/db.sql)")

    parser.add_argument("--last-check-threshold", dest="last_check_threshold", type=int,
                         help="If a file has been checked within less than the number of seconds defined here, \
                             the check on this file will be skipped (default: 86400(24hrs)")

    parser.add_argument("--dont-remove", dest="dont_remove", action="store_true", default=False,
                         help="When set, corrupted files are only logged but not removed")

    args = parser.parse_args()

    return args

def validate_path(path, parameter_name):
    if path:
        if not os.path.exists(os.path.dirname(path)):
            print("ERROR configuration parameter '"+parameter_name+"' expects a valid path")
            sys.exit(1)

def validate_positive_integer(var, parameter_name):
    if var < 1:
       print("ERROR configuration parameter '"+parameter_name+"' expects a positive integer")
       sys.exit(1)

def set_configuration():
    # Read command line arguments
    cmdline_args = parseargs()

    # Step 1. Set defaults:
    # We need to set everything as a string so that the interpolation works in the case
    # that a config file is provided. When no config file is provided, we need to change
    # for the correct data type manually.
    args = dict()
    args['path']                    = ""
    args['rootfile']                = ""
    args['db']                      = "/var/lib/xcache_consistency_check/db.sql"
    args['num_procs']               = "1"
    args['last_check_threshold']    = "86400"
    # When set to None, it logs to stdout
    args['logfile']                 = ""
    args['dont_remove']             = "False"
    # Allow to remove corrupted files that are in this list of directories
    # This can only be configured via configuration file. If empty, it will
    # allow to remove from anywhere
    args['remove_only_from']        = ""

    # The following params cannot be read from config file so we don't care about
    # interpolation issues and thus we can define the desired data types
    args['full_file']   = False
    args['max']         = -1
    args['debug']       = False
    args['dry_run']     = False

    # Step 2. Read config file
    config = configparser.ConfigParser(args)
    config_file = cmdline_args.config_file
    if(config_file):
        print("configuration file: "+config_file)
        config.read(config_file)
        args['path']                   = config.get('Main', 'path').replace('"','')
        args['db']                     = config.get('Main', 'db').replace('"','')
        args['logfile']                = config.get('Main', 'logfile').replace('"','')
        args['remove_only_from']       = config.get('Main', 'remove_only_from').replace('"','')
        try:
            args['num_procs']              = config.getint('Main', 'num_procs')
        except:
            print("ERROR configuration parameter 'num_procs' expects a positive integer")
            sys.exit(1)
        try:
            args['last_check_threshold']   = config.getint('Main', 'last_check_threshold')
        except:
            print("ERROR configuration parameter 'last_check_threshold' expects a positive integer")
            sys.exit(1)
        try:
            args['dont_remove']            = config.getboolean('Main', 'dont_remove')
        except:
            print("ERROR configuration parameter 'dont_remove' expects a boolean")
            sys.exit(1)
    else:
        # We need to convert strings into the desired data types since that didn't
        # happened above because no config file was provided
        args['num_procs']               = int(args['num_procs'])
        args['last_check_threshold']    = int(args['last_check_threshold'])
        if args['dont_remove'].lower() == "false":
            args['dont_remove'] = False
        else:
            args['dont_remove'] = True
        print("no config file provided")

    # Step 3. Overwrite with command line arguments
    if(cmdline_args.path):
        args['path'] = cmdline_args.path
    if(cmdline_args.logfile):
        args['logfile'] = cmdline_args.logfile
    if(cmdline_args.db):
       args['db'] = cmdline_args.db
    if(cmdline_args.num_procs):
       args['num_procs'] = cmdline_args.num_procs
    if(cmdline_args.last_check_threshold):
       args['last_check_threshold'] = cmdline_args.last_check_threshold
    if(cmdline_args.dont_remove):
       args['dont_remove'] = cmdline_args.dont_remove

    # Step 4. Get the arguments not allowed in the config file from the command line arguments
    if(cmdline_args.rootfile):
        args['rootfile']    = cmdline_args.rootfile
    if(cmdline_args.full_file):
        args['full_file']   = cmdline_args.full_file
    if(cmdline_args.max):
        args['max']         = cmdline_args.max
    if(cmdline_args.debug):
        args['debug']       = cmdline_args.debug
    if(cmdline_args.dry_run):
        args['dry_run']     = cmdline_args.dry_run

    # Step 5. Extra validation for some special arguments
    if args['path'] and args['path'] != "auto":
        validate_path(args['path'], "path")

    validate_path(args['logfile'], "logfile")
    validate_path(args['db'], "db")
    validate_path(args['rootfile'], "rootfile")
    validate_positive_integer(args['num_procs'], "num_procs")
    validate_positive_integer(args['last_check_threshold'], "last_check_threshold")

    # Step 6. Transform relative paths to absolute ones
    if args['path'] and args['path'] != "auto":
        relative_path = args['path']
        args['path'] = os.path.abspath(relative_path)+"/"
        print("Transforming relative path: "+relative_path+" to absolute: "+args['path'])

    if args['rootfile']:
        relative_path = args['rootfile']
        args['rootfile'] = os.path.abspath(relative_path)
        print("Transforming relative rootfile: "+relative_path+" to absolute: "+args['rootfile'])

    #TODO:
    # following prints have to go to stderr and make sure they appear on jornalctl
    # Step 7. Verify required conditions

    if not args['path'] and not args['rootfile']:
           print("ERROR: Either --path or --rootfile need to be defined")
           sys.exit(1)
    # If --path is not defined then --rootfile it is
    if not args['path']:
        # Check that the rootfile is not empty and that the file exist
        if not os.path.isfile(args['rootfile']):
            print("ERROR: --rootfile: "+args['rootfile']+" does not exist")
            sys.exit(1)
    #TODO:
    # remove_from and dont_remove are mutually ex
    # DB is set
    # We can write where the DB is supposed to be stored
    # We can write where logs are supposed to be stored

    return args

###############################################################################
#                               MAIN
###############################################################################
def main():
    # Get arguments
    args = set_configuration()
    #------ Configs --------------------------------------------------------------

    # Log level: {CRITICAL, ERROR, WARNING, INFO, DEBUG, NOTSET}
    if args['debug'] == True:
        log_lvl = logging.DEBUG
    else:
        log_lvl = logging.INFO

    log_format = '%(asctime)s %(levelname)s - %(message)s'
    log_format_time = '%Y%m%d %H:%M:%S'
    #----- Setup the logger and the log level ------------------------------------
    if args['logfile']:
        log_handler = logging.handlers.WatchedFileHandler(args['logfile'])
        formatter = logging.Formatter(log_format, log_format_time)
        log_handler.setFormatter(formatter)
        log.addHandler(log_handler)
        log.setLevel(log_lvl)
    else:
        logging.basicConfig(filename=args['logfile'], level=log_lvl, format=log_format, datefmt=log_format_time)
    #-----------------------------------------------------------------------------

    log.info("*******************************************")
    log.info("**** Xcache consistency check starting ****")
    log.info("*******************************************")

    # If both path and rootfile are set, disable 'path'
    if args['path'] and args['rootfile']:
       args['path'] = ""
       log.warning("Both 'path' and 'rootfile' are set, disabling 'path'")

    xrootd_config_file=""
    # If path is set to 'auto' try to find the path to the rootfiles
    if not args['rootfile'] and args['path'] == "auto":
        xrootd_config_file = find_xrootd_config_file()
        if not find_xrootd_config_file:
            log.error("Cannot get XRootD's configuration file from 'ps', is XRootD running?")
            sys.exit(1)
        localroot = get_localroot_from_xrootd_config_file(xrootd_config_file)
        if not localroot:
            log.error("Cannot get 'localroot' from XRootD's configuration file")
            sys.exit(1)
        if localroot[-1] != "/":
            localroot+="/"
        args['path'] = localroot

    # Figure out 'remove_only_from' and make sure all paths end with "/"
    if args['remove_only_from']:
        if args['remove_only_from'] == "auto":
            if not xrootd_config_file:
                xrootd_config_file = find_xrootd_config_file()
                if not xrootd_config_file:
                    log.error("Cannot get XRootD's configuration file from 'ps', is XRootD running?")
                    sys.exit(1)
            remove_paths = get_pfc_spaces_from_xrootd_config_file(xrootd_config_file)
            if not xrootd_config_file:
                log.error("Cannot get 'pfc_spacest' from XRootD's configuration file")
                sys.exit(1)
        else:
            remove_paths=""
            for remove_path in args['remove_only_from'].split(","):
                if remove_path[-1] != "/":
                    remove_path +="/"
                remove_paths+=remove_path+","
            # Remove last comma
            remove_paths = remove_paths[:-1]

        # Always put either 'path' or 'rootfile' in the list
        if args['path']:
            remove_paths+=","+args['path']
        else:
            remove_paths+=","+os.path.dirname(args['rootfile'])+"/"
        args['remove_only_from']=remove_paths

    else:
        log.debug("remove_only_from is not set")

    # Print the arguments
    log.info("********* Configuration parametes *********")
    for key in args.keys():
        #log.info(key+" : "+str(args[key])+" type: "+str(type(args[key])))
        log.info(key+" : "+str(args[key]))
    log.info("*******************************************")


    #------ DB setup--------------------------------------------------------------
    create_db(args['db'])
    conn = create_connection(args['db'])
    #-----------------------------------------------------------------------------


    #------ Find file(s)  ---------------------------------------------------------
    # If args['path'] is empty, that means we are analizing a single root file
    if not args['path']:
        rootfile = args['rootfile']
        log.debug("path is not defined, analyzing single file: "+rootfile)
        filepath_list = []
        filepath_list.append(rootfile)
    else:
        filepath_list  = list_files_recursively(args['path'], ".root")
        log.info("found: "+str(len(filepath_list))+" .root files in: "+args['path'])
    max_counter = 0


    #------ Dry Run  --------------------------------------------------------------
    if args['dry_run'] == True:
        for root_filepath in filepath_list:
            if args['max'] > 0 and  max_counter >= args['max']:
                break
            # Verify that there is a corresponfing .cinfo file
            cinfo_filepath = root_filepath+".cinfo"
            if args['full_file'] or os.path.isfile(cinfo_filepath):
                log.info("Analyzing file: "+ root_filepath)
            else:
                log.error("Cannot find a corresponding .cinfo file for root file:  %s", root_filepath)
            max_counter +=1

    #------ Real Run  -------------------------------------------------------------
    else:
        # For every root file
        for root_filepath in filepath_list:
            if args['max'] > 0 and max_counter >= args['max']:
                break
            # Verify that there is a corresponfing .cinfo file or that we can assume the file is full
            # so that we don't need a .cinfo file
            root_filename  = os.path.basename(root_filepath)
            cinfo_filepath = root_filepath+".cinfo"
            if args['full_file'] == True or os.path.isfile(cinfo_filepath):
                log.info("Analyzing file: "+ root_filename)
                # The file could have disappeared from the moment we listed it, and until
                # the moment it's its turn to be analyzed so let's check that it's still there
                if not os.path.isfile(root_filepath):
                    log.info("File: "+ root_filename+" does not longer exist, skipping..")
                    continue
                ### Step 1.0 Do I need to fully analyze this file or only verify the checksum?
                db_file_id = None
                # Get file's last modification date
                last_modification_date = int(os.path.getmtime(root_filepath))
                db_file_id, db_last_check_ts, db_last_modification_date, db_checksum = get_file_from_db(conn, root_filepath)

                # if the file is in the DB and the last modification date registerd in the db
                # is the same as the current one, it means that the file
                # hasn't change since the last analisis, so we just need to verify the checksums.
                ts = int(time.time())
                if db_file_id is not None:
                    # If we have checked this file recently (less than 'last_check_threshold' secs ago), then we skip the check
                    if ts - db_last_check_ts <= args['last_check_threshold']:
                        log.info("file %s, was analyzed less than 'last_check_threshold' seconds ago, skipping", root_filename)
                        continue
                    # If the current last-modification-date and the last-modification-date stored in the db is the same
                    # that means that the file hasn't changed (no new blocks have been downloaded) since the last check
                    # so we can perform a quick comparison of checkusms
                    elif last_modification_date == db_last_modification_date:
                        curr_checksum = calculate_checksum(root_filepath)
                        if db_checksum == curr_checksum:
                            log.info("OK file's checksum:  %s", root_filename)
                            # update last check timestamp of the file's record  to 'ts'
                            update_last_check_ts(conn, root_filepath, ts)
                            continue
                        # The modification date of the file hasn't changed  but somehow the file has
                        # that means the file is corrupted
                        else:
                            log.info("CORRUPTED file's checksum:  %s, db:%s, curr:%s", root_filepath, db_checksum, curr_checksum)
                            if args['dont_remove']:
                                log.info("File will not be removed, --dont-remove is set to True")
                            else:
                                log.info("Will attempt to remove file: %s", root_filepath)
                                remove_file(root_filepath, args['remove_only_from'], args['full_file'])
                                remove_from_db(conn, root_filepath)
                                continue
                    else:
                        # We need to re-analyze this file, since it has changed and last check was more than 'last_check_threshold'
                        # seconds ago.
                        log.debug("file %s, has changed since last analysis and last check was more that 'last_check_treshold' seconds ago", root_filepath)
                else:
                    log.debug("file %s, not in the DB", root_filepath)

                # Step 1.1 Calculate the byte ranges on the file unless the file is to be assumed
                # fully downloaded (see option --full-file)
                if args['full_file'] == True:
                    is_full_file = True
                    byte_ranges = None
                else:
                    log.debug("Starting to parse cinfo file")
                    is_full_file, byte_ranges, error = parse_cinfo(cinfo_filepath)
                    if error != 0:
                        log.error("Cannot parse cinfo file:" +cinfo_filepath+"\n skipping file")
                        continue
                # Step 2. Create a list of baskets in the file
                list_of_baskets = get_list_of_baskets_in_file(root_filepath, byte_ranges, is_full_file)

                # Step 3. Look for a corrupted basket within the list
                chunk = 10
                basket_index = Value('i', 0)
                corrupted_flag = Value('i', 0)
                lock = Lock()

                # The total number of processes to be used are num_workers + 1. The parent process is
                # also used
                num_workers = args['num_procs'] -1
                process_list = []
                for p in range(0, num_workers):
                    p = Process(target=check_baskets, args=(list_of_baskets, basket_index, chunk, len(list_of_baskets), corrupted_flag, lock, root_filepath))
                    p.start()
                    process_list.append(p)

                # The parent process is also doing his part
                check_baskets(list_of_baskets, basket_index, chunk, len(list_of_baskets), corrupted_flag, lock, root_filepath)

                for p in process_list:
                    p.join()

                if corrupted_flag.value ==True:
                    log.info("CORRUPTED file:  %s", root_filepath)
                    if args['dont_remove']:
                        log.info("File will not be removed, --dont-remove is set to True")
                    else:
                        log.info("Will attempt to remove file: %s", root_filepath)
                        remove_file(root_filepath, args['remove_only_from'], args['full_file'])
                        remove_from_db(conn, root_filepath)
                else:
                    log.info("OK file:  %s", root_filepath)
                    file_checksum = calculate_checksum(root_filepath)
                    if db_file_id == None:
                        insert_in_db(conn, root_filepath, ts, last_modification_date, file_checksum)
                    else:
                        update_db(conn, root_filepath, ts, last_modification_date, file_checksum)

            else:
                log.error("Cannot find a corresponding .cinfo file for root file:  %s", root_filepath)

            max_counter +=1

# Make 'log' global
log = logging.getLogger()
if __name__ == "__main__":
    main()
