ci: add test marks linter

This commit is contained in:
igor.udot 2024-12-02 10:47:32 +08:00
parent 47df2ed524
commit bbcef0570d
13 changed files with 663 additions and 30 deletions

View File

@ -53,13 +53,11 @@ select =
E133, # closing bracket is missing indentation E133, # closing bracket is missing indentation
E201, # whitespace after '(' E201, # whitespace after '('
E202, # whitespace before ')' E202, # whitespace before ')'
E203, # whitespace before ':'
E211, # whitespace before '(' E211, # whitespace before '('
E221, # multiple spaces before operator E221, # multiple spaces before operator
E222, # multiple spaces after operator E222, # multiple spaces after operator
E223, # tab before operator E223, # tab before operator
E224, # tab after operator E224, # tab after operator
E225, # missing whitespace around operator
E226, # missing whitespace around arithmetic operator E226, # missing whitespace around arithmetic operator
E227, # missing whitespace around bitwise or shift operator E227, # missing whitespace around bitwise or shift operator
E228, # missing whitespace around modulo operator E228, # missing whitespace around modulo operator
@ -125,6 +123,7 @@ select =
ignore = ignore =
E221, # multiple spaces before operator E221, # multiple spaces before operator
E225, # missing whitespace around operator
E231, # missing whitespace after ',', ';', or ':' E231, # missing whitespace after ',', ';', or ':'
E241, # multiple spaces after ',' E241, # multiple spaces after ','
W503, # line break before binary operator W503, # line break before binary operator

View File

@ -4,6 +4,32 @@
default_stages: [pre-commit] default_stages: [pre-commit]
repos: repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: "v0.9.7"
hooks:
- id: ruff-format
args: [ "--preview" ]
files: 'pytest_.*\.py$'
- repo: local
hooks:
- id: pytest-linter
name: Pytest Linter Check
entry: tools/ci/check_test_files.py
language: python
files: 'pytest_.*\.py$'
require_serial: true
additional_dependencies:
- pytest-embedded-idf[serial]~=1.14
- pytest-embedded-jtag~=1.14
- pytest-embedded-qemu~=1.14
- pytest-ignore-test-results~=0.3
- pytest-rerunfailures
- pytest-timeout
- idf-build-apps~=2.8
- python-gitlab
- minio
- click
- esp-idf-monitor
- repo: https://github.com/pre-commit/pre-commit-hooks - repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0 rev: v4.5.0
hooks: hooks:
@ -54,6 +80,8 @@ repos:
(?x)^( (?x)^(
.*_pb2.py .*_pb2.py
)$ )$
|pytest_eth_iperf.py
|pytest_iperf.py
- repo: https://github.com/codespell-project/codespell - repo: https://github.com/codespell-project/codespell
rev: v2.3.0 rev: v2.3.0
hooks: hooks:
@ -151,7 +179,7 @@ repos:
require_serial: true require_serial: true
additional_dependencies: additional_dependencies:
- PyYAML == 5.3.1 - PyYAML == 5.3.1
- idf-build-apps>=2.6.2,<3 - idf-build-apps>=2.8,<3
- id: sort-yaml-files - id: sort-yaml-files
name: sort yaml files name: sort yaml files
entry: tools/ci/sort_yaml.py entry: tools/ci/sort_yaml.py

View File

