From bbcef0570d60b8f56ac2aec4c44c1a06deb03ece Mon Sep 17 00:00:00 2001 From: "igor.udot" Date: Mon, 2 Dec 2024 10:47:32 +0800 Subject: [PATCH] ci: add test marks linter --- .flake8 | 3 +- .pre-commit-config.yaml | 30 +- conftest.py | 20 +- ruff.toml | 7 + tools/ci/check_test_files.py | 47 ++ tools/ci/exclude_check_tools_files.txt | 2 + tools/ci/executable-list.txt | 1 + tools/ci/idf_ci/app.py | 2 +- tools/ci/idf_pytest/constants.py | 41 +- tools/ci/idf_pytest/plugin.py | 17 +- .../ci/python_packages/common_test_methods.py | 8 +- tools/ci/test_linter.py | 512 ++++++++++++++++++ .../gdbstub_runtime/pytest_gdbstub_runtime.py | 3 +- 13 files changed, 663 insertions(+), 30 deletions(-) create mode 100644 ruff.toml create mode 100755 tools/ci/check_test_files.py create mode 100644 tools/ci/test_linter.py diff --git a/.flake8 b/.flake8 index 6ccb00b2ea..d078cf9002 100644 --- a/.flake8 +++ b/.flake8 @@ -53,13 +53,11 @@ select = E133, # closing bracket is missing indentation E201, # whitespace after '(' E202, # whitespace before ')' - E203, # whitespace before ':' E211, # whitespace before '(' E221, # multiple spaces before operator E222, # multiple spaces after operator E223, # tab before operator E224, # tab after operator - E225, # missing whitespace around operator E226, # missing whitespace around arithmetic operator E227, # missing whitespace around bitwise or shift operator E228, # missing whitespace around modulo operator @@ -125,6 +123,7 @@ select = ignore = E221, # multiple spaces before operator + E225, # missing whitespace around operator E231, # missing whitespace after ',', ';', or ':' E241, # multiple spaces after ',' W503, # line break before binary operator diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2d8af6baff..e7534bc953 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,6 +4,32 @@ default_stages: [pre-commit] repos: + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: "v0.9.7" + hooks: + - id: ruff-format + args: [ "--preview" ] + files: 'pytest_.*\.py$' + - repo: local + hooks: + - id: pytest-linter + name: Pytest Linter Check + entry: tools/ci/check_test_files.py + language: python + files: 'pytest_.*\.py$' + require_serial: true + additional_dependencies: + - pytest-embedded-idf[serial]~=1.14 + - pytest-embedded-jtag~=1.14 + - pytest-embedded-qemu~=1.14 + - pytest-ignore-test-results~=0.3 + - pytest-rerunfailures + - pytest-timeout + - idf-build-apps~=2.8 + - python-gitlab + - minio + - click + - esp-idf-monitor - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.5.0 hooks: @@ -54,6 +80,8 @@ repos: (?x)^( .*_pb2.py )$ + |pytest_eth_iperf.py + |pytest_iperf.py - repo: https://github.com/codespell-project/codespell rev: v2.3.0 hooks: @@ -151,7 +179,7 @@ repos: require_serial: true additional_dependencies: - PyYAML == 5.3.1 - - idf-build-apps>=2.6.2,<3 + - idf-build-apps>=2.8,<3 - id: sort-yaml-files name: sort yaml files entry: tools/ci/sort_yaml.py diff --git a/conftest.py b/conftest.py index e14b0d200e..444b61ff0f 100644 --- a/conftest.py +++ b/conftest.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: 2021-2024 Espressif Systems (Shanghai) CO LTD +# SPDX-FileCopyrightText: 2021-2025 Espressif Systems (Shanghai) CO LTD # SPDX-License-Identifier: Apache-2.0 # pylint: disable=W0621 # redefined-outer-name # @@ -39,8 +39,14 @@ from dynamic_pipelines.constants import TEST_RELATED_APPS_DOWNLOAD_URLS_FILENAME from idf_ci.app import import_apps_from_txt from idf_ci.uploader import AppDownloader, AppUploader from idf_ci_utils import IDF_PATH, idf_relpath -from idf_pytest.constants import DEFAULT_SDKCONFIG, ENV_MARKERS, SPECIAL_MARKERS, TARGET_MARKERS, PytestCase, \ - DEFAULT_LOGDIR +from idf_pytest.constants import ( + DEFAULT_SDKCONFIG, + ENV_MARKERS, + SPECIAL_MARKERS, + TARGET_MARKERS, + PytestCase, + DEFAULT_LOGDIR, +) from idf_pytest.plugin import IDF_PYTEST_EMBEDDED_KEY, ITEM_PYTEST_CASE_KEY, IdfPytestEmbedded from idf_pytest.utils import format_case_id from pytest_embedded.plugin import multi_dut_argument, multi_dut_fixture @@ -318,7 +324,7 @@ def check_performance(idf_path: str) -> t.Callable[[str, float, str], None]: def _find_perf_item(operator: str, path: str) -> float: with open(path, encoding='utf-8') as f: data = f.read() - match = re.search(fr'#define\s+IDF_PERFORMANCE_{operator}_{item.upper()}\s+([\d.]+)', data) + match = re.search(rf'#define\s+IDF_PERFORMANCE_{operator}_{item.upper()}\s+([\d.]+)', data) return float(match.group(1)) # type: ignore def _check_perf(operator: str, standard_value: float) -> None: @@ -420,6 +426,12 @@ def pytest_addoption(parser: pytest.Parser) -> None: def pytest_configure(config: Config) -> None: + from pytest_embedded_idf.utils import supported_targets, preview_targets + from idf_pytest.constants import SUPPORTED_TARGETS, PREVIEW_TARGETS + + supported_targets.set(SUPPORTED_TARGETS) + preview_targets.set(PREVIEW_TARGETS) + # cli option "--target" target = [_t.strip().lower() for _t in (config.getoption('target', '') or '').split(',') if _t.strip()] diff --git a/ruff.toml b/ruff.toml new file mode 100644 index 0000000000..5ce8d55523 --- /dev/null +++ b/ruff.toml @@ -0,0 +1,7 @@ +line-length = 120 +target-version = "py38" + +[format] +quote-style = "single" +exclude = [] +docstring-code-format = true diff --git a/tools/ci/check_test_files.py b/tools/ci/check_test_files.py new file mode 100755 index 0000000000..e561759590 --- /dev/null +++ b/tools/ci/check_test_files.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: 2025 Espressif Systems (Shanghai) CO LTD +# SPDX-License-Identifier: Apache-2.0 +import argparse +import os +import sys +from pathlib import Path + +import pytest + + +sys.path.insert(0, os.path.dirname(__file__)) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) + +from idf_ci_utils import IDF_PATH # noqa: E402 + +os.environ['IDF_PATH'] = IDF_PATH +os.environ['PYTEST_IGNORE_COLLECT_IMPORT_ERROR'] = '1' + +from idf_pytest.plugin import IdfPytestEmbedded # noqa: E402 + + +def main() -> None: + parser = argparse.ArgumentParser(description='Pytest linter check') + parser.add_argument( + 'files', + nargs='*', + help='Python files to check (full paths separated by space)', + ) + args = parser.parse_args() + + # Convert input files to pytest-compatible paths + pytest_scripts = [str(Path(f).resolve()) for f in args.files] + + cmd = [ + '--collect-only', + *pytest_scripts, + '--target', 'all', + '-p', 'test_linter', + ] + + res = pytest.main(cmd, plugins=[IdfPytestEmbedded('all')]) + sys.exit(res) + + +if __name__ == '__main__': + main() diff --git a/tools/ci/exclude_check_tools_files.txt b/tools/ci/exclude_check_tools_files.txt index 5febcfec8e..60c7c6fcab 100644 --- a/tools/ci/exclude_check_tools_files.txt +++ b/tools/ci/exclude_check_tools_files.txt @@ -61,3 +61,5 @@ tools/legacy_exports/export_legacy.bat tools/ci/idf_build_apps_dump_soc_caps.py tools/bt/bt_hci_to_btsnoop.py tools/bt/README.md +tools/ci/test_linter.py +tools/ci/check_test_files.py diff --git a/tools/ci/executable-list.txt b/tools/ci/executable-list.txt index 38d44f0c5a..4a3aebc4df 100644 --- a/tools/ci/executable-list.txt +++ b/tools/ci/executable-list.txt @@ -65,6 +65,7 @@ tools/ci/check_requirement_files.py tools/ci/check_rules_components_patterns.py tools/ci/check_soc_headers_leak.py tools/ci/check_soc_struct_headers.py +tools/ci/check_test_files.py tools/ci/check_tools_files_patterns.py tools/ci/check_type_comments.py tools/ci/checkout_project_ref.py diff --git a/tools/ci/idf_ci/app.py b/tools/ci/idf_ci/app.py index 336c374115..4f3d01d807 100644 --- a/tools/ci/idf_ci/app.py +++ b/tools/ci/idf_ci/app.py @@ -46,7 +46,7 @@ class Metrics: self.difference = difference or 0.0 self.difference_percentage = difference_percentage or 0.0 - def to_dict(self) -> dict[str, t.Any]: + def to_dict(self) -> t.Dict[str, t.Any]: """ Converts the Metrics object to a dictionary. """ diff --git a/tools/ci/idf_pytest/constants.py b/tools/ci/idf_pytest/constants.py index b0dbcaf5f0..3d52f9c4af 100644 --- a/tools/ci/idf_pytest/constants.py +++ b/tools/ci/idf_pytest/constants.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: 2023-2024 Espressif Systems (Shanghai) CO LTD +# SPDX-FileCopyrightText: 2023-2025 Espressif Systems (Shanghai) CO LTD # SPDX-License-Identifier: Apache-2.0 """ Pytest Related Constants. Don't import third-party packages here. @@ -16,7 +16,18 @@ from idf_ci_utils import IDF_PATH from idf_ci_utils import idf_relpath from pytest_embedded.utils import to_list -SUPPORTED_TARGETS = ['esp32', 'esp32s2', 'esp32c3', 'esp32s3', 'esp32c2', 'esp32c6', 'esp32h2', 'esp32p4', 'esp32c5', 'esp32c61'] +SUPPORTED_TARGETS = [ + 'esp32', + 'esp32s2', + 'esp32c3', + 'esp32s3', + 'esp32c2', + 'esp32c6', + 'esp32h2', + 'esp32p4', + 'esp32c5', + 'esp32c61', +] PREVIEW_TARGETS: t.List[str] = [] # this PREVIEW_TARGETS excludes 'linux' target DEFAULT_SDKCONFIG = 'default' DEFAULT_LOGDIR = 'pytest-embedded' @@ -30,6 +41,8 @@ TARGET_MARKERS = { 'esp32c5': 'support esp32c5 target', 'esp32c6': 'support esp32c6 target', 'esp32h2': 'support esp32h2 target', + 'esp32h4': 'support esp32h4 target', # as preview + 'esp32h21': 'support esp32h21 target', # as preview 'esp32p4': 'support esp32p4 target', 'esp32c61': 'support esp32c61 target', 'linux': 'support linux target', @@ -174,6 +187,7 @@ class PytestApp: """ Pytest App with relative path to IDF_PATH """ + def __init__(self, path: str, target: str, config: str) -> None: self.path = idf_relpath(path) self.target = target @@ -215,8 +229,10 @@ class PytestCase: for _t in [app.target for app in self.apps]: if _t in self.target_markers: skip = False - warnings.warn(f'`pytest.mark.[TARGET]` defined in parametrize for multi-dut test cases is deprecated. ' # noqa: W604 - f'Please use parametrize instead for test case {self.item.nodeid}') + warnings.warn( + f'`pytest.mark.[TARGET]` defined in parametrize for multi-dut test cases is deprecated. ' # noqa: W604 + f'Please use parametrize instead for test case {self.item.nodeid}' + ) break if not skip: @@ -238,7 +254,7 @@ class PytestCase: return {marker.name for marker in self.item.iter_markers()} @property - def target_markers(self) -> t.Set[str]: + def skip_targets(self) -> t.Set[str]: def _get_temp_markers_disabled_targets(marker_name: str) -> t.Set[str]: temp_marker = self.item.get_closest_marker(marker_name) @@ -260,11 +276,15 @@ class PytestCase: # in CI we skip the union of `temp_skip` and `temp_skip_ci` if os.getenv('CI_JOB_ID'): - skip_targets = temp_skip_ci_targets.union(temp_skip_targets) + _skip_targets = temp_skip_ci_targets.union(temp_skip_targets) else: # we use `temp_skip` locally - skip_targets = temp_skip_targets + _skip_targets = temp_skip_targets - return {marker for marker in self.all_markers if marker in TARGET_MARKERS} - skip_targets + return _skip_targets + + @property + def target_markers(self) -> t.Set[str]: + return {marker for marker in self.all_markers if marker in TARGET_MARKERS} - self.skip_targets @property def env_markers(self) -> t.Set[str]: @@ -285,10 +305,7 @@ class PytestCase: if 'jtag' in self.env_markers or 'usb_serial_jtag' in self.env_markers: return True - cases_need_elf = [ - 'panic', - 'gdbstub_runtime' - ] + cases_need_elf = ['panic', 'gdbstub_runtime'] for case in cases_need_elf: if any(case in Path(app.path).parts for app in self.apps): diff --git a/tools/ci/idf_pytest/plugin.py b/tools/ci/idf_pytest/plugin.py index 06603a5a9c..cfbeca5e4d 100644 --- a/tools/ci/idf_pytest/plugin.py +++ b/tools/ci/idf_pytest/plugin.py @@ -265,12 +265,17 @@ class IdfPytestEmbedded: # 3.3. CollectMode.MULTI_ALL_WITH_PARAM, intended to be used by `get_pytest_cases` else: - items[:] = [ - _item - for _item in items - if not item_to_case_dict[_item].is_single_dut_test_case - and self.get_param(_item, 'target', None) is not None - ] + filtered_items = [] + for item in items: + case = item_to_case_dict[item] + target = self.get_param(item, 'target', None) + if ( + not case.is_single_dut_test_case and + target is not None and + target not in case.skip_targets + ): + filtered_items.append(item) + items[:] = filtered_items # 4. filter according to the sdkconfig, if there's param 'config' defined if self.config_name: diff --git a/tools/ci/python_packages/common_test_methods.py b/tools/ci/python_packages/common_test_methods.py index fe5db2a9e7..d18f3f0e68 100644 --- a/tools/ci/python_packages/common_test_methods.py +++ b/tools/ci/python_packages/common_test_methods.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: 2022-2024 Espressif Systems (Shanghai) CO LTD +# SPDX-FileCopyrightText: 2022-2025 Espressif Systems (Shanghai) CO LTD # SPDX-License-Identifier: Apache-2.0 import logging import os @@ -6,6 +6,8 @@ import socket from typing import Any from typing import List +from idf_ci_utils import IDF_PATH + try: import netifaces except ImportError: @@ -17,8 +19,8 @@ except ImportError: import yaml ENV_CONFIG_FILE_SEARCH = [ - os.path.join(os.environ['IDF_PATH'], 'EnvConfig.yml'), - os.path.join(os.environ['IDF_PATH'], 'ci-test-runner-configs', os.environ.get('CI_RUNNER_DESCRIPTION', ''), 'EnvConfig.yml'), + os.path.join(IDF_PATH, 'EnvConfig.yml'), + os.path.join(IDF_PATH, 'ci-test-runner-configs', os.environ.get('CI_RUNNER_DESCRIPTION', ''), 'EnvConfig.yml'), ] ENV_CONFIG_TEMPLATE = ''' $IDF_PATH/EnvConfig.yml: diff --git a/tools/ci/test_linter.py b/tools/ci/test_linter.py new file mode 100644 index 0000000000..eb03da0a87 --- /dev/null +++ b/tools/ci/test_linter.py @@ -0,0 +1,512 @@ +# SPDX-FileCopyrightText: 2024-2025 Espressif Systems (Shanghai) CO LTD +# SPDX-License-Identifier: Apache-2.0 +import ast +import itertools +import os +import typing as t +import warnings +from collections import defaultdict + +import pytest +from idf_pytest.constants import PREVIEW_TARGETS +from idf_pytest.constants import SUPPORTED_TARGETS +from idf_pytest.constants import TARGET_MARKERS +from pytest import Config +from pytest import Function +from pytest import Mark + + +def is_target_in_marker(mark: Mark) -> bool: + return mark.name in TARGET_MARKERS or mark.name in ('supported_targets', 'preview_targets', 'linux') + + +def remove_keys(data: t.Dict[str, t.Any], keys_to_remove: t.List[str]) -> t.Dict[str, t.Any]: + """ + Remove specific keys from a dictionary. + """ + return {key: value for key, value in data.items() if key not in keys_to_remove} + + +def get_values_by_keys(data: t.Dict[str, t.Any], keys: t.List[str]) -> t.Tuple[t.Any, ...]: + """ + Retrieve values from a dictionary for specified keys. + """ + return tuple([data[key] for key in keys if key in data]) + + +def group_by_target(vals: t.List[t.Dict[str, t.Any]]) -> t.List[t.Dict[str, t.Any]]: + """ + Groups rows by non-target keys and modifies targets to 'supported_targets' + if all supported targets are present in a group. + + Parameters: + vals: List of dictionaries to process. + + Returns: + Processed list of dictionaries with supported targets. + """ + if not vals or 'target' not in vals[0]: + return vals + + def _process_group( + _vals: t.List[t.Dict[str, t.Any]], group: t.List[str], group_name: str + ) -> t.List[t.Dict[str, t.Any]]: + # Identify keys excluding 'target' + non_target_keys = [key for key in sorted(_vals[0].keys()) if key != 'target'] + + # Group rows by values of keys excluding 'target' + grouped_rows = defaultdict(list) + for index, row in enumerate(_vals): + key = get_values_by_keys(row, non_target_keys) + grouped_rows[key].append((index, row['target'])) + + # Identify groups that contain all supported targets + to_skip_lines: t.Set[int] = set() + to_update_lines: t.Set[int] = set() + for _, rows in grouped_rows.items(): + lines = [] + remaining_targets = set(group) + for index, target in rows: + if target in remaining_targets: + lines.append(index) + remaining_targets.remove(target) + if not remaining_targets: + to_skip_lines.update(lines[1:]) # Skip all but the first matching line + to_update_lines.add(lines[0]) # Update the first matching line + break + + # Construct new list of rows with modifications + new_values = [] + for ind, row in enumerate(_vals): + if ind in to_update_lines: + row['target'] = group_name + if ind not in to_skip_lines: + new_values.append(row) + + return new_values + + if SUPPORTED_TARGETS: + vals = _process_group(vals, SUPPORTED_TARGETS, 'supported_targets') + if PREVIEW_TARGETS: + vals = _process_group(vals, PREVIEW_TARGETS, 'preview_targets') + return vals + + +class CurrentItemContext: + test_name: str + + +class PathRestore: + # If restored is True, then add the import os when the file is being formatted. + restored: bool = False + + def __init__(self, path: str) -> None: + PathRestore.restored = True + self.path = path + + def __repr__(self) -> str: + return f"f'{self.path}'" + + +def restore_path(vals: t.List[t.Dict[str, t.Any]], file_path: str) -> t.List[t.Dict[str, t.Any]]: + if 'app_path' not in vals[0].keys(): + return vals + file_path = os.path.dirname(os.path.abspath(file_path)) + for row in vals: + paths = row['app_path'].split('|') + row['app_path'] = '|'.join([ + f'{{os.path.join(os.path.dirname(__file__), "{os.path.relpath(p, file_path)}")}}' for p in paths + ]) + row['app_path'] = PathRestore(row['app_path']) + return vals + + +def make_hashable(item: t.Any) -> t.Union[t.Tuple[t.Any, ...], t.Any]: + """Recursively convert object to a hashable form, storing original values.""" + if isinstance(item, (set, list, tuple)): + converted = tuple(make_hashable(i) for i in item) + elif isinstance(item, dict): + converted = tuple(sorted((k, make_hashable(v)) for k, v in item.items())) + else: + converted = item # Primitives are already hashable + + return converted + + +def restore_params(data: t.List[t.Dict[str, t.Any]]) -> t.List[t.Tuple[t.List[str], t.List[t.Any]]]: + """ + Restore parameters from pytest --collect-only data structure. + """ + # Ensure all dictionaries have the same number of keys + if len({len(d) for d in data}) != 1: + raise ValueError( + f'Inconsistent parameter {CurrentItemContext.test_name} structure: all rows must have the same number of keys.' + ) + + all_markers_is_empty = [] + for d in data: + if 'markers' in d: + all_markers_is_empty.append(not (d['markers'])) + d['markers'] = list(set(d['markers'])) + if all(all_markers_is_empty): + for d in data: + del d['markers'] + + hashable_to_original: t.Dict[t.Tuple[str, t.Any], t.Any] = {} + + def save_to_hash(key: str, hashable_value: t.Any, original_value: t.Any) -> t.Any: + """Stores the mapping of hashable values to their original.""" + if isinstance(original_value, list): + original_value = tuple(original_value) + hashable_to_original[(key, hashable_value)] = original_value + return hashable_value + + def restore_from_hash(key: str, hashable_value: t.Any) -> t.Any: + """Restores the original value from its hashable equivalent.""" + return hashable_to_original.get((key, hashable_value), hashable_value) + + # Convert data to a hashable format + data = [{k: save_to_hash(k, make_hashable(v), v) for k, v in row.items()} for row in data] + unique_data = [] + for d in data: + if d not in unique_data: + unique_data.append(d) + data = unique_data + data = group_by_target(data) + + params_multiplier: t.List[t.Tuple[t.List[str], t.List[t.Any]]] = [] + current_keys: t.List[str] = sorted(data[0].keys(), key=lambda x: (x == 'markers', x)) + i = 1 + + while len(current_keys) > i: + # It should be combinations because we are only concerned with the elements, not their order. + for _ in itertools.combinations(current_keys, i): + perm: t.List[str] = list(_) + if perm == ['markers'] or [k for k in current_keys if k not in perm] == ['markers']: + # The mark_runner must be used together with another parameter. + continue + + grouped_buckets = defaultdict(list) + for row in data: + grouped_buckets[get_values_by_keys(row, perm)].append(remove_keys(row, perm)) + + grouped_values = list(grouped_buckets.values()) + if all(v == grouped_values[0] for v in grouped_values): + current_keys = [k for k in current_keys if k not in perm] + params_multiplier.append((perm, list(grouped_buckets.keys()))) + data = grouped_values[0] + break + else: + i += 1 + + if data: + remaining_values = [get_values_by_keys(row, current_keys) for row in data] + params_multiplier.append((current_keys, remaining_values)) + + for key, values in params_multiplier: + values[:] = [tuple(restore_from_hash(key[i], v) for i, v in enumerate(row)) for row in values] + output: t.List[t.Any] = [] + if len(key) == 1: + for row in values: + output.extend(row) + values[:] = output + + for p in params_multiplier: + if 'markers' in p[0]: + for i, el in enumerate(p[1]): + if el[-1] == (): + p[1][i] = el[:-1] + + return params_multiplier + + +def format_mark(name: str, args: t.Tuple[t.Any, ...], kwargs: t.Dict[str, t.Any]) -> str: + """Format pytest mark with given arguments and keyword arguments.""" + args_str = ', '.join(repr(arg) if isinstance(arg, str) else str(arg) for arg in args) + kwargs_str = ', '.join(f'{key}={repr(value) if isinstance(value, str) else value}' for key, value in kwargs.items()) + combined = ', '.join(filter(None, [args_str, kwargs_str])) + return f'@pytest.mark.{name}({combined})\n' if combined else f'@pytest.mark.{name}\n' + + +def format_parametrize(keys: t.Union[str, t.List[str]], values: t.List[t.Any], indirect: t.Sequence[str]) -> str: + """Format pytest parametrize for given keys and values.""" + # Ensure keys is always a list + if isinstance(keys, str): + keys = [keys] + + # Markers will always be at the end, so just remove markers from the keys if it is present + # keys = [k for k in keys if k not in ('__markers',)] + + key_str = repr(keys[0]) if len(keys) == 1 else repr(','.join(keys)) + + # If there any value which need to be represented in some spec way, best way is wrap it with class like PathRestore + formatted_values = [' ' + repr(value) for value in values] + values_str = ',\n'.join(formatted_values) + + if indirect: + return f'@idf_parametrize({key_str}, [\n{values_str}\n], indirect={indirect})\n' + + return f'@idf_parametrize({key_str}, [\n{values_str}\n])\n' + + +def key_for_item(item: Function) -> t.Tuple[str, str]: + return item.originalname, str(item.fspath) + + +def collect_markers(item: Function) -> t.Tuple[t.List[Mark], t.List[Mark]]: + """Separate local and global markers for a pytest item.""" + local_markers, global_markers = [], [] + + for mark in item.iter_markers(): + if mark.name == 'parametrize': + continue + if 'callspec' in dir(item) and mark in item.callspec.marks: + local_markers.append(mark) + else: + global_markers.append(mark) + + return local_markers, global_markers + + +class MarkerRepr(str): + def __new__(cls, mark_name: str, kwargs_str: str, args_str: str, all_args: str) -> 'MarkerRepr': + if not all_args: + instance = super().__new__(cls, f'pytest.mark.{mark_name}') + else: + instance = super().__new__(cls, f'pytest.mark.{mark_name}({all_args})') + return instance # type: ignore + + def __init__(self, mark_name: str, kwargs_str: str, args_str: str, all_args: str) -> None: + super().__init__() + self.kwargs_str = kwargs_str + self.args_str = args_str + self.all_args = all_args + self.mark_name = mark_name + + def __hash__(self) -> int: + return hash(repr(self)) + + def __repr__(self) -> str: + if not self.all_args: + return f'pytest.mark.{self.mark_name}' + return f'pytest.mark.{self.mark_name}({self.all_args})' + + +def mark_to_source(mark: Mark) -> MarkerRepr: + """Convert a Mark instance to its pytest.mark source code representation.""" + kwargs_str = ', '.join(f'{k}={repr(v)}' for k, v in mark.kwargs.items()) + args_str = ', '.join(repr(arg) for arg in mark.args) + all_args = ', '.join(filter(None, [args_str, kwargs_str])) + + return MarkerRepr(mark.name, kwargs_str, args_str, all_args) + + +def process_local_markers(local_markers: t.List[Mark]) -> t.Tuple[t.List[str], t.List[MarkerRepr]]: + """Process local markers to extract targets and runners.""" + local_targets, other_markers = [], [] + + for mark in local_markers: + if is_target_in_marker(mark): + local_targets.append(mark.name) + else: + other_markers.append(mark_to_source(mark)) + return sorted(local_targets), sorted(other_markers) + + +def validate_global_markers( + global_markers: t.List[Mark], local_targets: t.List[str], function_name: str +) -> t.List[Mark]: + """Validate and normalize global markers.""" + normalized_markers = [] + + for mark in global_markers: + if is_target_in_marker(mark): + if local_targets: + warnings.warn(f'IN {function_name} IGNORING GLOBAL TARGET {mark.name} DUE TO LOCAL TARGETS') + continue + normalized_markers.append(mark) + + return normalized_markers + + +def filter_target(_targets: t.List[str]) -> t.List[str]: + """ + Filters the input targets based on certain conditions. + """ + if len(_targets) == 1: + return _targets + + def remove_duplicates(target_list: t.List[str], group: t.List[str], group_name: str) -> t.List[str]: + updated_target = [] + for _t in target_list: + if _t in group: + warnings.warn(f'{_t} is already included in {group_name}, no need to specify it separately.') + continue + updated_target.append(_t) + return updated_target + + if 'supported_targets' in _targets: + _targets = remove_duplicates(_targets, SUPPORTED_TARGETS, 'supported_targets') + + if 'preview_targets' in _targets: + _targets = remove_duplicates(_targets, PREVIEW_TARGETS, 'preview_targets') + + return _targets + + +@pytest.hookimpl(tryfirst=True) +def pytest_collection_modifyitems(config: Config, items: t.List[Function]) -> None: + """ + Local and Global marks in my diff are as follows: + - Local: Used with a parameter inside a parameterized function, like: + parameterized(param(marks=[....])) + - Global: A regular mark. + """ + + test_name_to_params: t.Dict[t.Tuple[str, str], t.List] = defaultdict(list) + test_name_to_global_mark: t.Dict[t.Tuple[str, str], t.List] = defaultdict(list) + test_name_has_local_target_marks = defaultdict(bool) + + # Collect all fixtures to determine if a parameter is regular or a fixture + fm = config.pluginmanager.get_plugin('funcmanage') + known_fixtures = set(fm._arg2fixturedefs.keys()) + + # Collecting data + for item in items: + collected = [] + item_key = key_for_item(item) + + local_markers, global_markers = collect_markers(item) + # global_markers.sort(key=lambda x: x.name) + global_markers.reverse() # markers of item need to be reverted to save origin order + + local_targets, other_markers = process_local_markers(local_markers) + if local_targets: + test_name_has_local_target_marks[item_key] = True + local_targets = filter_target(local_targets) + other_markers_dict = {'markers': other_markers} if other_markers else {'markers': []} + + if local_targets: + for target in local_targets: + params = item.callspec.params if 'callspec' in dir(item) else {} + collected.append({**params, **other_markers_dict, 'target': target}) + else: + if 'callspec' in dir(item): + collected.append({**other_markers_dict, **item.callspec.params}) + + global_markers = validate_global_markers(global_markers, local_targets, item.name) + + # Just warning if global markers was changed + if item_key in test_name_to_global_mark: + if test_name_to_global_mark[item_key] != global_markers: + warnings.warn( + f'{item.originalname} HAS DIFFERENT GLOBAL MARKERS! {test_name_to_global_mark[item_key]} {global_markers}' + ) + + test_name_to_global_mark[item_key] = global_markers + test_name_to_params[item_key].extend(collected) + + # Post-processing: Modify files based on collected data + for (function_name, file_path), function_params in test_name_to_params.items(): + CurrentItemContext.test_name = function_name + to_add_lines = [] + global_targets = [] + for mark in test_name_to_global_mark[(function_name, file_path)]: + if is_target_in_marker(mark): + global_targets.append(mark.name) + continue + to_add_lines.append(format_mark(mark.name, mark.args, mark.kwargs)) + + function_params_will_not_update = True + if test_name_has_local_target_marks[(function_name, file_path)]: + function_params_will_not_update = False + + # After filter_target, it will lose part of them, but we need them when removing decorators in the file. + original_global_targets = global_targets + global_targets = filter_target(global_targets) + + is_target_already_in_params = any({'target' in param for param in function_params}) + extra = [] + if global_targets: + # If any of param have target then skip add global marker. + if is_target_already_in_params: + warnings.warn(f'Function {function_name} already have target params! Skip adding global target') + else: + extra = [{'target': _t} for _t in global_targets] + + def _update_file(file_path: str, to_add_lines: t.List[str], lines: t.List[str]) -> None: + output = [] + start_with_comment = True + + imports = ['from pytest_embedded_idf.utils import idf_parametrize'] + if PathRestore.restored: + imports += ['import os'] + + for i, line in enumerate(lines): + if line.strip() in imports: + continue + if start_with_comment: + if not line == '\n' and not line.startswith(('from', 'import', '#')): + output.extend([f'{_imp}\n' for _imp in imports]) + start_with_comment = False + + if i in skip_lines: + continue + + if line.startswith(f'def {function_name}('): + output.extend(to_add_lines) + output.append(line) + + with open(file_path, 'w+') as file: + file.writelines(output) + + if not function_params_will_not_update: + buffered_params: t.List[str] = [] + if function_params: + function_params = restore_path(function_params, file_path) + + for parameter_names, parameter_values in restore_params(function_params): + buffered_params.append( + format_parametrize( + parameter_names, + parameter_values, + indirect=[p for p in parameter_names if p in known_fixtures], + ) + ) + + to_add_lines.extend(buffered_params) + + with open(file_path) as file: + lines = file.readlines() + tree = ast.parse(''.join(lines)) + + skip_lines: t.Set[int] = set() + for node in ast.walk(tree): + if isinstance(node, ast.FunctionDef) and node.name == function_name: + for dec in node.decorator_list: + assert dec.end_lineno is not None + skip_lines.update(list(range(dec.lineno - 1, dec.end_lineno))) # ast count lines from 1 not 0 + break + + _update_file(file_path, to_add_lines, lines) + + if global_targets: + with open(file_path) as file: + lines = file.readlines() + tree = ast.parse(''.join(lines)) + + skip_lines = set() + for node in ast.walk(tree): + if isinstance(node, ast.FunctionDef) and node.name == function_name: + for dec in node.decorator_list: + if isinstance(dec, ast.Attribute): + if dec.attr in original_global_targets: + assert dec.end_lineno is not None + skip_lines.update(list(range(dec.lineno - 1, dec.end_lineno))) + break + + if extra: + to_add_lines = [format_parametrize('target', [_t['target'] for _t in extra], ['target'])] if extra else [] + else: + to_add_lines = [] + _update_file(file_path, to_add_lines, lines) diff --git a/tools/test_apps/system/gdbstub_runtime/pytest_gdbstub_runtime.py b/tools/test_apps/system/gdbstub_runtime/pytest_gdbstub_runtime.py index dc51c3681a..2f994fedb9 100644 --- a/tools/test_apps/system/gdbstub_runtime/pytest_gdbstub_runtime.py +++ b/tools/test_apps/system/gdbstub_runtime/pytest_gdbstub_runtime.py @@ -4,6 +4,7 @@ import os import os.path as path import sys from typing import Any +from typing import Dict import pytest @@ -26,7 +27,7 @@ def start_gdb(dut: PanicTestDut) -> None: dut.start_gdb_for_gdbstub() -def run_and_break(dut: PanicTestDut, cmd: str) -> dict[Any, Any]: +def run_and_break(dut: PanicTestDut, cmd: str) -> Dict[Any, Any]: responses = dut.gdb_write(cmd) assert dut.find_gdb_response('running', 'result', responses) is not None if not dut.find_gdb_response('stopped', 'notify', responses): # have not stopped on breakpoint yet