#!/usr/bin/env python3
# Copyright 2021 WebAssembly Community Group participants
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""A test case update script.

This script is a utility to update wasm-opt based lit tests with new FileCheck
patterns. It is based on LLVM's update_llc_test_checks.py script.
"""

import argparse
import glob
import os
import re
import subprocess
import sys
import tempfile


script_name = os.path.basename(__file__)
NOTICE = (f';; NOTE: Assertions have been generated by {script_name} and ' +
          'should not be edited.')


RUN_LINE_RE = re.compile(r'^\s*;;\s*RUN:\s*(.*)$')
CHECK_PREFIX_RE = re.compile(r'.*--check-prefix[= ](\S+).*')

items = ['type', 'import', 'global', 'memory', 'data', 'table', 'elem', 'tag',
         'export', 'start', 'func']
ITEM_RE = re.compile(r'(^\s*)\((' + '|'.join(items) + r')\s+(\$?[^\s()]*).*$',
                     re.MULTILINE)


def warn(msg):
    print(f'WARNING: {msg}', file=sys.stderr)


def itertests(args):
    """
    Yield (filename, lines) for each test specified in the command line args
    """
    for pattern in args.tests:
        tests = glob.glob(pattern, recursive=True)
        if not tests:
            warn(f'No tests matched {pattern}. Ignoring it.')
            continue
        for test in tests:
            with open(test) as f:
                lines = [line.rstrip() for line in f]
            first_line = lines[0] if lines else ''
            if script_name not in first_line and not args.force:
                warn(f'Skipping test {test} which was not generated by '
                     f'{script_name}. Use -f to override.')
                continue
            yield test, lines


def find_run_lines(test, lines):
    line_matches = [RUN_LINE_RE.match(l) for l in lines]
    matches = [match.group(1) for match in line_matches if match]
    if not matches:
        warn(f'No RUN lines found in {test}. Ignoring.')
        return []
    run_lines = [matches[0]]
    for line in matches[1:]:
        if run_lines[-1].endswith('\\'):
            run_lines[-1] = run_lines[-1].rstrip('\\') + ' ' + line
        else:
            run_lines.append(line)
    return run_lines


def run_command(args, test, tmp, command):
    env = dict(os.environ)
    env['PATH'] = args.binaryen_bin + os.pathsep + env['PATH']
    command = command.replace('%s', test)
    command = command.replace('%t', tmp)
    return subprocess.check_output(command, shell=True, env=env).decode('utf-8')


def find_end(module, start):
    # Find the index one past the closing parenthesis corresponding to the first
    # open parenthesis at `start`.
    assert module[start] == '('
    depth = 1
    for end in range(start + 1, len(module)):
        if depth == 0:
            break
        elif module[end] == '(':
            depth += 1
        elif module[end] == ')':
            depth -= 1
    return end


def split_output(module):
    # Return a list of (name, [lines]) for module items
    out = []
    for match in ITEM_RE.finditer(module):
        kind, name = match[2], match[3]
        end = find_end(module, match.end(1))
        lines = module[match.start():end].split('\n')
        out.append(((kind, name), lines))
    return out


def main():
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument(
        '--binaryen-bin', dest='binaryen_bin', default='bin',
        help=('Specifies the path to the Binaryen executables in the CMake build'
              ' directory. Default: bin/ of current directory (i.e. assume an'
              ' in-tree build).'))
    parser.add_argument(
        '-f', '--force', action='store_true',
        help=('Generate FileCheck patterns even for test files whose existing '
              'patterns were not generated by this script.'))
    parser.add_argument(
        '--dry-run', action='store_true',
        help=('Print the updated test file contents instead of changing the '
              'test files'))
    parser.add_argument('tests', nargs='+', help='The test files to update')
    args = parser.parse_args()
    args.binaryen_bin = os.path.abspath(args.binaryen_bin)

    tmp = tempfile.mktemp()

    for test, lines in itertests(args):
        # List of (prefix, command)
        run_list = []
        for line in find_run_lines(test, lines):
            commands = [cmd.strip() for cmd in line.rsplit('|', 1)]
            filecheck_cmd = ''
            if len(commands) > 1 and commands[1].startswith('filecheck '):
                filecheck_cmd = commands[1]
                commands = commands[:1]

            check_prefix = ''
            if filecheck_cmd.startswith('filecheck '):
                prefix_match = CHECK_PREFIX_RE.match(filecheck_cmd)
                if prefix_match:
                    check_prefix = prefix_match.group(1)
                else:
                    check_prefix = 'CHECK'

            run_list.append((check_prefix, commands[0]))

        # Map check prefixes to lists of ((kind, name), [lines])
        output_modules = {}
        for prefix, command, in run_list:
            output = run_command(args, test, tmp, command)
            if prefix:
                output_modules[prefix] = split_output(output)

        any_prefix = '|'.join(output_modules.keys())
        check_line_re = re.compile(r'^\s*;;\s*(' + any_prefix +
                                   r')(?:-NEXT|-LABEL|-NOT)?:.*$')
        output_lines = [NOTICE]

        def emit_checks(indent, prefix, lines):
            output_lines.append(f'{indent};; {prefix}:     {lines[0]}')
            for line in lines[1:]:
                output_lines.append(f'{indent};; {prefix}-NEXT:{line}')

        # Skip the notice if it is already in the output
        if lines and script_name in lines[0]:
            lines = lines[1:]

        named_items = []
        for line in lines:
            match = ITEM_RE.match(line)
            if match:
                kind, name = match[2], match[3]
                named_items.append((kind, name))

        for line in lines:
            # Skip pre-existing check lines; we will regenerate them.
            if check_line_re.match(line):
                continue
            match = ITEM_RE.match(line)
            if match:
                indent, kind, name = match.groups()
                for prefix, items in output_modules.items():
                    # If the output for this prefix contains an item with this
                    # name, emit all the items up to and including the matching
                    # item
                    has_item = False
                    for kind_name, lines in items:
                        if name and (kind, name) == kind_name:
                            has_item = True
                            break
                    if has_item:
                        while True:
                            kind_name, lines = items.pop(0)
                            if kind_name in named_items:
                                emit_checks(indent, prefix, lines)
                            if name and (kind, name) == kind_name:
                                break
            output_lines.append(line)
        # Output any remaining checks for each prefix
        for prefix, items in output_modules.items():
            for kind_name, lines in items:
                if kind_name in named_items:
                    emit_checks('', prefix, lines)

        if args.dry_run:
            print('\n'.join(output_lines))
        else:
            with open(test, 'w') as f:
                for line in output_lines:
                    f.write(line + '\n')


if __name__ == '__main__':
    main()