@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: 2021-2024 Espressif Systems (Shanghai) CO LTD # SPDX-FileCopyrightText: 2021-2025 Espressif Systems (Shanghai) CO LTD
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# pylint: disable=W0621 # redefined-outer-name # pylint: disable=W0621 # redefined-outer-name
# #
@ -39,8 +39,14 @@ from dynamic_pipelines.constants import TEST_RELATED_APPS_DOWNLOAD_URLS_FILENAME
from idf_ci.app import import_apps_from_txt from idf_ci.app import import_apps_from_txt
from idf_ci.uploader import AppDownloader, AppUploader from idf_ci.uploader import AppDownloader, AppUploader
from idf_ci_utils import IDF_PATH, idf_relpath from idf_ci_utils import IDF_PATH, idf_relpath
from idf_pytest.constants import DEFAULT_SDKCONFIG, ENV_MARKERS, SPECIAL_MARKERS, TARGET_MARKERS, PytestCase, \ from idf_pytest.constants import (
DEFAULT_LOGDIR DEFAULT_SDKCONFIG,
ENV_MARKERS,
SPECIAL_MARKERS,
TARGET_MARKERS,
PytestCase,
DEFAULT_LOGDIR,
)
from idf_pytest.plugin import IDF_PYTEST_EMBEDDED_KEY, ITEM_PYTEST_CASE_KEY, IdfPytestEmbedded from idf_pytest.plugin import IDF_PYTEST_EMBEDDED_KEY, ITEM_PYTEST_CASE_KEY, IdfPytestEmbedded
from idf_pytest.utils import format_case_id from idf_pytest.utils import format_case_id
from pytest_embedded.plugin import multi_dut_argument, multi_dut_fixture from pytest_embedded.plugin import multi_dut_argument, multi_dut_fixture
@ -318,7 +324,7 @@ def check_performance(idf_path: str) -> t.Callable[[str, float, str], None]:
def _find_perf_item(operator: str, path: str) -> float: def _find_perf_item(operator: str, path: str) -> float:
with open(path, encoding='utf-8') as f: with open(path, encoding='utf-8') as f:
data = f.read() data = f.read()
match = re.search(fr'#define\s+IDF_PERFORMANCE_{operator}_{item.upper()}\s+([\d.]+)', data) match = re.search(rf'#define\s+IDF_PERFORMANCE_{operator}_{item.upper()}\s+([\d.]+)', data)
return float(match.group(1)) # type: ignore return float(match.group(1)) # type: ignore
def _check_perf(operator: str, standard_value: float) -> None: def _check_perf(operator: str, standard_value: float) -> None:
@ -420,6 +426,12 @@ def pytest_addoption(parser: pytest.Parser) -> None:
def pytest_configure(config: Config) -> None: def pytest_configure(config: Config) -> None:
from pytest_embedded_idf.utils import supported_targets, preview_targets
from idf_pytest.constants import SUPPORTED_TARGETS, PREVIEW_TARGETS
supported_targets.set(SUPPORTED_TARGETS)
preview_targets.set(PREVIEW_TARGETS)
# cli option "--target" # cli option "--target"
target = [_t.strip().lower() for _t in (config.getoption('target', '') or '').split(',') if _t.strip()] target = [_t.strip().lower() for _t in (config.getoption('target', '') or '').split(',') if _t.strip()]

7
ruff.toml Normal file
View File

@ -0,0 +1,7 @@
line-length = 120
target-version = "py38"
[format]
quote-style = "single"
exclude = []
docstring-code-format = true

47
tools/ci/check_test_files.py Executable file
View File

@ -0,0 +1,47 @@
#!/usr/bin/env python3
# SPDX-FileCopyrightText: 2025 Espressif Systems (Shanghai) CO LTD
# SPDX-License-Identifier: Apache-2.0
import argparse
import os
import sys
from pathlib import Path
import pytest
sys.path.insert(0, os.path.dirname(__file__))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
from idf_ci_utils import IDF_PATH # noqa: E402
os.environ['IDF_PATH'] = IDF_PATH
os.environ['PYTEST_IGNORE_COLLECT_IMPORT_ERROR'] = '1'
from idf_pytest.plugin import IdfPytestEmbedded # noqa: E402
def main() -> None:
parser = argparse.ArgumentParser(description='Pytest linter check')
parser.add_argument(
'files',
nargs='*',
help='Python files to check (full paths separated by space)',
)
args = parser.parse_args()
# Convert input files to pytest-compatible paths
pytest_scripts = [str(Path(f).resolve()) for f in args.files]
cmd = [
'--collect-only',
*pytest_scripts,
'--target', 'all',
'-p', 'test_linter',
]
res = pytest.main(cmd, plugins=[IdfPytestEmbedded('all')])
sys.exit(res)
if __name__ == '__main__':
main()

View File

@ -61,3 +61,5 @@ tools/legacy_exports/export_legacy.bat
tools/ci/idf_build_apps_dump_soc_caps.py tools/ci/idf_build_apps_dump_soc_caps.py
tools/bt/bt_hci_to_btsnoop.py tools/bt/bt_hci_to_btsnoop.py
tools/bt/README.md tools/bt/README.md
tools/ci/test_linter.py
tools/ci/check_test_files.py

View File

@ -65,6 +65,7 @@ tools/ci/check_requirement_files.py
tools/ci/check_rules_components_patterns.py tools/ci/check_rules_components_patterns.py
tools/ci/check_soc_headers_leak.py tools/ci/check_soc_headers_leak.py
tools/ci/check_soc_struct_headers.py tools/ci/check_soc_struct_headers.py
tools/ci/check_test_files.py
tools/ci/check_tools_files_patterns.py tools/ci/check_tools_files_patterns.py
tools/ci/check_type_comments.py tools/ci/check_type_comments.py
tools/ci/checkout_project_ref.py tools/ci/checkout_project_ref.py

