#!/usr/bin/env python3

# Copyright (C) 2020- The University of Notre Dame
# This software is distributed under the GNU General Public License.
# See the file COPYING for details.

import os
import re
import ast
import sys
import json
import glob
import argparse
import subprocess
import importlib
import email.parser

HELPMSG = '''Determine the environment required by Python code.

This script is assumed to be run from inside an active conda environment.
Modules under $CONDA_PREFIX/lib/pythonX.Y/ are assumed to be part of the
base Python installation, with the exception of the subdirectory
site-packages/, which is assumed to contain extra packages installed via
pip or conda. If a module is in site-packages/, but not known by either
pip or conda, a warning is printed. This would indicate a strange packaging
layout that we can't automatically detect, so the user should manually provide
the package name with the --pkg-mapping option. Other modules under the prefix
are assumed to be part of the Python package. It is also an error to import
anything from outside the prefix.'''

## AST-related stuff

def get_stmt_imports(stmt):
    imports = []
    if isinstance(stmt, ast.Import):
        for a in stmt.names:
            imports.append(a.name)
    elif isinstance(stmt, ast.ImportFrom):
        # don't bother with relative imports
        if stmt.level == 0:
            imports.append(stmt.module)
    return imports

def analyze_toplevel(module):
    deps = []
    for stmt in module.body:
        deps += get_stmt_imports(stmt)
    return deps

def analyze_full(module):
    deps = []
    for stmt in ast.walk(module):
        deps += get_stmt_imports(stmt)
    return deps

def analyze_function(module, func_name):
    for stmt in ast.walk(module):
        if isinstance(stmt, ast.FunctionDef) and stmt.name == func_name:
            return analyze_full(stmt)

## Module search functions

def builtin_module(pkg_name):
    if pkg_name in sys.builtin_module_names:
        return True
    pkg = importlib.import_module(pkg_name)
    pth = os.path.relpath(pkg.__file__, start=sys.prefix)
    if pth.startswith('../'):
        raise RuntimeError('Modules cannot be installed outside the conda environment')
    m = re.match(r'lib(?:64)?/python[\d.]+/site-packages/', pth)
    return not m

def conda_packages(pkg):
    out = {}
    with open(os.path.join(sys.prefix, 'conda-meta', pkg['dist_name'] + '.json')) as f:
        meta = json.load(f)
    for a in meta['files']:
        m = re.match(r'lib(?:64)?/python[\d.]+/site-packages/([^/]+)(?:/|\.py)', a)
        if m:
            if '.' in m.group(1):
                continue
            out[m.group(1)] = pkg['name']
    return out

def pip_packages(pkg):
    out = {}
    parser = email.parser.BytesParser()
    m = parser.parsebytes(subprocess.check_output(['pip', 'show', '-f', pkg['name']]))
    files = [x.strip().split('/')[0] for x in m.get('Files').splitlines()]
    for f in files:
        if len(f) == 0:
            continue
        if '.' in f:
            continue
        out[f] = pkg['name']
    return out

def package_mappings():
    out = {}
    pkgs = json.loads(subprocess.check_output(['conda', 'list', '--json']))
    conda_pkgs = [a for a in pkgs if a['channel'] != 'pypi']
    for a in conda_pkgs:
        out.update(conda_packages(a))
    pip_pkgs = [a for a in pkgs if a['channel'] == 'pypi']
    for a in pip_pkgs:
        out.update(pip_packages(a))
    return out

def choose_dep(pkg_mapping, overrides, conda_env, pip_env, conda_pkgs, pip_pkgs, pkg):
    # Don't try to pack up standard modules.
    # This also throws exceptions for modules that can't be imported,
    # or those outside the Conda prefix.
    if builtin_module(pkg):
        return
    # Check if the user provided a package name
    if pkg in overrides:
        for a in conda_env:
            if a.startswith(overrides[pkg] + '='):
                conda_pkgs.add(a)
                return
        for a in pip_env:
            if a.startswith(overrides[pkg] + '='):
                pip_pkgs.add(a)
                return
    # Look for an installed package with the same name
    for a in conda_env:
        if a.startswith(pkg + '='):
            conda_pkgs.add(a)
            return
    for a in pip_env:
        if a.startswith(pkg + '='):
            pip_pkgs.add(a)
            return
    # Try to find another package that provides the module
    if pkg in pkg_mapping:
        for a in conda_env:
            if a.startswith(pkg_mapping[pkg] + '='):
                conda_pkgs.add(a)
                return
        for a in pip_env:
            if a.startswith(pkg_mapping[pkg] + '='):
                pip_pkgs.add(a)
                return
    # Warn the user the package wasn't found
    print("Warning: couldn't match {} to a conda/pip package!".format(pkg), file=sys.stderr)
    print("You may need to use the --import flag.", file=sys.stderr)

## Main code

def export_env(pkg_mapping, overrides, imports):
    imports = set(imports)
    env = json.loads(subprocess.check_output(['conda', 'env', 'export', '--json']))
    conda_env = env.pop('dependencies', [])
    pip_env = []
    for i in range(len(conda_env)):
        if isinstance(conda_env[i], dict):
            pip_env = conda_env.pop(i)['pip']

    conda_pkgs = set()
    pip_pkgs = set()
    for pkg in imports:
        choose_dep(pkg_mapping, overrides, conda_env, pip_env, conda_pkgs, pip_pkgs, pkg)

    # Always include python and pip
    for a in conda_env:
        if a.startswith('python='):
            conda_pkgs.add(a)
    for a in conda_env:
        if a.startswith('pip='):
            conda_pkgs.add(a)

    conda_pkgs = sorted(list(conda_pkgs))
    pip_pkgs = sorted(list(pip_pkgs))
    if pip_pkgs:
        conda_pkgs.append({'pip': pip_pkgs})
    if conda_pkgs:
        env['dependencies'] = conda_pkgs
    return env

if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description=HELPMSG)
    actions = parser.add_mutually_exclusive_group()
    parser.add_argument('source',
        help='Analyze the given Python source code, or - for stdin.')
    parser.add_argument('out',
        help='Path to write the JSON description, or - for stdout.')
    actions.add_argument('--toplevel', action='store_true',
        help='Only include imports at the top level of the script.')
    actions.add_argument('--function',
        help='Only include imports in the given function.')
    actions.add_argument('--pkg-mapping', action='append', metavar='IMPORT=NAME', default=[],
        help='Specify that the module imported as IMPORT in the code is provided by the pip/conda package NAME.')
    args = parser.parse_args()

    filename = '<stdin>'
    source = sys.stdin
    out = sys.stdout
    if args.source != '-':
        source = open(args.source, 'r')
        filename = args.source
    if args.out != '-':
        out = open(args.out, 'w')

    overrides = {}
    for a in args.pkg_mapping:
        (i, n) = a.split('=')
        overrides[i] = n

    code = ast.parse(source.read(), filename=filename)
    if args.toplevel:
        imports = analyze_toplevel(code)
    elif args.function:
        imports = analyze_function(code, args.function)
    else:
        imports = analyze_full(code)
    pkg_mapping = package_mappings()

    json.dump(export_env(pkg_mapping, overrides, imports), out, indent=4, sort_keys=True)
    out.write('\n')