View File

@ -46,7 +46,7 @@ class Metrics:
self.difference = difference or 0.0 self.difference = difference or 0.0
self.difference_percentage = difference_percentage or 0.0 self.difference_percentage = difference_percentage or 0.0
def to_dict(self) -> dict[str, t.Any]: def to_dict(self) -> t.Dict[str, t.Any]:
""" """
Converts the Metrics object to a dictionary. Converts the Metrics object to a dictionary.
""" """

View File

@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: 2023-2024 Espressif Systems (Shanghai) CO LTD # SPDX-FileCopyrightText: 2023-2025 Espressif Systems (Shanghai) CO LTD
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
""" """
Pytest Related Constants. Don't import third-party packages here. Pytest Related Constants. Don't import third-party packages here.
@ -16,7 +16,18 @@ from idf_ci_utils import IDF_PATH
from idf_ci_utils import idf_relpath from idf_ci_utils import idf_relpath
from pytest_embedded.utils import to_list from pytest_embedded.utils import to_list
SUPPORTED_TARGETS = ['esp32', 'esp32s2', 'esp32c3', 'esp32s3', 'esp32c2', 'esp32c6', 'esp32h2', 'esp32p4', 'esp32c5', 'esp32c61'] SUPPORTED_TARGETS = [
'esp32',
'esp32s2',
'esp32c3',
'esp32s3',
'esp32c2',
'esp32c6',
'esp32h2',
'esp32p4',
'esp32c5',
'esp32c61',
]
PREVIEW_TARGETS: t.List[str] = [] # this PREVIEW_TARGETS excludes 'linux' target PREVIEW_TARGETS: t.List[str] = [] # this PREVIEW_TARGETS excludes 'linux' target
DEFAULT_SDKCONFIG = 'default' DEFAULT_SDKCONFIG = 'default'
DEFAULT_LOGDIR = 'pytest-embedded' DEFAULT_LOGDIR = 'pytest-embedded'
@ -30,6 +41,8 @@ TARGET_MARKERS = {
'esp32c5': 'support esp32c5 target', 'esp32c5': 'support esp32c5 target',
'esp32c6': 'support esp32c6 target', 'esp32c6': 'support esp32c6 target',
'esp32h2': 'support esp32h2 target', 'esp32h2': 'support esp32h2 target',
'esp32h4': 'support esp32h4 target', # as preview
'esp32h21': 'support esp32h21 target', # as preview
'esp32p4': 'support esp32p4 target', 'esp32p4': 'support esp32p4 target',
'esp32c61': 'support esp32c61 target', 'esp32c61': 'support esp32c61 target',
'linux': 'support linux target', 'linux': 'support linux target',
@ -174,6 +187,7 @@ class PytestApp:
""" """
Pytest App with relative path to IDF_PATH Pytest App with relative path to IDF_PATH
""" """
def __init__(self, path: str, target: str, config: str) -> None: def __init__(self, path: str, target: str, config: str) -> None:
self.path = idf_relpath(path) self.path = idf_relpath(path)
self.target = target self.target = target
@ -215,8 +229,10 @@ class PytestCase:
for _t in [app.target for app in self.apps]: for _t in [app.target for app in self.apps]:
if _t in self.target_markers: if _t in self.target_markers:
skip = False skip = False
warnings.warn(f'`pytest.mark.[TARGET]` defined in parametrize for multi-dut test cases is deprecated. ' # noqa: W604 warnings.warn(
f'Please use parametrize instead for test case {self.item.nodeid}') f'`pytest.mark.[TARGET]` defined in parametrize for multi-dut test cases is deprecated. ' # noqa: W604
f'Please use parametrize instead for test case {self.item.nodeid}'
)
break break
if not skip: if not skip:
@ -238,7 +254,7 @@ class PytestCase:
return {marker.name for marker in self.item.iter_markers()} return {marker.name for marker in self.item.iter_markers()}
@property @property
def target_markers(self) -> t.Set[str]: def skip_targets(self) -> t.Set[str]:
def _get_temp_markers_disabled_targets(marker_name: str) -> t.Set[str]: def _get_temp_markers_disabled_targets(marker_name: str) -> t.Set[str]:
temp_marker = self.item.get_closest_marker(marker_name) temp_marker = self.item.get_closest_marker(marker_name)
@ -260,11 +276,15 @@ class PytestCase:
# in CI we skip the union of `temp_skip` and `temp_skip_ci` # in CI we skip the union of `temp_skip` and `temp_skip_ci`
if os.getenv('CI_JOB_ID'): if os.getenv('CI_JOB_ID'):
skip_targets = temp_skip_ci_targets.union(temp_skip_targets) _skip_targets = temp_skip_ci_targets.union(temp_skip_targets)
else: # we use `temp_skip` locally else: # we use `temp_skip` locally
skip_targets = temp_skip_targets _skip_targets = temp_skip_targets
return {marker for marker in self.all_markers if marker in TARGET_MARKERS} - skip_targets return _skip_targets
@property
def target_markers(self) -> t.Set[str]:
return {marker for marker in self.all_markers if marker in TARGET_MARKERS} - self.skip_targets
@property @property
def env_markers(self) -> t.Set[str]: def env_markers(self) -> t.Set[str]:
@ -285,10 +305,7 @@ class PytestCase:
if 'jtag' in self.env_markers or 'usb_serial_jtag' in self.env_markers: if 'jtag' in self.env_markers or 'usb_serial_jtag' in self.env_markers:
return True return True
cases_need_elf = [ cases_need_elf = ['panic', 'gdbstub_runtime']
'panic',
'gdbstub_runtime'
]
for case in cases_need_elf: for case in cases_need_elf:
if any(case in Path(app.path).parts for app in self.apps): if any(case in Path(app.path).parts for app in self.apps):

View File

@ -265,12 +265,17 @@ class IdfPytestEmbedded:
# 3.3. CollectMode.MULTI_ALL_WITH_PARAM, intended to be used by `get_pytest_cases` # 3.3. CollectMode.MULTI_ALL_WITH_PARAM, intended to be used by `get_pytest_cases`
else: else:
items[:] = [ filtered_items = []
_item for item in items:
for _item in items case = item_to_case_dict[item]
if not item_to_case_dict[_item].is_single_dut_test_case target = self.get_param(item, 'target', None)
and self.get_param(_item, 'target', None) is not None if (
] not case.is_single_dut_test_case and
target is not None and
target not in case.skip_targets
):
filtered_items.append(item)
items[:] = filtered_items
# 4. filter according to the sdkconfig, if there's param 'config' defined # 4. filter according to the sdkconfig, if there's param 'config' defined
if self.config_name: if self.config_name:

View File

@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: 2022-2024 Espressif Systems (Shanghai) CO LTD # SPDX-FileCopyrightText: 2022-2025 Espressif Systems (Shanghai) CO LTD
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import logging import logging
import os import os
@ -6,6 +6,8 @@ import socket
from typing import Any from typing import Any
from typing import List from typing import List
from idf_ci_utils import IDF_PATH
try: try:
import netifaces import netifaces
except ImportError: except ImportError:
@ -17,8 +19,8 @@ except ImportError:
import yaml import yaml
ENV_CONFIG_FILE_SEARCH = [ ENV_CONFIG_FILE_SEARCH = [
os.path.join(os.environ['IDF_PATH'], 'EnvConfig.yml'), os.path.join(IDF_PATH, 'EnvConfig.yml'),
os.path.join(os.environ['IDF_PATH'], 'ci-test-runner-configs', os.environ.get('CI_RUNNER_DESCRIPTION', ''), 'EnvConfig.yml'), os.path.join(IDF_PATH, 'ci-test-runner-configs', os.environ.get('CI_RUNNER_DESCRIPTION', ''), 'EnvConfig.yml'),
] ]
ENV_CONFIG_TEMPLATE = ''' ENV_CONFIG_TEMPLATE = '''
$IDF_PATH/EnvConfig.yml: $IDF_PATH/EnvConfig.yml:

512
tools/ci/test_linter.py Normal file
View File

@ -0,0 +1,512 @@
# SPDX-FileCopyrightText: 2024-2025 Espressif Systems (Shanghai) CO LTD
# SPDX-License-Identifier: Apache-2.0
import ast
import itertools
import os
import typing as t
import warnings
from collections import defaultdict
import pytest
from idf_pytest.constants import PREVIEW_TARGETS
from idf_pytest.constants import SUPPORTED_TARGETS
from idf_pytest.constants import TARGET_MARKERS
from pytest import Config
from pytest import Function
from pytest import Mark
def is_target_in_marker(mark: Mark) -> bool:
return mark.name in TARGET_MARKERS or mark.name in ('supported_targets', 'preview_targets', 'linux')
def remove_keys(data: t.Dict[str, t.Any], keys_to_remove: t.List[str]) -> t.Dict[str, t.Any]:
"""
Remove specific keys from a dictionary.
"""
return {key: value for key, value in data.items() if key not in keys_to_remove}
def get_values_by_keys(data: t.Dict[str, t.Any], keys: t.List[str]) -> t.Tuple[t.Any, ...]:
"""
Retrieve values from a dictionary for specified keys.
"""
return tuple([data[key] for key in keys if key in data])
def group_by_target(vals: t.List[t.Dict[str, t.Any]]) -> t.List[t.Dict[str, t.Any]]:
"""
Groups rows by non-target keys and modifies targets to 'supported_targets'
if all supported targets are present in a group.
Parameters:
vals: List of dictionaries to process.
Returns:
Processed list of dictionaries with supported targets.
"""
if not vals or 'target' not in vals[0]:
return vals
def _process_group(
_vals: t.List[t.Dict[str, t.Any]], group: t.List[str], group_name: str
) -> t.List[t.Dict[str, t.Any]]:
# Identify keys excluding 'target'
non_target_keys = [key for key in sorted(_vals[0].keys()) if key != 'target']
# Group rows by values of keys excluding 'target'
grouped_rows = defaultdict(list)
for index, row in enumerate(_vals):
key = get_values_by_keys(row, non_target_keys)
grouped_rows[key].append((index, row['target']))
# Identify groups that contain all supported targets
to_skip_lines: t.Set[int] = set()
to_update_lines: t.Set[int] = set()
for _, rows in grouped_rows.items():
lines = []
remaining_targets = set(group)
for index, target in rows:
if target in remaining_targets:
lines.append(index)
remaining_targets.remove(target)
if not remaining_targets:
to_skip_lines.update(lines[1:]) # Skip all but the first matching line
to_update_lines.add(lines[0]) # Update the first matching line
break
# Construct new list of rows with modifications
new_values = []
for ind, row in enumerate(_vals):
if ind in to_update_lines:
row['target'] = group_name
if ind not in to_skip_lines:
new_values.append(row)
return new_values
if SUPPORTED_TARGETS:
vals = _process_group(vals, SUPPORTED_TARGETS, 'supported_targets')
if PREVIEW_TARGETS:
vals = _process_group(vals, PREVIEW_TARGETS, 'preview_targets')
return vals
class CurrentItemContext:
test_name: str
class PathRestore:
# If restored is True, then add the import os when the file is being formatted.
restored: bool = False
def __init__(self, path: str) -> None:
PathRestore.restored = True
self.path = path
def __repr__(self) -> str:
return f"f'{self.path}'"
def restore_path(vals: t.List[t.Dict[str, t.Any]], file_path: str) -> t.List[t.Dict[str, t.Any]]:
if 'app_path' not in vals[0].keys():
return vals
file_path = os.path.dirname(os.path.abspath(file_path))
for row in vals:
paths = row['app_path'].split('|')
row['app_path'] = '|'.join([
f'{{os.path.join(os.path.dirname(__file__), "{os.path.relpath(p, file_path)}")}}' for p in paths
])
row['app_path'] = PathRestore(row['app_path'])
return vals
def make_hashable(item: t.Any) -> t.Union[t.Tuple[t.Any, ...], t.Any]:
"""Recursively convert object to a hashable form, storing original values."""
if isinstance(item, (set, list, tuple)):
converted = tuple(make_hashable(i) for i in item)
elif isinstance(item, dict):
converted = tuple(sorted((k, make_hashable(v)) for k, v in item.items()))
else:
converted = item # Primitives are already hashable
return converted
def restore_params(data: t.List[t.Dict[str, t.Any]]) -> t.List[t.Tuple[t.List[str], t.List[t.Any]]]:
"""
Restore parameters from pytest --collect-only data structure.
"""
# Ensure all dictionaries have the same number of keys
if len({len(d) for d in data}) != 1:
raise ValueError(
f'Inconsistent parameter {CurrentItemContext.test_name} structure: all rows must have the same number of keys.'
)
all_markers_is_empty = []
for d in data:
if 'markers' in d:
all_markers_is_empty.append(not (d['markers']))
d['markers'] = list(set(d['markers']))
if all(all_markers_is_empty):
for d in data:
del d['markers']
hashable_to_original: t.Dict[t.Tuple[str, t.Any], t.Any] = {}
def save_to_hash(key: str, hashable_value: t.Any, original_value: t.Any) -> t.Any:
"""Stores the mapping of hashable values to their original."""
if isinstance(original_value, list):
original_value = tuple(original_value)
hashable_to_original[(key, hashable_value)] = original_value
return hashable_value
def restore_from_hash(key: str, hashable_value: t.Any) -> t.Any:
"""Restores the original value from its hashable equivalent."""
return hashable_to_original.get((key, hashable_value), hashable_value)
# Convert data to a hashable format
data = [{k: save_to_hash(k, make_hashable(v), v) for k, v in row.items()} for row in data]
unique_data = []
for d in data:
if d not in unique_data:
unique_data.append(d)
data = unique_data
data = group_by_target(data)
params_multiplier: t.List[t.Tuple[t.List[str], t.List[t.Any]]] = []
current_keys: t.List[str] = sorted(data[0].keys(), key=lambda x: (x == 'markers', x))
i = 1
while len(current_keys) > i:
# It should be combinations because we are only concerned with the elements, not their order.
for _ in itertools.combinations(current_keys, i):
perm: t.List[str] = list(_)
if perm == ['markers'] or [k for k in current_keys if k not in perm] == ['markers']:
# The mark_runner must be used together with another parameter.
continue
grouped_buckets = defaultdict(list)
for row in data:
grouped_buckets[get_values_by_keys(row, perm)].append(remove_keys(row, perm))
grouped_values = list(grouped_buckets.values())
if all(v == grouped_values[0] for v in grouped_values):
current_keys = [k for k in current_keys if k not in perm]
params_multiplier.append((perm, list(grouped_buckets.keys())))
data = grouped_values[0]
break
else:
i += 1
if data:
remaining_values = [get_values_by_keys(row, current_keys) for row in data]
params_multiplier.append((current_keys, remaining_values))
for key, values in params_multiplier:
values[:] = [tuple(restore_from_hash(key[i], v) for i, v in enumerate(row)) for row in values]
output: t.List[t.Any] = []
if len(key) == 1:
for row in values:
output.extend(row)
values[:] = output
for p in params_multiplier:
if 'markers' in p[0]:
for i, el in enumerate(p[1]):
if el[-1] == ():
p[1][i] = el[:-1]
return params_multiplier
def format_mark(name: str, args: t.Tuple[t.Any, ...], kwargs: t.Dict[str, t.Any]) -> str:
"""Format pytest mark with given arguments and keyword arguments."""
args_str = ', '.join(repr(arg) if isinstance(arg, str) else str(arg) for arg in args)
kwargs_str = ', '.join(f'{key}={repr(value) if isinstance(value, str) else value}' for key, value in kwargs.items())
combined = ', '.join(filter(None, [args_str, kwargs_str]))
return f'@pytest.mark.{name}({combined})\n' if combined else f'@pytest.mark.{name}\n'
def format_parametrize(keys: t.Union[str, t.List[str]], values: t.List[t.Any], indirect: t.Sequence[str]) -> str:
"""Format pytest parametrize for given keys and values."""
# Ensure keys is always a list
if isinstance(keys, str):
keys = [keys]
# Markers will always be at the end, so just remove markers from the keys if it is present
# keys = [k for k in keys if k not in ('__markers',)]
key_str = repr(keys[0]) if len(keys) == 1 else repr(','.join(keys))
# If there any value which need to be represented in some spec way, best way is wrap it with class like PathRestore
formatted_values = [' ' + repr(value) for value in values]
values_str = ',\n'.join(formatted_values)
if indirect:
return f'@idf_parametrize({key_str}, [\n{values_str}\n], indirect={indirect})\n'
return f'@idf_parametrize({key_str}, [\n{values_str}\n])\n'
def key_for_item(item: Function) -> t.Tuple[str, str]:
return item.originalname, str(item.fspath)
def collect_markers(item: Function) -> t.Tuple[t.List[Mark], t.List[Mark]]:
"""Separate local and global markers for a pytest item."""
local_markers, global_markers = [], []
for mark in item.iter_markers():
if mark.name == 'parametrize':
continue
if 'callspec' in dir(item) and mark in item.callspec.marks:
local_markers.append(mark)
else:
global_markers.append(mark)
return local_markers, global_markers
class MarkerRepr(str):
def __new__(cls, mark_name: str, kwargs_str: str, args_str: str, all_args: str) -> 'MarkerRepr':
if not all_args:
instance = super().__new__(cls, f'pytest.mark.{mark_name}')
else:
instance = super().__new__(cls, f'pytest.mark.{mark_name}({all_args})')
return instance # type: ignore
def __init__(self, mark_name: str, kwargs_str: str, args_str: str, all_args: str) -> None:
super().__init__()
self.kwargs_str = kwargs_str
self.args_str = args_str
self.all_args = all_args
self.mark_name = mark_name
def __hash__(self) -> int:
return hash(repr(self))
def __repr__(self) -> str:
if not self.all_args:
return f'pytest.mark.{self.mark_name}'
return f'pytest.mark.{self.mark_name}({self.all_args})'
def mark_to_source(mark: Mark) -> MarkerRepr:
"""Convert a Mark instance to its pytest.mark source code representation."""
kwargs_str = ', '.join(f'{k}={repr(v)}' for k, v in mark.kwargs.items())
args_str = ', '.join(repr(arg) for arg in mark.args)
all_args = ', '.join(filter(None, [args_str, kwargs_str]))
return MarkerRepr(mark.name, kwargs_str, args_str, all_args)
def process_local_markers(local_markers: t.List[Mark]) -> t.Tuple[t.List[str], t.List[MarkerRepr]]:
"""Process local markers to extract targets and runners."""
local_targets, other_markers = [], []
for mark in local_markers:
if is_target_in_marker(mark):
local_targets.append(mark.name)
else:
other_markers.append(mark_to_source(mark))
return sorted(local_targets), sorted(other_markers)
def validate_global_markers(
global_markers: t.List[Mark], local_targets: t.List[str], function_name: str
) -> t.List[Mark]:
"""Validate and normalize global markers."""
normalized_markers = []
for mark in global_markers:
if is_target_in_marker(mark):
if local_targets:
warnings.warn(f'IN {function_name} IGNORING GLOBAL TARGET {mark.name} DUE TO LOCAL TARGETS')
continue
normalized_markers.append(mark)
return normalized_markers
def filter_target(_targets: t.List[str]) -> t.List[str]:
"""
Filters the input targets based on certain conditions.
"""
if len(_targets) == 1:
return _targets
def remove_duplicates(target_list: t.List[str], group: t.List[str], group_name: str) -> t.List[str]:
updated_target = []
for _t in target_list:
if _t in group:
warnings.warn(f'{_t} is already included in {group_name}, no need to specify it separately.')
continue
updated_target.append(_t)
return updated_target
if 'supported_targets' in _targets:
_targets = remove_duplicates(_targets, SUPPORTED_TARGETS, 'supported_targets')
if 'preview_targets' in _targets:
_targets = remove_duplicates(_targets, PREVIEW_TARGETS, 'preview_targets')
return _targets
@pytest.hookimpl(tryfirst=True)
def pytest_collection_modifyitems(config: Config, items: t.List[Function]) -> None:
"""
Local and Global marks in my diff are as follows:
- Local: Used with a parameter inside a parameterized function, like:
parameterized(param(marks=[....]))
- Global: A regular mark.
"""
test_name_to_params: t.Dict[t.Tuple[str, str], t.List] = defaultdict(list)
test_name_to_global_mark: t.Dict[t.Tuple[str, str], t.List] = defaultdict(list)
test_name_has_local_target_marks = defaultdict(bool)
# Collect all fixtures to determine if a parameter is regular or a fixture
fm = config.pluginmanager.get_plugin('funcmanage')
known_fixtures = set(fm._arg2fixturedefs.keys())
# Collecting data
for item in items:
collected = []
item_key = key_for_item(item)
local_markers, global_markers = collect_markers(item)
# global_markers.sort(key=lambda x: x.name)
global_markers.reverse() # markers of item need to be reverted to save origin order
local_targets, other_markers = process_local_markers(local_markers)
if local_targets:
test_name_has_local_target_marks[item_key] = True
local_targets = filter_target(local_targets)
other_markers_dict = {'markers': other_markers} if other_markers else {'markers': []}
if local_targets:
for target in local_targets:
params = item.callspec.params if 'callspec' in dir(item) else {}
collected.append({**params, **other_markers_dict, 'target': target})
else:
if 'callspec' in dir(item):
collected.append({**other_markers_dict, **item.callspec.params})
global_markers = validate_global_markers(global_markers, local_targets, item.name)
# Just warning if global markers was changed
if item_key in test_name_to_global_mark:
if test_name_to_global_mark[item_key] != global_markers:
warnings.warn(
f'{item.originalname} HAS DIFFERENT GLOBAL MARKERS! {test_name_to_global_mark[item_key]} {global_markers}'
)
test_name_to_global_mark[item_key] = global_markers
test_name_to_params[item_key].extend(collected)
# Post-processing: Modify files based on collected data
for (function_name, file_path), function_params in test_name_to_params.items():
CurrentItemContext.test_name = function_name
to_add_lines = []
global_targets = []
for mark in test_name_to_global_mark[(function_name, file_path)]:
if is_target_in_marker(mark):
global_targets.append(mark.name)
continue
to_add_lines.append(format_mark(mark.name, mark.args, mark.kwargs))
function_params_will_not_update = True
if test_name_has_local_target_marks[(function_name, file_path)]:
function_params_will_not_update = False
# After filter_target, it will lose part of them, but we need them when removing decorators in the file.
original_global_targets = global_targets
global_targets = filter_target(global_targets)
is_target_already_in_params = any({'target' in param for param in function_params})
extra = []
if global_targets:
# If any of param have target then skip add global marker.
if is_target_already_in_params:
warnings.warn(f'Function {function_name} already have target params! Skip adding global target')
else:
extra = [{'target': _t} for _t in global_targets]
def _update_file(file_path: str, to_add_lines: t.List[str], lines: t.List[str]) -> None:
output = []
start_with_comment = True
imports = ['from pytest_embedded_idf.utils import idf_parametrize']
if PathRestore.restored:
imports += ['import os']
for i, line in enumerate(lines):
if line.strip() in imports:
continue
if start_with_comment:
if not line == '\n' and not line.startswith(('from', 'import', '#')):
output.extend([f'{_imp}\n' for _imp in imports])
start_with_comment = False
if i in skip_lines:
continue
if line.startswith(f'def {function_name}('):
output.extend(to_add_lines)
output.append(line)
with open(file_path, 'w+') as file:
file.writelines(output)
if not function_params_will_not_update:
buffered_params: t.List[str] = []
if function_params:
function_params = restore_path(function_params, file_path)
for parameter_names, parameter_values in restore_params(function_params):
buffered_params.append(
format_parametrize(
parameter_names,
parameter_values,
indirect=[p for p in parameter_names if p in known_fixtures],
)
)
to_add_lines.extend(buffered_params)
with open(file_path) as file:
lines = file.readlines()
tree = ast.parse(''.join(lines))
skip_lines: t.Set[int] = set()
for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef) and node.name == function_name:
for dec in node.decorator_list:
assert dec.end_lineno is not None
skip_lines.update(list(range(dec.lineno - 1, dec.end_lineno))) # ast count lines from 1 not 0
break
_update_file(file_path, to_add_lines, lines)
if global_targets:
with open(file_path) as file:
lines = file.readlines()
tree = ast.parse(''.join(lines))
skip_lines = set()
for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef) and node.name == function_name:
for dec in node.decorator_list:
if isinstance(dec, ast.Attribute):
if dec.attr in original_global_targets:
assert dec.end_lineno is not None
skip_lines.update(list(range(dec.lineno - 1, dec.end_lineno)))
break
if extra:
to_add_lines = [format_parametrize('target', [_t['target'] for _t in extra], ['target'])] if extra else []
else:
to_add_lines = []
_update_file(file_path, to_add_lines, lines)

View File

@ -4,6 +4,7 @@ import os
import os.path as path import os.path as path
import sys import sys
from typing import Any from typing import Any
from typing import Dict
import pytest import pytest
@ -26,7 +27,7 @@ def start_gdb(dut: PanicTestDut) -> None:
dut.start_gdb_for_gdbstub() dut.start_gdb_for_gdbstub()
def run_and_break(dut: PanicTestDut, cmd: str) -> dict[Any, Any]: def run_and_break(dut: PanicTestDut, cmd: str) -> Dict[Any, Any]:
responses = dut.gdb_write(cmd) responses = dut.gdb_write(cmd)
assert dut.find_gdb_response('running', 'result', responses) is not None assert dut.find_gdb_response('running', 'result', responses) is not None
if not dut.find_gdb_response('stopped', 'notify', responses): # have not stopped on breakpoint yet if not dut.find_gdb_response('stopped', 'notify', responses): # have not stopped on breakpoint yet