diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 5b756e8af7e..b626503c5b4 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -210,8 +210,19 @@ jobs: with: target: aarch64-linux-android + - name: Setup Android NDK + id: setup-ndk + uses: nttld/setup-ndk@v1 + with: + ndk-version: r27 + add-to-path: true + - name: Check compilation for android run: cargo check --target aarch64-linux-android ${{ env.CARGO_ARGS_NO_SSL }} + env: + CC_aarch64_linux_android: ${{ steps.setup-ndk.outputs.ndk-path }}/toolchains/llvm/prebuilt/linux-x86_64/bin/aarch64-linux-android24-clang + AR_aarch64_linux_android: ${{ steps.setup-ndk.outputs.ndk-path }}/toolchains/llvm/prebuilt/linux-x86_64/bin/llvm-ar + CARGO_TARGET_AARCH64_LINUX_ANDROID_LINKER: ${{ steps.setup-ndk.outputs.ndk-path }}/toolchains/llvm/prebuilt/linux-x86_64/bin/aarch64-linux-android24-clang - uses: dtolnay/rust-toolchain@stable with: diff --git a/.github/workflows/cron-ci.yaml b/.github/workflows/cron-ci.yaml index c942364064c..46b45249445 100644 --- a/.github/workflows/cron-ci.yaml +++ b/.github/workflows/cron-ci.yaml @@ -180,7 +180,11 @@ jobs: cd website rm -rf ./assets/criterion cp -r ../target/criterion ./assets/criterion - git add ./assets/criterion + printf '{\n "generated_at": "%s",\n "rustpython_commit": "%s",\n "rustpython_ref": "%s"\n}\n' \ + "$(date -u +%Y-%m-%dT%H:%M:%SZ)" \ + "${{ github.sha }}" \ + "${{ github.ref_name }}" > ./_data/criterion-metadata.json + git add ./assets/criterion ./_data/criterion-metadata.json if git -c user.name="Github Actions" -c user.email="actions@github.com" commit -m "Update benchmark results"; then git push fi diff --git a/.github/workflows/pr-auto-commit.yaml b/.github/workflows/pr-auto-commit.yaml index d4d97b19d18..0d85dc78926 100644 --- a/.github/workflows/pr-auto-commit.yaml +++ b/.github/workflows/pr-auto-commit.yaml @@ -14,7 +14,6 @@ concurrency: jobs: auto_format: - if: ${{ !contains(github.event.pull_request.labels.*.name, 'skip:ci') }} permissions: contents: write pull-requests: write diff --git a/Cargo.lock b/Cargo.lock index 41ab3280428..8df208e4aaa 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -142,6 +142,15 @@ dependencies = [ "num-traits", ] +[[package]] +name = "ar_archive_writer" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0c269894b6fe5e9d7ada0cf69b5bf847ff35bc25fc271f08e1d080fce80339a" +dependencies = [ + "object", +] + [[package]] name = "arbitrary" version = "1.4.2" @@ -2081,6 +2090,15 @@ dependencies = [ "syn", ] +[[package]] +name = "object" +version = "0.32.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a6a622008b6e321afc04970976f62ee297fdbaa6f95318ca343e3eebb9648441" +dependencies = [ + "memchr", +] + [[package]] name = "oid-registry" version = "0.8.1" @@ -2448,6 +2466,16 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "psm" +version = "0.1.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d11f2fedc3b7dafdc2851bc52f277377c5473d378859be234bc7ebb593144d01" +dependencies = [ + "ar_archive_writer", + "cc", +] + [[package]] name = "pymath" version = "0.1.5" @@ -3280,6 +3308,7 @@ dependencies = [ "optional", "parking_lot", "paste", + "psm", "result-like", "ruff_python_ast", "ruff_python_parser", diff --git a/Lib/_py_warnings.py b/Lib/_py_warnings.py new file mode 100644 index 00000000000..55f8c069591 --- /dev/null +++ b/Lib/_py_warnings.py @@ -0,0 +1,869 @@ +"""Python part of the warnings subsystem.""" + +import sys +import _contextvars +import _thread + + +__all__ = ["warn", "warn_explicit", "showwarning", + "formatwarning", "filterwarnings", "simplefilter", + "resetwarnings", "catch_warnings", "deprecated"] + + +# Normally '_wm' is sys.modules['warnings'] but for unit tests it can be +# a different module. User code is allowed to reassign global attributes +# of the 'warnings' module, commonly 'filters' or 'showwarning'. So we +# need to lookup these global attributes dynamically on the '_wm' object, +# rather than binding them earlier. The code in this module consistently uses +# '_wm.' rather than using the globals of this module. If the +# '_warnings' C extension is in use, some globals are replaced by functions +# and variables defined in that extension. +_wm = None + + +def _set_module(module): + global _wm + _wm = module + + +# filters contains a sequence of filter 5-tuples +# The components of the 5-tuple are: +# - an action: error, ignore, always, all, default, module, or once +# - a compiled regex that must match the warning message +# - a class representing the warning category +# - a compiled regex that must match the module that is being warned +# - a line number for the line being warning, or 0 to mean any line +# If either if the compiled regexs are None, match anything. +filters = [] + + +defaultaction = "default" +onceregistry = {} +_lock = _thread.RLock() +_filters_version = 1 + + +# If true, catch_warnings() will use a context var to hold the modified +# filters list. Otherwise, catch_warnings() will operate on the 'filters' +# global of the warnings module. +_use_context = sys.flags.context_aware_warnings + + +class _Context: + def __init__(self, filters): + self._filters = filters + self.log = None # if set to a list, logging is enabled + + def copy(self): + context = _Context(self._filters[:]) + if self.log is not None: + context.log = self.log + return context + + def _record_warning(self, msg): + self.log.append(msg) + + +class _GlobalContext(_Context): + def __init__(self): + self.log = None + + @property + def _filters(self): + # Since there is quite a lot of code that assigns to + # warnings.filters, this needs to return the current value of + # the module global. + try: + return _wm.filters + except AttributeError: + # 'filters' global was deleted. Do we need to actually handle this case? + return [] + + +_global_context = _GlobalContext() + + +_warnings_context = _contextvars.ContextVar('warnings_context') + + +def _get_context(): + if not _use_context: + return _global_context + try: + return _wm._warnings_context.get() + except LookupError: + return _global_context + + +def _set_context(context): + assert _use_context + _wm._warnings_context.set(context) + + +def _new_context(): + assert _use_context + old_context = _wm._get_context() + new_context = old_context.copy() + _wm._set_context(new_context) + return old_context, new_context + + +def _get_filters(): + """Return the current list of filters. This is a non-public API used by + module functions and by the unit tests.""" + return _wm._get_context()._filters + + +def _filters_mutated_lock_held(): + _wm._filters_version += 1 + + +def showwarning(message, category, filename, lineno, file=None, line=None): + """Hook to write a warning to a file; replace if you like.""" + msg = _wm.WarningMessage(message, category, filename, lineno, file, line) + _wm._showwarnmsg_impl(msg) + + +def formatwarning(message, category, filename, lineno, line=None): + """Function to format a warning the standard way.""" + msg = _wm.WarningMessage(message, category, filename, lineno, None, line) + return _wm._formatwarnmsg_impl(msg) + + +def _showwarnmsg_impl(msg): + context = _wm._get_context() + if context.log is not None: + context._record_warning(msg) + return + file = msg.file + if file is None: + file = sys.stderr + if file is None: + # sys.stderr is None when run with pythonw.exe: + # warnings get lost + return + text = _wm._formatwarnmsg(msg) + try: + file.write(text) + except OSError: + # the file (probably stderr) is invalid - this warning gets lost. + pass + + +def _formatwarnmsg_impl(msg): + category = msg.category.__name__ + s = f"{msg.filename}:{msg.lineno}: {category}: {msg.message}\n" + + if msg.line is None: + try: + import linecache + line = linecache.getline(msg.filename, msg.lineno) + except Exception: + # When a warning is logged during Python shutdown, linecache + # and the import machinery don't work anymore + line = None + linecache = None + else: + line = msg.line + if line: + line = line.strip() + s += " %s\n" % line + + if msg.source is not None: + try: + import tracemalloc + # Logging a warning should not raise a new exception: + # catch Exception, not only ImportError and RecursionError. + except Exception: + # don't suggest to enable tracemalloc if it's not available + suggest_tracemalloc = False + tb = None + else: + try: + suggest_tracemalloc = not tracemalloc.is_tracing() + tb = tracemalloc.get_object_traceback(msg.source) + except Exception: + # When a warning is logged during Python shutdown, tracemalloc + # and the import machinery don't work anymore + suggest_tracemalloc = False + tb = None + + if tb is not None: + s += 'Object allocated at (most recent call last):\n' + for frame in tb: + s += (' File "%s", lineno %s\n' + % (frame.filename, frame.lineno)) + + try: + if linecache is not None: + line = linecache.getline(frame.filename, frame.lineno) + else: + line = None + except Exception: + line = None + if line: + line = line.strip() + s += ' %s\n' % line + elif suggest_tracemalloc: + s += (f'{category}: Enable tracemalloc to get the object ' + f'allocation traceback\n') + return s + + +# Keep a reference to check if the function was replaced +_showwarning_orig = showwarning + + +def _showwarnmsg(msg): + """Hook to write a warning to a file; replace if you like.""" + try: + sw = _wm.showwarning + except AttributeError: + pass + else: + if sw is not _showwarning_orig: + # warnings.showwarning() was replaced + if not callable(sw): + raise TypeError("warnings.showwarning() must be set to a " + "function or method") + + sw(msg.message, msg.category, msg.filename, msg.lineno, + msg.file, msg.line) + return + _wm._showwarnmsg_impl(msg) + + +# Keep a reference to check if the function was replaced +_formatwarning_orig = formatwarning + + +def _formatwarnmsg(msg): + """Function to format a warning the standard way.""" + try: + fw = _wm.formatwarning + except AttributeError: + pass + else: + if fw is not _formatwarning_orig: + # warnings.formatwarning() was replaced + return fw(msg.message, msg.category, + msg.filename, msg.lineno, msg.line) + return _wm._formatwarnmsg_impl(msg) + + +def filterwarnings(action, message="", category=Warning, module="", lineno=0, + append=False): + """Insert an entry into the list of warnings filters (at the front). + + 'action' -- one of "error", "ignore", "always", "all", "default", "module", + or "once" + 'message' -- a regex that the warning message must match + 'category' -- a class that the warning must be a subclass of + 'module' -- a regex that the module name must match + 'lineno' -- an integer line number, 0 matches all warnings + 'append' -- if true, append to the list of filters + """ + if action not in {"error", "ignore", "always", "all", "default", "module", "once"}: + raise ValueError(f"invalid action: {action!r}") + if not isinstance(message, str): + raise TypeError("message must be a string") + if not isinstance(category, type) or not issubclass(category, Warning): + raise TypeError("category must be a Warning subclass") + if not isinstance(module, str): + raise TypeError("module must be a string") + if not isinstance(lineno, int): + raise TypeError("lineno must be an int") + if lineno < 0: + raise ValueError("lineno must be an int >= 0") + + if message or module: + import re + + if message: + message = re.compile(message, re.I) + else: + message = None + if module: + module = re.compile(module) + else: + module = None + + _wm._add_filter(action, message, category, module, lineno, append=append) + + +def simplefilter(action, category=Warning, lineno=0, append=False): + """Insert a simple entry into the list of warnings filters (at the front). + + A simple filter matches all modules and messages. + 'action' -- one of "error", "ignore", "always", "all", "default", "module", + or "once" + 'category' -- a class that the warning must be a subclass of + 'lineno' -- an integer line number, 0 matches all warnings + 'append' -- if true, append to the list of filters + """ + if action not in {"error", "ignore", "always", "all", "default", "module", "once"}: + raise ValueError(f"invalid action: {action!r}") + if not isinstance(lineno, int): + raise TypeError("lineno must be an int") + if lineno < 0: + raise ValueError("lineno must be an int >= 0") + _wm._add_filter(action, None, category, None, lineno, append=append) + + +def _filters_mutated(): + # Even though this function is not part of the public API, it's used by + # a fair amount of user code. + with _wm._lock: + _wm._filters_mutated_lock_held() + + +def _add_filter(*item, append): + with _wm._lock: + filters = _wm._get_filters() + if not append: + # Remove possible duplicate filters, so new one will be placed + # in correct place. If append=True and duplicate exists, do nothing. + try: + filters.remove(item) + except ValueError: + pass + filters.insert(0, item) + else: + if item not in filters: + filters.append(item) + _wm._filters_mutated_lock_held() + + +def resetwarnings(): + """Clear the list of warning filters, so that no filters are active.""" + with _wm._lock: + del _wm._get_filters()[:] + _wm._filters_mutated_lock_held() + + +class _OptionError(Exception): + """Exception used by option processing helpers.""" + pass + + +# Helper to process -W options passed via sys.warnoptions +def _processoptions(args): + for arg in args: + try: + _wm._setoption(arg) + except _wm._OptionError as msg: + print("Invalid -W option ignored:", msg, file=sys.stderr) + + +# Helper for _processoptions() +def _setoption(arg): + parts = arg.split(':') + if len(parts) > 5: + raise _wm._OptionError("too many fields (max 5): %r" % (arg,)) + while len(parts) < 5: + parts.append('') + action, message, category, module, lineno = [s.strip() + for s in parts] + action = _wm._getaction(action) + category = _wm._getcategory(category) + if message or module: + import re + if message: + message = re.escape(message) + if module: + module = re.escape(module) + r'\z' + if lineno: + try: + lineno = int(lineno) + if lineno < 0: + raise ValueError + except (ValueError, OverflowError): + raise _wm._OptionError("invalid lineno %r" % (lineno,)) from None + else: + lineno = 0 + _wm.filterwarnings(action, message, category, module, lineno) + + +# Helper for _setoption() +def _getaction(action): + if not action: + return "default" + for a in ('default', 'always', 'all', 'ignore', 'module', 'once', 'error'): + if a.startswith(action): + return a + raise _wm._OptionError("invalid action: %r" % (action,)) + + +# Helper for _setoption() +def _getcategory(category): + if not category: + return Warning + if '.' not in category: + import builtins as m + klass = category + else: + module, _, klass = category.rpartition('.') + try: + m = __import__(module, None, None, [klass]) + except ImportError: + raise _wm._OptionError("invalid module name: %r" % (module,)) from None + try: + cat = getattr(m, klass) + except AttributeError: + raise _wm._OptionError("unknown warning category: %r" % (category,)) from None + if not issubclass(cat, Warning): + raise _wm._OptionError("invalid warning category: %r" % (category,)) + return cat + + +def _is_internal_filename(filename): + return 'importlib' in filename and '_bootstrap' in filename + + +def _is_filename_to_skip(filename, skip_file_prefixes): + return any(filename.startswith(prefix) for prefix in skip_file_prefixes) + + +def _is_internal_frame(frame): + """Signal whether the frame is an internal CPython implementation detail.""" + return _is_internal_filename(frame.f_code.co_filename) + + +def _next_external_frame(frame, skip_file_prefixes): + """Find the next frame that doesn't involve Python or user internals.""" + frame = frame.f_back + while frame is not None and ( + _is_internal_filename(filename := frame.f_code.co_filename) or + _is_filename_to_skip(filename, skip_file_prefixes)): + frame = frame.f_back + return frame + + +# Code typically replaced by _warnings +def warn(message, category=None, stacklevel=1, source=None, + *, skip_file_prefixes=()): + """Issue a warning, or maybe ignore it or raise an exception.""" + # Check if message is already a Warning object + if isinstance(message, Warning): + category = message.__class__ + # Check category argument + if category is None: + category = UserWarning + if not (isinstance(category, type) and issubclass(category, Warning)): + raise TypeError("category must be a Warning subclass, " + "not '{:s}'".format(type(category).__name__)) + if not isinstance(skip_file_prefixes, tuple): + # The C version demands a tuple for implementation performance. + raise TypeError('skip_file_prefixes must be a tuple of strs.') + if skip_file_prefixes: + stacklevel = max(2, stacklevel) + # Get context information + try: + if stacklevel <= 1 or _is_internal_frame(sys._getframe(1)): + # If frame is too small to care or if the warning originated in + # internal code, then do not try to hide any frames. + frame = sys._getframe(stacklevel) + else: + frame = sys._getframe(1) + # Look for one frame less since the above line starts us off. + for x in range(stacklevel-1): + frame = _next_external_frame(frame, skip_file_prefixes) + if frame is None: + raise ValueError + except ValueError: + globals = sys.__dict__ + filename = "" + lineno = 0 + else: + globals = frame.f_globals + filename = frame.f_code.co_filename + lineno = frame.f_lineno + if '__name__' in globals: + module = globals['__name__'] + else: + module = "" + registry = globals.setdefault("__warningregistry__", {}) + _wm.warn_explicit( + message, + category, + filename, + lineno, + module, + registry, + globals, + source=source, + ) + + +def warn_explicit(message, category, filename, lineno, + module=None, registry=None, module_globals=None, + source=None): + lineno = int(lineno) + if module is None: + module = filename or "" + if module[-3:].lower() == ".py": + module = module[:-3] # XXX What about leading pathname? + if isinstance(message, Warning): + text = str(message) + category = message.__class__ + else: + text = message + message = category(message) + key = (text, category, lineno) + with _wm._lock: + if registry is None: + registry = {} + if registry.get('version', 0) != _wm._filters_version: + registry.clear() + registry['version'] = _wm._filters_version + # Quick test for common case + if registry.get(key): + return + # Search the filters + for item in _wm._get_filters(): + action, msg, cat, mod, ln = item + if ((msg is None or msg.match(text)) and + issubclass(category, cat) and + (mod is None or mod.match(module)) and + (ln == 0 or lineno == ln)): + break + else: + action = _wm.defaultaction + # Early exit actions + if action == "ignore": + return + + if action == "error": + raise message + # Other actions + if action == "once": + registry[key] = 1 + oncekey = (text, category) + if _wm.onceregistry.get(oncekey): + return + _wm.onceregistry[oncekey] = 1 + elif action in {"always", "all"}: + pass + elif action == "module": + registry[key] = 1 + altkey = (text, category, 0) + if registry.get(altkey): + return + registry[altkey] = 1 + elif action == "default": + registry[key] = 1 + else: + # Unrecognized actions are errors + raise RuntimeError( + "Unrecognized action (%r) in warnings.filters:\n %s" % + (action, item)) + + # Prime the linecache for formatting, in case the + # "file" is actually in a zipfile or something. + import linecache + linecache.getlines(filename, module_globals) + + # Print message and context + msg = _wm.WarningMessage(message, category, filename, lineno, source=source) + _wm._showwarnmsg(msg) + + +class WarningMessage(object): + + _WARNING_DETAILS = ("message", "category", "filename", "lineno", "file", + "line", "source") + + def __init__(self, message, category, filename, lineno, file=None, + line=None, source=None): + self.message = message + self.category = category + self.filename = filename + self.lineno = lineno + self.file = file + self.line = line + self.source = source + self._category_name = category.__name__ if category else None + + def __str__(self): + return ("{message : %r, category : %r, filename : %r, lineno : %s, " + "line : %r}" % (self.message, self._category_name, + self.filename, self.lineno, self.line)) + + def __repr__(self): + return f'<{type(self).__qualname__} {self}>' + + +class catch_warnings(object): + + """A context manager that copies and restores the warnings filter upon + exiting the context. + + The 'record' argument specifies whether warnings should be captured by a + custom implementation of warnings.showwarning() and be appended to a list + returned by the context manager. Otherwise None is returned by the context + manager. The objects appended to the list are arguments whose attributes + mirror the arguments to showwarning(). + + The 'module' argument is to specify an alternative module to the module + named 'warnings' and imported under that name. This argument is only useful + when testing the warnings module itself. + + If the 'action' argument is not None, the remaining arguments are passed + to warnings.simplefilter() as if it were called immediately on entering the + context. + """ + + def __init__(self, *, record=False, module=None, + action=None, category=Warning, lineno=0, append=False): + """Specify whether to record warnings and if an alternative module + should be used other than sys.modules['warnings']. + + """ + self._record = record + self._module = sys.modules['warnings'] if module is None else module + self._entered = False + if action is None: + self._filter = None + else: + self._filter = (action, category, lineno, append) + + def __repr__(self): + args = [] + if self._record: + args.append("record=True") + if self._module is not sys.modules['warnings']: + args.append("module=%r" % self._module) + name = type(self).__name__ + return "%s(%s)" % (name, ", ".join(args)) + + def __enter__(self): + if self._entered: + raise RuntimeError("Cannot enter %r twice" % self) + self._entered = True + with _wm._lock: + if _use_context: + self._saved_context, context = self._module._new_context() + else: + context = None + self._filters = self._module.filters + self._module.filters = self._filters[:] + self._showwarning = self._module.showwarning + self._showwarnmsg_impl = self._module._showwarnmsg_impl + self._module._filters_mutated_lock_held() + if self._record: + if _use_context: + context.log = log = [] + else: + log = [] + self._module._showwarnmsg_impl = log.append + # Reset showwarning() to the default implementation to make sure + # that _showwarnmsg() calls _showwarnmsg_impl() + self._module.showwarning = self._module._showwarning_orig + else: + log = None + if self._filter is not None: + self._module.simplefilter(*self._filter) + return log + + def __exit__(self, *exc_info): + if not self._entered: + raise RuntimeError("Cannot exit %r without entering first" % self) + with _wm._lock: + if _use_context: + self._module._warnings_context.set(self._saved_context) + else: + self._module.filters = self._filters + self._module.showwarning = self._showwarning + self._module._showwarnmsg_impl = self._showwarnmsg_impl + self._module._filters_mutated_lock_held() + + +class deprecated: + """Indicate that a class, function or overload is deprecated. + + When this decorator is applied to an object, the type checker + will generate a diagnostic on usage of the deprecated object. + + Usage: + + @deprecated("Use B instead") + class A: + pass + + @deprecated("Use g instead") + def f(): + pass + + @overload + @deprecated("int support is deprecated") + def g(x: int) -> int: ... + @overload + def g(x: str) -> int: ... + + The warning specified by *category* will be emitted at runtime + on use of deprecated objects. For functions, that happens on calls; + for classes, on instantiation and on creation of subclasses. + If the *category* is ``None``, no warning is emitted at runtime. + The *stacklevel* determines where the + warning is emitted. If it is ``1`` (the default), the warning + is emitted at the direct caller of the deprecated object; if it + is higher, it is emitted further up the stack. + Static type checker behavior is not affected by the *category* + and *stacklevel* arguments. + + The deprecation message passed to the decorator is saved in the + ``__deprecated__`` attribute on the decorated object. + If applied to an overload, the decorator + must be after the ``@overload`` decorator for the attribute to + exist on the overload as returned by ``get_overloads()``. + + See PEP 702 for details. + + """ + def __init__( + self, + message: str, + /, + *, + category: type[Warning] | None = DeprecationWarning, + stacklevel: int = 1, + ) -> None: + if not isinstance(message, str): + raise TypeError( + f"Expected an object of type str for 'message', not {type(message).__name__!r}" + ) + self.message = message + self.category = category + self.stacklevel = stacklevel + + def __call__(self, arg, /): + # Make sure the inner functions created below don't + # retain a reference to self. + msg = self.message + category = self.category + stacklevel = self.stacklevel + if category is None: + arg.__deprecated__ = msg + return arg + elif isinstance(arg, type): + import functools + from types import MethodType + + original_new = arg.__new__ + + @functools.wraps(original_new) + def __new__(cls, /, *args, **kwargs): + if cls is arg: + _wm.warn(msg, category=category, stacklevel=stacklevel + 1) + if original_new is not object.__new__: + return original_new(cls, *args, **kwargs) + # Mirrors a similar check in object.__new__. + elif cls.__init__ is object.__init__ and (args or kwargs): + raise TypeError(f"{cls.__name__}() takes no arguments") + else: + return original_new(cls) + + arg.__new__ = staticmethod(__new__) + + if "__init_subclass__" in arg.__dict__: + # __init_subclass__ is directly present on the decorated class. + # Synthesize a wrapper that calls this method directly. + original_init_subclass = arg.__init_subclass__ + # We need slightly different behavior if __init_subclass__ + # is a bound method (likely if it was implemented in Python). + # Otherwise, it likely means it's a builtin such as + # object's implementation of __init_subclass__. + if isinstance(original_init_subclass, MethodType): + original_init_subclass = original_init_subclass.__func__ + + @functools.wraps(original_init_subclass) + def __init_subclass__(*args, **kwargs): + _wm.warn(msg, category=category, stacklevel=stacklevel + 1) + return original_init_subclass(*args, **kwargs) + else: + def __init_subclass__(cls, *args, **kwargs): + _wm.warn(msg, category=category, stacklevel=stacklevel + 1) + return super(arg, cls).__init_subclass__(*args, **kwargs) + + arg.__init_subclass__ = classmethod(__init_subclass__) + + arg.__deprecated__ = __new__.__deprecated__ = msg + __init_subclass__.__deprecated__ = msg + return arg + elif callable(arg): + import functools + import inspect + + @functools.wraps(arg) + def wrapper(*args, **kwargs): + _wm.warn(msg, category=category, stacklevel=stacklevel + 1) + return arg(*args, **kwargs) + + if inspect.iscoroutinefunction(arg): + wrapper = inspect.markcoroutinefunction(wrapper) + + arg.__deprecated__ = wrapper.__deprecated__ = msg + return wrapper + else: + raise TypeError( + "@deprecated decorator with non-None category must be applied to " + f"a class or callable, not {arg!r}" + ) + + +_DEPRECATED_MSG = "{name!r} is deprecated and slated for removal in Python {remove}" + + +def _deprecated(name, message=_DEPRECATED_MSG, *, remove, _version=sys.version_info): + """Warn that *name* is deprecated or should be removed. + + RuntimeError is raised if *remove* specifies a major/minor tuple older than + the current Python version or the same version but past the alpha. + + The *message* argument is formatted with *name* and *remove* as a Python + version tuple (e.g. (3, 11)). + + """ + remove_formatted = f"{remove[0]}.{remove[1]}" + if (_version[:2] > remove) or (_version[:2] == remove and _version[3] != "alpha"): + msg = f"{name!r} was slated for removal after Python {remove_formatted} alpha" + raise RuntimeError(msg) + else: + msg = message.format(name=name, remove=remove_formatted) + _wm.warn(msg, DeprecationWarning, stacklevel=3) + + +# Private utility function called by _PyErr_WarnUnawaitedCoroutine +def _warn_unawaited_coroutine(coro): + msg_lines = [ + f"coroutine '{coro.__qualname__}' was never awaited\n" + ] + if coro.cr_origin is not None: + import linecache, traceback + def extract(): + for filename, lineno, funcname in reversed(coro.cr_origin): + line = linecache.getline(filename, lineno) + yield (filename, lineno, funcname, line) + msg_lines.append("Coroutine created at (most recent call last)\n") + msg_lines += traceback.format_list(list(extract())) + msg = "".join(msg_lines).rstrip("\n") + # Passing source= here means that if the user happens to have tracemalloc + # enabled and tracking where the coroutine was created, the warning will + # contain that traceback. This does mean that if they have *both* + # coroutine origin tracking *and* tracemalloc enabled, they'll get two + # partially-redundant tracebacks. If we wanted to be clever we could + # probably detect this case and avoid it, but for now we don't bother. + _wm.warn( + msg, category=RuntimeWarning, stacklevel=2, source=coro + ) + + +def _setup_defaults(): + # Several warning categories are ignored by default in regular builds + if hasattr(sys, 'gettotalrefcount'): + return + _wm.filterwarnings("default", category=DeprecationWarning, module="__main__", append=1) + _wm.simplefilter("ignore", category=DeprecationWarning, append=1) + _wm.simplefilter("ignore", category=PendingDeprecationWarning, append=1) + _wm.simplefilter("ignore", category=ImportWarning, append=1) + _wm.simplefilter("ignore", category=ResourceWarning, append=1) diff --git a/Lib/base64.py b/Lib/base64.py old mode 100755 new mode 100644 index 5a7e790a193..f95132a4274 --- a/Lib/base64.py +++ b/Lib/base64.py @@ -1,12 +1,9 @@ -#! /usr/bin/env python3 - """Base16, Base32, Base64 (RFC 3548), Base85 and Ascii85 data encodings""" # Modified 04-Oct-1995 by Jack Jansen to use binascii module # Modified 30-Dec-2003 by Barry Warsaw to add full RFC 3548 support # Modified 22-May-2007 by Guido van Rossum to use bytes everywhere -import re import struct import binascii @@ -286,7 +283,7 @@ def b16decode(s, casefold=False): s = _bytes_from_decode_data(s) if casefold: s = s.upper() - if re.search(b'[^0-9A-F]', s): + if s.translate(None, delete=b'0123456789ABCDEF'): raise binascii.Error('Non-base16 digit found') return binascii.unhexlify(s) @@ -465,9 +462,12 @@ def b85decode(b): # Delay the initialization of tables to not waste memory # if the function is never called if _b85dec is None: - _b85dec = [None] * 256 + # we don't assign to _b85dec directly to avoid issues when + # multiple threads call this function simultaneously + b85dec_tmp = [None] * 256 for i, c in enumerate(_b85alphabet): - _b85dec[c] = i + b85dec_tmp[c] = i + _b85dec = b85dec_tmp b = _bytes_from_decode_data(b) padding = (-len(b)) % 5 @@ -604,7 +604,14 @@ def main(): with open(args[0], 'rb') as f: func(f, sys.stdout.buffer) else: - func(sys.stdin.buffer, sys.stdout.buffer) + if sys.stdin.isatty(): + # gh-138775: read terminal input data all at once to detect EOF + import io + data = sys.stdin.buffer.read() + buffer = io.BytesIO(data) + else: + buffer = sys.stdin.buffer + func(buffer, sys.stdout.buffer) if __name__ == '__main__': diff --git a/Lib/bz2.py b/Lib/bz2.py index 2420cd01906..eb58f4da596 100644 --- a/Lib/bz2.py +++ b/Lib/bz2.py @@ -10,9 +10,9 @@ __author__ = "Nadeem Vawda " from builtins import open as _builtin_open +from compression._common import _streams import io import os -import _compression from _bz2 import BZ2Compressor, BZ2Decompressor @@ -23,7 +23,7 @@ _MODE_WRITE = 3 -class BZ2File(_compression.BaseStream): +class BZ2File(_streams.BaseStream): """A file object providing transparent bzip2 (de)compression. @@ -88,7 +88,7 @@ def __init__(self, filename, mode="r", *, compresslevel=9): raise TypeError("filename must be a str, bytes, file or PathLike object") if self._mode == _MODE_READ: - raw = _compression.DecompressReader(self._fp, + raw = _streams.DecompressReader(self._fp, BZ2Decompressor, trailing_error=OSError) self._buffer = io.BufferedReader(raw) else: @@ -248,7 +248,7 @@ def writelines(self, seq): Line separators are not added between the written byte strings. """ - return _compression.BaseStream.writelines(self, seq) + return _streams.BaseStream.writelines(self, seq) def seek(self, offset, whence=io.SEEK_SET): """Change the file position. diff --git a/Lib/compression/__init__.py b/Lib/compression/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/Lib/compression/_common/__init__.py b/Lib/compression/_common/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/Lib/_compression.py b/Lib/compression/_common/_streams.py similarity index 98% rename from Lib/_compression.py rename to Lib/compression/_common/_streams.py index e8b70aa0a3e..9f367d4e304 100644 --- a/Lib/_compression.py +++ b/Lib/compression/_common/_streams.py @@ -1,4 +1,4 @@ -"""Internal classes used by the gzip, lzma and bz2 modules""" +"""Internal classes used by compression modules""" import io import sys diff --git a/Lib/compression/bz2.py b/Lib/compression/bz2.py new file mode 100644 index 00000000000..16815d6cd20 --- /dev/null +++ b/Lib/compression/bz2.py @@ -0,0 +1,5 @@ +import bz2 +__doc__ = bz2.__doc__ +del bz2 + +from bz2 import * diff --git a/Lib/compression/gzip.py b/Lib/compression/gzip.py new file mode 100644 index 00000000000..552f48f948a --- /dev/null +++ b/Lib/compression/gzip.py @@ -0,0 +1,5 @@ +import gzip +__doc__ = gzip.__doc__ +del gzip + +from gzip import * diff --git a/Lib/compression/lzma.py b/Lib/compression/lzma.py new file mode 100644 index 00000000000..b4bc7ccb1db --- /dev/null +++ b/Lib/compression/lzma.py @@ -0,0 +1,5 @@ +import lzma +__doc__ = lzma.__doc__ +del lzma + +from lzma import * diff --git a/Lib/compression/zlib.py b/Lib/compression/zlib.py new file mode 100644 index 00000000000..3aa7e2db90e --- /dev/null +++ b/Lib/compression/zlib.py @@ -0,0 +1,5 @@ +import zlib +__doc__ = zlib.__doc__ +del zlib + +from zlib import * diff --git a/Lib/compression/zstd/__init__.py b/Lib/compression/zstd/__init__.py new file mode 100644 index 00000000000..84b25914b0a --- /dev/null +++ b/Lib/compression/zstd/__init__.py @@ -0,0 +1,242 @@ +"""Python bindings to the Zstandard (zstd) compression library (RFC-8878).""" + +__all__ = ( + # compression.zstd + 'COMPRESSION_LEVEL_DEFAULT', + 'compress', + 'CompressionParameter', + 'decompress', + 'DecompressionParameter', + 'finalize_dict', + 'get_frame_info', + 'Strategy', + 'train_dict', + + # compression.zstd._zstdfile + 'open', + 'ZstdFile', + + # _zstd + 'get_frame_size', + 'zstd_version', + 'zstd_version_info', + 'ZstdCompressor', + 'ZstdDecompressor', + 'ZstdDict', + 'ZstdError', +) + +import _zstd +import enum +from _zstd import (ZstdCompressor, ZstdDecompressor, ZstdDict, ZstdError, + get_frame_size, zstd_version) +from compression.zstd._zstdfile import ZstdFile, open, _nbytes + +# zstd_version_number is (MAJOR * 100 * 100 + MINOR * 100 + RELEASE) +zstd_version_info = (*divmod(_zstd.zstd_version_number // 100, 100), + _zstd.zstd_version_number % 100) +"""Version number of the runtime zstd library as a tuple of integers.""" + +COMPRESSION_LEVEL_DEFAULT = _zstd.ZSTD_CLEVEL_DEFAULT +"""The default compression level for Zstandard, currently '3'.""" + + +class FrameInfo: + """Information about a Zstandard frame.""" + + __slots__ = 'decompressed_size', 'dictionary_id' + + def __init__(self, decompressed_size, dictionary_id): + super().__setattr__('decompressed_size', decompressed_size) + super().__setattr__('dictionary_id', dictionary_id) + + def __repr__(self): + return (f'FrameInfo(decompressed_size={self.decompressed_size}, ' + f'dictionary_id={self.dictionary_id})') + + def __setattr__(self, name, _): + raise AttributeError(f"can't set attribute {name!r}") + + +def get_frame_info(frame_buffer): + """Get Zstandard frame information from a frame header. + + *frame_buffer* is a bytes-like object. It should start from the beginning + of a frame, and needs to include at least the frame header (6 to 18 bytes). + + The returned FrameInfo object has two attributes. + 'decompressed_size' is the size in bytes of the data in the frame when + decompressed, or None when the decompressed size is unknown. + 'dictionary_id' is an int in the range (0, 2**32). The special value 0 + means that the dictionary ID was not recorded in the frame header, + the frame may or may not need a dictionary to be decoded, + and the ID of such a dictionary is not specified. + """ + return FrameInfo(*_zstd.get_frame_info(frame_buffer)) + + +def train_dict(samples, dict_size): + """Return a ZstdDict representing a trained Zstandard dictionary. + + *samples* is an iterable of samples, where a sample is a bytes-like + object representing a file. + + *dict_size* is the dictionary's maximum size, in bytes. + """ + if not isinstance(dict_size, int): + ds_cls = type(dict_size).__qualname__ + raise TypeError(f'dict_size must be an int object, not {ds_cls!r}.') + + samples = tuple(samples) + chunks = b''.join(samples) + chunk_sizes = tuple(_nbytes(sample) for sample in samples) + if not chunks: + raise ValueError("samples contained no data; can't train dictionary.") + dict_content = _zstd.train_dict(chunks, chunk_sizes, dict_size) + return ZstdDict(dict_content) + + +def finalize_dict(zstd_dict, /, samples, dict_size, level): + """Return a ZstdDict representing a finalized Zstandard dictionary. + + Given a custom content as a basis for dictionary, and a set of samples, + finalize *zstd_dict* by adding headers and statistics according to the + Zstandard dictionary format. + + You may compose an effective dictionary content by hand, which is used as + basis dictionary, and use some samples to finalize a dictionary. The basis + dictionary may be a "raw content" dictionary. See *is_raw* in ZstdDict. + + *samples* is an iterable of samples, where a sample is a bytes-like object + representing a file. + *dict_size* is the dictionary's maximum size, in bytes. + *level* is the expected compression level. The statistics for each + compression level differ, so tuning the dictionary to the compression level + can provide improvements. + """ + + if not isinstance(zstd_dict, ZstdDict): + raise TypeError('zstd_dict argument should be a ZstdDict object.') + if not isinstance(dict_size, int): + raise TypeError('dict_size argument should be an int object.') + if not isinstance(level, int): + raise TypeError('level argument should be an int object.') + + samples = tuple(samples) + chunks = b''.join(samples) + chunk_sizes = tuple(_nbytes(sample) for sample in samples) + if not chunks: + raise ValueError("The samples are empty content, can't finalize the " + "dictionary.") + dict_content = _zstd.finalize_dict(zstd_dict.dict_content, chunks, + chunk_sizes, dict_size, level) + return ZstdDict(dict_content) + + +def compress(data, level=None, options=None, zstd_dict=None): + """Return Zstandard compressed *data* as bytes. + + *level* is an int specifying the compression level to use, defaulting to + COMPRESSION_LEVEL_DEFAULT ('3'). + *options* is a dict object that contains advanced compression + parameters. See CompressionParameter for more on options. + *zstd_dict* is a ZstdDict object, a pre-trained Zstandard dictionary. See + the function train_dict for how to train a ZstdDict on sample data. + + For incremental compression, use a ZstdCompressor instead. + """ + comp = ZstdCompressor(level=level, options=options, zstd_dict=zstd_dict) + return comp.compress(data, mode=ZstdCompressor.FLUSH_FRAME) + + +def decompress(data, zstd_dict=None, options=None): + """Decompress one or more frames of Zstandard compressed *data*. + + *zstd_dict* is a ZstdDict object, a pre-trained Zstandard dictionary. See + the function train_dict for how to train a ZstdDict on sample data. + *options* is a dict object that contains advanced compression + parameters. See DecompressionParameter for more on options. + + For incremental decompression, use a ZstdDecompressor instead. + """ + results = [] + while True: + decomp = ZstdDecompressor(options=options, zstd_dict=zstd_dict) + results.append(decomp.decompress(data)) + if not decomp.eof: + raise ZstdError('Compressed data ended before the ' + 'end-of-stream marker was reached') + data = decomp.unused_data + if not data: + break + return b''.join(results) + + +class CompressionParameter(enum.IntEnum): + """Compression parameters.""" + + compression_level = _zstd.ZSTD_c_compressionLevel + window_log = _zstd.ZSTD_c_windowLog + hash_log = _zstd.ZSTD_c_hashLog + chain_log = _zstd.ZSTD_c_chainLog + search_log = _zstd.ZSTD_c_searchLog + min_match = _zstd.ZSTD_c_minMatch + target_length = _zstd.ZSTD_c_targetLength + strategy = _zstd.ZSTD_c_strategy + + enable_long_distance_matching = _zstd.ZSTD_c_enableLongDistanceMatching + ldm_hash_log = _zstd.ZSTD_c_ldmHashLog + ldm_min_match = _zstd.ZSTD_c_ldmMinMatch + ldm_bucket_size_log = _zstd.ZSTD_c_ldmBucketSizeLog + ldm_hash_rate_log = _zstd.ZSTD_c_ldmHashRateLog + + content_size_flag = _zstd.ZSTD_c_contentSizeFlag + checksum_flag = _zstd.ZSTD_c_checksumFlag + dict_id_flag = _zstd.ZSTD_c_dictIDFlag + + nb_workers = _zstd.ZSTD_c_nbWorkers + job_size = _zstd.ZSTD_c_jobSize + overlap_log = _zstd.ZSTD_c_overlapLog + + def bounds(self): + """Return the (lower, upper) int bounds of a compression parameter. + + Both the lower and upper bounds are inclusive. + """ + return _zstd.get_param_bounds(self.value, is_compress=True) + + +class DecompressionParameter(enum.IntEnum): + """Decompression parameters.""" + + window_log_max = _zstd.ZSTD_d_windowLogMax + + def bounds(self): + """Return the (lower, upper) int bounds of a decompression parameter. + + Both the lower and upper bounds are inclusive. + """ + return _zstd.get_param_bounds(self.value, is_compress=False) + + +class Strategy(enum.IntEnum): + """Compression strategies, listed from fastest to strongest. + + Note that new strategies might be added in the future. + Only the order (from fast to strong) is guaranteed, + the numeric value might change. + """ + + fast = _zstd.ZSTD_fast + dfast = _zstd.ZSTD_dfast + greedy = _zstd.ZSTD_greedy + lazy = _zstd.ZSTD_lazy + lazy2 = _zstd.ZSTD_lazy2 + btlazy2 = _zstd.ZSTD_btlazy2 + btopt = _zstd.ZSTD_btopt + btultra = _zstd.ZSTD_btultra + btultra2 = _zstd.ZSTD_btultra2 + + +# Check validity of the CompressionParameter & DecompressionParameter types +_zstd.set_parameter_types(CompressionParameter, DecompressionParameter) diff --git a/Lib/compression/zstd/_zstdfile.py b/Lib/compression/zstd/_zstdfile.py new file mode 100644 index 00000000000..d709f5efc65 --- /dev/null +++ b/Lib/compression/zstd/_zstdfile.py @@ -0,0 +1,345 @@ +import io +from os import PathLike +from _zstd import ZstdCompressor, ZstdDecompressor, ZSTD_DStreamOutSize +from compression._common import _streams + +__all__ = ('ZstdFile', 'open') + +_MODE_CLOSED = 0 +_MODE_READ = 1 +_MODE_WRITE = 2 + + +def _nbytes(dat, /): + if isinstance(dat, (bytes, bytearray)): + return len(dat) + with memoryview(dat) as mv: + return mv.nbytes + + +class ZstdFile(_streams.BaseStream): + """A file-like object providing transparent Zstandard (de)compression. + + A ZstdFile can act as a wrapper for an existing file object, or refer + directly to a named file on disk. + + ZstdFile provides a *binary* file interface. Data is read and returned as + bytes, and may only be written to objects that support the Buffer Protocol. + """ + + FLUSH_BLOCK = ZstdCompressor.FLUSH_BLOCK + FLUSH_FRAME = ZstdCompressor.FLUSH_FRAME + + def __init__(self, file, /, mode='r', *, + level=None, options=None, zstd_dict=None): + """Open a Zstandard compressed file in binary mode. + + *file* can be either an file-like object, or a file name to open. + + *mode* can be 'r' for reading (default), 'w' for (over)writing, 'x' for + creating exclusively, or 'a' for appending. These can equivalently be + given as 'rb', 'wb', 'xb' and 'ab' respectively. + + *level* is an optional int specifying the compression level to use, + or COMPRESSION_LEVEL_DEFAULT if not given. + + *options* is an optional dict for advanced compression parameters. + See CompressionParameter and DecompressionParameter for the possible + options. + + *zstd_dict* is an optional ZstdDict object, a pre-trained Zstandard + dictionary. See train_dict() to train ZstdDict on sample data. + """ + self._fp = None + self._close_fp = False + self._mode = _MODE_CLOSED + self._buffer = None + + if not isinstance(mode, str): + raise ValueError('mode must be a str') + if options is not None and not isinstance(options, dict): + raise TypeError('options must be a dict or None') + mode = mode.removesuffix('b') # handle rb, wb, xb, ab + if mode == 'r': + if level is not None: + raise TypeError('level is illegal in read mode') + self._mode = _MODE_READ + elif mode in {'w', 'a', 'x'}: + if level is not None and not isinstance(level, int): + raise TypeError('level must be int or None') + self._mode = _MODE_WRITE + self._compressor = ZstdCompressor(level=level, options=options, + zstd_dict=zstd_dict) + self._pos = 0 + else: + raise ValueError(f'Invalid mode: {mode!r}') + + if isinstance(file, (str, bytes, PathLike)): + self._fp = io.open(file, f'{mode}b') + self._close_fp = True + elif ((mode == 'r' and hasattr(file, 'read')) + or (mode != 'r' and hasattr(file, 'write'))): + self._fp = file + else: + raise TypeError('file must be a file-like object ' + 'or a str, bytes, or PathLike object') + + if self._mode == _MODE_READ: + raw = _streams.DecompressReader( + self._fp, + ZstdDecompressor, + zstd_dict=zstd_dict, + options=options, + ) + self._buffer = io.BufferedReader(raw) + + def close(self): + """Flush and close the file. + + May be called multiple times. Once the file has been closed, + any other operation on it will raise ValueError. + """ + if self._fp is None: + return + try: + if self._mode == _MODE_READ: + if getattr(self, '_buffer', None): + self._buffer.close() + self._buffer = None + elif self._mode == _MODE_WRITE: + self.flush(self.FLUSH_FRAME) + self._compressor = None + finally: + self._mode = _MODE_CLOSED + try: + if self._close_fp: + self._fp.close() + finally: + self._fp = None + self._close_fp = False + + def write(self, data, /): + """Write a bytes-like object *data* to the file. + + Returns the number of uncompressed bytes written, which is + always the length of data in bytes. Note that due to buffering, + the file on disk may not reflect the data written until .flush() + or .close() is called. + """ + self._check_can_write() + + length = _nbytes(data) + + compressed = self._compressor.compress(data) + self._fp.write(compressed) + self._pos += length + return length + + def flush(self, mode=FLUSH_BLOCK): + """Flush remaining data to the underlying stream. + + The mode argument can be FLUSH_BLOCK or FLUSH_FRAME. Abuse of this + method will reduce compression ratio, use it only when necessary. + + If the program is interrupted afterwards, all data can be recovered. + To ensure saving to disk, also need to use os.fsync(fd). + + This method does nothing in reading mode. + """ + if self._mode == _MODE_READ: + return + self._check_not_closed() + if mode not in {self.FLUSH_BLOCK, self.FLUSH_FRAME}: + raise ValueError('Invalid mode argument, expected either ' + 'ZstdFile.FLUSH_FRAME or ' + 'ZstdFile.FLUSH_BLOCK') + if self._compressor.last_mode == mode: + return + # Flush zstd block/frame, and write. + data = self._compressor.flush(mode) + self._fp.write(data) + if hasattr(self._fp, 'flush'): + self._fp.flush() + + def read(self, size=-1): + """Read up to size uncompressed bytes from the file. + + If size is negative or omitted, read until EOF is reached. + Returns b'' if the file is already at EOF. + """ + if size is None: + size = -1 + self._check_can_read() + return self._buffer.read(size) + + def read1(self, size=-1): + """Read up to size uncompressed bytes, while trying to avoid + making multiple reads from the underlying stream. Reads up to a + buffer's worth of data if size is negative. + + Returns b'' if the file is at EOF. + """ + self._check_can_read() + if size < 0: + # Note this should *not* be io.DEFAULT_BUFFER_SIZE. + # ZSTD_DStreamOutSize is the minimum amount to read guaranteeing + # a full block is read. + size = ZSTD_DStreamOutSize + return self._buffer.read1(size) + + def readinto(self, b): + """Read bytes into b. + + Returns the number of bytes read (0 for EOF). + """ + self._check_can_read() + return self._buffer.readinto(b) + + def readinto1(self, b): + """Read bytes into b, while trying to avoid making multiple reads + from the underlying stream. + + Returns the number of bytes read (0 for EOF). + """ + self._check_can_read() + return self._buffer.readinto1(b) + + def readline(self, size=-1): + """Read a line of uncompressed bytes from the file. + + The terminating newline (if present) is retained. If size is + non-negative, no more than size bytes will be read (in which + case the line may be incomplete). Returns b'' if already at EOF. + """ + self._check_can_read() + return self._buffer.readline(size) + + def seek(self, offset, whence=io.SEEK_SET): + """Change the file position. + + The new position is specified by offset, relative to the + position indicated by whence. Possible values for whence are: + + 0: start of stream (default): offset must not be negative + 1: current stream position + 2: end of stream; offset must not be positive + + Returns the new file position. + + Note that seeking is emulated, so depending on the arguments, + this operation may be extremely slow. + """ + self._check_can_read() + + # BufferedReader.seek() checks seekable + return self._buffer.seek(offset, whence) + + def peek(self, size=-1): + """Return buffered data without advancing the file position. + + Always returns at least one byte of data, unless at EOF. + The exact number of bytes returned is unspecified. + """ + # Relies on the undocumented fact that BufferedReader.peek() always + # returns at least one byte (except at EOF) + self._check_can_read() + return self._buffer.peek(size) + + def __next__(self): + if ret := self._buffer.readline(): + return ret + raise StopIteration + + def tell(self): + """Return the current file position.""" + self._check_not_closed() + if self._mode == _MODE_READ: + return self._buffer.tell() + elif self._mode == _MODE_WRITE: + return self._pos + + def fileno(self): + """Return the file descriptor for the underlying file.""" + self._check_not_closed() + return self._fp.fileno() + + @property + def name(self): + self._check_not_closed() + return self._fp.name + + @property + def mode(self): + return 'wb' if self._mode == _MODE_WRITE else 'rb' + + @property + def closed(self): + """True if this file is closed.""" + return self._mode == _MODE_CLOSED + + def seekable(self): + """Return whether the file supports seeking.""" + return self.readable() and self._buffer.seekable() + + def readable(self): + """Return whether the file was opened for reading.""" + self._check_not_closed() + return self._mode == _MODE_READ + + def writable(self): + """Return whether the file was opened for writing.""" + self._check_not_closed() + return self._mode == _MODE_WRITE + + +def open(file, /, mode='rb', *, level=None, options=None, zstd_dict=None, + encoding=None, errors=None, newline=None): + """Open a Zstandard compressed file in binary or text mode. + + file can be either a file name (given as a str, bytes, or PathLike object), + in which case the named file is opened, or it can be an existing file object + to read from or write to. + + The mode parameter can be 'r', 'rb' (default), 'w', 'wb', 'x', 'xb', 'a', + 'ab' for binary mode, or 'rt', 'wt', 'xt', 'at' for text mode. + + The level, options, and zstd_dict parameters specify the settings the same + as ZstdFile. + + When using read mode (decompression), the options parameter is a dict + representing advanced decompression options. The level parameter is not + supported in this case. When using write mode (compression), only one of + level, an int representing the compression level, or options, a dict + representing advanced compression options, may be passed. In both modes, + zstd_dict is a ZstdDict instance containing a trained Zstandard dictionary. + + For binary mode, this function is equivalent to the ZstdFile constructor: + ZstdFile(filename, mode, ...). In this case, the encoding, errors and + newline parameters must not be provided. + + For text mode, an ZstdFile object is created, and wrapped in an + io.TextIOWrapper instance with the specified encoding, error handling + behavior, and line ending(s). + """ + + text_mode = 't' in mode + mode = mode.replace('t', '') + + if text_mode: + if 'b' in mode: + raise ValueError(f'Invalid mode: {mode!r}') + else: + if encoding is not None: + raise ValueError('Argument "encoding" not supported in binary mode') + if errors is not None: + raise ValueError('Argument "errors" not supported in binary mode') + if newline is not None: + raise ValueError('Argument "newline" not supported in binary mode') + + binary_file = ZstdFile(file, mode, level=level, options=options, + zstd_dict=zstd_dict) + + if text_mode: + return io.TextIOWrapper(binary_file, encoding, errors, newline) + else: + return binary_file diff --git a/Lib/gzip.py b/Lib/gzip.py index a550c20a7a0..c00f51858de 100644 --- a/Lib/gzip.py +++ b/Lib/gzip.py @@ -5,7 +5,6 @@ # based on Andrew Kuchling's minigzip.py distributed with the zlib module -import _compression import builtins import io import os @@ -14,6 +13,7 @@ import time import weakref import zlib +from compression._common import _streams __all__ = ["BadGzipFile", "GzipFile", "open", "compress", "decompress"] @@ -144,7 +144,7 @@ def writable(self): return True -class GzipFile(_compression.BaseStream): +class GzipFile(_streams.BaseStream): """The GzipFile class simulates most of the methods of a file object with the exception of the truncate() method. @@ -193,6 +193,11 @@ def __init__(self, filename=None, mode=None, """ + # Ensure attributes exist at __del__ + self.mode = None + self.fileobj = None + self._buffer = None + if mode and ('t' in mode or 'U' in mode): raise ValueError("Invalid mode: {!r}".format(mode)) if mode and 'b' not in mode: @@ -332,11 +337,15 @@ def _write_raw(self, data): return length - def read(self, size=-1): - self._check_not_closed() + def _check_read(self, caller): if self.mode != READ: import errno - raise OSError(errno.EBADF, "read() on write-only GzipFile object") + msg = f"{caller}() on write-only GzipFile object" + raise OSError(errno.EBADF, msg) + + def read(self, size=-1): + self._check_not_closed() + self._check_read("read") return self._buffer.read(size) def read1(self, size=-1): @@ -344,19 +353,25 @@ def read1(self, size=-1): Reads up to a buffer's worth of data if size is negative.""" self._check_not_closed() - if self.mode != READ: - import errno - raise OSError(errno.EBADF, "read1() on write-only GzipFile object") + self._check_read("read1") if size < 0: size = io.DEFAULT_BUFFER_SIZE return self._buffer.read1(size) + def readinto(self, b): + self._check_not_closed() + self._check_read("readinto") + return self._buffer.readinto(b) + + def readinto1(self, b): + self._check_not_closed() + self._check_read("readinto1") + return self._buffer.readinto1(b) + def peek(self, n): self._check_not_closed() - if self.mode != READ: - import errno - raise OSError(errno.EBADF, "peek() on write-only GzipFile object") + self._check_read("peek") return self._buffer.peek(n) @property @@ -365,7 +380,9 @@ def closed(self): def close(self): fileobj = self.fileobj - if fileobj is None or self._buffer.closed: + if fileobj is None: + return + if self._buffer is None or self._buffer.closed: return try: if self.mode == WRITE: @@ -445,6 +462,13 @@ def readline(self, size=-1): self._check_not_closed() return self._buffer.readline(size) + def __del__(self): + if self.mode == WRITE and not self.closed: + import warnings + warnings.warn("unclosed GzipFile", + ResourceWarning, source=self, stacklevel=2) + + super().__del__() def _read_exact(fp, n): '''Read exactly *n* bytes from `fp` @@ -499,7 +523,7 @@ def _read_gzip_header(fp): return last_mtime -class _GzipReader(_compression.DecompressReader): +class _GzipReader(_streams.DecompressReader): def __init__(self, fp): super().__init__(_PaddedFile(fp), zlib._ZlibDecompressor, wbits=-zlib.MAX_WBITS) @@ -597,12 +621,12 @@ def _rewind(self): self._new_member = True -def compress(data, compresslevel=_COMPRESS_LEVEL_BEST, *, mtime=None): +def compress(data, compresslevel=_COMPRESS_LEVEL_BEST, *, mtime=0): """Compress data in one shot and return the compressed string. compresslevel sets the compression level in range of 0-9. - mtime can be used to set the modification time. The modification time is - set to the current time by default. + mtime can be used to set the modification time. + The modification time is set to 0 by default, for reproducibility. """ # Wbits=31 automatically includes a gzip header and trailer. gzip_data = zlib.compress(data, level=compresslevel, wbits=31) @@ -643,7 +667,9 @@ def main(): from argparse import ArgumentParser parser = ArgumentParser(description= "A simple command line interface for the gzip module: act like gzip, " - "but do not delete the input file.") + "but do not delete the input file.", + color=True, + ) group = parser.add_mutually_exclusive_group() group.add_argument('--fast', action='store_true', help='compress faster') group.add_argument('--best', action='store_true', help='compress better') diff --git a/Lib/json/__init__.py b/Lib/json/__init__.py index c7a6dcdf77e..9eaa4f3fbc1 100644 --- a/Lib/json/__init__.py +++ b/Lib/json/__init__.py @@ -86,13 +86,13 @@ '[2.0, 1.0]' -Using json.tool from the shell to validate and pretty-print:: +Using json from the shell to validate and pretty-print:: - $ echo '{"json":"obj"}' | python -m json.tool + $ echo '{"json":"obj"}' | python -m json { "json": "obj" } - $ echo '{ 1.2:3.4}' | python -m json.tool + $ echo '{ 1.2:3.4}' | python -m json Expecting property name enclosed in double quotes: line 1 column 3 (char 2) """ __version__ = '2.0.9' diff --git a/Lib/json/__main__.py b/Lib/json/__main__.py new file mode 100644 index 00000000000..1808eaddb62 --- /dev/null +++ b/Lib/json/__main__.py @@ -0,0 +1,20 @@ +"""Command-line tool to validate and pretty-print JSON + +Usage:: + + $ echo '{"json":"obj"}' | python -m json + { + "json": "obj" + } + $ echo '{ 1.2:3.4}' | python -m json + Expecting property name enclosed in double quotes: line 1 column 3 (char 2) + +""" +import json.tool + + +if __name__ == '__main__': + try: + json.tool.main() + except BrokenPipeError as exc: + raise SystemExit(exc.errno) diff --git a/Lib/json/encoder.py b/Lib/json/encoder.py index 0671500d106..5cf6d64f3ea 100644 --- a/Lib/json/encoder.py +++ b/Lib/json/encoder.py @@ -295,37 +295,40 @@ def _iterencode_list(lst, _current_indent_level): else: newline_indent = None separator = _item_separator - first = True - for value in lst: - if first: - first = False - else: + for i, value in enumerate(lst): + if i: buf = separator - if isinstance(value, str): - yield buf + _encoder(value) - elif value is None: - yield buf + 'null' - elif value is True: - yield buf + 'true' - elif value is False: - yield buf + 'false' - elif isinstance(value, int): - # Subclasses of int/float may override __repr__, but we still - # want to encode them as integers/floats in JSON. One example - # within the standard library is IntEnum. - yield buf + _intstr(value) - elif isinstance(value, float): - # see comment above for int - yield buf + _floatstr(value) - else: - yield buf - if isinstance(value, (list, tuple)): - chunks = _iterencode_list(value, _current_indent_level) - elif isinstance(value, dict): - chunks = _iterencode_dict(value, _current_indent_level) + try: + if isinstance(value, str): + yield buf + _encoder(value) + elif value is None: + yield buf + 'null' + elif value is True: + yield buf + 'true' + elif value is False: + yield buf + 'false' + elif isinstance(value, int): + # Subclasses of int/float may override __repr__, but we still + # want to encode them as integers/floats in JSON. One example + # within the standard library is IntEnum. + yield buf + _intstr(value) + elif isinstance(value, float): + # see comment above for int + yield buf + _floatstr(value) else: - chunks = _iterencode(value, _current_indent_level) - yield from chunks + yield buf + if isinstance(value, (list, tuple)): + chunks = _iterencode_list(value, _current_indent_level) + elif isinstance(value, dict): + chunks = _iterencode_dict(value, _current_indent_level) + else: + chunks = _iterencode(value, _current_indent_level) + yield from chunks + except GeneratorExit: + raise + except BaseException as exc: + exc.add_note(f'when serializing {type(lst).__name__} item {i}') + raise if newline_indent is not None: _current_indent_level -= 1 yield '\n' + _indent * _current_indent_level @@ -385,28 +388,34 @@ def _iterencode_dict(dct, _current_indent_level): yield item_separator yield _encoder(key) yield _key_separator - if isinstance(value, str): - yield _encoder(value) - elif value is None: - yield 'null' - elif value is True: - yield 'true' - elif value is False: - yield 'false' - elif isinstance(value, int): - # see comment for int/float in _make_iterencode - yield _intstr(value) - elif isinstance(value, float): - # see comment for int/float in _make_iterencode - yield _floatstr(value) - else: - if isinstance(value, (list, tuple)): - chunks = _iterencode_list(value, _current_indent_level) - elif isinstance(value, dict): - chunks = _iterencode_dict(value, _current_indent_level) + try: + if isinstance(value, str): + yield _encoder(value) + elif value is None: + yield 'null' + elif value is True: + yield 'true' + elif value is False: + yield 'false' + elif isinstance(value, int): + # see comment for int/float in _make_iterencode + yield _intstr(value) + elif isinstance(value, float): + # see comment for int/float in _make_iterencode + yield _floatstr(value) else: - chunks = _iterencode(value, _current_indent_level) - yield from chunks + if isinstance(value, (list, tuple)): + chunks = _iterencode_list(value, _current_indent_level) + elif isinstance(value, dict): + chunks = _iterencode_dict(value, _current_indent_level) + else: + chunks = _iterencode(value, _current_indent_level) + yield from chunks + except GeneratorExit: + raise + except BaseException as exc: + exc.add_note(f'when serializing {type(dct).__name__} item {key!r}') + raise if not first and newline_indent is not None: _current_indent_level -= 1 yield '\n' + _indent * _current_indent_level @@ -439,8 +448,14 @@ def _iterencode(o, _current_indent_level): if markerid in markers: raise ValueError("Circular reference detected") markers[markerid] = o - o = _default(o) - yield from _iterencode(o, _current_indent_level) + newobj = _default(o) + try: + yield from _iterencode(newobj, _current_indent_level) + except GeneratorExit: + raise + except BaseException as exc: + exc.add_note(f'when serializing {type(o).__name__} object') + raise if markers is not None: del markers[markerid] return _iterencode diff --git a/Lib/json/tool.py b/Lib/json/tool.py index fdfc3372bcc..1967817add8 100644 --- a/Lib/json/tool.py +++ b/Lib/json/tool.py @@ -1,25 +1,50 @@ -r"""Command-line tool to validate and pretty-print JSON - -Usage:: - - $ echo '{"json":"obj"}' | python -m json.tool - { - "json": "obj" - } - $ echo '{ 1.2:3.4}' | python -m json.tool - Expecting property name enclosed in double quotes: line 1 column 3 (char 2) +"""Command-line tool to validate and pretty-print JSON +See `json.__main__` for a usage example (invocation as +`python -m json.tool` is supported for backwards compatibility). """ import argparse import json +import re import sys +from _colorize import get_theme, can_colorize + + +# The string we are colorizing is valid JSON, +# so we can use a looser but simpler regex to match +# the various parts, most notably strings and numbers, +# where the regex given by the spec is much more complex. +_color_pattern = re.compile(r''' + (?P"(\\.|[^"\\])*")(?=:) | + (?P"(\\.|[^"\\])*") | + (?PNaN|-?Infinity|[0-9\-+.Ee]+) | + (?Ptrue|false) | + (?Pnull) +''', re.VERBOSE) + +_group_to_theme_color = { + "key": "definition", + "string": "string", + "number": "number", + "boolean": "keyword", + "null": "keyword", +} + + +def _colorize_json(json_str, theme): + def _replace_match_callback(match): + for group, color in _group_to_theme_color.items(): + if m := match.group(group): + return f"{theme[color]}{m}{theme.reset}" + return match.group() + + return re.sub(_color_pattern, _replace_match_callback, json_str) def main(): - prog = 'python -m json.tool' description = ('A simple command line interface for json module ' 'to validate and pretty-print JSON objects.') - parser = argparse.ArgumentParser(prog=prog, description=description) + parser = argparse.ArgumentParser(description=description, color=True) parser.add_argument('infile', nargs='?', help='a JSON file to be validated or pretty-printed', default='-') @@ -75,9 +100,16 @@ def main(): else: outfile = open(options.outfile, 'w', encoding='utf-8') with outfile: - for obj in objs: - json.dump(obj, outfile, **dump_args) - outfile.write('\n') + if can_colorize(file=outfile): + t = get_theme(tty_file=outfile).syntax + for obj in objs: + json_str = json.dumps(obj, **dump_args) + outfile.write(_colorize_json(json_str, t)) + outfile.write('\n') + else: + for obj in objs: + json.dump(obj, outfile, **dump_args) + outfile.write('\n') except ValueError as e: raise SystemExit(e) @@ -86,4 +118,4 @@ def main(): try: main() except BrokenPipeError as exc: - sys.exit(exc.errno) + raise SystemExit(exc.errno) diff --git a/Lib/lzma.py b/Lib/lzma.py index c1e3d33deb6..316066d024e 100644 --- a/Lib/lzma.py +++ b/Lib/lzma.py @@ -24,9 +24,9 @@ import builtins import io import os +from compression._common import _streams from _lzma import * -from _lzma import _encode_filter_properties, _decode_filter_properties -import _compression +from _lzma import _encode_filter_properties, _decode_filter_properties # noqa: F401 # Value 0 no longer used @@ -35,7 +35,7 @@ _MODE_WRITE = 3 -class LZMAFile(_compression.BaseStream): +class LZMAFile(_streams.BaseStream): """A file object providing transparent LZMA (de)compression. @@ -127,7 +127,7 @@ def __init__(self, filename=None, mode="r", *, raise TypeError("filename must be a str, bytes, file or PathLike object") if self._mode == _MODE_READ: - raw = _compression.DecompressReader(self._fp, LZMADecompressor, + raw = _streams.DecompressReader(self._fp, LZMADecompressor, trailing_error=LZMAError, format=format, filters=filters) self._buffer = io.BufferedReader(raw) diff --git a/Lib/test/_test_multiprocessing.py b/Lib/test/_test_multiprocessing.py index f6c1f292867..0855e384f24 100644 --- a/Lib/test/_test_multiprocessing.py +++ b/Lib/test/_test_multiprocessing.py @@ -4417,6 +4417,7 @@ def test_shared_memory_across_processes(self): sms.close() + @unittest.skip("TODO: RUSTPYTHON; flaky") @unittest.skipIf(os.name != "posix", "not feasible in non-posix platforms") def test_shared_memory_SharedMemoryServer_ignores_sigint(self): # bpo-36368: protect SharedMemoryManager server process from diff --git a/Lib/test/support/__init__.py b/Lib/test/support/__init__.py index 444ca2219cf..cc5a48738fd 100644 --- a/Lib/test/support/__init__.py +++ b/Lib/test/support/__init__.py @@ -3,11 +3,11 @@ if __name__ != 'test.support': raise ImportError('support must be imported from the test package') +import _opcode import contextlib -import dataclasses import functools +import inspect import logging -import _opcode import os import re import stat @@ -19,6 +19,7 @@ import unittest import warnings +import annotationlib __all__ = [ # globals @@ -32,7 +33,7 @@ "is_resource_enabled", "requires", "requires_freebsd_version", "requires_gil_enabled", "requires_linux_version", "requires_mac_ver", "check_syntax_error", - "requires_gzip", "requires_bz2", "requires_lzma", + "requires_gzip", "requires_bz2", "requires_lzma", "requires_zstd", "bigmemtest", "bigaddrspacetest", "cpython_only", "get_attribute", "requires_IEEE_754", "requires_zlib", "has_fork_support", "requires_fork", @@ -41,10 +42,11 @@ "anticipate_failure", "load_package_tests", "detect_api_mismatch", "check__all__", "skip_if_buggy_ucrt_strfptime", "check_disallow_instantiation", "check_sanitizer", "skip_if_sanitizer", - "requires_limited_api", "requires_specialization", + "requires_limited_api", "requires_specialization", "thread_unsafe", # sys "MS_WINDOWS", "is_jython", "is_android", "is_emscripten", "is_wasi", "is_apple_mobile", "check_impl_detail", "unix_shell", "setswitchinterval", + "support_remote_exec_only", # os "get_pagesize", # network @@ -57,13 +59,16 @@ "run_with_tz", "PGO", "missing_compiler_executable", "ALWAYS_EQ", "NEVER_EQ", "LARGEST", "SMALLEST", "LOOPBACK_TIMEOUT", "INTERNET_TIMEOUT", "SHORT_TIMEOUT", "LONG_TIMEOUT", - "Py_DEBUG", "exceeds_recursion_limit", "get_c_recursion_limit", - "skip_on_s390x", - "without_optimizer", + "Py_DEBUG", "exceeds_recursion_limit", "skip_on_s390x", + "requires_jit_enabled", + "requires_jit_disabled", "force_not_colorized", "force_not_colorized_test_class", "make_clean_env", "BrokenIter", + "in_systemd_nspawn_sync_suppressed", + "run_no_yield_async_fn", "run_yielding_async_fn", "async_yield", + "reset_code", "on_github_actions" ] @@ -389,6 +394,21 @@ def wrapper(*args, **kw): return decorator +def thread_unsafe(reason): + """Mark a test as not thread safe. When the test runner is run with + --parallel-threads=N, the test will be run in a single thread.""" + def decorator(test_item): + test_item.__unittest_thread_unsafe__ = True + # the reason is not currently used + test_item.__unittest_thread_unsafe__why__ = reason + return test_item + if isinstance(reason, types.FunctionType): + test_item = reason + reason = '' + return decorator(test_item) + return decorator + + def skip_if_buildbot(reason=None): """Decorator raising SkipTest if running on a buildbot.""" import getpass @@ -401,7 +421,8 @@ def skip_if_buildbot(reason=None): isbuildbot = False return unittest.skipIf(isbuildbot, reason) -def check_sanitizer(*, address=False, memory=False, ub=False, thread=False): +def check_sanitizer(*, address=False, memory=False, ub=False, thread=False, + function=True): """Returns True if Python is compiled with sanitizer support""" if not (address or memory or ub or thread): raise ValueError('At least one of address, memory, ub or thread must be True') @@ -425,11 +446,15 @@ def check_sanitizer(*, address=False, memory=False, ub=False, thread=False): '-fsanitize=thread' in cflags or '--with-thread-sanitizer' in config_args ) + function_sanitizer = ( + '-fsanitize=function' in cflags + ) return ( (memory and memory_sanitizer) or (address and address_sanitizer) or (ub and ub_sanitizer) or - (thread and thread_sanitizer) + (thread and thread_sanitizer) or + (function and function_sanitizer) ) @@ -514,13 +539,19 @@ def requires_lzma(reason='requires lzma'): lzma = None # XXX: RUSTPYTHON; xz is not supported yet return unittest.skipUnless(lzma, reason) +def requires_zstd(reason='requires zstd'): + try: + from compression import zstd + except ImportError: + zstd = None + return unittest.skipUnless(zstd, reason) + def has_no_debug_ranges(): try: - import _testinternalcapi + import _testcapi except ImportError: raise unittest.SkipTest("_testinternalcapi required") - config = _testinternalcapi.get_config() - return not bool(config['code_debug_ranges']) + return not _testcapi.config_get('code_debug_ranges') def requires_debug_ranges(reason='requires co_positions / debug_ranges'): try: @@ -531,6 +562,7 @@ def requires_debug_ranges(reason='requires co_positions / debug_ranges'): return unittest.skipIf(skip, reason) +# XXX: RUSTPYTHON; this is not belong to 3.14 def can_use_suppress_immortalization(suppress=True): """Check if suppress_immortalization(suppress) can be used. @@ -583,6 +615,11 @@ def skip_if_suppress_immortalization(): is_android = sys.platform == "android" +def skip_android_selinux(name): + return unittest.skipIf( + sys.platform == "android", f"Android blocks {name} with SELinux" + ) + if sys.platform not in {"win32", "vxworks", "ios", "tvos", "watchos"}: unix_shell = '/system/bin/sh' if is_android else '/bin/sh' else: @@ -593,6 +630,15 @@ def skip_if_suppress_immortalization(): is_emscripten = sys.platform == "emscripten" is_wasi = sys.platform == "wasi" +# Use is_wasm32 as a generic check for WebAssembly platforms. +is_wasm32 = is_emscripten or is_wasi + +def skip_emscripten_stack_overflow(): + return unittest.skipIf(is_emscripten, "Exhausts stack on Emscripten") + +def skip_wasi_stack_overflow(): + return unittest.skipIf(is_wasi, "Exhausts stack on WASI") + is_apple_mobile = sys.platform in {"ios", "tvos", "watchos"} is_apple = is_apple_mobile or sys.platform == "darwin" @@ -715,9 +761,11 @@ def sortdict(dict): return "{%s}" % withcommas -def run_code(code: str) -> dict[str, object]: +def run_code(code: str, extra_names: dict[str, object] | None = None) -> dict[str, object]: """Run a piece of code after dedenting it, and return its global namespace.""" ns = {} + if extra_names: + ns.update(extra_names) exec(textwrap.dedent(code), ns) return ns @@ -735,7 +783,9 @@ def check_syntax_error(testcase, statement, errtext='', *, lineno=None, offset=N def open_urlresource(url, *args, **kw): - import urllib.request, urllib.parse + import urllib.parse + import urllib.request + from .os_helper import unlink try: import gzip @@ -953,8 +1003,16 @@ def calcvobjsize(fmt): return struct.calcsize(_vheader + fmt + _align) -_TPFLAGS_HAVE_GC = 1<<14 +_TPFLAGS_STATIC_BUILTIN = 1<<1 +_TPFLAGS_DISALLOW_INSTANTIATION = 1<<7 +_TPFLAGS_IMMUTABLETYPE = 1<<8 _TPFLAGS_HEAPTYPE = 1<<9 +_TPFLAGS_BASETYPE = 1<<10 +_TPFLAGS_READY = 1<<12 +_TPFLAGS_READYING = 1<<13 +_TPFLAGS_HAVE_GC = 1<<14 +_TPFLAGS_BASE_EXC_SUBCLASS = 1<<30 +_TPFLAGS_TYPE_SUBCLASS = 1<<31 def check_sizeof(test, o, size): try: @@ -1318,6 +1376,26 @@ def coverage_wrapper(*args, **kwargs): return coverage_wrapper +def no_rerun(reason): + """Skip rerunning for a particular test. + + WARNING: Use this decorator with care; skipping rerunning makes it + impossible to find reference leaks. Provide a clear reason for skipping the + test using the 'reason' parameter. + """ + def deco(func): + assert not isinstance(func, type), func + _has_run = False + def wrapper(self): + nonlocal _has_run + if _has_run: + self.skipTest(reason) + func(self) + _has_run = True + return wrapper + return deco + + def refcount_test(test): """Decorator for tests which involve reference counting. @@ -1331,8 +1409,9 @@ def refcount_test(test): def requires_limited_api(test): try: - import _testcapi - import _testlimitedcapi + import _testcapi # noqa: F401 + + import _testlimitedcapi # noqa: F401 except ImportError: return unittest.skip('needs _testcapi and _testlimitedcapi modules')(test) return test @@ -1347,6 +1426,18 @@ def requires_specialization(test): _opcode.ENABLE_SPECIALIZATION, "requires specialization")(test) +def requires_specialization_ft(test): + return unittest.skipUnless( + _opcode.ENABLE_SPECIALIZATION_FT, "requires specialization")(test) + + +def reset_code(f: types.FunctionType) -> types.FunctionType: + """Clear all specializations, local instrumentation, and JIT code for the given function.""" + f.__code__ = f.__code__.replace() + return f + +on_github_actions = "GITHUB_ACTIONS" in os.environ + #======================================================================= # Check for the presence of docstrings. @@ -1575,8 +1666,8 @@ def __init__(self, link=None): if sys.platform == "win32": def _platform_specific(self): - import glob import _winapi + import glob if os.path.lexists(self.real) and not os.path.exists(self.real): # App symlink appears to not exist, but we want the @@ -1972,10 +2063,11 @@ def missing_compiler_executable(cmd_names=[]): missing. """ - from setuptools._distutils import ccompiler, sysconfig - from setuptools import errors import shutil + from setuptools import errors + from setuptools._distutils import ccompiler, sysconfig + compiler = ccompiler.new_compiler() sysconfig.customize_compiler(compiler) if compiler.compiler_type == "msvc": @@ -2304,7 +2396,15 @@ def skip_if_broken_multiprocessing_synchronize(): # bpo-38377: On Linux, creating a semaphore fails with OSError # if the current user does not have the permission to create # a file in /dev/shm/ directory. - synchronize.Lock(ctx=None) + import multiprocessing + synchronize.Lock(ctx=multiprocessing.get_context('fork')) + # The explicit fork mp context is required in order for + # TestResourceTracker.test_resource_tracker_reused to work. + # synchronize creates a new multiprocessing.resource_tracker + # process at module import time via the above call in that + # scenario. Awkward. This enables gh-84559. No code involved + # should have threads at that point so fork() should be safe. + except OSError as exc: raise unittest.SkipTest(f"broken multiprocessing SemLock: {exc!r}") @@ -2396,8 +2496,9 @@ def clear_ignored_deprecations(*tokens: object) -> None: raise ValueError("Provide token or tokens returned by ignore_deprecations_from") new_filters = [] + old_filters = warnings._get_filters() endswith = tuple(rf"(?#support{id(token)})" for token in tokens) - for action, message, category, module, lineno in warnings.filters: + for action, message, category, module, lineno in old_filters: if action == "ignore" and category is DeprecationWarning: if isinstance(message, re.Pattern): msg = message.pattern @@ -2406,8 +2507,8 @@ def clear_ignored_deprecations(*tokens: object) -> None: if msg.endswith(endswith): continue new_filters.append((action, message, category, module, lineno)) - if warnings.filters != new_filters: - warnings.filters[:] = new_filters + if old_filters != new_filters: + old_filters[:] = new_filters warnings._filters_mutated() @@ -2415,7 +2516,7 @@ def clear_ignored_deprecations(*tokens: object) -> None: def requires_venv_with_pip(): # ensurepip requires zlib to open ZIP archives (.whl binary wheel packages) try: - import zlib + import zlib # noqa: F401 except ImportError: return unittest.skipIf(True, "venv: ensurepip requires zlib") @@ -2455,6 +2556,7 @@ def _findwheel(pkgname): @contextlib.contextmanager def setup_venv_with_pip_setuptools(venv_dir): import subprocess + from .os_helper import temp_cwd def run_command(cmd): @@ -2610,30 +2712,30 @@ def sleeping_retry(timeout, err_msg=None, /, delay = min(delay * 2, max_delay) -class CPUStopwatch: +class Stopwatch: """Context manager to roughly time a CPU-bound operation. - Disables GC. Uses CPU time if it can (i.e. excludes sleeps & time of - other processes). + Disables GC. Uses perf_counter, which is a clock with the highest + available resolution. It is chosen even though it does include + time elapsed during sleep and is system-wide, because the + resolution of process_time is too coarse on Windows and + process_time does not exist everywhere (for example, WASM). - N.B.: - - This *includes* time spent in other threads. + Note: + - This *includes* time spent in other threads/processes. - Some systems only have a coarse resolution; check - stopwatch.clock_info.rseolution if. + stopwatch.clock_info.resolution when using the results. Usage: - with ProcessStopwatch() as stopwatch: + with Stopwatch() as stopwatch: ... elapsed = stopwatch.seconds resolution = stopwatch.clock_info.resolution """ def __enter__(self): - get_time = time.process_time - clock_info = time.get_clock_info('process_time') - if get_time() <= 0: # some platforms like WASM lack process_time() - get_time = time.monotonic - clock_info = time.get_clock_info('monotonic') + get_time = time.perf_counter + clock_info = time.get_clock_info('perf_counter') self.context = disable_gc() self.context.__enter__() self.get_time = get_time @@ -2661,6 +2763,7 @@ def adjust_int_max_str_digits(max_digits): sys.set_int_max_str_digits(current) +# XXX: RUSTPYTHON; removed in 3.14 def get_c_recursion_limit(): try: import _testcapi @@ -2671,7 +2774,7 @@ def get_c_recursion_limit(): def exceeds_recursion_limit(): """For recursion tests, easily exceeds default recursion limit.""" - return get_c_recursion_limit() * 3 + return 150_000 # Windows doesn't have os.uname() but it doesn't support s390x. @@ -2680,21 +2783,9 @@ def exceeds_recursion_limit(): Py_TRACE_REFS = hasattr(sys, 'getobjects') -# Decorator to disable optimizer while a function run -def without_optimizer(func): - try: - from _testinternalcapi import get_optimizer, set_optimizer - except ImportError: - return func - @functools.wraps(func) - def wrapper(*args, **kwargs): - save_opt = get_optimizer() - try: - set_optimizer(None) - return func(*args, **kwargs) - finally: - set_optimizer(save_opt) - return wrapper +_JIT_ENABLED = sys._jit.is_enabled() +requires_jit_enabled = unittest.skipUnless(_JIT_ENABLED, "requires JIT enabled") +requires_jit_disabled = unittest.skipIf(_JIT_ENABLED, "requires JIT disabled") _BASE_COPY_SRC_DIR_IGNORED_NAMES = frozenset({ @@ -2724,19 +2815,121 @@ def copy_python_src_ignore(path, names): return ignored -def iter_builtin_types(): - for obj in __builtins__.values(): - if not isinstance(obj, type): +# XXX Move this to the inspect module? +def walk_class_hierarchy(top, *, topdown=True): + # This is based on the logic in os.walk(). + assert isinstance(top, type), repr(top) + stack = [top] + while stack: + top = stack.pop() + if isinstance(top, tuple): + yield top continue - cls = obj - if cls.__module__ != 'builtins': + + subs = type(top).__subclasses__(top) + if topdown: + # Yield before subclass traversal if going top down. + yield top, subs + # Traverse into subclasses. + for sub in reversed(subs): + stack.append(sub) + else: + # Yield after subclass traversal if going bottom up. + stack.append((top, subs)) + # Traverse into subclasses. + for sub in reversed(subs): + stack.append(sub) + + +def iter_builtin_types(): + # First try the explicit route. + try: + import _testinternalcapi + except ImportError: + _testinternalcapi = None + if _testinternalcapi is not None: + yield from _testinternalcapi.get_static_builtin_types() + return + + # Fall back to making a best-effort guess. + if hasattr(object, '__flags__'): + # Look for any type object with the Py_TPFLAGS_STATIC_BUILTIN flag set. + import datetime + seen = set() + for cls, subs in walk_class_hierarchy(object): + if cls in seen: + continue + seen.add(cls) + if not (cls.__flags__ & _TPFLAGS_STATIC_BUILTIN): + # Do not walk its subclasses. + subs[:] = [] + continue + yield cls + else: + # Fall back to a naive approach. + seen = set() + for obj in __builtins__.values(): + if not isinstance(obj, type): + continue + cls = obj + # XXX? + if cls.__module__ != 'builtins': + continue + if cls == ExceptionGroup: + # It's a heap type. + continue + if cls in seen: + continue + seen.add(cls) + yield cls + + +# XXX Move this to the inspect module? +def iter_name_in_mro(cls, name): + """Yield matching items found in base.__dict__ across the MRO. + + The descriptor protocol is not invoked. + + list(iter_name_in_mro(cls, name))[0] is roughly equivalent to + find_name_in_mro() in Objects/typeobject.c (AKA PyType_Lookup()). + + inspect.getattr_static() is similar. + """ + # This can fail if "cls" is weird. + for base in inspect._static_getmro(cls): + # This can fail if "base" is weird. + ns = inspect._get_dunder_dict_of_class(base) + try: + obj = ns[name] + except KeyError: continue - yield cls + yield obj, base -def iter_slot_wrappers(cls): - assert cls.__module__ == 'builtins', cls +# XXX Move this to the inspect module? +def find_name_in_mro(cls, name, default=inspect._sentinel): + for res in iter_name_in_mro(cls, name): + # Return the first one. + return res + if default is not inspect._sentinel: + return default, None + raise AttributeError(name) + +# XXX The return value should always be exactly the same... +def identify_type_slot_wrappers(): + try: + import _testinternalcapi + except ImportError: + _testinternalcapi = None + if _testinternalcapi is not None: + names = {n: None for n in _testinternalcapi.identify_type_slot_wrappers()} + return list(names) + else: + raise NotImplementedError + + +def iter_slot_wrappers(cls): def is_slot_wrapper(name, value): if not isinstance(value, types.WrapperDescriptorType): assert not repr(value).startswith(' dict[str, str]: return clean_env -def initialized_with_pyrepl(): - """Detect whether PyREPL was used during Python initialization.""" - # If the main module has a __file__ attribute it's a Python module, which means PyREPL. - return hasattr(sys.modules["__main__"], "__file__") +WINDOWS_STATUS = { + 0xC0000005: "STATUS_ACCESS_VIOLATION", + 0xC00000FD: "STATUS_STACK_OVERFLOW", + 0xC000013A: "STATUS_CONTROL_C_EXIT", +} + +def get_signal_name(exitcode): + import signal + + if exitcode < 0: + signum = -exitcode + try: + return signal.Signals(signum).name + except ValueError: + pass + + # Shell exit code (ex: WASI build) + if 128 < exitcode < 256: + signum = exitcode - 128 + try: + return signal.Signals(signum).name + except ValueError: + pass + + try: + return WINDOWS_STATUS[exitcode] + except KeyError: + pass + return None class BrokenIter: def __init__(self, init_raises=False, next_raises=False, iter_raises=False): @@ -2849,222 +3104,166 @@ def __iter__(self): return self -def linked_to_musl(): +def in_systemd_nspawn_sync_suppressed() -> bool: """ - Test if the Python executable is linked to the musl C library. + Test whether the test suite is runing in systemd-nspawn + with ``--suppress-sync=true``. + + This can be used to skip tests that rely on ``fsync()`` calls + and similar not being intercepted. """ - if sys.platform != 'linux': + + if not hasattr(os, "O_SYNC"): return False - import subprocess - exe = getattr(sys, '_base_executable', sys.executable) - cmd = ['ldd', exe] try: - stdout = subprocess.check_output(cmd, - text=True, - stderr=subprocess.STDOUT) - except (OSError, subprocess.CalledProcessError): + with open("/run/systemd/container", "rb") as fp: + if fp.read().rstrip() != b"systemd-nspawn": + return False + except FileNotFoundError: return False - return ('musl' in stdout) + # If systemd-nspawn is used, O_SYNC flag will immediately + # trigger EINVAL. Otherwise, ENOENT will be given instead. + import errno + try: + fd = os.open(__file__, os.O_RDONLY | os.O_SYNC) + except OSError as err: + if err.errno == errno.EINVAL: + return True + else: + os.close(fd) -# TODO: RUSTPYTHON -# Every line of code below allowed us to update `Lib/test/support/__init__.py` without -# needing to update `libregtest` and its dependencies. -# Ideally we want to remove all code below and update `libregtest`. -# -# Code below was copied from: https://github.com/RustPython/RustPython/blob/9499d39f55b73535e2405bf208d5380241f79ada/Lib/test/support/__init__.py + return False -from .testresult import get_test_runner +def run_no_yield_async_fn(async_fn, /, *args, **kwargs): + coro = async_fn(*args, **kwargs) + try: + coro.send(None) + except StopIteration as e: + return e.value + else: + raise AssertionError("coroutine did not complete") + finally: + coro.close() -def _filter_suite(suite, pred): - """Recursively filter test cases in a suite based on a predicate.""" - newtests = [] - for test in suite._tests: - if isinstance(test, unittest.TestSuite): - _filter_suite(test, pred) - newtests.append(test) - else: - if pred(test): - newtests.append(test) - suite._tests = newtests -# By default, don't filter tests -_match_test_func = None +@types.coroutine +def async_yield(v): + return (yield v) -_accept_test_patterns = None -_ignore_test_patterns = None -def match_test(test): - # Function used by support.run_unittest() and regrtest --list-cases - if _match_test_func is None: - return True - else: - return _match_test_func(test.id()) +def run_yielding_async_fn(async_fn, /, *args, **kwargs): + coro = async_fn(*args, **kwargs) + try: + while True: + try: + coro.send(None) + except StopIteration as e: + return e.value + finally: + coro.close() -def _is_full_match_test(pattern): - # If a pattern contains at least one dot, it's considered - # as a full test identifier. - # Example: 'test.test_os.FileTests.test_access'. - # - # ignore patterns which contain fnmatch patterns: '*', '?', '[...]' - # or '[!...]'. For example, ignore 'test_access*'. - return ('.' in pattern) and (not re.search(r'[?*\[\]]', pattern)) - -def set_match_tests(accept_patterns=None, ignore_patterns=None): - global _match_test_func, _accept_test_patterns, _ignore_test_patterns - - if accept_patterns is None: - accept_patterns = () - if ignore_patterns is None: - ignore_patterns = () - - accept_func = ignore_func = None - - if accept_patterns != _accept_test_patterns: - accept_patterns, accept_func = _compile_match_function(accept_patterns) - if ignore_patterns != _ignore_test_patterns: - ignore_patterns, ignore_func = _compile_match_function(ignore_patterns) - - # Create a copy since patterns can be mutable and so modified later - _accept_test_patterns = tuple(accept_patterns) - _ignore_test_patterns = tuple(ignore_patterns) - - if accept_func is not None or ignore_func is not None: - def match_function(test_id): - accept = True - ignore = False - if accept_func: - accept = accept_func(test_id) - if ignore_func: - ignore = ignore_func(test_id) - return accept and not ignore - - _match_test_func = match_function - -def _compile_match_function(patterns): - if not patterns: - func = None - # set_match_tests(None) behaves as set_match_tests(()) - patterns = () - elif all(map(_is_full_match_test, patterns)): - # Simple case: all patterns are full test identifier. - # The test.bisect_cmd utility only uses such full test identifiers. - func = set(patterns).__contains__ - else: - import fnmatch - regex = '|'.join(map(fnmatch.translate, patterns)) - # The search *is* case sensitive on purpose: - # don't use flags=re.IGNORECASE - regex_match = re.compile(regex).match - - def match_test_regex(test_id): - if regex_match(test_id): - # The regex matches the whole identifier, for example - # 'test.test_os.FileTests.test_access'. - return True - else: - # Try to match parts of the test identifier. - # For example, split 'test.test_os.FileTests.test_access' - # into: 'test', 'test_os', 'FileTests' and 'test_access'. - return any(map(regex_match, test_id.split("."))) - - func = match_test_regex - - return patterns, func - -def run_unittest(*classes): - """Run tests from unittest.TestCase-derived classes.""" - valid_types = (unittest.TestSuite, unittest.TestCase) - loader = unittest.TestLoader() - suite = unittest.TestSuite() - for cls in classes: - if isinstance(cls, str): - if cls in sys.modules: - suite.addTest(loader.loadTestsFromModule(sys.modules[cls])) - else: - raise ValueError("str arguments must be keys in sys.modules") - elif isinstance(cls, valid_types): - suite.addTest(cls) - else: - suite.addTest(loader.loadTestsFromTestCase(cls)) - _filter_suite(suite, match_test) - return _run_suite(suite) - -def _run_suite(suite): - """Run tests from a unittest.TestSuite-derived class.""" - runner = get_test_runner(sys.stdout, - verbosity=verbose, - capture_output=(junit_xml_list is not None)) - - result = runner.run(suite) - - if junit_xml_list is not None: - junit_xml_list.append(result.get_xml_element()) - - if not result.testsRun and not result.skipped and not result.errors: - raise TestDidNotRun - if not result.wasSuccessful(): - stats = TestStats.from_unittest(result) - if len(result.errors) == 1 and not result.failures: - err = result.errors[0][1] - elif len(result.failures) == 1 and not result.errors: - err = result.failures[0][1] - else: - err = "multiple errors occurred" - if not verbose: err += "; run in verbose mode for details" - errors = [(str(tc), exc_str) for tc, exc_str in result.errors] - failures = [(str(tc), exc_str) for tc, exc_str in result.failures] - raise TestFailedWithDetails(err, errors, failures, stats=stats) - return result -@dataclasses.dataclass(slots=True) -class TestStats: - tests_run: int = 0 - failures: int = 0 - skipped: int = 0 +def is_libssl_fips_mode(): + try: + from _hashlib import get_fips_mode # ask _hashopenssl.c + except ImportError: + return False # more of a maybe, unless we add this to the _ssl module. + return get_fips_mode() != 0 + +def _supports_remote_attaching(): + PROCESS_VM_READV_SUPPORTED = False + + try: + from _remote_debugging import PROCESS_VM_READV_SUPPORTED + except ImportError: + pass - @staticmethod - def from_unittest(result): - return TestStats(result.testsRun, - len(result.failures), - len(result.skipped)) + return PROCESS_VM_READV_SUPPORTED - @staticmethod - def from_doctest(results): - return TestStats(results.attempted, - results.failed) +def _support_remote_exec_only_impl(): + if not sys.is_remote_debug_enabled(): + return unittest.skip("Remote debugging is not enabled") + if sys.platform not in ("darwin", "linux", "win32"): + return unittest.skip("Test only runs on Linux, Windows and macOS") + if sys.platform == "linux" and not _supports_remote_attaching(): + return unittest.skip("Test only runs on Linux with process_vm_readv support") + return _id - def accumulate(self, stats): - self.tests_run += stats.tests_run - self.failures += stats.failures - self.skipped += stats.skipped +def support_remote_exec_only(test): + return _support_remote_exec_only_impl()(test) +class EqualToForwardRef: + """Helper to ease use of annotationlib.ForwardRef in tests. -def run_doctest(module, verbosity=None, optionflags=0): - """Run doctest on the given module. Return (#failures, #tests). + This checks only attributes that can be set using the constructor. - If optional argument verbosity is not specified (or is None), pass - support's belief about verbosity on to doctest. Else doctest's - usual behavior is used (it searches sys.argv for -v). """ - import doctest + def __init__( + self, + arg, + *, + module=None, + owner=None, + is_class=False, + ): + self.__forward_arg__ = arg + self.__forward_is_class__ = is_class + self.__forward_module__ = module + self.__owner__ = owner - if verbosity is None: - verbosity = verbose - else: - verbosity = None - - results = doctest.testmod(module, - verbose=verbosity, - optionflags=optionflags) - if results.failed: - stats = TestStats.from_doctest(results) - raise TestFailed(f"{results.failed} of {results.attempted} " - f"doctests failed", - stats=stats) - if verbose: - print('doctest (%s) ... %d tests with zero failures' % - (module.__name__, results.attempted)) - return results + def __eq__(self, other): + if not isinstance(other, (EqualToForwardRef, annotationlib.ForwardRef)): + return NotImplemented + return ( + self.__forward_arg__ == other.__forward_arg__ + and self.__forward_module__ == other.__forward_module__ + and self.__forward_is_class__ == other.__forward_is_class__ + and self.__owner__ == other.__owner__ + ) + + def __repr__(self): + extra = [] + if self.__forward_module__ is not None: + extra.append(f", module={self.__forward_module__!r}") + if self.__forward_is_class__: + extra.append(", is_class=True") + if self.__owner__ is not None: + extra.append(f", owner={self.__owner__!r}") + return f"EqualToForwardRef({self.__forward_arg__!r}{''.join(extra)})" + + +_linked_to_musl = None +def linked_to_musl(): + """ + Report if the Python executable is linked to the musl C library. + + Return False if we don't think it is, or a version triple otherwise. + """ + # This is can be a relatively expensive check, so we use a cache. + global _linked_to_musl + if _linked_to_musl is not None: + return _linked_to_musl + + # emscripten (at least as far as we're concerned) and wasi use musl, + # but platform doesn't know how to get the version, so set it to zero. + if is_wasm32: + _linked_to_musl = (0, 0, 0) + return _linked_to_musl + + # On all other non-linux platforms assume no musl. + if sys.platform != 'linux': + _linked_to_musl = False + return _linked_to_musl + + # On linux, we'll depend on the platform module to do the check, so new + # musl platforms should add support in that module if possible. + import platform + lib, version = platform.libc_ver() + if lib != 'musl': + _linked_to_musl = False + return _linked_to_musl + _linked_to_musl = tuple(map(int, version.split('.'))) + return _linked_to_musl diff --git a/Lib/test/support/_hypothesis_stubs/__init__.py b/Lib/test/support/_hypothesis_stubs/__init__.py index 6ba5bb814b9..9a57c309616 100644 --- a/Lib/test/support/_hypothesis_stubs/__init__.py +++ b/Lib/test/support/_hypothesis_stubs/__init__.py @@ -1,6 +1,6 @@ -from enum import Enum import functools import unittest +from enum import Enum __all__ = [ "given", diff --git a/Lib/test/support/ast_helper.py b/Lib/test/support/ast_helper.py index 8a0415b6aae..98eaf0b2721 100644 --- a/Lib/test/support/ast_helper.py +++ b/Lib/test/support/ast_helper.py @@ -1,5 +1,6 @@ import ast + class ASTTestMixin: """Test mixing to have basic assertions for AST nodes.""" @@ -16,6 +17,9 @@ def traverse_compare(a, b, missing=object()): self.fail(f"{type(a)!r} is not {type(b)!r}") if isinstance(a, ast.AST): for field in a._fields: + if isinstance(a, ast.Constant) and field == "kind": + # Skip the 'kind' field for ast.Constant + continue value1 = getattr(a, field, missing) value2 = getattr(b, field, missing) # Singletons are equal by definition, so further diff --git a/Lib/test/support/asynchat.py b/Lib/test/support/asynchat.py index 38c47a1fda6..a8c6b28a9e1 100644 --- a/Lib/test/support/asynchat.py +++ b/Lib/test/support/asynchat.py @@ -1,5 +1,5 @@ # TODO: This module was deprecated and removed from CPython 3.12 -# Now it is a test-only helper. Any attempts to rewrite exising tests that +# Now it is a test-only helper. Any attempts to rewrite existing tests that # are using this module and remove it completely are appreciated! # See: https://github.com/python/cpython/issues/72719 diff --git a/Lib/test/support/asyncore.py b/Lib/test/support/asyncore.py index b397aca5568..658c22fdcee 100644 --- a/Lib/test/support/asyncore.py +++ b/Lib/test/support/asyncore.py @@ -1,5 +1,5 @@ # TODO: This module was deprecated and removed from CPython 3.12 -# Now it is a test-only helper. Any attempts to rewrite exising tests that +# Now it is a test-only helper. Any attempts to rewrite existing tests that # are using this module and remove it completely are appreciated! # See: https://github.com/python/cpython/issues/72719 @@ -51,17 +51,27 @@ sophisticated high-performance network servers and clients a snap. """ +import os import select import socket import sys import time import warnings - -import os -from errno import EALREADY, EINPROGRESS, EWOULDBLOCK, ECONNRESET, EINVAL, \ - ENOTCONN, ESHUTDOWN, EISCONN, EBADF, ECONNABORTED, EPIPE, EAGAIN, \ - errorcode - +from errno import ( + EAGAIN, + EALREADY, + EBADF, + ECONNABORTED, + ECONNRESET, + EINPROGRESS, + EINVAL, + EISCONN, + ENOTCONN, + EPIPE, + ESHUTDOWN, + EWOULDBLOCK, + errorcode, +) _DISCONNECTED = frozenset({ECONNRESET, ENOTCONN, ESHUTDOWN, ECONNABORTED, EPIPE, EBADF}) diff --git a/Lib/test/support/bytecode_helper.py b/Lib/test/support/bytecode_helper.py index 85bcd1f0f1c..4a3c8c2c4f1 100644 --- a/Lib/test/support/bytecode_helper.py +++ b/Lib/test/support/bytecode_helper.py @@ -1,9 +1,10 @@ """bytecode_helper - support tools for testing correct bytecode generation""" -import unittest import dis import io import opcode +import unittest + try: import _testinternalcapi except ImportError: @@ -71,7 +72,7 @@ class Label: def assertInstructionsMatch(self, actual_seq, expected): # get an InstructionSequence and an expected list, where each - # entry is a label or an instruction tuple. Construct an expcted + # entry is a label or an instruction tuple. Construct an expected # instruction sequence and compare with the one given. self.assertIsInstance(expected, list) diff --git a/Lib/test/support/interpreters/channels.py b/Lib/test/support/channels.py similarity index 73% rename from Lib/test/support/interpreters/channels.py rename to Lib/test/support/channels.py index d2bd93d77f7..3f7b46030fd 100644 --- a/Lib/test/support/interpreters/channels.py +++ b/Lib/test/support/channels.py @@ -1,19 +1,23 @@ """Cross-interpreter Channels High Level Module.""" import time +from concurrent.interpreters import _crossinterp +from concurrent.interpreters._crossinterp import ( + UNBOUND_ERROR, + UNBOUND_REMOVE, +) + import _interpchannels as _channels -from . import _crossinterp # aliases: from _interpchannels import ( - ChannelError, ChannelNotFoundError, ChannelClosedError, - ChannelEmptyError, ChannelNotEmptyError, -) -from ._crossinterp import ( - UNBOUND_ERROR, UNBOUND_REMOVE, + ChannelClosedError, + ChannelEmptyError, + ChannelError, + ChannelNotEmptyError, + ChannelNotFoundError, ) - __all__ = [ 'UNBOUND', 'UNBOUND_ERROR', 'UNBOUND_REMOVE', 'create', 'list_all', @@ -55,15 +59,23 @@ def create(*, unbounditems=UNBOUND): """ unbound = _serialize_unbound(unbounditems) unboundop, = unbound - cid = _channels.create(unboundop) - recv, send = RecvChannel(cid), SendChannel(cid, _unbound=unbound) + cid = _channels.create(unboundop, -1) + recv, send = RecvChannel(cid), SendChannel(cid) + send._set_unbound(unboundop, unbounditems) return recv, send def list_all(): """Return a list of (recv, send) for all open channels.""" - return [(RecvChannel(cid), SendChannel(cid, _unbound=unbound)) - for cid, unbound in _channels.list_all()] + channels = [] + for cid, unboundop, _ in _channels.list_all(): + chan = _, send = RecvChannel(cid), SendChannel(cid) + if not hasattr(send, '_unboundop'): + send._set_unbound(unboundop) + else: + assert send._unbound[0] == unboundop + channels.append(chan) + return channels class _ChannelEnd: @@ -97,12 +109,8 @@ def __eq__(self, other): return other._id == self._id # for pickling: - def __getnewargs__(self): - return (int(self._id),) - - # for pickling: - def __getstate__(self): - return None + def __reduce__(self): + return (type(self), (int(self._id),)) @property def id(self): @@ -175,16 +183,33 @@ class SendChannel(_ChannelEnd): _end = 'send' - def __new__(cls, cid, *, _unbound=None): - if _unbound is None: - try: - op = _channels.get_channel_defaults(cid) - _unbound = (op,) - except ChannelNotFoundError: - _unbound = _serialize_unbound(UNBOUND) - self = super().__new__(cls, cid) - self._unbound = _unbound - return self +# def __new__(cls, cid, *, _unbound=None): +# if _unbound is None: +# try: +# op = _channels.get_channel_defaults(cid) +# _unbound = (op,) +# except ChannelNotFoundError: +# _unbound = _serialize_unbound(UNBOUND) +# self = super().__new__(cls, cid) +# self._unbound = _unbound +# return self + + def _set_unbound(self, op, items=None): + assert not hasattr(self, '_unbound') + if items is None: + items = _resolve_unbound(op) + unbound = (op, items) + self._unbound = unbound + return unbound + + @property + def unbounditems(self): + try: + _, items = self._unbound + except AttributeError: + op, _ = _channels.get_queue_defaults(self._id) + _, items = self._set_unbound(op) + return items @property def is_closed(self): @@ -192,61 +217,61 @@ def is_closed(self): return info.closed or info.closing def send(self, obj, timeout=None, *, - unbound=None, + unbounditems=None, ): """Send the object (i.e. its data) to the channel's receiving end. This blocks until the object is received. """ - if unbound is None: - unboundop, = self._unbound + if unbounditems is None: + unboundop = -1 else: - unboundop, = _serialize_unbound(unbound) + unboundop, = _serialize_unbound(unbounditems) _channels.send(self._id, obj, unboundop, timeout=timeout, blocking=True) def send_nowait(self, obj, *, - unbound=None, + unbounditems=None, ): """Send the object to the channel's receiving end. If the object is immediately received then return True (else False). Otherwise this is the same as send(). """ - if unbound is None: - unboundop, = self._unbound + if unbounditems is None: + unboundop = -1 else: - unboundop, = _serialize_unbound(unbound) + unboundop, = _serialize_unbound(unbounditems) # XXX Note that at the moment channel_send() only ever returns # None. This should be fixed when channel_send_wait() is added. # See bpo-32604 and gh-19829. return _channels.send(self._id, obj, unboundop, blocking=False) def send_buffer(self, obj, timeout=None, *, - unbound=None, + unbounditems=None, ): """Send the object's buffer to the channel's receiving end. This blocks until the object is received. """ - if unbound is None: - unboundop, = self._unbound + if unbounditems is None: + unboundop = -1 else: - unboundop, = _serialize_unbound(unbound) + unboundop, = _serialize_unbound(unbounditems) _channels.send_buffer(self._id, obj, unboundop, timeout=timeout, blocking=True) def send_buffer_nowait(self, obj, *, - unbound=None, + unbounditems=None, ): """Send the object's buffer to the channel's receiving end. If the object is immediately received then return True (else False). Otherwise this is the same as send(). """ - if unbound is None: - unboundop, = self._unbound + if unbounditems is None: + unboundop = -1 else: - unboundop, = _serialize_unbound(unbound) + unboundop, = _serialize_unbound(unbounditems) return _channels.send_buffer(self._id, obj, unboundop, blocking=False) def close(self): diff --git a/Lib/test/support/hashlib_helper.py b/Lib/test/support/hashlib_helper.py index a4e6c92203a..75dc2ba7506 100644 --- a/Lib/test/support/hashlib_helper.py +++ b/Lib/test/support/hashlib_helper.py @@ -1,51 +1,330 @@ import functools import hashlib +import importlib import unittest +from test.support.import_helper import import_module + try: import _hashlib except ImportError: _hashlib = None +try: + import _hmac +except ImportError: + _hmac = None + + +def requires_hashlib(): + return unittest.skipIf(_hashlib is None, "requires _hashlib") + + +def requires_builtin_hmac(): + return unittest.skipIf(_hmac is None, "requires _hmac") + + +def _missing_hash(digestname, implementation=None, *, exc=None): + parts = ["missing", implementation, f"hash algorithm: {digestname!r}"] + msg = " ".join(filter(None, parts)) + raise unittest.SkipTest(msg) from exc + + +def _openssl_availabillity(digestname, *, usedforsecurity): + try: + _hashlib.new(digestname, usedforsecurity=usedforsecurity) + except AttributeError: + assert _hashlib is None + _missing_hash(digestname, "OpenSSL") + except ValueError as exc: + _missing_hash(digestname, "OpenSSL", exc=exc) + + +def _decorate_func_or_class(func_or_class, decorator_func): + if not isinstance(func_or_class, type): + return decorator_func(func_or_class) + + decorated_class = func_or_class + setUpClass = decorated_class.__dict__.get('setUpClass') + if setUpClass is None: + def setUpClass(cls): + super(decorated_class, cls).setUpClass() + setUpClass.__qualname__ = decorated_class.__qualname__ + '.setUpClass' + setUpClass.__module__ = decorated_class.__module__ + else: + setUpClass = setUpClass.__func__ + setUpClass = classmethod(decorator_func(setUpClass)) + decorated_class.setUpClass = setUpClass + return decorated_class + def requires_hashdigest(digestname, openssl=None, usedforsecurity=True): - """Decorator raising SkipTest if a hashing algorithm is not available + """Decorator raising SkipTest if a hashing algorithm is not available. - The hashing algorithm could be missing or blocked by a strict crypto - policy. + The hashing algorithm may be missing, blocked by a strict crypto policy, + or Python may be configured with `--with-builtin-hashlib-hashes=no`. If 'openssl' is True, then the decorator checks that OpenSSL provides - the algorithm. Otherwise the check falls back to built-in - implementations. The usedforsecurity flag is passed to the constructor. + the algorithm. Otherwise the check falls back to (optional) built-in + HACL* implementations. + The usedforsecurity flag is passed to the constructor but has no effect + on HACL* implementations. + + Examples of exceptions being suppressed: ValueError: [digital envelope routines: EVP_DigestInit_ex] disabled for FIPS ValueError: unsupported hash type md4 """ + if openssl and _hashlib is not None: + def test_availability(): + _hashlib.new(digestname, usedforsecurity=usedforsecurity) + else: + def test_availability(): + hashlib.new(digestname, usedforsecurity=usedforsecurity) + + def decorator_func(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + try: + test_availability() + except ValueError as exc: + _missing_hash(digestname, exc=exc) + return func(*args, **kwargs) + return wrapper + + def decorator(func_or_class): + return _decorate_func_or_class(func_or_class, decorator_func) + return decorator + + +def requires_openssl_hashdigest(digestname, *, usedforsecurity=True): + """Decorator raising SkipTest if an OpenSSL hashing algorithm is missing. + + The hashing algorithm may be missing or blocked by a strict crypto policy. + """ + def decorator_func(func): + @requires_hashlib() # avoid checking at each call + @functools.wraps(func) + def wrapper(*args, **kwargs): + _openssl_availabillity(digestname, usedforsecurity=usedforsecurity) + return func(*args, **kwargs) + return wrapper + def decorator(func_or_class): - if isinstance(func_or_class, type): - setUpClass = func_or_class.__dict__.get('setUpClass') - if setUpClass is None: - def setUpClass(cls): - super(func_or_class, cls).setUpClass() - setUpClass.__qualname__ = func_or_class.__qualname__ + '.setUpClass' - setUpClass.__module__ = func_or_class.__module__ - else: - setUpClass = setUpClass.__func__ - setUpClass = classmethod(decorator(setUpClass)) - func_or_class.setUpClass = setUpClass - return func_or_class - - @functools.wraps(func_or_class) + return _decorate_func_or_class(func_or_class, decorator_func) + return decorator + + +def find_openssl_hashdigest_constructor(digestname, *, usedforsecurity=True): + """Find the OpenSSL hash function constructor by its name.""" + assert isinstance(digestname, str), digestname + _openssl_availabillity(digestname, usedforsecurity=usedforsecurity) + # This returns a function of the form _hashlib.openssl_ and + # not a lambda function as it is rejected by _hashlib.hmac_new(). + return getattr(_hashlib, f"openssl_{digestname}") + + +def requires_builtin_hashdigest( + module_name, digestname, *, usedforsecurity=True +): + """Decorator raising SkipTest if a HACL* hashing algorithm is missing. + + - The *module_name* is the C extension module name based on HACL*. + - The *digestname* is one of its member, e.g., 'md5'. + """ + def decorator_func(func): + @functools.wraps(func) def wrapper(*args, **kwargs): + module = import_module(module_name) try: - if openssl and _hashlib is not None: - _hashlib.new(digestname, usedforsecurity=usedforsecurity) - else: - hashlib.new(digestname, usedforsecurity=usedforsecurity) - except ValueError: - raise unittest.SkipTest( - f"hash digest '{digestname}' is not available." - ) - return func_or_class(*args, **kwargs) + getattr(module, digestname) + except AttributeError: + fullname = f'{module_name}.{digestname}' + _missing_hash(fullname, implementation="HACL") + return func(*args, **kwargs) return wrapper + + def decorator(func_or_class): + return _decorate_func_or_class(func_or_class, decorator_func) return decorator + + +def find_builtin_hashdigest_constructor( + module_name, digestname, *, usedforsecurity=True +): + """Find the HACL* hash function constructor. + + - The *module_name* is the C extension module name based on HACL*. + - The *digestname* is one of its member, e.g., 'md5'. + """ + module = import_module(module_name) + try: + constructor = getattr(module, digestname) + constructor(b'', usedforsecurity=usedforsecurity) + except (AttributeError, TypeError, ValueError): + _missing_hash(f'{module_name}.{digestname}', implementation="HACL") + return constructor + + +class HashFunctionsTrait: + """Mixin trait class containing hash functions. + + This class is assumed to have all unitest.TestCase methods but should + not directly inherit from it to prevent the test suite being run on it. + + Subclasses should implement the hash functions by returning an object + that can be recognized as a valid digestmod parameter for both hashlib + and HMAC. In particular, it cannot be a lambda function as it will not + be recognized by hashlib (it will still be accepted by the pure Python + implementation of HMAC). + """ + + ALGORITHMS = [ + 'md5', 'sha1', + 'sha224', 'sha256', 'sha384', 'sha512', + 'sha3_224', 'sha3_256', 'sha3_384', 'sha3_512', + ] + + # Default 'usedforsecurity' to use when looking up a hash function. + usedforsecurity = True + + def _find_constructor(self, name): + # By default, a missing algorithm skips the test that uses it. + self.assertIn(name, self.ALGORITHMS) + self.skipTest(f"missing hash function: {name}") + + @property + def md5(self): + return self._find_constructor("md5") + + @property + def sha1(self): + return self._find_constructor("sha1") + + @property + def sha224(self): + return self._find_constructor("sha224") + + @property + def sha256(self): + return self._find_constructor("sha256") + + @property + def sha384(self): + return self._find_constructor("sha384") + + @property + def sha512(self): + return self._find_constructor("sha512") + + @property + def sha3_224(self): + return self._find_constructor("sha3_224") + + @property + def sha3_256(self): + return self._find_constructor("sha3_256") + + @property + def sha3_384(self): + return self._find_constructor("sha3_384") + + @property + def sha3_512(self): + return self._find_constructor("sha3_512") + + +class NamedHashFunctionsTrait(HashFunctionsTrait): + """Trait containing named hash functions. + + Hash functions are available if and only if they are available in hashlib. + """ + + def _find_constructor(self, name): + self.assertIn(name, self.ALGORITHMS) + return name + + +class OpenSSLHashFunctionsTrait(HashFunctionsTrait): + """Trait containing OpenSSL hash functions. + + Hash functions are available if and only if they are available in _hashlib. + """ + + def _find_constructor(self, name): + self.assertIn(name, self.ALGORITHMS) + return find_openssl_hashdigest_constructor( + name, usedforsecurity=self.usedforsecurity + ) + + +class BuiltinHashFunctionsTrait(HashFunctionsTrait): + """Trait containing HACL* hash functions. + + Hash functions are available if and only if they are available in C. + In particular, HACL* HMAC-MD5 may be available even though HACL* md5 + is not since the former is unconditionally built. + """ + + def _find_constructor_in(self, module, name): + self.assertIn(name, self.ALGORITHMS) + return find_builtin_hashdigest_constructor(module, name) + + @property + def md5(self): + return self._find_constructor_in("_md5", "md5") + + @property + def sha1(self): + return self._find_constructor_in("_sha1", "sha1") + + @property + def sha224(self): + return self._find_constructor_in("_sha2", "sha224") + + @property + def sha256(self): + return self._find_constructor_in("_sha2", "sha256") + + @property + def sha384(self): + return self._find_constructor_in("_sha2", "sha384") + + @property + def sha512(self): + return self._find_constructor_in("_sha2", "sha512") + + @property + def sha3_224(self): + return self._find_constructor_in("_sha3", "sha3_224") + + @property + def sha3_256(self): + return self._find_constructor_in("_sha3","sha3_256") + + @property + def sha3_384(self): + return self._find_constructor_in("_sha3","sha3_384") + + @property + def sha3_512(self): + return self._find_constructor_in("_sha3","sha3_512") + + +def find_gil_minsize(modules_names, default=2048): + """Get the largest GIL_MINSIZE value for the given cryptographic modules. + + The valid module names are the following: + + - _hashlib + - _md5, _sha1, _sha2, _sha3, _blake2 + - _hmac + """ + sizes = [] + for module_name in modules_names: + try: + module = importlib.import_module(module_name) + except ImportError: + continue + sizes.append(getattr(module, '_GIL_MINSIZE', default)) + return max(sizes, default=default) diff --git a/Lib/test/support/hypothesis_helper.py b/Lib/test/support/hypothesis_helper.py index a99a4963ffe..6e9e168f63a 100644 --- a/Lib/test/support/hypothesis_helper.py +++ b/Lib/test/support/hypothesis_helper.py @@ -7,9 +7,10 @@ else: # Regrtest changes to use a tempdir as the working directory, so we have # to tell Hypothesis to use the original in order to persist the database. + from hypothesis.configuration import set_hypothesis_home_dir + from test.support import has_socket_support from test.support.os_helper import SAVEDCWD - from hypothesis.configuration import set_hypothesis_home_dir set_hypothesis_home_dir(os.path.join(SAVEDCWD, ".hypothesis")) diff --git a/Lib/test/support/i18n_helper.py b/Lib/test/support/i18n_helper.py index 2e304f29e8b..af97cdc9cb5 100644 --- a/Lib/test/support/i18n_helper.py +++ b/Lib/test/support/i18n_helper.py @@ -3,10 +3,10 @@ import sys import unittest from pathlib import Path + from test.support import REPO_ROOT, TEST_HOME_DIR, requires_subprocess from test.test_tools import skip_if_missing - pygettext = Path(REPO_ROOT) / 'Tools' / 'i18n' / 'pygettext.py' msgid_pattern = re.compile(r'msgid(.*?)(?:msgid_plural|msgctxt|msgstr)', diff --git a/Lib/test/support/import_helper.py b/Lib/test/support/import_helper.py index a4ea5bc6af9..2d80b663dd5 100644 --- a/Lib/test/support/import_helper.py +++ b/Lib/test/support/import_helper.py @@ -1,14 +1,16 @@ -import contextlib import _imp +import contextlib import importlib +import importlib.machinery import importlib.util import os import shutil import sys +import textwrap import unittest import warnings -from .os_helper import unlink, temp_dir +from .os_helper import temp_dir, unlink @contextlib.contextmanager @@ -309,3 +311,132 @@ def ready_to_import(name=None, source=""): sys.modules[name] = old_module else: sys.modules.pop(name, None) + + +def ensure_lazy_imports(imported_module, modules_to_block): + """Test that when imported_module is imported, none of the modules in + modules_to_block are imported as a side effect.""" + modules_to_block = frozenset(modules_to_block) + script = textwrap.dedent( + f""" + import sys + modules_to_block = {modules_to_block} + if unexpected := modules_to_block & sys.modules.keys(): + startup = ", ".join(unexpected) + raise AssertionError(f'unexpectedly imported at startup: {{startup}}') + + import {imported_module} + if unexpected := modules_to_block & sys.modules.keys(): + after = ", ".join(unexpected) + raise AssertionError(f'unexpectedly imported after importing {imported_module}: {{after}}') + """ + ) + from .script_helper import assert_python_ok + assert_python_ok("-S", "-c", script) + + +@contextlib.contextmanager +def module_restored(name): + """A context manager that restores a module to the original state.""" + missing = object() + orig = sys.modules.get(name, missing) + if orig is None: + mod = importlib.import_module(name) + else: + mod = type(sys)(name) + mod.__dict__.update(orig.__dict__) + sys.modules[name] = mod + try: + yield mod + finally: + if orig is missing: + sys.modules.pop(name, None) + else: + sys.modules[name] = orig + + +def create_module(name, loader=None, *, ispkg=False): + """Return a new, empty module.""" + spec = importlib.machinery.ModuleSpec( + name, + loader, + origin='', + is_package=ispkg, + ) + return importlib.util.module_from_spec(spec) + + +def _ensure_module(name, ispkg, addparent, clearnone): + try: + mod = orig = sys.modules[name] + except KeyError: + mod = orig = None + missing = True + else: + missing = False + if mod is not None: + # It was already imported. + return mod, orig, missing + # Otherwise, None means it was explicitly disabled. + + assert name != '__main__' + if not missing: + assert orig is None, (name, sys.modules[name]) + if not clearnone: + raise ModuleNotFoundError(name) + del sys.modules[name] + # Try normal import, then fall back to adding the module. + try: + mod = importlib.import_module(name) + except ModuleNotFoundError: + if addparent and not clearnone: + addparent = None + mod = _add_module(name, ispkg, addparent) + return mod, orig, missing + + +def _add_module(spec, ispkg, addparent): + if isinstance(spec, str): + name = spec + mod = create_module(name, ispkg=ispkg) + spec = mod.__spec__ + else: + name = spec.name + mod = importlib.util.module_from_spec(spec) + sys.modules[name] = mod + if addparent is not False and spec.parent: + _ensure_module(spec.parent, True, addparent, bool(addparent)) + return mod + + +def add_module(spec, *, parents=True): + """Return the module after creating it and adding it to sys.modules. + + If parents is True then also create any missing parents. + """ + return _add_module(spec, False, parents) + + +def add_package(spec, *, parents=True): + """Return the module after creating it and adding it to sys.modules. + + If parents is True then also create any missing parents. + """ + return _add_module(spec, True, parents) + + +def ensure_module_imported(name, *, clearnone=True): + """Return the corresponding module. + + If it was already imported then return that. Otherwise, try + importing it (optionally clear it first if None). If that fails + then create a new empty module. + + It can be helpful to combine this with ready_to_import() and/or + isolated_modules(). + """ + if sys.modules.get(name) is not None: + mod = sys.modules[name] + else: + mod, _, _ = _ensure_module(name, False, True, clearnone) + return mod diff --git a/Lib/test/support/interpreters/__init__.py b/Lib/test/support/interpreters/__init__.py deleted file mode 100644 index e067f259364..00000000000 --- a/Lib/test/support/interpreters/__init__.py +++ /dev/null @@ -1,258 +0,0 @@ -"""Subinterpreters High Level Module.""" - -import threading -import weakref -import _interpreters - -# aliases: -from _interpreters import ( - InterpreterError, InterpreterNotFoundError, NotShareableError, - is_shareable, -) - - -__all__ = [ - 'get_current', 'get_main', 'create', 'list_all', 'is_shareable', - 'Interpreter', - 'InterpreterError', 'InterpreterNotFoundError', 'ExecutionFailed', - 'NotShareableError', - 'create_queue', 'Queue', 'QueueEmpty', 'QueueFull', -] - - -_queuemod = None - -def __getattr__(name): - if name in ('Queue', 'QueueEmpty', 'QueueFull', 'create_queue'): - global create_queue, Queue, QueueEmpty, QueueFull - ns = globals() - from .queues import ( - create as create_queue, - Queue, QueueEmpty, QueueFull, - ) - return ns[name] - else: - raise AttributeError(name) - - -_EXEC_FAILURE_STR = """ -{superstr} - -Uncaught in the interpreter: - -{formatted} -""".strip() - -class ExecutionFailed(InterpreterError): - """An unhandled exception happened during execution. - - This is raised from Interpreter.exec() and Interpreter.call(). - """ - - def __init__(self, excinfo): - msg = excinfo.formatted - if not msg: - if excinfo.type and excinfo.msg: - msg = f'{excinfo.type.__name__}: {excinfo.msg}' - else: - msg = excinfo.type.__name__ or excinfo.msg - super().__init__(msg) - self.excinfo = excinfo - - def __str__(self): - try: - formatted = self.excinfo.errdisplay - except Exception: - return super().__str__() - else: - return _EXEC_FAILURE_STR.format( - superstr=super().__str__(), - formatted=formatted, - ) - - -def create(): - """Return a new (idle) Python interpreter.""" - id = _interpreters.create(reqrefs=True) - return Interpreter(id, _ownsref=True) - - -def list_all(): - """Return all existing interpreters.""" - return [Interpreter(id, _whence=whence) - for id, whence in _interpreters.list_all(require_ready=True)] - - -def get_current(): - """Return the currently running interpreter.""" - id, whence = _interpreters.get_current() - return Interpreter(id, _whence=whence) - - -def get_main(): - """Return the main interpreter.""" - id, whence = _interpreters.get_main() - assert whence == _interpreters.WHENCE_RUNTIME, repr(whence) - return Interpreter(id, _whence=whence) - - -_known = weakref.WeakValueDictionary() - -class Interpreter: - """A single Python interpreter. - - Attributes: - - "id" - the unique process-global ID number for the interpreter - "whence" - indicates where the interpreter was created - - If the interpreter wasn't created by this module - then any method that modifies the interpreter will fail, - i.e. .close(), .prepare_main(), .exec(), and .call() - """ - - _WHENCE_TO_STR = { - _interpreters.WHENCE_UNKNOWN: 'unknown', - _interpreters.WHENCE_RUNTIME: 'runtime init', - _interpreters.WHENCE_LEGACY_CAPI: 'legacy C-API', - _interpreters.WHENCE_CAPI: 'C-API', - _interpreters.WHENCE_XI: 'cross-interpreter C-API', - _interpreters.WHENCE_STDLIB: '_interpreters module', - } - - def __new__(cls, id, /, _whence=None, _ownsref=None): - # There is only one instance for any given ID. - if not isinstance(id, int): - raise TypeError(f'id must be an int, got {id!r}') - id = int(id) - if _whence is None: - if _ownsref: - _whence = _interpreters.WHENCE_STDLIB - else: - _whence = _interpreters.whence(id) - assert _whence in cls._WHENCE_TO_STR, repr(_whence) - if _ownsref is None: - _ownsref = (_whence == _interpreters.WHENCE_STDLIB) - try: - self = _known[id] - assert hasattr(self, '_ownsref') - except KeyError: - self = super().__new__(cls) - _known[id] = self - self._id = id - self._whence = _whence - self._ownsref = _ownsref - if _ownsref: - # This may raise InterpreterNotFoundError: - _interpreters.incref(id) - return self - - def __repr__(self): - return f'{type(self).__name__}({self.id})' - - def __hash__(self): - return hash(self._id) - - def __del__(self): - self._decref() - - # for pickling: - def __getnewargs__(self): - return (self._id,) - - # for pickling: - def __getstate__(self): - return None - - def _decref(self): - if not self._ownsref: - return - self._ownsref = False - try: - _interpreters.decref(self._id) - except InterpreterNotFoundError: - pass - - @property - def id(self): - return self._id - - @property - def whence(self): - return self._WHENCE_TO_STR[self._whence] - - def is_running(self): - """Return whether or not the identified interpreter is running.""" - return _interpreters.is_running(self._id) - - # Everything past here is available only to interpreters created by - # interpreters.create(). - - def close(self): - """Finalize and destroy the interpreter. - - Attempting to destroy the current interpreter results - in an InterpreterError. - """ - return _interpreters.destroy(self._id, restrict=True) - - def prepare_main(self, ns=None, /, **kwargs): - """Bind the given values into the interpreter's __main__. - - The values must be shareable. - """ - ns = dict(ns, **kwargs) if ns is not None else kwargs - _interpreters.set___main___attrs(self._id, ns, restrict=True) - - def exec(self, code, /): - """Run the given source code in the interpreter. - - This is essentially the same as calling the builtin "exec" - with this interpreter, using the __dict__ of its __main__ - module as both globals and locals. - - There is no return value. - - If the code raises an unhandled exception then an ExecutionFailed - exception is raised, which summarizes the unhandled exception. - The actual exception is discarded because objects cannot be - shared between interpreters. - - This blocks the current Python thread until done. During - that time, the previous interpreter is allowed to run - in other threads. - """ - excinfo = _interpreters.exec(self._id, code, restrict=True) - if excinfo is not None: - raise ExecutionFailed(excinfo) - - def call(self, callable, /): - """Call the object in the interpreter with given args/kwargs. - - Only functions that take no arguments and have no closure - are supported. - - The return value is discarded. - - If the callable raises an exception then the error display - (including full traceback) is send back between the interpreters - and an ExecutionFailed exception is raised, much like what - happens with Interpreter.exec(). - """ - # XXX Support args and kwargs. - # XXX Support arbitrary callables. - # XXX Support returning the return value (e.g. via pickle). - excinfo = _interpreters.call(self._id, callable, restrict=True) - if excinfo is not None: - raise ExecutionFailed(excinfo) - - def call_in_thread(self, callable, /): - """Return a new thread that calls the object in the interpreter. - - The return value and any raised exception are discarded. - """ - def task(): - self.call(callable) - t = threading.Thread(target=task) - t.start() - return t diff --git a/Lib/test/support/interpreters/_crossinterp.py b/Lib/test/support/interpreters/_crossinterp.py deleted file mode 100644 index 544e197ba4c..00000000000 --- a/Lib/test/support/interpreters/_crossinterp.py +++ /dev/null @@ -1,102 +0,0 @@ -"""Common code between queues and channels.""" - - -class ItemInterpreterDestroyed(Exception): - """Raised when trying to get an item whose interpreter was destroyed.""" - - -class classonly: - """A non-data descriptor that makes a value only visible on the class. - - This is like the "classmethod" builtin, but does not show up on - instances of the class. It may be used as a decorator. - """ - - def __init__(self, value): - self.value = value - self.getter = classmethod(value).__get__ - self.name = None - - def __set_name__(self, cls, name): - if self.name is not None: - raise TypeError('already used') - self.name = name - - def __get__(self, obj, cls): - if obj is not None: - raise AttributeError(self.name) - # called on the class - return self.getter(None, cls) - - -class UnboundItem: - """Represents a cross-interpreter item no longer bound to an interpreter. - - An item is unbound when the interpreter that added it to the - cross-interpreter container is destroyed. - """ - - __slots__ = () - - @classonly - def singleton(cls, kind, module, name='UNBOUND'): - doc = cls.__doc__.replace('cross-interpreter container', kind) - doc = doc.replace('cross-interpreter', kind) - subclass = type( - f'Unbound{kind.capitalize()}Item', - (cls,), - dict( - _MODULE=module, - _NAME=name, - __doc__=doc, - ), - ) - return object.__new__(subclass) - - _MODULE = __name__ - _NAME = 'UNBOUND' - - def __new__(cls): - raise Exception(f'use {cls._MODULE}.{cls._NAME}') - - def __repr__(self): - return f'{self._MODULE}.{self._NAME}' -# return f'interpreters.queues.UNBOUND' - - -UNBOUND = object.__new__(UnboundItem) -UNBOUND_ERROR = object() -UNBOUND_REMOVE = object() - -_UNBOUND_CONSTANT_TO_FLAG = { - UNBOUND_REMOVE: 1, - UNBOUND_ERROR: 2, - UNBOUND: 3, -} -_UNBOUND_FLAG_TO_CONSTANT = {v: k - for k, v in _UNBOUND_CONSTANT_TO_FLAG.items()} - - -def serialize_unbound(unbound): - op = unbound - try: - flag = _UNBOUND_CONSTANT_TO_FLAG[op] - except KeyError: - raise NotImplementedError(f'unsupported unbound replacement op {op!r}') - return flag, - - -def resolve_unbound(flag, exctype_destroyed): - try: - op = _UNBOUND_FLAG_TO_CONSTANT[flag] - except KeyError: - raise NotImplementedError(f'unsupported unbound replacement op {flag!r}') - if op is UNBOUND_REMOVE: - # "remove" not possible here - raise NotImplementedError - elif op is UNBOUND_ERROR: - raise exctype_destroyed("item's original interpreter destroyed") - elif op is UNBOUND: - return UNBOUND - else: - raise NotImplementedError(repr(op)) diff --git a/Lib/test/support/interpreters/queues.py b/Lib/test/support/interpreters/queues.py deleted file mode 100644 index deb8e8613af..00000000000 --- a/Lib/test/support/interpreters/queues.py +++ /dev/null @@ -1,313 +0,0 @@ -"""Cross-interpreter Queues High Level Module.""" - -import pickle -import queue -import time -import weakref -import _interpqueues as _queues -from . import _crossinterp - -# aliases: -from _interpqueues import ( - QueueError, QueueNotFoundError, -) -from ._crossinterp import ( - UNBOUND_ERROR, UNBOUND_REMOVE, -) - -__all__ = [ - 'UNBOUND', 'UNBOUND_ERROR', 'UNBOUND_REMOVE', - 'create', 'list_all', - 'Queue', - 'QueueError', 'QueueNotFoundError', 'QueueEmpty', 'QueueFull', - 'ItemInterpreterDestroyed', -] - - -class QueueEmpty(QueueError, queue.Empty): - """Raised from get_nowait() when the queue is empty. - - It is also raised from get() if it times out. - """ - - -class QueueFull(QueueError, queue.Full): - """Raised from put_nowait() when the queue is full. - - It is also raised from put() if it times out. - """ - - -class ItemInterpreterDestroyed(QueueError, - _crossinterp.ItemInterpreterDestroyed): - """Raised from get() and get_nowait().""" - - -_SHARED_ONLY = 0 -_PICKLED = 1 - - -UNBOUND = _crossinterp.UnboundItem.singleton('queue', __name__) - - -def _serialize_unbound(unbound): - if unbound is UNBOUND: - unbound = _crossinterp.UNBOUND - return _crossinterp.serialize_unbound(unbound) - - -def _resolve_unbound(flag): - resolved = _crossinterp.resolve_unbound(flag, ItemInterpreterDestroyed) - if resolved is _crossinterp.UNBOUND: - resolved = UNBOUND - return resolved - - -def create(maxsize=0, *, syncobj=False, unbounditems=UNBOUND): - """Return a new cross-interpreter queue. - - The queue may be used to pass data safely between interpreters. - - "syncobj" sets the default for Queue.put() - and Queue.put_nowait(). - - "unbounditems" likewise sets the default. See Queue.put() for - supported values. The default value is UNBOUND, which replaces - the unbound item. - """ - fmt = _SHARED_ONLY if syncobj else _PICKLED - unbound = _serialize_unbound(unbounditems) - unboundop, = unbound - qid = _queues.create(maxsize, fmt, unboundop) - return Queue(qid, _fmt=fmt, _unbound=unbound) - - -def list_all(): - """Return a list of all open queues.""" - return [Queue(qid, _fmt=fmt, _unbound=(unboundop,)) - for qid, fmt, unboundop in _queues.list_all()] - - -_known_queues = weakref.WeakValueDictionary() - -class Queue: - """A cross-interpreter queue.""" - - def __new__(cls, id, /, *, _fmt=None, _unbound=None): - # There is only one instance for any given ID. - if isinstance(id, int): - id = int(id) - else: - raise TypeError(f'id must be an int, got {id!r}') - if _fmt is None: - if _unbound is None: - _fmt, op = _queues.get_queue_defaults(id) - _unbound = (op,) - else: - _fmt, _ = _queues.get_queue_defaults(id) - elif _unbound is None: - _, op = _queues.get_queue_defaults(id) - _unbound = (op,) - try: - self = _known_queues[id] - except KeyError: - self = super().__new__(cls) - self._id = id - self._fmt = _fmt - self._unbound = _unbound - _known_queues[id] = self - _queues.bind(id) - return self - - def __del__(self): - try: - _queues.release(self._id) - except QueueNotFoundError: - pass - try: - del _known_queues[self._id] - except KeyError: - pass - - def __repr__(self): - return f'{type(self).__name__}({self.id})' - - def __hash__(self): - return hash(self._id) - - # for pickling: - def __getnewargs__(self): - return (self._id,) - - # for pickling: - def __getstate__(self): - return None - - @property - def id(self): - return self._id - - @property - def maxsize(self): - try: - return self._maxsize - except AttributeError: - self._maxsize = _queues.get_maxsize(self._id) - return self._maxsize - - def empty(self): - return self.qsize() == 0 - - def full(self): - return _queues.is_full(self._id) - - def qsize(self): - return _queues.get_count(self._id) - - def put(self, obj, timeout=None, *, - syncobj=None, - unbound=None, - _delay=10 / 1000, # 10 milliseconds - ): - """Add the object to the queue. - - This blocks while the queue is full. - - If "syncobj" is None (the default) then it uses the - queue's default, set with create_queue(). - - If "syncobj" is false then all objects are supported, - at the expense of worse performance. - - If "syncobj" is true then the object must be "shareable". - Examples of "shareable" objects include the builtin singletons, - str, and memoryview. One benefit is that such objects are - passed through the queue efficiently. - - The key difference, though, is conceptual: the corresponding - object returned from Queue.get() will be strictly equivalent - to the given obj. In other words, the two objects will be - effectively indistinguishable from each other, even if the - object is mutable. The received object may actually be the - same object, or a copy (immutable values only), or a proxy. - Regardless, the received object should be treated as though - the original has been shared directly, whether or not it - actually is. That's a slightly different and stronger promise - than just (initial) equality, which is all "syncobj=False" - can promise. - - "unbound" controls the behavior of Queue.get() for the given - object if the current interpreter (calling put()) is later - destroyed. - - If "unbound" is None (the default) then it uses the - queue's default, set with create_queue(), - which is usually UNBOUND. - - If "unbound" is UNBOUND_ERROR then get() will raise an - ItemInterpreterDestroyed exception if the original interpreter - has been destroyed. This does not otherwise affect the queue; - the next call to put() will work like normal, returning the next - item in the queue. - - If "unbound" is UNBOUND_REMOVE then the item will be removed - from the queue as soon as the original interpreter is destroyed. - Be aware that this will introduce an imbalance between put() - and get() calls. - - If "unbound" is UNBOUND then it is returned by get() in place - of the unbound item. - """ - if syncobj is None: - fmt = self._fmt - else: - fmt = _SHARED_ONLY if syncobj else _PICKLED - if unbound is None: - unboundop, = self._unbound - else: - unboundop, = _serialize_unbound(unbound) - if timeout is not None: - timeout = int(timeout) - if timeout < 0: - raise ValueError(f'timeout value must be non-negative') - end = time.time() + timeout - if fmt is _PICKLED: - obj = pickle.dumps(obj) - while True: - try: - _queues.put(self._id, obj, fmt, unboundop) - except QueueFull as exc: - if timeout is not None and time.time() >= end: - raise # re-raise - time.sleep(_delay) - else: - break - - def put_nowait(self, obj, *, syncobj=None, unbound=None): - if syncobj is None: - fmt = self._fmt - else: - fmt = _SHARED_ONLY if syncobj else _PICKLED - if unbound is None: - unboundop, = self._unbound - else: - unboundop, = _serialize_unbound(unbound) - if fmt is _PICKLED: - obj = pickle.dumps(obj) - _queues.put(self._id, obj, fmt, unboundop) - - def get(self, timeout=None, *, - _delay=10 / 1000, # 10 milliseconds - ): - """Return the next object from the queue. - - This blocks while the queue is empty. - - If the next item's original interpreter has been destroyed - then the "next object" is determined by the value of the - "unbound" argument to put(). - """ - if timeout is not None: - timeout = int(timeout) - if timeout < 0: - raise ValueError(f'timeout value must be non-negative') - end = time.time() + timeout - while True: - try: - obj, fmt, unboundop = _queues.get(self._id) - except QueueEmpty as exc: - if timeout is not None and time.time() >= end: - raise # re-raise - time.sleep(_delay) - else: - break - if unboundop is not None: - assert obj is None, repr(obj) - return _resolve_unbound(unboundop) - if fmt == _PICKLED: - obj = pickle.loads(obj) - else: - assert fmt == _SHARED_ONLY - return obj - - def get_nowait(self): - """Return the next object from the channel. - - If the queue is empty then raise QueueEmpty. Otherwise this - is the same as get(). - """ - try: - obj, fmt, unboundop = _queues.get(self._id) - except QueueEmpty as exc: - raise # re-raise - if unboundop is not None: - assert obj is None, repr(obj) - return _resolve_unbound(unboundop) - if fmt == _PICKLED: - obj = pickle.loads(obj) - else: - assert fmt == _SHARED_ONLY - return obj - - -_queues._register_heap_types(Queue, QueueEmpty, QueueFull) diff --git a/Lib/test/support/logging_helper.py b/Lib/test/support/logging_helper.py index 12fcca4f0f0..db556c7f5ad 100644 --- a/Lib/test/support/logging_helper.py +++ b/Lib/test/support/logging_helper.py @@ -1,5 +1,6 @@ import logging.handlers + class TestHandler(logging.handlers.BufferingHandler): def __init__(self, matcher): # BufferingHandler takes a "capacity" argument diff --git a/Lib/test/support/os_helper.py b/Lib/test/support/os_helper.py index 26c467a7ad2..d3d6fa632f9 100644 --- a/Lib/test/support/os_helper.py +++ b/Lib/test/support/os_helper.py @@ -13,7 +13,6 @@ from test import support - # Filename used for testing TESTFN_ASCII = '@test' @@ -295,6 +294,33 @@ def skip_unless_working_chmod(test): return test if ok else unittest.skip(msg)(test) +@contextlib.contextmanager +def save_mode(path, *, quiet=False): + """Context manager that restores the mode (permissions) of *path* on exit. + + Arguments: + + path: Path of the file to restore the mode of. + + quiet: if False (the default), the context manager raises an exception + on error. Otherwise, it issues only a warning and keeps the current + working directory the same. + + """ + saved_mode = os.stat(path) + try: + yield + finally: + try: + os.chmod(path, saved_mode.st_mode) + except OSError as exc: + if not quiet: + raise + warnings.warn(f'tests may fail, unable to restore the mode of ' + f'{path!r} to {saved_mode.st_mode}: {exc}', + RuntimeWarning, stacklevel=3) + + # Check whether the current effective user has the capability to override # DAC (discretionary access control). Typically user root is able to # bypass file read, write, and execute permission checks. The capability diff --git a/Lib/test/support/pty_helper.py b/Lib/test/support/pty_helper.py index 6587fd40333..7e1ae9e59b8 100644 --- a/Lib/test/support/pty_helper.py +++ b/Lib/test/support/pty_helper.py @@ -10,6 +10,7 @@ from test.support.import_helper import import_module + def run_pty(script, input=b"dummy input\r", env=None): pty = import_module('pty') output = bytearray() diff --git a/Lib/test/support/script_helper.py b/Lib/test/support/script_helper.py index 04458077d51..a338f484449 100644 --- a/Lib/test/support/script_helper.py +++ b/Lib/test/support/script_helper.py @@ -3,17 +3,16 @@ import collections import importlib -import sys import os import os.path -import subprocess import py_compile - +import subprocess +import sys from importlib.util import source_from_cache + from test import support from test.support.import_helper import make_legacy_pyc - # Cached result of the expensive test performed in the function below. __cached_interp_requires_environment = None @@ -70,23 +69,25 @@ def fail(self, cmd_line): out = b'(... truncated stdout ...)' + out[-maxlen:] if len(err) > maxlen: err = b'(... truncated stderr ...)' + err[-maxlen:] - out = out.decode('ascii', 'replace').rstrip() - err = err.decode('ascii', 'replace').rstrip() - raise AssertionError("Process return code is %d\n" - "command line: %r\n" - "\n" - "stdout:\n" - "---\n" - "%s\n" - "---\n" - "\n" - "stderr:\n" - "---\n" - "%s\n" - "---" - % (self.rc, cmd_line, - out, - err)) + out = out.decode('utf8', 'replace').rstrip() + err = err.decode('utf8', 'replace').rstrip() + + exitcode = self.rc + signame = support.get_signal_name(exitcode) + if signame: + exitcode = f"{exitcode} ({signame})" + raise AssertionError(f"Process return code is {exitcode}\n" + f"command line: {cmd_line!r}\n" + f"\n" + f"stdout:\n" + f"---\n" + f"{out}\n" + f"---\n" + f"\n" + f"stderr:\n" + f"---\n" + f"{err}\n" + f"---") # Executing the interpreter in a subprocess diff --git a/Lib/test/support/smtpd.py b/Lib/test/support/smtpd.py old mode 100644 new mode 100755 index 6052232ec2b..cf333aaf6b0 --- a/Lib/test/support/smtpd.py +++ b/Lib/test/support/smtpd.py @@ -7,7 +7,7 @@ --nosetuid -n - This program generally tries to setuid `nobody', unless this flag is + This program generally tries to setuid 'nobody', unless this flag is set. The setuid call will fail if this program is not run as root (in which case, use this flag). @@ -17,7 +17,7 @@ --class classname -c classname - Use `classname' as the concrete SMTP proxy class. Uses `PureProxy' by + Use 'classname' as the concrete SMTP proxy class. Uses 'PureProxy' by default. --size limit @@ -39,8 +39,8 @@ Version: %(__version__)s -If localhost is not given then `localhost' is used, and if localport is not -given then 8025 is used. If remotehost is not given then `localhost' is used, +If localhost is not given then 'localhost' is used, and if localport is not +given then 8025 is used. If remotehost is not given then 'localhost' is used, and if remoteport is not given, then 25 is used. """ @@ -70,16 +70,17 @@ # - Handle more ESMTP extensions # - handle error codes from the backend smtpd -import sys -import os +import collections import errno import getopt -import time +import os import socket -import collections -from test.support import asyncore, asynchat -from warnings import warn +import sys +import time from email._header_value_parser import get_addr_spec, get_angle_addr +from warnings import warn + +from test.support import asynchat, asyncore __all__ = [ "SMTPChannel", "SMTPServer", "DebuggingServer", "PureProxy", @@ -633,7 +634,8 @@ def __init__(self, localaddr, remoteaddr, " be set to True at the same time") asyncore.dispatcher.__init__(self, map=map) try: - gai_results = socket.getaddrinfo(*localaddr, + family = 0 if socket.has_ipv6 else socket.AF_INET + gai_results = socket.getaddrinfo(*localaddr, family=family, type=socket.SOCK_STREAM) self.create_socket(gai_results[0][0], gai_results[0][1]) # try to re-use a server port if possible @@ -672,9 +674,9 @@ def process_message(self, peer, mailfrom, rcpttos, data, **kwargs): message to. data is a string containing the entire full text of the message, - headers (if supplied) and all. It has been `de-transparencied' + headers (if supplied) and all. It has been 'de-transparencied' according to RFC 821, Section 4.5.2. In other words, a line - containing a `.' followed by other text has had the leading dot + containing a '.' followed by other text has had the leading dot removed. kwargs is a dictionary containing additional information. It is @@ -685,7 +687,7 @@ def process_message(self, peer, mailfrom, rcpttos, data, **kwargs): ['BODY=8BITMIME', 'SMTPUTF8']. 'rcpt_options': same, for the rcpt command. - This function should return None for a normal `250 Ok' response; + This function should return None for a normal '250 Ok' response; otherwise, it should return the desired response string in RFC 821 format. diff --git a/Lib/test/support/socket_helper.py b/Lib/test/support/socket_helper.py index a41e487f3e4..655ffbea0db 100644 --- a/Lib/test/support/socket_helper.py +++ b/Lib/test/support/socket_helper.py @@ -2,8 +2,8 @@ import errno import os.path import socket -import sys import subprocess +import sys import tempfile import unittest diff --git a/Lib/test/support/strace_helper.py b/Lib/test/support/strace_helper.py new file mode 100644 index 00000000000..abc93dee2ce --- /dev/null +++ b/Lib/test/support/strace_helper.py @@ -0,0 +1,210 @@ +import os +import re +import sys +import textwrap +import unittest +from dataclasses import dataclass +from functools import cache + +from test import support +from test.support.script_helper import run_python_until_end + +_strace_binary = "/usr/bin/strace" +_syscall_regex = re.compile( + r"(?P[^(]*)\((?P[^)]*)\)\s*[=]\s*(?P.+)") +_returncode_regex = re.compile( + br"\+\+\+ exited with (?P\d+) \+\+\+") + + +@dataclass +class StraceEvent: + syscall: str + args: list[str] + returncode: str + + +@dataclass +class StraceResult: + strace_returncode: int + python_returncode: int + + """The event messages generated by strace. This is very similar to the + stderr strace produces with returncode marker section removed.""" + event_bytes: bytes + stdout: bytes + stderr: bytes + + def events(self): + """Parse event_bytes data into system calls for easier processing. + + This assumes the program under inspection doesn't print any non-utf8 + strings which would mix into the strace output.""" + decoded_events = self.event_bytes.decode('utf-8', 'surrogateescape') + matches = [ + _syscall_regex.match(event) + for event in decoded_events.splitlines() + ] + return [ + StraceEvent(match["syscall"], + [arg.strip() for arg in (match["args"].split(","))], + match["returncode"]) for match in matches if match + ] + + def sections(self): + """Find all "MARK " writes and use them to make groups of events. + + This is useful to avoid variable / overhead events, like those at + interpreter startup or when opening a file so a test can verify just + the small case under study.""" + current_section = "__startup" + sections = {current_section: []} + for event in self.events(): + if event.syscall == 'write' and len( + event.args) > 2 and event.args[1].startswith("\"MARK "): + # Found a new section, don't include the write in the section + # but all events until next mark should be in that section + current_section = event.args[1].split( + " ", 1)[1].removesuffix('\\n"') + if current_section not in sections: + sections[current_section] = list() + else: + sections[current_section].append(event) + + return sections + +def _filter_memory_call(call): + # mmap can operate on a fd or "MAP_ANONYMOUS" which gives a block of memory. + # Ignore "MAP_ANONYMOUS + the "MAP_ANON" alias. + if call.syscall == "mmap" and "MAP_ANON" in call.args[3]: + return True + + if call.syscall in ("munmap", "mprotect"): + return True + + return False + + +def filter_memory(syscalls): + """Filter out memory allocation calls from File I/O calls. + + Some calls (mmap, munmap, etc) can be used on files or to just get a block + of memory. Use this function to filter out the memory related calls from + other calls.""" + + return [call for call in syscalls if not _filter_memory_call(call)] + + +@support.requires_subprocess() +def strace_python(code, strace_flags, check=True): + """Run strace and return the trace. + + Sets strace_returncode and python_returncode to `-1` on error.""" + res = None + + def _make_error(reason, details): + return StraceResult( + strace_returncode=-1, + python_returncode=-1, + event_bytes= f"error({reason},details={details!r}) = -1".encode('utf-8'), + stdout=res.out if res else b"", + stderr=res.err if res else b"") + + # Run strace, and get out the raw text + try: + res, cmd_line = run_python_until_end( + "-c", + textwrap.dedent(code), + __run_using_command=[_strace_binary] + strace_flags, + ) + except OSError as err: + return _make_error("Caught OSError", err) + + if check and res.rc: + res.fail(cmd_line) + + # Get out program returncode + stripped = res.err.strip() + output = stripped.rsplit(b"\n", 1) + if len(output) != 2: + return _make_error("Expected strace events and exit code line", + stripped[-50:]) + + returncode_match = _returncode_regex.match(output[1]) + if not returncode_match: + return _make_error("Expected to find returncode in last line.", + output[1][:50]) + + python_returncode = int(returncode_match["returncode"]) + if check and python_returncode: + res.fail(cmd_line) + + return StraceResult(strace_returncode=res.rc, + python_returncode=python_returncode, + event_bytes=output[0], + stdout=res.out, + stderr=res.err) + + +def get_events(code, strace_flags, prelude, cleanup): + # NOTE: The flush is currently required to prevent the prints from getting + # buffered and done all at once at exit + prelude = textwrap.dedent(prelude) + code = textwrap.dedent(code) + cleanup = textwrap.dedent(cleanup) + to_run = f""" +print("MARK prelude", flush=True) +{prelude} +print("MARK code", flush=True) +{code} +print("MARK cleanup", flush=True) +{cleanup} +print("MARK __shutdown", flush=True) + """ + trace = strace_python(to_run, strace_flags) + all_sections = trace.sections() + return all_sections['code'] + + +def get_syscalls(code, strace_flags, prelude="", cleanup="", + ignore_memory=True): + """Get the syscalls which a given chunk of python code generates""" + events = get_events(code, strace_flags, prelude=prelude, cleanup=cleanup) + + if ignore_memory: + events = filter_memory(events) + + return [ev.syscall for ev in events] + + +# Moderately expensive (spawns a subprocess), so share results when possible. +@cache +def _can_strace(): + res = strace_python("import sys; sys.exit(0)", + # --trace option needs strace 5.5 (gh-133741) + ["--trace=%process"], + check=False) + if res.strace_returncode == 0 and res.python_returncode == 0: + assert res.events(), "Should have parsed multiple calls" + return True + return False + + +def requires_strace(): + if sys.platform != "linux": + return unittest.skip("Linux only, requires strace.") + + if "LD_PRELOAD" in os.environ: + # Distribution packaging (ex. Debian `fakeroot` and Gentoo `sandbox`) + # use LD_PRELOAD to intercept system calls, which changes the overall + # set of system calls which breaks tests expecting a specific set of + # system calls). + return unittest.skip("Not supported when LD_PRELOAD is intercepting system calls.") + + if support.check_sanitizer(address=True, memory=True): + return unittest.skip("LeakSanitizer does not work under ptrace (strace, gdb, etc)") + + return unittest.skipUnless(_can_strace(), "Requires working strace") + + +__all__ = ["filter_memory", "get_events", "get_syscalls", "requires_strace", + "strace_python", "StraceEvent", "StraceResult"] diff --git a/Lib/test/support/testcase.py b/Lib/test/support/testcase.py index fd32457d146..e617b19b6ac 100644 --- a/Lib/test/support/testcase.py +++ b/Lib/test/support/testcase.py @@ -1,6 +1,7 @@ from math import copysign, isnan +# XXX: RUSTPYTHON: removed in 3.14 class ExtraAssertions: def assertIsSubclass(self, cls, superclass, msg=None): diff --git a/Lib/test/support/testresult.py b/Lib/test/support/testresult.py deleted file mode 100644 index de23fdd59de..00000000000 --- a/Lib/test/support/testresult.py +++ /dev/null @@ -1,191 +0,0 @@ -'''Test runner and result class for the regression test suite. - -''' - -import functools -import io -import sys -import time -import traceback -import unittest -from test import support - -class RegressionTestResult(unittest.TextTestResult): - USE_XML = False - - def __init__(self, stream, descriptions, verbosity): - super().__init__(stream=stream, descriptions=descriptions, - verbosity=2 if verbosity else 0) - self.buffer = True - if self.USE_XML: - from xml.etree import ElementTree as ET - from datetime import datetime, UTC - self.__ET = ET - self.__suite = ET.Element('testsuite') - self.__suite.set('start', - datetime.now(UTC) - .replace(tzinfo=None) - .isoformat(' ')) - self.__e = None - self.__start_time = None - - @classmethod - def __getId(cls, test): - try: - test_id = test.id - except AttributeError: - return str(test) - try: - return test_id() - except TypeError: - return str(test_id) - return repr(test) - - def startTest(self, test): - super().startTest(test) - if self.USE_XML: - self.__e = e = self.__ET.SubElement(self.__suite, 'testcase') - self.__start_time = time.perf_counter() - - def _add_result(self, test, capture=False, **args): - if not self.USE_XML: - return - e = self.__e - self.__e = None - if e is None: - return - ET = self.__ET - - e.set('name', args.pop('name', self.__getId(test))) - e.set('status', args.pop('status', 'run')) - e.set('result', args.pop('result', 'completed')) - if self.__start_time: - e.set('time', f'{time.perf_counter() - self.__start_time:0.6f}') - - if capture: - if self._stdout_buffer is not None: - stdout = self._stdout_buffer.getvalue().rstrip() - ET.SubElement(e, 'system-out').text = stdout - if self._stderr_buffer is not None: - stderr = self._stderr_buffer.getvalue().rstrip() - ET.SubElement(e, 'system-err').text = stderr - - for k, v in args.items(): - if not k or not v: - continue - e2 = ET.SubElement(e, k) - if hasattr(v, 'items'): - for k2, v2 in v.items(): - if k2: - e2.set(k2, str(v2)) - else: - e2.text = str(v2) - else: - e2.text = str(v) - - @classmethod - def __makeErrorDict(cls, err_type, err_value, err_tb): - if isinstance(err_type, type): - if err_type.__module__ == 'builtins': - typename = err_type.__name__ - else: - typename = f'{err_type.__module__}.{err_type.__name__}' - else: - typename = repr(err_type) - - msg = traceback.format_exception(err_type, err_value, None) - tb = traceback.format_exception(err_type, err_value, err_tb) - - return { - 'type': typename, - 'message': ''.join(msg), - '': ''.join(tb), - } - - def addError(self, test, err): - self._add_result(test, True, error=self.__makeErrorDict(*err)) - super().addError(test, err) - - def addExpectedFailure(self, test, err): - self._add_result(test, True, output=self.__makeErrorDict(*err)) - super().addExpectedFailure(test, err) - - def addFailure(self, test, err): - self._add_result(test, True, failure=self.__makeErrorDict(*err)) - super().addFailure(test, err) - if support.failfast: - self.stop() - - def addSkip(self, test, reason): - self._add_result(test, skipped=reason) - super().addSkip(test, reason) - - def addSuccess(self, test): - self._add_result(test) - super().addSuccess(test) - - def addUnexpectedSuccess(self, test): - self._add_result(test, outcome='UNEXPECTED_SUCCESS') - super().addUnexpectedSuccess(test) - - def get_xml_element(self): - if not self.USE_XML: - raise ValueError("USE_XML is false") - e = self.__suite - e.set('tests', str(self.testsRun)) - e.set('errors', str(len(self.errors))) - e.set('failures', str(len(self.failures))) - return e - -class QuietRegressionTestRunner: - def __init__(self, stream, buffer=False): - self.result = RegressionTestResult(stream, None, 0) - self.result.buffer = buffer - - def run(self, test): - test(self.result) - return self.result - -def get_test_runner_class(verbosity, buffer=False): - if verbosity: - return functools.partial(unittest.TextTestRunner, - resultclass=RegressionTestResult, - buffer=buffer, - verbosity=verbosity) - return functools.partial(QuietRegressionTestRunner, buffer=buffer) - -def get_test_runner(stream, verbosity, capture_output=False): - return get_test_runner_class(verbosity, capture_output)(stream) - -if __name__ == '__main__': - import xml.etree.ElementTree as ET - RegressionTestResult.USE_XML = True - - class TestTests(unittest.TestCase): - def test_pass(self): - pass - - def test_pass_slow(self): - time.sleep(1.0) - - def test_fail(self): - print('stdout', file=sys.stdout) - print('stderr', file=sys.stderr) - self.fail('failure message') - - def test_error(self): - print('stdout', file=sys.stdout) - print('stderr', file=sys.stderr) - raise RuntimeError('error message') - - suite = unittest.TestSuite() - suite.addTest(unittest.TestLoader().loadTestsFromTestCase(TestTests)) - stream = io.StringIO() - runner_cls = get_test_runner_class(sum(a == '-v' for a in sys.argv)) - runner = runner_cls(sys.stdout) - result = runner.run(suite) - print('Output:', stream.getvalue()) - print('XML: ', end='') - for s in ET.tostringlist(result.get_xml_element()): - print(s.decode(), end='') - print() diff --git a/Lib/test/support/threading_helper.py b/Lib/test/support/threading_helper.py index afa25a76f63..9b2b8f2dff0 100644 --- a/Lib/test/support/threading_helper.py +++ b/Lib/test/support/threading_helper.py @@ -8,7 +8,6 @@ from test import support - #======================================================================= # Threading support to prevent reporting refleaks when running regrtest.py -R @@ -248,3 +247,27 @@ def requires_working_threading(*, module=False): raise unittest.SkipTest(msg) else: return unittest.skipUnless(can_start_thread, msg) + + +def run_concurrently(worker_func, nthreads, args=(), kwargs={}): + """ + Run the worker function concurrently in multiple threads. + """ + barrier = threading.Barrier(nthreads) + + def wrapper_func(*args, **kwargs): + # Wait for all threads to reach this point before proceeding. + barrier.wait() + worker_func(*args, **kwargs) + + with catch_threading_exception() as cm: + workers = [ + threading.Thread(target=wrapper_func, args=args, kwargs=kwargs) + for _ in range(nthreads) + ] + with start_threads(workers): + pass + + # If a worker thread raises an exception, re-raise it. + if cm.exc_value is not None: + raise cm.exc_value diff --git a/Lib/test/support/venv.py b/Lib/test/support/venv.py index 78e6a51ec18..b60f6097e65 100644 --- a/Lib/test/support/venv.py +++ b/Lib/test/support/venv.py @@ -1,8 +1,8 @@ import contextlib import logging import os -import subprocess import shlex +import subprocess import sys import sysconfig import tempfile @@ -68,3 +68,14 @@ def run(self, *args, **subprocess_args): raise else: return result + + +class VirtualEnvironmentMixin: + def venv(self, name=None, **venv_create_args): + venv_name = self.id() + if name: + venv_name += f'-{name}' + return VirtualEnvironment.from_tmpdir( + prefix=f'{venv_name}-venv-', + **venv_create_args, + ) diff --git a/Lib/test/support/warnings_helper.py b/Lib/test/support/warnings_helper.py index c1bf0562300..5f6f14afd74 100644 --- a/Lib/test/support/warnings_helper.py +++ b/Lib/test/support/warnings_helper.py @@ -23,8 +23,7 @@ def check_syntax_warning(testcase, statement, errtext='', testcase.assertEqual(len(warns), 1, warns) warn, = warns - testcase.assertTrue(issubclass(warn.category, SyntaxWarning), - warn.category) + testcase.assertIsSubclass(warn.category, SyntaxWarning) if errtext: testcase.assertRegex(str(warn.message), errtext) testcase.assertEqual(warn.filename, '') @@ -160,11 +159,12 @@ def _filterwarnings(filters, quiet=False): registry = frame.f_globals.get('__warningregistry__') if registry: registry.clear() - with warnings.catch_warnings(record=True) as w: - # Set filter "always" to record all warnings. Because - # test_warnings swap the module, we need to look up in - # the sys.modules dictionary. - sys.modules['warnings'].simplefilter("always") + # Because test_warnings swap the module, we need to look up in the + # sys.modules dictionary. + wmod = sys.modules['warnings'] + with wmod.catch_warnings(record=True) as w: + # Set filter "always" to record all warnings. + wmod.simplefilter("always") yield WarningsRecorder(w) # Filter the recorded warnings reraise = list(w) diff --git a/Lib/test/test__colorize.py b/Lib/test/test__colorize.py index 31dc60fec35..026277267e0 100644 --- a/Lib/test/test__colorize.py +++ b/Lib/test/test__colorize.py @@ -24,7 +24,6 @@ def supports_virtual_terminal(): class TestTheme(unittest.TestCase): - @unittest.expectedFailure # TODO: RUSTPYTHON def test_attributes(self): # only theme configurations attributes by default for field in dataclasses.fields(_colorize.Theme): diff --git a/Lib/test/test_argparse.py b/Lib/test/test_argparse.py index beda61be8a4..9b8179ef969 100644 --- a/Lib/test/test_argparse.py +++ b/Lib/test/test_argparse.py @@ -6824,7 +6824,6 @@ def test_nargs_zero(self): class TestImportStar(TestCase): - @unittest.expectedFailure # TODO: RUSTPYTHON def test(self): for name in argparse.__all__: self.assertHasAttr(argparse, name) diff --git a/Lib/test/test_ast/test_ast.py b/Lib/test/test_ast/test_ast.py index 57c19477ab6..5309939777f 100644 --- a/Lib/test/test_ast/test_ast.py +++ b/Lib/test/test_ast/test_ast.py @@ -3023,7 +3023,6 @@ def visit_Expr(self, node: ast.Expr): self.assertASTTransformation(YieldRemover, code, expected) - @unittest.expectedFailure # TODO: RUSTPYTHON; is not def test_node_return_list(self): code = """ class DSL(Base, kw1=True): ... @@ -3064,7 +3063,6 @@ def visit_Call(self, node: ast.Call): self.assertASTTransformation(PrintToLog, code, expected) - @unittest.expectedFailure # TODO: RUSTPYTHON; is not def test_node_replace(self): code = """ def func(arg): diff --git a/Lib/test/test_base64.py b/Lib/test/test_base64.py index 409c8c109e8..a6739124571 100644 --- a/Lib/test/test_base64.py +++ b/Lib/test/test_base64.py @@ -1,10 +1,18 @@ -import unittest import base64 import binascii import os +import unittest from array import array +from test.support import cpython_only from test.support import os_helper from test.support import script_helper +from test.support.import_helper import ensure_lazy_imports + + +class LazyImportTest(unittest.TestCase): + @cpython_only + def test_lazy_import(self): + ensure_lazy_imports("base64", {"re", "getopt"}) class LegacyBase64TestCase(unittest.TestCase): @@ -200,18 +208,6 @@ def test_b64decode(self): self.check_other_types(base64.b64decode, b"YWJj", b"abc") self.check_decode_type_errors(base64.b64decode) - # Test with arbitrary alternative characters - tests_altchars = {(b'01a*b$cd', b'*$'): b'\xd3V\xbeo\xf7\x1d', - } - for (data, altchars), res in tests_altchars.items(): - data_str = data.decode('ascii') - altchars_str = altchars.decode('ascii') - - eq(base64.b64decode(data, altchars=altchars), res) - eq(base64.b64decode(data_str, altchars=altchars), res) - eq(base64.b64decode(data, altchars=altchars_str), res) - eq(base64.b64decode(data_str, altchars=altchars_str), res) - # Test standard alphabet for data, res in tests.items(): eq(base64.standard_b64decode(data), res) @@ -232,6 +228,20 @@ def test_b64decode(self): b'\xd3V\xbeo\xf7\x1d') self.check_decode_type_errors(base64.urlsafe_b64decode) + def test_b64decode_altchars(self): + # Test with arbitrary alternative characters + eq = self.assertEqual + res = b'\xd3V\xbeo\xf7\x1d' + for altchars in b'*$', b'+/', b'/+', b'+_', b'-+', b'-/', b'/_': + data = b'01a%cb%ccd' % tuple(altchars) + data_str = data.decode('ascii') + altchars_str = altchars.decode('ascii') + + eq(base64.b64decode(data, altchars=altchars), res) + eq(base64.b64decode(data_str, altchars=altchars), res) + eq(base64.b64decode(data, altchars=altchars_str), res) + eq(base64.b64decode(data_str, altchars=altchars_str), res) + def test_b64decode_padding_error(self): self.assertRaises(binascii.Error, base64.b64decode, b'abc') self.assertRaises(binascii.Error, base64.b64decode, 'abc') @@ -264,9 +274,12 @@ def test_b64decode_invalid_chars(self): base64.b64decode(bstr.decode('ascii'), validate=True) # Normal alphabet characters not discarded when alternative given - res = b'\xFB\xEF\xBE\xFF\xFF\xFF' - self.assertEqual(base64.b64decode(b'++[[//]]', b'[]'), res) - self.assertEqual(base64.urlsafe_b64decode(b'++--//__'), res) + res = b'\xfb\xef\xff' + self.assertEqual(base64.b64decode(b'++//', validate=True), res) + self.assertEqual(base64.b64decode(b'++//', '-_', validate=True), res) + self.assertEqual(base64.b64decode(b'--__', '-_', validate=True), res) + self.assertEqual(base64.urlsafe_b64decode(b'++//'), res) + self.assertEqual(base64.urlsafe_b64decode(b'--__'), res) def test_b32encode(self): eq = self.assertEqual @@ -321,23 +334,33 @@ def test_b32decode_casefold(self): self.assertRaises(binascii.Error, base64.b32decode, b'me======') self.assertRaises(binascii.Error, base64.b32decode, 'me======') + def test_b32decode_map01(self): # Mapping zero and one - eq(base64.b32decode(b'MLO23456'), b'b\xdd\xad\xf3\xbe') - eq(base64.b32decode('MLO23456'), b'b\xdd\xad\xf3\xbe') - - map_tests = {(b'M1023456', b'L'): b'b\xdd\xad\xf3\xbe', - (b'M1023456', b'I'): b'b\x1d\xad\xf3\xbe', - } - for (data, map01), res in map_tests.items(): - data_str = data.decode('ascii') + eq = self.assertEqual + res_L = b'b\xdd\xad\xf3\xbe' + res_I = b'b\x1d\xad\xf3\xbe' + eq(base64.b32decode(b'MLO23456'), res_L) + eq(base64.b32decode('MLO23456'), res_L) + eq(base64.b32decode(b'MIO23456'), res_I) + eq(base64.b32decode('MIO23456'), res_I) + self.assertRaises(binascii.Error, base64.b32decode, b'M1023456') + self.assertRaises(binascii.Error, base64.b32decode, b'M1O23456') + self.assertRaises(binascii.Error, base64.b32decode, b'ML023456') + self.assertRaises(binascii.Error, base64.b32decode, b'MI023456') + + data = b'M1023456' + data_str = data.decode('ascii') + for map01, res in [(b'L', res_L), (b'I', res_I)]: map01_str = map01.decode('ascii') eq(base64.b32decode(data, map01=map01), res) eq(base64.b32decode(data_str, map01=map01), res) eq(base64.b32decode(data, map01=map01_str), res) eq(base64.b32decode(data_str, map01=map01_str), res) - self.assertRaises(binascii.Error, base64.b32decode, data) - self.assertRaises(binascii.Error, base64.b32decode, data_str) + + eq(base64.b32decode(b'M1O23456', map01=map01), res) + eq(base64.b32decode(b'M%c023456' % map01, map01=map01), res) + eq(base64.b32decode(b'M%cO23456' % map01, map01=map01), res) def test_b32decode_error(self): tests = [b'abc', b'ABCDEF==', b'==ABCDEF'] @@ -804,7 +827,7 @@ def test_decode_nonascii_str(self): self.assertRaises(ValueError, f, 'with non-ascii \xcb') def test_ErrorHeritage(self): - self.assertTrue(issubclass(binascii.Error, ValueError)) + self.assertIsSubclass(binascii.Error, ValueError) def test_RFC4648_test_cases(self): # test cases from RFC 4648 section 10 diff --git a/Lib/test/test_bytes.py b/Lib/test/test_bytes.py index a1a711da70b..fcef9c0c972 100644 --- a/Lib/test/test_bytes.py +++ b/Lib/test/test_bytes.py @@ -1288,11 +1288,20 @@ class SubBytes(bytes): self.assertNotEqual(id(s), id(1 * s)) self.assertNotEqual(id(s), id(s * 2)) + @unittest.expectedFailure # TODO: RUSTPYTHON + def test_fromhex(self): + return super().test_fromhex() + + @unittest.expectedFailure # TODO: RUSTPYTHON + def test_mod(self): + return super().test_mod() + class ByteArrayTest(BaseBytesTest, unittest.TestCase): type2test = bytearray - _testlimitedcapi = import_helper.import_module('_testlimitedcapi') + # XXX: RUSTPYTHON; import_helper.import_module here cause the entire test stopping + _testlimitedcapi = None # import_helper.import_module('_testlimitedcapi') def test_getitem_error(self): b = bytearray(b'python') @@ -1385,6 +1394,7 @@ def by(s): b = by("Hello, world") self.assertEqual(re.findall(br"\w+", b), [by("Hello"), by("world")]) + @unittest.expectedFailure # TODO: RUSTPYTHON def test_setitem(self): def setitem_as_mapping(b, i, val): b[i] = val @@ -1432,6 +1442,7 @@ def do_tests(setitem): with self.subTest("tp_as_sequence"): do_tests(setitem_as_sequence) + @unittest.expectedFailure # TODO: RUSTPYTHON def test_delitem(self): def del_as_mapping(b, i): del b[i] @@ -1618,6 +1629,7 @@ def g(): alloc = b.__alloc__() self.assertGreaterEqual(alloc, len(b)) # NOTE: RUSTPYTHON patched + @unittest.expectedFailure # TODO: RUSTPYTHON def test_extend(self): orig = b'hello' a = bytearray(orig) @@ -1840,6 +1852,7 @@ def test_repeat_after_setslice(self): self.assertEqual(b1, b) self.assertEqual(b3, b'xcxcxc') + @unittest.expectedFailure # TODO: RUSTPYTHON def test_mutating_index(self): # bytearray slice assignment can call into python code # that reallocates the internal buffer @@ -1860,6 +1873,7 @@ def __index__(self): with self.assertRaises(IndexError): self._testlimitedcapi.sequence_setitem(b, 0, Boom()) + @unittest.expectedFailure # TODO: RUSTPYTHON def test_mutating_index_inbounds(self): # gh-91153 continued # Ensure buffer is not broken even if length is correct @@ -1893,6 +1907,14 @@ def __index__(self): self.assertEqual(instance.ba[0], ord("?"), "Assigned bytearray not altered") self.assertEqual(instance.new_ba, bytearray(0x180), "Wrong object altered") + @unittest.expectedFailure # TODO: RUSTPYTHON + def test_fromhex(self): + return super().test_fromhex() + + @unittest.expectedFailure # TODO: RUSTPYTHON + def test_mod(self): + return super().test_mod() + class AssortedBytesTest(unittest.TestCase): # @@ -1912,6 +1934,7 @@ def test_bytes_repr(self, f=repr): self.assertEqual(f(b"'\"'"), r"""b'\'"\''""") # '\'"\'' self.assertEqual(f(BytesSubclass(b"abc")), "b'abc'") + @unittest.expectedFailure # TODO: RUSTPYTHON def test_bytearray_repr(self, f=repr): self.assertEqual(f(bytearray()), "bytearray(b'')") self.assertEqual(f(bytearray(b'abc')), "bytearray(b'abc')") @@ -1933,6 +1956,7 @@ def test_bytearray_repr(self, f=repr): def test_bytes_str(self): self.test_bytes_repr(str) + @unittest.expectedFailure # TODO: RUSTPYTHON @check_bytes_warnings def test_bytearray_str(self): self.test_bytearray_repr(str) @@ -2233,6 +2257,14 @@ class ByteArraySubclassWithSlotsTest(SubclassTest, unittest.TestCase): basetype = bytearray type2test = ByteArraySubclassWithSlots + @unittest.expectedFailure # TODO: RUSTPYTHON + def test_copy(self): + return super().test_copy() + + @unittest.expectedFailure # TODO: RUSTPYTHON + def test_pickle(self): + return super().test_pickle() + class BytesSubclassTest(SubclassTest, unittest.TestCase): basetype = bytes type2test = BytesSubclass diff --git a/Lib/test/test_bz2.py b/Lib/test/test_bz2.py index 3617eba8e8d..26b5e79d337 100644 --- a/Lib/test/test_bz2.py +++ b/Lib/test/test_bz2.py @@ -16,7 +16,7 @@ from test.support import import_helper from test.support import threading_helper from test.support.os_helper import unlink, FakePath -import _compression +from compression._common import _streams import sys @@ -126,15 +126,15 @@ def testReadMultiStream(self): def testReadMonkeyMultiStream(self): # Test BZ2File.read() on a multi-stream archive where a stream # boundary coincides with the end of the raw read buffer. - buffer_size = _compression.BUFFER_SIZE - _compression.BUFFER_SIZE = len(self.DATA) + buffer_size = _streams.BUFFER_SIZE + _streams.BUFFER_SIZE = len(self.DATA) try: self.createTempFile(streams=5) with BZ2File(self.filename) as bz2f: self.assertRaises(TypeError, bz2f.read, float()) self.assertEqual(bz2f.read(), self.TEXT * 5) finally: - _compression.BUFFER_SIZE = buffer_size + _streams.BUFFER_SIZE = buffer_size def testReadTrailingJunk(self): self.createTempFile(suffix=self.BAD_DATA) @@ -184,7 +184,7 @@ def testPeek(self): with BZ2File(self.filename) as bz2f: pdata = bz2f.peek() self.assertNotEqual(len(pdata), 0) - self.assertTrue(self.TEXT.startswith(pdata)) + self.assertStartsWith(self.TEXT, pdata) self.assertEqual(bz2f.read(), self.TEXT) def testReadInto(self): @@ -730,8 +730,7 @@ def testOpenBytesFilename(self): self.assertEqual(f.read(), self.DATA) self.assertEqual(f.name, str_filename) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def testOpenPathLikeFilename(self): filename = FakePath(self.filename) with BZ2File(filename, "wb") as f: @@ -744,7 +743,7 @@ def testOpenPathLikeFilename(self): def testDecompressLimited(self): """Decompressed data buffering should be limited""" bomb = bz2.compress(b'\0' * int(2e6), compresslevel=9) - self.assertLess(len(bomb), _compression.BUFFER_SIZE) + self.assertLess(len(bomb), _streams.BUFFER_SIZE) decomp = BZ2File(BytesIO(bomb)) self.assertEqual(decomp.read(1), b'\0') @@ -770,7 +769,7 @@ def testPeekBytesIO(self): with BZ2File(bio) as bz2f: pdata = bz2f.peek() self.assertNotEqual(len(pdata), 0) - self.assertTrue(self.TEXT.startswith(pdata)) + self.assertStartsWith(self.TEXT, pdata) self.assertEqual(bz2f.read(), self.TEXT) def testWriteBytesIO(self): @@ -1190,8 +1189,7 @@ def test_encoding_error_handler(self): as f: self.assertEqual(f.read(), "foobar") - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_newline(self): # Test with explicit newline (universal newline mode disabled). text = self.TEXT.decode("ascii") diff --git a/Lib/test/test_calendar.py b/Lib/test/test_calendar.py index 35573fd9f01..7ade4271b7a 100644 --- a/Lib/test/test_calendar.py +++ b/Lib/test/test_calendar.py @@ -1090,7 +1090,6 @@ def test_option_months(self): output = run('--months', '1', '2004') self.assertIn(conv('\nMo Tu We Th Fr Sa Su\n'), output) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_option_type(self): self.assertFailure('-t') self.assertFailure('--type') diff --git a/Lib/test/test_gzip.py b/Lib/test/test_gzip.py index b0d9613cdbd..4a8813c4da1 100644 --- a/Lib/test/test_gzip.py +++ b/Lib/test/test_gzip.py @@ -9,6 +9,7 @@ import struct import sys import unittest +import warnings from subprocess import PIPE, Popen from test.support import catch_unraisable_exception from test.support import import_helper @@ -143,6 +144,38 @@ def test_read1(self): self.assertEqual(f.tell(), nread) self.assertEqual(b''.join(blocks), data1 * 50) + def test_readinto(self): + # 10MB of uncompressible data to ensure multiple reads + large_data = os.urandom(10 * 2**20) + with gzip.GzipFile(self.filename, 'wb') as f: + f.write(large_data) + + buf = bytearray(len(large_data)) + with gzip.GzipFile(self.filename, 'r') as f: + nbytes = f.readinto(buf) + self.assertEqual(nbytes, len(large_data)) + self.assertEqual(buf, large_data) + + def test_readinto1(self): + # 10MB of uncompressible data to ensure multiple reads + large_data = os.urandom(10 * 2**20) + with gzip.GzipFile(self.filename, 'wb') as f: + f.write(large_data) + + nread = 0 + buf = bytearray(len(large_data)) + memview = memoryview(buf) # Simplifies slicing + with gzip.GzipFile(self.filename, 'r') as f: + for count in range(200): + nbytes = f.readinto1(memview[nread:]) + if not nbytes: + break + nread += nbytes + self.assertEqual(f.tell(), nread) + self.assertEqual(buf, large_data) + # readinto1() should require multiple loops + self.assertGreater(count, 1) + @bigmemtest(size=_4G, memuse=1) def test_read_large(self, size): # Read chunk size over UINT_MAX should be supported, despite zlib's @@ -298,13 +331,13 @@ def test_mode(self): def test_1647484(self): for mode in ('wb', 'rb'): with gzip.GzipFile(self.filename, mode) as f: - self.assertTrue(hasattr(f, "name")) + self.assertHasAttr(f, "name") self.assertEqual(f.name, self.filename) def test_paddedfile_getattr(self): self.test_write() with gzip.GzipFile(self.filename, 'rb') as f: - self.assertTrue(hasattr(f.fileobj, "name")) + self.assertHasAttr(f.fileobj, "name") self.assertEqual(f.fileobj.name, self.filename) def test_mtime(self): @@ -312,7 +345,7 @@ def test_mtime(self): with gzip.GzipFile(self.filename, 'w', mtime = mtime) as fWrite: fWrite.write(data1) with gzip.GzipFile(self.filename) as fRead: - self.assertTrue(hasattr(fRead, 'mtime')) + self.assertHasAttr(fRead, 'mtime') self.assertIsNone(fRead.mtime) dataRead = fRead.read() self.assertEqual(dataRead, data1) @@ -427,7 +460,7 @@ def test_zero_padded_file(self): self.assertEqual(d, data1 * 50, "Incorrect data in file") def test_gzip_BadGzipFile_exception(self): - self.assertTrue(issubclass(gzip.BadGzipFile, OSError)) + self.assertIsSubclass(gzip.BadGzipFile, OSError) def test_bad_gzip_file(self): with open(self.filename, 'wb') as file: @@ -715,6 +748,17 @@ def test_compress_mtime(self): f.read(1) # to set mtime attribute self.assertEqual(f.mtime, mtime) + def test_compress_mtime_default(self): + # test for gh-125260 + datac = gzip.compress(data1, mtime=0) + datac2 = gzip.compress(data1) + self.assertEqual(datac, datac2) + datac3 = gzip.compress(data1, mtime=None) + self.assertNotEqual(datac, datac3) + with gzip.GzipFile(fileobj=io.BytesIO(datac3), mode="rb") as f: + f.read(1) # to set mtime attribute + self.assertGreater(f.mtime, 1) + def test_compress_correct_level(self): for mtime in (0, 42): with self.subTest(mtime=mtime): @@ -856,9 +900,10 @@ def test_refloop_unraisable(self): # fileobj would be closed before the GzipFile as the result of a # reference loop. See issue gh-129726 with catch_unraisable_exception() as cm: - gzip.GzipFile(fileobj=io.BytesIO(), mode="w") - gc.collect() - self.assertIsNone(cm.unraisable) + with self.assertWarns(ResourceWarning): + gzip.GzipFile(fileobj=io.BytesIO(), mode="w") + gc.collect() + self.assertIsNone(cm.unraisable) class TestOpen(BaseTest): @@ -991,8 +1036,7 @@ def test_encoding_error_handler(self): as f: self.assertEqual(f.read(), "foobar") - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_newline(self): # Test with explicit newline (universal newline mode disabled). uncompressed = data1.decode("ascii") * 50 diff --git a/Lib/test/test_int.py b/Lib/test/test_int.py index 1ab7a1fb6dd..7f62ac067c4 100644 --- a/Lib/test/test_int.py +++ b/Lib/test/test_int.py @@ -681,7 +681,7 @@ def test_denial_of_service_prevented_int_to_str(self): digits = 78_268 with ( support.adjust_int_max_str_digits(digits), - support.CPUStopwatch() as sw_convert): + support.Stopwatch() as sw_convert): huge_decimal = str(huge_int) self.assertEqual(len(huge_decimal), digits) # Ensuring that we chose a slow enough conversion to measure. @@ -696,7 +696,7 @@ def test_denial_of_service_prevented_int_to_str(self): with support.adjust_int_max_str_digits(int(.995 * digits)): with ( self.assertRaises(ValueError) as err, - support.CPUStopwatch() as sw_fail_huge): + support.Stopwatch() as sw_fail_huge): str(huge_int) self.assertIn('conversion', str(err.exception)) self.assertLessEqual(sw_fail_huge.seconds, sw_convert.seconds/2) @@ -706,7 +706,7 @@ def test_denial_of_service_prevented_int_to_str(self): extra_huge_int = int(f'0x{"c"*500_000}', base=16) # 602060 digits. with ( self.assertRaises(ValueError) as err, - support.CPUStopwatch() as sw_fail_extra_huge): + support.Stopwatch() as sw_fail_extra_huge): # If not limited, 8 seconds said Zen based cloud VM. str(extra_huge_int) self.assertIn('conversion', str(err.exception)) @@ -722,7 +722,7 @@ def test_denial_of_service_prevented_str_to_int(self): huge = '8'*digits with ( support.adjust_int_max_str_digits(digits), - support.CPUStopwatch() as sw_convert): + support.Stopwatch() as sw_convert): int(huge) # Ensuring that we chose a slow enough conversion to measure. # It takes 0.1 seconds on a Zen based cloud VM in an opt build. @@ -734,7 +734,7 @@ def test_denial_of_service_prevented_str_to_int(self): with support.adjust_int_max_str_digits(digits - 1): with ( self.assertRaises(ValueError) as err, - support.CPUStopwatch() as sw_fail_huge): + support.Stopwatch() as sw_fail_huge): int(huge) self.assertIn('conversion', str(err.exception)) self.assertLessEqual(sw_fail_huge.seconds, sw_convert.seconds/2) @@ -744,7 +744,7 @@ def test_denial_of_service_prevented_str_to_int(self): extra_huge = '7'*1_200_000 with ( self.assertRaises(ValueError) as err, - support.CPUStopwatch() as sw_fail_extra_huge): + support.Stopwatch() as sw_fail_extra_huge): # If not limited, 8 seconds in the Zen based cloud VM. int(extra_huge) self.assertIn('conversion', str(err.exception)) diff --git a/Lib/test/test_json/__init__.py b/Lib/test/test_json/__init__.py index b919af2328f..7091364cddb 100644 --- a/Lib/test/test_json/__init__.py +++ b/Lib/test/test_json/__init__.py @@ -41,8 +41,7 @@ def test_pyjson(self): 'json.encoder') class TestCTest(CTest): - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_cjson(self): self.assertEqual(self.json.scanner.make_scanner.__module__, '_json') self.assertEqual(self.json.decoder.scanstring.__module__, '_json') diff --git a/Lib/test/test_json/test_decode.py b/Lib/test/test_json/test_decode.py index ad37d47f083..b6531b237c0 100644 --- a/Lib/test/test_json/test_decode.py +++ b/Lib/test/test_json/test_decode.py @@ -18,8 +18,7 @@ def test_float(self): self.assertIsInstance(rval, float) self.assertEqual(rval, 1.0) - # TODO: RUSTPYTHON - @unittest.skip("TODO: RUSTPYTHON; called `Result::unwrap()` on an `Err` value: ParseFloatError { kind: Invalid }") + @unittest.skip('TODO: RUSTPYTHON; called `Result::unwrap()` on an `Err` value: ParseFloatError { kind: Invalid }') def test_nonascii_digits_rejected(self): # JSON specifies only ascii digits, see gh-125687 for num in ["1\uff10", "0.\uff10", "0e\uff10"]: @@ -138,9 +137,6 @@ def test_limit_int(self): class TestPyDecode(TestDecode, PyTest): pass class TestCDecode(TestDecode, CTest): - def test_keys_reuse(self): - return super().test_keys_reuse() - # TODO: RUSTPYTHON @unittest.expectedFailure def test_limit_int(self): diff --git a/Lib/test/test_json/test_default.py b/Lib/test/test_json/test_default.py index 3ce16684a08..b576947c4f2 100644 --- a/Lib/test/test_json/test_default.py +++ b/Lib/test/test_json/test_default.py @@ -1,4 +1,5 @@ import collections +import unittest # XXX: RUSTPYTHON; importing to be able to skip tests from test.test_json import PyTest, CTest @@ -8,6 +9,26 @@ def test_default(self): self.dumps(type, default=repr), self.dumps(repr(type))) + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_bad_default(self): + def default(obj): + if obj is NotImplemented: + raise ValueError + if obj is ...: + return NotImplemented + if obj is type: + return collections + return [...] + + with self.assertRaises(ValueError) as cm: + self.dumps(type, default=default) + self.assertEqual(cm.exception.__notes__, + ['when serializing ellipsis object', + 'when serializing list item 0', + 'when serializing module object', + 'when serializing type object']) + def test_ordereddict(self): od = collections.OrderedDict(a=1, b=2, c=3, d=4) od.move_to_end('b') diff --git a/Lib/test/test_json/test_fail.py b/Lib/test/test_json/test_fail.py index 7a85665c816..cb88e317b4e 100644 --- a/Lib/test/test_json/test_fail.py +++ b/Lib/test/test_json/test_fail.py @@ -1,7 +1,7 @@ -from test.test_json import PyTest, CTest - import unittest # XXX: RUSTPYTHON; importing to be able to skip tests +from test.test_json import PyTest, CTest + # 2007-10-05 JSONDOCS = [ # https://json.org/JSON_checker/test/fail1.json @@ -102,8 +102,27 @@ def test_non_string_keys_dict(self): def test_not_serializable(self): import sys with self.assertRaisesRegex(TypeError, - 'Object of type module is not JSON serializable'): + 'Object of type module is not JSON serializable') as cm: self.dumps(sys) + self.assertNotHasAttr(cm.exception, '__notes__') + + with self.assertRaises(TypeError) as cm: + self.dumps([1, [2, 3, sys]]) + self.assertEqual(cm.exception.__notes__, + ['when serializing list item 2', + 'when serializing list item 1']) + + with self.assertRaises(TypeError) as cm: + self.dumps((1, (2, 3, sys))) + self.assertEqual(cm.exception.__notes__, + ['when serializing tuple item 2', + 'when serializing tuple item 1']) + + with self.assertRaises(TypeError) as cm: + self.dumps({'a': {'b': sys}}) + self.assertEqual(cm.exception.__notes__, + ["when serializing dict item 'b'", + "when serializing dict item 'a'"]) def test_truncated_input(self): test_cases = [ diff --git a/Lib/test/test_json/test_recursion.py b/Lib/test/test_json/test_recursion.py index 59f6f2c4b19..2a24edef629 100644 --- a/Lib/test/test_json/test_recursion.py +++ b/Lib/test/test_json/test_recursion.py @@ -14,8 +14,8 @@ def test_listrecursion(self): x.append(x) try: self.dumps(x) - except ValueError: - pass + except ValueError as exc: + self.assertEqual(exc.__notes__, ["when serializing list item 0"]) else: self.fail("didn't raise ValueError on list recursion") x = [] @@ -23,8 +23,8 @@ def test_listrecursion(self): x.append(y) try: self.dumps(x) - except ValueError: - pass + except ValueError as exc: + self.assertEqual(exc.__notes__, ["when serializing list item 0"]*2) else: self.fail("didn't raise ValueError on alternating list recursion") y = [] @@ -37,8 +37,8 @@ def test_dictrecursion(self): x["test"] = x try: self.dumps(x) - except ValueError: - pass + except ValueError as exc: + self.assertEqual(exc.__notes__, ["when serializing dict item 'test'"]) else: self.fail("didn't raise ValueError on dict recursion") x = {} @@ -62,31 +62,41 @@ def default(self, o): enc.recurse = True try: enc.encode(JSONTestObject) - except ValueError: - pass + except ValueError as exc: + self.assertEqual(exc.__notes__, + ["when serializing list item 0", + "when serializing type object"]) else: self.fail("didn't raise ValueError on default recursion") + # TODO: RUSTPYTHON - @unittest.skip("TODO: RUSTPYTHON; crashes") + @unittest.skip('TODO: RUSTPYTHON; crashes') + # TODO: RUSTPYHTON; needs to upgrade test.support to 3.14 above + # @support.skip_emscripten_stack_overflow() + # @support.skip_wasi_stack_overflow() def test_highly_nested_objects_decoding(self): + very_deep = 200000 # test that loading highly-nested objects doesn't segfault when C # accelerations are used. See #12017 with self.assertRaises(RecursionError): with support.infinite_recursion(): - self.loads('{"a":' * 100000 + '1' + '}' * 100000) + self.loads('{"a":' * very_deep + '1' + '}' * very_deep) with self.assertRaises(RecursionError): with support.infinite_recursion(): - self.loads('{"a":' * 100000 + '[1]' + '}' * 100000) + self.loads('{"a":' * very_deep + '[1]' + '}' * very_deep) with self.assertRaises(RecursionError): with support.infinite_recursion(): - self.loads('[' * 100000 + '1' + ']' * 100000) + self.loads('[' * very_deep + '1' + ']' * very_deep) + # TODO: RUSTPYHTON; needs to upgrade test.support to 3.14 above + # @support.skip_wasi_stack_overflow() + # @support.skip_emscripten_stack_overflow() @support.requires_resource('cpu') def test_highly_nested_objects_encoding(self): # See #12051 l, d = [], {} - for x in range(100000): + for x in range(200_000): l, d = [l], {'k':d} with self.assertRaises(RecursionError): with support.infinite_recursion(5000): @@ -95,6 +105,9 @@ def test_highly_nested_objects_encoding(self): with support.infinite_recursion(5000): self.dumps(d) + # TODO: RUSTPYHTON; needs to upgrade test.support to 3.14 above + # @support.skip_emscripten_stack_overflow() + # @support.skip_wasi_stack_overflow() def test_endless_recursion(self): # See #12051 class EndlessJSONEncoder(self.json.JSONEncoder): diff --git a/Lib/test/test_json/test_scanstring.py b/Lib/test/test_json/test_scanstring.py index d6922c3b1b9..c7fc30f2235 100644 --- a/Lib/test/test_json/test_scanstring.py +++ b/Lib/test/test_json/test_scanstring.py @@ -144,8 +144,7 @@ def test_bad_escapes(self): with self.assertRaises(self.JSONDecodeError, msg=s): scanstring(s, 1, True) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_overflow(self): with self.assertRaises(OverflowError): self.json.decoder.scanstring("xxx", sys.maxsize+1) diff --git a/Lib/test/test_json/test_speedups.py b/Lib/test/test_json/test_speedups.py index ada96729123..25b51b307e1 100644 --- a/Lib/test/test_json/test_speedups.py +++ b/Lib/test/test_json/test_speedups.py @@ -40,8 +40,7 @@ def test_make_encoder(self): b"\xCD\x7D\x3D\x4E\x12\x4C\xF9\x79\xD7\x52\xBA\x82\xF2\x27\x4A\x7D\xA0\xCA\x75", None) - # TODO: RUSTPYTHON; TypeError: 'NoneType' object is not callable - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON; TypeError: 'NoneType' object is not callable def test_bad_str_encoder(self): # Issue #31505: There shouldn't be an assertion failure in case # c_make_encoder() receives a bad encoder() argument. @@ -63,8 +62,7 @@ def bad_encoder2(*args): with self.assertRaises(ZeroDivisionError): enc('spam', 4) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_bad_markers_argument_to_encoder(self): # https://bugs.python.org/issue45269 with self.assertRaisesRegex( @@ -74,8 +72,7 @@ def test_bad_markers_argument_to_encoder(self): self.json.encoder.c_make_encoder(1, None, None, None, ': ', ', ', False, False, False) - # TODO: RUSTPYTHON; ZeroDivisionError not raised by test - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON; ZeroDivisionError not raised by test def test_bad_bool_args(self): def test(name): self.json.encoder.JSONEncoder(**{name: BadBool()}).encode({'a': 1}) diff --git a/Lib/test/test_json/test_tool.py b/Lib/test/test_json/test_tool.py index 2b63810d539..068a7acb651 100644 --- a/Lib/test/test_json/test_tool.py +++ b/Lib/test/test_json/test_tool.py @@ -6,12 +6,23 @@ import subprocess from test import support -from test.support import os_helper +from test.support import force_not_colorized, os_helper from test.support.script_helper import assert_python_ok +from _colorize import get_theme + +# XXX: RUSTPYTHON; force_colorized not available in test.support +def force_colorized(func): + """Placeholder decorator for RustPython - force_colorized not available.""" + import functools + @functools.wraps(func) + def wrapper(*args, **kwargs): + raise unittest.SkipTest("TODO: RUSTPYTHON; force_colorized not available") + return wrapper + @support.requires_subprocess() -class TestTool(unittest.TestCase): +class TestMain(unittest.TestCase): data = """ [["blorpie"],[ "whoops" ] , [ @@ -19,6 +30,7 @@ class TestTool(unittest.TestCase): "i-vhbjkhnth", {"nifty":87}, {"morefield" :\tfalse,"field" :"yes"} ] """ + module = 'json' expect_without_sort_keys = textwrap.dedent("""\ [ @@ -86,8 +98,9 @@ class TestTool(unittest.TestCase): } """) + @force_not_colorized def test_stdin_stdout(self): - args = sys.executable, '-m', 'json.tool' + args = sys.executable, '-m', self.module process = subprocess.run(args, input=self.data, capture_output=True, text=True, check=True) self.assertEqual(process.stdout, self.expect) self.assertEqual(process.stderr, '') @@ -101,7 +114,8 @@ def _create_infile(self, data=None): def test_infile_stdout(self): infile = self._create_infile() - rc, out, err = assert_python_ok('-m', 'json.tool', infile) + rc, out, err = assert_python_ok('-m', self.module, infile, + PYTHON_COLORS='0') self.assertEqual(rc, 0) self.assertEqual(out.splitlines(), self.expect.encode().splitlines()) self.assertEqual(err, b'') @@ -115,7 +129,8 @@ def test_non_ascii_infile(self): ''').encode() infile = self._create_infile(data) - rc, out, err = assert_python_ok('-m', 'json.tool', infile) + rc, out, err = assert_python_ok('-m', self.module, infile, + PYTHON_COLORS='0') self.assertEqual(rc, 0) self.assertEqual(out.splitlines(), expect.splitlines()) @@ -124,7 +139,8 @@ def test_non_ascii_infile(self): def test_infile_outfile(self): infile = self._create_infile() outfile = os_helper.TESTFN + '.out' - rc, out, err = assert_python_ok('-m', 'json.tool', infile, outfile) + rc, out, err = assert_python_ok('-m', self.module, infile, outfile, + PYTHON_COLORS='0') self.addCleanup(os.remove, outfile) with open(outfile, "r", encoding="utf-8") as fp: self.assertEqual(fp.read(), self.expect) @@ -134,33 +150,38 @@ def test_infile_outfile(self): def test_writing_in_place(self): infile = self._create_infile() - rc, out, err = assert_python_ok('-m', 'json.tool', infile, infile) + rc, out, err = assert_python_ok('-m', self.module, infile, infile, + PYTHON_COLORS='0') with open(infile, "r", encoding="utf-8") as fp: self.assertEqual(fp.read(), self.expect) self.assertEqual(rc, 0) self.assertEqual(out, b'') self.assertEqual(err, b'') + @force_not_colorized def test_jsonlines(self): - args = sys.executable, '-m', 'json.tool', '--json-lines' + args = sys.executable, '-m', self.module, '--json-lines' process = subprocess.run(args, input=self.jsonlines_raw, capture_output=True, text=True, check=True) self.assertEqual(process.stdout, self.jsonlines_expect) self.assertEqual(process.stderr, '') def test_help_flag(self): - rc, out, err = assert_python_ok('-m', 'json.tool', '-h') + rc, out, err = assert_python_ok('-m', self.module, '-h', + PYTHON_COLORS='0') self.assertEqual(rc, 0) - self.assertTrue(out.startswith(b'usage: ')) + self.assertStartsWith(out, b'usage: ') self.assertEqual(err, b'') def test_sort_keys_flag(self): infile = self._create_infile() - rc, out, err = assert_python_ok('-m', 'json.tool', '--sort-keys', infile) + rc, out, err = assert_python_ok('-m', self.module, '--sort-keys', infile, + PYTHON_COLORS='0') self.assertEqual(rc, 0) self.assertEqual(out.splitlines(), self.expect_without_sort_keys.encode().splitlines()) self.assertEqual(err, b'') + @force_not_colorized def test_indent(self): input_ = '[1, 2]' expect = textwrap.dedent('''\ @@ -169,31 +190,34 @@ def test_indent(self): 2 ] ''') - args = sys.executable, '-m', 'json.tool', '--indent', '2' + args = sys.executable, '-m', self.module, '--indent', '2' process = subprocess.run(args, input=input_, capture_output=True, text=True, check=True) self.assertEqual(process.stdout, expect) self.assertEqual(process.stderr, '') + @force_not_colorized def test_no_indent(self): input_ = '[1,\n2]' expect = '[1, 2]\n' - args = sys.executable, '-m', 'json.tool', '--no-indent' + args = sys.executable, '-m', self.module, '--no-indent' process = subprocess.run(args, input=input_, capture_output=True, text=True, check=True) self.assertEqual(process.stdout, expect) self.assertEqual(process.stderr, '') + @force_not_colorized def test_tab(self): input_ = '[1, 2]' expect = '[\n\t1,\n\t2\n]\n' - args = sys.executable, '-m', 'json.tool', '--tab' + args = sys.executable, '-m', self.module, '--tab' process = subprocess.run(args, input=input_, capture_output=True, text=True, check=True) self.assertEqual(process.stdout, expect) self.assertEqual(process.stderr, '') + @force_not_colorized def test_compact(self): input_ = '[ 1 ,\n 2]' expect = '[1,2]\n' - args = sys.executable, '-m', 'json.tool', '--compact' + args = sys.executable, '-m', self.module, '--compact' process = subprocess.run(args, input=input_, capture_output=True, text=True, check=True) self.assertEqual(process.stdout, expect) self.assertEqual(process.stderr, '') @@ -202,7 +226,8 @@ def test_no_ensure_ascii_flag(self): infile = self._create_infile('{"key":"💩"}') outfile = os_helper.TESTFN + '.out' self.addCleanup(os.remove, outfile) - assert_python_ok('-m', 'json.tool', '--no-ensure-ascii', infile, outfile) + assert_python_ok('-m', self.module, '--no-ensure-ascii', infile, + outfile, PYTHON_COLORS='0') with open(outfile, "rb") as f: lines = f.read().splitlines() # asserting utf-8 encoded output file @@ -213,20 +238,99 @@ def test_ensure_ascii_default(self): infile = self._create_infile('{"key":"💩"}') outfile = os_helper.TESTFN + '.out' self.addCleanup(os.remove, outfile) - assert_python_ok('-m', 'json.tool', infile, outfile) + assert_python_ok('-m', self.module, infile, outfile, PYTHON_COLORS='0') with open(outfile, "rb") as f: lines = f.read().splitlines() # asserting an ascii encoded output file expected = [b'{', rb' "key": "\ud83d\udca9"', b"}"] self.assertEqual(lines, expected) + @force_not_colorized @unittest.skipIf(sys.platform =="win32", "The test is failed with ValueError on Windows") def test_broken_pipe_error(self): - cmd = [sys.executable, '-m', 'json.tool'] + cmd = [sys.executable, '-m', self.module] proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stdin=subprocess.PIPE) - # bpo-39828: Closing before json.tool attempts to write into stdout. + # bpo-39828: Closing before json attempts to write into stdout. proc.stdout.close() proc.communicate(b'"{}"') self.assertEqual(proc.returncode, errno.EPIPE) + + @force_colorized + def test_colors(self): + infile = os_helper.TESTFN + self.addCleanup(os.remove, infile) + + t = get_theme().syntax + ob = "{" + cb = "}" + + cases = ( + ('{}', '{}'), + ('[]', '[]'), + ('null', f'{t.keyword}null{t.reset}'), + ('true', f'{t.keyword}true{t.reset}'), + ('false', f'{t.keyword}false{t.reset}'), + ('NaN', f'{t.number}NaN{t.reset}'), + ('Infinity', f'{t.number}Infinity{t.reset}'), + ('-Infinity', f'{t.number}-Infinity{t.reset}'), + ('"foo"', f'{t.string}"foo"{t.reset}'), + (r'" \"foo\" "', f'{t.string}" \\"foo\\" "{t.reset}'), + ('"α"', f'{t.string}"\\u03b1"{t.reset}'), + ('123', f'{t.number}123{t.reset}'), + ('-1.25e+23', f'{t.number}-1.25e+23{t.reset}'), + (r'{"\\": ""}', + f'''\ +{ob} + {t.definition}"\\\\"{t.reset}: {t.string}""{t.reset} +{cb}'''), + (r'{"\\\\": ""}', + f'''\ +{ob} + {t.definition}"\\\\\\\\"{t.reset}: {t.string}""{t.reset} +{cb}'''), + ('''\ +{ + "foo": "bar", + "baz": 1234, + "qux": [true, false, null], + "xyz": [NaN, -Infinity, Infinity] +}''', + f'''\ +{ob} + {t.definition}"foo"{t.reset}: {t.string}"bar"{t.reset}, + {t.definition}"baz"{t.reset}: {t.number}1234{t.reset}, + {t.definition}"qux"{t.reset}: [ + {t.keyword}true{t.reset}, + {t.keyword}false{t.reset}, + {t.keyword}null{t.reset} + ], + {t.definition}"xyz"{t.reset}: [ + {t.number}NaN{t.reset}, + {t.number}-Infinity{t.reset}, + {t.number}Infinity{t.reset} + ] +{cb}'''), + ) + + for input_, expected in cases: + with self.subTest(input=input_): + with open(infile, "w", encoding="utf-8") as fp: + fp.write(input_) + _, stdout_b, _ = assert_python_ok( + '-m', self.module, infile, FORCE_COLOR='1', __isolated='1' + ) + stdout = stdout_b.decode() + stdout = stdout.replace('\r\n', '\n') # normalize line endings + stdout = stdout.strip() + self.assertEqual(stdout, expected) + + +@support.requires_subprocess() +class TestTool(TestMain): + module = 'json.tool' + + +if __name__ == "__main__": + unittest.main() diff --git a/Lib/test/test_json/test_unicode.py b/Lib/test/test_json/test_unicode.py index be0ac8823d5..2118c9827ea 100644 --- a/Lib/test/test_json/test_unicode.py +++ b/Lib/test/test_json/test_unicode.py @@ -94,8 +94,7 @@ def test_bytes_encode(self): self.assertRaises(TypeError, self.dumps, b"hi") self.assertRaises(TypeError, self.dumps, [b"hi"]) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_bytes_decode(self): for encoding, bom in [ ('utf-8', codecs.BOM_UTF8), diff --git a/Lib/test/test_lzma.py b/Lib/test/test_lzma.py index 4010ef9c340..1bfc9551ce3 100644 --- a/Lib/test/test_lzma.py +++ b/Lib/test/test_lzma.py @@ -1,4 +1,3 @@ -import _compression import array from io import BytesIO, UnsupportedOperation, DEFAULT_BUFFER_SIZE import os @@ -7,6 +6,7 @@ import sys from test import support import unittest +from compression._common import _streams from test.support import _4G, bigmemtest from test.support.import_helper import import_module @@ -22,8 +22,7 @@ class CompressorDecompressorTestCase(unittest.TestCase): # Test error cases. - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_simple_bad_args(self): self.assertRaises(TypeError, LZMACompressor, []) self.assertRaises(TypeError, LZMACompressor, format=3.45) @@ -64,8 +63,7 @@ def test_simple_bad_args(self): lzd.decompress(empty) self.assertRaises(EOFError, lzd.decompress, b"quux") - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_bad_filter_spec(self): self.assertRaises(TypeError, LZMACompressor, filters=[b"wobsite"]) self.assertRaises(ValueError, LZMACompressor, filters=[{"xyzzy": 3}]) @@ -82,8 +80,7 @@ def test_decompressor_after_eof(self): lzd.decompress(COMPRESSED_XZ) self.assertRaises(EOFError, lzd.decompress, b"nyan") - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_decompressor_memlimit(self): lzd = LZMADecompressor(memlimit=1024) self.assertRaises(LZMAError, lzd.decompress, COMPRESSED_XZ) @@ -104,8 +101,7 @@ def _test_decompressor(self, lzd, data, check, unused_data=b""): self.assertTrue(lzd.eof) self.assertEqual(lzd.unused_data, unused_data) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_decompressor_auto(self): lzd = LZMADecompressor() self._test_decompressor(lzd, COMPRESSED_XZ, lzma.CHECK_CRC64) @@ -113,44 +109,37 @@ def test_decompressor_auto(self): lzd = LZMADecompressor() self._test_decompressor(lzd, COMPRESSED_ALONE, lzma.CHECK_NONE) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_decompressor_xz(self): lzd = LZMADecompressor(lzma.FORMAT_XZ) self._test_decompressor(lzd, COMPRESSED_XZ, lzma.CHECK_CRC64) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_decompressor_alone(self): lzd = LZMADecompressor(lzma.FORMAT_ALONE) self._test_decompressor(lzd, COMPRESSED_ALONE, lzma.CHECK_NONE) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_decompressor_raw_1(self): lzd = LZMADecompressor(lzma.FORMAT_RAW, filters=FILTERS_RAW_1) self._test_decompressor(lzd, COMPRESSED_RAW_1, lzma.CHECK_NONE) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_decompressor_raw_2(self): lzd = LZMADecompressor(lzma.FORMAT_RAW, filters=FILTERS_RAW_2) self._test_decompressor(lzd, COMPRESSED_RAW_2, lzma.CHECK_NONE) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_decompressor_raw_3(self): lzd = LZMADecompressor(lzma.FORMAT_RAW, filters=FILTERS_RAW_3) self._test_decompressor(lzd, COMPRESSED_RAW_3, lzma.CHECK_NONE) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_decompressor_raw_4(self): lzd = LZMADecompressor(lzma.FORMAT_RAW, filters=FILTERS_RAW_4) self._test_decompressor(lzd, COMPRESSED_RAW_4, lzma.CHECK_NONE) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_decompressor_chunks(self): lzd = LZMADecompressor() out = [] @@ -163,8 +152,7 @@ def test_decompressor_chunks(self): self.assertTrue(lzd.eof) self.assertEqual(lzd.unused_data, b"") - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_decompressor_chunks_empty(self): lzd = LZMADecompressor() out = [] @@ -180,8 +168,7 @@ def test_decompressor_chunks_empty(self): self.assertTrue(lzd.eof) self.assertEqual(lzd.unused_data, b"") - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_decompressor_chunks_maxsize(self): lzd = LZMADecompressor() max_length = 100 @@ -273,16 +260,14 @@ def test_decompressor_inputbuf_3(self): out.append(lzd.decompress(COMPRESSED_XZ[300:])) self.assertEqual(b''.join(out), INPUT) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_decompressor_unused_data(self): lzd = LZMADecompressor() extra = b"fooblibar" self._test_decompressor(lzd, COMPRESSED_XZ + extra, lzma.CHECK_CRC64, unused_data=extra) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_decompressor_bad_input(self): lzd = LZMADecompressor() self.assertRaises(LZMAError, lzd.decompress, COMPRESSED_RAW_1) @@ -296,8 +281,7 @@ def test_decompressor_bad_input(self): lzd = LZMADecompressor(lzma.FORMAT_RAW, filters=FILTERS_RAW_1) self.assertRaises(LZMAError, lzd.decompress, COMPRESSED_XZ) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_decompressor_bug_28275(self): # Test coverage for Issue 28275 lzd = LZMADecompressor() @@ -307,32 +291,28 @@ def test_decompressor_bug_28275(self): # Test that LZMACompressor->LZMADecompressor preserves the input data. - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_roundtrip_xz(self): lzc = LZMACompressor() cdata = lzc.compress(INPUT) + lzc.flush() lzd = LZMADecompressor() self._test_decompressor(lzd, cdata, lzma.CHECK_CRC64) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_roundtrip_alone(self): lzc = LZMACompressor(lzma.FORMAT_ALONE) cdata = lzc.compress(INPUT) + lzc.flush() lzd = LZMADecompressor() self._test_decompressor(lzd, cdata, lzma.CHECK_NONE) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_roundtrip_raw(self): lzc = LZMACompressor(lzma.FORMAT_RAW, filters=FILTERS_RAW_4) cdata = lzc.compress(INPUT) + lzc.flush() lzd = LZMADecompressor(lzma.FORMAT_RAW, filters=FILTERS_RAW_4) self._test_decompressor(lzd, cdata, lzma.CHECK_NONE) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_roundtrip_raw_empty(self): lzc = LZMACompressor(lzma.FORMAT_RAW, filters=FILTERS_RAW_4) cdata = lzc.compress(INPUT) @@ -343,8 +323,7 @@ def test_roundtrip_raw_empty(self): lzd = LZMADecompressor(lzma.FORMAT_RAW, filters=FILTERS_RAW_4) self._test_decompressor(lzd, cdata, lzma.CHECK_NONE) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_roundtrip_chunks(self): lzc = LZMACompressor() cdata = [] @@ -355,8 +334,7 @@ def test_roundtrip_chunks(self): lzd = LZMADecompressor() self._test_decompressor(lzd, cdata, lzma.CHECK_CRC64) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_roundtrip_empty_chunks(self): lzc = LZMACompressor() cdata = [] @@ -372,8 +350,7 @@ def test_roundtrip_empty_chunks(self): # LZMADecompressor intentionally does not handle concatenated streams. - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_decompressor_multistream(self): lzd = LZMADecompressor() self._test_decompressor(lzd, COMPRESSED_XZ + COMPRESSED_ALONE, @@ -434,8 +411,7 @@ class CompressDecompressFunctionTestCase(unittest.TestCase): # Test error cases: - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_bad_args(self): self.assertRaises(TypeError, lzma.compress) self.assertRaises(TypeError, lzma.compress, []) @@ -463,24 +439,22 @@ def test_bad_args(self): lzma.decompress(b"", format=lzma.FORMAT_XZ, filters=FILTERS_RAW_1) with self.assertRaises(ValueError): lzma.decompress( - b"", format=lzma.FORMAT_ALONE, filters=FILTERS_RAW_1) + b"", format=lzma.FORMAT_ALONE, filters=FILTERS_RAW_1) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_decompress_memlimit(self): with self.assertRaises(LZMAError): lzma.decompress(COMPRESSED_XZ, memlimit=1024) with self.assertRaises(LZMAError): lzma.decompress( - COMPRESSED_XZ, format=lzma.FORMAT_XZ, memlimit=1024) + COMPRESSED_XZ, format=lzma.FORMAT_XZ, memlimit=1024) with self.assertRaises(LZMAError): lzma.decompress( - COMPRESSED_ALONE, format=lzma.FORMAT_ALONE, memlimit=1024) + COMPRESSED_ALONE, format=lzma.FORMAT_ALONE, memlimit=1024) # Test LZMADecompressor on known-good input data. - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_decompress_good_input(self): ddata = lzma.decompress(COMPRESSED_XZ) self.assertEqual(ddata, INPUT) @@ -495,23 +469,22 @@ def test_decompress_good_input(self): self.assertEqual(ddata, INPUT) ddata = lzma.decompress( - COMPRESSED_RAW_1, lzma.FORMAT_RAW, filters=FILTERS_RAW_1) + COMPRESSED_RAW_1, lzma.FORMAT_RAW, filters=FILTERS_RAW_1) self.assertEqual(ddata, INPUT) ddata = lzma.decompress( - COMPRESSED_RAW_2, lzma.FORMAT_RAW, filters=FILTERS_RAW_2) + COMPRESSED_RAW_2, lzma.FORMAT_RAW, filters=FILTERS_RAW_2) self.assertEqual(ddata, INPUT) ddata = lzma.decompress( - COMPRESSED_RAW_3, lzma.FORMAT_RAW, filters=FILTERS_RAW_3) + COMPRESSED_RAW_3, lzma.FORMAT_RAW, filters=FILTERS_RAW_3) self.assertEqual(ddata, INPUT) ddata = lzma.decompress( - COMPRESSED_RAW_4, lzma.FORMAT_RAW, filters=FILTERS_RAW_4) + COMPRESSED_RAW_4, lzma.FORMAT_RAW, filters=FILTERS_RAW_4) self.assertEqual(ddata, INPUT) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_decompress_incomplete_input(self): self.assertRaises(LZMAError, lzma.decompress, COMPRESSED_XZ[:128]) self.assertRaises(LZMAError, lzma.decompress, COMPRESSED_ALONE[:128]) @@ -524,8 +497,7 @@ def test_decompress_incomplete_input(self): self.assertRaises(LZMAError, lzma.decompress, COMPRESSED_RAW_4[:128], format=lzma.FORMAT_RAW, filters=FILTERS_RAW_4) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_decompress_bad_input(self): with self.assertRaises(LZMAError): lzma.decompress(COMPRESSED_BOGUS) @@ -541,8 +513,7 @@ def test_decompress_bad_input(self): # Test that compress()->decompress() preserves the input data. - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_roundtrip(self): cdata = lzma.compress(INPUT) ddata = lzma.decompress(cdata) @@ -568,14 +539,12 @@ def test_decompress_multistream(self): # Test robust handling of non-LZMA data following the compressed stream(s). - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_decompress_trailing_junk(self): ddata = lzma.decompress(COMPRESSED_XZ + COMPRESSED_BOGUS) self.assertEqual(ddata, INPUT) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_decompress_multistream_trailing_junk(self): ddata = lzma.decompress(COMPRESSED_XZ * 3 + COMPRESSED_BOGUS) self.assertEqual(ddata, INPUT * 3) @@ -612,8 +581,7 @@ def test_init(self): self.assertIsInstance(f, LZMAFile) self.assertEqual(f.mode, "wb") - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_init_with_PathLike_filename(self): filename = FakePath(TESTFN) with TempFile(filename, COMPRESSED_XZ): @@ -694,8 +662,7 @@ def test_init_bad_mode(self): with self.assertRaises(ValueError): LZMAFile(BytesIO(COMPRESSED_XZ), "rw") - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_init_bad_check(self): with self.assertRaises(TypeError): LZMAFile(BytesIO(), "w", check=b"asd") @@ -716,6 +683,7 @@ def test_init_bad_check(self): with self.assertRaises(ValueError): LZMAFile(BytesIO(COMPRESSED_XZ), check=lzma.CHECK_UNKNOWN) + @unittest.expectedFailure # TODO: RUSTPYTHON def test_init_bad_preset(self): with self.assertRaises(TypeError): LZMAFile(BytesIO(), "w", preset=4.39) @@ -723,18 +691,19 @@ def test_init_bad_preset(self): LZMAFile(BytesIO(), "w", preset=10) with self.assertRaises(LZMAError): LZMAFile(BytesIO(), "w", preset=23) - with self.assertRaises(OverflowError): + with self.assertRaises(ValueError): LZMAFile(BytesIO(), "w", preset=-1) - with self.assertRaises(OverflowError): + with self.assertRaises(ValueError): LZMAFile(BytesIO(), "w", preset=-7) + with self.assertRaises(OverflowError): + LZMAFile(BytesIO(), "w", preset=2**1000) with self.assertRaises(TypeError): LZMAFile(BytesIO(), "w", preset="foo") # Cannot specify a preset with mode="r". with self.assertRaises(ValueError): LZMAFile(BytesIO(COMPRESSED_XZ), preset=3) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_init_bad_filter_spec(self): with self.assertRaises(TypeError): LZMAFile(BytesIO(), "w", filters=[b"wobsite"]) @@ -752,8 +721,7 @@ def test_init_bad_filter_spec(self): LZMAFile(BytesIO(), "w", filters=[{"id": lzma.FILTER_X86, "foo": 0}]) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_init_with_preset_and_filters(self): with self.assertRaises(ValueError): LZMAFile(BytesIO(), "w", format=lzma.FORMAT_RAW, @@ -872,8 +840,7 @@ def test_writable(self): f.close() self.assertRaises(ValueError, f.writable) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_read(self): with LZMAFile(BytesIO(COMPRESSED_XZ)) as f: self.assertEqual(f.read(), INPUT) @@ -921,8 +888,7 @@ def test_read_10(self): chunks.append(result) self.assertEqual(b"".join(chunks), INPUT) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_read_multistream(self): with LZMAFile(BytesIO(COMPRESSED_XZ * 5)) as f: self.assertEqual(f.read(), INPUT * 5) @@ -935,22 +901,20 @@ def test_read_multistream(self): def test_read_multistream_buffer_size_aligned(self): # Test the case where a stream boundary coincides with the end # of the raw read buffer. - saved_buffer_size = _compression.BUFFER_SIZE - _compression.BUFFER_SIZE = len(COMPRESSED_XZ) + saved_buffer_size = _streams.BUFFER_SIZE + _streams.BUFFER_SIZE = len(COMPRESSED_XZ) try: with LZMAFile(BytesIO(COMPRESSED_XZ * 5)) as f: self.assertEqual(f.read(), INPUT * 5) finally: - _compression.BUFFER_SIZE = saved_buffer_size + _streams.BUFFER_SIZE = saved_buffer_size - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_read_trailing_junk(self): with LZMAFile(BytesIO(COMPRESSED_XZ + COMPRESSED_BOGUS)) as f: self.assertEqual(f.read(), INPUT) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_read_multistream_trailing_junk(self): with LZMAFile(BytesIO(COMPRESSED_XZ * 5 + COMPRESSED_BOGUS)) as f: self.assertEqual(f.read(), INPUT * 5) @@ -1056,8 +1020,7 @@ def test_read_bad_args(self): with LZMAFile(BytesIO(COMPRESSED_XZ)) as f: self.assertRaises(TypeError, f.read, float()) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_read_bad_data(self): with LZMAFile(BytesIO(COMPRESSED_BOGUS)) as f: self.assertRaises(LZMAError, f.read) @@ -1103,20 +1066,19 @@ def test_peek(self): with LZMAFile(BytesIO(COMPRESSED_XZ)) as f: result = f.peek() self.assertGreater(len(result), 0) - self.assertTrue(INPUT.startswith(result)) + self.assertStartsWith(INPUT, result) self.assertEqual(f.read(), INPUT) with LZMAFile(BytesIO(COMPRESSED_XZ)) as f: result = f.peek(10) self.assertGreater(len(result), 0) - self.assertTrue(INPUT.startswith(result)) + self.assertStartsWith(INPUT, result) self.assertEqual(f.read(), INPUT) def test_peek_bad_args(self): with LZMAFile(BytesIO(), "w") as f: self.assertRaises(ValueError, f.peek) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_iterator(self): with BytesIO(INPUT) as f: lines = f.readlines() @@ -1148,16 +1110,15 @@ def test_readlines(self): def test_decompress_limited(self): """Decompressed data buffering should be limited""" bomb = lzma.compress(b'\0' * int(2e6), preset=6) - self.assertLess(len(bomb), _compression.BUFFER_SIZE) + self.assertLess(len(bomb), _streams.BUFFER_SIZE) decomp = LZMAFile(BytesIO(bomb)) self.assertEqual(decomp.read(1), b'\0') max_decomp = 1 + DEFAULT_BUFFER_SIZE self.assertLessEqual(decomp._buffer.raw.tell(), max_decomp, - "Excessive amount of data was decompressed") + "Excessive amount of data was decompressed") - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_write(self): with BytesIO() as dst: with LZMAFile(dst, "w") as f: @@ -1426,8 +1387,7 @@ def test_tell_bad_args(self): f.close() self.assertRaises(ValueError, f.tell) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_issue21872(self): # sometimes decompress data incompletely @@ -1511,8 +1471,7 @@ def test_filename(self): with lzma.open(TESTFN, "rb") as f: self.assertEqual(f.read(), INPUT * 2) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_with_pathlike_filename(self): filename = FakePath(TESTFN) with TempFile(filename): @@ -1539,8 +1498,7 @@ def test_bad_params(self): with self.assertRaises(ValueError): lzma.open(TESTFN, "rb", newline="\n") - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_format_and_filters(self): # Test non-default format and filter chain. options = {"format": lzma.FORMAT_RAW, "filters": FILTERS_RAW_1} @@ -1571,8 +1529,7 @@ def test_encoding_error_handler(self): with lzma.open(bio, "rt", encoding="ascii", errors="ignore") as f: self.assertEqual(f.read(), "foobar") - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_newline(self): # Test with explicit newline (universal newline mode disabled). text = INPUT.decode("ascii") @@ -1597,8 +1554,7 @@ def test_x_mode(self): class MiscellaneousTestCase(unittest.TestCase): - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_is_check_supported(self): # CHECK_NONE and CHECK_CRC32 should always be supported, # regardless of the options liblzma was compiled with. @@ -1611,8 +1567,7 @@ def test_is_check_supported(self): # This value should not be a valid check ID. self.assertFalse(lzma.is_check_supported(lzma.CHECK_UNKNOWN)) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test__encode_filter_properties(self): with self.assertRaises(TypeError): lzma._encode_filter_properties(b"not a dict") @@ -1622,20 +1577,19 @@ def test__encode_filter_properties(self): lzma._encode_filter_properties({"id": lzma.FILTER_LZMA2, "junk": 12}) with self.assertRaises(lzma.LZMAError): lzma._encode_filter_properties({"id": lzma.FILTER_DELTA, - "dist": 9001}) + "dist": 9001}) # Test with parameters used by zipfile module. props = lzma._encode_filter_properties({ - "id": lzma.FILTER_LZMA1, - "pb": 2, - "lp": 0, - "lc": 3, - "dict_size": 8 << 20, - }) + "id": lzma.FILTER_LZMA1, + "pb": 2, + "lp": 0, + "lc": 3, + "dict_size": 8 << 20, + }) self.assertEqual(props, b"]\x00\x00\x80\x00") - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test__decode_filter_properties(self): with self.assertRaises(TypeError): lzma._decode_filter_properties(lzma.FILTER_X86, {"should be": bytes}) @@ -1644,7 +1598,7 @@ def test__decode_filter_properties(self): # Test with parameters used by zipfile module. filterspec = lzma._decode_filter_properties( - lzma.FILTER_LZMA1, b"]\x00\x00\x80\x00") + lzma.FILTER_LZMA1, b"]\x00\x00\x80\x00") self.assertEqual(filterspec["id"], lzma.FILTER_LZMA1) self.assertEqual(filterspec["pb"], 2) self.assertEqual(filterspec["lp"], 0) @@ -1659,11 +1613,10 @@ def test__decode_filter_properties(self): filterspec = lzma._decode_filter_properties(f, b"") self.assertEqual(filterspec, {"id": f}) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_filter_properties_roundtrip(self): spec1 = lzma._decode_filter_properties( - lzma.FILTER_LZMA1, b"]\x00\x00\x80\x00") + lzma.FILTER_LZMA1, b"]\x00\x00\x80\x00") reencoded = lzma._encode_filter_properties(spec1) spec2 = lzma._decode_filter_properties(lzma.FILTER_LZMA1, reencoded) self.assertEqual(spec1, spec2) diff --git a/Lib/test/test_re.py b/Lib/test/test_re.py index 7c5bfba41bf..04c8ee71a99 100644 --- a/Lib/test/test_re.py +++ b/Lib/test/test_re.py @@ -1,7 +1,7 @@ from test.support import (gc_collect, bigmemtest, _2G, cpython_only, captured_stdout, check_disallow_instantiation, is_emscripten, is_wasi, - warnings_helper, SHORT_TIMEOUT, CPUStopwatch, requires_resource) + warnings_helper, SHORT_TIMEOUT, Stopwatch, requires_resource) import locale import re import string @@ -2487,7 +2487,7 @@ def test_bug_40736(self): @requires_resource('cpu') def test_search_anchor_at_beginning(self): s = 'x'*10**7 - with CPUStopwatch() as stopwatch: + with Stopwatch() as stopwatch: for p in r'\Ay', r'^y': self.assertIsNone(re.search(p, s)) self.assertEqual(re.split(p, s), [s]) diff --git a/Lib/test/test_regrtest.py b/Lib/test/test_regrtest.py index 9f09413f59c..ba038f18b4c 100644 --- a/Lib/test/test_regrtest.py +++ b/Lib/test/test_regrtest.py @@ -24,7 +24,7 @@ from xml.etree import ElementTree from test import support -from test.support import os_helper, without_optimizer +from test.support import os_helper, requires_jit_disabled from test.libregrtest import cmdline from test.libregrtest import main from test.libregrtest import setup @@ -1245,7 +1245,7 @@ def test_run(self): stats=TestStats(4, 1), forever=True) - @without_optimizer + @requires_jit_disabled def check_leak(self, code, what, *, run_workers=False): test = self.create_test('huntrleaks', code=code) diff --git a/Lib/test/test_sqlite3/test_factory.py b/Lib/test/test_sqlite3/test_factory.py index c13a7481520..20cff8b585f 100644 --- a/Lib/test/test_sqlite3/test_factory.py +++ b/Lib/test/test_sqlite3/test_factory.py @@ -80,8 +80,6 @@ def setUp(self): def tearDown(self): self.con.close() - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_is_instance(self): cur = self.con.cursor() self.assertIsInstance(cur, sqlite.Cursor) diff --git a/Lib/test/test_support.py b/Lib/test/test_support.py index d66afdc833c..481be2bff92 100644 --- a/Lib/test/test_support.py +++ b/Lib/test/test_support.py @@ -5,6 +5,7 @@ import logging import os import shutil +import signal import socket import stat import subprocess @@ -16,12 +17,13 @@ import warnings from test import support -from test.support import import_helper -from test.support import os_helper -from test.support import script_helper -from test.support import socket_helper -from test.support import warnings_helper -from test.support.testcase import ExtraAssertions +from test.support import ( + import_helper, + os_helper, + script_helper, + socket_helper, + warnings_helper, +) TESTFN = os_helper.TESTFN @@ -51,26 +53,26 @@ def _caplog(): root_logger.removeHandler(handler) -class TestSupport(unittest.TestCase, ExtraAssertions): +class TestSupport(unittest.TestCase): @classmethod def setUpClass(cls): - orig_filter_len = len(warnings.filters) + orig_filter_len = len(warnings._get_filters()) cls._warnings_helper_token = support.ignore_deprecations_from( "test.support.warnings_helper", like=".*used in test_support.*" ) cls._test_support_token = support.ignore_deprecations_from( __name__, like=".*You should NOT be seeing this.*" ) - assert len(warnings.filters) == orig_filter_len + 2 + assert len(warnings._get_filters()) == orig_filter_len + 2 @classmethod def tearDownClass(cls): - orig_filter_len = len(warnings.filters) + orig_filter_len = len(warnings._get_filters()) support.clear_ignored_deprecations( cls._warnings_helper_token, cls._test_support_token, ) - assert len(warnings.filters) == orig_filter_len - 2 + assert len(warnings._get_filters()) == orig_filter_len - 2 def test_ignored_deprecations_are_silent(self): """Test support.ignore_deprecations_from() silences warnings""" @@ -98,7 +100,7 @@ def test_get_original_stdout(self): self.assertEqual(support.get_original_stdout(), sys.stdout) def test_unload(self): - import sched + import sched # noqa: F401 self.assertIn("sched", sys.modules) import_helper.unload("sched") self.assertNotIn("sched", sys.modules) @@ -407,10 +409,10 @@ class Obj: with support.swap_attr(obj, "y", 5) as y: self.assertEqual(obj.y, 5) self.assertIsNone(y) - self.assertFalse(hasattr(obj, 'y')) + self.assertNotHasAttr(obj, 'y') with support.swap_attr(obj, "y", 5): del obj.y - self.assertFalse(hasattr(obj, 'y')) + self.assertNotHasAttr(obj, 'y') def test_swap_item(self): D = {"x":1} @@ -458,6 +460,7 @@ def test_detect_api_mismatch__ignore(self): self.OtherClass, self.RefClass, ignore=ignore) self.assertEqual(set(), missing_items) + @unittest.expectedFailure # TODO: RUSTPYTHON def test_check__all__(self): extra = {'tempdir'} not_exported = {'template'} @@ -469,7 +472,6 @@ def test_check__all__(self): extra = { 'TextTestResult', 'installHandler', - 'IsolatedAsyncioTestCase', } not_exported = {'load_tests', "TestProgram", "BaseTestSuite"} support.check__all__(self, @@ -562,6 +564,7 @@ def test_args_from_interpreter_flags(self): ['-Wignore', '-X', 'dev'], ['-X', 'faulthandler'], ['-X', 'importtime'], + ['-X', 'importtime=2'], ['-X', 'showrefcount'], ['-X', 'tracemalloc'], ['-X', 'tracemalloc=3'], @@ -586,7 +589,6 @@ def test_optim_args_from_interpreter_flags(self): self.check_options(opts, 'optim_args_from_interpreter_flags') @unittest.skipIf(support.is_apple_mobile, "Unstable on Apple Mobile") - @unittest.skipIf(support.is_emscripten, "Unstable in Emscripten") @unittest.skipIf(support.is_wasi, "Unavailable on WASI") def test_fd_count(self): # We cannot test the absolute value of fd_count(): on old Linux kernel @@ -614,17 +616,14 @@ def test_print_warning(self): self.check_print_warning("a\nb", 'Warning -- a\nWarning -- b\n') - # TODO: RUSTPYTHON - strftime extension not fully supported on non-Windows - @unittest.skipUnless(sys.platform == "win32" or support.is_emscripten, - "strftime extension not fully supported on non-Windows") + @unittest.expectedFailureIf(sys.platform != "win32", "TODO: RUSTPYTHON; no has_strftime_extensions yet") def test_has_strftime_extensions(self): - if support.is_emscripten or sys.platform == "win32": + if sys.platform == "win32": self.assertFalse(support.has_strftime_extensions) else: self.assertTrue(support.has_strftime_extensions) - # TODO: RUSTPYTHON - _testinternalcapi module not available - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON; - _testinternalcapi module not available def test_get_recursion_depth(self): # test support.get_recursion_depth() code = textwrap.dedent(""" @@ -668,8 +667,7 @@ def test_recursive(depth, limit): """) script_helper.assert_python_ok("-c", code) - # TODO: RUSTPYTHON - stack overflow in debug mode with deep recursion - @unittest.skip("TODO: RUSTPYTHON - causes segfault in debug builds") + @unittest.skip('TODO: RUSTPYTHON; - causes segfault in debug builds') def test_recursion(self): # Test infinite_recursion() and get_recursion_available() functions. def recursive_function(depth): @@ -778,9 +776,31 @@ def test_copy_python_src_ignore(self): self.assertEqual(support.copy_python_src_ignore(path, os.listdir(path)), ignored) + def test_get_signal_name(self): + for exitcode, expected in ( + (-int(signal.SIGINT), 'SIGINT'), + (-int(signal.SIGSEGV), 'SIGSEGV'), + (128 + int(signal.SIGABRT), 'SIGABRT'), + (3221225477, "STATUS_ACCESS_VIOLATION"), + (0xC00000FD, "STATUS_STACK_OVERFLOW"), + ): + self.assertEqual(support.get_signal_name(exitcode), expected, + exitcode) + def test_linked_to_musl(self): linked = support.linked_to_musl() - self.assertIsInstance(linked, bool) + self.assertIsNotNone(linked) + if support.is_wasm32: + self.assertTrue(linked) + # The value is cached, so make sure it returns the same value again. + self.assertIs(linked, support.linked_to_musl()) + # The musl version is either triple or just a major version number. + if linked: + self.assertIsInstance(linked, tuple) + self.assertIn(len(linked), (1, 3)) + for v in linked: + self.assertIsInstance(v, int) + # XXX -follows a list of untested API # make_legacy_pyc diff --git a/Lib/test/test_sys.py b/Lib/test/test_sys.py index b65e1291bbc..77300fbe0bb 100644 --- a/Lib/test/test_sys.py +++ b/Lib/test/test_sys.py @@ -7,19 +7,24 @@ import operator import os import random +import socket import struct import subprocess import sys import sysconfig import test.support +from io import StringIO +from unittest import mock from test import support from test.support import os_helper from test.support.script_helper import assert_python_ok, assert_python_failure +from test.support.socket_helper import find_unused_port from test.support import threading_helper from test.support import import_helper from test.support import force_not_colorized +from test.support import SHORT_TIMEOUT try: - from test.support import interpreters + from concurrent import interpreters except ImportError: interpreters = None import textwrap @@ -52,7 +57,7 @@ def test_original_displayhook(self): dh(None) self.assertEqual(out.getvalue(), "") - self.assertTrue(not hasattr(builtins, "_")) + self.assertNotHasAttr(builtins, "_") # sys.displayhook() requires arguments self.assertRaises(TypeError, dh) @@ -167,7 +172,7 @@ def test_original_excepthook(self): with support.captured_stderr() as err: sys.__excepthook__(*sys.exc_info()) - self.assertTrue(err.getvalue().endswith("ValueError: 42\n")) + self.assertEndsWith(err.getvalue(), "ValueError: 42\n") self.assertRaises(TypeError, sys.__excepthook__) @@ -188,7 +193,7 @@ def test_excepthook_bytes_filename(self): err = err.getvalue() self.assertIn(""" File "b'bytes_filename'", line 123\n""", err) self.assertIn(""" text\n""", err) - self.assertTrue(err.endswith("SyntaxError: msg\n")) + self.assertEndsWith(err, "SyntaxError: msg\n") def test_excepthook(self): with test.support.captured_output("stderr") as stderr: @@ -206,7 +211,7 @@ class SysModuleTest(unittest.TestCase): def tearDown(self): test.support.reap_children() - @unittest.expectedFailure # TODO: RUSTPYTHON + @unittest.expectedFailure # TODO: RUSTPYTHON def test_exit(self): # call with two arguments self.assertRaises(TypeError, sys.exit, 42, 42) @@ -266,8 +271,7 @@ def check_exit_message(code, expected, **env_vars): rc, out, err = assert_python_failure('-c', code, **env_vars) self.assertEqual(rc, 1) self.assertEqual(out, b'') - self.assertTrue(err.startswith(expected), - "%s doesn't start with %s" % (ascii(err), ascii(expected))) + self.assertStartsWith(err, expected) # test that stderr buffer is flushed before the exit message is written # into stderr @@ -399,36 +403,6 @@ def test_setrecursionlimit_to_depth(self): finally: sys.setrecursionlimit(old_limit) - @unittest.skipUnless(support.Py_GIL_DISABLED, "only meaningful if the GIL is disabled") - @threading_helper.requires_working_threading() - def test_racing_recursion_limit(self): - from threading import Thread - def something_recursive(): - def count(n): - if n > 0: - return count(n - 1) + 1 - return 0 - - count(50) - - def set_recursion_limit(): - for limit in range(100, 200): - sys.setrecursionlimit(limit) - - threads = [] - for _ in range(5): - threads.append(Thread(target=set_recursion_limit)) - - for _ in range(5): - threads.append(Thread(target=something_recursive)) - - with threading_helper.catch_threading_exception() as cm: - with threading_helper.start_threads(threads): - pass - - if cm.exc_value: - raise cm.exc_value - @unittest.expectedFailure # TODO: RUSTPYTHON def test_getwindowsversion(self): # Raise SkipTest if sys doesn't have getwindowsversion attribute @@ -467,7 +441,7 @@ def test_call_tracing(self): @unittest.skipUnless(hasattr(sys, "setdlopenflags"), 'test needs sys.setdlopenflags()') def test_dlopenflags(self): - self.assertTrue(hasattr(sys, "getdlopenflags")) + self.assertHasAttr(sys, "getdlopenflags") self.assertRaises(TypeError, sys.getdlopenflags, 42) oldflags = sys.getdlopenflags() self.assertRaises(TypeError, sys.setdlopenflags) @@ -656,8 +630,7 @@ def g456(): # And the next record must be for g456(). filename, lineno, funcname, sourceline = stack[i+1] self.assertEqual(funcname, "g456") - self.assertTrue((sourceline.startswith("if leave_g.wait(") or - sourceline.startswith("g_raised.set()"))) + self.assertStartsWith(sourceline, ("if leave_g.wait(", "g_raised.set()")) finally: # Reap the spawned thread. leave_g.set() @@ -757,6 +730,8 @@ def test_attributes(self): self.assertIn(sys.float_repr_style, ('short', 'legacy')) if not sys.platform.startswith('win'): self.assertIsInstance(sys.abiflags, str) + else: + self.assertFalse(hasattr(sys, 'abiflags')) def test_thread_info(self): info = sys.thread_info @@ -882,6 +857,7 @@ def test_subinterp_intern_singleton(self): ''')) self.assertTrue(sys._is_interned(s)) + @unittest.expectedFailure # TODO: RUSTPYTHON; needs update for context_aware_warnings def test_sys_flags(self): self.assertTrue(sys.flags) attrs = ("debug", @@ -889,11 +865,10 @@ def test_sys_flags(self): "dont_write_bytecode", "no_user_site", "no_site", "ignore_environment", "verbose", "bytes_warning", "quiet", "hash_randomization", "isolated", "dev_mode", "utf8_mode", - "warn_default_encoding", "safe_path", "int_max_str_digits", - "thread_inherit_context") + "warn_default_encoding", "safe_path", "int_max_str_digits") for attr in attrs: - self.assertTrue(hasattr(sys.flags, attr), attr) - attr_type = bool if attr in ("dev_mode", "safe_path", "thread_inherit_context") else int + self.assertHasAttr(sys.flags, attr) + attr_type = bool if attr in ("dev_mode", "safe_path") else int self.assertEqual(type(getattr(sys.flags, attr)), attr_type, attr) self.assertTrue(repr(sys.flags)) self.assertEqual(len(sys.flags), len(attrs)) @@ -919,9 +894,11 @@ def test_sys_getwindowsversion_no_instantiation(self): @test.support.cpython_only def test_clear_type_cache(self): - sys._clear_type_cache() + with self.assertWarnsRegex(DeprecationWarning, + r"sys\._clear_type_cache\(\) is deprecated.*"): + sys._clear_type_cache() - @unittest.skip("TODO: RUSTPYTHON; cp424 encoding not supported, causes panic") + @unittest.skip('TODO: RUSTPYTHON; cp424 encoding not supported, causes panic') @force_not_colorized @support.requires_subprocess() def test_ioencoding(self): @@ -1101,10 +1078,11 @@ def test_implementation(self): levels = {'alpha': 0xA, 'beta': 0xB, 'candidate': 0xC, 'final': 0xF} - self.assertTrue(hasattr(sys.implementation, 'name')) - self.assertTrue(hasattr(sys.implementation, 'version')) - self.assertTrue(hasattr(sys.implementation, 'hexversion')) - self.assertTrue(hasattr(sys.implementation, 'cache_tag')) + self.assertHasAttr(sys.implementation, 'name') + self.assertHasAttr(sys.implementation, 'version') + self.assertHasAttr(sys.implementation, 'hexversion') + self.assertHasAttr(sys.implementation, 'cache_tag') + self.assertHasAttr(sys.implementation, 'supports_isolated_interpreters') version = sys.implementation.version self.assertEqual(version[:2], (version.major, version.minor)) @@ -1118,6 +1096,15 @@ def test_implementation(self): self.assertEqual(sys.implementation.name, sys.implementation.name.lower()) + # https://peps.python.org/pep-0734 + sii = sys.implementation.supports_isolated_interpreters + self.assertIsInstance(sii, bool) + if test.support.check_impl_detail(cpython=True): + if test.support.is_emscripten or test.support.is_wasi: + self.assertFalse(sii) + else: + self.assertTrue(sii) + @test.support.cpython_only def test_debugmallocstats(self): # Test sys._debugmallocstats() @@ -1128,14 +1115,10 @@ def test_debugmallocstats(self): # Output of sys._debugmallocstats() depends on configure flags. # The sysconfig vars are not available on Windows. if sys.platform != "win32": - with_freelists = sysconfig.get_config_var("WITH_FREELISTS") with_pymalloc = sysconfig.get_config_var("WITH_PYMALLOC") - if with_freelists: - self.assertIn(b"free PyDictObjects", err) + self.assertIn(b"free PyDictObjects", err) if with_pymalloc: self.assertIn(b'Small block threshold', err) - if not with_freelists and not with_pymalloc: - self.assertFalse(err) # The function has no parameter self.assertRaises(TypeError, sys._debugmallocstats, True) @@ -1166,18 +1149,20 @@ def test_getallocatedblocks(self): # about the underlying implementation: the function might # return 0 or something greater. self.assertGreaterEqual(a, 0) + gc.collect() + b = sys.getallocatedblocks() + self.assertLessEqual(b, a) try: - # While we could imagine a Python session where the number of - # multiple buffer objects would exceed the sharing of references, - # it is unlikely to happen in a normal test run. - self.assertLess(a, sys.gettotalrefcount()) + # The reported blocks will include immortalized strings, but the + # total ref count will not. This will sanity check that among all + # other objects (those eligible for garbage collection) there + # are more references being tracked than allocated blocks. + interned_immortal = sys.getunicodeinternedsize(_only_immortal=True) + self.assertLess(a - interned_immortal, sys.gettotalrefcount()) except AttributeError: # gettotalrefcount() not available pass gc.collect() - b = sys.getallocatedblocks() - self.assertLessEqual(b, a) - gc.collect() c = sys.getallocatedblocks() self.assertIn(c, range(b - 50, b + 50)) @@ -1444,7 +1429,7 @@ def __del__(self): else: self.assertIn("ValueError", report) self.assertIn("del is broken", report) - self.assertTrue(report.endswith("\n")) + self.assertEndsWith(report, "\n") def test_original_unraisablehook_exception_qualname(self): # See bpo-41031, bpo-45083. @@ -1689,15 +1674,19 @@ class C(object): pass # float check(float(0), size('d')) # sys.floatinfo - check(sys.float_info, vsize('') + self.P * len(sys.float_info)) + check(sys.float_info, self.P + vsize('') + self.P * len(sys.float_info)) # frame def func(): return sys._getframe() x = func() - check(x, size('3Pi2c2P7P2ic??2P')) + if support.Py_GIL_DISABLED: + INTERPRETER_FRAME = '9PihcP' + else: + INTERPRETER_FRAME = '9PhcP' + check(x, size('3PiccPPP' + INTERPRETER_FRAME + 'P')) # function def func(): pass - check(func, size('15Pi')) + check(func, size('16Pi')) class c(): @staticmethod def foo(): @@ -1711,7 +1700,7 @@ def bar(cls): check(bar, size('PP')) # generator def get_gen(): yield 1 - check(get_gen(), size('PP4P4c7P2ic??2P')) + check(get_gen(), size('6P4c' + INTERPRETER_FRAME + 'P')) # iterator check(iter('abc'), size('lP')) # callable-iterator @@ -1793,13 +1782,14 @@ def delx(self): del self.__x # super check(super(int), size('3P')) # tuple - check((), vsize('')) - check((1,2,3), vsize('') + 3*self.P) + check((), vsize('') + self.P) + check((1,2,3), vsize('') + self.P + 3*self.P) # type # static type: PyTypeObject fmt = 'P2nPI13Pl4Pn9Pn12PIPc' s = vsize(fmt) check(int, s) + typeid = 'n' if support.Py_GIL_DISABLED else '' # class s = vsize(fmt + # PyTypeObject '4P' # PyAsyncMethods @@ -1807,8 +1797,9 @@ def delx(self): del self.__x '3P' # PyMappingMethods '10P' # PySequenceMethods '2P' # PyBufferProcs - '6P' - '1PIP' # Specializer cache + '7P' + '1PIP' # Specializer cache + + typeid # heap type id (free-threaded only) ) class newstyleclass(object): pass # Separate block for PyDictKeysObject with 8 keys and 5 entries @@ -1913,8 +1904,10 @@ def test_pythontypes(self): # symtable entry # XXX # sys.flags - # FIXME: The +1 will not be necessary once gh-122575 is fixed - check(sys.flags, vsize('') + self.P * (1 + len(sys.flags))) + # FIXME: The +3 is for the 'gil', 'thread_inherit_context' and + # 'context_aware_warnings' flags and will not be necessary once + # gh-122575 is fixed + check(sys.flags, vsize('') + self.P + self.P * (3 + len(sys.flags))) def test_asyncgen_hooks(self): old = sys.get_asyncgen_hooks() @@ -1972,5 +1965,318 @@ def write(self, s): self.assertEqual(out, b"") self.assertEqual(err, b"") +@test.support.support_remote_exec_only +@test.support.cpython_only +class TestRemoteExec(unittest.TestCase): + def tearDown(self): + test.support.reap_children() + + def _run_remote_exec_test(self, script_code, python_args=None, env=None, + prologue='', + script_path=os_helper.TESTFN + '_remote.py'): + # Create the script that will be remotely executed + self.addCleanup(os_helper.unlink, script_path) + + with open(script_path, 'w') as f: + f.write(script_code) + + # Create and run the target process + target = os_helper.TESTFN + '_target.py' + self.addCleanup(os_helper.unlink, target) + + port = find_unused_port() + + with open(target, 'w') as f: + f.write(f''' +import sys +import time +import socket + +# Connect to the test process +sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) +sock.connect(('localhost', {port})) + +{prologue} + +# Signal that the process is ready +sock.sendall(b"ready") + +print("Target process running...") + +# Wait for remote script to be executed +# (the execution will happen as the following +# code is processed as soon as the recv call +# unblocks) +sock.recv(1024) + +# Do a bunch of work to give the remote script time to run +x = 0 +for i in range(100): + x += i + +# Write confirmation back +sock.sendall(b"executed") +sock.close() +''') + + # Start the target process and capture its output + cmd = [sys.executable] + if python_args: + cmd.extend(python_args) + cmd.append(target) + + # Create a socket server to communicate with the target process + server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + server_socket.bind(('localhost', port)) + server_socket.settimeout(SHORT_TIMEOUT) + server_socket.listen(1) + + with subprocess.Popen(cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + env=env, + ) as proc: + client_socket = None + try: + # Accept connection from target process + client_socket, _ = server_socket.accept() + server_socket.close() + + response = client_socket.recv(1024) + self.assertEqual(response, b"ready") + + # Try remote exec on the target process + sys.remote_exec(proc.pid, script_path) + + # Signal script to continue + client_socket.sendall(b"continue") + + # Wait for execution confirmation + response = client_socket.recv(1024) + self.assertEqual(response, b"executed") + + # Return output for test verification + stdout, stderr = proc.communicate(timeout=10.0) + return proc.returncode, stdout, stderr + except PermissionError: + self.skipTest("Insufficient permissions to execute code in remote process") + finally: + if client_socket is not None: + client_socket.close() + proc.kill() + proc.terminate() + proc.wait(timeout=SHORT_TIMEOUT) + + def test_remote_exec(self): + """Test basic remote exec functionality""" + script = 'print("Remote script executed successfully!")' + returncode, stdout, stderr = self._run_remote_exec_test(script) + # self.assertEqual(returncode, 0) + self.assertIn(b"Remote script executed successfully!", stdout) + self.assertEqual(stderr, b"") + + def test_remote_exec_bytes(self): + script = 'print("Remote script executed successfully!")' + script_path = os.fsencode(os_helper.TESTFN) + b'_bytes_remote.py' + returncode, stdout, stderr = self._run_remote_exec_test(script, + script_path=script_path) + self.assertIn(b"Remote script executed successfully!", stdout) + self.assertEqual(stderr, b"") + + @unittest.skipUnless(os_helper.TESTFN_UNDECODABLE, 'requires undecodable path') + @unittest.skipIf(sys.platform == 'darwin', + 'undecodable paths are not supported on macOS') + def test_remote_exec_undecodable(self): + script = 'print("Remote script executed successfully!")' + script_path = os_helper.TESTFN_UNDECODABLE + b'_undecodable_remote.py' + for script_path in [script_path, os.fsdecode(script_path)]: + returncode, stdout, stderr = self._run_remote_exec_test(script, + script_path=script_path) + self.assertIn(b"Remote script executed successfully!", stdout) + self.assertEqual(stderr, b"") + + def test_remote_exec_with_self_process(self): + """Test remote exec with the target process being the same as the test process""" + + code = 'import sys;print("Remote script executed successfully!", file=sys.stderr)' + file = os_helper.TESTFN + '_remote_self.py' + with open(file, 'w') as f: + f.write(code) + self.addCleanup(os_helper.unlink, file) + with mock.patch('sys.stderr', new_callable=StringIO) as mock_stderr: + with mock.patch('sys.stdout', new_callable=StringIO) as mock_stdout: + sys.remote_exec(os.getpid(), os.path.abspath(file)) + print("Done") + self.assertEqual(mock_stderr.getvalue(), "Remote script executed successfully!\n") + self.assertEqual(mock_stdout.getvalue(), "Done\n") + + def test_remote_exec_raises_audit_event(self): + """Test remote exec raises an audit event""" + prologue = '''\ +import sys +def audit_hook(event, arg): + print(f"Audit event: {event}, arg: {arg}".encode("ascii", errors="replace")) +sys.addaudithook(audit_hook) +''' + script = ''' +print("Remote script executed successfully!") +''' + returncode, stdout, stderr = self._run_remote_exec_test(script, prologue=prologue) + self.assertEqual(returncode, 0) + self.assertIn(b"Remote script executed successfully!", stdout) + self.assertIn(b"Audit event: cpython.remote_debugger_script, arg: ", stdout) + self.assertEqual(stderr, b"") + + def test_remote_exec_with_exception(self): + """Test remote exec with an exception raised in the target process + + The exception should be raised in the main thread of the target process + but not crash the target process. + """ + script = ''' +raise Exception("Remote script exception") +''' + returncode, stdout, stderr = self._run_remote_exec_test(script) + self.assertEqual(returncode, 0) + self.assertIn(b"Remote script exception", stderr) + self.assertEqual(stdout.strip(), b"Target process running...") + + def test_new_namespace_for_each_remote_exec(self): + """Test that each remote_exec call gets its own namespace.""" + script = textwrap.dedent( + """ + assert globals() is not __import__("__main__").__dict__ + print("Remote script executed successfully!") + """ + ) + returncode, stdout, stderr = self._run_remote_exec_test(script) + self.assertEqual(returncode, 0) + self.assertEqual(stderr, b"") + self.assertIn(b"Remote script executed successfully", stdout) + + def test_remote_exec_disabled_by_env(self): + """Test remote exec is disabled when PYTHON_DISABLE_REMOTE_DEBUG is set""" + env = os.environ.copy() + env['PYTHON_DISABLE_REMOTE_DEBUG'] = '1' + with self.assertRaisesRegex(RuntimeError, "Remote debugging is not enabled in the remote process"): + self._run_remote_exec_test("print('should not run')", env=env) + + def test_remote_exec_disabled_by_xoption(self): + """Test remote exec is disabled with -Xdisable-remote-debug""" + with self.assertRaisesRegex(RuntimeError, "Remote debugging is not enabled in the remote process"): + self._run_remote_exec_test("print('should not run')", python_args=['-Xdisable-remote-debug']) + + def test_remote_exec_invalid_pid(self): + """Test remote exec with invalid process ID""" + with self.assertRaises(OSError): + sys.remote_exec(99999, "print('should not run')") + + def test_remote_exec_invalid_script(self): + """Test remote exec with invalid script type""" + with self.assertRaises(TypeError): + sys.remote_exec(0, None) + with self.assertRaises(TypeError): + sys.remote_exec(0, 123) + + def test_remote_exec_syntax_error(self): + """Test remote exec with syntax error in script""" + script = ''' +this is invalid python code +''' + returncode, stdout, stderr = self._run_remote_exec_test(script) + self.assertEqual(returncode, 0) + self.assertIn(b"SyntaxError", stderr) + self.assertEqual(stdout.strip(), b"Target process running...") + + def test_remote_exec_invalid_script_path(self): + """Test remote exec with invalid script path""" + with self.assertRaises(OSError): + sys.remote_exec(os.getpid(), "invalid_script_path") + + def test_remote_exec_in_process_without_debug_fails_envvar(self): + """Test remote exec in a process without remote debugging enabled""" + script = os_helper.TESTFN + '_remote.py' + self.addCleanup(os_helper.unlink, script) + with open(script, 'w') as f: + f.write('print("Remote script executed successfully!")') + env = os.environ.copy() + env['PYTHON_DISABLE_REMOTE_DEBUG'] = '1' + + _, out, err = assert_python_failure('-c', f'import os, sys; sys.remote_exec(os.getpid(), "{script}")', **env) + self.assertIn(b"Remote debugging is not enabled", err) + self.assertEqual(out, b"") + + def test_remote_exec_in_process_without_debug_fails_xoption(self): + """Test remote exec in a process without remote debugging enabled""" + script = os_helper.TESTFN + '_remote.py' + self.addCleanup(os_helper.unlink, script) + with open(script, 'w') as f: + f.write('print("Remote script executed successfully!")') + + _, out, err = assert_python_failure('-Xdisable-remote-debug', '-c', f'import os, sys; sys.remote_exec(os.getpid(), "{script}")') + self.assertIn(b"Remote debugging is not enabled", err) + self.assertEqual(out, b"") + +class TestSysJIT(unittest.TestCase): + + def test_jit_is_available(self): + available = sys._jit.is_available() + script = f"import sys; assert sys._jit.is_available() is {available}" + assert_python_ok("-c", script, PYTHON_JIT="0") + assert_python_ok("-c", script, PYTHON_JIT="1") + + def test_jit_is_enabled(self): + available = sys._jit.is_available() + script = "import sys; assert sys._jit.is_enabled() is {enabled}" + assert_python_ok("-c", script.format(enabled=False), PYTHON_JIT="0") + assert_python_ok("-c", script.format(enabled=available), PYTHON_JIT="1") + + @unittest.expectedFailure # TODO: RUSTPYTHON + def test_jit_is_active(self): + available = sys._jit.is_available() + script = textwrap.dedent( + """ + import _testcapi + import _testinternalcapi + import sys + + def frame_0_interpreter() -> None: + assert sys._jit.is_active() is False + + def frame_1_interpreter() -> None: + assert sys._jit.is_active() is False + frame_0_interpreter() + assert sys._jit.is_active() is False + + def frame_2_jit(expected: bool) -> None: + # Inlined into the last loop of frame_3_jit: + assert sys._jit.is_active() is expected + # Insert C frame: + _testcapi.pyobject_vectorcall(frame_1_interpreter, None, None) + assert sys._jit.is_active() is expected + + def frame_3_jit() -> None: + # JITs just before the last loop: + for i in range(_testinternalcapi.TIER2_THRESHOLD + 1): + # Careful, doing this in the reverse order breaks tracing: + expected = {enabled} and i == _testinternalcapi.TIER2_THRESHOLD + assert sys._jit.is_active() is expected + frame_2_jit(expected) + assert sys._jit.is_active() is expected + + def frame_4_interpreter() -> None: + assert sys._jit.is_active() is False + frame_3_jit() + assert sys._jit.is_active() is False + + assert sys._jit.is_active() is False + frame_4_interpreter() + assert sys._jit.is_active() is False + """ + ) + assert_python_ok("-c", script.format(enabled=False), PYTHON_JIT="0") + assert_python_ok("-c", script.format(enabled=available), PYTHON_JIT="1") + + if __name__ == "__main__": unittest.main() diff --git a/Lib/test/test_unittest/__init__.py b/Lib/test/test_unittest/__init__.py index bc502ef32d2..365f26d6438 100644 --- a/Lib/test/test_unittest/__init__.py +++ b/Lib/test/test_unittest/__init__.py @@ -1,4 +1,5 @@ import os.path + from test.support import load_package_tests diff --git a/Lib/test/test_unittest/__main__.py b/Lib/test/test_unittest/__main__.py index 40a23a297ec..0d53bfab847 100644 --- a/Lib/test/test_unittest/__main__.py +++ b/Lib/test/test_unittest/__main__.py @@ -1,4 +1,5 @@ -from . import load_tests import unittest +from . import load_tests + unittest.main() diff --git a/Lib/test/test_unittest/_test_warnings.py b/Lib/test/test_unittest/_test_warnings.py index 08b846ee47e..d9f41a4144b 100644 --- a/Lib/test/test_unittest/_test_warnings.py +++ b/Lib/test/test_unittest/_test_warnings.py @@ -14,6 +14,7 @@ import unittest import warnings + def warnfun(): warnings.warn('rw', RuntimeWarning) diff --git a/Lib/test/test_unittest/test_assertions.py b/Lib/test/test_unittest/test_assertions.py index 1dec947ea76..3d782573d7b 100644 --- a/Lib/test/test_unittest/test_assertions.py +++ b/Lib/test/test_unittest/test_assertions.py @@ -1,10 +1,11 @@ import datetime +import unittest import warnings import weakref -import unittest -from test.support import gc_collect from itertools import product +from test.support import gc_collect + class Test_Assertions(unittest.TestCase): def test_AlmostEqual(self): diff --git a/Lib/test/test_unittest/test_async_case.py b/Lib/test/test_unittest/test_async_case.py index 8f73f466379..7b3cb949e4c 100644 --- a/Lib/test/test_unittest/test_async_case.py +++ b/Lib/test/test_unittest/test_async_case.py @@ -1,7 +1,9 @@ import asyncio import contextvars import unittest + from test import support +from test.support import force_not_colorized support.requires_working_socket(module=True) @@ -11,7 +13,9 @@ class MyException(Exception): def tearDownModule(): - asyncio.set_event_loop_policy(None) + # XXX: RUSTPYTHON; asyncio.events._set_event_loop_policy is not implemented + # asyncio.events._set_event_loop_policy(None) + pass class TestCM: @@ -253,6 +257,7 @@ async def on_cleanup(self): test.doCleanups() self.assertEqual(events, ['asyncSetUp', 'test', 'asyncTearDown', 'cleanup']) + @force_not_colorized def test_exception_in_tear_clean_up(self): class Test(unittest.IsolatedAsyncioTestCase): async def asyncSetUp(self): @@ -314,18 +319,21 @@ async def test3(self): self.assertIn('It is deprecated to return a value that is not None', str(w.warning)) self.assertIn('test1', str(w.warning)) self.assertEqual(w.filename, __file__) + self.assertIn("returned 'int'", str(w.warning)) with self.assertWarns(DeprecationWarning) as w: Test('test2').run() self.assertIn('It is deprecated to return a value that is not None', str(w.warning)) self.assertIn('test2', str(w.warning)) self.assertEqual(w.filename, __file__) + self.assertIn("returned 'async_generator'", str(w.warning)) with self.assertWarns(DeprecationWarning) as w: Test('test3').run() self.assertIn('It is deprecated to return a value that is not None', str(w.warning)) self.assertIn('test3', str(w.warning)) self.assertEqual(w.filename, __file__) + self.assertIn(f'returned {Nothing.__name__!r}', str(w.warning)) def test_cleanups_interleave_order(self): events = [] @@ -470,6 +478,7 @@ async def cleanup(self, fut): test.doCleanups() self.assertEqual(events, ['asyncSetUp', 'test', 'cleanup']) + @unittest.expectedFailure def test_setup_get_event_loop(self): # See https://github.com/python/cpython/issues/95736 # Make sure the default event loop is not used @@ -477,7 +486,7 @@ def test_setup_get_event_loop(self): class TestCase1(unittest.IsolatedAsyncioTestCase): def setUp(self): - asyncio.get_event_loop_policy().get_event_loop() + asyncio.events._get_event_loop_policy().get_event_loop() async def test_demo1(self): pass @@ -486,8 +495,9 @@ async def test_demo1(self): result = test.run() self.assertTrue(result.wasSuccessful()) + @unittest.expectedFailure # TODO: RUSTPYTHON; asyncio.events._set_event_loop_policy is not implemented def test_loop_factory(self): - asyncio.set_event_loop_policy(None) + asyncio.events._set_event_loop_policy(None) class TestCase1(unittest.IsolatedAsyncioTestCase): loop_factory = asyncio.EventLoop diff --git a/Lib/test/test_unittest/test_break.py b/Lib/test/test_unittest/test_break.py index 1da98af3e74..8aa20008ac7 100644 --- a/Lib/test/test_unittest/test_break.py +++ b/Lib/test/test_unittest/test_break.py @@ -1,10 +1,10 @@ import gc import io import os -import sys import signal -import weakref +import sys import unittest +import weakref from test import support diff --git a/Lib/test/test_unittest/test_case.py b/Lib/test/test_unittest/test_case.py index 82a442a04e6..6e77040c265 100644 --- a/Lib/test/test_unittest/test_case.py +++ b/Lib/test/test_unittest/test_case.py @@ -1,26 +1,27 @@ import contextlib import difflib -import pprint +import inspect +import logging import pickle +import pprint import re import sys -import logging +import types +import unittest import warnings import weakref -import inspect -import types - +from collections import UserString from copy import deepcopy -from test import support - -import unittest +from test import support +from test.support import captured_stderr, gc_collect from test.test_unittest.support import ( - TestEquality, TestHashing, LoggingResult, LegacyLoggingResult, - ResultWithNoStartTestRunStopTestRun + LegacyLoggingResult, + LoggingResult, + ResultWithNoStartTestRunStopTestRun, + TestEquality, + TestHashing, ) -from test.support import captured_stderr, gc_collect - log_foo = logging.getLogger('foo') log_foobar = logging.getLogger('foo.bar') @@ -54,6 +55,10 @@ def tearDown(self): self.events.append('tearDown') +class List(list): + pass + + class Test_TestCase(unittest.TestCase, TestEquality, TestHashing): ### Set up attributes used by inherited tests @@ -85,7 +90,7 @@ class Test(unittest.TestCase): def runTest(self): raise MyException() def test(self): pass - self.assertEqual(Test().id()[-13:], '.Test.runTest') + self.assertEndsWith(Test().id(), '.Test.runTest') # test that TestCase can be instantiated with no args # primarily for use at the interactive interpreter @@ -106,7 +111,7 @@ class Test(unittest.TestCase): def runTest(self): raise MyException() def test(self): pass - self.assertEqual(Test('test').id()[-10:], '.Test.test') + self.assertEndsWith(Test('test').id(), '.Test.test') # "class TestCase([methodName])" # ... @@ -325,18 +330,40 @@ def test3(self): self.assertIn('It is deprecated to return a value that is not None', str(w.warning)) self.assertIn('test1', str(w.warning)) self.assertEqual(w.filename, __file__) + self.assertIn("returned 'int'", str(w.warning)) with self.assertWarns(DeprecationWarning) as w: Foo('test2').run() self.assertIn('It is deprecated to return a value that is not None', str(w.warning)) self.assertIn('test2', str(w.warning)) self.assertEqual(w.filename, __file__) + self.assertIn("returned 'generator'", str(w.warning)) with self.assertWarns(DeprecationWarning) as w: Foo('test3').run() self.assertIn('It is deprecated to return a value that is not None', str(w.warning)) self.assertIn('test3', str(w.warning)) self.assertEqual(w.filename, __file__) + self.assertIn(f'returned {Nothing.__name__!r}', str(w.warning)) + + def test_deprecation_of_return_val_from_test_async_method(self): + class Foo(unittest.TestCase): + async def test1(self): + return 1 + + with self.assertWarns(DeprecationWarning) as w: + warnings.filterwarnings('ignore', + 'coroutine .* was never awaited', RuntimeWarning) + Foo('test1').run() + support.gc_collect() + self.assertIn('It is deprecated to return a value that is not None', str(w.warning)) + self.assertIn('test1', str(w.warning)) + self.assertEqual(w.filename, __file__) + self.assertIn("returned 'coroutine'", str(w.warning)) + self.assertIn( + 'Maybe you forgot to use IsolatedAsyncioTestCase as the base class?', + str(w.warning), + ) def _check_call_order__subtests(self, result, events, expected_events): class Foo(Test.LoggingTestCase): @@ -678,16 +705,136 @@ def testAssertIsNot(self): self.assertRaises(self.failureException, self.assertIsNot, thing, thing) def testAssertIsInstance(self): - thing = [] + thing = List() self.assertIsInstance(thing, list) - self.assertRaises(self.failureException, self.assertIsInstance, - thing, dict) + self.assertIsInstance(thing, (int, list)) + with self.assertRaises(self.failureException) as cm: + self.assertIsInstance(thing, int) + self.assertEqual(str(cm.exception), + "[] is not an instance of ") + with self.assertRaises(self.failureException) as cm: + self.assertIsInstance(thing, (int, float)) + self.assertEqual(str(cm.exception), + "[] is not an instance of any of (, )") + + with self.assertRaises(self.failureException) as cm: + self.assertIsInstance(thing, int, 'ababahalamaha') + self.assertIn('ababahalamaha', str(cm.exception)) + with self.assertRaises(self.failureException) as cm: + self.assertIsInstance(thing, int, msg='ababahalamaha') + self.assertIn('ababahalamaha', str(cm.exception)) def testAssertNotIsInstance(self): - thing = [] - self.assertNotIsInstance(thing, dict) - self.assertRaises(self.failureException, self.assertNotIsInstance, - thing, list) + thing = List() + self.assertNotIsInstance(thing, int) + self.assertNotIsInstance(thing, (int, float)) + with self.assertRaises(self.failureException) as cm: + self.assertNotIsInstance(thing, list) + self.assertEqual(str(cm.exception), + "[] is an instance of ") + with self.assertRaises(self.failureException) as cm: + self.assertNotIsInstance(thing, (int, list)) + self.assertEqual(str(cm.exception), + "[] is an instance of ") + + with self.assertRaises(self.failureException) as cm: + self.assertNotIsInstance(thing, list, 'ababahalamaha') + self.assertIn('ababahalamaha', str(cm.exception)) + with self.assertRaises(self.failureException) as cm: + self.assertNotIsInstance(thing, list, msg='ababahalamaha') + self.assertIn('ababahalamaha', str(cm.exception)) + + def testAssertIsSubclass(self): + self.assertIsSubclass(List, list) + self.assertIsSubclass(List, (int, list)) + with self.assertRaises(self.failureException) as cm: + self.assertIsSubclass(List, int) + self.assertEqual(str(cm.exception), + f"{List!r} is not a subclass of ") + with self.assertRaises(self.failureException) as cm: + self.assertIsSubclass(List, (int, float)) + self.assertEqual(str(cm.exception), + f"{List!r} is not a subclass of any of (, )") + with self.assertRaises(self.failureException) as cm: + self.assertIsSubclass(1, int) + self.assertEqual(str(cm.exception), "1 is not a class") + + with self.assertRaises(self.failureException) as cm: + self.assertIsSubclass(List, int, 'ababahalamaha') + self.assertIn('ababahalamaha', str(cm.exception)) + with self.assertRaises(self.failureException) as cm: + self.assertIsSubclass(List, int, msg='ababahalamaha') + self.assertIn('ababahalamaha', str(cm.exception)) + + def testAssertNotIsSubclass(self): + self.assertNotIsSubclass(List, int) + self.assertNotIsSubclass(List, (int, float)) + with self.assertRaises(self.failureException) as cm: + self.assertNotIsSubclass(List, list) + self.assertEqual(str(cm.exception), + f"{List!r} is a subclass of ") + with self.assertRaises(self.failureException) as cm: + self.assertNotIsSubclass(List, (int, list)) + self.assertEqual(str(cm.exception), + f"{List!r} is a subclass of ") + with self.assertRaises(self.failureException) as cm: + self.assertNotIsSubclass(1, int) + self.assertEqual(str(cm.exception), "1 is not a class") + + with self.assertRaises(self.failureException) as cm: + self.assertNotIsSubclass(List, list, 'ababahalamaha') + self.assertIn('ababahalamaha', str(cm.exception)) + with self.assertRaises(self.failureException) as cm: + self.assertNotIsSubclass(List, list, msg='ababahalamaha') + self.assertIn('ababahalamaha', str(cm.exception)) + + def testAssertHasAttr(self): + a = List() + a.x = 1 + self.assertHasAttr(a, 'x') + with self.assertRaises(self.failureException) as cm: + self.assertHasAttr(a, 'y') + self.assertEqual(str(cm.exception), + "'List' object has no attribute 'y'") + with self.assertRaises(self.failureException) as cm: + self.assertHasAttr(List, 'spam') + self.assertEqual(str(cm.exception), + "type object 'List' has no attribute 'spam'") + with self.assertRaises(self.failureException) as cm: + self.assertHasAttr(sys, 'nonexistent') + self.assertEqual(str(cm.exception), + "module 'sys' has no attribute 'nonexistent'") + + with self.assertRaises(self.failureException) as cm: + self.assertHasAttr(a, 'y', 'ababahalamaha') + self.assertIn('ababahalamaha', str(cm.exception)) + with self.assertRaises(self.failureException) as cm: + self.assertHasAttr(a, 'y', msg='ababahalamaha') + self.assertIn('ababahalamaha', str(cm.exception)) + + def testAssertNotHasAttr(self): + a = List() + a.x = 1 + self.assertNotHasAttr(a, 'y') + with self.assertRaises(self.failureException) as cm: + self.assertNotHasAttr(a, 'x') + self.assertEqual(str(cm.exception), + "'List' object has unexpected attribute 'x'") + with self.assertRaises(self.failureException) as cm: + self.assertNotHasAttr(List, 'append') + self.assertEqual(str(cm.exception), + "type object 'List' has unexpected attribute 'append'") + with self.assertRaises(self.failureException) as cm: + self.assertNotHasAttr(sys, 'modules') + self.assertEqual(str(cm.exception), + "module 'sys' has unexpected attribute 'modules'") + + with self.assertRaises(self.failureException) as cm: + self.assertNotHasAttr(a, 'x', 'ababahalamaha') + self.assertIn('ababahalamaha', str(cm.exception)) + with self.assertRaises(self.failureException) as cm: + self.assertNotHasAttr(a, 'x', msg='ababahalamaha') + self.assertIn('ababahalamaha', str(cm.exception)) def testAssertIn(self): animals = {'monkey': 'banana', 'cow': 'grass', 'seal': 'fish'} @@ -1842,6 +1989,186 @@ def testAssertNoLogsYieldsNone(self): pass self.assertIsNone(value) + def testAssertStartsWith(self): + self.assertStartsWith('ababahalamaha', 'ababa') + self.assertStartsWith('ababahalamaha', ('x', 'ababa', 'y')) + self.assertStartsWith(UserString('ababahalamaha'), 'ababa') + self.assertStartsWith(UserString('ababahalamaha'), ('x', 'ababa', 'y')) + self.assertStartsWith(bytearray(b'ababahalamaha'), b'ababa') + self.assertStartsWith(bytearray(b'ababahalamaha'), (b'x', b'ababa', b'y')) + self.assertStartsWith(b'ababahalamaha', bytearray(b'ababa')) + self.assertStartsWith(b'ababahalamaha', + (bytearray(b'x'), bytearray(b'ababa'), bytearray(b'y'))) + + with self.assertRaises(self.failureException) as cm: + self.assertStartsWith('ababahalamaha', 'amaha') + self.assertEqual(str(cm.exception), + "'ababahalamaha' doesn't start with 'amaha'") + with self.assertRaises(self.failureException) as cm: + self.assertStartsWith('ababahalamaha', ('x', 'y')) + self.assertEqual(str(cm.exception), + "'ababahalamaha' doesn't start with any of ('x', 'y')") + + with self.assertRaises(self.failureException) as cm: + self.assertStartsWith(b'ababahalamaha', 'ababa') + self.assertEqual(str(cm.exception), 'Expected str, not bytes') + with self.assertRaises(self.failureException) as cm: + self.assertStartsWith(b'ababahalamaha', ('amaha', 'ababa')) + self.assertEqual(str(cm.exception), 'Expected str, not bytes') + with self.assertRaises(self.failureException) as cm: + self.assertStartsWith([], 'ababa') + self.assertEqual(str(cm.exception), 'Expected str, not list') + with self.assertRaises(self.failureException) as cm: + self.assertStartsWith('ababahalamaha', b'ababa') + self.assertEqual(str(cm.exception), 'Expected bytes, not str') + with self.assertRaises(self.failureException) as cm: + self.assertStartsWith('ababahalamaha', (b'amaha', b'ababa')) + self.assertEqual(str(cm.exception), 'Expected bytes, not str') + with self.assertRaises(TypeError): + self.assertStartsWith('ababahalamaha', ord('a')) + + with self.assertRaises(self.failureException) as cm: + self.assertStartsWith('ababahalamaha', 'amaha', 'abracadabra') + self.assertIn('ababahalamaha', str(cm.exception)) + with self.assertRaises(self.failureException) as cm: + self.assertStartsWith('ababahalamaha', 'amaha', msg='abracadabra') + self.assertIn('ababahalamaha', str(cm.exception)) + + def testAssertNotStartsWith(self): + self.assertNotStartsWith('ababahalamaha', 'amaha') + self.assertNotStartsWith('ababahalamaha', ('x', 'amaha', 'y')) + self.assertNotStartsWith(UserString('ababahalamaha'), 'amaha') + self.assertNotStartsWith(UserString('ababahalamaha'), ('x', 'amaha', 'y')) + self.assertNotStartsWith(bytearray(b'ababahalamaha'), b'amaha') + self.assertNotStartsWith(bytearray(b'ababahalamaha'), (b'x', b'amaha', b'y')) + self.assertNotStartsWith(b'ababahalamaha', bytearray(b'amaha')) + self.assertNotStartsWith(b'ababahalamaha', + (bytearray(b'x'), bytearray(b'amaha'), bytearray(b'y'))) + + with self.assertRaises(self.failureException) as cm: + self.assertNotStartsWith('ababahalamaha', 'ababa') + self.assertEqual(str(cm.exception), + "'ababahalamaha' starts with 'ababa'") + with self.assertRaises(self.failureException) as cm: + self.assertNotStartsWith('ababahalamaha', ('x', 'ababa', 'y')) + self.assertEqual(str(cm.exception), + "'ababahalamaha' starts with 'ababa'") + + with self.assertRaises(self.failureException) as cm: + self.assertNotStartsWith(b'ababahalamaha', 'ababa') + self.assertEqual(str(cm.exception), 'Expected str, not bytes') + with self.assertRaises(self.failureException) as cm: + self.assertNotStartsWith(b'ababahalamaha', ('amaha', 'ababa')) + self.assertEqual(str(cm.exception), 'Expected str, not bytes') + with self.assertRaises(self.failureException) as cm: + self.assertNotStartsWith([], 'ababa') + self.assertEqual(str(cm.exception), 'Expected str, not list') + with self.assertRaises(self.failureException) as cm: + self.assertNotStartsWith('ababahalamaha', b'ababa') + self.assertEqual(str(cm.exception), 'Expected bytes, not str') + with self.assertRaises(self.failureException) as cm: + self.assertNotStartsWith('ababahalamaha', (b'amaha', b'ababa')) + self.assertEqual(str(cm.exception), 'Expected bytes, not str') + with self.assertRaises(TypeError): + self.assertNotStartsWith('ababahalamaha', ord('a')) + + with self.assertRaises(self.failureException) as cm: + self.assertNotStartsWith('ababahalamaha', 'ababa', 'abracadabra') + self.assertIn('ababahalamaha', str(cm.exception)) + with self.assertRaises(self.failureException) as cm: + self.assertNotStartsWith('ababahalamaha', 'ababa', msg='abracadabra') + self.assertIn('ababahalamaha', str(cm.exception)) + + def testAssertEndsWith(self): + self.assertEndsWith('ababahalamaha', 'amaha') + self.assertEndsWith('ababahalamaha', ('x', 'amaha', 'y')) + self.assertEndsWith(UserString('ababahalamaha'), 'amaha') + self.assertEndsWith(UserString('ababahalamaha'), ('x', 'amaha', 'y')) + self.assertEndsWith(bytearray(b'ababahalamaha'), b'amaha') + self.assertEndsWith(bytearray(b'ababahalamaha'), (b'x', b'amaha', b'y')) + self.assertEndsWith(b'ababahalamaha', bytearray(b'amaha')) + self.assertEndsWith(b'ababahalamaha', + (bytearray(b'x'), bytearray(b'amaha'), bytearray(b'y'))) + + with self.assertRaises(self.failureException) as cm: + self.assertEndsWith('ababahalamaha', 'ababa') + self.assertEqual(str(cm.exception), + "'ababahalamaha' doesn't end with 'ababa'") + with self.assertRaises(self.failureException) as cm: + self.assertEndsWith('ababahalamaha', ('x', 'y')) + self.assertEqual(str(cm.exception), + "'ababahalamaha' doesn't end with any of ('x', 'y')") + + with self.assertRaises(self.failureException) as cm: + self.assertEndsWith(b'ababahalamaha', 'amaha') + self.assertEqual(str(cm.exception), 'Expected str, not bytes') + with self.assertRaises(self.failureException) as cm: + self.assertEndsWith(b'ababahalamaha', ('ababa', 'amaha')) + self.assertEqual(str(cm.exception), 'Expected str, not bytes') + with self.assertRaises(self.failureException) as cm: + self.assertEndsWith([], 'amaha') + self.assertEqual(str(cm.exception), 'Expected str, not list') + with self.assertRaises(self.failureException) as cm: + self.assertEndsWith('ababahalamaha', b'amaha') + self.assertEqual(str(cm.exception), 'Expected bytes, not str') + with self.assertRaises(self.failureException) as cm: + self.assertEndsWith('ababahalamaha', (b'ababa', b'amaha')) + self.assertEqual(str(cm.exception), 'Expected bytes, not str') + with self.assertRaises(TypeError): + self.assertEndsWith('ababahalamaha', ord('a')) + + with self.assertRaises(self.failureException) as cm: + self.assertEndsWith('ababahalamaha', 'ababa', 'abracadabra') + self.assertIn('ababahalamaha', str(cm.exception)) + with self.assertRaises(self.failureException) as cm: + self.assertEndsWith('ababahalamaha', 'ababa', msg='abracadabra') + self.assertIn('ababahalamaha', str(cm.exception)) + + def testAssertNotEndsWith(self): + self.assertNotEndsWith('ababahalamaha', 'ababa') + self.assertNotEndsWith('ababahalamaha', ('x', 'ababa', 'y')) + self.assertNotEndsWith(UserString('ababahalamaha'), 'ababa') + self.assertNotEndsWith(UserString('ababahalamaha'), ('x', 'ababa', 'y')) + self.assertNotEndsWith(bytearray(b'ababahalamaha'), b'ababa') + self.assertNotEndsWith(bytearray(b'ababahalamaha'), (b'x', b'ababa', b'y')) + self.assertNotEndsWith(b'ababahalamaha', bytearray(b'ababa')) + self.assertNotEndsWith(b'ababahalamaha', + (bytearray(b'x'), bytearray(b'ababa'), bytearray(b'y'))) + + with self.assertRaises(self.failureException) as cm: + self.assertNotEndsWith('ababahalamaha', 'amaha') + self.assertEqual(str(cm.exception), + "'ababahalamaha' ends with 'amaha'") + with self.assertRaises(self.failureException) as cm: + self.assertNotEndsWith('ababahalamaha', ('x', 'amaha', 'y')) + self.assertEqual(str(cm.exception), + "'ababahalamaha' ends with 'amaha'") + + with self.assertRaises(self.failureException) as cm: + self.assertNotEndsWith(b'ababahalamaha', 'amaha') + self.assertEqual(str(cm.exception), 'Expected str, not bytes') + with self.assertRaises(self.failureException) as cm: + self.assertNotEndsWith(b'ababahalamaha', ('ababa', 'amaha')) + self.assertEqual(str(cm.exception), 'Expected str, not bytes') + with self.assertRaises(self.failureException) as cm: + self.assertNotEndsWith([], 'amaha') + self.assertEqual(str(cm.exception), 'Expected str, not list') + with self.assertRaises(self.failureException) as cm: + self.assertNotEndsWith('ababahalamaha', b'amaha') + self.assertEqual(str(cm.exception), 'Expected bytes, not str') + with self.assertRaises(self.failureException) as cm: + self.assertNotEndsWith('ababahalamaha', (b'ababa', b'amaha')) + self.assertEqual(str(cm.exception), 'Expected bytes, not str') + with self.assertRaises(TypeError): + self.assertNotEndsWith('ababahalamaha', ord('a')) + + with self.assertRaises(self.failureException) as cm: + self.assertNotEndsWith('ababahalamaha', 'amaha', 'abracadabra') + self.assertIn('ababahalamaha', str(cm.exception)) + with self.assertRaises(self.failureException) as cm: + self.assertNotEndsWith('ababahalamaha', 'amaha', msg='abracadabra') + self.assertIn('ababahalamaha', str(cm.exception)) + def testDeprecatedFailMethods(self): """Test that the deprecated fail* methods get removed in 3.12""" deprecated_names = [ diff --git a/Lib/test/test_unittest/test_discovery.py b/Lib/test/test_unittest/test_discovery.py index a44b18406c0..9ed3d04b1f8 100644 --- a/Lib/test/test_unittest/test_discovery.py +++ b/Lib/test/test_unittest/test_discovery.py @@ -1,15 +1,17 @@ import os.path -from os.path import abspath +import pickle import re import sys import types -import pickle -from test import support -from test.support import import_helper - import unittest import unittest.mock +from importlib._bootstrap_external import NamespaceLoader +from os.path import abspath + import test.test_unittest +from test import support +from test.support import import_helper +from test.test_importlib import util as test_util class TestableTestProgram(unittest.TestProgram): @@ -395,7 +397,7 @@ def restore_isdir(): self.addCleanup(restore_isdir) _find_tests_args = [] - def _find_tests(start_dir, pattern): + def _find_tests(start_dir, pattern, namespace=None): _find_tests_args.append((start_dir, pattern)) return ['tests'] loader._find_tests = _find_tests @@ -815,7 +817,7 @@ def test_discovery_from_dotted_path(self): expectedPath = os.path.abspath(os.path.dirname(test.test_unittest.__file__)) self.wasRun = False - def _find_tests(start_dir, pattern): + def _find_tests(start_dir, pattern, namespace=None): self.wasRun = True self.assertEqual(start_dir, expectedPath) return tests @@ -848,6 +850,55 @@ def restore(): 'Can not use builtin modules ' 'as dotted module names') + def test_discovery_from_dotted_namespace_packages(self): + loader = unittest.TestLoader() + + package = types.ModuleType('package') + package.__name__ = "tests" + package.__path__ = ['/a', '/b'] + package.__file__ = None + package.__spec__ = types.SimpleNamespace( + name=package.__name__, + loader=NamespaceLoader(package.__name__, package.__path__, None), + submodule_search_locations=['/a', '/b'] + ) + + def _import(packagename, *args, **kwargs): + sys.modules[packagename] = package + return package + + _find_tests_args = [] + def _find_tests(start_dir, pattern, namespace=None): + _find_tests_args.append((start_dir, pattern)) + return ['%s/tests' % start_dir] + + loader._find_tests = _find_tests + loader.suiteClass = list + + with unittest.mock.patch('builtins.__import__', _import): + # Since loader.discover() can modify sys.path, restore it when done. + with import_helper.DirsOnSysPath(): + # Make sure to remove 'package' from sys.modules when done. + with test_util.uncache('package'): + suite = loader.discover('package') + + self.assertEqual(suite, ['/a/tests', '/b/tests']) + + @unittest.expectedFailure # TODO: RUSTPYTHON + def test_discovery_start_dir_is_namespace(self): + """Subdirectory discovery not affected if start_dir is a namespace pkg.""" + loader = unittest.TestLoader() + with ( + import_helper.DirsOnSysPath(os.path.join(os.path.dirname(__file__))), + test_util.uncache('namespace_test_pkg') + ): + suite = loader.discover('namespace_test_pkg') + self.assertEqual( + {list(suite)[0]._tests[0].__module__ for suite in suite._tests if list(suite)}, + # files under namespace_test_pkg.noop not discovered. + {'namespace_test_pkg.test_foo', 'namespace_test_pkg.bar.test_bar'}, + ) + def test_discovery_failed_discovery(self): from test.test_importlib import util diff --git a/Lib/test/test_unittest/test_loader.py b/Lib/test/test_unittest/test_loader.py index 83dd25ca546..0acefccf7f6 100644 --- a/Lib/test/test_unittest/test_loader.py +++ b/Lib/test/test_unittest/test_loader.py @@ -1,9 +1,9 @@ import functools import sys import types - import unittest + class Test_TestLoader(unittest.TestCase): ### Basic object tests @@ -76,7 +76,7 @@ def runTest(self): loader = unittest.TestLoader() # This has to be false for the test to succeed - self.assertFalse('runTest'.startswith(loader.testMethodPrefix)) + self.assertNotStartsWith('runTest', loader.testMethodPrefix) suite = loader.loadTestsFromTestCase(Foo) self.assertIsInstance(suite, loader.suiteClass) diff --git a/Lib/test/test_unittest/test_program.py b/Lib/test/test_unittest/test_program.py index 2e3a7508478..99c5ec48b67 100644 --- a/Lib/test/test_unittest/test_program.py +++ b/Lib/test/test_unittest/test_program.py @@ -1,9 +1,10 @@ import os -import sys import subprocess -from test import support +import sys import unittest + import test.test_unittest +from test import support from test.test_unittest.test_result import BufferedWriter @@ -135,14 +136,14 @@ def test_NonExit(self): argv=["foobar"], testRunner=unittest.TextTestRunner(stream=stream), testLoader=self.TestLoader(self.FooBar)) - self.assertTrue(hasattr(program, 'result')) + self.assertHasAttr(program, 'result') out = stream.getvalue() self.assertIn('\nFAIL: testFail ', out) self.assertIn('\nERROR: testError ', out) self.assertIn('\nUNEXPECTED SUCCESS: testUnexpectedSuccess ', out) expected = ('\n\nFAILED (failures=1, errors=1, skipped=1, ' 'expected failures=1, unexpected successes=1)\n') - self.assertTrue(out.endswith(expected)) + self.assertEndsWith(out, expected) def test_Exit(self): stream = BufferedWriter() @@ -159,7 +160,7 @@ def test_Exit(self): self.assertIn('\nUNEXPECTED SUCCESS: testUnexpectedSuccess ', out) expected = ('\n\nFAILED (failures=1, errors=1, skipped=1, ' 'expected failures=1, unexpected successes=1)\n') - self.assertTrue(out.endswith(expected)) + self.assertEndsWith(out, expected) def test_ExitAsDefault(self): stream = BufferedWriter() @@ -174,7 +175,7 @@ def test_ExitAsDefault(self): self.assertIn('\nUNEXPECTED SUCCESS: testUnexpectedSuccess ', out) expected = ('\n\nFAILED (failures=1, errors=1, skipped=1, ' 'expected failures=1, unexpected successes=1)\n') - self.assertTrue(out.endswith(expected)) + self.assertEndsWith(out, expected) def test_ExitSkippedSuite(self): stream = BufferedWriter() @@ -186,7 +187,7 @@ def test_ExitSkippedSuite(self): self.assertEqual(cm.exception.code, 0) out = stream.getvalue() expected = '\n\nOK (skipped=1)\n' - self.assertTrue(out.endswith(expected)) + self.assertEndsWith(out, expected) def test_ExitEmptySuite(self): stream = BufferedWriter() @@ -506,7 +507,6 @@ def testParseArgsSelectedTestNames(self): self.assertEqual(program.testNamePatterns, ['*foo*', '*bar*', '*pat*']) @unittest.expectedFailureIf(sys.platform != 'win32', 'TODO: RUSTPYTHON') - def testSelectedTestNamesFunctionalTest(self): def run_unittest(args): # Use -E to ignore PYTHONSAFEPATH env var diff --git a/Lib/test/test_unittest/test_result.py b/Lib/test/test_unittest/test_result.py index 4d552d54e9a..c260f90bf03 100644 --- a/Lib/test/test_unittest/test_result.py +++ b/Lib/test/test_unittest/test_result.py @@ -4,8 +4,12 @@ import traceback import unittest from unittest.util import strclass -from test.support import warnings_helper -from test.support import captured_stdout, force_not_colorized_test_class + +from test.support import ( + captured_stdout, + force_not_colorized_test_class, + warnings_helper, +) from test.test_unittest.support import BufferedWriter @@ -13,7 +17,7 @@ class MockTraceback(object): class TracebackException: def __init__(self, *args, **kwargs): self.capture_locals = kwargs.get('capture_locals', False) - def format(self): + def format(self, **kwargs): result = ['A traceback'] if self.capture_locals: result.append('locals') @@ -186,7 +190,7 @@ def test_1(self): test = Foo('test_1') try: test.fail("foo") - except: + except AssertionError: exc_info_tuple = sys.exc_info() result = unittest.TestResult() @@ -214,7 +218,7 @@ def test_1(self): def get_exc_info(): try: test.fail("foo") - except: + except AssertionError: return sys.exc_info() exc_info_tuple = get_exc_info() @@ -241,9 +245,9 @@ def get_exc_info(): try: try: test.fail("foo") - except: + except AssertionError: raise ValueError(42) - except: + except ValueError: return sys.exc_info() exc_info_tuple = get_exc_info() @@ -271,7 +275,7 @@ def get_exc_info(): loop.__cause__ = loop loop.__context__ = loop raise loop - except: + except Exception: return sys.exc_info() exc_info_tuple = get_exc_info() @@ -300,7 +304,7 @@ def get_exc_info(): ex1.__cause__ = ex2 ex2.__context__ = ex1 raise C - except: + except Exception: return sys.exc_info() exc_info_tuple = get_exc_info() @@ -345,7 +349,7 @@ def test_1(self): test = Foo('test_1') try: raise TypeError() - except: + except TypeError: exc_info_tuple = sys.exc_info() result = unittest.TestResult() @@ -454,7 +458,7 @@ def test(result): self.assertTrue(result.failfast) result = runner.run(test) stream.flush() - self.assertTrue(stream.getvalue().endswith('\n\nOK\n')) + self.assertEndsWith(stream.getvalue(), '\n\nOK\n') @force_not_colorized_test_class diff --git a/Lib/test/test_unittest/test_runner.py b/Lib/test/test_unittest/test_runner.py index 4d3cfd60b8d..b215a3664d1 100644 --- a/Lib/test/test_unittest/test_runner.py +++ b/Lib/test/test_unittest/test_runner.py @@ -1,13 +1,12 @@ import io import os -import sys import pickle import subprocess -from test import support - +import sys import unittest from unittest.case import _Outcome +from test import support from test.test_unittest.support import ( BufferedWriter, LoggingResult, diff --git a/Lib/test/test_unittest/test_setups.py b/Lib/test/test_unittest/test_setups.py index 2df703ed934..2468681003b 100644 --- a/Lib/test/test_unittest/test_setups.py +++ b/Lib/test/test_unittest/test_setups.py @@ -1,6 +1,5 @@ import io import sys - import unittest diff --git a/Lib/test/test_unittest/test_skipping.py b/Lib/test/test_unittest/test_skipping.py index f146dcac18e..f5cb860c60b 100644 --- a/Lib/test/test_unittest/test_skipping.py +++ b/Lib/test/test_unittest/test_skipping.py @@ -1,5 +1,6 @@ import unittest +from test.support import force_not_colorized from test.test_unittest.support import LoggingResult @@ -293,6 +294,7 @@ def test_die(self): self.assertFalse(result.unexpectedSuccesses) self.assertTrue(result.wasSuccessful()) + @force_not_colorized def test_expected_failure_and_fail_in_cleanup(self): class Foo(unittest.TestCase): @unittest.expectedFailure @@ -372,6 +374,7 @@ def test_die(self): self.assertEqual(result.unexpectedSuccesses, [test]) self.assertFalse(result.wasSuccessful()) + @force_not_colorized def test_unexpected_success_and_fail_in_cleanup(self): class Foo(unittest.TestCase): @unittest.expectedFailure diff --git a/Lib/test/test_unittest/test_suite.py b/Lib/test/test_unittest/test_suite.py index ca52ee9d9c0..11c8c859f3d 100644 --- a/Lib/test/test_unittest/test_suite.py +++ b/Lib/test/test_unittest/test_suite.py @@ -1,10 +1,9 @@ -import unittest - import gc import sys +import unittest import weakref -from test.test_unittest.support import LoggingResult, TestEquality +from test.test_unittest.support import LoggingResult, TestEquality ### Support code for Test_TestSuite ################################################################ diff --git a/Lib/test/test_unittest/test_util.py b/Lib/test/test_unittest/test_util.py index d590a333930..abadcb96601 100644 --- a/Lib/test/test_unittest/test_util.py +++ b/Lib/test/test_unittest/test_util.py @@ -1,5 +1,9 @@ import unittest -from unittest.util import safe_repr, sorted_list_difference, unorderable_list_difference +from unittest.util import ( + safe_repr, + sorted_list_difference, + unorderable_list_difference, +) class TestUtil(unittest.TestCase): diff --git a/Lib/test/test_warnings/__init__.py b/Lib/test/test_warnings/__init__.py index ae9a365e1c9..abdf7b32df2 100644 --- a/Lib/test/test_warnings/__init__.py +++ b/Lib/test/test_warnings/__init__.py @@ -24,10 +24,13 @@ from warnings import deprecated -py_warnings = import_helper.import_fresh_module('warnings', - blocked=['_warnings']) -c_warnings = import_helper.import_fresh_module('warnings', - fresh=['_warnings']) +py_warnings = import_helper.import_fresh_module('_py_warnings') +py_warnings._set_module(py_warnings) + +c_warnings = import_helper.import_fresh_module( + "warnings", fresh=["_warnings", "_py_warnings"] +) +c_warnings._set_module(c_warnings) @contextmanager def warnings_state(module): @@ -43,15 +46,21 @@ def warnings_state(module): except NameError: pass original_warnings = warning_tests.warnings - original_filters = module.filters - try: + if module._use_context: + saved_context, context = module._new_context() + else: + original_filters = module.filters module.filters = original_filters[:] + try: module.simplefilter("once") warning_tests.warnings = module yield finally: warning_tests.warnings = original_warnings - module.filters = original_filters + if module._use_context: + module._set_context(saved_context) + else: + module.filters = original_filters class TestWarning(Warning): @@ -93,7 +102,7 @@ class PublicAPITests(BaseTest): """ def test_module_all_attribute(self): - self.assertTrue(hasattr(self.module, '__all__')) + self.assertHasAttr(self.module, '__all__') target_api = ["warn", "warn_explicit", "showwarning", "formatwarning", "filterwarnings", "simplefilter", "resetwarnings", "catch_warnings", "deprecated"] @@ -111,14 +120,14 @@ class FilterTests(BaseTest): """Testing the filtering functionality.""" def test_error(self): - with original_warnings.catch_warnings(module=self.module) as w: + with self.module.catch_warnings() as w: self.module.resetwarnings() self.module.filterwarnings("error", category=UserWarning) self.assertRaises(UserWarning, self.module.warn, "FilterTests.test_error") def test_error_after_default(self): - with original_warnings.catch_warnings(module=self.module) as w: + with self.module.catch_warnings() as w: self.module.resetwarnings() message = "FilterTests.test_ignore_after_default" def f(): @@ -136,8 +145,7 @@ def f(): self.assertRaises(UserWarning, f) def test_ignore(self): - with original_warnings.catch_warnings(record=True, - module=self.module) as w: + with self.module.catch_warnings(record=True) as w: self.module.resetwarnings() self.module.filterwarnings("ignore", category=UserWarning) self.module.warn("FilterTests.test_ignore", UserWarning) @@ -145,8 +153,7 @@ def test_ignore(self): self.assertEqual(list(__warningregistry__), ['version']) def test_ignore_after_default(self): - with original_warnings.catch_warnings(record=True, - module=self.module) as w: + with self.module.catch_warnings(record=True) as w: self.module.resetwarnings() message = "FilterTests.test_ignore_after_default" def f(): @@ -157,44 +164,43 @@ def f(): f() self.assertEqual(len(w), 1) - def test_always(self): - with original_warnings.catch_warnings(record=True, - module=self.module) as w: - self.module.resetwarnings() - self.module.filterwarnings("always", category=UserWarning) - message = "FilterTests.test_always" - def f(): - self.module.warn(message, UserWarning) - f() - self.assertEqual(len(w), 1) - self.assertEqual(w[-1].message.args[0], message) - f() - self.assertEqual(len(w), 2) - self.assertEqual(w[-1].message.args[0], message) + def test_always_and_all(self): + for mode in {"always", "all"}: + with self.module.catch_warnings(record=True) as w: + self.module.resetwarnings() + self.module.filterwarnings(mode, category=UserWarning) + message = "FilterTests.test_always_and_all" + def f(): + self.module.warn(message, UserWarning) + f() + self.assertEqual(len(w), 1) + self.assertEqual(w[-1].message.args[0], message) + f() + self.assertEqual(len(w), 2) + self.assertEqual(w[-1].message.args[0], message) - def test_always_after_default(self): - with original_warnings.catch_warnings(record=True, - module=self.module) as w: - self.module.resetwarnings() - message = "FilterTests.test_always_after_ignore" - def f(): - self.module.warn(message, UserWarning) - f() - self.assertEqual(len(w), 1) - self.assertEqual(w[-1].message.args[0], message) - f() - self.assertEqual(len(w), 1) - self.module.filterwarnings("always", category=UserWarning) - f() - self.assertEqual(len(w), 2) - self.assertEqual(w[-1].message.args[0], message) - f() - self.assertEqual(len(w), 3) - self.assertEqual(w[-1].message.args[0], message) + def test_always_and_all_after_default(self): + for mode in {"always", "all"}: + with self.module.catch_warnings(record=True) as w: + self.module.resetwarnings() + message = "FilterTests.test_always_and_all_after_ignore" + def f(): + self.module.warn(message, UserWarning) + f() + self.assertEqual(len(w), 1) + self.assertEqual(w[-1].message.args[0], message) + f() + self.assertEqual(len(w), 1) + self.module.filterwarnings(mode, category=UserWarning) + f() + self.assertEqual(len(w), 2) + self.assertEqual(w[-1].message.args[0], message) + f() + self.assertEqual(len(w), 3) + self.assertEqual(w[-1].message.args[0], message) def test_default(self): - with original_warnings.catch_warnings(record=True, - module=self.module) as w: + with self.module.catch_warnings(record=True) as w: self.module.resetwarnings() self.module.filterwarnings("default", category=UserWarning) message = UserWarning("FilterTests.test_default") @@ -209,8 +215,7 @@ def test_default(self): raise ValueError("loop variant unhandled") def test_module(self): - with original_warnings.catch_warnings(record=True, - module=self.module) as w: + with self.module.catch_warnings(record=True) as w: self.module.resetwarnings() self.module.filterwarnings("module", category=UserWarning) message = UserWarning("FilterTests.test_module") @@ -221,8 +226,7 @@ def test_module(self): self.assertEqual(len(w), 0) def test_once(self): - with original_warnings.catch_warnings(record=True, - module=self.module) as w: + with self.module.catch_warnings(record=True) as w: self.module.resetwarnings() self.module.filterwarnings("once", category=UserWarning) message = UserWarning("FilterTests.test_once") @@ -237,9 +241,88 @@ def test_once(self): 42) self.assertEqual(len(w), 0) + @unittest.expectedFailure # TODO: RUSTPYTHON re.PatternError: bad escape \z at position 15 + def test_filter_module(self): + MS_WINDOWS = (sys.platform == 'win32') + with self.module.catch_warnings(record=True) as w: + self.module.simplefilter('error') + self.module.filterwarnings('always', module=r'package\.module\z') + self.module.warn_explicit('msg', UserWarning, 'filename', 42, + module='package.module') + self.assertEqual(len(w), 1) + with self.assertRaises(UserWarning): + self.module.warn_explicit('msg', UserWarning, '/path/to/package/module', 42) + with self.assertRaises(UserWarning): + self.module.warn_explicit('msg', UserWarning, '/path/to/package/module.py', 42) + + with self.module.catch_warnings(record=True) as w: + self.module.simplefilter('error') + self.module.filterwarnings('always', module='package') + self.module.warn_explicit('msg', UserWarning, 'filename', 42, + module='package.module') + self.assertEqual(len(w), 1) + with self.assertRaises(UserWarning): + self.module.warn_explicit('msg', UserWarning, 'filename', 42, + module='other.package.module') + with self.assertRaises(UserWarning): + self.module.warn_explicit('msg', UserWarning, '/path/to/otherpackage/module.py', 42) + + with self.module.catch_warnings(record=True) as w: + self.module.simplefilter('error') + self.module.filterwarnings('always', module=r'/path/to/package/module\z') + self.module.warn_explicit('msg', UserWarning, '/path/to/package/module', 42) + self.assertEqual(len(w), 1) + self.module.warn_explicit('msg', UserWarning, '/path/to/package/module.py', 42) + self.assertEqual(len(w), 2) + with self.assertRaises(UserWarning): + self.module.warn_explicit('msg', UserWarning, '/PATH/TO/PACKAGE/MODULE', 42) + if MS_WINDOWS: + if self.module is py_warnings: + self.module.warn_explicit('msg', UserWarning, r'/path/to/package/module.PY', 42) + self.assertEqual(len(w), 3) + with self.assertRaises(UserWarning): + self.module.warn_explicit('msg', UserWarning, r'/path/to/package/module/__init__.py', 42) + with self.assertRaises(UserWarning): + self.module.warn_explicit('msg', UserWarning, r'/path/to/package/module.pyw', 42) + with self.assertRaises(UserWarning): + self.module.warn_explicit('msg', UserWarning, r'\path\to\package\module', 42) + + with self.module.catch_warnings(record=True) as w: + self.module.simplefilter('error') + self.module.filterwarnings('always', module=r'/path/to/package/__init__\z') + self.module.warn_explicit('msg', UserWarning, '/path/to/package/__init__.py', 42) + self.assertEqual(len(w), 1) + self.module.warn_explicit('msg', UserWarning, '/path/to/package/__init__', 42) + self.assertEqual(len(w), 2) + + if MS_WINDOWS: + with self.module.catch_warnings(record=True) as w: + self.module.simplefilter('error') + self.module.filterwarnings('always', module=r'C:\\path\\to\\package\\module\z') + self.module.warn_explicit('msg', UserWarning, r'C:\path\to\package\module', 42) + self.assertEqual(len(w), 1) + self.module.warn_explicit('msg', UserWarning, r'C:\path\to\package\module.py', 42) + self.assertEqual(len(w), 2) + if self.module is py_warnings: + self.module.warn_explicit('msg', UserWarning, r'C:\path\to\package\module.PY', 42) + self.assertEqual(len(w), 3) + with self.assertRaises(UserWarning): + self.module.warn_explicit('msg', UserWarning, r'C:\path\to\package\module.pyw', 42) + with self.assertRaises(UserWarning): + self.module.warn_explicit('msg', UserWarning, r'C:\PATH\TO\PACKAGE\MODULE', 42) + with self.assertRaises(UserWarning): + self.module.warn_explicit('msg', UserWarning, r'C:/path/to/package/module', 42) + with self.assertRaises(UserWarning): + self.module.warn_explicit('msg', UserWarning, r'C:\path\to\package\module\__init__.py', 42) + + with self.module.catch_warnings(record=True) as w: + self.module.simplefilter('error') + self.module.filterwarnings('always', module=r'\z') + self.module.warn_explicit('msg', UserWarning, '', 42) + self.assertEqual(len(w), 1) + def test_module_globals(self): - with original_warnings.catch_warnings(record=True, - module=self.module) as w: + with self.module.catch_warnings(record=True) as w: self.module.simplefilter("always", UserWarning) # bpo-33509: module_globals=None must not crash @@ -259,15 +342,14 @@ def test_module_globals(self): self.assertEqual(len(w), 2) def test_inheritance(self): - with original_warnings.catch_warnings(module=self.module) as w: + with self.module.catch_warnings() as w: self.module.resetwarnings() self.module.filterwarnings("error", category=Warning) self.assertRaises(UserWarning, self.module.warn, "FilterTests.test_inheritance", UserWarning) def test_ordering(self): - with original_warnings.catch_warnings(record=True, - module=self.module) as w: + with self.module.catch_warnings(record=True) as w: self.module.resetwarnings() self.module.filterwarnings("ignore", category=UserWarning) self.module.filterwarnings("error", category=UserWarning, @@ -282,8 +364,7 @@ def test_ordering(self): def test_filterwarnings(self): # Test filterwarnings(). # Implicitly also tests resetwarnings(). - with original_warnings.catch_warnings(record=True, - module=self.module) as w: + with self.module.catch_warnings(record=True) as w: self.module.filterwarnings("error", "", Warning, "", 0) self.assertRaises(UserWarning, self.module.warn, 'convert to error') @@ -307,8 +388,7 @@ def test_filterwarnings(self): self.assertIs(w[-1].category, UserWarning) def test_message_matching(self): - with original_warnings.catch_warnings(record=True, - module=self.module) as w: + with self.module.catch_warnings(record=True) as w: self.module.simplefilter("ignore", UserWarning) self.module.filterwarnings("error", "match", UserWarning) self.assertRaises(UserWarning, self.module.warn, "match") @@ -324,54 +404,52 @@ def match(self, a): L[:] = [] L = [("default",X(),UserWarning,X(),0) for i in range(2)] - with original_warnings.catch_warnings(record=True, - module=self.module) as w: + with self.module.catch_warnings(record=True) as w: self.module.filters = L self.module.warn_explicit(UserWarning("b"), None, "f.py", 42) self.assertEqual(str(w[-1].message), "b") def test_filterwarnings_duplicate_filters(self): - with original_warnings.catch_warnings(module=self.module): + with self.module.catch_warnings(): self.module.resetwarnings() self.module.filterwarnings("error", category=UserWarning) - self.assertEqual(len(self.module.filters), 1) + self.assertEqual(len(self.module._get_filters()), 1) self.module.filterwarnings("ignore", category=UserWarning) self.module.filterwarnings("error", category=UserWarning) self.assertEqual( - len(self.module.filters), 2, + len(self.module._get_filters()), 2, "filterwarnings inserted duplicate filter" ) self.assertEqual( - self.module.filters[0][0], "error", + self.module._get_filters()[0][0], "error", "filterwarnings did not promote filter to " "the beginning of list" ) def test_simplefilter_duplicate_filters(self): - with original_warnings.catch_warnings(module=self.module): + with self.module.catch_warnings(): self.module.resetwarnings() self.module.simplefilter("error", category=UserWarning) - self.assertEqual(len(self.module.filters), 1) + self.assertEqual(len(self.module._get_filters()), 1) self.module.simplefilter("ignore", category=UserWarning) self.module.simplefilter("error", category=UserWarning) self.assertEqual( - len(self.module.filters), 2, + len(self.module._get_filters()), 2, "simplefilter inserted duplicate filter" ) self.assertEqual( - self.module.filters[0][0], "error", + self.module._get_filters()[0][0], "error", "simplefilter did not promote filter to the beginning of list" ) def test_append_duplicate(self): - with original_warnings.catch_warnings(module=self.module, - record=True) as w: + with self.module.catch_warnings(record=True) as w: self.module.resetwarnings() self.module.simplefilter("ignore") self.module.simplefilter("error", append=True) self.module.simplefilter("ignore", append=True) self.module.warn("test_append_duplicate", category=UserWarning) - self.assertEqual(len(self.module.filters), 2, + self.assertEqual(len(self.module._get_filters()), 2, "simplefilter inserted duplicate filter" ) self.assertEqual(len(w), 0, @@ -401,19 +479,17 @@ def test_argument_validation(self): self.module.simplefilter('ignore', lineno=-1) def test_catchwarnings_with_simplefilter_ignore(self): - with original_warnings.catch_warnings(module=self.module): + with self.module.catch_warnings(module=self.module): self.module.resetwarnings() self.module.simplefilter("error") - with self.module.catch_warnings( - module=self.module, action="ignore" - ): + with self.module.catch_warnings(action="ignore"): self.module.warn("This will be ignored") def test_catchwarnings_with_simplefilter_error(self): - with original_warnings.catch_warnings(module=self.module): + with self.module.catch_warnings(): self.module.resetwarnings() with self.module.catch_warnings( - module=self.module, action="error", category=FutureWarning + action="error", category=FutureWarning ): with support.captured_stderr() as stderr: error_msg = "Other types of warnings are not errors" @@ -435,8 +511,7 @@ class WarnTests(BaseTest): """Test warnings.warn() and warnings.warn_explicit().""" def test_message(self): - with original_warnings.catch_warnings(record=True, - module=self.module) as w: + with self.module.catch_warnings(record=True) as w: self.module.simplefilter("once") for i in range(4): text = 'multi %d' %i # Different text on each call. @@ -448,8 +523,7 @@ def test_message(self): def test_warn_nonstandard_types(self): # warn() should handle non-standard types without issue. for ob in (Warning, None, 42): - with original_warnings.catch_warnings(record=True, - module=self.module) as w: + with self.module.catch_warnings(record=True) as w: self.module.simplefilter("once") self.module.warn(ob) # Don't directly compare objects since @@ -458,8 +532,7 @@ def test_warn_nonstandard_types(self): def test_filename(self): with warnings_state(self.module): - with original_warnings.catch_warnings(record=True, - module=self.module) as w: + with self.module.catch_warnings(record=True) as w: warning_tests.inner("spam1") self.assertEqual(os.path.basename(w[-1].filename), "stacklevel.py") @@ -471,8 +544,7 @@ def test_stacklevel(self): # Test stacklevel argument # make sure all messages are different, so the warning won't be skipped with warnings_state(self.module): - with original_warnings.catch_warnings(record=True, - module=self.module) as w: + with self.module.catch_warnings(record=True) as w: warning_tests.inner("spam3", stacklevel=1) self.assertEqual(os.path.basename(w[-1].filename), "stacklevel.py") @@ -494,23 +566,20 @@ def test_stacklevel(self): self.assertEqual(os.path.basename(w[-1].filename), "") - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_stacklevel_import(self): # Issue #24305: With stacklevel=2, module-level warnings should work. import_helper.unload('test.test_warnings.data.import_warning') with warnings_state(self.module): - with original_warnings.catch_warnings(record=True, - module=self.module) as w: + with self.module.catch_warnings(record=True) as w: self.module.simplefilter('always') - import test.test_warnings.data.import_warning + import test.test_warnings.data.import_warning # noqa: F401 self.assertEqual(len(w), 1) self.assertEqual(w[0].filename, __file__) def test_skip_file_prefixes(self): with warnings_state(self.module): - with original_warnings.catch_warnings(record=True, - module=self.module) as w: + with self.module.catch_warnings(record=True) as w: self.module.simplefilter('always') # Warning never attributed to the data/ package. @@ -533,6 +602,16 @@ def test_skip_file_prefixes(self): warning_tests.package("prefix02", stacklevel=3) self.assertIn("unittest", w[-1].filename) + def test_skip_file_prefixes_file_path(self): + # see: gh-126209 + with warnings_state(self.module): + skipped = warning_tests.__file__ + with self.module.catch_warnings(record=True) as w: + warning_tests.outer("msg", skip_file_prefixes=(skipped,)) + + self.assertEqual(len(w), 1) + self.assertNotEqual(w[-1].filename, skipped) + def test_skip_file_prefixes_type_errors(self): with warnings_state(self.module): warn = warning_tests.warnings.warn @@ -548,23 +627,16 @@ def test_exec_filename(self): codeobj = compile(("import warnings\n" "warnings.warn('hello', UserWarning)"), filename, "exec") - with original_warnings.catch_warnings(record=True) as w: + with self.module.catch_warnings(record=True) as w: self.module.simplefilter("always", category=UserWarning) exec(codeobj) self.assertEqual(w[0].filename, filename) def test_warn_explicit_non_ascii_filename(self): - with original_warnings.catch_warnings(record=True, - module=self.module) as w: + with self.module.catch_warnings(record=True) as w: self.module.resetwarnings() self.module.filterwarnings("always", category=UserWarning) - filenames = ["nonascii\xe9\u20ac"] - if not support.is_emscripten: - # JavaScript does not like surrogates. - # Invalid UTF-8 leading byte 0x80 encountered when - # deserializing a UTF-8 string in wasm memory to a JS - # string! - filenames.append("surrogate\udc80") + filenames = ["nonascii\xe9\u20ac", "surrogate\udc80"] for filename in filenames: try: os.fsencode(filename) @@ -625,7 +697,7 @@ class NonWarningSubclass: self.assertIn('category must be a Warning subclass, not ', str(cm.exception)) - with original_warnings.catch_warnings(module=self.module): + with self.module.catch_warnings(): self.module.resetwarnings() self.module.filterwarnings('default') with self.assertWarns(MyWarningClass) as cm: @@ -641,7 +713,7 @@ class NonWarningSubclass: self.assertIsInstance(cm.warning, Warning) def check_module_globals(self, module_globals): - with original_warnings.catch_warnings(module=self.module, record=True) as w: + with self.module.catch_warnings(record=True) as w: self.module.filterwarnings('default') self.module.warn_explicit( 'eggs', UserWarning, 'bar', 1, @@ -654,7 +726,7 @@ def check_module_globals_error(self, module_globals, errmsg, errtype=ValueError) if self.module is py_warnings: self.check_module_globals(module_globals) return - with original_warnings.catch_warnings(module=self.module, record=True) as w: + with self.module.catch_warnings(record=True) as w: self.module.filterwarnings('always') with self.assertRaisesRegex(errtype, re.escape(errmsg)): self.module.warn_explicit( @@ -666,7 +738,7 @@ def check_module_globals_deprecated(self, module_globals, msg): if self.module is py_warnings: self.check_module_globals(module_globals) return - with original_warnings.catch_warnings(module=self.module, record=True) as w: + with self.module.catch_warnings(record=True) as w: self.module.filterwarnings('always') self.module.warn_explicit( 'eggs', UserWarning, 'bar', 1, @@ -734,53 +806,44 @@ def test_gh86298_no_loader_with_spec_loader_okay(self): class CWarnTests(WarnTests, unittest.TestCase): module = c_warnings - # TODO: RUSTPYTHON - @unittest.expectedFailure # As an early adopter, we sanity check the # test.import_helper.import_fresh_module utility function + @unittest.expectedFailure # TODO: RUSTPYTHON def test_accelerated(self): self.assertIsNot(original_warnings, self.module) - self.assertFalse(hasattr(self.module.warn, '__code__')) + self.assertNotHasAttr(self.module.warn, '__code__') - # TODO: RUSTPYTHON - @unittest.expectedFailure - def test_gh86298_no_loader_and_spec_is_none(self): - return super().test_gh86298_no_loader_and_spec_is_none() + @unittest.expectedFailure # TODO: RUSTPYTHON + def test_gh86298_loader_and_spec_loader_disagree(self): + return super().test_gh86298_loader_and_spec_loader_disagree() - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_gh86298_loader_is_none_and_spec_is_none(self): return super().test_gh86298_loader_is_none_and_spec_is_none() - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_gh86298_loader_is_none_and_spec_loader_is_none(self): return super().test_gh86298_loader_is_none_and_spec_loader_is_none() - - # TODO: RUSTPYTHON - @unittest.expectedFailure + + @unittest.expectedFailure # TODO: RUSTPYTHON + def test_gh86298_no_loader_and_no_spec_loader(self): + return super().test_gh86298_no_loader_and_no_spec_loader() + + @unittest.expectedFailure # TODO: RUSTPYTHON + def test_gh86298_no_loader_and_spec_is_none(self): + return super().test_gh86298_no_loader_and_spec_is_none() + + @unittest.expectedFailure # TODO: RUSTPYTHON def test_gh86298_no_spec(self): return super().test_gh86298_no_spec() - # TODO: RUSTPYTHON - @unittest.expectedFailure - def test_gh86298_spec_is_none(self): - return super().test_gh86298_spec_is_none() - - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_gh86298_no_spec_loader(self): return super().test_gh86298_no_spec_loader() - # TODO: RUSTPYTHON - @unittest.expectedFailure - def test_gh86298_loader_and_spec_loader_disagree(self): - return super().test_gh86298_loader_and_spec_loader_disagree() - - # TODO: RUSTPYTHON - @unittest.expectedFailure - def test_gh86298_no_loader_and_no_spec_loader(self): - return super().test_gh86298_no_loader_and_no_spec_loader() + @unittest.expectedFailure # TODO: RUSTPYTHON + def test_gh86298_spec_is_none(self): + return super().test_gh86298_spec_is_none() class PyWarnTests(WarnTests, unittest.TestCase): module = py_warnings @@ -789,7 +852,7 @@ class PyWarnTests(WarnTests, unittest.TestCase): # test.import_helper.import_fresh_module utility function def test_pure_python(self): self.assertIsNot(original_warnings, self.module) - self.assertTrue(hasattr(self.module.warn, '__code__')) + self.assertHasAttr(self.module.warn, '__code__') class WCmdLineTests(BaseTest): @@ -797,7 +860,7 @@ class WCmdLineTests(BaseTest): def test_improper_input(self): # Uses the private _setoption() function to test the parsing # of command-line warning arguments - with original_warnings.catch_warnings(module=self.module): + with self.module.catch_warnings(): self.assertRaises(self.module._OptionError, self.module._setoption, '1:2:3:4:5:6') self.assertRaises(self.module._OptionError, @@ -816,7 +879,7 @@ def test_improper_input(self): self.assertRaises(UserWarning, self.module.warn, 'convert to error') def test_import_from_module(self): - with original_warnings.catch_warnings(module=self.module): + with self.module.catch_warnings(): self.module._setoption('ignore::Warning') with self.assertRaises(self.module._OptionError): self.module._setoption('ignore::TestWarning') @@ -857,11 +920,10 @@ class _WarningsTests(BaseTest, unittest.TestCase): module = c_warnings - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_filter(self): # Everything should function even if 'filters' is not in warnings. - with original_warnings.catch_warnings(module=self.module) as w: + with self.module.catch_warnings() as w: self.module.filterwarnings("error", "", Warning, "", 0) self.assertRaises(UserWarning, self.module.warn, 'convert to error') @@ -869,8 +931,7 @@ def test_filter(self): self.assertRaises(UserWarning, self.module.warn, 'convert to error') - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_onceregistry(self): # Replacing or removing the onceregistry should be okay. global __warningregistry__ @@ -878,8 +939,7 @@ def test_onceregistry(self): try: original_registry = self.module.onceregistry __warningregistry__ = {} - with original_warnings.catch_warnings(record=True, - module=self.module) as w: + with self.module.catch_warnings(record=True) as w: self.module.resetwarnings() self.module.filterwarnings("once", category=UserWarning) self.module.warn_explicit(message, UserWarning, "file", 42) @@ -901,15 +961,13 @@ def test_onceregistry(self): finally: self.module.onceregistry = original_registry - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_default_action(self): # Replacing or removing defaultaction should be okay. message = UserWarning("defaultaction test") original = self.module.defaultaction try: - with original_warnings.catch_warnings(record=True, - module=self.module) as w: + with self.module.catch_warnings(record=True) as w: self.module.resetwarnings() registry = {} self.module.warn_explicit(message, UserWarning, "", 42, @@ -942,8 +1000,12 @@ def test_default_action(self): def test_showwarning_missing(self): # Test that showwarning() missing is okay. + if self.module._use_context: + # If _use_context is true, the warnings module does not + # override/restore showwarning() + return text = 'del showwarning test' - with original_warnings.catch_warnings(module=self.module): + with self.module.catch_warnings(): self.module.filterwarnings("always", category=UserWarning) del self.module.showwarning with support.captured_output('stderr') as stream: @@ -951,12 +1013,11 @@ def test_showwarning_missing(self): result = stream.getvalue() self.assertIn(text, result) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_showwarnmsg_missing(self): # Test that _showwarnmsg() missing is okay. text = 'del _showwarnmsg test' - with original_warnings.catch_warnings(module=self.module): + with self.module.catch_warnings(): self.module.filterwarnings("always", category=UserWarning) show = self.module._showwarnmsg @@ -970,50 +1031,55 @@ def test_showwarnmsg_missing(self): self.assertIn(text, result) def test_showwarning_not_callable(self): - with original_warnings.catch_warnings(module=self.module): - self.module.filterwarnings("always", category=UserWarning) - self.module.showwarning = print - with support.captured_output('stdout'): - self.module.warn('Warning!') - self.module.showwarning = 23 - self.assertRaises(TypeError, self.module.warn, "Warning!") + orig = self.module.showwarning + try: + with self.module.catch_warnings(): + self.module.filterwarnings("always", category=UserWarning) + self.module.showwarning = print + with support.captured_output('stdout'): + self.module.warn('Warning!') + self.module.showwarning = 23 + self.assertRaises(TypeError, self.module.warn, "Warning!") + finally: + self.module.showwarning = orig def test_show_warning_output(self): # With showwarning() missing, make sure that output is okay. - text = 'test show_warning' - with original_warnings.catch_warnings(module=self.module): - self.module.filterwarnings("always", category=UserWarning) - del self.module.showwarning - with support.captured_output('stderr') as stream: - warning_tests.inner(text) - result = stream.getvalue() - self.assertEqual(result.count('\n'), 2, - "Too many newlines in %r" % result) - first_line, second_line = result.split('\n', 1) - expected_file = os.path.splitext(warning_tests.__file__)[0] + '.py' - first_line_parts = first_line.rsplit(':', 3) - path, line, warning_class, message = first_line_parts - line = int(line) - self.assertEqual(expected_file, path) - self.assertEqual(warning_class, ' ' + UserWarning.__name__) - self.assertEqual(message, ' ' + text) - expected_line = ' ' + linecache.getline(path, line).strip() + '\n' - assert expected_line - self.assertEqual(second_line, expected_line) - - # TODO: RUSTPYTHON - @unittest.expectedFailure + orig = self.module.showwarning + try: + text = 'test show_warning' + with self.module.catch_warnings(): + self.module.filterwarnings("always", category=UserWarning) + del self.module.showwarning + with support.captured_output('stderr') as stream: + warning_tests.inner(text) + result = stream.getvalue() + self.assertEqual(result.count('\n'), 2, + "Too many newlines in %r" % result) + first_line, second_line = result.split('\n', 1) + expected_file = os.path.splitext(warning_tests.__file__)[0] + '.py' + first_line_parts = first_line.rsplit(':', 3) + path, line, warning_class, message = first_line_parts + line = int(line) + self.assertEqual(expected_file, path) + self.assertEqual(warning_class, ' ' + UserWarning.__name__) + self.assertEqual(message, ' ' + text) + expected_line = ' ' + linecache.getline(path, line).strip() + '\n' + assert expected_line + self.assertEqual(second_line, expected_line) + finally: + self.module.showwarning = orig + def test_filename_none(self): # issue #12467: race condition if a warning is emitted at shutdown globals_dict = globals() oldfile = globals_dict['__file__'] try: - catch = original_warnings.catch_warnings(record=True, - module=self.module) + catch = self.module.catch_warnings(record=True) with catch as w: self.module.filterwarnings("always", category=UserWarning) globals_dict['__file__'] = None - original_warnings.warn('test', UserWarning) + self.module.warn('test', UserWarning) self.assertTrue(len(w)) finally: globals_dict['__file__'] = oldfile @@ -1027,8 +1093,7 @@ def test_stderr_none(self): self.assertNotIn(b'Warning!', stderr) self.assertNotIn(b'Error', stderr) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_issue31285(self): # warn_explicit() should neither raise a SystemError nor cause an # assertion failure, in case the return value of get_source() has a @@ -1052,7 +1117,7 @@ def get_source(self, fullname): wmod = self.module - with original_warnings.catch_warnings(module=wmod): + with wmod.catch_warnings(): wmod.filterwarnings('default', category=UserWarning) linecache.clearcache() @@ -1079,7 +1144,7 @@ def test_issue31411(self): # warn_explicit() shouldn't raise a SystemError in case # warnings.onceregistry isn't a dictionary. wmod = self.module - with original_warnings.catch_warnings(module=wmod): + with wmod.catch_warnings(): wmod.filterwarnings('once') with support.swap_attr(wmod, 'onceregistry', None): with self.assertRaises(TypeError): @@ -1090,12 +1155,12 @@ def test_issue31416(self): # warn_explicit() shouldn't cause an assertion failure in case of a # bad warnings.filters or warnings.defaultaction. wmod = self.module - with original_warnings.catch_warnings(module=wmod): - wmod.filters = [(None, None, Warning, None, 0)] + with wmod.catch_warnings(): + wmod._get_filters()[:] = [(None, None, Warning, None, 0)] with self.assertRaises(TypeError): wmod.warn_explicit('foo', Warning, 'bar', 1) - wmod.filters = [] + wmod._get_filters()[:] = [] with support.swap_attr(wmod, 'defaultaction', None), \ self.assertRaises(TypeError): wmod.warn_explicit('foo', Warning, 'bar', 1) @@ -1104,7 +1169,7 @@ def test_issue31416(self): def test_issue31566(self): # warn() shouldn't cause an assertion failure in case of a bad # __name__ global. - with original_warnings.catch_warnings(module=self.module): + with self.module.catch_warnings(): self.module.filterwarnings('error', category=UserWarning) with support.swap_item(globals(), '__name__', b'foo'), \ support.swap_item(globals(), '__file__', None): @@ -1181,8 +1246,7 @@ class CWarningsDisplayTests(WarningsDisplayTests, unittest.TestCase): class PyWarningsDisplayTests(WarningsDisplayTests, unittest.TestCase): module = py_warnings - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_tracemalloc(self): self.addCleanup(os_helper.unlink, os_helper.TESTFN) @@ -1234,16 +1298,18 @@ class CatchWarningTests(BaseTest): """Test catch_warnings().""" def test_catch_warnings_restore(self): + if self.module._use_context: + return # test disabled if using context vars wmod = self.module orig_filters = wmod.filters orig_showwarning = wmod.showwarning # Ensure both showwarning and filters are restored when recording - with wmod.catch_warnings(module=wmod, record=True): + with wmod.catch_warnings(record=True): wmod.filters = wmod.showwarning = object() self.assertIs(wmod.filters, orig_filters) self.assertIs(wmod.showwarning, orig_showwarning) # Same test, but with recording disabled - with wmod.catch_warnings(module=wmod, record=False): + with wmod.catch_warnings(record=False): wmod.filters = wmod.showwarning = object() self.assertIs(wmod.filters, orig_filters) self.assertIs(wmod.showwarning, orig_showwarning) @@ -1251,7 +1317,7 @@ def test_catch_warnings_restore(self): def test_catch_warnings_recording(self): wmod = self.module # Ensure warnings are recorded when requested - with wmod.catch_warnings(module=wmod, record=True) as w: + with wmod.catch_warnings(record=True) as w: self.assertEqual(w, []) self.assertIs(type(w), list) wmod.simplefilter("always") @@ -1265,44 +1331,48 @@ def test_catch_warnings_recording(self): self.assertEqual(w, []) # Ensure warnings are not recorded when not requested orig_showwarning = wmod.showwarning - with wmod.catch_warnings(module=wmod, record=False) as w: + with wmod.catch_warnings(record=False) as w: self.assertIsNone(w) self.assertIs(wmod.showwarning, orig_showwarning) def test_catch_warnings_reentry_guard(self): wmod = self.module # Ensure catch_warnings is protected against incorrect usage - x = wmod.catch_warnings(module=wmod, record=True) + x = wmod.catch_warnings(record=True) self.assertRaises(RuntimeError, x.__exit__) with x: self.assertRaises(RuntimeError, x.__enter__) # Same test, but with recording disabled - x = wmod.catch_warnings(module=wmod, record=False) + x = wmod.catch_warnings(record=False) self.assertRaises(RuntimeError, x.__exit__) with x: self.assertRaises(RuntimeError, x.__enter__) def test_catch_warnings_defaults(self): wmod = self.module - orig_filters = wmod.filters + orig_filters = wmod._get_filters() orig_showwarning = wmod.showwarning # Ensure default behaviour is not to record warnings - with wmod.catch_warnings(module=wmod) as w: + with wmod.catch_warnings() as w: self.assertIsNone(w) self.assertIs(wmod.showwarning, orig_showwarning) - self.assertIsNot(wmod.filters, orig_filters) - self.assertIs(wmod.filters, orig_filters) + self.assertIsNot(wmod._get_filters(), orig_filters) + self.assertIs(wmod._get_filters(), orig_filters) if wmod is sys.modules['warnings']: # Ensure the default module is this one with wmod.catch_warnings() as w: self.assertIsNone(w) self.assertIs(wmod.showwarning, orig_showwarning) - self.assertIsNot(wmod.filters, orig_filters) - self.assertIs(wmod.filters, orig_filters) + self.assertIsNot(wmod._get_filters(), orig_filters) + self.assertIs(wmod._get_filters(), orig_filters) def test_record_override_showwarning_before(self): # Issue #28835: If warnings.showwarning() was overridden, make sure # that catch_warnings(record=True) overrides it again. + if self.module._use_context: + # If _use_context is true, the warnings module does not restore + # showwarning() + return text = "This is a warning" wmod = self.module my_log = [] @@ -1313,7 +1383,7 @@ def my_logger(message, category, filename, lineno, file=None, line=None): # Override warnings.showwarning() before calling catch_warnings() with support.swap_attr(wmod, 'showwarning', my_logger): - with wmod.catch_warnings(module=wmod, record=True) as log: + with wmod.catch_warnings(record=True) as log: self.assertIsNot(wmod.showwarning, my_logger) wmod.simplefilter("always") @@ -1328,6 +1398,10 @@ def my_logger(message, category, filename, lineno, file=None, line=None): def test_record_override_showwarning_inside(self): # Issue #28835: It is possible to override warnings.showwarning() # in the catch_warnings(record=True) context manager. + if self.module._use_context: + # If _use_context is true, the warnings module does not restore + # showwarning() + return text = "This is a warning" wmod = self.module my_log = [] @@ -1336,7 +1410,7 @@ def my_logger(message, category, filename, lineno, file=None, line=None): nonlocal my_log my_log.append(message) - with wmod.catch_warnings(module=wmod, record=True) as log: + with wmod.catch_warnings(record=True) as log: wmod.simplefilter("always") wmod.showwarning = my_logger wmod.warn(text) @@ -1385,8 +1459,7 @@ class PyCatchWarningTests(CatchWarningTests, unittest.TestCase): class EnvironmentVariableTests(BaseTest): - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_single_warning(self): rc, stdout, stderr = assert_python_ok("-c", "import sys; sys.stdout.write(str(sys.warnoptions))", @@ -1394,8 +1467,7 @@ def test_single_warning(self): PYTHONDEVMODE="") self.assertEqual(stdout, b"['ignore::DeprecationWarning']") - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_comma_separated_warnings(self): rc, stdout, stderr = assert_python_ok("-c", "import sys; sys.stdout.write(str(sys.warnoptions))", @@ -1404,8 +1476,7 @@ def test_comma_separated_warnings(self): self.assertEqual(stdout, b"['ignore::DeprecationWarning', 'ignore::UnicodeWarning']") - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON @force_not_colorized def test_envvar_and_command_line(self): rc, stdout, stderr = assert_python_ok("-Wignore::UnicodeWarning", "-c", @@ -1415,8 +1486,7 @@ def test_envvar_and_command_line(self): self.assertEqual(stdout, b"['ignore::DeprecationWarning', 'ignore::UnicodeWarning']") - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON @force_not_colorized def test_conflicting_envvar_and_command_line(self): rc, stdout, stderr = assert_python_failure("-Werror::DeprecationWarning", "-c", @@ -1458,7 +1528,7 @@ def test_default_filter_configuration(self): code = "import sys; sys.modules.pop('warnings', None); sys.modules['_warnings'] = None; " else: code = "" - code += "import warnings; [print(f) for f in warnings.filters]" + code += "import warnings; [print(f) for f in warnings._get_filters()]" rc, stdout, stderr = assert_python_ok("-c", code, __isolated=True) stdout_lines = [line.strip() for line in stdout.splitlines()] @@ -1466,8 +1536,7 @@ def test_default_filter_configuration(self): self.assertEqual(stdout_lines, expected_output) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON @unittest.skipUnless(sys.getfilesystemencoding() != 'ascii', 'requires non-ascii filesystemencoding') def test_nonascii(self): @@ -1481,18 +1550,24 @@ def test_nonascii(self): class CEnvironmentVariableTests(EnvironmentVariableTests, unittest.TestCase): module = c_warnings - - # TODO: RUSTPYTHON - @unittest.expectedFailure - def test_default_filter_configuration(self): - # XXX: RUSTPYHTON; remove the entire function when fixed - super().test_default_filter_configuration() - + @unittest.expectedFailure # TODO: RUSTPYTHON Lists differ + def test_default_filter_configuration(self): super().test_default_filter_configuration() # TODO: RUSTPYTHON class PyEnvironmentVariableTests(EnvironmentVariableTests, unittest.TestCase): module = py_warnings +class LocksTest(unittest.TestCase): + @support.cpython_only + @unittest.skipUnless(c_warnings, 'C module is required') + def test_release_lock_no_lock(self): + with self.assertRaisesRegex( + RuntimeError, + 'cannot release un-acquired lock', + ): + c_warnings._release_lock() + + class _DeprecatedTest(BaseTest, unittest.TestCase): """Test _deprecated().""" @@ -1547,8 +1622,7 @@ def test_issue_8766(self): class FinalizationTest(unittest.TestCase): - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_finalization(self): # Issue #19421: warnings.warn() should not crash # during Python finalization @@ -1566,8 +1640,7 @@ def __del__(self): self.assertEqual(err.decode().rstrip(), ':7: UserWarning: test') - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_late_resource_warning(self): # Issue #21925: Emitting a ResourceWarning late during the Python # shutdown must be logged. @@ -1578,15 +1651,178 @@ def test_late_resource_warning(self): # (_warnings will try to import it) code = "f = open(%a)" % __file__ rc, out, err = assert_python_ok("-Wd", "-c", code) - self.assertTrue(err.startswith(expected), ascii(err)) + self.assertStartsWith(err, expected) # import the warnings module code = "import warnings; f = open(%a)" % __file__ rc, out, err = assert_python_ok("-Wd", "-c", code) - self.assertTrue(err.startswith(expected), ascii(err)) + self.assertStartsWith(err, expected) + + +class AsyncTests(BaseTest): + """Verifies that the catch_warnings() context manager behaves + as expected when used inside async co-routines. This requires + that the context_aware_warnings flag is enabled, so that + the context manager uses a context variable. + """ + + def setUp(self): + super().setUp() + self.module.resetwarnings() + + @unittest.skipIf(not sys.flags.context_aware_warnings, + "requires context aware warnings") + def test_async_context(self): + import asyncio + + # Events to force the execution interleaving we want. + step_a1 = asyncio.Event() + step_a2 = asyncio.Event() + step_b1 = asyncio.Event() + step_b2 = asyncio.Event() + + async def run_a(): + with self.module.catch_warnings(record=True) as w: + await step_a1.wait() + # The warning emitted here should be caught be the enclosing + # context manager. + self.module.warn('run_a warning', UserWarning) + step_b1.set() + await step_a2.wait() + self.assertEqual(len(w), 1) + self.assertEqual(w[0].message.args[0], 'run_a warning') + step_b2.set() + + async def run_b(): + with self.module.catch_warnings(record=True) as w: + step_a1.set() + await step_b1.wait() + # The warning emitted here should be caught be the enclosing + # context manager. + self.module.warn('run_b warning', UserWarning) + step_a2.set() + await step_b2.wait() + self.assertEqual(len(w), 1) + self.assertEqual(w[0].message.args[0], 'run_b warning') + + async def run_tasks(): + await asyncio.gather(run_a(), run_b()) + asyncio.run(run_tasks()) -class DeprecatedTests(unittest.TestCase): + @unittest.skipIf(not sys.flags.context_aware_warnings, + "requires context aware warnings") + def test_async_task_inherit(self): + """Check that a new asyncio task inherits warnings context from the + coroutine that spawns it. + """ + import asyncio + + step1 = asyncio.Event() + step2 = asyncio.Event() + + async def run_child1(): + await step1.wait() + # This should be recorded by the run_parent() catch_warnings + # context. + self.module.warn('child warning', UserWarning) + step2.set() + + async def run_child2(): + # This establishes a new catch_warnings() context. The + # run_child1() task should still be using the context from + # run_parent() if context-aware warnings are enabled. + with self.module.catch_warnings(record=True) as w: + step1.set() + await step2.wait() + + async def run_parent(): + with self.module.catch_warnings(record=True) as w: + await asyncio.gather(run_child1(), run_child2()) + self.assertEqual(len(w), 1) + self.assertEqual(w[0].message.args[0], 'child warning') + + asyncio.run(run_parent()) + + +class CAsyncTests(AsyncTests, unittest.TestCase): + module = c_warnings + + +class PyAsyncTests(AsyncTests, unittest.TestCase): + module = py_warnings + + +class ThreadTests(BaseTest): + """Verifies that the catch_warnings() context manager behaves as + expected when used within threads. This requires that both the + context_aware_warnings flag and thread_inherit_context flags are enabled. + """ + + ENABLE_THREAD_TESTS = (sys.flags.context_aware_warnings and + sys.flags.thread_inherit_context) + + def setUp(self): + super().setUp() + self.module.resetwarnings() + + @unittest.skipIf(not ENABLE_THREAD_TESTS, + "requires thread-safe warnings flags") + def test_threaded_context(self): + import threading + + barrier = threading.Barrier(2, timeout=2) + + def run_a(): + with self.module.catch_warnings(record=True) as w: + barrier.wait() + # The warning emitted here should be caught be the enclosing + # context manager. + self.module.warn('run_a warning', UserWarning) + barrier.wait() + self.assertEqual(len(w), 1) + self.assertEqual(w[0].message.args[0], 'run_a warning') + # Should be caught be the catch_warnings() context manager of run_threads() + self.module.warn('main warning', UserWarning) + + def run_b(): + with self.module.catch_warnings(record=True) as w: + barrier.wait() + # The warning emitted here should be caught be the enclosing + # context manager. + barrier.wait() + self.module.warn('run_b warning', UserWarning) + self.assertEqual(len(w), 1) + self.assertEqual(w[0].message.args[0], 'run_b warning') + # Should be caught be the catch_warnings() context manager of run_threads() + self.module.warn('main warning', UserWarning) + + def run_threads(): + threads = [ + threading.Thread(target=run_a), + threading.Thread(target=run_b), + ] + with self.module.catch_warnings(record=True) as w: + for thread in threads: + thread.start() + for thread in threads: + thread.join() + self.assertEqual(len(w), 2) + self.assertEqual(w[0].message.args[0], 'main warning') + self.assertEqual(w[1].message.args[0], 'main warning') + + run_threads() + + +class CThreadTests(ThreadTests, unittest.TestCase): + module = c_warnings + + +class PyThreadTests(ThreadTests, unittest.TestCase): + module = py_warnings + + +class DeprecatedTests(PyPublicAPITests): def test_dunder_deprecated(self): @deprecated("A will go away soon") class A: @@ -1614,6 +1850,7 @@ def h(x): self.assertEqual(len(overloads), 2) self.assertEqual(overloads[0].__deprecated__, "no more ints") + @unittest.expectedFailure # TODO: RUSTPYTHON DeprecationWarning not triggered def test_class(self): @deprecated("A will go away soon") class A: @@ -1625,6 +1862,7 @@ class A: with self.assertRaises(TypeError): A(42) + @unittest.expectedFailure # TODO: RUSTPYTHON DeprecationWarning not triggered def test_class_with_init(self): @deprecated("HasInit will go away soon") class HasInit: @@ -1635,6 +1873,7 @@ def __init__(self, x): instance = HasInit(42) self.assertEqual(instance.x, 42) + @unittest.expectedFailure # TODO: RUSTPYTHON DeprecationWarning not triggered def test_class_with_new(self): has_new_called = False @@ -1653,6 +1892,7 @@ def __init__(self, x) -> None: self.assertEqual(instance.x, 42) self.assertTrue(has_new_called) + @unittest.expectedFailure # TODO: RUSTPYTHON DeprecationWarning not triggered def test_class_with_inherited_new(self): new_base_called = False @@ -1674,6 +1914,7 @@ class HasInheritedNew(NewBase): self.assertEqual(instance.x, 42) self.assertTrue(new_base_called) + @unittest.expectedFailure # TODO: RUSTPYTHON DeprecationWarning not triggered def test_class_with_new_but_no_init(self): new_called = False @@ -1691,6 +1932,7 @@ def __new__(cls, x): self.assertEqual(instance.x, 42) self.assertTrue(new_called) + @unittest.expectedFailure # TODO: RUSTPYTHON DeprecationWarning not triggered def test_mixin_class(self): @deprecated("Mixin will go away soon") class Mixin: @@ -1707,6 +1949,7 @@ class Child(Base, Mixin): instance = Child(42) self.assertEqual(instance.a, 42) + @unittest.expectedFailure # TODO: RUSTPYTHON DeprecationWarning not triggered def test_do_not_shadow_user_arguments(self): new_called = False new_called_cls = None @@ -1726,6 +1969,7 @@ class Foo(metaclass=MyMeta, cls='haha'): self.assertTrue(new_called) self.assertEqual(new_called_cls, 'haha') + @unittest.expectedFailure # TODO: RUSTPYTHON DeprecationWarning not triggered def test_existing_init_subclass(self): @deprecated("C will go away soon") class C: @@ -1742,6 +1986,7 @@ class D(C): self.assertTrue(D.inited) self.assertIsInstance(D(), D) # no deprecation + @unittest.expectedFailure # TODO: RUSTPYTHON DeprecationWarning not triggered def test_existing_init_subclass_in_base(self): class Base: def __init_subclass__(cls, x) -> None: @@ -1762,6 +2007,27 @@ class D(C, x=3): self.assertEqual(D.inited, 3) + @unittest.expectedFailure # TODO: RUSTPYTHON DeprecationWarning not triggered + def test_existing_init_subclass_in_sibling_base(self): + @deprecated("A will go away soon") + class A: + pass + class B: + def __init_subclass__(cls, x): + super().__init_subclass__() + cls.inited = x + + with self.assertWarnsRegex(DeprecationWarning, "A will go away soon"): + class C(A, B, x=42): + pass + self.assertEqual(C.inited, 42) + + with self.assertWarnsRegex(DeprecationWarning, "A will go away soon"): + class D(B, A, x=42): + pass + self.assertEqual(D.inited, 42) + + @unittest.expectedFailure # TODO: RUSTPYTHON DeprecationWarning not triggered def test_init_subclass_has_correct_cls(self): init_subclass_saw = None @@ -1779,6 +2045,7 @@ class C(Base): self.assertIs(init_subclass_saw, C) + @unittest.expectedFailure # TODO: RUSTPYTHON DeprecationWarning not triggered def test_init_subclass_with_explicit_classmethod(self): init_subclass_saw = None @@ -1797,6 +2064,7 @@ class C(Base): self.assertIs(init_subclass_saw, C) + @unittest.expectedFailure # TODO: RUSTPYTHON DeprecationWarning not triggered def test_function(self): @deprecated("b will go away soon") def b(): @@ -1805,6 +2073,7 @@ def b(): with self.assertWarnsRegex(DeprecationWarning, "b will go away soon"): b() + @unittest.expectedFailure # TODO: RUSTPYTHON DeprecationWarning not triggered def test_method(self): class Capybara: @deprecated("x will go away soon") @@ -1815,6 +2084,7 @@ def x(self): with self.assertWarnsRegex(DeprecationWarning, "x will go away soon"): instance.x() + @unittest.expectedFailure # TODO: RUSTPYTHON DeprecationWarning not triggered def test_property(self): class Capybara: @property @@ -1842,6 +2112,7 @@ def no_more_setting(self, value): with self.assertWarnsRegex(DeprecationWarning, "no more setting"): instance.no_more_setting = 42 + @unittest.expectedFailure # TODO: RUSTPYTHON RuntimeWarning not triggered def test_category(self): @deprecated("c will go away soon", category=RuntimeWarning) def c(): diff --git a/Lib/test/test_warnings/data/stacklevel.py b/Lib/test/test_warnings/data/stacklevel.py index c6dd24733b3..fe36242d3d2 100644 --- a/Lib/test/test_warnings/data/stacklevel.py +++ b/Lib/test/test_warnings/data/stacklevel.py @@ -4,11 +4,13 @@ import warnings from test.test_warnings.data import package_helper -def outer(message, stacklevel=1): - inner(message, stacklevel) -def inner(message, stacklevel=1): - warnings.warn(message, stacklevel=stacklevel) +def outer(message, stacklevel=1, skip_file_prefixes=()): + inner(message, stacklevel, skip_file_prefixes) + +def inner(message, stacklevel=1, skip_file_prefixes=()): + warnings.warn(message, stacklevel=stacklevel, + skip_file_prefixes=skip_file_prefixes) def package(message, *, stacklevel): package_helper.inner_api(message, stacklevel=stacklevel, diff --git a/Lib/test/test_zlib.py b/Lib/test/test_zlib.py index c9f7b183408..bb1366cb21c 100644 --- a/Lib/test/test_zlib.py +++ b/Lib/test/test_zlib.py @@ -3,7 +3,6 @@ from test.support import import_helper import binascii import copy -import os import pickle import random import sys @@ -13,11 +12,11 @@ zlib = import_helper.import_module('zlib') requires_Compress_copy = unittest.skipUnless( - hasattr(zlib.compressobj(), "copy"), - 'requires Compress.copy()') + hasattr(zlib.compressobj(), "copy"), + 'requires Compress.copy()') requires_Decompress_copy = unittest.skipUnless( - hasattr(zlib.decompressobj(), "copy"), - 'requires Decompress.copy()') + hasattr(zlib.decompressobj(), "copy"), + 'requires Decompress.copy()') def _zlib_runtime_version_tuple(zlib_version=zlib.ZLIB_RUNTIME_VERSION): @@ -154,7 +153,7 @@ def test_badcompressobj(self): self.assertRaises(ValueError, zlib.compressobj, 1, zlib.DEFLATED, 0) # specifying total bits too large causes an error self.assertRaises(ValueError, - zlib.compressobj, 1, zlib.DEFLATED, zlib.MAX_WBITS + 1) + zlib.compressobj, 1, zlib.DEFLATED, zlib.MAX_WBITS + 1) def test_baddecompressobj(self): # verify failure on building decompress object with bad params @@ -242,8 +241,8 @@ def test_incomplete_stream(self): # A useful error message is given x = zlib.compress(HAMLET_SCENE) self.assertRaisesRegex(zlib.error, - "Error -5 while decompressing data: incomplete or truncated stream", - zlib.decompress, x[:-1]) + "Error -5 while decompressing data: incomplete or truncated stream", + zlib.decompress, x[:-1]) # Memory use of the following functions takes into account overallocation @@ -377,7 +376,7 @@ def test_decompinc(self, flush=False, source=None, cx=256, dcx=64): bufs.append(dco.decompress(combuf[i:i+dcx])) self.assertEqual(b'', dco.unconsumed_tail, ######## "(A) uct should be b'': not %d long" % - len(dco.unconsumed_tail)) + len(dco.unconsumed_tail)) self.assertEqual(b'', dco.unused_data) if flush: bufs.append(dco.flush()) @@ -390,7 +389,7 @@ def test_decompinc(self, flush=False, source=None, cx=256, dcx=64): break self.assertEqual(b'', dco.unconsumed_tail, ######## "(B) uct should be b'': not %d long" % - len(dco.unconsumed_tail)) + len(dco.unconsumed_tail)) self.assertEqual(b'', dco.unused_data) self.assertEqual(data, b''.join(bufs)) # Failure means: "decompressobj with init options failed" @@ -419,7 +418,7 @@ def test_decompimax(self, source=None, cx=256, dcx=64): #max_length = 1 + len(cb)//10 chunk = dco.decompress(cb, dcx) self.assertFalse(len(chunk) > dcx, - 'chunk too big (%d>%d)' % (len(chunk), dcx)) + 'chunk too big (%d>%d)' % (len(chunk), dcx)) bufs.append(chunk) cb = dco.unconsumed_tail bufs.append(dco.flush()) @@ -444,7 +443,7 @@ def test_decompressmaxlen(self, flush=False): max_length = 1 + len(cb)//10 chunk = dco.decompress(cb, max_length) self.assertFalse(len(chunk) > max_length, - 'chunk too big (%d>%d)' % (len(chunk),max_length)) + 'chunk too big (%d>%d)' % (len(chunk),max_length)) bufs.append(chunk) cb = dco.unconsumed_tail if flush: @@ -453,7 +452,7 @@ def test_decompressmaxlen(self, flush=False): while chunk: chunk = dco.decompress(b'', max_length) self.assertFalse(len(chunk) > max_length, - 'chunk too big (%d>%d)' % (len(chunk),max_length)) + 'chunk too big (%d>%d)' % (len(chunk),max_length)) bufs.append(chunk) self.assertEqual(data, b''.join(bufs), 'Wrong data retrieved') @@ -490,8 +489,7 @@ def test_clear_unconsumed_tail(self): ddata += dco.decompress(dco.unconsumed_tail) self.assertEqual(dco.unconsumed_tail, b"") - # TODO: RUSTPYTHON: Z_BLOCK support in flate2 - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON; Z_BLOCK support in flate2 def test_flushes(self): # Test flush() with the various options, using all the # different levels in order to provide more variations. @@ -633,7 +631,7 @@ def test_decompress_unused_data(self): self.assertEqual(dco.unconsumed_tail, b'') else: data += dco.decompress( - dco.unconsumed_tail + x[i : i + step], maxlen) + dco.unconsumed_tail + x[i : i + step], maxlen) data += dco.flush() self.assertTrue(dco.eof) self.assertEqual(data, source) @@ -811,8 +809,7 @@ def test_large_unconsumed_tail(self, size): finally: comp = uncomp = data = None - # TODO: RUSTPYTHON: wbits=0 support in flate2 - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON; wbits=0 support in flate2 def test_wbits(self): # wbits=0 only supported since zlib v1.2.3.5 supports_wbits_0 = ZLIB_RUNTIME_VERSION_TUPLE >= (1, 2, 3, 5) @@ -941,6 +938,7 @@ def choose_lines(source, number, seed=None, generator=random): Farewell. """ + class ZlibDecompressorTest(unittest.TestCase): # Test adopted from test_bz2.py TEXT = HAMLET_SCENE @@ -1015,7 +1013,7 @@ def testDecompressorChunksMaxsize(self): # Feed some input len_ = len(self.BIG_DATA) - 64 out.append(zlibd.decompress(self.BIG_DATA[:len_], - max_length=max_length)) + max_length=max_length)) self.assertFalse(zlibd.needs_input) self.assertEqual(len(out[-1]), max_length) @@ -1026,7 +1024,7 @@ def testDecompressorChunksMaxsize(self): # Retrieve more data while providing more input out.append(zlibd.decompress(self.BIG_DATA[len_:], - max_length=max_length)) + max_length=max_length)) self.assertLessEqual(len(out[-1]), max_length) # Retrieve remaining uncompressed data @@ -1046,7 +1044,7 @@ def test_decompressor_inputbuf_1(self): # Create input buffer and fill it self.assertEqual(zlibd.decompress(self.DATA[:100], - max_length=0), b'') + max_length=0), b'') # Retrieve some results, freeing capacity at beginning # of input buffer @@ -1068,7 +1066,7 @@ def test_decompressor_inputbuf_2(self): # Create input buffer and empty it self.assertEqual(zlibd.decompress(self.DATA[:200], - max_length=0), b'') + max_length=0), b'') out.append(zlibd.decompress(b'')) # Fill buffer with new data @@ -1112,6 +1110,7 @@ def test_refleaks_in___init__(self): zlibd.__init__() self.assertAlmostEqual(gettotalrefcount() - refs_before, 0, delta=10) + class CustomInt: def __index__(self): return 100 diff --git a/Lib/test/test_zstd.py b/Lib/test/test_zstd.py new file mode 100644 index 00000000000..cf618534add --- /dev/null +++ b/Lib/test/test_zstd.py @@ -0,0 +1,2802 @@ +import array +import gc +import io +import pathlib +import random +import re +import os +import unittest +import tempfile +import threading + +from test.support.import_helper import import_module +from test.support import threading_helper +from test.support import _1M +from test.support import Py_GIL_DISABLED + +_zstd = import_module("_zstd") +zstd = import_module("compression.zstd") + +from compression.zstd import ( + open, + compress, + decompress, + ZstdCompressor, + ZstdDecompressor, + ZstdDict, + ZstdError, + zstd_version, + zstd_version_info, + COMPRESSION_LEVEL_DEFAULT, + get_frame_info, + get_frame_size, + finalize_dict, + train_dict, + CompressionParameter, + DecompressionParameter, + Strategy, + ZstdFile, +) + +_1K = 1024 +_130_1K = 130 * _1K +DICT_SIZE1 = 3*_1K + +DAT_130K_D = None +DAT_130K_C = None + +DECOMPRESSED_DAT = None +COMPRESSED_DAT = None + +DECOMPRESSED_100_PLUS_32KB = None +COMPRESSED_100_PLUS_32KB = None + +SKIPPABLE_FRAME = None + +THIS_FILE_BYTES = None +THIS_FILE_STR = None +COMPRESSED_THIS_FILE = None + +COMPRESSED_BOGUS = None + +SAMPLES = None + +TRAINED_DICT = None + +# Cannot be deferred to setup as it is used to check whether or not to skip +# tests +try: + SUPPORT_MULTITHREADING = CompressionParameter.nb_workers.bounds() != (0, 0) +except Exception: + SUPPORT_MULTITHREADING = False + +C_INT_MIN = -(2**31) +C_INT_MAX = (2**31) - 1 + + +def setUpModule(): + # uncompressed size 130KB, more than a zstd block. + # with a frame epilogue, 4 bytes checksum. + global DAT_130K_D + DAT_130K_D = bytes([random.randint(0, 127) for _ in range(130*_1K)]) + + global DAT_130K_C + DAT_130K_C = compress(DAT_130K_D, options={CompressionParameter.checksum_flag:1}) + + global DECOMPRESSED_DAT + DECOMPRESSED_DAT = b'abcdefg123456' * 1000 + + global COMPRESSED_DAT + COMPRESSED_DAT = compress(DECOMPRESSED_DAT) + + global DECOMPRESSED_100_PLUS_32KB + DECOMPRESSED_100_PLUS_32KB = b'a' * (100 + 32*_1K) + + global COMPRESSED_100_PLUS_32KB + COMPRESSED_100_PLUS_32KB = compress(DECOMPRESSED_100_PLUS_32KB) + + global SKIPPABLE_FRAME + SKIPPABLE_FRAME = (0x184D2A50).to_bytes(4, byteorder='little') + \ + (32*_1K).to_bytes(4, byteorder='little') + \ + b'a' * (32*_1K) + + global THIS_FILE_BYTES, THIS_FILE_STR + with io.open(os.path.abspath(__file__), 'rb') as f: + THIS_FILE_BYTES = f.read() + THIS_FILE_BYTES = re.sub(rb'\r?\n', rb'\n', THIS_FILE_BYTES) + THIS_FILE_STR = THIS_FILE_BYTES.decode('utf-8') + + global COMPRESSED_THIS_FILE + COMPRESSED_THIS_FILE = compress(THIS_FILE_BYTES) + + global COMPRESSED_BOGUS + COMPRESSED_BOGUS = DECOMPRESSED_DAT + + # dict data + words = [b'red', b'green', b'yellow', b'black', b'withe', b'blue', + b'lilac', b'purple', b'navy', b'glod', b'silver', b'olive', + b'dog', b'cat', b'tiger', b'lion', b'fish', b'bird'] + lst = [] + for i in range(300): + sample = [b'%s = %d' % (random.choice(words), random.randrange(100)) + for j in range(20)] + sample = b'\n'.join(sample) + + lst.append(sample) + global SAMPLES + SAMPLES = lst + assert len(SAMPLES) > 10 + + global TRAINED_DICT + TRAINED_DICT = train_dict(SAMPLES, 3*_1K) + assert len(TRAINED_DICT.dict_content) <= 3*_1K + + +class FunctionsTestCase(unittest.TestCase): + + def test_version(self): + s = ".".join((str(i) for i in zstd_version_info)) + self.assertEqual(s, zstd_version) + + def test_compressionLevel_values(self): + min, max = CompressionParameter.compression_level.bounds() + self.assertIs(type(COMPRESSION_LEVEL_DEFAULT), int) + self.assertIs(type(min), int) + self.assertIs(type(max), int) + self.assertLess(min, max) + + def test_roundtrip_default(self): + raw_dat = THIS_FILE_BYTES[: len(THIS_FILE_BYTES) // 6] + dat1 = compress(raw_dat) + dat2 = decompress(dat1) + self.assertEqual(dat2, raw_dat) + + def test_roundtrip_level(self): + raw_dat = THIS_FILE_BYTES[: len(THIS_FILE_BYTES) // 6] + level_min, level_max = CompressionParameter.compression_level.bounds() + + for level in range(max(-20, level_min), level_max + 1): + dat1 = compress(raw_dat, level) + dat2 = decompress(dat1) + self.assertEqual(dat2, raw_dat) + + def test_get_frame_info(self): + # no dict + info = get_frame_info(COMPRESSED_100_PLUS_32KB[:20]) + self.assertEqual(info.decompressed_size, 32 * _1K + 100) + self.assertEqual(info.dictionary_id, 0) + + # use dict + dat = compress(b"a" * 345, zstd_dict=TRAINED_DICT) + info = get_frame_info(dat) + self.assertEqual(info.decompressed_size, 345) + self.assertEqual(info.dictionary_id, TRAINED_DICT.dict_id) + + with self.assertRaisesRegex(ZstdError, "not less than the frame header"): + get_frame_info(b"aaaaaaaaaaaaaa") + + def test_get_frame_size(self): + size = get_frame_size(COMPRESSED_100_PLUS_32KB) + self.assertEqual(size, len(COMPRESSED_100_PLUS_32KB)) + + with self.assertRaisesRegex(ZstdError, "not less than this complete frame"): + get_frame_size(b"aaaaaaaaaaaaaa") + + def test_decompress_2x130_1K(self): + decompressed_size = get_frame_info(DAT_130K_C).decompressed_size + self.assertEqual(decompressed_size, _130_1K) + + dat = decompress(DAT_130K_C + DAT_130K_C) + self.assertEqual(len(dat), 2 * _130_1K) + + +class CompressorTestCase(unittest.TestCase): + + def test_simple_compress_bad_args(self): + # ZstdCompressor + self.assertRaises(TypeError, ZstdCompressor, []) + self.assertRaises(TypeError, ZstdCompressor, level=3.14) + self.assertRaises(TypeError, ZstdCompressor, level="abc") + self.assertRaises(TypeError, ZstdCompressor, options=b"abc") + + self.assertRaises(TypeError, ZstdCompressor, zstd_dict=123) + self.assertRaises(TypeError, ZstdCompressor, zstd_dict=b"abcd1234") + self.assertRaises(TypeError, ZstdCompressor, zstd_dict={1: 2, 3: 4}) + + # valid range for compression level is [-(1<<17), 22] + msg = r'illegal compression level {}; the valid range is \[-?\d+, -?\d+\]' + with self.assertRaisesRegex(ValueError, msg.format(C_INT_MAX)): + ZstdCompressor(C_INT_MAX) + with self.assertRaisesRegex(ValueError, msg.format(C_INT_MIN)): + ZstdCompressor(C_INT_MIN) + msg = r'illegal compression level; the valid range is \[-?\d+, -?\d+\]' + with self.assertRaisesRegex(ValueError, msg): + ZstdCompressor(level=-(2**1000)) + with self.assertRaisesRegex(ValueError, msg): + ZstdCompressor(level=2**1000) + + with self.assertRaises(ValueError): + ZstdCompressor(options={CompressionParameter.window_log: 100}) + with self.assertRaises(ValueError): + ZstdCompressor(options={3333: 100}) + + # Method bad arguments + zc = ZstdCompressor() + self.assertRaises(TypeError, zc.compress) + self.assertRaises((TypeError, ValueError), zc.compress, b"foo", b"bar") + self.assertRaises(TypeError, zc.compress, "str") + self.assertRaises((TypeError, ValueError), zc.flush, b"foo") + self.assertRaises(TypeError, zc.flush, b"blah", 1) + + self.assertRaises(ValueError, zc.compress, b'', -1) + self.assertRaises(ValueError, zc.compress, b'', 3) + self.assertRaises(ValueError, zc.flush, zc.CONTINUE) # 0 + self.assertRaises(ValueError, zc.flush, 3) + + zc.compress(b'') + zc.compress(b'', zc.CONTINUE) + zc.compress(b'', zc.FLUSH_BLOCK) + zc.compress(b'', zc.FLUSH_FRAME) + empty = zc.flush() + zc.flush(zc.FLUSH_BLOCK) + zc.flush(zc.FLUSH_FRAME) + + def test_compress_parameters(self): + d = {CompressionParameter.compression_level : 10, + + CompressionParameter.window_log : 12, + CompressionParameter.hash_log : 10, + CompressionParameter.chain_log : 12, + CompressionParameter.search_log : 12, + CompressionParameter.min_match : 4, + CompressionParameter.target_length : 12, + CompressionParameter.strategy : Strategy.lazy, + + CompressionParameter.enable_long_distance_matching : 1, + CompressionParameter.ldm_hash_log : 12, + CompressionParameter.ldm_min_match : 11, + CompressionParameter.ldm_bucket_size_log : 5, + CompressionParameter.ldm_hash_rate_log : 12, + + CompressionParameter.content_size_flag : 1, + CompressionParameter.checksum_flag : 1, + CompressionParameter.dict_id_flag : 0, + + CompressionParameter.nb_workers : 2 if SUPPORT_MULTITHREADING else 0, + CompressionParameter.job_size : 5*_1M if SUPPORT_MULTITHREADING else 0, + CompressionParameter.overlap_log : 9 if SUPPORT_MULTITHREADING else 0, + } + ZstdCompressor(options=d) + + d1 = d.copy() + # larger than signed int + d1[CompressionParameter.ldm_bucket_size_log] = C_INT_MAX + with self.assertRaises(ValueError): + ZstdCompressor(options=d1) + # smaller than signed int + d1[CompressionParameter.ldm_bucket_size_log] = C_INT_MIN + with self.assertRaises(ValueError): + ZstdCompressor(options=d1) + + # out of bounds compression level + level_min, level_max = CompressionParameter.compression_level.bounds() + with self.assertRaises(ValueError): + compress(b'', level_max+1) + with self.assertRaises(ValueError): + compress(b'', level_min-1) + with self.assertRaises(ValueError): + compress(b'', 2**1000) + with self.assertRaises(ValueError): + compress(b'', -(2**1000)) + with self.assertRaises(ValueError): + compress(b'', options={ + CompressionParameter.compression_level: level_max+1}) + with self.assertRaises(ValueError): + compress(b'', options={ + CompressionParameter.compression_level: level_min-1}) + + # zstd lib doesn't support MT compression + if not SUPPORT_MULTITHREADING: + with self.assertRaises(ValueError): + ZstdCompressor(options={CompressionParameter.nb_workers:4}) + with self.assertRaises(ValueError): + ZstdCompressor(options={CompressionParameter.job_size:4}) + with self.assertRaises(ValueError): + ZstdCompressor(options={CompressionParameter.overlap_log:4}) + + # out of bounds error msg + option = {CompressionParameter.window_log:100} + with self.assertRaisesRegex( + ValueError, + "compression parameter 'window_log' received an illegal value 100; " + r'the valid range is \[-?\d+, -?\d+\]', + ): + compress(b'', options=option) + + def test_unknown_compression_parameter(self): + KEY = 100001234 + option = {CompressionParameter.compression_level: 10, + KEY: 200000000} + pattern = rf"invalid compression parameter 'unknown parameter \(key {KEY}\)'" + with self.assertRaisesRegex(ValueError, pattern): + ZstdCompressor(options=option) + + @unittest.skipIf(not SUPPORT_MULTITHREADING, + "zstd build doesn't support multi-threaded compression") + def test_zstd_multithread_compress(self): + size = 40*_1M + b = THIS_FILE_BYTES * (size // len(THIS_FILE_BYTES)) + + options = {CompressionParameter.compression_level : 4, + CompressionParameter.nb_workers : 2} + + # compress() + dat1 = compress(b, options=options) + dat2 = decompress(dat1) + self.assertEqual(dat2, b) + + # ZstdCompressor + c = ZstdCompressor(options=options) + dat1 = c.compress(b, c.CONTINUE) + dat2 = c.compress(b, c.FLUSH_BLOCK) + dat3 = c.compress(b, c.FLUSH_FRAME) + dat4 = decompress(dat1+dat2+dat3) + self.assertEqual(dat4, b * 3) + + # ZstdFile + with ZstdFile(io.BytesIO(), 'w', options=options) as f: + f.write(b) + + def test_compress_flushblock(self): + point = len(THIS_FILE_BYTES) // 2 + + c = ZstdCompressor() + self.assertEqual(c.last_mode, c.FLUSH_FRAME) + dat1 = c.compress(THIS_FILE_BYTES[:point]) + self.assertEqual(c.last_mode, c.CONTINUE) + dat1 += c.compress(THIS_FILE_BYTES[point:], c.FLUSH_BLOCK) + self.assertEqual(c.last_mode, c.FLUSH_BLOCK) + dat2 = c.flush() + pattern = "Compressed data ended before the end-of-stream marker" + with self.assertRaisesRegex(ZstdError, pattern): + decompress(dat1) + + dat3 = decompress(dat1 + dat2) + + self.assertEqual(dat3, THIS_FILE_BYTES) + + def test_compress_flushframe(self): + # test compress & decompress + point = len(THIS_FILE_BYTES) // 2 + + c = ZstdCompressor() + + dat1 = c.compress(THIS_FILE_BYTES[:point]) + self.assertEqual(c.last_mode, c.CONTINUE) + + dat1 += c.compress(THIS_FILE_BYTES[point:], c.FLUSH_FRAME) + self.assertEqual(c.last_mode, c.FLUSH_FRAME) + + nt = get_frame_info(dat1) + self.assertEqual(nt.decompressed_size, None) # no content size + + dat2 = decompress(dat1) + + self.assertEqual(dat2, THIS_FILE_BYTES) + + # single .FLUSH_FRAME mode has content size + c = ZstdCompressor() + dat = c.compress(THIS_FILE_BYTES, mode=c.FLUSH_FRAME) + self.assertEqual(c.last_mode, c.FLUSH_FRAME) + + nt = get_frame_info(dat) + self.assertEqual(nt.decompressed_size, len(THIS_FILE_BYTES)) + + def test_compress_empty(self): + # output empty content frame + self.assertNotEqual(compress(b''), b'') + + c = ZstdCompressor() + self.assertNotEqual(c.compress(b'', c.FLUSH_FRAME), b'') + + def test_set_pledged_input_size(self): + DAT = DECOMPRESSED_100_PLUS_32KB + CHUNK_SIZE = len(DAT) // 3 + + # wrong value + c = ZstdCompressor() + with self.assertRaisesRegex(ValueError, + r'should be a positive int less than \d+'): + c.set_pledged_input_size(-300) + # overflow + with self.assertRaisesRegex(ValueError, + r'should be a positive int less than \d+'): + c.set_pledged_input_size(2**64) + # ZSTD_CONTENTSIZE_ERROR is invalid + with self.assertRaisesRegex(ValueError, + r'should be a positive int less than \d+'): + c.set_pledged_input_size(2**64-2) + # ZSTD_CONTENTSIZE_UNKNOWN should use None + with self.assertRaisesRegex(ValueError, + r'should be a positive int less than \d+'): + c.set_pledged_input_size(2**64-1) + + # check valid values are settable + c.set_pledged_input_size(2**63) + c.set_pledged_input_size(2**64-3) + + # check that zero means empty frame + c = ZstdCompressor(level=1) + c.set_pledged_input_size(0) + c.compress(b'') + dat = c.flush() + ret = get_frame_info(dat) + self.assertEqual(ret.decompressed_size, 0) + + + # wrong mode + c = ZstdCompressor(level=1) + c.compress(b'123456') + self.assertEqual(c.last_mode, c.CONTINUE) + with self.assertRaisesRegex(ValueError, + r'last_mode == FLUSH_FRAME'): + c.set_pledged_input_size(300) + + # None value + c = ZstdCompressor(level=1) + c.set_pledged_input_size(None) + dat = c.compress(DAT) + c.flush() + + ret = get_frame_info(dat) + self.assertEqual(ret.decompressed_size, None) + + # correct value + c = ZstdCompressor(level=1) + c.set_pledged_input_size(len(DAT)) + + chunks = [] + posi = 0 + while posi < len(DAT): + dat = c.compress(DAT[posi:posi+CHUNK_SIZE]) + posi += CHUNK_SIZE + chunks.append(dat) + + dat = c.flush() + chunks.append(dat) + chunks = b''.join(chunks) + + ret = get_frame_info(chunks) + self.assertEqual(ret.decompressed_size, len(DAT)) + self.assertEqual(decompress(chunks), DAT) + + c.set_pledged_input_size(len(DAT)) # the second frame + dat = c.compress(DAT) + c.flush() + + ret = get_frame_info(dat) + self.assertEqual(ret.decompressed_size, len(DAT)) + self.assertEqual(decompress(dat), DAT) + + # not enough data + c = ZstdCompressor(level=1) + c.set_pledged_input_size(len(DAT)+1) + + for start in range(0, len(DAT), CHUNK_SIZE): + end = min(start+CHUNK_SIZE, len(DAT)) + _dat = c.compress(DAT[start:end]) + + with self.assertRaises(ZstdError): + c.flush() + + # too much data + c = ZstdCompressor(level=1) + c.set_pledged_input_size(len(DAT)) + + for start in range(0, len(DAT), CHUNK_SIZE): + end = min(start+CHUNK_SIZE, len(DAT)) + _dat = c.compress(DAT[start:end]) + + with self.assertRaises(ZstdError): + c.compress(b'extra', ZstdCompressor.FLUSH_FRAME) + + # content size not set if content_size_flag == 0 + c = ZstdCompressor(options={CompressionParameter.content_size_flag: 0}) + c.set_pledged_input_size(10) + dat1 = c.compress(b"hello") + dat2 = c.compress(b"world") + dat3 = c.flush() + frame_data = get_frame_info(dat1 + dat2 + dat3) + self.assertIsNone(frame_data.decompressed_size) + + +class DecompressorTestCase(unittest.TestCase): + + def test_simple_decompress_bad_args(self): + # ZstdDecompressor + self.assertRaises(TypeError, ZstdDecompressor, ()) + self.assertRaises(TypeError, ZstdDecompressor, zstd_dict=123) + self.assertRaises(TypeError, ZstdDecompressor, zstd_dict=b'abc') + self.assertRaises(TypeError, ZstdDecompressor, zstd_dict={1:2, 3:4}) + + self.assertRaises(TypeError, ZstdDecompressor, options=123) + self.assertRaises(TypeError, ZstdDecompressor, options='abc') + self.assertRaises(TypeError, ZstdDecompressor, options=b'abc') + + with self.assertRaises(ValueError): + ZstdDecompressor(options={C_INT_MAX: 100}) + with self.assertRaises(ValueError): + ZstdDecompressor(options={C_INT_MIN: 100}) + with self.assertRaises(ValueError): + ZstdDecompressor(options={0: C_INT_MAX}) + with self.assertRaises(OverflowError): + ZstdDecompressor(options={2**1000: 100}) + with self.assertRaises(OverflowError): + ZstdDecompressor(options={-(2**1000): 100}) + with self.assertRaises(OverflowError): + ZstdDecompressor(options={0: -(2**1000)}) + + with self.assertRaises(ValueError): + ZstdDecompressor(options={DecompressionParameter.window_log_max: 100}) + with self.assertRaises(ValueError): + ZstdDecompressor(options={3333: 100}) + + empty = compress(b'') + lzd = ZstdDecompressor() + self.assertRaises(TypeError, lzd.decompress) + self.assertRaises(TypeError, lzd.decompress, b"foo", b"bar") + self.assertRaises(TypeError, lzd.decompress, "str") + lzd.decompress(empty) + + def test_decompress_parameters(self): + d = {DecompressionParameter.window_log_max : 15} + ZstdDecompressor(options=d) + + d1 = d.copy() + # larger than signed int + d1[DecompressionParameter.window_log_max] = 2**1000 + with self.assertRaises(OverflowError): + ZstdDecompressor(None, d1) + # smaller than signed int + d1[DecompressionParameter.window_log_max] = -(2**1000) + with self.assertRaises(OverflowError): + ZstdDecompressor(None, d1) + + d1[DecompressionParameter.window_log_max] = C_INT_MAX + with self.assertRaises(ValueError): + ZstdDecompressor(None, d1) + d1[DecompressionParameter.window_log_max] = C_INT_MIN + with self.assertRaises(ValueError): + ZstdDecompressor(None, d1) + + # out of bounds error msg + options = {DecompressionParameter.window_log_max:100} + with self.assertRaisesRegex( + ValueError, + "decompression parameter 'window_log_max' received an illegal value 100; " + r'the valid range is \[-?\d+, -?\d+\]', + ): + decompress(b'', options=options) + + # out of bounds deecompression parameter + options[DecompressionParameter.window_log_max] = C_INT_MAX + with self.assertRaises(ValueError): + decompress(b'', options=options) + options[DecompressionParameter.window_log_max] = C_INT_MIN + with self.assertRaises(ValueError): + decompress(b'', options=options) + options[DecompressionParameter.window_log_max] = 2**1000 + with self.assertRaises(OverflowError): + decompress(b'', options=options) + options[DecompressionParameter.window_log_max] = -(2**1000) + with self.assertRaises(OverflowError): + decompress(b'', options=options) + + def test_unknown_decompression_parameter(self): + KEY = 100001234 + options = {DecompressionParameter.window_log_max: DecompressionParameter.window_log_max.bounds()[1], + KEY: 200000000} + pattern = rf"invalid decompression parameter 'unknown parameter \(key {KEY}\)'" + with self.assertRaisesRegex(ValueError, pattern): + ZstdDecompressor(options=options) + + def test_decompress_epilogue_flags(self): + # DAT_130K_C has a 4 bytes checksum at frame epilogue + + # full unlimited + d = ZstdDecompressor() + dat = d.decompress(DAT_130K_C) + self.assertEqual(len(dat), _130_1K) + self.assertFalse(d.needs_input) + + with self.assertRaises(EOFError): + dat = d.decompress(b'') + + # full limited + d = ZstdDecompressor() + dat = d.decompress(DAT_130K_C, _130_1K) + self.assertEqual(len(dat), _130_1K) + self.assertFalse(d.needs_input) + + with self.assertRaises(EOFError): + dat = d.decompress(b'', 0) + + # [:-4] unlimited + d = ZstdDecompressor() + dat = d.decompress(DAT_130K_C[:-4]) + self.assertEqual(len(dat), _130_1K) + self.assertTrue(d.needs_input) + + dat = d.decompress(b'') + self.assertEqual(len(dat), 0) + self.assertTrue(d.needs_input) + + # [:-4] limited + d = ZstdDecompressor() + dat = d.decompress(DAT_130K_C[:-4], _130_1K) + self.assertEqual(len(dat), _130_1K) + self.assertFalse(d.needs_input) + + dat = d.decompress(b'', 0) + self.assertEqual(len(dat), 0) + self.assertFalse(d.needs_input) + + # [:-3] unlimited + d = ZstdDecompressor() + dat = d.decompress(DAT_130K_C[:-3]) + self.assertEqual(len(dat), _130_1K) + self.assertTrue(d.needs_input) + + dat = d.decompress(b'') + self.assertEqual(len(dat), 0) + self.assertTrue(d.needs_input) + + # [:-3] limited + d = ZstdDecompressor() + dat = d.decompress(DAT_130K_C[:-3], _130_1K) + self.assertEqual(len(dat), _130_1K) + self.assertFalse(d.needs_input) + + dat = d.decompress(b'', 0) + self.assertEqual(len(dat), 0) + self.assertFalse(d.needs_input) + + # [:-1] unlimited + d = ZstdDecompressor() + dat = d.decompress(DAT_130K_C[:-1]) + self.assertEqual(len(dat), _130_1K) + self.assertTrue(d.needs_input) + + dat = d.decompress(b'') + self.assertEqual(len(dat), 0) + self.assertTrue(d.needs_input) + + # [:-1] limited + d = ZstdDecompressor() + dat = d.decompress(DAT_130K_C[:-1], _130_1K) + self.assertEqual(len(dat), _130_1K) + self.assertFalse(d.needs_input) + + dat = d.decompress(b'', 0) + self.assertEqual(len(dat), 0) + self.assertFalse(d.needs_input) + + def test_decompressor_arg(self): + zd = ZstdDict(b'12345678', is_raw=True) + + with self.assertRaises(TypeError): + d = ZstdDecompressor(zstd_dict={}) + + with self.assertRaises(TypeError): + d = ZstdDecompressor(options=zd) + + ZstdDecompressor() + ZstdDecompressor(zd, {}) + ZstdDecompressor(zstd_dict=zd, options={DecompressionParameter.window_log_max:25}) + + def test_decompressor_1(self): + # empty + d = ZstdDecompressor() + dat = d.decompress(b'') + + self.assertEqual(dat, b'') + self.assertFalse(d.eof) + + # 130_1K full + d = ZstdDecompressor() + dat = d.decompress(DAT_130K_C) + + self.assertEqual(len(dat), _130_1K) + self.assertTrue(d.eof) + self.assertFalse(d.needs_input) + + # 130_1K full, limit output + d = ZstdDecompressor() + dat = d.decompress(DAT_130K_C, _130_1K) + + self.assertEqual(len(dat), _130_1K) + self.assertTrue(d.eof) + self.assertFalse(d.needs_input) + + # 130_1K, without 4 bytes checksum + d = ZstdDecompressor() + dat = d.decompress(DAT_130K_C[:-4]) + + self.assertEqual(len(dat), _130_1K) + self.assertFalse(d.eof) + self.assertTrue(d.needs_input) + + # above, limit output + d = ZstdDecompressor() + dat = d.decompress(DAT_130K_C[:-4], _130_1K) + + self.assertEqual(len(dat), _130_1K) + self.assertFalse(d.eof) + self.assertFalse(d.needs_input) + + # full, unused_data + TRAIL = b'89234893abcd' + d = ZstdDecompressor() + dat = d.decompress(DAT_130K_C + TRAIL, _130_1K) + + self.assertEqual(len(dat), _130_1K) + self.assertTrue(d.eof) + self.assertFalse(d.needs_input) + self.assertEqual(d.unused_data, TRAIL) + + def test_decompressor_chunks_read_300(self): + TRAIL = b'89234893abcd' + DAT = DAT_130K_C + TRAIL + d = ZstdDecompressor() + + bi = io.BytesIO(DAT) + lst = [] + while True: + if d.needs_input: + dat = bi.read(300) + if not dat: + break + else: + raise Exception('should not get here') + + ret = d.decompress(dat) + lst.append(ret) + if d.eof: + break + + ret = b''.join(lst) + + self.assertEqual(len(ret), _130_1K) + self.assertTrue(d.eof) + self.assertFalse(d.needs_input) + self.assertEqual(d.unused_data + bi.read(), TRAIL) + + def test_decompressor_chunks_read_3(self): + TRAIL = b'89234893' + DAT = DAT_130K_C + TRAIL + d = ZstdDecompressor() + + bi = io.BytesIO(DAT) + lst = [] + while True: + if d.needs_input: + dat = bi.read(3) + if not dat: + break + else: + dat = b'' + + ret = d.decompress(dat, 1) + lst.append(ret) + if d.eof: + break + + ret = b''.join(lst) + + self.assertEqual(len(ret), _130_1K) + self.assertTrue(d.eof) + self.assertFalse(d.needs_input) + self.assertEqual(d.unused_data + bi.read(), TRAIL) + + + def test_decompress_empty(self): + with self.assertRaises(ZstdError): + decompress(b'') + + d = ZstdDecompressor() + self.assertEqual(d.decompress(b''), b'') + self.assertFalse(d.eof) + + def test_decompress_empty_content_frame(self): + DAT = compress(b'') + # decompress + self.assertGreaterEqual(len(DAT), 4) + self.assertEqual(decompress(DAT), b'') + + with self.assertRaises(ZstdError): + decompress(DAT[:-1]) + + # ZstdDecompressor + d = ZstdDecompressor() + dat = d.decompress(DAT) + self.assertEqual(dat, b'') + self.assertTrue(d.eof) + self.assertFalse(d.needs_input) + self.assertEqual(d.unused_data, b'') + self.assertEqual(d.unused_data, b'') # twice + + d = ZstdDecompressor() + dat = d.decompress(DAT[:-1]) + self.assertEqual(dat, b'') + self.assertFalse(d.eof) + self.assertTrue(d.needs_input) + self.assertEqual(d.unused_data, b'') + self.assertEqual(d.unused_data, b'') # twice + +class DecompressorFlagsTestCase(unittest.TestCase): + + @classmethod + def setUpClass(cls): + options = {CompressionParameter.checksum_flag:1} + c = ZstdCompressor(options=options) + + cls.DECOMPRESSED_42 = b'a'*42 + cls.FRAME_42 = c.compress(cls.DECOMPRESSED_42, c.FLUSH_FRAME) + + cls.DECOMPRESSED_60 = b'a'*60 + cls.FRAME_60 = c.compress(cls.DECOMPRESSED_60, c.FLUSH_FRAME) + + cls.FRAME_42_60 = cls.FRAME_42 + cls.FRAME_60 + cls.DECOMPRESSED_42_60 = cls.DECOMPRESSED_42 + cls.DECOMPRESSED_60 + + cls._130_1K = 130*_1K + + c = ZstdCompressor() + cls.UNKNOWN_FRAME_42 = c.compress(cls.DECOMPRESSED_42) + c.flush() + cls.UNKNOWN_FRAME_60 = c.compress(cls.DECOMPRESSED_60) + c.flush() + cls.UNKNOWN_FRAME_42_60 = cls.UNKNOWN_FRAME_42 + cls.UNKNOWN_FRAME_60 + + cls.TRAIL = b'12345678abcdefg!@#$%^&*()_+|' + + def test_function_decompress(self): + + self.assertEqual(len(decompress(COMPRESSED_100_PLUS_32KB)), 100+32*_1K) + + # 1 frame + self.assertEqual(decompress(self.FRAME_42), self.DECOMPRESSED_42) + + self.assertEqual(decompress(self.UNKNOWN_FRAME_42), self.DECOMPRESSED_42) + + pattern = r"Compressed data ended before the end-of-stream marker" + with self.assertRaisesRegex(ZstdError, pattern): + decompress(self.FRAME_42[:1]) + + with self.assertRaisesRegex(ZstdError, pattern): + decompress(self.FRAME_42[:-4]) + + with self.assertRaisesRegex(ZstdError, pattern): + decompress(self.FRAME_42[:-1]) + + # 2 frames + self.assertEqual(decompress(self.FRAME_42_60), self.DECOMPRESSED_42_60) + + self.assertEqual(decompress(self.UNKNOWN_FRAME_42_60), self.DECOMPRESSED_42_60) + + self.assertEqual(decompress(self.FRAME_42 + self.UNKNOWN_FRAME_60), + self.DECOMPRESSED_42_60) + + self.assertEqual(decompress(self.UNKNOWN_FRAME_42 + self.FRAME_60), + self.DECOMPRESSED_42_60) + + with self.assertRaisesRegex(ZstdError, pattern): + decompress(self.FRAME_42_60[:-4]) + + with self.assertRaisesRegex(ZstdError, pattern): + decompress(self.UNKNOWN_FRAME_42_60[:-1]) + + # 130_1K + self.assertEqual(decompress(DAT_130K_C), DAT_130K_D) + + with self.assertRaisesRegex(ZstdError, pattern): + decompress(DAT_130K_C[:-4]) + + with self.assertRaisesRegex(ZstdError, pattern): + decompress(DAT_130K_C[:-1]) + + # Unknown frame descriptor + with self.assertRaisesRegex(ZstdError, "Unknown frame descriptor"): + decompress(b'aaaaaaaaa') + + with self.assertRaisesRegex(ZstdError, "Unknown frame descriptor"): + decompress(self.FRAME_42 + b'aaaaaaaaa') + + with self.assertRaisesRegex(ZstdError, "Unknown frame descriptor"): + decompress(self.UNKNOWN_FRAME_42_60 + b'aaaaaaaaa') + + # doesn't match checksum + checksum = DAT_130K_C[-4:] + if checksum[0] == 255: + wrong_checksum = bytes([254]) + checksum[1:] + else: + wrong_checksum = bytes([checksum[0]+1]) + checksum[1:] + + dat = DAT_130K_C[:-4] + wrong_checksum + + with self.assertRaisesRegex(ZstdError, "doesn't match checksum"): + decompress(dat) + + def test_function_skippable(self): + self.assertEqual(decompress(SKIPPABLE_FRAME), b'') + self.assertEqual(decompress(SKIPPABLE_FRAME + SKIPPABLE_FRAME), b'') + + # 1 frame + 2 skippable + self.assertEqual(len(decompress(SKIPPABLE_FRAME + SKIPPABLE_FRAME + DAT_130K_C)), + self._130_1K) + + self.assertEqual(len(decompress(DAT_130K_C + SKIPPABLE_FRAME + SKIPPABLE_FRAME)), + self._130_1K) + + self.assertEqual(len(decompress(SKIPPABLE_FRAME + DAT_130K_C + SKIPPABLE_FRAME)), + self._130_1K) + + # unknown size + self.assertEqual(decompress(SKIPPABLE_FRAME + self.UNKNOWN_FRAME_60), + self.DECOMPRESSED_60) + + self.assertEqual(decompress(self.UNKNOWN_FRAME_60 + SKIPPABLE_FRAME), + self.DECOMPRESSED_60) + + # 2 frames + 1 skippable + self.assertEqual(decompress(self.FRAME_42 + SKIPPABLE_FRAME + self.FRAME_60), + self.DECOMPRESSED_42_60) + + self.assertEqual(decompress(SKIPPABLE_FRAME + self.FRAME_42_60), + self.DECOMPRESSED_42_60) + + self.assertEqual(decompress(self.UNKNOWN_FRAME_42_60 + SKIPPABLE_FRAME), + self.DECOMPRESSED_42_60) + + # incomplete + with self.assertRaises(ZstdError): + decompress(SKIPPABLE_FRAME[:1]) + + with self.assertRaises(ZstdError): + decompress(SKIPPABLE_FRAME[:-1]) + + with self.assertRaises(ZstdError): + decompress(self.FRAME_42 + SKIPPABLE_FRAME[:-1]) + + # Unknown frame descriptor + with self.assertRaisesRegex(ZstdError, "Unknown frame descriptor"): + decompress(b'aaaaaaaaa' + SKIPPABLE_FRAME) + + with self.assertRaisesRegex(ZstdError, "Unknown frame descriptor"): + decompress(SKIPPABLE_FRAME + b'aaaaaaaaa') + + with self.assertRaisesRegex(ZstdError, "Unknown frame descriptor"): + decompress(SKIPPABLE_FRAME + SKIPPABLE_FRAME + b'aaaaaaaaa') + + def test_decompressor_1(self): + # empty 1 + d = ZstdDecompressor() + + dat = d.decompress(b'') + self.assertEqual(dat, b'') + self.assertFalse(d.eof) + self.assertTrue(d.needs_input) + self.assertEqual(d.unused_data, b'') + self.assertEqual(d.unused_data, b'') # twice + + dat = d.decompress(b'', 0) + self.assertEqual(dat, b'') + self.assertFalse(d.eof) + self.assertFalse(d.needs_input) + self.assertEqual(d.unused_data, b'') + self.assertEqual(d.unused_data, b'') # twice + + dat = d.decompress(COMPRESSED_100_PLUS_32KB + b'a') + self.assertEqual(dat, DECOMPRESSED_100_PLUS_32KB) + self.assertTrue(d.eof) + self.assertFalse(d.needs_input) + self.assertEqual(d.unused_data, b'a') + self.assertEqual(d.unused_data, b'a') # twice + + # empty 2 + d = ZstdDecompressor() + + dat = d.decompress(b'', 0) + self.assertEqual(dat, b'') + self.assertFalse(d.eof) + self.assertFalse(d.needs_input) + self.assertEqual(d.unused_data, b'') + self.assertEqual(d.unused_data, b'') # twice + + dat = d.decompress(b'') + self.assertEqual(dat, b'') + self.assertFalse(d.eof) + self.assertTrue(d.needs_input) + self.assertEqual(d.unused_data, b'') + self.assertEqual(d.unused_data, b'') # twice + + dat = d.decompress(COMPRESSED_100_PLUS_32KB + b'a') + self.assertEqual(dat, DECOMPRESSED_100_PLUS_32KB) + self.assertTrue(d.eof) + self.assertFalse(d.needs_input) + self.assertEqual(d.unused_data, b'a') + self.assertEqual(d.unused_data, b'a') # twice + + # 1 frame + d = ZstdDecompressor() + dat = d.decompress(self.FRAME_42) + + self.assertEqual(dat, self.DECOMPRESSED_42) + self.assertTrue(d.eof) + self.assertFalse(d.needs_input) + self.assertEqual(d.unused_data, b'') + self.assertEqual(d.unused_data, b'') # twice + + with self.assertRaises(EOFError): + d.decompress(b'') + + # 1 frame, trail + d = ZstdDecompressor() + dat = d.decompress(self.FRAME_42 + self.TRAIL) + + self.assertEqual(dat, self.DECOMPRESSED_42) + self.assertTrue(d.eof) + self.assertFalse(d.needs_input) + self.assertEqual(d.unused_data, self.TRAIL) + self.assertEqual(d.unused_data, self.TRAIL) # twice + + # 1 frame, 32_1K + temp = compress(b'a'*(32*_1K)) + d = ZstdDecompressor() + dat = d.decompress(temp, 32*_1K) + + self.assertEqual(dat, b'a'*(32*_1K)) + self.assertTrue(d.eof) + self.assertFalse(d.needs_input) + self.assertEqual(d.unused_data, b'') + self.assertEqual(d.unused_data, b'') # twice + + with self.assertRaises(EOFError): + d.decompress(b'') + + # 1 frame, 32_1K+100, trail + d = ZstdDecompressor() + dat = d.decompress(COMPRESSED_100_PLUS_32KB+self.TRAIL, 100) # 100 bytes + + self.assertEqual(len(dat), 100) + self.assertFalse(d.eof) + self.assertFalse(d.needs_input) + self.assertEqual(d.unused_data, b'') + + dat = d.decompress(b'') # 32_1K + + self.assertEqual(len(dat), 32*_1K) + self.assertTrue(d.eof) + self.assertFalse(d.needs_input) + self.assertEqual(d.unused_data, self.TRAIL) + self.assertEqual(d.unused_data, self.TRAIL) # twice + + with self.assertRaises(EOFError): + d.decompress(b'') + + # incomplete 1 + d = ZstdDecompressor() + dat = d.decompress(self.FRAME_60[:1]) + + self.assertFalse(d.eof) + self.assertTrue(d.needs_input) + self.assertEqual(d.unused_data, b'') + self.assertEqual(d.unused_data, b'') # twice + + # incomplete 2 + d = ZstdDecompressor() + + dat = d.decompress(self.FRAME_60[:-4]) + self.assertEqual(dat, self.DECOMPRESSED_60) + self.assertFalse(d.eof) + self.assertTrue(d.needs_input) + self.assertEqual(d.unused_data, b'') + self.assertEqual(d.unused_data, b'') # twice + + # incomplete 3 + d = ZstdDecompressor() + + dat = d.decompress(self.FRAME_60[:-1]) + self.assertEqual(dat, self.DECOMPRESSED_60) + self.assertFalse(d.eof) + self.assertTrue(d.needs_input) + self.assertEqual(d.unused_data, b'') + + # incomplete 4 + d = ZstdDecompressor() + + dat = d.decompress(self.FRAME_60[:-4], 60) + self.assertEqual(dat, self.DECOMPRESSED_60) + self.assertFalse(d.eof) + self.assertFalse(d.needs_input) + self.assertEqual(d.unused_data, b'') + self.assertEqual(d.unused_data, b'') # twice + + dat = d.decompress(b'') + self.assertEqual(dat, b'') + self.assertFalse(d.eof) + self.assertTrue(d.needs_input) + self.assertEqual(d.unused_data, b'') + self.assertEqual(d.unused_data, b'') # twice + + # Unknown frame descriptor + d = ZstdDecompressor() + with self.assertRaisesRegex(ZstdError, "Unknown frame descriptor"): + d.decompress(b'aaaaaaaaa') + + def test_decompressor_skippable(self): + # 1 skippable + d = ZstdDecompressor() + dat = d.decompress(SKIPPABLE_FRAME) + + self.assertEqual(dat, b'') + self.assertTrue(d.eof) + self.assertFalse(d.needs_input) + self.assertEqual(d.unused_data, b'') + self.assertEqual(d.unused_data, b'') # twice + + # 1 skippable, max_length=0 + d = ZstdDecompressor() + dat = d.decompress(SKIPPABLE_FRAME, 0) + + self.assertEqual(dat, b'') + self.assertTrue(d.eof) + self.assertFalse(d.needs_input) + self.assertEqual(d.unused_data, b'') + self.assertEqual(d.unused_data, b'') # twice + + # 1 skippable, trail + d = ZstdDecompressor() + dat = d.decompress(SKIPPABLE_FRAME + self.TRAIL) + + self.assertEqual(dat, b'') + self.assertTrue(d.eof) + self.assertFalse(d.needs_input) + self.assertEqual(d.unused_data, self.TRAIL) + self.assertEqual(d.unused_data, self.TRAIL) # twice + + # incomplete + d = ZstdDecompressor() + dat = d.decompress(SKIPPABLE_FRAME[:-1]) + + self.assertEqual(dat, b'') + self.assertFalse(d.eof) + self.assertTrue(d.needs_input) + self.assertEqual(d.unused_data, b'') + self.assertEqual(d.unused_data, b'') # twice + + # incomplete + d = ZstdDecompressor() + dat = d.decompress(SKIPPABLE_FRAME[:-1], 0) + + self.assertEqual(dat, b'') + self.assertFalse(d.eof) + self.assertFalse(d.needs_input) + self.assertEqual(d.unused_data, b'') + self.assertEqual(d.unused_data, b'') # twice + + dat = d.decompress(b'') + + self.assertEqual(dat, b'') + self.assertFalse(d.eof) + self.assertTrue(d.needs_input) + self.assertEqual(d.unused_data, b'') + self.assertEqual(d.unused_data, b'') # twice + + + +class ZstdDictTestCase(unittest.TestCase): + + def test_is_raw(self): + # must be passed as a keyword argument + with self.assertRaises(TypeError): + ZstdDict(bytes(8), True) + + # content < 8 + b = b'1234567' + with self.assertRaises(ValueError): + ZstdDict(b) + + # content == 8 + b = b'12345678' + zd = ZstdDict(b, is_raw=True) + self.assertEqual(zd.dict_id, 0) + + temp = compress(b'aaa12345678', level=3, zstd_dict=zd) + self.assertEqual(b'aaa12345678', decompress(temp, zd)) + + # is_raw == False + b = b'12345678abcd' + with self.assertRaises(ValueError): + ZstdDict(b) + + # read only attributes + with self.assertRaises(AttributeError): + zd.dict_content = b + + with self.assertRaises(AttributeError): + zd.dict_id = 10000 + + # ZstdDict arguments + zd = ZstdDict(TRAINED_DICT.dict_content, is_raw=False) + self.assertNotEqual(zd.dict_id, 0) + + zd = ZstdDict(TRAINED_DICT.dict_content, is_raw=True) + self.assertNotEqual(zd.dict_id, 0) # note this assertion + + with self.assertRaises(TypeError): + ZstdDict("12345678abcdef", is_raw=True) + with self.assertRaises(TypeError): + ZstdDict(TRAINED_DICT) + + # invalid parameter + with self.assertRaises(TypeError): + ZstdDict(desk333=345) + + def test_invalid_dict(self): + DICT_MAGIC = 0xEC30A437.to_bytes(4, byteorder='little') + dict_content = DICT_MAGIC + b'abcdefghighlmnopqrstuvwxyz' + + # corrupted + zd = ZstdDict(dict_content, is_raw=False) + with self.assertRaisesRegex(ZstdError, r'ZSTD_CDict.*?content\.$'): + ZstdCompressor(zstd_dict=zd.as_digested_dict) + with self.assertRaisesRegex(ZstdError, r'ZSTD_DDict.*?content\.$'): + ZstdDecompressor(zd) + + # wrong type + with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'): + ZstdCompressor(zstd_dict=[zd, 1]) + with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'): + ZstdCompressor(zstd_dict=(zd, 1.0)) + with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'): + ZstdCompressor(zstd_dict=(zd,)) + with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'): + ZstdCompressor(zstd_dict=(zd, 1, 2)) + with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'): + ZstdCompressor(zstd_dict=(zd, -1)) + with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'): + ZstdCompressor(zstd_dict=(zd, 3)) + with self.assertRaises(OverflowError): + ZstdCompressor(zstd_dict=(zd, 2**1000)) + with self.assertRaises(OverflowError): + ZstdCompressor(zstd_dict=(zd, -2**1000)) + + with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'): + ZstdDecompressor(zstd_dict=[zd, 1]) + with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'): + ZstdDecompressor(zstd_dict=(zd, 1.0)) + with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'): + ZstdDecompressor((zd,)) + with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'): + ZstdDecompressor((zd, 1, 2)) + with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'): + ZstdDecompressor((zd, -1)) + with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'): + ZstdDecompressor((zd, 3)) + with self.assertRaises(OverflowError): + ZstdDecompressor((zd, 2**1000)) + with self.assertRaises(OverflowError): + ZstdDecompressor((zd, -2**1000)) + + def test_train_dict(self): + TRAINED_DICT = train_dict(SAMPLES, DICT_SIZE1) + ZstdDict(TRAINED_DICT.dict_content, is_raw=False) + + self.assertNotEqual(TRAINED_DICT.dict_id, 0) + self.assertGreater(len(TRAINED_DICT.dict_content), 0) + self.assertLessEqual(len(TRAINED_DICT.dict_content), DICT_SIZE1) + self.assertTrue(re.match(r'^$', str(TRAINED_DICT))) + + # compress/decompress + c = ZstdCompressor(zstd_dict=TRAINED_DICT) + for sample in SAMPLES: + dat1 = compress(sample, zstd_dict=TRAINED_DICT) + dat2 = decompress(dat1, TRAINED_DICT) + self.assertEqual(sample, dat2) + + dat1 = c.compress(sample) + dat1 += c.flush() + dat2 = decompress(dat1, TRAINED_DICT) + self.assertEqual(sample, dat2) + + def test_finalize_dict(self): + DICT_SIZE2 = 200*_1K + C_LEVEL = 6 + + try: + dic2 = finalize_dict(TRAINED_DICT, SAMPLES, DICT_SIZE2, C_LEVEL) + except NotImplementedError: + # < v1.4.5 at compile-time, >= v.1.4.5 at run-time + return + + self.assertNotEqual(dic2.dict_id, 0) + self.assertGreater(len(dic2.dict_content), 0) + self.assertLessEqual(len(dic2.dict_content), DICT_SIZE2) + + # compress/decompress + c = ZstdCompressor(C_LEVEL, zstd_dict=dic2) + for sample in SAMPLES: + dat1 = compress(sample, C_LEVEL, zstd_dict=dic2) + dat2 = decompress(dat1, dic2) + self.assertEqual(sample, dat2) + + dat1 = c.compress(sample) + dat1 += c.flush() + dat2 = decompress(dat1, dic2) + self.assertEqual(sample, dat2) + + # dict mismatch + self.assertNotEqual(TRAINED_DICT.dict_id, dic2.dict_id) + + dat1 = compress(SAMPLES[0], zstd_dict=TRAINED_DICT) + with self.assertRaises(ZstdError): + decompress(dat1, dic2) + + def test_train_dict_arguments(self): + with self.assertRaises(ValueError): + train_dict([], 100*_1K) + + with self.assertRaises(ValueError): + train_dict(SAMPLES, -100) + + with self.assertRaises(ValueError): + train_dict(SAMPLES, 0) + + def test_finalize_dict_arguments(self): + with self.assertRaises(TypeError): + finalize_dict({1:2}, (b'aaa', b'bbb'), 100*_1K, 2) + + with self.assertRaises(ValueError): + finalize_dict(TRAINED_DICT, [], 100*_1K, 2) + + with self.assertRaises(ValueError): + finalize_dict(TRAINED_DICT, SAMPLES, -100, 2) + + with self.assertRaises(ValueError): + finalize_dict(TRAINED_DICT, SAMPLES, 0, 2) + + def test_train_dict_c(self): + # argument wrong type + with self.assertRaises(TypeError): + _zstd.train_dict({}, (), 100) + with self.assertRaises(TypeError): + _zstd.train_dict(bytearray(), (), 100) + with self.assertRaises(TypeError): + _zstd.train_dict(b'', 99, 100) + with self.assertRaises(TypeError): + _zstd.train_dict(b'', [], 100) + with self.assertRaises(TypeError): + _zstd.train_dict(b'', (), 100.1) + with self.assertRaises(TypeError): + _zstd.train_dict(b'', (99.1,), 100) + with self.assertRaises(ValueError): + _zstd.train_dict(b'abc', (4, -1), 100) + with self.assertRaises(ValueError): + _zstd.train_dict(b'abc', (2,), 100) + with self.assertRaises(ValueError): + _zstd.train_dict(b'', (99,), 100) + + # size > size_t + with self.assertRaises(ValueError): + _zstd.train_dict(b'', (2**1000,), 100) + with self.assertRaises(ValueError): + _zstd.train_dict(b'', (-2**1000,), 100) + + # dict_size <= 0 + with self.assertRaises(ValueError): + _zstd.train_dict(b'', (), 0) + with self.assertRaises(ValueError): + _zstd.train_dict(b'', (), -1) + + with self.assertRaises(ZstdError): + _zstd.train_dict(b'', (), 1) + + def test_finalize_dict_c(self): + with self.assertRaises(TypeError): + _zstd.finalize_dict(1, 2, 3, 4, 5) + + # argument wrong type + with self.assertRaises(TypeError): + _zstd.finalize_dict({}, b'', (), 100, 5) + with self.assertRaises(TypeError): + _zstd.finalize_dict(bytearray(TRAINED_DICT.dict_content), b'', (), 100, 5) + with self.assertRaises(TypeError): + _zstd.finalize_dict(TRAINED_DICT.dict_content, {}, (), 100, 5) + with self.assertRaises(TypeError): + _zstd.finalize_dict(TRAINED_DICT.dict_content, bytearray(), (), 100, 5) + with self.assertRaises(TypeError): + _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', 99, 100, 5) + with self.assertRaises(TypeError): + _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', [], 100, 5) + with self.assertRaises(TypeError): + _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), 100.1, 5) + with self.assertRaises(TypeError): + _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), 100, 5.1) + + with self.assertRaises(ValueError): + _zstd.finalize_dict(TRAINED_DICT.dict_content, b'abc', (4, -1), 100, 5) + with self.assertRaises(ValueError): + _zstd.finalize_dict(TRAINED_DICT.dict_content, b'abc', (2,), 100, 5) + with self.assertRaises(ValueError): + _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (99,), 100, 5) + + # size > size_t + with self.assertRaises(ValueError): + _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (2**1000,), 100, 5) + with self.assertRaises(ValueError): + _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (-2**1000,), 100, 5) + + # dict_size <= 0 + with self.assertRaises(ValueError): + _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), 0, 5) + with self.assertRaises(ValueError): + _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), -1, 5) + with self.assertRaises(OverflowError): + _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), 2**1000, 5) + with self.assertRaises(OverflowError): + _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), -2**1000, 5) + + with self.assertRaises(OverflowError): + _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), 100, 2**1000) + with self.assertRaises(OverflowError): + _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), 100, -2**1000) + + with self.assertRaises(ZstdError): + _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), 100, 5) + + def test_train_buffer_protocol_samples(self): + def _nbytes(dat): + if isinstance(dat, (bytes, bytearray)): + return len(dat) + return memoryview(dat).nbytes + + # prepare samples + chunk_lst = [] + wrong_size_lst = [] + correct_size_lst = [] + for _ in range(300): + arr = array.array('Q', [random.randint(0, 20) for i in range(20)]) + chunk_lst.append(arr) + correct_size_lst.append(_nbytes(arr)) + wrong_size_lst.append(len(arr)) + concatenation = b''.join(chunk_lst) + + # wrong size list + with self.assertRaisesRegex(ValueError, + "The samples size tuple doesn't match the concatenation's size"): + _zstd.train_dict(concatenation, tuple(wrong_size_lst), 100*_1K) + + # correct size list + _zstd.train_dict(concatenation, tuple(correct_size_lst), 3*_1K) + + # wrong size list + with self.assertRaisesRegex(ValueError, + "The samples size tuple doesn't match the concatenation's size"): + _zstd.finalize_dict(TRAINED_DICT.dict_content, + concatenation, tuple(wrong_size_lst), 300*_1K, 5) + + # correct size list + _zstd.finalize_dict(TRAINED_DICT.dict_content, + concatenation, tuple(correct_size_lst), 300*_1K, 5) + + def test_as_prefix(self): + # V1 + V1 = THIS_FILE_BYTES + zd = ZstdDict(V1, is_raw=True) + + # V2 + mid = len(V1) // 2 + V2 = V1[:mid] + \ + (b'a' if V1[mid] != int.from_bytes(b'a') else b'b') + \ + V1[mid+1:] + + # compress + dat = compress(V2, zstd_dict=zd.as_prefix) + self.assertEqual(get_frame_info(dat).dictionary_id, 0) + + # decompress + self.assertEqual(decompress(dat, zd.as_prefix), V2) + + # use wrong prefix + zd2 = ZstdDict(SAMPLES[0], is_raw=True) + try: + decompressed = decompress(dat, zd2.as_prefix) + except ZstdError: # expected + pass + else: + self.assertNotEqual(decompressed, V2) + + # read only attribute + with self.assertRaises(AttributeError): + zd.as_prefix = b'1234' + + def test_as_digested_dict(self): + zd = TRAINED_DICT + + # test .as_digested_dict + dat = compress(SAMPLES[0], zstd_dict=zd.as_digested_dict) + self.assertEqual(decompress(dat, zd.as_digested_dict), SAMPLES[0]) + with self.assertRaises(AttributeError): + zd.as_digested_dict = b'1234' + + # test .as_undigested_dict + dat = compress(SAMPLES[0], zstd_dict=zd.as_undigested_dict) + self.assertEqual(decompress(dat, zd.as_undigested_dict), SAMPLES[0]) + with self.assertRaises(AttributeError): + zd.as_undigested_dict = b'1234' + + def test_advanced_compression_parameters(self): + options = {CompressionParameter.compression_level: 6, + CompressionParameter.window_log: 20, + CompressionParameter.enable_long_distance_matching: 1} + + # automatically select + dat = compress(SAMPLES[0], options=options, zstd_dict=TRAINED_DICT) + self.assertEqual(decompress(dat, TRAINED_DICT), SAMPLES[0]) + + # explicitly select + dat = compress(SAMPLES[0], options=options, zstd_dict=TRAINED_DICT.as_digested_dict) + self.assertEqual(decompress(dat, TRAINED_DICT), SAMPLES[0]) + + def test_len(self): + self.assertEqual(len(TRAINED_DICT), len(TRAINED_DICT.dict_content)) + self.assertIn(str(len(TRAINED_DICT)), str(TRAINED_DICT)) + +class FileTestCase(unittest.TestCase): + def setUp(self): + self.DECOMPRESSED_42 = b'a'*42 + self.FRAME_42 = compress(self.DECOMPRESSED_42) + + def test_init(self): + with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f: + pass + with ZstdFile(io.BytesIO(), "w") as f: + pass + with ZstdFile(io.BytesIO(), "x") as f: + pass + with ZstdFile(io.BytesIO(), "a") as f: + pass + + with ZstdFile(io.BytesIO(), "w", level=12) as f: + pass + with ZstdFile(io.BytesIO(), "w", options={CompressionParameter.checksum_flag:1}) as f: + pass + with ZstdFile(io.BytesIO(), "w", options={}) as f: + pass + with ZstdFile(io.BytesIO(), "w", level=20, zstd_dict=TRAINED_DICT) as f: + pass + + with ZstdFile(io.BytesIO(), "r", options={DecompressionParameter.window_log_max:25}) as f: + pass + with ZstdFile(io.BytesIO(), "r", options={}, zstd_dict=TRAINED_DICT) as f: + pass + + def test_init_with_PathLike_filename(self): + with tempfile.NamedTemporaryFile(delete=False) as tmp_f: + filename = pathlib.Path(tmp_f.name) + + with ZstdFile(filename, "a") as f: + f.write(DECOMPRESSED_100_PLUS_32KB) + with ZstdFile(filename) as f: + self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB) + + with ZstdFile(filename, "a") as f: + f.write(DECOMPRESSED_100_PLUS_32KB) + with ZstdFile(filename) as f: + self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB * 2) + + os.remove(filename) + + def test_init_with_filename(self): + with tempfile.NamedTemporaryFile(delete=False) as tmp_f: + filename = pathlib.Path(tmp_f.name) + + with ZstdFile(filename) as f: + pass + with ZstdFile(filename, "w") as f: + pass + with ZstdFile(filename, "a") as f: + pass + + os.remove(filename) + + def test_init_mode(self): + bi = io.BytesIO() + + with ZstdFile(bi, "r"): + pass + with ZstdFile(bi, "rb"): + pass + with ZstdFile(bi, "w"): + pass + with ZstdFile(bi, "wb"): + pass + with ZstdFile(bi, "a"): + pass + with ZstdFile(bi, "ab"): + pass + + def test_init_with_x_mode(self): + with tempfile.NamedTemporaryFile() as tmp_f: + filename = pathlib.Path(tmp_f.name) + + for mode in ("x", "xb"): + with ZstdFile(filename, mode): + pass + with self.assertRaises(FileExistsError): + with ZstdFile(filename, mode): + pass + os.remove(filename) + + def test_init_bad_mode(self): + with self.assertRaises(ValueError): + ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), (3, "x")) + with self.assertRaises(ValueError): + ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "") + with self.assertRaises(ValueError): + ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "xt") + with self.assertRaises(ValueError): + ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "x+") + with self.assertRaises(ValueError): + ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "rx") + with self.assertRaises(ValueError): + ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "wx") + with self.assertRaises(ValueError): + ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "rt") + with self.assertRaises(ValueError): + ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "r+") + with self.assertRaises(ValueError): + ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "wt") + with self.assertRaises(ValueError): + ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "w+") + with self.assertRaises(ValueError): + ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "rw") + + with self.assertRaisesRegex(TypeError, + r"not be a CompressionParameter"): + ZstdFile(io.BytesIO(), 'rb', + options={CompressionParameter.compression_level:5}) + with self.assertRaisesRegex(TypeError, + r"not be a DecompressionParameter"): + ZstdFile(io.BytesIO(), 'wb', + options={DecompressionParameter.window_log_max:21}) + + with self.assertRaises(TypeError): + ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "r", level=12) + + def test_init_bad_check(self): + with self.assertRaises(TypeError): + ZstdFile(io.BytesIO(), "w", level='asd') + # CHECK_UNKNOWN and anything above CHECK_ID_MAX should be invalid. + with self.assertRaises(ValueError): + ZstdFile(io.BytesIO(), "w", options={999:9999}) + with self.assertRaises(ValueError): + ZstdFile(io.BytesIO(), "w", options={CompressionParameter.window_log:99}) + + with self.assertRaises(TypeError): + ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "r", options=33) + + with self.assertRaises(OverflowError): + ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), + options={DecompressionParameter.window_log_max:2**31}) + + with self.assertRaises(ValueError): + ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), + options={444:333}) + + with self.assertRaises(TypeError): + ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), zstd_dict={1:2}) + + with self.assertRaises(TypeError): + ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), zstd_dict=b'dict123456') + + def test_init_close_fp(self): + # get a temp file name + with tempfile.NamedTemporaryFile(delete=False) as tmp_f: + tmp_f.write(DAT_130K_C) + filename = tmp_f.name + + with self.assertRaises(TypeError): + ZstdFile(filename, options={'a':'b'}) + + # for PyPy + gc.collect() + + os.remove(filename) + + def test_close(self): + with io.BytesIO(COMPRESSED_100_PLUS_32KB) as src: + f = ZstdFile(src) + f.close() + # ZstdFile.close() should not close the underlying file object. + self.assertFalse(src.closed) + # Try closing an already-closed ZstdFile. + f.close() + self.assertFalse(src.closed) + + # Test with a real file on disk, opened directly by ZstdFile. + with tempfile.NamedTemporaryFile(delete=False) as tmp_f: + filename = pathlib.Path(tmp_f.name) + + f = ZstdFile(filename) + fp = f._fp + f.close() + # Here, ZstdFile.close() *should* close the underlying file object. + self.assertTrue(fp.closed) + # Try closing an already-closed ZstdFile. + f.close() + + os.remove(filename) + + def test_closed(self): + f = ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) + try: + self.assertFalse(f.closed) + f.read() + self.assertFalse(f.closed) + finally: + f.close() + self.assertTrue(f.closed) + + f = ZstdFile(io.BytesIO(), "w") + try: + self.assertFalse(f.closed) + finally: + f.close() + self.assertTrue(f.closed) + + def test_fileno(self): + # 1 + f = ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) + try: + self.assertRaises(io.UnsupportedOperation, f.fileno) + finally: + f.close() + self.assertRaises(ValueError, f.fileno) + + # 2 + with tempfile.NamedTemporaryFile(delete=False) as tmp_f: + filename = pathlib.Path(tmp_f.name) + + f = ZstdFile(filename) + try: + self.assertEqual(f.fileno(), f._fp.fileno()) + self.assertIsInstance(f.fileno(), int) + finally: + f.close() + self.assertRaises(ValueError, f.fileno) + + os.remove(filename) + + # 3, no .fileno() method + class C: + def read(self, size=-1): + return b'123' + with ZstdFile(C(), 'rb') as f: + with self.assertRaisesRegex(AttributeError, r'fileno'): + f.fileno() + + def test_name(self): + # 1 + f = ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) + try: + with self.assertRaises(AttributeError): + f.name + finally: + f.close() + with self.assertRaises(ValueError): + f.name + + # 2 + with tempfile.NamedTemporaryFile(delete=False) as tmp_f: + filename = pathlib.Path(tmp_f.name) + + f = ZstdFile(filename) + try: + self.assertEqual(f.name, f._fp.name) + self.assertIsInstance(f.name, str) + finally: + f.close() + with self.assertRaises(ValueError): + f.name + + os.remove(filename) + + # 3, no .filename property + class C: + def read(self, size=-1): + return b'123' + with ZstdFile(C(), 'rb') as f: + with self.assertRaisesRegex(AttributeError, r'name'): + f.name + + def test_seekable(self): + f = ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) + try: + self.assertTrue(f.seekable()) + f.read() + self.assertTrue(f.seekable()) + finally: + f.close() + self.assertRaises(ValueError, f.seekable) + + f = ZstdFile(io.BytesIO(), "w") + try: + self.assertFalse(f.seekable()) + finally: + f.close() + self.assertRaises(ValueError, f.seekable) + + src = io.BytesIO(COMPRESSED_100_PLUS_32KB) + src.seekable = lambda: False + f = ZstdFile(src) + try: + self.assertFalse(f.seekable()) + finally: + f.close() + self.assertRaises(ValueError, f.seekable) + + def test_readable(self): + f = ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) + try: + self.assertTrue(f.readable()) + f.read() + self.assertTrue(f.readable()) + finally: + f.close() + self.assertRaises(ValueError, f.readable) + + f = ZstdFile(io.BytesIO(), "w") + try: + self.assertFalse(f.readable()) + finally: + f.close() + self.assertRaises(ValueError, f.readable) + + def test_writable(self): + f = ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) + try: + self.assertFalse(f.writable()) + f.read() + self.assertFalse(f.writable()) + finally: + f.close() + self.assertRaises(ValueError, f.writable) + + f = ZstdFile(io.BytesIO(), "w") + try: + self.assertTrue(f.writable()) + finally: + f.close() + self.assertRaises(ValueError, f.writable) + + def test_read_0(self): + with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f: + self.assertEqual(f.read(0), b"") + self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB) + with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), + options={DecompressionParameter.window_log_max:20}) as f: + self.assertEqual(f.read(0), b"") + + # empty file + with ZstdFile(io.BytesIO(b'')) as f: + self.assertEqual(f.read(0), b"") + with self.assertRaises(EOFError): + f.read(10) + + with ZstdFile(io.BytesIO(b'')) as f: + with self.assertRaises(EOFError): + f.read(10) + + def test_read_10(self): + with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f: + chunks = [] + while True: + result = f.read(10) + if not result: + break + self.assertLessEqual(len(result), 10) + chunks.append(result) + self.assertEqual(b"".join(chunks), DECOMPRESSED_100_PLUS_32KB) + + def test_read_multistream(self): + with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB * 5)) as f: + self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB * 5) + + with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB + SKIPPABLE_FRAME)) as f: + self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB) + + with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB + COMPRESSED_DAT)) as f: + self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB + DECOMPRESSED_DAT) + + def test_read_incomplete(self): + with ZstdFile(io.BytesIO(DAT_130K_C[:-200])) as f: + self.assertRaises(EOFError, f.read) + + # Trailing data isn't a valid compressed stream + with ZstdFile(io.BytesIO(self.FRAME_42 + b'12345')) as f: + self.assertRaises(ZstdError, f.read) + + with ZstdFile(io.BytesIO(SKIPPABLE_FRAME + b'12345')) as f: + self.assertRaises(ZstdError, f.read) + + def test_read_truncated(self): + # Drop stream epilogue: 4 bytes checksum + truncated = DAT_130K_C[:-4] + with ZstdFile(io.BytesIO(truncated)) as f: + self.assertRaises(EOFError, f.read) + + with ZstdFile(io.BytesIO(truncated)) as f: + # this is an important test, make sure it doesn't raise EOFError. + self.assertEqual(f.read(130*_1K), DAT_130K_D) + with self.assertRaises(EOFError): + f.read(1) + + # Incomplete header + for i in range(1, 20): + with ZstdFile(io.BytesIO(truncated[:i])) as f: + self.assertRaises(EOFError, f.read, 1) + + def test_read_bad_args(self): + f = ZstdFile(io.BytesIO(COMPRESSED_DAT)) + f.close() + self.assertRaises(ValueError, f.read) + with ZstdFile(io.BytesIO(), "w") as f: + self.assertRaises(ValueError, f.read) + with ZstdFile(io.BytesIO(COMPRESSED_DAT)) as f: + self.assertRaises(TypeError, f.read, float()) + + def test_read_bad_data(self): + with ZstdFile(io.BytesIO(COMPRESSED_BOGUS)) as f: + self.assertRaises(ZstdError, f.read) + + def test_read_exception(self): + class C: + def read(self, size=-1): + raise OSError + with ZstdFile(C()) as f: + with self.assertRaises(OSError): + f.read(10) + + def test_read1(self): + with ZstdFile(io.BytesIO(DAT_130K_C)) as f: + blocks = [] + while True: + result = f.read1() + if not result: + break + blocks.append(result) + self.assertEqual(b"".join(blocks), DAT_130K_D) + self.assertEqual(f.read1(), b"") + + def test_read1_0(self): + with ZstdFile(io.BytesIO(COMPRESSED_DAT)) as f: + self.assertEqual(f.read1(0), b"") + + def test_read1_10(self): + with ZstdFile(io.BytesIO(COMPRESSED_DAT)) as f: + blocks = [] + while True: + result = f.read1(10) + if not result: + break + blocks.append(result) + self.assertEqual(b"".join(blocks), DECOMPRESSED_DAT) + self.assertEqual(f.read1(), b"") + + def test_read1_multistream(self): + with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB * 5)) as f: + blocks = [] + while True: + result = f.read1() + if not result: + break + blocks.append(result) + self.assertEqual(b"".join(blocks), DECOMPRESSED_100_PLUS_32KB * 5) + self.assertEqual(f.read1(), b"") + + def test_read1_bad_args(self): + f = ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) + f.close() + self.assertRaises(ValueError, f.read1) + with ZstdFile(io.BytesIO(), "w") as f: + self.assertRaises(ValueError, f.read1) + with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f: + self.assertRaises(TypeError, f.read1, None) + + def test_readinto(self): + arr = array.array("I", range(100)) + self.assertEqual(len(arr), 100) + self.assertEqual(len(arr) * arr.itemsize, 400) + ba = bytearray(300) + with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f: + # 0 length output buffer + self.assertEqual(f.readinto(ba[0:0]), 0) + + # use correct length for buffer protocol object + self.assertEqual(f.readinto(arr), 400) + self.assertEqual(arr.tobytes(), DECOMPRESSED_100_PLUS_32KB[:400]) + + # normal readinto + self.assertEqual(f.readinto(ba), 300) + self.assertEqual(ba, DECOMPRESSED_100_PLUS_32KB[400:700]) + + def test_peek(self): + with ZstdFile(io.BytesIO(DAT_130K_C)) as f: + result = f.peek() + self.assertGreater(len(result), 0) + self.assertTrue(DAT_130K_D.startswith(result)) + self.assertEqual(f.read(), DAT_130K_D) + with ZstdFile(io.BytesIO(DAT_130K_C)) as f: + result = f.peek(10) + self.assertGreater(len(result), 0) + self.assertTrue(DAT_130K_D.startswith(result)) + self.assertEqual(f.read(), DAT_130K_D) + + def test_peek_bad_args(self): + with ZstdFile(io.BytesIO(), "w") as f: + self.assertRaises(ValueError, f.peek) + + def test_iterator(self): + with io.BytesIO(THIS_FILE_BYTES) as f: + lines = f.readlines() + compressed = compress(THIS_FILE_BYTES) + + # iter + with ZstdFile(io.BytesIO(compressed)) as f: + self.assertListEqual(list(iter(f)), lines) + + # readline + with ZstdFile(io.BytesIO(compressed)) as f: + for line in lines: + self.assertEqual(f.readline(), line) + self.assertEqual(f.readline(), b'') + self.assertEqual(f.readline(), b'') + + # readlines + with ZstdFile(io.BytesIO(compressed)) as f: + self.assertListEqual(f.readlines(), lines) + + def test_decompress_limited(self): + _ZSTD_DStreamInSize = 128*_1K + 3 + + bomb = compress(b'\0' * int(2e6), level=10) + self.assertLess(len(bomb), _ZSTD_DStreamInSize) + + decomp = ZstdFile(io.BytesIO(bomb)) + self.assertEqual(decomp.read(1), b'\0') + + # BufferedReader uses 128 KiB buffer in __init__.py + max_decomp = 128*_1K + self.assertLessEqual(decomp._buffer.raw.tell(), max_decomp, + "Excessive amount of data was decompressed") + + def test_write(self): + raw_data = THIS_FILE_BYTES[: len(THIS_FILE_BYTES) // 6] + with io.BytesIO() as dst: + with ZstdFile(dst, "w") as f: + f.write(raw_data) + + comp = ZstdCompressor() + expected = comp.compress(raw_data) + comp.flush() + self.assertEqual(dst.getvalue(), expected) + + with io.BytesIO() as dst: + with ZstdFile(dst, "w", level=12) as f: + f.write(raw_data) + + comp = ZstdCompressor(12) + expected = comp.compress(raw_data) + comp.flush() + self.assertEqual(dst.getvalue(), expected) + + with io.BytesIO() as dst: + with ZstdFile(dst, "w", options={CompressionParameter.checksum_flag:1}) as f: + f.write(raw_data) + + comp = ZstdCompressor(options={CompressionParameter.checksum_flag:1}) + expected = comp.compress(raw_data) + comp.flush() + self.assertEqual(dst.getvalue(), expected) + + with io.BytesIO() as dst: + options = {CompressionParameter.compression_level:-5, + CompressionParameter.checksum_flag:1} + with ZstdFile(dst, "w", + options=options) as f: + f.write(raw_data) + + comp = ZstdCompressor(options=options) + expected = comp.compress(raw_data) + comp.flush() + self.assertEqual(dst.getvalue(), expected) + + def test_write_empty_frame(self): + # .FLUSH_FRAME generates an empty content frame + c = ZstdCompressor() + self.assertNotEqual(c.flush(c.FLUSH_FRAME), b'') + self.assertNotEqual(c.flush(c.FLUSH_FRAME), b'') + + # don't generate empty content frame + bo = io.BytesIO() + with ZstdFile(bo, 'w') as f: + pass + self.assertEqual(bo.getvalue(), b'') + + bo = io.BytesIO() + with ZstdFile(bo, 'w') as f: + f.flush(f.FLUSH_FRAME) + self.assertEqual(bo.getvalue(), b'') + + # if .write(b''), generate empty content frame + bo = io.BytesIO() + with ZstdFile(bo, 'w') as f: + f.write(b'') + self.assertNotEqual(bo.getvalue(), b'') + + # has an empty content frame + bo = io.BytesIO() + with ZstdFile(bo, 'w') as f: + f.flush(f.FLUSH_BLOCK) + self.assertNotEqual(bo.getvalue(), b'') + + def test_write_empty_block(self): + # If no internal data, .FLUSH_BLOCK return b''. + c = ZstdCompressor() + self.assertEqual(c.flush(c.FLUSH_BLOCK), b'') + self.assertNotEqual(c.compress(b'123', c.FLUSH_BLOCK), + b'') + self.assertEqual(c.flush(c.FLUSH_BLOCK), b'') + self.assertEqual(c.compress(b''), b'') + self.assertEqual(c.compress(b''), b'') + self.assertEqual(c.flush(c.FLUSH_BLOCK), b'') + + # mode = .last_mode + bo = io.BytesIO() + with ZstdFile(bo, 'w') as f: + f.write(b'123') + f.flush(f.FLUSH_BLOCK) + fp_pos = f._fp.tell() + self.assertNotEqual(fp_pos, 0) + f.flush(f.FLUSH_BLOCK) + self.assertEqual(f._fp.tell(), fp_pos) + + # mode != .last_mode + bo = io.BytesIO() + with ZstdFile(bo, 'w') as f: + f.flush(f.FLUSH_BLOCK) + self.assertEqual(f._fp.tell(), 0) + f.write(b'') + f.flush(f.FLUSH_BLOCK) + self.assertEqual(f._fp.tell(), 0) + + def test_write_101(self): + with io.BytesIO() as dst: + with ZstdFile(dst, "w") as f: + for start in range(0, len(THIS_FILE_BYTES), 101): + f.write(THIS_FILE_BYTES[start:start+101]) + + comp = ZstdCompressor() + expected = comp.compress(THIS_FILE_BYTES) + comp.flush() + self.assertEqual(dst.getvalue(), expected) + + def test_write_append(self): + def comp(data): + comp = ZstdCompressor() + return comp.compress(data) + comp.flush() + + part1 = THIS_FILE_BYTES[:_1K] + part2 = THIS_FILE_BYTES[_1K:1536] + part3 = THIS_FILE_BYTES[1536:] + expected = b"".join(comp(x) for x in (part1, part2, part3)) + with io.BytesIO() as dst: + with ZstdFile(dst, "w") as f: + f.write(part1) + with ZstdFile(dst, "a") as f: + f.write(part2) + with ZstdFile(dst, "a") as f: + f.write(part3) + self.assertEqual(dst.getvalue(), expected) + + def test_write_bad_args(self): + f = ZstdFile(io.BytesIO(), "w") + f.close() + self.assertRaises(ValueError, f.write, b"foo") + with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "r") as f: + self.assertRaises(ValueError, f.write, b"bar") + with ZstdFile(io.BytesIO(), "w") as f: + self.assertRaises(TypeError, f.write, None) + self.assertRaises(TypeError, f.write, "text") + self.assertRaises(TypeError, f.write, 789) + + def test_writelines(self): + def comp(data): + comp = ZstdCompressor() + return comp.compress(data) + comp.flush() + + with io.BytesIO(THIS_FILE_BYTES) as f: + lines = f.readlines() + with io.BytesIO() as dst: + with ZstdFile(dst, "w") as f: + f.writelines(lines) + expected = comp(THIS_FILE_BYTES) + self.assertEqual(dst.getvalue(), expected) + + def test_seek_forward(self): + with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f: + f.seek(555) + self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB[555:]) + + def test_seek_forward_across_streams(self): + with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB * 2)) as f: + f.seek(len(DECOMPRESSED_100_PLUS_32KB) + 123) + self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB[123:]) + + def test_seek_forward_relative_to_current(self): + with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f: + f.read(100) + f.seek(1236, 1) + self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB[1336:]) + + def test_seek_forward_relative_to_end(self): + with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f: + f.seek(-555, 2) + self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB[-555:]) + + def test_seek_backward(self): + with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f: + f.read(1001) + f.seek(211) + self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB[211:]) + + def test_seek_backward_across_streams(self): + with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB * 2)) as f: + f.read(len(DECOMPRESSED_100_PLUS_32KB) + 333) + f.seek(737) + self.assertEqual(f.read(), + DECOMPRESSED_100_PLUS_32KB[737:] + DECOMPRESSED_100_PLUS_32KB) + + def test_seek_backward_relative_to_end(self): + with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f: + f.seek(-150, 2) + self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB[-150:]) + + def test_seek_past_end(self): + with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f: + f.seek(len(DECOMPRESSED_100_PLUS_32KB) + 9001) + self.assertEqual(f.tell(), len(DECOMPRESSED_100_PLUS_32KB)) + self.assertEqual(f.read(), b"") + + def test_seek_past_start(self): + with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f: + f.seek(-88) + self.assertEqual(f.tell(), 0) + self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB) + + def test_seek_bad_args(self): + f = ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) + f.close() + self.assertRaises(ValueError, f.seek, 0) + with ZstdFile(io.BytesIO(), "w") as f: + self.assertRaises(ValueError, f.seek, 0) + with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f: + self.assertRaises(ValueError, f.seek, 0, 3) + # io.BufferedReader raises TypeError instead of ValueError + self.assertRaises((TypeError, ValueError), f.seek, 9, ()) + self.assertRaises(TypeError, f.seek, None) + self.assertRaises(TypeError, f.seek, b"derp") + + def test_seek_not_seekable(self): + class C(io.BytesIO): + def seekable(self): + return False + obj = C(COMPRESSED_100_PLUS_32KB) + with ZstdFile(obj, 'r') as f: + d = f.read(1) + self.assertFalse(f.seekable()) + with self.assertRaisesRegex(io.UnsupportedOperation, + 'File or stream is not seekable'): + f.seek(0) + d += f.read() + self.assertEqual(d, DECOMPRESSED_100_PLUS_32KB) + + def test_tell(self): + with ZstdFile(io.BytesIO(DAT_130K_C)) as f: + pos = 0 + while True: + self.assertEqual(f.tell(), pos) + result = f.read(random.randint(171, 189)) + if not result: + break + pos += len(result) + self.assertEqual(f.tell(), len(DAT_130K_D)) + with ZstdFile(io.BytesIO(), "w") as f: + for pos in range(0, len(DAT_130K_D), 143): + self.assertEqual(f.tell(), pos) + f.write(DAT_130K_D[pos:pos+143]) + self.assertEqual(f.tell(), len(DAT_130K_D)) + + def test_tell_bad_args(self): + f = ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) + f.close() + self.assertRaises(ValueError, f.tell) + + def test_file_dict(self): + # default + bi = io.BytesIO() + with ZstdFile(bi, 'w', zstd_dict=TRAINED_DICT) as f: + f.write(SAMPLES[0]) + bi.seek(0) + with ZstdFile(bi, zstd_dict=TRAINED_DICT) as f: + dat = f.read() + self.assertEqual(dat, SAMPLES[0]) + + # .as_(un)digested_dict + bi = io.BytesIO() + with ZstdFile(bi, 'w', zstd_dict=TRAINED_DICT.as_digested_dict) as f: + f.write(SAMPLES[0]) + bi.seek(0) + with ZstdFile(bi, zstd_dict=TRAINED_DICT.as_undigested_dict) as f: + dat = f.read() + self.assertEqual(dat, SAMPLES[0]) + + def test_file_prefix(self): + bi = io.BytesIO() + with ZstdFile(bi, 'w', zstd_dict=TRAINED_DICT.as_prefix) as f: + f.write(SAMPLES[0]) + bi.seek(0) + with ZstdFile(bi, zstd_dict=TRAINED_DICT.as_prefix) as f: + dat = f.read() + self.assertEqual(dat, SAMPLES[0]) + + def test_UnsupportedOperation(self): + # 1 + with ZstdFile(io.BytesIO(), 'r') as f: + with self.assertRaises(io.UnsupportedOperation): + f.write(b'1234') + + # 2 + class T: + def read(self, size): + return b'a' * size + + with self.assertRaises(TypeError): # on creation + with ZstdFile(T(), 'w') as f: + pass + + # 3 + with ZstdFile(io.BytesIO(), 'w') as f: + with self.assertRaises(io.UnsupportedOperation): + f.read(100) + with self.assertRaises(io.UnsupportedOperation): + f.seek(100) + self.assertEqual(f.closed, True) + with self.assertRaises(ValueError): + f.readable() + with self.assertRaises(ValueError): + f.tell() + with self.assertRaises(ValueError): + f.read(100) + + def test_read_readinto_readinto1(self): + lst = [] + with ZstdFile(io.BytesIO(COMPRESSED_THIS_FILE*5)) as f: + while True: + method = random.randint(0, 2) + size = random.randint(0, 300) + + if method == 0: + dat = f.read(size) + if not dat and size: + break + lst.append(dat) + elif method == 1: + ba = bytearray(size) + read_size = f.readinto(ba) + if read_size == 0 and size: + break + lst.append(bytes(ba[:read_size])) + elif method == 2: + ba = bytearray(size) + read_size = f.readinto1(ba) + if read_size == 0 and size: + break + lst.append(bytes(ba[:read_size])) + self.assertEqual(b''.join(lst), THIS_FILE_BYTES*5) + + def test_zstdfile_flush(self): + # closed + f = ZstdFile(io.BytesIO(), 'w') + f.close() + with self.assertRaises(ValueError): + f.flush() + + # read + with ZstdFile(io.BytesIO(), 'r') as f: + # does nothing for read-only stream + f.flush() + + # write + DAT = b'abcd' + bi = io.BytesIO() + with ZstdFile(bi, 'w') as f: + self.assertEqual(f.write(DAT), len(DAT)) + self.assertEqual(f.tell(), len(DAT)) + self.assertEqual(bi.tell(), 0) # not enough for a block + + self.assertEqual(f.flush(), None) + self.assertEqual(f.tell(), len(DAT)) + self.assertGreater(bi.tell(), 0) # flushed + + # write, no .flush() method + class C: + def write(self, b): + return len(b) + with ZstdFile(C(), 'w') as f: + self.assertEqual(f.write(DAT), len(DAT)) + self.assertEqual(f.tell(), len(DAT)) + + self.assertEqual(f.flush(), None) + self.assertEqual(f.tell(), len(DAT)) + + def test_zstdfile_flush_mode(self): + self.assertEqual(ZstdFile.FLUSH_BLOCK, ZstdCompressor.FLUSH_BLOCK) + self.assertEqual(ZstdFile.FLUSH_FRAME, ZstdCompressor.FLUSH_FRAME) + with self.assertRaises(AttributeError): + ZstdFile.CONTINUE + + bo = io.BytesIO() + with ZstdFile(bo, 'w') as f: + # flush block + self.assertEqual(f.write(b'123'), 3) + self.assertIsNone(f.flush(f.FLUSH_BLOCK)) + p1 = bo.tell() + # mode == .last_mode, should return + self.assertIsNone(f.flush()) + p2 = bo.tell() + self.assertEqual(p1, p2) + # flush frame + self.assertEqual(f.write(b'456'), 3) + self.assertIsNone(f.flush(mode=f.FLUSH_FRAME)) + # flush frame + self.assertEqual(f.write(b'789'), 3) + self.assertIsNone(f.flush(f.FLUSH_FRAME)) + p1 = bo.tell() + # mode == .last_mode, should return + self.assertIsNone(f.flush(f.FLUSH_FRAME)) + p2 = bo.tell() + self.assertEqual(p1, p2) + self.assertEqual(decompress(bo.getvalue()), b'123456789') + + bo = io.BytesIO() + with ZstdFile(bo, 'w') as f: + f.write(b'123') + with self.assertRaisesRegex(ValueError, r'\.FLUSH_.*?\.FLUSH_'): + f.flush(ZstdCompressor.CONTINUE) + with self.assertRaises(ValueError): + f.flush(-1) + with self.assertRaises(ValueError): + f.flush(123456) + with self.assertRaises(TypeError): + f.flush(node=ZstdCompressor.CONTINUE) + with self.assertRaises((TypeError, ValueError)): + f.flush('FLUSH_FRAME') + with self.assertRaises(TypeError): + f.flush(b'456', f.FLUSH_BLOCK) + + def test_zstdfile_truncate(self): + with ZstdFile(io.BytesIO(), 'w') as f: + with self.assertRaises(io.UnsupportedOperation): + f.truncate(200) + + def test_zstdfile_iter_issue45475(self): + lines = [l for l in ZstdFile(io.BytesIO(COMPRESSED_THIS_FILE))] + self.assertGreater(len(lines), 0) + + def test_append_new_file(self): + with tempfile.NamedTemporaryFile(delete=True) as tmp_f: + filename = tmp_f.name + + with ZstdFile(filename, 'a') as f: + pass + self.assertTrue(os.path.isfile(filename)) + + os.remove(filename) + +class OpenTestCase(unittest.TestCase): + + def test_binary_modes(self): + with open(io.BytesIO(COMPRESSED_100_PLUS_32KB), "rb") as f: + self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB) + with io.BytesIO() as bio: + with open(bio, "wb") as f: + f.write(DECOMPRESSED_100_PLUS_32KB) + file_data = decompress(bio.getvalue()) + self.assertEqual(file_data, DECOMPRESSED_100_PLUS_32KB) + with open(bio, "ab") as f: + f.write(DECOMPRESSED_100_PLUS_32KB) + file_data = decompress(bio.getvalue()) + self.assertEqual(file_data, DECOMPRESSED_100_PLUS_32KB * 2) + + def test_text_modes(self): + # empty input + with self.assertRaises(EOFError): + with open(io.BytesIO(b''), "rt", encoding="utf-8", newline='\n') as reader: + for _ in reader: + pass + + # read + uncompressed = THIS_FILE_STR.replace(os.linesep, "\n") + with open(io.BytesIO(COMPRESSED_THIS_FILE), "rt", encoding="utf-8") as f: + self.assertEqual(f.read(), uncompressed) + + with io.BytesIO() as bio: + # write + with open(bio, "wt", encoding="utf-8") as f: + f.write(uncompressed) + file_data = decompress(bio.getvalue()).decode("utf-8") + self.assertEqual(file_data.replace(os.linesep, "\n"), uncompressed) + # append + with open(bio, "at", encoding="utf-8") as f: + f.write(uncompressed) + file_data = decompress(bio.getvalue()).decode("utf-8") + self.assertEqual(file_data.replace(os.linesep, "\n"), uncompressed * 2) + + def test_bad_params(self): + with tempfile.NamedTemporaryFile(delete=False) as tmp_f: + TESTFN = pathlib.Path(tmp_f.name) + + with self.assertRaises(ValueError): + open(TESTFN, "") + with self.assertRaises(ValueError): + open(TESTFN, "rbt") + with self.assertRaises(ValueError): + open(TESTFN, "rb", encoding="utf-8") + with self.assertRaises(ValueError): + open(TESTFN, "rb", errors="ignore") + with self.assertRaises(ValueError): + open(TESTFN, "rb", newline="\n") + + os.remove(TESTFN) + + def test_option(self): + options = {DecompressionParameter.window_log_max:25} + with open(io.BytesIO(COMPRESSED_100_PLUS_32KB), "rb", options=options) as f: + self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB) + + options = {CompressionParameter.compression_level:12} + with io.BytesIO() as bio: + with open(bio, "wb", options=options) as f: + f.write(DECOMPRESSED_100_PLUS_32KB) + file_data = decompress(bio.getvalue()) + self.assertEqual(file_data, DECOMPRESSED_100_PLUS_32KB) + + def test_encoding(self): + uncompressed = THIS_FILE_STR.replace(os.linesep, "\n") + + with io.BytesIO() as bio: + with open(bio, "wt", encoding="utf-16-le") as f: + f.write(uncompressed) + file_data = decompress(bio.getvalue()).decode("utf-16-le") + self.assertEqual(file_data.replace(os.linesep, "\n"), uncompressed) + bio.seek(0) + with open(bio, "rt", encoding="utf-16-le") as f: + self.assertEqual(f.read().replace(os.linesep, "\n"), uncompressed) + + def test_encoding_error_handler(self): + with io.BytesIO(compress(b"foo\xffbar")) as bio: + with open(bio, "rt", encoding="ascii", errors="ignore") as f: + self.assertEqual(f.read(), "foobar") + + def test_newline(self): + # Test with explicit newline (universal newline mode disabled). + text = THIS_FILE_STR.replace(os.linesep, "\n") + with io.BytesIO() as bio: + with open(bio, "wt", encoding="utf-8", newline="\n") as f: + f.write(text) + bio.seek(0) + with open(bio, "rt", encoding="utf-8", newline="\r") as f: + self.assertEqual(f.readlines(), [text]) + + def test_x_mode(self): + with tempfile.NamedTemporaryFile(delete=False) as tmp_f: + TESTFN = pathlib.Path(tmp_f.name) + + for mode in ("x", "xb", "xt"): + os.remove(TESTFN) + + if mode == "xt": + encoding = "utf-8" + else: + encoding = None + with open(TESTFN, mode, encoding=encoding): + pass + with self.assertRaises(FileExistsError): + with open(TESTFN, mode): + pass + + os.remove(TESTFN) + + def test_open_dict(self): + # default + bi = io.BytesIO() + with open(bi, 'w', zstd_dict=TRAINED_DICT) as f: + f.write(SAMPLES[0]) + bi.seek(0) + with open(bi, zstd_dict=TRAINED_DICT) as f: + dat = f.read() + self.assertEqual(dat, SAMPLES[0]) + + # .as_(un)digested_dict + bi = io.BytesIO() + with open(bi, 'w', zstd_dict=TRAINED_DICT.as_digested_dict) as f: + f.write(SAMPLES[0]) + bi.seek(0) + with open(bi, zstd_dict=TRAINED_DICT.as_undigested_dict) as f: + dat = f.read() + self.assertEqual(dat, SAMPLES[0]) + + # invalid dictionary + bi = io.BytesIO() + with self.assertRaisesRegex(TypeError, 'zstd_dict'): + open(bi, 'w', zstd_dict={1:2, 2:3}) + + with self.assertRaisesRegex(TypeError, 'zstd_dict'): + open(bi, 'w', zstd_dict=b'1234567890') + + def test_open_prefix(self): + bi = io.BytesIO() + with open(bi, 'w', zstd_dict=TRAINED_DICT.as_prefix) as f: + f.write(SAMPLES[0]) + bi.seek(0) + with open(bi, zstd_dict=TRAINED_DICT.as_prefix) as f: + dat = f.read() + self.assertEqual(dat, SAMPLES[0]) + + def test_buffer_protocol(self): + # don't use len() for buffer protocol objects + arr = array.array("i", range(1000)) + LENGTH = len(arr) * arr.itemsize + + with open(io.BytesIO(), "wb") as f: + self.assertEqual(f.write(arr), LENGTH) + self.assertEqual(f.tell(), LENGTH) + +class FreeThreadingMethodTests(unittest.TestCase): + + @threading_helper.reap_threads + @threading_helper.requires_working_threading() + def test_compress_locking(self): + input = b'a'* (16*_1K) + num_threads = 8 + + # gh-136394: the first output of .compress() includes the frame header + # we run the first .compress() call outside of the threaded portion + # to make the test order-independent + + comp = ZstdCompressor() + parts = [comp.compress(input, ZstdCompressor.FLUSH_BLOCK)] + for _ in range(num_threads): + res = comp.compress(input, ZstdCompressor.FLUSH_BLOCK) + if res: + parts.append(res) + rest1 = comp.flush() + expected = b''.join(parts) + rest1 + + comp = ZstdCompressor() + output = [comp.compress(input, ZstdCompressor.FLUSH_BLOCK)] + def run_method(method, input_data, output_data): + res = method(input_data, ZstdCompressor.FLUSH_BLOCK) + if res: + output_data.append(res) + threads = [] + + for i in range(num_threads): + thread = threading.Thread(target=run_method, args=(comp.compress, input, output)) + + threads.append(thread) + + with threading_helper.start_threads(threads): + pass + + rest2 = comp.flush() + self.assertEqual(rest1, rest2) + actual = b''.join(output) + rest2 + self.assertEqual(expected, actual) + + @threading_helper.reap_threads + @threading_helper.requires_working_threading() + def test_decompress_locking(self): + input = compress(b'a'* (16*_1K)) + num_threads = 8 + # to ensure we decompress over multiple calls, set maxsize + window_size = _1K * 16//num_threads + + decomp = ZstdDecompressor() + parts = [] + for _ in range(num_threads): + res = decomp.decompress(input, window_size) + if res: + parts.append(res) + expected = b''.join(parts) + + comp = ZstdDecompressor() + output = [] + def run_method(method, input_data, output_data): + res = method(input_data, window_size) + if res: + output_data.append(res) + threads = [] + + for i in range(num_threads): + thread = threading.Thread(target=run_method, args=(comp.decompress, input, output)) + + threads.append(thread) + + with threading_helper.start_threads(threads): + pass + + actual = b''.join(output) + self.assertEqual(expected, actual) + + @threading_helper.reap_threads + @threading_helper.requires_working_threading() + def test_compress_shared_dict(self): + num_threads = 8 + + def run_method(b): + level = threading.get_ident() % 4 + # sync threads to increase chance of contention on + # capsule storing dictionary levels + b.wait() + ZstdCompressor(level=level, + zstd_dict=TRAINED_DICT.as_digested_dict) + b.wait() + ZstdCompressor(level=level, + zstd_dict=TRAINED_DICT.as_undigested_dict) + b.wait() + ZstdCompressor(level=level, + zstd_dict=TRAINED_DICT.as_prefix) + threads = [] + + b = threading.Barrier(num_threads) + for i in range(num_threads): + thread = threading.Thread(target=run_method, args=(b,)) + + threads.append(thread) + + with threading_helper.start_threads(threads): + pass + + @threading_helper.reap_threads + @threading_helper.requires_working_threading() + def test_decompress_shared_dict(self): + num_threads = 8 + + def run_method(b): + # sync threads to increase chance of contention on + # decompression dictionary + b.wait() + ZstdDecompressor(zstd_dict=TRAINED_DICT.as_digested_dict) + b.wait() + ZstdDecompressor(zstd_dict=TRAINED_DICT.as_undigested_dict) + b.wait() + ZstdDecompressor(zstd_dict=TRAINED_DICT.as_prefix) + threads = [] + + b = threading.Barrier(num_threads) + for i in range(num_threads): + thread = threading.Thread(target=run_method, args=(b,)) + + threads.append(thread) + + with threading_helper.start_threads(threads): + pass + + +if __name__ == "__main__": + unittest.main() diff --git a/Lib/unittest/__init__.py b/Lib/unittest/__init__.py index 6878c2a8f5a..b049402eed7 100644 --- a/Lib/unittest/__init__.py +++ b/Lib/unittest/__init__.py @@ -27,7 +27,7 @@ def testMultiply(self): http://docs.python.org/library/unittest.html Copyright (c) 1999-2003 Steve Purcell -Copyright (c) 2003-2010 Python Software Foundation +Copyright (c) 2003 Python Software Foundation This module is free software, and you may redistribute it and/or modify it under the same terms as Python itself, so long as this copyright message and disclaimer are retained in their original form. @@ -53,15 +53,30 @@ def testMultiply(self): __unittest = True -from .result import TestResult -from .case import (addModuleCleanup, TestCase, FunctionTestCase, SkipTest, skip, - skipIf, skipUnless, expectedFailure, doModuleCleanups, - enterModuleContext) -from .suite import BaseTestSuite, TestSuite +from .case import ( + FunctionTestCase, + SkipTest, + TestCase, + addModuleCleanup, + doModuleCleanups, + enterModuleContext, + expectedFailure, + skip, + skipIf, + skipUnless, +) from .loader import TestLoader, defaultTestLoader -from .main import TestProgram, main -from .runner import TextTestRunner, TextTestResult -from .signals import installHandler, registerResult, removeResult, removeHandler +from .main import TestProgram, main # noqa: F401 +from .result import TestResult +from .runner import TextTestResult, TextTestRunner +from .signals import ( + installHandler, + registerResult, + removeHandler, + removeResult, +) +from .suite import BaseTestSuite, TestSuite # noqa: F401 + # IsolatedAsyncioTestCase will be imported lazily. diff --git a/Lib/unittest/__main__.py b/Lib/unittest/__main__.py index e5876f569b5..50111190eee 100644 --- a/Lib/unittest/__main__.py +++ b/Lib/unittest/__main__.py @@ -1,6 +1,7 @@ """Main entry point""" import sys + if sys.argv[0].endswith("__main__.py"): import os.path # We change sys.argv[0] to make help message more useful diff --git a/Lib/unittest/_log.py b/Lib/unittest/_log.py index 94868e5bb95..c61abb15745 100644 --- a/Lib/unittest/_log.py +++ b/Lib/unittest/_log.py @@ -1,9 +1,8 @@ -import logging import collections +import logging from .case import _BaseTestCaseContext - _LoggingWatcher = collections.namedtuple("_LoggingWatcher", ["records", "output"]) diff --git a/Lib/unittest/async_case.py b/Lib/unittest/async_case.py index e761ba7e53c..a1c0d6c368c 100644 --- a/Lib/unittest/async_case.py +++ b/Lib/unittest/async_case.py @@ -75,9 +75,17 @@ async def enterAsyncContext(self, cm): enter = cls.__aenter__ exit = cls.__aexit__ except AttributeError: - raise TypeError(f"'{cls.__module__}.{cls.__qualname__}' object does " - f"not support the asynchronous context manager protocol" - ) from None + msg = (f"'{cls.__module__}.{cls.__qualname__}' object does " + "not support the asynchronous context manager protocol") + try: + cls.__enter__ + cls.__exit__ + except AttributeError: + pass + else: + msg += (" but it supports the context manager protocol. " + "Did you mean to use enterContext()?") + raise TypeError(msg) from None result = await enter(cm) self.addAsyncCleanup(exit, cm, None, None, None) return result @@ -91,9 +99,13 @@ def _callSetUp(self): self._callAsync(self.asyncSetUp) def _callTestMethod(self, method): - if self._callMaybeAsync(method) is not None: - warnings.warn(f'It is deprecated to return a value that is not None from a ' - f'test case ({method})', DeprecationWarning, stacklevel=4) + result = self._callMaybeAsync(method) + if result is not None: + msg = ( + f'It is deprecated to return a value that is not None ' + f'from a test case ({method} returned {type(result).__name__!r})', + ) + warnings.warn(msg, DeprecationWarning, stacklevel=4) def _callTearDown(self): self._callAsync(self.asyncTearDown) diff --git a/Lib/unittest/case.py b/Lib/unittest/case.py index 36daa61fa31..b09836d6747 100644 --- a/Lib/unittest/case.py +++ b/Lib/unittest/case.py @@ -1,20 +1,25 @@ """Test case implementation""" -import sys -import functools +import collections +import contextlib import difflib +import functools import pprint import re -import warnings -import collections -import contextlib -import traceback +import sys import time +import traceback import types +import warnings from . import result -from .util import (strclass, safe_repr, _count_diff_all_purpose, - _count_diff_hashable, _common_shorten_repr) +from .util import ( + _common_shorten_repr, + _count_diff_all_purpose, + _count_diff_hashable, + safe_repr, + strclass, +) __unittest = True @@ -111,8 +116,17 @@ def _enter_context(cm, addcleanup): enter = cls.__enter__ exit = cls.__exit__ except AttributeError: - raise TypeError(f"'{cls.__module__}.{cls.__qualname__}' object does " - f"not support the context manager protocol") from None + msg = (f"'{cls.__module__}.{cls.__qualname__}' object does " + "not support the context manager protocol") + try: + cls.__aenter__ + cls.__aexit__ + except AttributeError: + pass + else: + msg += (" but it supports the asynchronous context manager " + "protocol. Did you mean to use enterAsyncContext()?") + raise TypeError(msg) from None result = enter(cm) addcleanup(exit, cm, None, None, None) return result @@ -603,9 +617,18 @@ def _callSetUp(self): self.setUp() def _callTestMethod(self, method): - if method() is not None: - warnings.warn(f'It is deprecated to return a value that is not None from a ' - f'test case ({method})', DeprecationWarning, stacklevel=3) + result = method() + if result is not None: + import inspect + msg = ( + f'It is deprecated to return a value that is not None ' + f'from a test case ({method} returned {type(result).__name__!r})' + ) + if inspect.iscoroutine(result): + msg += ( + '. Maybe you forgot to use IsolatedAsyncioTestCase as the base class?' + ) + warnings.warn(msg, DeprecationWarning, stacklevel=3) def _callTearDown(self): self.tearDown() @@ -1312,13 +1335,71 @@ def assertIsInstance(self, obj, cls, msg=None): """Same as self.assertTrue(isinstance(obj, cls)), with a nicer default message.""" if not isinstance(obj, cls): - standardMsg = '%s is not an instance of %r' % (safe_repr(obj), cls) + if isinstance(cls, tuple): + standardMsg = f'{safe_repr(obj)} is not an instance of any of {cls!r}' + else: + standardMsg = f'{safe_repr(obj)} is not an instance of {cls!r}' self.fail(self._formatMessage(msg, standardMsg)) def assertNotIsInstance(self, obj, cls, msg=None): """Included for symmetry with assertIsInstance.""" if isinstance(obj, cls): - standardMsg = '%s is an instance of %r' % (safe_repr(obj), cls) + if isinstance(cls, tuple): + for x in cls: + if isinstance(obj, x): + cls = x + break + standardMsg = f'{safe_repr(obj)} is an instance of {cls!r}' + self.fail(self._formatMessage(msg, standardMsg)) + + def assertIsSubclass(self, cls, superclass, msg=None): + try: + if issubclass(cls, superclass): + return + except TypeError: + if not isinstance(cls, type): + self.fail(self._formatMessage(msg, f'{cls!r} is not a class')) + raise + if isinstance(superclass, tuple): + standardMsg = f'{cls!r} is not a subclass of any of {superclass!r}' + else: + standardMsg = f'{cls!r} is not a subclass of {superclass!r}' + self.fail(self._formatMessage(msg, standardMsg)) + + def assertNotIsSubclass(self, cls, superclass, msg=None): + try: + if not issubclass(cls, superclass): + return + except TypeError: + if not isinstance(cls, type): + self.fail(self._formatMessage(msg, f'{cls!r} is not a class')) + raise + if isinstance(superclass, tuple): + for x in superclass: + if issubclass(cls, x): + superclass = x + break + standardMsg = f'{cls!r} is a subclass of {superclass!r}' + self.fail(self._formatMessage(msg, standardMsg)) + + def assertHasAttr(self, obj, name, msg=None): + if not hasattr(obj, name): + if isinstance(obj, types.ModuleType): + standardMsg = f'module {obj.__name__!r} has no attribute {name!r}' + elif isinstance(obj, type): + standardMsg = f'type object {obj.__name__!r} has no attribute {name!r}' + else: + standardMsg = f'{type(obj).__name__!r} object has no attribute {name!r}' + self.fail(self._formatMessage(msg, standardMsg)) + + def assertNotHasAttr(self, obj, name, msg=None): + if hasattr(obj, name): + if isinstance(obj, types.ModuleType): + standardMsg = f'module {obj.__name__!r} has unexpected attribute {name!r}' + elif isinstance(obj, type): + standardMsg = f'type object {obj.__name__!r} has unexpected attribute {name!r}' + else: + standardMsg = f'{type(obj).__name__!r} object has unexpected attribute {name!r}' self.fail(self._formatMessage(msg, standardMsg)) def assertRaisesRegex(self, expected_exception, expected_regex, @@ -1382,6 +1463,80 @@ def assertNotRegex(self, text, unexpected_regex, msg=None): msg = self._formatMessage(msg, standardMsg) raise self.failureException(msg) + def _tail_type_check(self, s, tails, msg): + if not isinstance(tails, tuple): + tails = (tails,) + for tail in tails: + if isinstance(tail, str): + if not isinstance(s, str): + self.fail(self._formatMessage(msg, + f'Expected str, not {type(s).__name__}')) + elif isinstance(tail, (bytes, bytearray)): + if not isinstance(s, (bytes, bytearray)): + self.fail(self._formatMessage(msg, + f'Expected bytes, not {type(s).__name__}')) + + def assertStartsWith(self, s, prefix, msg=None): + try: + if s.startswith(prefix): + return + except (AttributeError, TypeError): + self._tail_type_check(s, prefix, msg) + raise + a = safe_repr(s, short=True) + b = safe_repr(prefix) + if isinstance(prefix, tuple): + standardMsg = f"{a} doesn't start with any of {b}" + else: + standardMsg = f"{a} doesn't start with {b}" + self.fail(self._formatMessage(msg, standardMsg)) + + def assertNotStartsWith(self, s, prefix, msg=None): + try: + if not s.startswith(prefix): + return + except (AttributeError, TypeError): + self._tail_type_check(s, prefix, msg) + raise + if isinstance(prefix, tuple): + for x in prefix: + if s.startswith(x): + prefix = x + break + a = safe_repr(s, short=True) + b = safe_repr(prefix) + self.fail(self._formatMessage(msg, f"{a} starts with {b}")) + + def assertEndsWith(self, s, suffix, msg=None): + try: + if s.endswith(suffix): + return + except (AttributeError, TypeError): + self._tail_type_check(s, suffix, msg) + raise + a = safe_repr(s, short=True) + b = safe_repr(suffix) + if isinstance(suffix, tuple): + standardMsg = f"{a} doesn't end with any of {b}" + else: + standardMsg = f"{a} doesn't end with {b}" + self.fail(self._formatMessage(msg, standardMsg)) + + def assertNotEndsWith(self, s, suffix, msg=None): + try: + if not s.endswith(suffix): + return + except (AttributeError, TypeError): + self._tail_type_check(s, suffix, msg) + raise + if isinstance(suffix, tuple): + for x in suffix: + if s.endswith(x): + suffix = x + break + a = safe_repr(s, short=True) + b = safe_repr(suffix) + self.fail(self._formatMessage(msg, f"{a} ends with {b}")) class FunctionTestCase(TestCase): diff --git a/Lib/unittest/loader.py b/Lib/unittest/loader.py index 22797b83a68..fa8d647ad8a 100644 --- a/Lib/unittest/loader.py +++ b/Lib/unittest/loader.py @@ -1,12 +1,11 @@ """Loading unittests.""" +import functools import os import re import sys import traceback import types -import functools - from fnmatch import fnmatch, fnmatchcase from . import case, suite, util @@ -274,6 +273,8 @@ def discover(self, start_dir, pattern='test*.py', top_level_dir=None): self._top_level_dir = top_level_dir is_not_importable = False + is_namespace = False + tests = [] if os.path.isdir(os.path.abspath(start_dir)): start_dir = os.path.abspath(start_dir) if start_dir != top_level_dir: @@ -286,12 +287,25 @@ def discover(self, start_dir, pattern='test*.py', top_level_dir=None): is_not_importable = True else: the_module = sys.modules[start_dir] - top_part = start_dir.split('.')[0] - try: - start_dir = os.path.abspath( - os.path.dirname((the_module.__file__))) - except AttributeError: - if the_module.__name__ in sys.builtin_module_names: + if not hasattr(the_module, "__file__") or the_module.__file__ is None: + # look for namespace packages + try: + spec = the_module.__spec__ + except AttributeError: + spec = None + + if spec and spec.submodule_search_locations is not None: + is_namespace = True + + for path in the_module.__path__: + if (not set_implicit_top and + not path.startswith(top_level_dir)): + continue + self._top_level_dir = \ + (path.split(the_module.__name__ + .replace(".", os.path.sep))[0]) + tests.extend(self._find_tests(path, pattern, namespace=True)) + elif the_module.__name__ in sys.builtin_module_names: # builtin module raise TypeError('Can not use builtin modules ' 'as dotted module names') from None @@ -300,14 +314,27 @@ def discover(self, start_dir, pattern='test*.py', top_level_dir=None): f"don't know how to discover from {the_module!r}" ) from None + else: + top_part = start_dir.split('.')[0] + start_dir = os.path.abspath(os.path.dirname((the_module.__file__))) + if set_implicit_top: - self._top_level_dir = self._get_directory_containing_module(top_part) + if not is_namespace: + if sys.modules[top_part].__file__ is None: + self._top_level_dir = os.path.dirname(the_module.__file__) + if self._top_level_dir not in sys.path: + sys.path.insert(0, self._top_level_dir) + else: + self._top_level_dir = \ + self._get_directory_containing_module(top_part) sys.path.remove(top_level_dir) if is_not_importable: raise ImportError('Start directory is not importable: %r' % start_dir) - tests = list(self._find_tests(start_dir, pattern)) + if not is_namespace: + tests = list(self._find_tests(start_dir, pattern)) + self._top_level_dir = original_top_level_dir return self.suiteClass(tests) @@ -343,7 +370,7 @@ def _match_path(self, path, full_path, pattern): # override this method to use alternative matching strategy return fnmatch(path, pattern) - def _find_tests(self, start_dir, pattern): + def _find_tests(self, start_dir, pattern, namespace=False): """Used by discovery. Yields test suites it loads.""" # Handle the __init__ in this package name = self._get_name_from_path(start_dir) @@ -352,7 +379,8 @@ def _find_tests(self, start_dir, pattern): if name != '.' and name not in self._loading_packages: # name is in self._loading_packages while we have called into # loadTestsFromModule with name. - tests, should_recurse = self._find_test_path(start_dir, pattern) + tests, should_recurse = self._find_test_path( + start_dir, pattern, namespace) if tests is not None: yield tests if not should_recurse: @@ -363,7 +391,8 @@ def _find_tests(self, start_dir, pattern): paths = sorted(os.listdir(start_dir)) for path in paths: full_path = os.path.join(start_dir, path) - tests, should_recurse = self._find_test_path(full_path, pattern) + tests, should_recurse = self._find_test_path( + full_path, pattern, False) if tests is not None: yield tests if should_recurse: @@ -371,11 +400,11 @@ def _find_tests(self, start_dir, pattern): name = self._get_name_from_path(full_path) self._loading_packages.add(name) try: - yield from self._find_tests(full_path, pattern) + yield from self._find_tests(full_path, pattern, False) finally: self._loading_packages.discard(name) - def _find_test_path(self, full_path, pattern): + def _find_test_path(self, full_path, pattern, namespace=False): """Used by discovery. Loads tests from a single file, or a directories' __init__.py when @@ -419,7 +448,8 @@ def _find_test_path(self, full_path, pattern): msg % (mod_name, module_dir, expected_dir)) return self.loadTestsFromModule(module, pattern=pattern), False elif os.path.isdir(full_path): - if not os.path.isfile(os.path.join(full_path, '__init__.py')): + if (not namespace and + not os.path.isfile(os.path.join(full_path, '__init__.py'))): return None, False load_tests = None diff --git a/Lib/unittest/main.py b/Lib/unittest/main.py index a0cd8a9f7ea..1855fccf336 100644 --- a/Lib/unittest/main.py +++ b/Lib/unittest/main.py @@ -1,8 +1,8 @@ """Unittest main program""" -import sys import argparse import os +import sys from . import loader, runner from .signals import installHandler @@ -197,7 +197,7 @@ def _getParentArgParser(self): return parser def _getMainArgParser(self, parent): - parser = argparse.ArgumentParser(parents=[parent]) + parser = argparse.ArgumentParser(parents=[parent], color=True) parser.prog = self.progName parser.print_help = self._print_help @@ -208,7 +208,7 @@ def _getMainArgParser(self, parent): return parser def _getDiscoveryArgParser(self, parent): - parser = argparse.ArgumentParser(parents=[parent]) + parser = argparse.ArgumentParser(parents=[parent], color=True) parser.prog = '%s discover' % self.progName parser.epilog = ('For test discovery all test modules must be ' 'importable from the top level directory of the ' diff --git a/Lib/unittest/mock.py b/Lib/unittest/mock.py index 6cec61ff35c..1089dcb11f1 100644 --- a/Lib/unittest/mock.py +++ b/Lib/unittest/mock.py @@ -25,19 +25,20 @@ import asyncio +import builtins import contextlib -import io import inspect +import io +import pkgutil import pprint import sys -import builtins -import pkgutil -from asyncio import iscoroutinefunction import threading -from types import CodeType, ModuleType, MethodType -from unittest.util import safe_repr -from functools import wraps, partial +from dataclasses import fields, is_dataclass +from functools import partial, wraps +from inspect import iscoroutinefunction from threading import RLock +from types import CodeType, MethodType, ModuleType +from unittest.util import safe_repr class InvalidSpecError(Exception): @@ -568,6 +569,11 @@ def _mock_add_spec(self, spec, spec_set, _spec_as_instance=False, __dict__['_mock_methods'] = spec __dict__['_spec_asyncs'] = _spec_asyncs + def _mock_extend_spec_methods(self, spec_methods): + methods = self.__dict__.get('_mock_methods') or [] + methods.extend(spec_methods) + self.__dict__['_mock_methods'] = methods + def __get_return_value(self): ret = self._mock_return_value if self._mock_delegate is not None: @@ -1766,7 +1772,7 @@ def patch( the patch is undone. If `new` is omitted, then the target is replaced with an - `AsyncMock if the patched object is an async function or a + `AsyncMock` if the patched object is an async function or a `MagicMock` otherwise. If `patch` is used as a decorator and `new` is omitted, the created mock is passed in as an extra argument to the decorated function. If `patch` is used as a context manager the created @@ -1840,7 +1846,8 @@ def patch( class _patch_dict(object): """ Patch a dictionary, or dictionary like object, and restore the dictionary - to its original state after the test. + to its original state after the test, where the restored dictionary is + a copy of the dictionary as it was before the test. `in_dict` can be a dictionary or a mapping like container. If it is a mapping then it must at least support getting, setting and deleting items @@ -2176,8 +2183,6 @@ def _mock_set_magics(self): if getattr(self, "_mock_methods", None) is not None: these_magics = orig_magics.intersection(self._mock_methods) - - remove_magics = set() remove_magics = orig_magics - these_magics for entry in remove_magics: @@ -2477,7 +2482,7 @@ class AsyncMock(AsyncMockMixin, AsyncMagicMixin, Mock): recognized as an async function, and the result of a call is an awaitable: >>> mock = AsyncMock() - >>> iscoroutinefunction(mock) + >>> inspect.iscoroutinefunction(mock) True >>> inspect.isawaitable(mock()) True @@ -2767,6 +2772,16 @@ def create_autospec(spec, spec_set=False, instance=False, _parent=None, f'[object={spec!r}]') is_async_func = _is_async_func(spec) _kwargs = {'spec': spec} + + entries = [(entry, _missing) for entry in dir(spec)] + if is_type and instance and is_dataclass(spec): + is_dataclass_spec = True + dataclass_fields = fields(spec) + entries.extend((f.name, f.type) for f in dataclass_fields) + dataclass_spec_list = [f.name for f in dataclass_fields] + else: + is_dataclass_spec = False + if spec_set: _kwargs = {'spec_set': spec} elif spec is None: @@ -2802,6 +2817,8 @@ def create_autospec(spec, spec_set=False, instance=False, _parent=None, mock = Klass(parent=_parent, _new_parent=_parent, _new_name=_new_name, name=_name, **_kwargs) + if is_dataclass_spec: + mock._mock_extend_spec_methods(dataclass_spec_list) if isinstance(spec, FunctionTypes): # should only happen at the top level because we don't @@ -2823,7 +2840,7 @@ def create_autospec(spec, spec_set=False, instance=False, _parent=None, _name='()', _parent=mock, wraps=wrapped) - for entry in dir(spec): + for entry, original in entries: if _is_magic(entry): # MagicMock already does the useful magic methods for us continue @@ -2837,10 +2854,11 @@ def create_autospec(spec, spec_set=False, instance=False, _parent=None, # AttributeError on being fetched? # we could be resilient against it, or catch and propagate the # exception when the attribute is fetched from the mock - try: - original = getattr(spec, entry) - except AttributeError: - continue + if original is _missing: + try: + original = getattr(spec, entry) + except AttributeError: + continue child_kwargs = {'spec': original} # Wrap child attributes also. diff --git a/Lib/unittest/result.py b/Lib/unittest/result.py index 3ace0a5b7bf..8eafb3891c9 100644 --- a/Lib/unittest/result.py +++ b/Lib/unittest/result.py @@ -3,9 +3,9 @@ import io import sys import traceback +from functools import wraps from . import util -from functools import wraps __unittest = True @@ -189,7 +189,10 @@ def _exc_info_to_string(self, err, test): tb_e = traceback.TracebackException( exctype, value, tb, capture_locals=self.tb_locals, compact=True) - msgLines = list(tb_e.format()) + from _colorize import can_colorize + + colorize = hasattr(self, "stream") and can_colorize(file=self.stream) + msgLines = list(tb_e.format(colorize=colorize)) if self.buffer: output = sys.stdout.getvalue() diff --git a/Lib/unittest/runner.py b/Lib/unittest/runner.py index 2bcadf0c998..5f22d91aebd 100644 --- a/Lib/unittest/runner.py +++ b/Lib/unittest/runner.py @@ -4,6 +4,8 @@ import time import warnings +from _colorize import get_theme + from . import result from .case import _SubTest from .signals import registerResult @@ -13,18 +15,18 @@ class _WritelnDecorator(object): """Used to decorate file-like objects with a handy 'writeln' method""" - def __init__(self,stream): + def __init__(self, stream): self.stream = stream def __getattr__(self, attr): if attr in ('stream', '__getstate__'): raise AttributeError(attr) - return getattr(self.stream,attr) + return getattr(self.stream, attr) def writeln(self, arg=None): if arg: self.write(arg) - self.write('\n') # text-mode streams translate to \r\n if needed + self.write('\n') # text-mode streams translate to \r\n if needed class TextTestResult(result.TestResult): @@ -43,6 +45,7 @@ def __init__(self, stream, descriptions, verbosity, *, durations=None): self.showAll = verbosity > 1 self.dots = verbosity == 1 self.descriptions = descriptions + self._theme = get_theme(tty_file=stream).unittest self._newline = True self.durations = durations @@ -76,86 +79,100 @@ def _write_status(self, test, status): def addSubTest(self, test, subtest, err): if err is not None: + t = self._theme if self.showAll: if issubclass(err[0], subtest.failureException): - self._write_status(subtest, "FAIL") + self._write_status(subtest, f"{t.fail}FAIL{t.reset}") else: - self._write_status(subtest, "ERROR") + self._write_status(subtest, f"{t.fail}ERROR{t.reset}") elif self.dots: if issubclass(err[0], subtest.failureException): - self.stream.write('F') + self.stream.write(f"{t.fail}F{t.reset}") else: - self.stream.write('E') + self.stream.write(f"{t.fail}E{t.reset}") self.stream.flush() super(TextTestResult, self).addSubTest(test, subtest, err) def addSuccess(self, test): super(TextTestResult, self).addSuccess(test) + t = self._theme if self.showAll: - self._write_status(test, "ok") + self._write_status(test, f"{t.passed}ok{t.reset}") elif self.dots: - self.stream.write('.') + self.stream.write(f"{t.passed}.{t.reset}") self.stream.flush() def addError(self, test, err): super(TextTestResult, self).addError(test, err) + t = self._theme if self.showAll: - self._write_status(test, "ERROR") + self._write_status(test, f"{t.fail}ERROR{t.reset}") elif self.dots: - self.stream.write('E') + self.stream.write(f"{t.fail}E{t.reset}") self.stream.flush() def addFailure(self, test, err): super(TextTestResult, self).addFailure(test, err) + t = self._theme if self.showAll: - self._write_status(test, "FAIL") + self._write_status(test, f"{t.fail}FAIL{t.reset}") elif self.dots: - self.stream.write('F') + self.stream.write(f"{t.fail}F{t.reset}") self.stream.flush() def addSkip(self, test, reason): super(TextTestResult, self).addSkip(test, reason) + t = self._theme if self.showAll: - self._write_status(test, "skipped {0!r}".format(reason)) + self._write_status(test, f"{t.warn}skipped{t.reset} {reason!r}") elif self.dots: - self.stream.write("s") + self.stream.write(f"{t.warn}s{t.reset}") self.stream.flush() def addExpectedFailure(self, test, err): super(TextTestResult, self).addExpectedFailure(test, err) + t = self._theme if self.showAll: - self.stream.writeln("expected failure") + self.stream.writeln(f"{t.warn}expected failure{t.reset}") self.stream.flush() elif self.dots: - self.stream.write("x") + self.stream.write(f"{t.warn}x{t.reset}") self.stream.flush() def addUnexpectedSuccess(self, test): super(TextTestResult, self).addUnexpectedSuccess(test) + t = self._theme if self.showAll: - self.stream.writeln("unexpected success") + self.stream.writeln(f"{t.fail}unexpected success{t.reset}") self.stream.flush() elif self.dots: - self.stream.write("u") + self.stream.write(f"{t.fail}u{t.reset}") self.stream.flush() def printErrors(self): + t = self._theme if self.dots or self.showAll: self.stream.writeln() self.stream.flush() - self.printErrorList('ERROR', self.errors) - self.printErrorList('FAIL', self.failures) - unexpectedSuccesses = getattr(self, 'unexpectedSuccesses', ()) + self.printErrorList(f"{t.fail}ERROR{t.reset}", self.errors) + self.printErrorList(f"{t.fail}FAIL{t.reset}", self.failures) + unexpectedSuccesses = getattr(self, "unexpectedSuccesses", ()) if unexpectedSuccesses: self.stream.writeln(self.separator1) for test in unexpectedSuccesses: - self.stream.writeln(f"UNEXPECTED SUCCESS: {self.getDescription(test)}") + self.stream.writeln( + f"{t.fail}UNEXPECTED SUCCESS{t.fail_info}: " + f"{self.getDescription(test)}{t.reset}" + ) self.stream.flush() def printErrorList(self, flavour, errors): + t = self._theme for test, err in errors: self.stream.writeln(self.separator1) - self.stream.writeln("%s: %s" % (flavour,self.getDescription(test))) + self.stream.writeln( + f"{flavour}{t.fail_info}: {self.getDescription(test)}{t.reset}" + ) self.stream.writeln(self.separator2) self.stream.writeln("%s" % err) self.stream.flush() @@ -232,7 +249,7 @@ def run(self, test): if self.warnings: # if self.warnings is set, use it to filter all the warnings warnings.simplefilter(self.warnings) - startTime = time.perf_counter() + start_time = time.perf_counter() startTestRun = getattr(result, 'startTestRun', None) if startTestRun is not None: startTestRun() @@ -242,8 +259,8 @@ def run(self, test): stopTestRun = getattr(result, 'stopTestRun', None) if stopTestRun is not None: stopTestRun() - stopTime = time.perf_counter() - timeTaken = stopTime - startTime + stop_time = time.perf_counter() + time_taken = stop_time - start_time result.printErrors() if self.durations is not None: self._printDurations(result) @@ -253,10 +270,10 @@ def run(self, test): run = result.testsRun self.stream.writeln("Ran %d test%s in %.3fs" % - (run, run != 1 and "s" or "", timeTaken)) + (run, run != 1 and "s" or "", time_taken)) self.stream.writeln() - expectedFails = unexpectedSuccesses = skipped = 0 + expected_fails = unexpected_successes = skipped = 0 try: results = map(len, (result.expectedFailures, result.unexpectedSuccesses, @@ -264,26 +281,30 @@ def run(self, test): except AttributeError: pass else: - expectedFails, unexpectedSuccesses, skipped = results + expected_fails, unexpected_successes, skipped = results infos = [] + t = get_theme(tty_file=self.stream).unittest + if not result.wasSuccessful(): - self.stream.write("FAILED") + self.stream.write(f"{t.fail_info}FAILED{t.reset}") failed, errored = len(result.failures), len(result.errors) if failed: - infos.append("failures=%d" % failed) + infos.append(f"{t.fail_info}failures={failed}{t.reset}") if errored: - infos.append("errors=%d" % errored) + infos.append(f"{t.fail_info}errors={errored}{t.reset}") elif run == 0 and not skipped: - self.stream.write("NO TESTS RAN") + self.stream.write(f"{t.warn}NO TESTS RAN{t.reset}") else: - self.stream.write("OK") + self.stream.write(f"{t.passed}OK{t.reset}") if skipped: - infos.append("skipped=%d" % skipped) - if expectedFails: - infos.append("expected failures=%d" % expectedFails) - if unexpectedSuccesses: - infos.append("unexpected successes=%d" % unexpectedSuccesses) + infos.append(f"{t.warn}skipped={skipped}{t.reset}") + if expected_fails: + infos.append(f"{t.warn}expected failures={expected_fails}{t.reset}") + if unexpected_successes: + infos.append( + f"{t.fail}unexpected successes={unexpected_successes}{t.reset}" + ) if infos: self.stream.writeln(" (%s)" % (", ".join(infos),)) else: diff --git a/Lib/unittest/signals.py b/Lib/unittest/signals.py index e6a5fc52439..4e654c2c5db 100644 --- a/Lib/unittest/signals.py +++ b/Lib/unittest/signals.py @@ -1,6 +1,5 @@ import signal import weakref - from functools import wraps __unittest = True diff --git a/Lib/unittest/suite.py b/Lib/unittest/suite.py index 6f45b6fe5f6..3c40176f070 100644 --- a/Lib/unittest/suite.py +++ b/Lib/unittest/suite.py @@ -2,8 +2,7 @@ import sys -from . import case -from . import util +from . import case, util __unittest = True diff --git a/Lib/unittest/util.py b/Lib/unittest/util.py index 050eaed0b3f..b81b6a4219b 100644 --- a/Lib/unittest/util.py +++ b/Lib/unittest/util.py @@ -1,6 +1,6 @@ """Various utility functions.""" -from collections import namedtuple, Counter +from collections import Counter, namedtuple from os.path import commonprefix __unittest = True diff --git a/Lib/warnings.py b/Lib/warnings.py index f83aaf231ea..6759857d909 100644 --- a/Lib/warnings.py +++ b/Lib/warnings.py @@ -1,735 +1,99 @@ -"""Python part of the warnings subsystem.""" - import sys +__all__ = [ + "warn", + "warn_explicit", + "showwarning", + "formatwarning", + "filterwarnings", + "simplefilter", + "resetwarnings", + "catch_warnings", + "deprecated", +] + +from _py_warnings import ( + WarningMessage, + _DEPRECATED_MSG, + _OptionError, + _add_filter, + _deprecated, + _filters_mutated, + _filters_mutated_lock_held, + _filters_version, + _formatwarning_orig, + _formatwarnmsg, + _formatwarnmsg_impl, + _get_context, + _get_filters, + _getaction, + _getcategory, + _is_filename_to_skip, + _is_internal_filename, + _is_internal_frame, + _lock, + _new_context, + _next_external_frame, + _processoptions, + _set_context, + _set_module, + _setoption, + _setup_defaults, + _showwarning_orig, + _showwarnmsg, + _showwarnmsg_impl, + _use_context, + _warn_unawaited_coroutine, + _warnings_context, + catch_warnings, + defaultaction, + deprecated, + filters, + filterwarnings, + formatwarning, + onceregistry, + resetwarnings, + showwarning, + simplefilter, + warn, + warn_explicit, +) -__all__ = ["warn", "warn_explicit", "showwarning", - "formatwarning", "filterwarnings", "simplefilter", - "resetwarnings", "catch_warnings", "deprecated"] - -def showwarning(message, category, filename, lineno, file=None, line=None): - """Hook to write a warning to a file; replace if you like.""" - msg = WarningMessage(message, category, filename, lineno, file, line) - _showwarnmsg_impl(msg) - -def formatwarning(message, category, filename, lineno, line=None): - """Function to format a warning the standard way.""" - msg = WarningMessage(message, category, filename, lineno, None, line) - return _formatwarnmsg_impl(msg) - -def _showwarnmsg_impl(msg): - file = msg.file - if file is None: - file = sys.stderr - if file is None: - # sys.stderr is None when run with pythonw.exe: - # warnings get lost - return - text = _formatwarnmsg(msg) - try: - file.write(text) - except OSError: - # the file (probably stderr) is invalid - this warning gets lost. - pass - -def _formatwarnmsg_impl(msg): - category = msg.category.__name__ - s = f"{msg.filename}:{msg.lineno}: {category}: {msg.message}\n" - - if msg.line is None: - try: - import linecache - line = linecache.getline(msg.filename, msg.lineno) - except Exception: - # When a warning is logged during Python shutdown, linecache - # and the import machinery don't work anymore - line = None - linecache = None - else: - line = msg.line - if line: - line = line.strip() - s += " %s\n" % line - - if msg.source is not None: - try: - import tracemalloc - # Logging a warning should not raise a new exception: - # catch Exception, not only ImportError and RecursionError. - except Exception: - # don't suggest to enable tracemalloc if it's not available - suggest_tracemalloc = False - tb = None - else: - try: - suggest_tracemalloc = not tracemalloc.is_tracing() - tb = tracemalloc.get_object_traceback(msg.source) - except Exception: - # When a warning is logged during Python shutdown, tracemalloc - # and the import machinery don't work anymore - suggest_tracemalloc = False - tb = None - - if tb is not None: - s += 'Object allocated at (most recent call last):\n' - for frame in tb: - s += (' File "%s", lineno %s\n' - % (frame.filename, frame.lineno)) - - try: - if linecache is not None: - line = linecache.getline(frame.filename, frame.lineno) - else: - line = None - except Exception: - line = None - if line: - line = line.strip() - s += ' %s\n' % line - elif suggest_tracemalloc: - s += (f'{category}: Enable tracemalloc to get the object ' - f'allocation traceback\n') - return s - -# Keep a reference to check if the function was replaced -_showwarning_orig = showwarning - -def _showwarnmsg(msg): - """Hook to write a warning to a file; replace if you like.""" - try: - sw = showwarning - except NameError: - pass - else: - if sw is not _showwarning_orig: - # warnings.showwarning() was replaced - if not callable(sw): - raise TypeError("warnings.showwarning() must be set to a " - "function or method") - - sw(msg.message, msg.category, msg.filename, msg.lineno, - msg.file, msg.line) - return - _showwarnmsg_impl(msg) - -# Keep a reference to check if the function was replaced -_formatwarning_orig = formatwarning - -def _formatwarnmsg(msg): - """Function to format a warning the standard way.""" - try: - fw = formatwarning - except NameError: - pass - else: - if fw is not _formatwarning_orig: - # warnings.formatwarning() was replaced - return fw(msg.message, msg.category, - msg.filename, msg.lineno, msg.line) - return _formatwarnmsg_impl(msg) - -def filterwarnings(action, message="", category=Warning, module="", lineno=0, - append=False): - """Insert an entry into the list of warnings filters (at the front). - - 'action' -- one of "error", "ignore", "always", "default", "module", - or "once" - 'message' -- a regex that the warning message must match - 'category' -- a class that the warning must be a subclass of - 'module' -- a regex that the module name must match - 'lineno' -- an integer line number, 0 matches all warnings - 'append' -- if true, append to the list of filters - """ - if action not in {"error", "ignore", "always", "default", "module", "once"}: - raise ValueError(f"invalid action: {action!r}") - if not isinstance(message, str): - raise TypeError("message must be a string") - if not isinstance(category, type) or not issubclass(category, Warning): - raise TypeError("category must be a Warning subclass") - if not isinstance(module, str): - raise TypeError("module must be a string") - if not isinstance(lineno, int): - raise TypeError("lineno must be an int") - if lineno < 0: - raise ValueError("lineno must be an int >= 0") - - if message or module: - import re - - if message: - message = re.compile(message, re.I) - else: - message = None - if module: - module = re.compile(module) - else: - module = None - - _add_filter(action, message, category, module, lineno, append=append) - -def simplefilter(action, category=Warning, lineno=0, append=False): - """Insert a simple entry into the list of warnings filters (at the front). - - A simple filter matches all modules and messages. - 'action' -- one of "error", "ignore", "always", "default", "module", - or "once" - 'category' -- a class that the warning must be a subclass of - 'lineno' -- an integer line number, 0 matches all warnings - 'append' -- if true, append to the list of filters - """ - if action not in {"error", "ignore", "always", "default", "module", "once"}: - raise ValueError(f"invalid action: {action!r}") - if not isinstance(lineno, int): - raise TypeError("lineno must be an int") - if lineno < 0: - raise ValueError("lineno must be an int >= 0") - _add_filter(action, None, category, None, lineno, append=append) - -def _add_filter(*item, append): - # Remove possible duplicate filters, so new one will be placed - # in correct place. If append=True and duplicate exists, do nothing. - if not append: - try: - filters.remove(item) - except ValueError: - pass - filters.insert(0, item) - else: - if item not in filters: - filters.append(item) - _filters_mutated() - -def resetwarnings(): - """Clear the list of warning filters, so that no filters are active.""" - filters[:] = [] - _filters_mutated() - -class _OptionError(Exception): - """Exception used by option processing helpers.""" - pass - -# Helper to process -W options passed via sys.warnoptions -def _processoptions(args): - for arg in args: - try: - _setoption(arg) - except _OptionError as msg: - print("Invalid -W option ignored:", msg, file=sys.stderr) - -# Helper for _processoptions() -def _setoption(arg): - parts = arg.split(':') - if len(parts) > 5: - raise _OptionError("too many fields (max 5): %r" % (arg,)) - while len(parts) < 5: - parts.append('') - action, message, category, module, lineno = [s.strip() - for s in parts] - action = _getaction(action) - category = _getcategory(category) - if message or module: - import re - if message: - message = re.escape(message) - if module: - module = re.escape(module) + r'\Z' - if lineno: - try: - lineno = int(lineno) - if lineno < 0: - raise ValueError - except (ValueError, OverflowError): - raise _OptionError("invalid lineno %r" % (lineno,)) from None - else: - lineno = 0 - filterwarnings(action, message, category, module, lineno) - -# Helper for _setoption() -def _getaction(action): - if not action: - return "default" - if action == "all": return "always" # Alias - for a in ('default', 'always', 'ignore', 'module', 'once', 'error'): - if a.startswith(action): - return a - raise _OptionError("invalid action: %r" % (action,)) - -# Helper for _setoption() -def _getcategory(category): - if not category: - return Warning - if '.' not in category: - import builtins as m - klass = category - else: - module, _, klass = category.rpartition('.') - try: - m = __import__(module, None, None, [klass]) - except ImportError: - raise _OptionError("invalid module name: %r" % (module,)) from None - try: - cat = getattr(m, klass) - except AttributeError: - raise _OptionError("unknown warning category: %r" % (category,)) from None - if not issubclass(cat, Warning): - raise _OptionError("invalid warning category: %r" % (category,)) - return cat - - -def _is_internal_filename(filename): - return 'importlib' in filename and '_bootstrap' in filename - - -def _is_filename_to_skip(filename, skip_file_prefixes): - return any(filename.startswith(prefix) for prefix in skip_file_prefixes) - - -def _is_internal_frame(frame): - """Signal whether the frame is an internal CPython implementation detail.""" - return _is_internal_filename(frame.f_code.co_filename) - - -def _next_external_frame(frame, skip_file_prefixes): - """Find the next frame that doesn't involve Python or user internals.""" - frame = frame.f_back - while frame is not None and ( - _is_internal_filename(filename := frame.f_code.co_filename) or - _is_filename_to_skip(filename, skip_file_prefixes)): - frame = frame.f_back - return frame - - -# Code typically replaced by _warnings -def warn(message, category=None, stacklevel=1, source=None, - *, skip_file_prefixes=()): - """Issue a warning, or maybe ignore it or raise an exception.""" - # Check if message is already a Warning object - if isinstance(message, Warning): - category = message.__class__ - # Check category argument - if category is None: - category = UserWarning - if not (isinstance(category, type) and issubclass(category, Warning)): - raise TypeError("category must be a Warning subclass, " - "not '{:s}'".format(type(category).__name__)) - if not isinstance(skip_file_prefixes, tuple): - # The C version demands a tuple for implementation performance. - raise TypeError('skip_file_prefixes must be a tuple of strs.') - if skip_file_prefixes: - stacklevel = max(2, stacklevel) - # Get context information - try: - if stacklevel <= 1 or _is_internal_frame(sys._getframe(1)): - # If frame is too small to care or if the warning originated in - # internal code, then do not try to hide any frames. - frame = sys._getframe(stacklevel) - else: - frame = sys._getframe(1) - # Look for one frame less since the above line starts us off. - for x in range(stacklevel-1): - frame = _next_external_frame(frame, skip_file_prefixes) - if frame is None: - raise ValueError - except ValueError: - globals = sys.__dict__ - filename = "" - lineno = 0 - else: - globals = frame.f_globals - filename = frame.f_code.co_filename - lineno = frame.f_lineno - if '__name__' in globals: - module = globals['__name__'] - else: - module = "" - registry = globals.setdefault("__warningregistry__", {}) - warn_explicit(message, category, filename, lineno, module, registry, - globals, source) - -def warn_explicit(message, category, filename, lineno, - module=None, registry=None, module_globals=None, - source=None): - lineno = int(lineno) - if module is None: - module = filename or "" - if module[-3:].lower() == ".py": - module = module[:-3] # XXX What about leading pathname? - if registry is None: - registry = {} - if registry.get('version', 0) != _filters_version: - registry.clear() - registry['version'] = _filters_version - if isinstance(message, Warning): - text = str(message) - category = message.__class__ - else: - text = message - message = category(message) - key = (text, category, lineno) - # Quick test for common case - if registry.get(key): - return - # Search the filters - for item in filters: - action, msg, cat, mod, ln = item - if ((msg is None or msg.match(text)) and - issubclass(category, cat) and - (mod is None or mod.match(module)) and - (ln == 0 or lineno == ln)): - break - else: - action = defaultaction - # Early exit actions - if action == "ignore": - return - - # Prime the linecache for formatting, in case the - # "file" is actually in a zipfile or something. - import linecache - linecache.getlines(filename, module_globals) - - if action == "error": - raise message - # Other actions - if action == "once": - registry[key] = 1 - oncekey = (text, category) - if onceregistry.get(oncekey): - return - onceregistry[oncekey] = 1 - elif action == "always": - pass - elif action == "module": - registry[key] = 1 - altkey = (text, category, 0) - if registry.get(altkey): - return - registry[altkey] = 1 - elif action == "default": - registry[key] = 1 - else: - # Unrecognized actions are errors - raise RuntimeError( - "Unrecognized action (%r) in warnings.filters:\n %s" % - (action, item)) - # Print message and context - msg = WarningMessage(message, category, filename, lineno, source=source) - _showwarnmsg(msg) - - -class WarningMessage(object): - - _WARNING_DETAILS = ("message", "category", "filename", "lineno", "file", - "line", "source") - - def __init__(self, message, category, filename, lineno, file=None, - line=None, source=None): - self.message = message - self.category = category - self.filename = filename - self.lineno = lineno - self.file = file - self.line = line - self.source = source - self._category_name = category.__name__ if category else None - - def __str__(self): - return ("{message : %r, category : %r, filename : %r, lineno : %s, " - "line : %r}" % (self.message, self._category_name, - self.filename, self.lineno, self.line)) - - -class catch_warnings(object): - - """A context manager that copies and restores the warnings filter upon - exiting the context. - - The 'record' argument specifies whether warnings should be captured by a - custom implementation of warnings.showwarning() and be appended to a list - returned by the context manager. Otherwise None is returned by the context - manager. The objects appended to the list are arguments whose attributes - mirror the arguments to showwarning(). - - The 'module' argument is to specify an alternative module to the module - named 'warnings' and imported under that name. This argument is only useful - when testing the warnings module itself. - - If the 'action' argument is not None, the remaining arguments are passed - to warnings.simplefilter() as if it were called immediately on entering the - context. - """ - - def __init__(self, *, record=False, module=None, - action=None, category=Warning, lineno=0, append=False): - """Specify whether to record warnings and if an alternative module - should be used other than sys.modules['warnings']. - - For compatibility with Python 3.0, please consider all arguments to be - keyword-only. - - """ - self._record = record - self._module = sys.modules['warnings'] if module is None else module - self._entered = False - if action is None: - self._filter = None - else: - self._filter = (action, category, lineno, append) - - def __repr__(self): - args = [] - if self._record: - args.append("record=True") - if self._module is not sys.modules['warnings']: - args.append("module=%r" % self._module) - name = type(self).__name__ - return "%s(%s)" % (name, ", ".join(args)) - - def __enter__(self): - if self._entered: - raise RuntimeError("Cannot enter %r twice" % self) - self._entered = True - self._filters = self._module.filters - self._module.filters = self._filters[:] - self._module._filters_mutated() - self._showwarning = self._module.showwarning - self._showwarnmsg_impl = self._module._showwarnmsg_impl - if self._filter is not None: - simplefilter(*self._filter) - if self._record: - log = [] - self._module._showwarnmsg_impl = log.append - # Reset showwarning() to the default implementation to make sure - # that _showwarnmsg() calls _showwarnmsg_impl() - self._module.showwarning = self._module._showwarning_orig - return log - else: - return None - - def __exit__(self, *exc_info): - if not self._entered: - raise RuntimeError("Cannot exit %r without entering first" % self) - self._module.filters = self._filters - self._module._filters_mutated() - self._module.showwarning = self._showwarning - self._module._showwarnmsg_impl = self._showwarnmsg_impl - - -class deprecated: - """Indicate that a class, function or overload is deprecated. - - When this decorator is applied to an object, the type checker - will generate a diagnostic on usage of the deprecated object. - - Usage: - - @deprecated("Use B instead") - class A: - pass - - @deprecated("Use g instead") - def f(): - pass - - @overload - @deprecated("int support is deprecated") - def g(x: int) -> int: ... - @overload - def g(x: str) -> int: ... - - The warning specified by *category* will be emitted at runtime - on use of deprecated objects. For functions, that happens on calls; - for classes, on instantiation and on creation of subclasses. - If the *category* is ``None``, no warning is emitted at runtime. - The *stacklevel* determines where the - warning is emitted. If it is ``1`` (the default), the warning - is emitted at the direct caller of the deprecated object; if it - is higher, it is emitted further up the stack. - Static type checker behavior is not affected by the *category* - and *stacklevel* arguments. - - The deprecation message passed to the decorator is saved in the - ``__deprecated__`` attribute on the decorated object. - If applied to an overload, the decorator - must be after the ``@overload`` decorator for the attribute to - exist on the overload as returned by ``get_overloads()``. - - See PEP 702 for details. - - """ - def __init__( - self, - message: str, - /, - *, - category: type[Warning] | None = DeprecationWarning, - stacklevel: int = 1, - ) -> None: - if not isinstance(message, str): - raise TypeError( - f"Expected an object of type str for 'message', not {type(message).__name__!r}" - ) - self.message = message - self.category = category - self.stacklevel = stacklevel - - def __call__(self, arg, /): - # Make sure the inner functions created below don't - # retain a reference to self. - msg = self.message - category = self.category - stacklevel = self.stacklevel - if category is None: - arg.__deprecated__ = msg - return arg - elif isinstance(arg, type): - import functools - from types import MethodType - - original_new = arg.__new__ - - @functools.wraps(original_new) - def __new__(cls, /, *args, **kwargs): - if cls is arg: - warn(msg, category=category, stacklevel=stacklevel + 1) - if original_new is not object.__new__: - return original_new(cls, *args, **kwargs) - # Mirrors a similar check in object.__new__. - elif cls.__init__ is object.__init__ and (args or kwargs): - raise TypeError(f"{cls.__name__}() takes no arguments") - else: - return original_new(cls) - - arg.__new__ = staticmethod(__new__) - - original_init_subclass = arg.__init_subclass__ - # We need slightly different behavior if __init_subclass__ - # is a bound method (likely if it was implemented in Python) - if isinstance(original_init_subclass, MethodType): - original_init_subclass = original_init_subclass.__func__ - - @functools.wraps(original_init_subclass) - def __init_subclass__(*args, **kwargs): - warn(msg, category=category, stacklevel=stacklevel + 1) - return original_init_subclass(*args, **kwargs) - - arg.__init_subclass__ = classmethod(__init_subclass__) - # Or otherwise, which likely means it's a builtin such as - # object's implementation of __init_subclass__. - else: - @functools.wraps(original_init_subclass) - def __init_subclass__(*args, **kwargs): - warn(msg, category=category, stacklevel=stacklevel + 1) - return original_init_subclass(*args, **kwargs) - - arg.__init_subclass__ = __init_subclass__ - - arg.__deprecated__ = __new__.__deprecated__ = msg - __init_subclass__.__deprecated__ = msg - return arg - elif callable(arg): - import functools - import inspect - - @functools.wraps(arg) - def wrapper(*args, **kwargs): - warn(msg, category=category, stacklevel=stacklevel + 1) - return arg(*args, **kwargs) - - if inspect.iscoroutinefunction(arg): - wrapper = inspect.markcoroutinefunction(wrapper) - - arg.__deprecated__ = wrapper.__deprecated__ = msg - return wrapper - else: - raise TypeError( - "@deprecated decorator with non-None category must be applied to " - f"a class or callable, not {arg!r}" - ) - - -_DEPRECATED_MSG = "{name!r} is deprecated and slated for removal in Python {remove}" - -def _deprecated(name, message=_DEPRECATED_MSG, *, remove, _version=sys.version_info): - """Warn that *name* is deprecated or should be removed. - - RuntimeError is raised if *remove* specifies a major/minor tuple older than - the current Python version or the same version but past the alpha. - - The *message* argument is formatted with *name* and *remove* as a Python - version tuple (e.g. (3, 11)). - - """ - remove_formatted = f"{remove[0]}.{remove[1]}" - if (_version[:2] > remove) or (_version[:2] == remove and _version[3] != "alpha"): - msg = f"{name!r} was slated for removal after Python {remove_formatted} alpha" - raise RuntimeError(msg) - else: - msg = message.format(name=name, remove=remove_formatted) - warn(msg, DeprecationWarning, stacklevel=3) - - -# Private utility function called by _PyErr_WarnUnawaitedCoroutine -def _warn_unawaited_coroutine(coro): - msg_lines = [ - f"coroutine '{coro.__qualname__}' was never awaited\n" - ] - if coro.cr_origin is not None: - import linecache, traceback - def extract(): - for filename, lineno, funcname in reversed(coro.cr_origin): - line = linecache.getline(filename, lineno) - yield (filename, lineno, funcname, line) - msg_lines.append("Coroutine created at (most recent call last)\n") - msg_lines += traceback.format_list(list(extract())) - msg = "".join(msg_lines).rstrip("\n") - # Passing source= here means that if the user happens to have tracemalloc - # enabled and tracking where the coroutine was created, the warning will - # contain that traceback. This does mean that if they have *both* - # coroutine origin tracking *and* tracemalloc enabled, they'll get two - # partially-redundant tracebacks. If we wanted to be clever we could - # probably detect this case and avoid it, but for now we don't bother. - warn(msg, category=RuntimeWarning, stacklevel=2, source=coro) - - -# filters contains a sequence of filter 5-tuples -# The components of the 5-tuple are: -# - an action: error, ignore, always, default, module, or once -# - a compiled regex that must match the warning message -# - a class representing the warning category -# - a compiled regex that must match the module that is being warned -# - a line number for the line being warning, or 0 to mean any line -# If either if the compiled regexs are None, match anything. try: - from _warnings import (filters, _defaultaction, _onceregistry, - warn, warn_explicit, _filters_mutated) - defaultaction = _defaultaction - onceregistry = _onceregistry + # Try to use the C extension, this will replace some parts of the + # _py_warnings implementation imported above. + from _warnings import ( + _acquire_lock, + _defaultaction as defaultaction, + _filters_mutated_lock_held, + _onceregistry as onceregistry, + _release_lock, + _warnings_context, + filters, + warn, + warn_explicit, + ) + _warnings_defaults = True -except ImportError: - filters = [] - defaultaction = "default" - onceregistry = {} - _filters_version = 1 + class _Lock: + def __enter__(self): + _acquire_lock() + return self - def _filters_mutated(): - global _filters_version - _filters_version += 1 + def __exit__(self, *args): + _release_lock() + _lock = _Lock() +except ImportError: _warnings_defaults = False # Module initialization +_set_module(sys.modules[__name__]) _processoptions(sys.warnoptions) if not _warnings_defaults: - # Several warning categories are ignored by default in regular builds - if not hasattr(sys, 'gettotalrefcount'): - filterwarnings("default", category=DeprecationWarning, - module="__main__", append=1) - simplefilter("ignore", category=DeprecationWarning, append=1) - simplefilter("ignore", category=PendingDeprecationWarning, append=1) - simplefilter("ignore", category=ImportWarning, append=1) - simplefilter("ignore", category=ResourceWarning, append=1) + _setup_defaults() del _warnings_defaults +del _setup_defaults diff --git a/crates/codegen/src/compile.rs b/crates/codegen/src/compile.rs index c531c5ff258..be7807ec411 100644 --- a/crates/codegen/src/compile.rs +++ b/crates/codegen/src/compile.rs @@ -21,18 +21,7 @@ use itertools::Itertools; use malachite_bigint::BigInt; use num_complex::Complex; use num_traits::{Num, ToPrimitive}; -use ruff_python_ast::{ - Alias, Arguments, BoolOp, CmpOp, Comprehension, ConversionFlag, DebugText, Decorator, DictItem, - ExceptHandler, ExceptHandlerExceptHandler, Expr, ExprAttribute, ExprBoolOp, ExprContext, - ExprFString, ExprList, ExprName, ExprSlice, ExprStarred, ExprSubscript, ExprTString, ExprTuple, - ExprUnaryOp, FString, FStringFlags, FStringPart, Identifier, Int, InterpolatedStringElement, - InterpolatedStringElements, Keyword, MatchCase, ModExpression, ModModule, Operator, Parameters, - Pattern, PatternMatchAs, PatternMatchClass, PatternMatchMapping, PatternMatchOr, - PatternMatchSequence, PatternMatchSingleton, PatternMatchStar, PatternMatchValue, Singleton, - Stmt, StmtAnnAssign, StmtExpr, TString, TypeParam, TypeParamParamSpec, TypeParamTypeVar, - TypeParamTypeVarTuple, TypeParams, UnaryOp, WithItem, - visitor::{Visitor, walk_expr}, -}; +use ruff_python_ast as ast; use ruff_text_size::{Ranged, TextRange}; use std::collections::HashSet; @@ -71,7 +60,7 @@ pub enum FBlockType { pub enum FBlockDatum { None, /// For FinallyTry: stores the finally body statements to compile during unwind - FinallyBody(Vec), + FinallyBody(Vec), /// For HandlerCleanup: stores the exception variable name (e.g., "e" in "except X as e") ExceptionName(String), } @@ -81,8 +70,8 @@ pub enum FBlockDatum { enum SuperCallType<'a> { /// super(class, self) - explicit 2-argument form TwoArg { - class_arg: &'a Expr, - self_arg: &'a Expr, + class_arg: &'a ast::Expr, + self_arg: &'a ast::Expr, }, /// super() - implicit 0-argument form (uses __class__ cell) ZeroArg, @@ -178,7 +167,7 @@ enum ComprehensionType { Dict, } -fn validate_duplicate_params(params: &Parameters) -> Result<(), CodegenErrorType> { +fn validate_duplicate_params(params: &ast::Parameters) -> Result<(), CodegenErrorType> { let mut seen_params = HashSet::new(); for param in params { let param_name = param.name().as_str(); @@ -211,7 +200,7 @@ pub fn compile_top( /// Compile a standard Python program to bytecode pub fn compile_program( - ast: &ModModule, + ast: &ast::ModModule, source_file: SourceFile, opts: CompileOpts, ) -> CompileResult { @@ -226,7 +215,7 @@ pub fn compile_program( /// Compile a Python program to bytecode for the context of a REPL pub fn compile_program_single( - ast: &ModModule, + ast: &ast::ModModule, source_file: SourceFile, opts: CompileOpts, ) -> CompileResult { @@ -240,7 +229,7 @@ pub fn compile_program_single( } pub fn compile_block_expression( - ast: &ModModule, + ast: &ast::ModModule, source_file: SourceFile, opts: CompileOpts, ) -> CompileResult { @@ -254,7 +243,7 @@ pub fn compile_block_expression( } pub fn compile_expression( - ast: &ModExpression, + ast: &ast::ModExpression, source_file: SourceFile, opts: CompileOpts, ) -> CompileResult { @@ -436,13 +425,13 @@ impl Compiler { /// Check if the slice is a two-element slice (no step) // = is_two_element_slice - const fn is_two_element_slice(slice: &Expr) -> bool { - matches!(slice, Expr::Slice(s) if s.step.is_none()) + const fn is_two_element_slice(slice: &ast::Expr) -> bool { + matches!(slice, ast::Expr::Slice(s) if s.step.is_none()) } /// Compile a slice expression // = compiler_slice - fn compile_slice(&mut self, s: &ExprSlice) -> CompileResult { + fn compile_slice(&mut self, s: &ast::ExprSlice) -> CompileResult { // Compile lower if let Some(lower) = &s.lower { self.compile_expression(lower)?; @@ -471,9 +460,9 @@ impl Compiler { // = compiler_subscript fn compile_subscript( &mut self, - value: &Expr, - slice: &Expr, - ctx: ExprContext, + value: &ast::Expr, + slice: &ast::Expr, + ctx: ast::ExprContext, ) -> CompileResult<()> { // 1. Check subscripter and index for Load context // 2. VISIT value @@ -481,7 +470,7 @@ impl Compiler { // 4. Otherwise VISIT slice and emit appropriate instruction // For Load context, some checks are skipped for now - // if ctx == ExprContext::Load { + // if ctx == ast::ExprContext::Load { // check_subscripter(value); // check_index(value, slice); // } @@ -490,17 +479,19 @@ impl Compiler { self.compile_expression(value)?; // Handle two-element slice (for Load/Store, not Del) - if Self::is_two_element_slice(slice) && !matches!(ctx, ExprContext::Del) { + if Self::is_two_element_slice(slice) && !matches!(ctx, ast::ExprContext::Del) { let argc = match slice { - Expr::Slice(s) => self.compile_slice(s)?, - _ => unreachable!("is_two_element_slice should only return true for Expr::Slice"), + ast::Expr::Slice(s) => self.compile_slice(s)?, + _ => unreachable!( + "is_two_element_slice should only return true for ast::Expr::Slice" + ), }; match ctx { - ExprContext::Load => { + ast::ExprContext::Load => { emit!(self, Instruction::BuildSlice { argc }); emit!(self, Instruction::Subscript); } - ExprContext::Store => { + ast::ExprContext::Store => { emit!(self, Instruction::BuildSlice { argc }); emit!(self, Instruction::StoreSubscr); } @@ -512,10 +503,10 @@ impl Compiler { // Emit appropriate instruction based on context match ctx { - ExprContext::Load => emit!(self, Instruction::Subscript), - ExprContext::Store => emit!(self, Instruction::StoreSubscr), - ExprContext::Del => emit!(self, Instruction::DeleteSubscr), - ExprContext::Invalid => { + ast::ExprContext::Load => emit!(self, Instruction::Subscript), + ast::ExprContext::Store => emit!(self, Instruction::StoreSubscr), + ast::ExprContext::Del => emit!(self, Instruction::DeleteSubscr), + ast::ExprContext::Invalid => { return Err(self.error(CodegenErrorType::SyntaxError( "Invalid expression context".to_owned(), ))); @@ -528,7 +519,7 @@ impl Compiler { /// Helper function for compiling tuples/lists/sets with starred expressions /// - /// Parameters: + /// ast::Parameters: /// - elts: The elements to compile /// - pushed: Number of items already on the stack /// - collection_type: What type of collection to build (tuple, list, set) @@ -536,7 +527,7 @@ impl Compiler { // = starunpack_helper in compile.c fn starunpack_helper( &mut self, - elts: &[Expr], + elts: &[ast::Expr], pushed: u32, collection_type: CollectionType, ) -> CompileResult<()> { @@ -735,13 +726,11 @@ impl Compiler { /// Returns Some(SuperCallType) if optimization is possible, None otherwise fn can_optimize_super_call<'a>( &self, - value: &'a Expr, + value: &'a ast::Expr, attr: &str, ) -> Option> { - use ruff_python_ast::*; - // 1. value must be a Call expression - let Expr::Call(ExprCall { + let ast::Expr::Call(ast::ExprCall { func, arguments, .. }) = value else { @@ -749,7 +738,7 @@ impl Compiler { }; // 2. func must be Name("super") - let Expr::Name(ExprName { id, .. }) = func.as_ref() else { + let ast::Expr::Name(ast::ExprName { id, .. }) = func.as_ref() else { return None; }; if id.as_str() != "super" { @@ -791,7 +780,7 @@ impl Compiler { let args = &arguments.args; // No starred expressions allowed - if args.iter().any(|arg| matches!(arg, Expr::Starred(_))) { + if args.iter().any(|arg| matches!(arg, ast::Expr::Starred(_))) { return None; } @@ -1113,7 +1102,7 @@ impl Compiler { let _table = self.pop_symbol_table(); // Various scopes can have sub_tables: - // - TypeParams scope can have sub_tables (the function body's symbol table) + // - ast::TypeParams scope can have sub_tables (the function body's symbol table) // - Module scope can have sub_tables (for TypeAlias scopes, nested functions, classes) // - Function scope can have sub_tables (for nested functions, classes) // - Class scope can have sub_tables (for nested classes, methods) @@ -1553,7 +1542,7 @@ impl Compiler { let mut parent_idx = stack_size - 2; let mut parent = &self.code_stack[parent_idx]; - // If parent is TypeParams scope, look at grandparent + // If parent is ast::TypeParams scope, look at grandparent // Check if parent is a type params scope by name pattern if parent.metadata.name.starts_with(" CompileResult<()> { let size_before = self.code_stack.len(); @@ -1672,7 +1661,7 @@ impl Compiler { fn compile_program_single( &mut self, - body: &[Stmt], + body: &[ast::Stmt], symbol_table: SymbolTable, ) -> CompileResult<()> { // Set future_annotations from symbol table (detected during symbol table scan) @@ -1698,7 +1687,7 @@ impl Compiler { if let Some((last, body)) = body.split_last() { for statement in body { - if let Stmt::Expr(StmtExpr { value, .. }) = &statement { + if let ast::Stmt::Expr(ast::StmtExpr { value, .. }) = &statement { self.compile_expression(value)?; emit!( self, @@ -1713,7 +1702,7 @@ impl Compiler { } } - if let Stmt::Expr(StmtExpr { value, .. }) = &last { + if let ast::Stmt::Expr(ast::StmtExpr { value, .. }) = &last { self.compile_expression(value)?; emit!(self, Instruction::Copy { index: 1_u32 }); emit!( @@ -1738,7 +1727,7 @@ impl Compiler { fn compile_block_expr( &mut self, - body: &[Stmt], + body: &[ast::Stmt], symbol_table: SymbolTable, ) -> CompileResult<()> { self.symbol_table_stack.push(symbol_table); @@ -1747,10 +1736,10 @@ impl Compiler { if let Some(last_statement) = body.last() { match last_statement { - Stmt::Expr(_) => { + ast::Stmt::Expr(_) => { self.current_block().instructions.pop(); // pop Instruction::PopTop } - Stmt::FunctionDef(_) | Stmt::ClassDef(_) => { + ast::Stmt::FunctionDef(_) | ast::Stmt::ClassDef(_) => { let pop_instructions = self.current_block().instructions.pop(); let store_inst = compiler_unwrap_option(self, pop_instructions); // pop Instruction::Store emit!(self, Instruction::Copy { index: 1_u32 }); @@ -1767,7 +1756,7 @@ impl Compiler { // Compile statement in eval mode: fn compile_eval( &mut self, - expression: &ModExpression, + expression: &ast::ModExpression, symbol_table: SymbolTable, ) -> CompileResult<()> { self.symbol_table_stack.push(symbol_table); @@ -1776,7 +1765,7 @@ impl Compiler { Ok(()) } - fn compile_statements(&mut self, statements: &[Stmt]) -> CompileResult<()> { + fn compile_statements(&mut self, statements: &[ast::Stmt]) -> CompileResult<()> { for statement in statements { self.compile_statement(statement)? } @@ -1823,7 +1812,7 @@ impl Compiler { // Determine the operation type based on symbol scope let is_function_like = self.ctx.in_func(); - // Look up the symbol, handling TypeParams and Annotation scopes specially + // Look up the symbol, handling ast::TypeParams and Annotation scopes specially let (symbol_scope, can_see_class_scope) = { let current_table = self.current_symbol_table(); let is_typeparams = current_table.typ == CompilerScope::TypeParams; @@ -1833,7 +1822,7 @@ impl Compiler { // First try to find in current table let symbol = current_table.lookup(name.as_ref()); - // If not found and we're in TypeParams or Annotation scope, try parent scope + // If not found and we're in ast::TypeParams or Annotation scope, try parent scope let symbol = if symbol.is_none() && (is_typeparams || is_annotation) { self.symbol_table_stack .get(self.symbol_table_stack.len() - 2) // Try to get parent index @@ -1984,22 +1973,21 @@ impl Compiler { Ok(()) } - fn compile_statement(&mut self, statement: &Stmt) -> CompileResult<()> { - use ruff_python_ast::*; + fn compile_statement(&mut self, statement: &ast::Stmt) -> CompileResult<()> { trace!("Compiling {statement:?}"); self.set_source_range(statement.range()); match &statement { // we do this here because `from __future__` still executes that `from` statement at runtime, // we still need to compile the ImportFrom down below - Stmt::ImportFrom(StmtImportFrom { module, names, .. }) + ast::Stmt::ImportFrom(ast::StmtImportFrom { module, names, .. }) if module.as_ref().map(|id| id.as_str()) == Some("__future__") => { self.compile_future_features(names)? } // ignore module-level doc comments - Stmt::Expr(StmtExpr { value, .. }) - if matches!(&**value, Expr::StringLiteral(..)) + ast::Stmt::Expr(ast::StmtExpr { value, .. }) + if matches!(&**value, ast::Expr::StringLiteral(..)) && matches!(self.done_with_future_stmts, DoneWithFuture::No) => { self.done_with_future_stmts = DoneWithFuture::DoneWithDoc @@ -2009,7 +1997,7 @@ impl Compiler { } match &statement { - Stmt::Import(StmtImport { names, .. }) => { + ast::Stmt::Import(ast::StmtImport { names, .. }) => { // import a, b, c as d for name in names { let name = &name; @@ -2030,7 +2018,7 @@ impl Compiler { } } } - Stmt::ImportFrom(StmtImportFrom { + ast::Stmt::ImportFrom(ast::StmtImportFrom { level, module, names, @@ -2096,16 +2084,16 @@ impl Compiler { emit!(self, Instruction::PopTop); } } - Stmt::Expr(StmtExpr { value, .. }) => { + ast::Stmt::Expr(ast::StmtExpr { value, .. }) => { self.compile_expression(value)?; // Pop result of stack, since we not use it: emit!(self, Instruction::PopTop); } - Stmt::Global(_) | Stmt::Nonlocal(_) => { + ast::Stmt::Global(_) | ast::Stmt::Nonlocal(_) => { // Handled during symbol table construction. } - Stmt::If(StmtIf { + ast::Stmt::If(ast::StmtIf { test, body, elif_else_clauses, @@ -2161,16 +2149,16 @@ impl Compiler { } self.leave_conditional_block(); } - Stmt::While(StmtWhile { + ast::Stmt::While(ast::StmtWhile { test, body, orelse, .. }) => self.compile_while(test, body, orelse)?, - Stmt::With(StmtWith { + ast::Stmt::With(ast::StmtWith { items, body, is_async, .. }) => self.compile_with(items, body, *is_async)?, - Stmt::For(StmtFor { + ast::Stmt::For(ast::StmtFor { target, iter, body, @@ -2178,8 +2166,10 @@ impl Compiler { is_async, .. }) => self.compile_for(target, iter, body, orelse, *is_async)?, - Stmt::Match(StmtMatch { subject, cases, .. }) => self.compile_match(subject, cases)?, - Stmt::Raise(StmtRaise { + ast::Stmt::Match(ast::StmtMatch { subject, cases, .. }) => { + self.compile_match(subject, cases)? + } + ast::Stmt::Raise(ast::StmtRaise { exc, cause, range, .. }) => { let kind = match exc { @@ -2198,7 +2188,7 @@ impl Compiler { self.set_source_range(*range); emit!(self, Instruction::RaiseVarargs { kind }); } - Stmt::Try(StmtTry { + ast::Stmt::Try(ast::StmtTry { body, handlers, orelse, @@ -2214,7 +2204,7 @@ impl Compiler { } self.leave_conditional_block(); } - Stmt::FunctionDef(StmtFunctionDef { + ast::Stmt::FunctionDef(ast::StmtFunctionDef { name, parameters, body, @@ -2236,7 +2226,7 @@ impl Compiler { type_params.as_deref(), )? } - Stmt::ClassDef(StmtClassDef { + ast::Stmt::ClassDef(ast::StmtClassDef { name, body, decorator_list, @@ -2250,7 +2240,7 @@ impl Compiler { type_params.as_deref(), arguments.as_deref(), )?, - Stmt::Assert(StmtAssert { test, msg, .. }) => { + ast::Stmt::Assert(ast::StmtAssert { test, msg, .. }) => { // if some flag, ignore all assert statements! if self.opts.optimize == 0 { let after_block = self.new_block(); @@ -2278,15 +2268,15 @@ impl Compiler { self.switch_to_block(after_block); } } - Stmt::Break(_) => { + ast::Stmt::Break(_) => { // Unwind fblock stack until we find a loop, emitting cleanup for each fblock self.compile_break_continue(statement.range(), true)?; } - Stmt::Continue(_) => { + ast::Stmt::Continue(_) => { // Unwind fblock stack until we find a loop, emitting cleanup for each fblock self.compile_break_continue(statement.range(), false)?; } - Stmt::Return(StmtReturn { value, .. }) => { + ast::Stmt::Return(ast::StmtReturn { value, .. }) => { if !self.ctx.in_func() { return Err( self.error_ranged(CodegenErrorType::InvalidReturn, statement.range()) @@ -2318,7 +2308,7 @@ impl Compiler { } } } - Stmt::Assign(StmtAssign { targets, value, .. }) => { + ast::Stmt::Assign(ast::StmtAssign { targets, value, .. }) => { self.compile_expression(value)?; for (i, target) in targets.iter().enumerate() { @@ -2328,25 +2318,25 @@ impl Compiler { self.compile_store(target)?; } } - Stmt::AugAssign(StmtAugAssign { + ast::Stmt::AugAssign(ast::StmtAugAssign { target, op, value, .. }) => self.compile_augassign(target, op, value)?, - Stmt::AnnAssign(StmtAnnAssign { + ast::Stmt::AnnAssign(ast::StmtAnnAssign { target, annotation, value, simple, .. }) => self.compile_annotated_assign(target, annotation, value.as_deref(), *simple)?, - Stmt::Delete(StmtDelete { targets, .. }) => { + ast::Stmt::Delete(ast::StmtDelete { targets, .. }) => { for target in targets { self.compile_delete(target)?; } } - Stmt::Pass(_) => { + ast::Stmt::Pass(_) => { // No need to emit any code here :) } - Stmt::TypeAlias(StmtTypeAlias { + ast::Stmt::TypeAlias(ast::StmtTypeAlias { name, type_params, value, @@ -2402,31 +2392,33 @@ impl Compiler { ); self.store_name(&name_string)?; } - Stmt::IpyEscapeCommand(_) => todo!(), + ast::Stmt::IpyEscapeCommand(_) => todo!(), } Ok(()) } - fn compile_delete(&mut self, expression: &Expr) -> CompileResult<()> { - use ruff_python_ast::*; + fn compile_delete(&mut self, expression: &ast::Expr) -> CompileResult<()> { match &expression { - Expr::Name(ExprName { id, .. }) => self.compile_name(id.as_str(), NameUsage::Delete)?, - Expr::Attribute(ExprAttribute { value, attr, .. }) => { + ast::Expr::Name(ast::ExprName { id, .. }) => { + self.compile_name(id.as_str(), NameUsage::Delete)? + } + ast::Expr::Attribute(ast::ExprAttribute { value, attr, .. }) => { self.compile_expression(value)?; let idx = self.name(attr.as_str()); emit!(self, Instruction::DeleteAttr { idx }); } - Expr::Subscript(ExprSubscript { + ast::Expr::Subscript(ast::ExprSubscript { value, slice, ctx, .. }) => { self.compile_subscript(value, slice, *ctx)?; } - Expr::Tuple(ExprTuple { elts, .. }) | Expr::List(ExprList { elts, .. }) => { + ast::Expr::Tuple(ast::ExprTuple { elts, .. }) + | ast::Expr::List(ast::ExprList { elts, .. }) => { for element in elts { self.compile_delete(element)?; } } - Expr::BinOp(_) | Expr::UnaryOp(_) => { + ast::Expr::BinOp(_) | ast::Expr::UnaryOp(_) => { return Err(self.error(CodegenErrorType::Delete("expression"))); } _ => return Err(self.error(CodegenErrorType::Delete(expression.python_name()))), @@ -2434,7 +2426,7 @@ impl Compiler { Ok(()) } - fn enter_function(&mut self, name: &str, parameters: &Parameters) -> CompileResult<()> { + fn enter_function(&mut self, name: &str, parameters: &ast::Parameters) -> CompileResult<()> { // TODO: partition_in_place let mut kw_without_defaults = vec![]; let mut kw_with_defaults = vec![]; @@ -2478,7 +2470,7 @@ impl Compiler { /// Push decorators onto the stack in source order. /// For @dec1 @dec2 def foo(): stack becomes [dec1, NULL, dec2, NULL] - fn prepare_decorators(&mut self, decorator_list: &[Decorator]) -> CompileResult<()> { + fn prepare_decorators(&mut self, decorator_list: &[ast::Decorator]) -> CompileResult<()> { for decorator in decorator_list { self.compile_expression(&decorator.expression)?; emit!(self, Instruction::PushNull); @@ -2490,7 +2482,7 @@ impl Compiler { /// Stack [dec1, NULL, dec2, NULL, func] -> dec2(func) -> dec1(dec2(func)) /// The forward loop works because each Call pops from TOS, naturally /// applying decorators bottom-up (innermost first). - fn apply_decorators(&mut self, decorator_list: &[Decorator]) { + fn apply_decorators(&mut self, decorator_list: &[ast::Decorator]) { for _ in decorator_list { emit!(self, Instruction::Call { nargs: 1 }); } @@ -2499,7 +2491,7 @@ impl Compiler { /// Compile type parameter bound or default in a separate scope and return closure fn compile_type_param_bound_or_default( &mut self, - expr: &Expr, + expr: &ast::Expr, name: &str, allow_starred: bool, ) -> CompileResult<()> { @@ -2514,8 +2506,8 @@ impl Compiler { self.enter_scope(name, CompilerScope::TypeParams, key, lineno)?; // Compile the expression - if allow_starred && matches!(expr, Expr::Starred(_)) { - if let Expr::Starred(starred) = expr { + if allow_starred && matches!(expr, ast::Expr::Starred(_)) { + if let ast::Expr::Starred(starred) = expr { self.compile_expression(&starred.value)?; emit!(self, Instruction::UnpackSequence { size: 1 }); } @@ -2542,11 +2534,11 @@ impl Compiler { /// Store each type parameter so it is accessible to the current scope, and leave a tuple of /// all the type parameters on the stack. Handles default values per PEP 695. - fn compile_type_params(&mut self, type_params: &TypeParams) -> CompileResult<()> { + fn compile_type_params(&mut self, type_params: &ast::TypeParams) -> CompileResult<()> { // First, compile each type parameter and store it for type_param in &type_params.type_params { match type_param { - TypeParam::TypeVar(TypeParamTypeVar { + ast::TypeParam::TypeVar(ast::TypeParamTypeVar { name, bound, default, @@ -2593,7 +2585,7 @@ impl Compiler { emit!(self, Instruction::Copy { index: 1_u32 }); self.store_name(name.as_ref())?; } - TypeParam::ParamSpec(TypeParamParamSpec { name, default, .. }) => { + ast::TypeParam::ParamSpec(ast::TypeParamParamSpec { name, default, .. }) => { self.emit_load_const(ConstantData::Str { value: name.as_str().into(), }); @@ -2618,7 +2610,9 @@ impl Compiler { emit!(self, Instruction::Copy { index: 1_u32 }); self.store_name(name.as_ref())?; } - TypeParam::TypeVarTuple(TypeParamTypeVarTuple { name, default, .. }) => { + ast::TypeParam::TypeVarTuple(ast::TypeParamTypeVarTuple { + name, default, .. + }) => { self.emit_load_const(ConstantData::Str { value: name.as_str().into(), }); @@ -2657,10 +2651,10 @@ impl Compiler { fn compile_try_statement( &mut self, - body: &[Stmt], - handlers: &[ExceptHandler], - orelse: &[Stmt], - finalbody: &[Stmt], + body: &[ast::Stmt], + handlers: &[ast::ExceptHandler], + orelse: &[ast::Stmt], + finalbody: &[ast::Stmt], ) -> CompileResult<()> { let handler_block = self.new_block(); let finally_block = self.new_block(); @@ -2843,8 +2837,11 @@ impl Compiler { // PUSH_EXC_INFO transforms [exc] -> [prev_exc, exc] for PopExcept emit!(self, Instruction::PushExcInfo); for handler in handlers { - let ExceptHandler::ExceptHandler(ExceptHandlerExceptHandler { - type_, name, body, .. + let ast::ExceptHandler::ExceptHandler(ast::ExceptHandlerExceptHandler { + type_, + name, + body, + .. }) = &handler; let next_handler = self.new_block(); @@ -3116,10 +3113,10 @@ impl Compiler { fn compile_try_star_except( &mut self, - body: &[Stmt], - handlers: &[ExceptHandler], - orelse: &[Stmt], - finalbody: &[Stmt], + body: &[ast::Stmt], + handlers: &[ast::ExceptHandler], + orelse: &[ast::Stmt], + finalbody: &[ast::Stmt], ) -> CompileResult<()> { // compiler_try_star_except // Stack layout during handler processing: [prev_exc, orig, list, rest] @@ -3176,8 +3173,11 @@ impl Compiler { let n = handlers.len(); for (i, handler) in handlers.iter().enumerate() { - let ExceptHandler::ExceptHandler(ExceptHandlerExceptHandler { - type_, name, body, .. + let ast::ExceptHandler::ExceptHandler(ast::ExceptHandlerExceptHandler { + type_, + name, + body, + .. }) = handler; let no_match_block = self.new_block(); @@ -3197,7 +3197,7 @@ impl Compiler { // Compile exception type if let Some(exc_type) = type_ { // Check for unparenthesized tuple - if let Expr::Tuple(ExprTuple { elts, range, .. }) = exc_type.as_ref() + if let ast::Expr::Tuple(ast::ExprTuple { elts, range, .. }) = exc_type.as_ref() && let Some(first) = elts.first() && range.start().to_u32() == first.range().start().to_u32() { @@ -3434,7 +3434,7 @@ impl Compiler { // = compiler_default_arguments fn compile_default_arguments( &mut self, - parameters: &Parameters, + parameters: &ast::Parameters, ) -> CompileResult { let mut funcflags = bytecode::MakeFunctionFlags::empty(); @@ -3492,8 +3492,8 @@ impl Compiler { fn compile_function_body( &mut self, name: &str, - parameters: &Parameters, - body: &[Stmt], + parameters: &ast::Parameters, + body: &[ast::Stmt], is_async: bool, funcflags: bytecode::MakeFunctionFlags, ) -> CompileResult<()> { @@ -3539,7 +3539,7 @@ impl Compiler { // Emit None at end if needed match body.last() { - Some(Stmt::Return(_)) => {} + Some(ast::Stmt::Return(_)) => {} _ => { self.emit_return_const(ConstantData::None); } @@ -3564,8 +3564,8 @@ impl Compiler { fn compile_annotations_closure( &mut self, func_name: &str, - parameters: &Parameters, - returns: Option<&Expr>, + parameters: &ast::Parameters, + returns: Option<&ast::Expr>, ) -> CompileResult { // Try to enter annotation scope - returns false if no annotation_block exists if !self.enter_annotation_scope(func_name)? { @@ -3630,29 +3630,74 @@ impl Compiler { Ok(true) } - /// Collect simple (non-conditional) annotations from module body + /// Collect simple annotations from module body in AST order (including nested blocks) /// Returns list of (name, annotation_expr) pairs - fn collect_simple_annotations(body: &[Stmt]) -> Vec<(&str, &Expr)> { - let mut annotations = Vec::new(); - for stmt in body { - if let Stmt::AnnAssign(StmtAnnAssign { - target, - annotation, - simple, - .. - }) = stmt - && *simple - && let Expr::Name(ExprName { id, .. }) = target.as_ref() - { - annotations.push((id.as_str(), annotation.as_ref())); + /// This must match the order that annotations are compiled to ensure + /// conditional_annotation_index stays in sync with __annotate__ enumeration. + fn collect_simple_annotations(body: &[ast::Stmt]) -> Vec<(&str, &ast::Expr)> { + fn walk<'a>(stmts: &'a [ast::Stmt], out: &mut Vec<(&'a str, &'a ast::Expr)>) { + for stmt in stmts { + match stmt { + ast::Stmt::AnnAssign(ast::StmtAnnAssign { + target, + annotation, + simple, + .. + }) if *simple && matches!(target.as_ref(), ast::Expr::Name(_)) => { + if let ast::Expr::Name(ast::ExprName { id, .. }) = target.as_ref() { + out.push((id.as_str(), annotation.as_ref())); + } + } + ast::Stmt::If(ast::StmtIf { + body, + elif_else_clauses, + .. + }) => { + walk(body, out); + for clause in elif_else_clauses { + walk(&clause.body, out); + } + } + ast::Stmt::For(ast::StmtFor { body, orelse, .. }) + | ast::Stmt::While(ast::StmtWhile { body, orelse, .. }) => { + walk(body, out); + walk(orelse, out); + } + ast::Stmt::With(ast::StmtWith { body, .. }) => walk(body, out), + ast::Stmt::Try(ast::StmtTry { + body, + handlers, + orelse, + finalbody, + .. + }) => { + walk(body, out); + for handler in handlers { + let ast::ExceptHandler::ExceptHandler( + ast::ExceptHandlerExceptHandler { body, .. }, + ) = handler; + walk(body, out); + } + walk(orelse, out); + walk(finalbody, out); + } + ast::Stmt::Match(ast::StmtMatch { cases, .. }) => { + for case in cases { + walk(&case.body, out); + } + } + _ => {} + } } } + let mut annotations = Vec::new(); + walk(body, &mut annotations); annotations } /// Compile module-level __annotate__ function (PEP 649) /// Returns true if __annotate__ was created and stored - fn compile_module_annotate(&mut self, body: &[Stmt]) -> CompileResult { + fn compile_module_annotate(&mut self, body: &[ast::Stmt]) -> CompileResult { // Collect simple annotations from module body first let annotations = Self::collect_simple_annotations(body); @@ -3795,12 +3840,12 @@ impl Compiler { fn compile_function_def( &mut self, name: &str, - parameters: &Parameters, - body: &[Stmt], - decorator_list: &[Decorator], - returns: Option<&Expr>, // TODO: use type hint somehow.. + parameters: &ast::Parameters, + body: &[ast::Stmt], + decorator_list: &[ast::Decorator], + returns: Option<&ast::Expr>, // TODO: use type hint somehow.. is_async: bool, - type_params: Option<&TypeParams>, + type_params: Option<&ast::TypeParams>, ) -> CompileResult<()> { self.prepare_decorators(decorator_list)?; @@ -4116,15 +4161,14 @@ impl Compiler { } // Python/compile.c find_ann - fn find_ann(body: &[Stmt]) -> bool { - use ruff_python_ast::*; + fn find_ann(body: &[ast::Stmt]) -> bool { for statement in body { let res = match &statement { - Stmt::AnnAssign(_) => true, - Stmt::For(StmtFor { body, orelse, .. }) => { + ast::Stmt::AnnAssign(_) => true, + ast::Stmt::For(ast::StmtFor { body, orelse, .. }) => { Self::find_ann(body) || Self::find_ann(orelse) } - Stmt::If(StmtIf { + ast::Stmt::If(ast::StmtIf { body, elif_else_clauses, .. @@ -4132,11 +4176,11 @@ impl Compiler { Self::find_ann(body) || elif_else_clauses.iter().any(|x| Self::find_ann(&x.body)) } - Stmt::While(StmtWhile { body, orelse, .. }) => { + ast::Stmt::While(ast::StmtWhile { body, orelse, .. }) => { Self::find_ann(body) || Self::find_ann(orelse) } - Stmt::With(StmtWith { body, .. }) => Self::find_ann(body), - Stmt::Try(StmtTry { + ast::Stmt::With(ast::StmtWith { body, .. }) => Self::find_ann(body), + ast::Stmt::Try(ast::StmtTry { body, orelse, finalbody, @@ -4156,8 +4200,8 @@ impl Compiler { fn compile_class_body( &mut self, name: &str, - body: &[Stmt], - type_params: Option<&TypeParams>, + body: &[ast::Stmt], + type_params: Option<&ast::TypeParams>, firstlineno: u32, ) -> CompileResult { // 1. Enter class scope @@ -4271,10 +4315,10 @@ impl Compiler { fn compile_class_def( &mut self, name: &str, - body: &[Stmt], - decorator_list: &[Decorator], - type_params: Option<&TypeParams>, - arguments: Option<&Arguments>, + body: &[ast::Stmt], + decorator_list: &[ast::Decorator], + type_params: Option<&ast::TypeParams>, + arguments: Option<&ast::Arguments>, ) -> CompileResult<()> { self.prepare_decorators(decorator_list)?; @@ -4343,8 +4387,11 @@ impl Compiler { // Compile bases and call __build_class__ // Check for starred bases or **kwargs - let has_starred = arguments - .is_some_and(|args| args.args.iter().any(|arg| matches!(arg, Expr::Starred(_)))); + let has_starred = arguments.is_some_and(|args| { + args.args + .iter() + .any(|arg| matches!(arg, ast::Expr::Starred(_))) + }); let has_double_star = arguments.is_some_and(|args| args.keywords.iter().any(|kw| kw.arg.is_none())); @@ -4376,11 +4423,12 @@ impl Compiler { } // Build kwargs if needed - let has_kwargs = arguments.is_some_and(|args| !args.keywords.is_empty()); - if has_kwargs { + if arguments.is_some_and(|args| !args.keywords.is_empty()) { self.compile_keywords(&arguments.unwrap().keywords)?; + } else { + emit!(self, Instruction::PushNull); } - emit!(self, Instruction::CallFunctionEx { has_kwargs }); + emit!(self, Instruction::CallFunctionEx); } else { // Simple case: no starred bases, no **kwargs // Compile bases normally @@ -4459,7 +4507,12 @@ impl Compiler { self.store_name(name) } - fn compile_while(&mut self, test: &Expr, body: &[Stmt], orelse: &[Stmt]) -> CompileResult<()> { + fn compile_while( + &mut self, + test: &ast::Expr, + body: &[ast::Stmt], + orelse: &[ast::Stmt], + ) -> CompileResult<()> { self.enter_conditional_block(); let while_block = self.new_block(); @@ -4497,8 +4550,8 @@ impl Compiler { fn compile_with( &mut self, - items: &[WithItem], - body: &[Stmt], + items: &[ast::WithItem], + body: &[ast::Stmt], is_async: bool, ) -> CompileResult<()> { self.enter_conditional_block(); @@ -4721,10 +4774,10 @@ impl Compiler { fn compile_for( &mut self, - target: &Expr, - iter: &Expr, - body: &[Stmt], - orelse: &[Stmt], + target: &ast::Expr, + iter: &ast::Expr, + body: &[ast::Stmt], + orelse: &[ast::Stmt], is_async: bool, ) -> CompileResult<()> { self.enter_conditional_block(); @@ -4909,7 +4962,7 @@ impl Compiler { /// to the list of captured names. fn pattern_helper_store_name( &mut self, - n: Option<&Identifier>, + n: Option<&ast::Identifier>, pc: &mut PatternContext, ) -> CompileResult<()> { match n { @@ -4943,7 +4996,7 @@ impl Compiler { } } - fn pattern_unpack_helper(&mut self, elts: &[Pattern]) -> CompileResult<()> { + fn pattern_unpack_helper(&mut self, elts: &[ast::Pattern]) -> CompileResult<()> { let n = elts.len(); let mut seen_star = false; for (i, elt) in elts.iter().enumerate() { @@ -4979,7 +5032,7 @@ impl Compiler { fn pattern_helper_sequence_unpack( &mut self, - patterns: &[Pattern], + patterns: &[ast::Pattern], _star: Option, pc: &mut PatternContext, ) -> CompileResult<()> { @@ -4999,7 +5052,7 @@ impl Compiler { fn pattern_helper_sequence_subscr( &mut self, - patterns: &[Pattern], + patterns: &[ast::Pattern], star: usize, pc: &mut PatternContext, ) -> CompileResult<()> { @@ -5047,7 +5100,7 @@ impl Compiler { fn compile_pattern_subpattern( &mut self, - p: &Pattern, + p: &ast::Pattern, pc: &mut PatternContext, ) -> CompileResult<()> { // Save the current allow_irrefutable state. @@ -5063,7 +5116,7 @@ impl Compiler { fn compile_pattern_as( &mut self, - p: &PatternMatchAs, + p: &ast::PatternMatchAs, pc: &mut PatternContext, ) -> CompileResult<()> { // If there is no sub-pattern, then it's an irrefutable match. @@ -5100,7 +5153,7 @@ impl Compiler { fn compile_pattern_star( &mut self, - p: &PatternMatchStar, + p: &ast::PatternMatchStar, pc: &mut PatternContext, ) -> CompileResult<()> { self.pattern_helper_store_name(p.name.as_ref(), pc)?; @@ -5111,8 +5164,8 @@ impl Compiler { /// and not duplicated. fn validate_kwd_attrs( &mut self, - attrs: &[Identifier], - _patterns: &[Pattern], + attrs: &[ast::Identifier], + _patterns: &[ast::Pattern], ) -> CompileResult<()> { let n_attrs = attrs.len(); for i in 0..n_attrs { @@ -5135,7 +5188,7 @@ impl Compiler { fn compile_pattern_class( &mut self, - p: &PatternMatchClass, + p: &ast::PatternMatchClass, pc: &mut PatternContext, ) -> CompileResult<()> { // Extract components from the MatchClass pattern. @@ -5214,7 +5267,7 @@ impl Compiler { for subpattern in patterns.iter().chain(kwd_patterns.iter()) { // Check if this is a true wildcard (underscore pattern without name binding) let is_true_wildcard = match subpattern { - Pattern::MatchAs(match_as) => { + ast::Pattern::MatchAs(match_as) => { // Only consider it wildcard if both pattern and name are None (i.e., "_") match_as.pattern.is_none() && match_as.name.is_none() } @@ -5237,7 +5290,7 @@ impl Compiler { fn compile_pattern_mapping( &mut self, - p: &PatternMatchMapping, + p: &ast::PatternMatchMapping, pc: &mut PatternContext, ) -> CompileResult<()> { let mapping = p; @@ -5310,14 +5363,14 @@ impl Compiler { // Validate and compile keys let mut seen = HashSet::new(); for key in keys { - let is_attribute = matches!(key, Expr::Attribute(_)); + let is_attribute = matches!(key, ast::Expr::Attribute(_)); let is_literal = matches!( key, - Expr::NumberLiteral(_) - | Expr::StringLiteral(_) - | Expr::BytesLiteral(_) - | Expr::BooleanLiteral(_) - | Expr::NoneLiteral(_) + ast::Expr::NumberLiteral(_) + | ast::Expr::StringLiteral(_) + | ast::Expr::BytesLiteral(_) + | ast::Expr::BooleanLiteral(_) + | ast::Expr::NoneLiteral(_) ); let key_repr = if is_literal { UnparseExpr::new(key, &self.source_file).to_string() @@ -5440,7 +5493,7 @@ impl Compiler { fn compile_pattern_or( &mut self, - p: &PatternMatchOr, + p: &ast::PatternMatchOr, pc: &mut PatternContext, ) -> CompileResult<()> { // Ensure the pattern is a MatchOr. @@ -5548,11 +5601,11 @@ impl Compiler { fn compile_pattern_sequence( &mut self, - p: &PatternMatchSequence, + p: &ast::PatternMatchSequence, pc: &mut PatternContext, ) -> CompileResult<()> { // Ensure the pattern is a MatchSequence. - let patterns = &p.patterns; // a slice of Pattern + let patterns = &p.patterns; // a slice of ast::Pattern let size = patterns.len(); let mut star: Option = None; let mut only_wildcard = true; @@ -5615,7 +5668,7 @@ impl Compiler { // Whatever comes next should consume the subject. pc.on_top -= 1; if only_wildcard { - // Patterns like: [] / [_] / [_, _] / [*_] / [_, *_] / [_, _, *_] / etc. + // ast::Patterns like: [] / [_] / [_, _] / [*_] / [_, *_] / [_, _, *_] / etc. emit!(self, Instruction::PopTop); } else if star_wildcard { self.pattern_helper_sequence_subscr(patterns, star.unwrap(), pc)?; @@ -5627,7 +5680,7 @@ impl Compiler { fn compile_pattern_value( &mut self, - p: &PatternMatchValue, + p: &ast::PatternMatchValue, pc: &mut PatternContext, ) -> CompileResult<()> { // TODO: ensure literal or attribute lookup @@ -5645,14 +5698,14 @@ impl Compiler { fn compile_pattern_singleton( &mut self, - p: &PatternMatchSingleton, + p: &ast::PatternMatchSingleton, pc: &mut PatternContext, ) -> CompileResult<()> { // Load the singleton constant value. self.emit_load_const(match p.value { - Singleton::None => ConstantData::None, - Singleton::False => ConstantData::Boolean { value: false }, - Singleton::True => ConstantData::Boolean { value: true }, + ast::Singleton::None => ConstantData::None, + ast::Singleton::False => ConstantData::Boolean { value: false }, + ast::Singleton::True => ConstantData::Boolean { value: true }, }); // Compare using the "Is" operator. emit!(self, Instruction::IsOp(Invert::No)); @@ -5663,32 +5716,32 @@ impl Compiler { fn compile_pattern( &mut self, - pattern_type: &Pattern, + pattern_type: &ast::Pattern, pattern_context: &mut PatternContext, ) -> CompileResult<()> { match &pattern_type { - Pattern::MatchValue(pattern_type) => { + ast::Pattern::MatchValue(pattern_type) => { self.compile_pattern_value(pattern_type, pattern_context) } - Pattern::MatchSingleton(pattern_type) => { + ast::Pattern::MatchSingleton(pattern_type) => { self.compile_pattern_singleton(pattern_type, pattern_context) } - Pattern::MatchSequence(pattern_type) => { + ast::Pattern::MatchSequence(pattern_type) => { self.compile_pattern_sequence(pattern_type, pattern_context) } - Pattern::MatchMapping(pattern_type) => { + ast::Pattern::MatchMapping(pattern_type) => { self.compile_pattern_mapping(pattern_type, pattern_context) } - Pattern::MatchClass(pattern_type) => { + ast::Pattern::MatchClass(pattern_type) => { self.compile_pattern_class(pattern_type, pattern_context) } - Pattern::MatchStar(pattern_type) => { + ast::Pattern::MatchStar(pattern_type) => { self.compile_pattern_star(pattern_type, pattern_context) } - Pattern::MatchAs(pattern_type) => { + ast::Pattern::MatchAs(pattern_type) => { self.compile_pattern_as(pattern_type, pattern_context) } - Pattern::MatchOr(pattern_type) => { + ast::Pattern::MatchOr(pattern_type) => { self.compile_pattern_or(pattern_type, pattern_context) } } @@ -5696,8 +5749,8 @@ impl Compiler { fn compile_match_inner( &mut self, - subject: &Expr, - cases: &[MatchCase], + subject: &ast::Expr, + cases: &[ast::MatchCase], pattern_context: &mut PatternContext, ) -> CompileResult<()> { self.compile_expression(subject)?; @@ -5768,7 +5821,11 @@ impl Compiler { Ok(()) } - fn compile_match(&mut self, subject: &Expr, cases: &[MatchCase]) -> CompileResult<()> { + fn compile_match( + &mut self, + subject: &ast::Expr, + cases: &[ast::MatchCase], + ) -> CompileResult<()> { self.enter_conditional_block(); let mut pattern_context = PatternContext::new(); self.compile_match_inner(subject, cases, &mut pattern_context)?; @@ -5777,21 +5834,21 @@ impl Compiler { } /// [CPython `compiler_addcompare`](https://github.com/python/cpython/blob/627894459a84be3488a1789919679c997056a03c/Python/compile.c#L2880-L2924) - fn compile_addcompare(&mut self, op: &CmpOp) { + fn compile_addcompare(&mut self, op: &ast::CmpOp) { use bytecode::ComparisonOperator::*; match op { - CmpOp::Eq => emit!(self, Instruction::CompareOp { op: Equal }), - CmpOp::NotEq => emit!(self, Instruction::CompareOp { op: NotEqual }), - CmpOp::Lt => emit!(self, Instruction::CompareOp { op: Less }), - CmpOp::LtE => emit!(self, Instruction::CompareOp { op: LessOrEqual }), - CmpOp::Gt => emit!(self, Instruction::CompareOp { op: Greater }), - CmpOp::GtE => { + ast::CmpOp::Eq => emit!(self, Instruction::CompareOp { op: Equal }), + ast::CmpOp::NotEq => emit!(self, Instruction::CompareOp { op: NotEqual }), + ast::CmpOp::Lt => emit!(self, Instruction::CompareOp { op: Less }), + ast::CmpOp::LtE => emit!(self, Instruction::CompareOp { op: LessOrEqual }), + ast::CmpOp::Gt => emit!(self, Instruction::CompareOp { op: Greater }), + ast::CmpOp::GtE => { emit!(self, Instruction::CompareOp { op: GreaterOrEqual }) } - CmpOp::In => emit!(self, Instruction::ContainsOp(Invert::No)), - CmpOp::NotIn => emit!(self, Instruction::ContainsOp(Invert::Yes)), - CmpOp::Is => emit!(self, Instruction::IsOp(Invert::No)), - CmpOp::IsNot => emit!(self, Instruction::IsOp(Invert::Yes)), + ast::CmpOp::In => emit!(self, Instruction::ContainsOp(Invert::No)), + ast::CmpOp::NotIn => emit!(self, Instruction::ContainsOp(Invert::Yes)), + ast::CmpOp::Is => emit!(self, Instruction::IsOp(Invert::No)), + ast::CmpOp::IsNot => emit!(self, Instruction::IsOp(Invert::Yes)), } } @@ -5815,9 +5872,9 @@ impl Compiler { /// - [CPython `compiler_compare`](https://github.com/python/cpython/blob/627894459a84be3488a1789919679c997056a03c/Python/compile.c#L4678-L4717) fn compile_compare( &mut self, - left: &Expr, - ops: &[CmpOp], - comparators: &[Expr], + left: &ast::Expr, + ops: &[ast::CmpOp], + comparators: &[ast::Expr], ) -> CompileResult<()> { let (last_op, mid_ops) = ops.split_last().unwrap(); let (last_comparator, mid_comparators) = comparators.split_last().unwrap(); @@ -5870,7 +5927,7 @@ impl Compiler { Ok(()) } - fn compile_annotation(&mut self, annotation: &Expr) -> CompileResult<()> { + fn compile_annotation(&mut self, annotation: &ast::Expr) -> CompileResult<()> { if self.future_annotations { self.emit_load_const(ConstantData::Str { value: UnparseExpr::new(annotation, &self.source_file) @@ -5883,7 +5940,7 @@ impl Compiler { // Special handling for starred annotations (*Ts -> Unpack[Ts]) let result = match annotation { - Expr::Starred(ExprStarred { value, .. }) => { + ast::Expr::Starred(ast::ExprStarred { value, .. }) => { // *args: *Ts (where Ts is a TypeVarTuple). // Do [annotation_value] = [*Ts]. self.compile_expression(value)?; @@ -5901,9 +5958,9 @@ impl Compiler { fn compile_annotated_assign( &mut self, - target: &Expr, - annotation: &Expr, - value: Option<&Expr>, + target: &ast::Expr, + annotation: &ast::Expr, + value: Option<&ast::Expr>, simple: bool, ) -> CompileResult<()> { // Perform the actual assignment first @@ -5915,7 +5972,7 @@ impl Compiler { // If we have a simple name in module or class scope, store annotation if simple && !self.ctx.in_func() - && let Expr::Name(ExprName { id, .. }) = target + && let ast::Expr::Name(ast::ExprName { id, .. }) = target { if self.future_annotations { // PEP 563: Store stringified annotation directly to __annotations__ @@ -5964,25 +6021,26 @@ impl Compiler { Ok(()) } - fn compile_store(&mut self, target: &Expr) -> CompileResult<()> { + fn compile_store(&mut self, target: &ast::Expr) -> CompileResult<()> { match &target { - Expr::Name(ExprName { id, .. }) => self.store_name(id.as_str())?, - Expr::Subscript(ExprSubscript { + ast::Expr::Name(ast::ExprName { id, .. }) => self.store_name(id.as_str())?, + ast::Expr::Subscript(ast::ExprSubscript { value, slice, ctx, .. }) => { self.compile_subscript(value, slice, *ctx)?; } - Expr::Attribute(ExprAttribute { value, attr, .. }) => { + ast::Expr::Attribute(ast::ExprAttribute { value, attr, .. }) => { self.compile_expression(value)?; let idx = self.name(attr.as_str()); emit!(self, Instruction::StoreAttr { idx }); } - Expr::List(ExprList { elts, .. }) | Expr::Tuple(ExprTuple { elts, .. }) => { + ast::Expr::List(ast::ExprList { elts, .. }) + | ast::Expr::Tuple(ast::ExprTuple { elts, .. }) => { let mut seen_star = false; // Scan for star args: for (i, element) in elts.iter().enumerate() { - if let Expr::Starred(_) = &element { + if let ast::Expr::Starred(_) = &element { if seen_star { return Err(self.error(CodegenErrorType::MultipleStarArgs)); } else { @@ -6012,7 +6070,7 @@ impl Compiler { } for element in elts { - if let Expr::Starred(ExprStarred { value, .. }) = &element { + if let ast::Expr::Starred(ast::ExprStarred { value, .. }) = &element { self.compile_store(value)?; } else { self.compile_store(element)?; @@ -6021,7 +6079,7 @@ impl Compiler { } _ => { return Err(self.error(match target { - Expr::Starred(_) => CodegenErrorType::SyntaxError( + ast::Expr::Starred(_) => CodegenErrorType::SyntaxError( "starred assignment target must be in a list or tuple".to_owned(), ), _ => CodegenErrorType::Assign(target.python_name()), @@ -6034,9 +6092,9 @@ impl Compiler { fn compile_augassign( &mut self, - target: &Expr, - op: &Operator, - value: &Expr, + target: &ast::Expr, + op: &ast::Operator, + value: &ast::Expr, ) -> CompileResult<()> { enum AugAssignKind<'a> { Name { id: &'a str }, @@ -6045,12 +6103,12 @@ impl Compiler { } let kind = match &target { - Expr::Name(ExprName { id, .. }) => { + ast::Expr::Name(ast::ExprName { id, .. }) => { let id = id.as_str(); self.compile_name(id, NameUsage::Load)?; AugAssignKind::Name { id } } - Expr::Subscript(ExprSubscript { + ast::Expr::Subscript(ast::ExprSubscript { value, slice, ctx: _, @@ -6065,7 +6123,7 @@ impl Compiler { emit!(self, Instruction::Subscript); AugAssignKind::Subscript } - Expr::Attribute(ExprAttribute { value, attr, .. }) => { + ast::Expr::Attribute(ast::ExprAttribute { value, attr, .. }) => { let attr = attr.as_str(); self.compile_expression(value)?; emit!(self, Instruction::Copy { index: 1_u32 }); @@ -6102,21 +6160,21 @@ impl Compiler { Ok(()) } - fn compile_op(&mut self, op: &Operator, inplace: bool) { + fn compile_op(&mut self, op: &ast::Operator, inplace: bool) { let bin_op = match op { - Operator::Add => BinaryOperator::Add, - Operator::Sub => BinaryOperator::Subtract, - Operator::Mult => BinaryOperator::Multiply, - Operator::MatMult => BinaryOperator::MatrixMultiply, - Operator::Div => BinaryOperator::TrueDivide, - Operator::FloorDiv => BinaryOperator::FloorDivide, - Operator::Mod => BinaryOperator::Remainder, - Operator::Pow => BinaryOperator::Power, - Operator::LShift => BinaryOperator::Lshift, - Operator::RShift => BinaryOperator::Rshift, - Operator::BitOr => BinaryOperator::Or, - Operator::BitXor => BinaryOperator::Xor, - Operator::BitAnd => BinaryOperator::And, + ast::Operator::Add => BinaryOperator::Add, + ast::Operator::Sub => BinaryOperator::Subtract, + ast::Operator::Mult => BinaryOperator::Multiply, + ast::Operator::MatMult => BinaryOperator::MatrixMultiply, + ast::Operator::Div => BinaryOperator::TrueDivide, + ast::Operator::FloorDiv => BinaryOperator::FloorDivide, + ast::Operator::Mod => BinaryOperator::Remainder, + ast::Operator::Pow => BinaryOperator::Power, + ast::Operator::LShift => BinaryOperator::Lshift, + ast::Operator::RShift => BinaryOperator::Rshift, + ast::Operator::BitOr => BinaryOperator::Or, + ast::Operator::BitXor => BinaryOperator::Xor, + ast::Operator::BitAnd => BinaryOperator::And, }; let op = if inplace { bin_op.as_inplace() } else { bin_op }; @@ -6133,15 +6191,15 @@ impl Compiler { /// (indicated by the condition parameter). fn compile_jump_if( &mut self, - expression: &Expr, + expression: &ast::Expr, condition: bool, target_block: BlockIdx, ) -> CompileResult<()> { // Compile expression for test, and jump to label if false match &expression { - Expr::BoolOp(ExprBoolOp { op, values, .. }) => { + ast::Expr::BoolOp(ast::ExprBoolOp { op, values, .. }) => { match op { - BoolOp::And => { + ast::BoolOp::And => { if condition { // If all values are true. let end_block = self.new_block(); @@ -6162,7 +6220,7 @@ impl Compiler { } } } - BoolOp::Or => { + ast::BoolOp::Or => { if condition { // If any of the values is true. for value in values { @@ -6185,8 +6243,8 @@ impl Compiler { } } } - Expr::UnaryOp(ExprUnaryOp { - op: UnaryOp::Not, + ast::Expr::UnaryOp(ast::ExprUnaryOp { + op: ast::UnaryOp::Not, operand, .. }) => { @@ -6217,7 +6275,7 @@ impl Compiler { /// Compile a boolean operation as an expression. /// This means, that the last value remains on the stack. - fn compile_bool_op(&mut self, op: &BoolOp, values: &[Expr]) -> CompileResult<()> { + fn compile_bool_op(&mut self, op: &ast::BoolOp, values: &[ast::Expr]) -> CompileResult<()> { let after_block = self.new_block(); let (last_value, values) = values.split_last().unwrap(); @@ -6227,7 +6285,7 @@ impl Compiler { emit!(self, Instruction::Copy { index: 1_u32 }); match op { - BoolOp::And => { + ast::BoolOp::And => { emit!( self, Instruction::PopJumpIfFalse { @@ -6235,7 +6293,7 @@ impl Compiler { } ); } - BoolOp::Or => { + ast::BoolOp::Or => { emit!( self, Instruction::PopJumpIfTrue { @@ -6254,7 +6312,7 @@ impl Compiler { Ok(()) } - fn compile_dict(&mut self, items: &[DictItem]) -> CompileResult<()> { + fn compile_dict(&mut self, items: &[ast::DictItem]) -> CompileResult<()> { // FIXME: correct order to build map, etc d = {**a, 'key': 2} should override // 'key' in dict a let mut size = 0; @@ -6348,18 +6406,19 @@ impl Compiler { Ok(()) } - fn compile_expression(&mut self, expression: &Expr) -> CompileResult<()> { - use ruff_python_ast::*; + fn compile_expression(&mut self, expression: &ast::Expr) -> CompileResult<()> { trace!("Compiling {expression:?}"); let range = expression.range(); self.set_source_range(range); match &expression { - Expr::Call(ExprCall { + ast::Expr::Call(ast::ExprCall { func, arguments, .. }) => self.compile_call(func, arguments)?, - Expr::BoolOp(ExprBoolOp { op, values, .. }) => self.compile_bool_op(op, values)?, - Expr::BinOp(ExprBinOp { + ast::Expr::BoolOp(ast::ExprBoolOp { op, values, .. }) => { + self.compile_bool_op(op, values)? + } + ast::Expr::BinOp(ast::ExprBinOp { left, op, right, .. }) => { self.compile_expression(left)?; @@ -6368,31 +6427,31 @@ impl Compiler { // Perform operation: self.compile_op(op, false); } - Expr::Subscript(ExprSubscript { + ast::Expr::Subscript(ast::ExprSubscript { value, slice, ctx, .. }) => { self.compile_subscript(value, slice, *ctx)?; } - Expr::UnaryOp(ExprUnaryOp { op, operand, .. }) => { + ast::Expr::UnaryOp(ast::ExprUnaryOp { op, operand, .. }) => { self.compile_expression(operand)?; // Perform operation: match op { - UnaryOp::UAdd => emit!( + ast::UnaryOp::UAdd => emit!( self, Instruction::CallIntrinsic1 { func: bytecode::IntrinsicFunction1::UnaryPositive } ), - UnaryOp::USub => emit!(self, Instruction::UnaryNegative), - UnaryOp::Not => { + ast::UnaryOp::USub => emit!(self, Instruction::UnaryNegative), + ast::UnaryOp::Not => { emit!(self, Instruction::ToBool); emit!(self, Instruction::UnaryNot); } - UnaryOp::Invert => emit!(self, Instruction::UnaryInvert), + ast::UnaryOp::Invert => emit!(self, Instruction::UnaryInvert), }; } - Expr::Attribute(ExprAttribute { value, attr, .. }) => { + ast::Expr::Attribute(ast::ExprAttribute { value, attr, .. }) => { // Check for super() attribute access optimization if let Some(super_type) = self.can_optimize_super_call(value, attr.as_str()) { // super().attr or super(cls, self).attr optimization @@ -6416,7 +6475,7 @@ impl Compiler { emit!(self, Instruction::LoadAttr { idx }); } } - Expr::Compare(ExprCompare { + ast::Expr::Compare(ast::ExprCompare { left, ops, comparators, @@ -6424,25 +6483,25 @@ impl Compiler { }) => { self.compile_compare(left, ops, comparators)?; } - // Expr::Constant(ExprConstant { value, .. }) => { + // ast::Expr::Constant(ExprConstant { value, .. }) => { // self.emit_load_const(compile_constant(value)); // } - Expr::List(ExprList { elts, .. }) => { + ast::Expr::List(ast::ExprList { elts, .. }) => { self.starunpack_helper(elts, 0, CollectionType::List)?; } - Expr::Tuple(ExprTuple { elts, .. }) => { + ast::Expr::Tuple(ast::ExprTuple { elts, .. }) => { self.starunpack_helper(elts, 0, CollectionType::Tuple)?; } - Expr::Set(ExprSet { elts, .. }) => { + ast::Expr::Set(ast::ExprSet { elts, .. }) => { self.starunpack_helper(elts, 0, CollectionType::Set)?; } - Expr::Dict(ExprDict { items, .. }) => { + ast::Expr::Dict(ast::ExprDict { items, .. }) => { self.compile_dict(items)?; } - Expr::Slice(ExprSlice { + ast::Expr::Slice(ast::ExprSlice { lower, upper, step, .. }) => { - let mut compile_bound = |bound: Option<&Expr>| match bound { + let mut compile_bound = |bound: Option<&ast::Expr>| match bound { Some(exp) => self.compile_expression(exp), None => { self.emit_load_const(ConstantData::None); @@ -6460,7 +6519,7 @@ impl Compiler { }; emit!(self, Instruction::BuildSlice { argc }); } - Expr::Yield(ExprYield { value, .. }) => { + ast::Expr::Yield(ast::ExprYield { value, .. }) => { if !self.ctx.in_func() { return Err(self.error(CodegenErrorType::InvalidYield)); } @@ -6478,7 +6537,7 @@ impl Compiler { } ); } - Expr::Await(ExprAwait { value, .. }) => { + ast::Expr::Await(ast::ExprAwait { value, .. }) => { if self.ctx.func != FunctionContext::AsyncFunction { return Err(self.error(CodegenErrorType::InvalidAwait)); } @@ -6487,7 +6546,7 @@ impl Compiler { self.emit_load_const(ConstantData::None); self.compile_yield_from_sequence(true)?; } - Expr::YieldFrom(ExprYieldFrom { value, .. }) => { + ast::Expr::YieldFrom(ast::ExprYieldFrom { value, .. }) => { match self.ctx.func { FunctionContext::NoFunction => { return Err(self.error(CodegenErrorType::InvalidYieldFrom)); @@ -6503,11 +6562,11 @@ impl Compiler { self.emit_load_const(ConstantData::None); self.compile_yield_from_sequence(false)?; } - Expr::Name(ExprName { id, .. }) => self.load_name(id.as_str())?, - Expr::Lambda(ExprLambda { + ast::Expr::Name(ast::ExprName { id, .. }) => self.load_name(id.as_str())?, + ast::Expr::Lambda(ast::ExprLambda { parameters, body, .. }) => { - let default_params = Parameters::default(); + let default_params = ast::Parameters::default(); let params = parameters.as_deref().unwrap_or(&default_params); validate_duplicate_params(params).map_err(|e| self.error(e))?; @@ -6586,7 +6645,7 @@ impl Compiler { self.ctx = prev_ctx; } - Expr::ListComp(ExprListComp { + ast::Expr::ListComp(ast::ExprListComp { elt, generators, .. }) => { self.compile_comprehension( @@ -6612,7 +6671,7 @@ impl Compiler { Self::contains_await(elt) || Self::generators_contain_await(generators), )?; } - Expr::SetComp(ExprSetComp { + ast::Expr::SetComp(ast::ExprSetComp { elt, generators, .. }) => { self.compile_comprehension( @@ -6638,7 +6697,7 @@ impl Compiler { Self::contains_await(elt) || Self::generators_contain_await(generators), )?; } - Expr::DictComp(ExprDictComp { + ast::Expr::DictComp(ast::ExprDictComp { key, value, generators, @@ -6673,7 +6732,7 @@ impl Compiler { || Self::generators_contain_await(generators), )?; } - Expr::Generator(ExprGenerator { + ast::Expr::Generator(ast::ExprGenerator { elt, generators, .. }) => { // Check if element or generators contain async content @@ -6707,7 +6766,7 @@ impl Compiler { element_contains_await, )?; } - Expr::Starred(ExprStarred { value, .. }) => { + ast::Expr::Starred(ast::ExprStarred { value, .. }) => { if self.in_annotation { // In annotation context, starred expressions are allowed (PEP 646) // For now, just compile the inner value without wrapping with Unpack @@ -6717,7 +6776,7 @@ impl Compiler { return Err(self.error(CodegenErrorType::InvalidStarExpr)); } } - Expr::If(ExprIf { + ast::Expr::If(ast::ExprIf { test, body, orelse, .. }) => { let else_block = self.new_block(); @@ -6741,7 +6800,7 @@ impl Compiler { self.switch_to_block(after_block); } - Expr::Named(ExprNamed { + ast::Expr::Named(ast::ExprNamed { target, value, node_index: _, @@ -6751,13 +6810,13 @@ impl Compiler { emit!(self, Instruction::Copy { index: 1_u32 }); self.compile_store(target)?; } - Expr::FString(fstring) => { + ast::Expr::FString(fstring) => { self.compile_expr_fstring(fstring)?; } - Expr::TString(tstring) => { + ast::Expr::TString(tstring) => { self.compile_expr_tstring(tstring)?; } - Expr::StringLiteral(string) => { + ast::Expr::StringLiteral(string) => { let value = string.value.to_str(); if value.contains(char::REPLACEMENT_CHARACTER) { let value = string @@ -6776,42 +6835,42 @@ impl Compiler { }); } } - Expr::BytesLiteral(bytes) => { + ast::Expr::BytesLiteral(bytes) => { let iter = bytes.value.iter().flat_map(|x| x.iter().copied()); let v: Vec = iter.collect(); self.emit_load_const(ConstantData::Bytes { value: v }); } - Expr::NumberLiteral(number) => match &number.value { - Number::Int(int) => { + ast::Expr::NumberLiteral(number) => match &number.value { + ast::Number::Int(int) => { let value = ruff_int_to_bigint(int).map_err(|e| self.error(e))?; self.emit_load_const(ConstantData::Integer { value }); } - Number::Float(float) => { + ast::Number::Float(float) => { self.emit_load_const(ConstantData::Float { value: *float }); } - Number::Complex { real, imag } => { + ast::Number::Complex { real, imag } => { self.emit_load_const(ConstantData::Complex { value: Complex::new(*real, *imag), }); } }, - Expr::BooleanLiteral(b) => { + ast::Expr::BooleanLiteral(b) => { self.emit_load_const(ConstantData::Boolean { value: b.value }); } - Expr::NoneLiteral(_) => { + ast::Expr::NoneLiteral(_) => { self.emit_load_const(ConstantData::None); } - Expr::EllipsisLiteral(_) => { + ast::Expr::EllipsisLiteral(_) => { self.emit_load_const(ConstantData::Ellipsis); } - Expr::IpyEscapeCommand(_) => { + ast::Expr::IpyEscapeCommand(_) => { panic!("unexpected ipy escape command"); } } Ok(()) } - fn compile_keywords(&mut self, keywords: &[Keyword]) -> CompileResult<()> { + fn compile_keywords(&mut self, keywords: &[ast::Keyword]) -> CompileResult<()> { let mut size = 0; let groupby = keywords.iter().chunk_by(|e| e.arg.is_none()); for (is_unpacking, sub_keywords) in &groupby { @@ -6841,10 +6900,10 @@ impl Compiler { Ok(()) } - fn compile_call(&mut self, func: &Expr, args: &Arguments) -> CompileResult<()> { + fn compile_call(&mut self, func: &ast::Expr, args: &ast::Arguments) -> CompileResult<()> { // Method call: obj → LOAD_ATTR_METHOD → [method, self_or_null] → args → CALL // Regular call: func → PUSH_NULL → args → CALL - if let Expr::Attribute(ExprAttribute { value, attr, .. }) = &func { + if let ast::Expr::Attribute(ast::ExprAttribute { value, attr, .. }) = &func { // Check for super() method call optimization if let Some(super_type) = self.can_optimize_super_call(value, attr.as_str()) { // super().method() or super(cls, self).method() optimization @@ -6883,7 +6942,7 @@ impl Compiler { fn compile_call_helper( &mut self, additional_positional: u32, - arguments: &Arguments, + arguments: &ast::Arguments, ) -> CompileResult<()> { let args_count = u32::try_from(arguments.len()).expect("too many arguments"); let count = args_count @@ -6903,11 +6962,12 @@ impl Compiler { } // Create an optional map with kw-args: - let has_kwargs = !arguments.keywords.is_empty(); - if has_kwargs { + if !arguments.keywords.is_empty() { self.compile_keywords(&arguments.keywords)?; + } else { + emit!(self, Instruction::PushNull); } - emit!(self, Instruction::CallFunctionEx { has_kwargs }); + emit!(self, Instruction::CallFunctionEx); } else if !arguments.keywords.is_empty() { // No **kwargs in this branch (has_double_star is false), // so all keywords have arg.is_some() @@ -6945,9 +7005,13 @@ impl Compiler { // Given a vector of expr / star expr generate code which gives either // a list of expressions on the stack, or a list of tuples. - fn gather_elements(&mut self, before: u32, elements: &[Expr]) -> CompileResult<(u32, bool)> { + fn gather_elements( + &mut self, + before: u32, + elements: &[ast::Expr], + ) -> CompileResult<(u32, bool)> { // First determine if we have starred elements: - let has_stars = elements.iter().any(|e| matches!(e, Expr::Starred(_))); + let has_stars = elements.iter().any(|e| matches!(e, ast::Expr::Starred(_))); let size = if has_stars { let mut size = 0; @@ -6955,14 +7019,17 @@ impl Compiler { let mut run_size = before; loop { - if iter.peek().is_none_or(|e| matches!(e, Expr::Starred(_))) { + if iter + .peek() + .is_none_or(|e| matches!(e, ast::Expr::Starred(_))) + { emit!(self, Instruction::BuildTuple { size: run_size }); run_size = 0; size += 1; } match iter.next() { - Some(Expr::Starred(ExprStarred { value, .. })) => { + Some(ast::Expr::Starred(ast::ExprStarred { value, .. })) => { self.compile_expression(value)?; // We need to collect each unpacked element into a // tuple, since any side-effects during the conversion @@ -6990,7 +7057,7 @@ impl Compiler { Ok((size, has_stars)) } - fn compile_comprehension_element(&mut self, element: &Expr) -> CompileResult<()> { + fn compile_comprehension_element(&mut self, element: &ast::Expr) -> CompileResult<()> { self.compile_expression(element).map_err(|e| { if let CodegenErrorType::InvalidStarExpr = e.error { self.error(CodegenErrorType::SyntaxError( @@ -7006,7 +7073,7 @@ impl Compiler { &mut self, name: &str, init_collection: Option, - generators: &[Comprehension], + generators: &[ast::Comprehension], compile_element: &dyn Fn(&mut Self) -> CompileResult<()>, comprehension_type: ComprehensionType, element_contains_await: bool, @@ -7197,25 +7264,25 @@ impl Compiler { } /// Collect variable names from an assignment target expression - fn collect_target_names(&self, target: &Expr, names: &mut Vec) { + fn collect_target_names(&self, target: &ast::Expr, names: &mut Vec) { match target { - Expr::Name(name) => { + ast::Expr::Name(name) => { let name_str = name.id.to_string(); if !names.contains(&name_str) { names.push(name_str); } } - Expr::Tuple(tuple) => { + ast::Expr::Tuple(tuple) => { for elt in &tuple.elts { self.collect_target_names(elt, names); } } - Expr::List(list) => { + ast::Expr::List(list) => { for elt in &list.elts { self.collect_target_names(elt, names); } } - Expr::Starred(starred) => { + ast::Expr::Starred(starred) => { self.collect_target_names(&starred.value, names); } _ => { @@ -7229,7 +7296,7 @@ impl Compiler { fn compile_inlined_comprehension( &mut self, init_collection: Option, - generators: &[Comprehension], + generators: &[ast::Comprehension], compile_element: &dyn Fn(&mut Self) -> CompileResult<()>, _has_an_async_gen: bool, ) -> CompileResult<()> { @@ -7414,7 +7481,7 @@ impl Compiler { Ok(()) } - fn compile_future_features(&mut self, features: &[Alias]) -> Result<(), CodegenError> { + fn compile_future_features(&mut self, features: &[ast::Alias]) -> Result<(), CodegenError> { if let DoneWithFuture::Yes = self.done_with_future_stmts { return Err(self.error(CodegenErrorType::InvalidFuturePlacement)); } @@ -7762,26 +7829,28 @@ impl Compiler { /// async for: ... /// ``` /// are statements, so we won't check for them here - fn contains_await(expression: &Expr) -> bool { + fn contains_await(expression: &ast::Expr) -> bool { + use ast::visitor::Visitor; + #[derive(Default)] struct AwaitVisitor { found: bool, } - impl Visitor<'_> for AwaitVisitor { - fn visit_expr(&mut self, expr: &Expr) { + impl ast::visitor::Visitor<'_> for AwaitVisitor { + fn visit_expr(&mut self, expr: &ast::Expr) { if self.found { return; } match expr { - Expr::Await(_) => self.found = true, + ast::Expr::Await(_) => self.found = true, // Note: We do NOT check for async comprehensions here. // Async list/set/dict comprehensions are handled by compile_comprehension // which already awaits the result. A generator expression containing // an async comprehension as its element does NOT become an async generator, // because the async comprehension is awaited when evaluating the element. - _ => walk_expr(self, expr), + _ => ast::visitor::walk_expr(self, expr), } } } @@ -7793,7 +7862,7 @@ impl Compiler { /// Check if any of the generators (except the first one's iter) contains an await expression. /// The first generator's iter is evaluated outside the comprehension scope. - fn generators_contain_await(generators: &[Comprehension]) -> bool { + fn generators_contain_await(generators: &[ast::Comprehension]) -> bool { for (i, generator) in generators.iter().enumerate() { // First generator's iter is evaluated outside the comprehension if i > 0 && Self::contains_await(&generator.iter) { @@ -7809,7 +7878,7 @@ impl Compiler { false } - fn compile_expr_fstring(&mut self, fstring: &ExprFString) -> CompileResult<()> { + fn compile_expr_fstring(&mut self, fstring: &ast::ExprFString) -> CompileResult<()> { let fstring = &fstring.value; for part in fstring { self.compile_fstring_part(part)?; @@ -7826,9 +7895,9 @@ impl Compiler { Ok(()) } - fn compile_fstring_part(&mut self, part: &FStringPart) -> CompileResult<()> { + fn compile_fstring_part(&mut self, part: &ast::FStringPart) -> CompileResult<()> { match part { - FStringPart::Literal(string) => { + ast::FStringPart::Literal(string) => { if string.value.contains(char::REPLACEMENT_CHARACTER) { // might have a surrogate literal; should reparse to be sure let source = self.source_file.slice(string.range); @@ -7844,24 +7913,24 @@ impl Compiler { } Ok(()) } - FStringPart::FString(fstring) => self.compile_fstring(fstring), + ast::FStringPart::FString(fstring) => self.compile_fstring(fstring), } } - fn compile_fstring(&mut self, fstring: &FString) -> CompileResult<()> { + fn compile_fstring(&mut self, fstring: &ast::FString) -> CompileResult<()> { self.compile_fstring_elements(fstring.flags, &fstring.elements) } fn compile_fstring_elements( &mut self, - flags: FStringFlags, - fstring_elements: &InterpolatedStringElements, + flags: ast::FStringFlags, + fstring_elements: &ast::InterpolatedStringElements, ) -> CompileResult<()> { let mut element_count = 0; for element in fstring_elements { element_count += 1; match element { - InterpolatedStringElement::Literal(string) => { + ast::InterpolatedStringElement::Literal(string) => { if string.value.contains(char::REPLACEMENT_CHARACTER) { // might have a surrogate literal; should reparse to be sure let source = self.source_file.slice(string.range); @@ -7878,15 +7947,15 @@ impl Compiler { }); } } - InterpolatedStringElement::Interpolation(fstring_expr) => { + ast::InterpolatedStringElement::Interpolation(fstring_expr) => { let mut conversion = match fstring_expr.conversion { - ConversionFlag::None => ConvertValueOparg::None, - ConversionFlag::Str => ConvertValueOparg::Str, - ConversionFlag::Repr => ConvertValueOparg::Repr, - ConversionFlag::Ascii => ConvertValueOparg::Ascii, + ast::ConversionFlag::None => ConvertValueOparg::None, + ast::ConversionFlag::Str => ConvertValueOparg::Str, + ast::ConversionFlag::Repr => ConvertValueOparg::Repr, + ast::ConversionFlag::Ascii => ConvertValueOparg::Ascii, }; - if let Some(DebugText { leading, trailing }) = &fstring_expr.debug_text { + if let Some(ast::DebugText { leading, trailing }) = &fstring_expr.debug_text { let range = fstring_expr.expression.range(); let source = self.source_file.slice(range); let text = [leading, source, trailing].concat(); @@ -7946,9 +8015,9 @@ impl Compiler { Ok(()) } - fn compile_expr_tstring(&mut self, expr_tstring: &ExprTString) -> CompileResult<()> { - // TStringValue can contain multiple TString parts (implicit concatenation) - // Each TString part should be compiled and the results merged into a single Template + fn compile_expr_tstring(&mut self, expr_tstring: &ast::ExprTString) -> CompileResult<()> { + // ast::TStringValue can contain multiple ast::TString parts (implicit concatenation) + // Each ast::TString part should be compiled and the results merged into a single Template let tstring_value = &expr_tstring.value; // Collect all strings and compile all interpolations @@ -7997,18 +8066,18 @@ impl Compiler { fn compile_tstring_into( &mut self, - tstring: &TString, + tstring: &ast::TString, strings: &mut Vec, current_string: &mut Wtf8Buf, interp_count: &mut u32, ) -> CompileResult<()> { for element in &tstring.elements { match element { - InterpolatedStringElement::Literal(lit) => { + ast::InterpolatedStringElement::Literal(lit) => { // Accumulate literal parts into current_string current_string.push_str(&lit.value); } - InterpolatedStringElement::Interpolation(interp) => { + ast::InterpolatedStringElement::Interpolation(interp) => { // Finish current string segment strings.push(std::mem::take(current_string)); @@ -8024,19 +8093,19 @@ impl Compiler { // Determine conversion code let conversion: u32 = match interp.conversion { - ConversionFlag::None => 0, - ConversionFlag::Str => 1, - ConversionFlag::Repr => 2, - ConversionFlag::Ascii => 3, + ast::ConversionFlag::None => 0, + ast::ConversionFlag::Str => 1, + ast::ConversionFlag::Repr => 2, + ast::ConversionFlag::Ascii => 3, }; // Handle format_spec let has_format_spec = interp.format_spec.is_some(); if let Some(format_spec) = &interp.format_spec { // Compile format_spec as a string using fstring element compilation - // Use default FStringFlags since format_spec syntax is independent of t-string flags + // Use default ast::FStringFlags since format_spec syntax is independent of t-string flags self.compile_fstring_elements( - FStringFlags::empty(), + ast::FStringFlags::empty(), &format_spec.elements, )?; } @@ -8146,12 +8215,12 @@ fn expandtabs(input: &str, tab_size: usize) -> String { expanded_str } -fn split_doc<'a>(body: &'a [Stmt], opts: &CompileOpts) -> (Option, &'a [Stmt]) { - if let Some((Stmt::Expr(expr), body_rest)) = body.split_first() { +fn split_doc<'a>(body: &'a [ast::Stmt], opts: &CompileOpts) -> (Option, &'a [ast::Stmt]) { + if let Some((ast::Stmt::Expr(expr), body_rest)) = body.split_first() { let doc_comment = match &*expr.value { - Expr::StringLiteral(value) => Some(&value.value), + ast::Expr::StringLiteral(value) => Some(&value.value), // f-strings are not allowed in Python doc comments. - Expr::FString(_) => None, + ast::Expr::FString(_) => None, _ => None, }; if let Some(doc) = doc_comment { @@ -8165,7 +8234,7 @@ fn split_doc<'a>(body: &'a [Stmt], opts: &CompileOpts) -> (Option, &'a [ (None, body) } -pub fn ruff_int_to_bigint(int: &Int) -> Result { +pub fn ruff_int_to_bigint(int: &ast::Int) -> Result { if let Some(small) = int.as_u64() { Ok(BigInt::from(small)) } else { @@ -8175,7 +8244,7 @@ pub fn ruff_int_to_bigint(int: &Int) -> Result { /// Converts a `ruff` ast integer into a `BigInt`. /// Unlike small integers, big integers may be stored in one of four possible radix representations. -fn parse_big_integer(int: &Int) -> Result { +fn parse_big_integer(int: &ast::Int) -> Result { // TODO: Improve ruff API // Can we avoid this copy? let s = format!("{int}"); @@ -8218,35 +8287,34 @@ impl ToU32 for usize { #[cfg(test)] mod ruff_tests { use super::*; - use ruff_python_ast::name::Name; - use ruff_python_ast::*; + use ast::name::Name; /// Test if the compiler can correctly identify fstrings containing an `await` expression. #[test] fn test_fstring_contains_await() { let range = TextRange::default(); - let flags = FStringFlags::empty(); + let flags = ast::FStringFlags::empty(); // f'{x}' - let expr_x = Expr::Name(ExprName { - node_index: AtomicNodeIndex::NONE, + let expr_x = ast::Expr::Name(ast::ExprName { + node_index: ast::AtomicNodeIndex::NONE, range, id: Name::new("x"), - ctx: ExprContext::Load, + ctx: ast::ExprContext::Load, }); - let not_present = &Expr::FString(ExprFString { - node_index: AtomicNodeIndex::NONE, + let not_present = &ast::Expr::FString(ast::ExprFString { + node_index: ast::AtomicNodeIndex::NONE, range, - value: FStringValue::single(FString { - node_index: AtomicNodeIndex::NONE, + value: ast::FStringValue::single(ast::FString { + node_index: ast::AtomicNodeIndex::NONE, range, - elements: vec![InterpolatedStringElement::Interpolation( - InterpolatedElement { - node_index: AtomicNodeIndex::NONE, + elements: vec![ast::InterpolatedStringElement::Interpolation( + ast::InterpolatedElement { + node_index: ast::AtomicNodeIndex::NONE, range, expression: Box::new(expr_x), debug_text: None, - conversion: ConversionFlag::None, + conversion: ast::ConversionFlag::None, format_spec: None, }, )] @@ -8257,29 +8325,29 @@ mod ruff_tests { assert!(!Compiler::contains_await(not_present)); // f'{await x}' - let expr_await_x = Expr::Await(ExprAwait { - node_index: AtomicNodeIndex::NONE, + let expr_await_x = ast::Expr::Await(ast::ExprAwait { + node_index: ast::AtomicNodeIndex::NONE, range, - value: Box::new(Expr::Name(ExprName { - node_index: AtomicNodeIndex::NONE, + value: Box::new(ast::Expr::Name(ast::ExprName { + node_index: ast::AtomicNodeIndex::NONE, range, id: Name::new("x"), - ctx: ExprContext::Load, + ctx: ast::ExprContext::Load, })), }); - let present = &Expr::FString(ExprFString { - node_index: AtomicNodeIndex::NONE, + let present = &ast::Expr::FString(ast::ExprFString { + node_index: ast::AtomicNodeIndex::NONE, range, - value: FStringValue::single(FString { - node_index: AtomicNodeIndex::NONE, + value: ast::FStringValue::single(ast::FString { + node_index: ast::AtomicNodeIndex::NONE, range, - elements: vec![InterpolatedStringElement::Interpolation( - InterpolatedElement { - node_index: AtomicNodeIndex::NONE, + elements: vec![ast::InterpolatedStringElement::Interpolation( + ast::InterpolatedElement { + node_index: ast::AtomicNodeIndex::NONE, range, expression: Box::new(expr_await_x), debug_text: None, - conversion: ConversionFlag::None, + conversion: ast::ConversionFlag::None, format_spec: None, }, )] @@ -8290,45 +8358,45 @@ mod ruff_tests { assert!(Compiler::contains_await(present)); // f'{x:{await y}}' - let expr_x = Expr::Name(ExprName { - node_index: AtomicNodeIndex::NONE, + let expr_x = ast::Expr::Name(ast::ExprName { + node_index: ast::AtomicNodeIndex::NONE, range, id: Name::new("x"), - ctx: ExprContext::Load, + ctx: ast::ExprContext::Load, }); - let expr_await_y = Expr::Await(ExprAwait { - node_index: AtomicNodeIndex::NONE, + let expr_await_y = ast::Expr::Await(ast::ExprAwait { + node_index: ast::AtomicNodeIndex::NONE, range, - value: Box::new(Expr::Name(ExprName { - node_index: AtomicNodeIndex::NONE, + value: Box::new(ast::Expr::Name(ast::ExprName { + node_index: ast::AtomicNodeIndex::NONE, range, id: Name::new("y"), - ctx: ExprContext::Load, + ctx: ast::ExprContext::Load, })), }); - let present = &Expr::FString(ExprFString { - node_index: AtomicNodeIndex::NONE, + let present = &ast::Expr::FString(ast::ExprFString { + node_index: ast::AtomicNodeIndex::NONE, range, - value: FStringValue::single(FString { - node_index: AtomicNodeIndex::NONE, + value: ast::FStringValue::single(ast::FString { + node_index: ast::AtomicNodeIndex::NONE, range, - elements: vec![InterpolatedStringElement::Interpolation( - InterpolatedElement { - node_index: AtomicNodeIndex::NONE, + elements: vec![ast::InterpolatedStringElement::Interpolation( + ast::InterpolatedElement { + node_index: ast::AtomicNodeIndex::NONE, range, expression: Box::new(expr_x), debug_text: None, - conversion: ConversionFlag::None, - format_spec: Some(Box::new(InterpolatedStringFormatSpec { - node_index: AtomicNodeIndex::NONE, + conversion: ast::ConversionFlag::None, + format_spec: Some(Box::new(ast::InterpolatedStringFormatSpec { + node_index: ast::AtomicNodeIndex::NONE, range, - elements: vec![InterpolatedStringElement::Interpolation( - InterpolatedElement { - node_index: AtomicNodeIndex::NONE, + elements: vec![ast::InterpolatedStringElement::Interpolation( + ast::InterpolatedElement { + node_index: ast::AtomicNodeIndex::NONE, range, expression: Box::new(expr_await_y), debug_text: None, - conversion: ConversionFlag::None, + conversion: ast::ConversionFlag::None, format_spec: None, }, )] diff --git a/crates/codegen/src/lib.rs b/crates/codegen/src/lib.rs index 34d3870ae91..9dd7384170a 100644 --- a/crates/codegen/src/lib.rs +++ b/crates/codegen/src/lib.rs @@ -18,7 +18,7 @@ pub mod symboltable; mod unparse; pub use compile::CompileOpts; -use ruff_python_ast::Expr; +use ruff_python_ast as ast; pub(crate) use compile::InternalResult; @@ -27,7 +27,7 @@ pub trait ToPythonName { fn python_name(&self) -> &'static str; } -impl ToPythonName for Expr { +impl ToPythonName for ast::Expr { fn python_name(&self) -> &'static str { match self { Self::BoolOp { .. } | Self::BinOp { .. } | Self::UnaryOp { .. } => "operator", diff --git a/crates/codegen/src/string_parser.rs b/crates/codegen/src/string_parser.rs index 175e75c1a26..7e1558d2b17 100644 --- a/crates/codegen/src/string_parser.rs +++ b/crates/codegen/src/string_parser.rs @@ -7,7 +7,7 @@ use core::convert::Infallible; -use ruff_python_ast::{AnyStringFlags, StringFlags}; +use ruff_python_ast::{self as ast, StringFlags as _}; use rustpython_wtf8::{CodePoint, Wtf8, Wtf8Buf}; // use ruff_python_parser::{LexicalError, LexicalErrorType}; @@ -24,11 +24,11 @@ struct StringParser { /// Current position of the parser in the source. cursor: usize, /// Flags that can be used to query information about the string. - flags: AnyStringFlags, + flags: ast::AnyStringFlags, } impl StringParser { - const fn new(source: Box, flags: AnyStringFlags) -> Self { + const fn new(source: Box, flags: ast::AnyStringFlags) -> Self { Self { source, cursor: 0, @@ -272,7 +272,7 @@ impl StringParser { } } -pub(crate) fn parse_string_literal(source: &str, flags: AnyStringFlags) -> Box { +pub(crate) fn parse_string_literal(source: &str, flags: ast::AnyStringFlags) -> Box { let source = &source[flags.opener_len().to_usize()..]; let source = &source[..source.len() - flags.quote_len().to_usize()]; StringParser::new(source.into(), flags) @@ -280,7 +280,10 @@ pub(crate) fn parse_string_literal(source: &str, flags: AnyStringFlags) -> Box, flags: AnyStringFlags) -> Box { +pub(crate) fn parse_fstring_literal_element( + source: Box, + flags: ast::AnyStringFlags, +) -> Box { StringParser::new(source, flags) .parse_fstring_middle() .unwrap_or_else(|x| match x {}) diff --git a/crates/codegen/src/symboltable.rs b/crates/codegen/src/symboltable.rs index 63b330db92f..129900133fa 100644 --- a/crates/codegen/src/symboltable.rs +++ b/crates/codegen/src/symboltable.rs @@ -13,12 +13,7 @@ use crate::{ }; use alloc::{borrow::Cow, fmt}; use bitflags::bitflags; -use ruff_python_ast::{ - self as ast, Comprehension, Decorator, Expr, Identifier, ModExpression, ModModule, Parameter, - ParameterWithDefault, Parameters, Pattern, PatternMatchAs, PatternMatchClass, - PatternMatchMapping, PatternMatchOr, PatternMatchSequence, PatternMatchStar, PatternMatchValue, - Stmt, TypeParam, TypeParamParamSpec, TypeParamTypeVar, TypeParamTypeVarTuple, TypeParams, -}; +use ruff_python_ast as ast; use ruff_text_size::{Ranged, TextRange}; use rustpython_compiler_core::{PositionEncoding, SourceFile, SourceLocation}; use std::collections::HashSet; @@ -97,13 +92,19 @@ impl SymbolTable { } } - pub fn scan_program(program: &ModModule, source_file: SourceFile) -> SymbolTableResult { + pub fn scan_program( + program: &ast::ModModule, + source_file: SourceFile, + ) -> SymbolTableResult { let mut builder = SymbolTableBuilder::new(source_file); builder.scan_statements(program.body.as_ref())?; builder.finish() } - pub fn scan_expr(expr: &ModExpression, source_file: SourceFile) -> SymbolTableResult { + pub fn scan_expr( + expr: &ast::ModExpression, + source_file: SourceFile, + ) -> SymbolTableResult { let mut builder = SymbolTableBuilder::new(source_file); builder.scan_expression(expr.body.as_ref(), ExpressionContext::Load)?; builder.finish() @@ -973,21 +974,21 @@ impl SymbolTableBuilder { .get() as _ } - fn scan_statements(&mut self, statements: &[Stmt]) -> SymbolTableResult { + fn scan_statements(&mut self, statements: &[ast::Stmt]) -> SymbolTableResult { for statement in statements { self.scan_statement(statement)?; } Ok(()) } - fn scan_parameters(&mut self, parameters: &[ParameterWithDefault]) -> SymbolTableResult { + fn scan_parameters(&mut self, parameters: &[ast::ParameterWithDefault]) -> SymbolTableResult { for parameter in parameters { self.scan_parameter(¶meter.parameter)?; } Ok(()) } - fn scan_parameter(&mut self, parameter: &Parameter) -> SymbolTableResult { + fn scan_parameter(&mut self, parameter: &ast::Parameter) -> SymbolTableResult { self.check_name( parameter.name.as_str(), ExpressionContext::Store, @@ -1019,7 +1020,7 @@ impl SymbolTableBuilder { self.register_ident(¶meter.name, usage) } - fn scan_annotation(&mut self, annotation: &Expr) -> SymbolTableResult { + fn scan_annotation(&mut self, annotation: &ast::Expr) -> SymbolTableResult { let current_scope = self.tables.last().map(|t| t.typ); // PEP 649: Check if this is a conditional annotation @@ -1074,8 +1075,8 @@ impl SymbolTableBuilder { result } - fn scan_statement(&mut self, statement: &Stmt) -> SymbolTableResult { - use ruff_python_ast::*; + fn scan_statement(&mut self, statement: &ast::Stmt) -> SymbolTableResult { + use ast::*; if let Stmt::ImportFrom(StmtImportFrom { module, names, .. }) = &statement && module.as_ref().map(|id| id.as_str()) == Some("__future__") { @@ -1435,7 +1436,7 @@ impl SymbolTableBuilder { fn scan_decorators( &mut self, - decorators: &[Decorator], + decorators: &[ast::Decorator], context: ExpressionContext, ) -> SymbolTableResult { for decorator in decorators { @@ -1446,7 +1447,7 @@ impl SymbolTableBuilder { fn scan_expressions( &mut self, - expressions: &[Expr], + expressions: &[ast::Expr], context: ExpressionContext, ) -> SymbolTableResult { for expression in expressions { @@ -1457,10 +1458,10 @@ impl SymbolTableBuilder { fn scan_expression( &mut self, - expression: &Expr, + expression: &ast::Expr, context: ExpressionContext, ) -> SymbolTableResult { - use ruff_python_ast::*; + use ast::*; // Check for expressions not allowed in certain contexts // (type parameters, annotations, type aliases, TypeVar bounds/defaults) @@ -1837,9 +1838,9 @@ impl SymbolTableBuilder { fn scan_comprehension( &mut self, scope_name: &str, - elt1: &Expr, - elt2: Option<&Expr>, - generators: &[Comprehension], + elt1: &ast::Expr, + elt2: Option<&ast::Expr>, + generators: &[ast::Comprehension], range: TextRange, is_generator: bool, ) -> SymbolTableResult { @@ -1906,7 +1907,7 @@ impl SymbolTableBuilder { // = symtable_visit_type_param_bound_or_default fn scan_type_param_bound_or_default( &mut self, - expr: &Expr, + expr: &ast::Expr, scope_name: &str, scope_info: &'static str, ) -> SymbolTableResult { @@ -1932,14 +1933,14 @@ impl SymbolTableBuilder { result } - fn scan_type_params(&mut self, type_params: &TypeParams) -> SymbolTableResult { + fn scan_type_params(&mut self, type_params: &ast::TypeParams) -> SymbolTableResult { // Check for duplicate type parameter names let mut seen_names: std::collections::HashSet<&str> = std::collections::HashSet::new(); for type_param in &type_params.type_params { let (name, range) = match type_param { - TypeParam::TypeVar(tv) => (tv.name.as_str(), tv.range), - TypeParam::ParamSpec(ps) => (ps.name.as_str(), ps.range), - TypeParam::TypeVarTuple(tvt) => (tvt.name.as_str(), tvt.range), + ast::TypeParam::TypeVar(tv) => (tv.name.as_str(), tv.range), + ast::TypeParam::ParamSpec(ps) => (ps.name.as_str(), ps.range), + ast::TypeParam::TypeVarTuple(tvt) => (tvt.name.as_str(), tvt.range), }; if !seen_names.insert(name) { return Err(SymbolTableError { @@ -1959,7 +1960,7 @@ impl SymbolTableBuilder { // First register all type parameters for type_param in &type_params.type_params { match type_param { - TypeParam::TypeVar(TypeParamTypeVar { + ast::TypeParam::TypeVar(ast::TypeParamTypeVar { name, bound, range: type_var_range, @@ -1991,7 +1992,7 @@ impl SymbolTableBuilder { )?; } } - TypeParam::ParamSpec(TypeParamParamSpec { + ast::TypeParam::ParamSpec(ast::TypeParamParamSpec { name, range: param_spec_range, default, @@ -2009,7 +2010,7 @@ impl SymbolTableBuilder { )?; } } - TypeParam::TypeVarTuple(TypeParamTypeVarTuple { + ast::TypeParam::TypeVarTuple(ast::TypeParamTypeVarTuple { name, range: type_var_tuple_range, default, @@ -2032,22 +2033,24 @@ impl SymbolTableBuilder { Ok(()) } - fn scan_patterns(&mut self, patterns: &[Pattern]) -> SymbolTableResult { + fn scan_patterns(&mut self, patterns: &[ast::Pattern]) -> SymbolTableResult { for pattern in patterns { self.scan_pattern(pattern)?; } Ok(()) } - fn scan_pattern(&mut self, pattern: &Pattern) -> SymbolTableResult { - use Pattern::*; + fn scan_pattern(&mut self, pattern: &ast::Pattern) -> SymbolTableResult { + use ast::Pattern::*; match pattern { - MatchValue(PatternMatchValue { value, .. }) => { + MatchValue(ast::PatternMatchValue { value, .. }) => { self.scan_expression(value, ExpressionContext::Load)? } MatchSingleton(_) => {} - MatchSequence(PatternMatchSequence { patterns, .. }) => self.scan_patterns(patterns)?, - MatchMapping(PatternMatchMapping { + MatchSequence(ast::PatternMatchSequence { patterns, .. }) => { + self.scan_patterns(patterns)? + } + MatchMapping(ast::PatternMatchMapping { keys, patterns, rest, @@ -2059,19 +2062,19 @@ impl SymbolTableBuilder { self.register_ident(rest, SymbolUsage::Assigned)?; } } - MatchClass(PatternMatchClass { cls, arguments, .. }) => { + MatchClass(ast::PatternMatchClass { cls, arguments, .. }) => { self.scan_expression(cls, ExpressionContext::Load)?; self.scan_patterns(&arguments.patterns)?; for kw in &arguments.keywords { self.scan_pattern(&kw.pattern)?; } } - MatchStar(PatternMatchStar { name, .. }) => { + MatchStar(ast::PatternMatchStar { name, .. }) => { if let Some(name) = name { self.register_ident(name, SymbolUsage::Assigned)?; } } - MatchAs(PatternMatchAs { pattern, name, .. }) => { + MatchAs(ast::PatternMatchAs { pattern, name, .. }) => { if let Some(pattern) = pattern { self.scan_pattern(pattern)?; } @@ -2079,7 +2082,7 @@ impl SymbolTableBuilder { self.register_ident(name, SymbolUsage::Assigned)?; } } - MatchOr(PatternMatchOr { patterns, .. }) => self.scan_patterns(patterns)?, + MatchOr(ast::PatternMatchOr { patterns, .. }) => self.scan_patterns(patterns)?, } Ok(()) } @@ -2087,7 +2090,7 @@ impl SymbolTableBuilder { fn enter_scope_with_parameters( &mut self, name: &str, - parameters: &Parameters, + parameters: &ast::Parameters, line_number: u32, has_return_annotation: bool, ) -> SymbolTableResult { @@ -2174,7 +2177,7 @@ impl SymbolTableBuilder { Ok(()) } - fn register_ident(&mut self, ident: &Identifier, role: SymbolUsage) -> SymbolTableResult { + fn register_ident(&mut self, ident: &ast::Identifier, role: SymbolUsage) -> SymbolTableResult { self.register_name(ident.as_str(), role, ident.range) } diff --git a/crates/codegen/src/unparse.rs b/crates/codegen/src/unparse.rs index 7b26d229187..849544ab946 100644 --- a/crates/codegen/src/unparse.rs +++ b/crates/codegen/src/unparse.rs @@ -1,9 +1,6 @@ use alloc::fmt; use core::fmt::Display as _; -use ruff_python_ast::{ - self as ruff, Arguments, BoolOp, Comprehension, ConversionFlag, Expr, Identifier, Operator, - Parameter, ParameterWithDefault, Parameters, -}; +use ruff_python_ast as ast; use ruff_text_size::Ranged; use rustpython_compiler_core::SourceFile; use rustpython_literal::escape::{AsciiEscape, UnicodeEscape}; @@ -40,7 +37,7 @@ impl<'a, 'b, 'c> Unparser<'a, 'b, 'c> { self.f.write_str(s) } - fn p_id(&mut self, s: &Identifier) -> fmt::Result { + fn p_id(&mut self, s: &ast::Identifier) -> fmt::Result { self.f.write_str(s.as_str()) } @@ -59,7 +56,7 @@ impl<'a, 'b, 'c> Unparser<'a, 'b, 'c> { self.f.write_fmt(f) } - fn unparse_expr(&mut self, ast: &Expr, level: u8) -> fmt::Result { + fn unparse_expr(&mut self, ast: &ast::Expr, level: u8) -> fmt::Result { macro_rules! op_prec { ($op_ty:ident, $x:expr, $enu:path, $($var:ident($op:literal, $prec:ident)),*$(,)?) => { match $x { @@ -83,13 +80,13 @@ impl<'a, 'b, 'c> Unparser<'a, 'b, 'c> { }}; } match &ast { - Expr::BoolOp(ruff::ExprBoolOp { + ast::Expr::BoolOp(ast::ExprBoolOp { op, values, node_index: _, range: _range, }) => { - let (op, prec) = op_prec!(bin, op, BoolOp, And("and", AND), Or("or", OR)); + let (op, prec) = op_prec!(bin, op, ast::BoolOp, And("and", AND), Or("or", OR)); group_if!(prec, { let mut first = true; for val in values { @@ -98,7 +95,7 @@ impl<'a, 'b, 'c> Unparser<'a, 'b, 'c> { } }) } - Expr::Named(ruff::ExprNamed { + ast::Expr::Named(ast::ExprNamed { target, value, node_index: _, @@ -110,18 +107,18 @@ impl<'a, 'b, 'c> Unparser<'a, 'b, 'c> { self.unparse_expr(value, precedence::ATOM)?; }) } - Expr::BinOp(ruff::ExprBinOp { + ast::Expr::BinOp(ast::ExprBinOp { left, op, right, node_index: _, range: _range, }) => { - let right_associative = matches!(op, Operator::Pow); + let right_associative = matches!(op, ast::Operator::Pow); let (op, prec) = op_prec!( bin, op, - Operator, + ast::Operator, Add("+", ARITH), Sub("-", ARITH), Mult("*", TERM), @@ -142,7 +139,7 @@ impl<'a, 'b, 'c> Unparser<'a, 'b, 'c> { self.unparse_expr(right, prec + !right_associative as u8)?; }) } - Expr::UnaryOp(ruff::ExprUnaryOp { + ast::Expr::UnaryOp(ast::ExprUnaryOp { op, operand, node_index: _, @@ -151,7 +148,7 @@ impl<'a, 'b, 'c> Unparser<'a, 'b, 'c> { let (op, prec) = op_prec!( un, op, - ruff::UnaryOp, + ast::UnaryOp, Invert("~", FACTOR), Not("not ", NOT), UAdd("+", FACTOR), @@ -162,7 +159,7 @@ impl<'a, 'b, 'c> Unparser<'a, 'b, 'c> { self.unparse_expr(operand, prec)?; }) } - Expr::Lambda(ruff::ExprLambda { + ast::Expr::Lambda(ast::ExprLambda { parameters, body, node_index: _, @@ -178,7 +175,7 @@ impl<'a, 'b, 'c> Unparser<'a, 'b, 'c> { write!(self, ": {}", UnparseExpr::new(body, self.source))?; }) } - Expr::If(ruff::ExprIf { + ast::Expr::If(ast::ExprIf { test, body, orelse, @@ -193,7 +190,7 @@ impl<'a, 'b, 'c> Unparser<'a, 'b, 'c> { self.unparse_expr(orelse, precedence::TEST)?; }) } - Expr::Dict(ruff::ExprDict { + ast::Expr::Dict(ast::ExprDict { items, node_index: _, range: _range, @@ -211,7 +208,7 @@ impl<'a, 'b, 'c> Unparser<'a, 'b, 'c> { } self.p("}")?; } - Expr::Set(ruff::ExprSet { + ast::Expr::Set(ast::ExprSet { elts, node_index: _, range: _range, @@ -224,7 +221,7 @@ impl<'a, 'b, 'c> Unparser<'a, 'b, 'c> { } self.p("}")?; } - Expr::ListComp(ruff::ExprListComp { + ast::Expr::ListComp(ast::ExprListComp { elt, generators, node_index: _, @@ -235,7 +232,7 @@ impl<'a, 'b, 'c> Unparser<'a, 'b, 'c> { self.unparse_comp(generators)?; self.p("]")?; } - Expr::SetComp(ruff::ExprSetComp { + ast::Expr::SetComp(ast::ExprSetComp { elt, generators, node_index: _, @@ -246,7 +243,7 @@ impl<'a, 'b, 'c> Unparser<'a, 'b, 'c> { self.unparse_comp(generators)?; self.p("}")?; } - Expr::DictComp(ruff::ExprDictComp { + ast::Expr::DictComp(ast::ExprDictComp { key, value, generators, @@ -260,7 +257,7 @@ impl<'a, 'b, 'c> Unparser<'a, 'b, 'c> { self.unparse_comp(generators)?; self.p("}")?; } - Expr::Generator(ruff::ExprGenerator { + ast::Expr::Generator(ast::ExprGenerator { parenthesized: _, elt, generators, @@ -272,7 +269,7 @@ impl<'a, 'b, 'c> Unparser<'a, 'b, 'c> { self.unparse_comp(generators)?; self.p(")")?; } - Expr::Await(ruff::ExprAwait { + ast::Expr::Await(ast::ExprAwait { value, node_index: _, range: _range, @@ -282,7 +279,7 @@ impl<'a, 'b, 'c> Unparser<'a, 'b, 'c> { self.unparse_expr(value, precedence::ATOM)?; }) } - Expr::Yield(ruff::ExprYield { + ast::Expr::Yield(ast::ExprYield { value, node_index: _, range: _range, @@ -293,7 +290,7 @@ impl<'a, 'b, 'c> Unparser<'a, 'b, 'c> { self.p("(yield)")?; } } - Expr::YieldFrom(ruff::ExprYieldFrom { + ast::Expr::YieldFrom(ast::ExprYieldFrom { value, node_index: _, range: _range, @@ -304,7 +301,7 @@ impl<'a, 'b, 'c> Unparser<'a, 'b, 'c> { UnparseExpr::new(value, self.source) )?; } - Expr::Compare(ruff::ExprCompare { + ast::Expr::Compare(ast::ExprCompare { left, ops, comparators, @@ -322,9 +319,9 @@ impl<'a, 'b, 'c> Unparser<'a, 'b, 'c> { } }) } - Expr::Call(ruff::ExprCall { + ast::Expr::Call(ast::ExprCall { func, - arguments: Arguments { args, keywords, .. }, + arguments: ast::Arguments { args, keywords, .. }, node_index: _, range: _range, }) => { @@ -332,7 +329,7 @@ impl<'a, 'b, 'c> Unparser<'a, 'b, 'c> { self.p("(")?; if let ( [ - Expr::Generator(ruff::ExprGenerator { + ast::Expr::Generator(ast::ExprGenerator { elt, generators, node_index: _, @@ -365,9 +362,9 @@ impl<'a, 'b, 'c> Unparser<'a, 'b, 'c> { } self.p(")")?; } - Expr::FString(ruff::ExprFString { value, .. }) => self.unparse_fstring(value)?, - Expr::TString(_) => self.p("t\"\"")?, - Expr::StringLiteral(ruff::ExprStringLiteral { value, .. }) => { + ast::Expr::FString(ast::ExprFString { value, .. }) => self.unparse_fstring(value)?, + ast::Expr::TString(_) => self.p("t\"\"")?, + ast::Expr::StringLiteral(ast::ExprStringLiteral { value, .. }) => { if value.is_unicode() { self.p("u")? } @@ -375,12 +372,12 @@ impl<'a, 'b, 'c> Unparser<'a, 'b, 'c> { .str_repr() .fmt(self.f)? } - Expr::BytesLiteral(ruff::ExprBytesLiteral { value, .. }) => { + ast::Expr::BytesLiteral(ast::ExprBytesLiteral { value, .. }) => { AsciiEscape::new_repr(&value.bytes().collect::>()) .bytes_repr() .fmt(self.f)? } - Expr::NumberLiteral(ruff::ExprNumberLiteral { value, .. }) => { + ast::Expr::NumberLiteral(ast::ExprNumberLiteral { value, .. }) => { #[allow(clippy::correctness, clippy::assertions_on_constants)] const { assert!(f64::MAX_10_EXP == 308) @@ -388,28 +385,28 @@ impl<'a, 'b, 'c> Unparser<'a, 'b, 'c> { let inf_str = "1e309"; match value { - ruff::Number::Int(int) => int.fmt(self.f)?, - &ruff::Number::Float(fp) => { + ast::Number::Int(int) => int.fmt(self.f)?, + &ast::Number::Float(fp) => { if fp.is_infinite() { self.p(inf_str)? } else { self.p(&rustpython_literal::float::to_string(fp))? } } - &ruff::Number::Complex { real, imag } => self + &ast::Number::Complex { real, imag } => self .p(&rustpython_literal::complex::to_string(real, imag) .replace("inf", inf_str))?, } } - Expr::BooleanLiteral(ruff::ExprBooleanLiteral { value, .. }) => { + ast::Expr::BooleanLiteral(ast::ExprBooleanLiteral { value, .. }) => { self.p(if *value { "True" } else { "False" })? } - Expr::NoneLiteral(ruff::ExprNoneLiteral { .. }) => self.p("None")?, - Expr::EllipsisLiteral(ruff::ExprEllipsisLiteral { .. }) => self.p("...")?, - Expr::Attribute(ruff::ExprAttribute { value, attr, .. }) => { + ast::Expr::NoneLiteral(ast::ExprNoneLiteral { .. }) => self.p("None")?, + ast::Expr::EllipsisLiteral(ast::ExprEllipsisLiteral { .. }) => self.p("...")?, + ast::Expr::Attribute(ast::ExprAttribute { value, attr, .. }) => { self.unparse_expr(value, precedence::ATOM)?; - let period = if let Expr::NumberLiteral(ruff::ExprNumberLiteral { - value: ruff::Number::Int(_), + let period = if let ast::Expr::NumberLiteral(ast::ExprNumberLiteral { + value: ast::Number::Int(_), .. }) = value.as_ref() { @@ -420,19 +417,19 @@ impl<'a, 'b, 'c> Unparser<'a, 'b, 'c> { self.p(period)?; self.p_id(attr)?; } - Expr::Subscript(ruff::ExprSubscript { value, slice, .. }) => { + ast::Expr::Subscript(ast::ExprSubscript { value, slice, .. }) => { self.unparse_expr(value, precedence::ATOM)?; let lvl = precedence::TUPLE; self.p("[")?; self.unparse_expr(slice, lvl)?; self.p("]")?; } - Expr::Starred(ruff::ExprStarred { value, .. }) => { + ast::Expr::Starred(ast::ExprStarred { value, .. }) => { self.p("*")?; self.unparse_expr(value, precedence::EXPR)?; } - Expr::Name(ruff::ExprName { id, .. }) => self.p(id.as_str())?, - Expr::List(ruff::ExprList { elts, .. }) => { + ast::Expr::Name(ast::ExprName { id, .. }) => self.p(id.as_str())?, + ast::Expr::List(ast::ExprList { elts, .. }) => { self.p("[")?; let mut first = true; for elt in elts { @@ -441,7 +438,7 @@ impl<'a, 'b, 'c> Unparser<'a, 'b, 'c> { } self.p("]")?; } - Expr::Tuple(ruff::ExprTuple { elts, .. }) => { + ast::Expr::Tuple(ast::ExprTuple { elts, .. }) => { if elts.is_empty() { self.p("()")?; } else { @@ -455,7 +452,7 @@ impl<'a, 'b, 'c> Unparser<'a, 'b, 'c> { }) } } - Expr::Slice(ruff::ExprSlice { + ast::Expr::Slice(ast::ExprSlice { lower, upper, step, @@ -474,12 +471,12 @@ impl<'a, 'b, 'c> Unparser<'a, 'b, 'c> { self.unparse_expr(step, precedence::TEST)?; } } - Expr::IpyEscapeCommand(_) => {} + ast::Expr::IpyEscapeCommand(_) => {} } Ok(()) } - fn unparse_arguments(&mut self, args: &Parameters) -> fmt::Result { + fn unparse_arguments(&mut self, args: &ast::Parameters) -> fmt::Result { let mut first = true; for (i, arg) in args.posonlyargs.iter().chain(&args.args).enumerate() { self.p_delim(&mut first, ", ")?; @@ -504,7 +501,7 @@ impl<'a, 'b, 'c> Unparser<'a, 'b, 'c> { } Ok(()) } - fn unparse_function_arg(&mut self, arg: &ParameterWithDefault) -> fmt::Result { + fn unparse_function_arg(&mut self, arg: &ast::ParameterWithDefault) -> fmt::Result { self.unparse_arg(&arg.parameter)?; if let Some(default) = &arg.default { write!(self, "={}", UnparseExpr::new(default, self.source))?; @@ -512,7 +509,7 @@ impl<'a, 'b, 'c> Unparser<'a, 'b, 'c> { Ok(()) } - fn unparse_arg(&mut self, arg: &Parameter) -> fmt::Result { + fn unparse_arg(&mut self, arg: &ast::Parameter) -> fmt::Result { self.p_id(&arg.name)?; if let Some(ann) = &arg.annotation { write!(self, ": {}", UnparseExpr::new(ann, self.source))?; @@ -520,7 +517,7 @@ impl<'a, 'b, 'c> Unparser<'a, 'b, 'c> { Ok(()) } - fn unparse_comp(&mut self, generators: &[Comprehension]) -> fmt::Result { + fn unparse_comp(&mut self, generators: &[ast::Comprehension]) -> fmt::Result { for comp in generators { self.p(if comp.is_async { " async for " @@ -538,10 +535,7 @@ impl<'a, 'b, 'c> Unparser<'a, 'b, 'c> { Ok(()) } - fn unparse_fstring_body( - &mut self, - elements: &[ruff::InterpolatedStringElement], - ) -> fmt::Result { + fn unparse_fstring_body(&mut self, elements: &[ast::InterpolatedStringElement]) -> fmt::Result { for elem in elements { self.unparse_fstring_elem(elem)?; } @@ -550,15 +544,15 @@ impl<'a, 'b, 'c> Unparser<'a, 'b, 'c> { fn unparse_formatted( &mut self, - val: &Expr, - debug_text: Option<&ruff::DebugText>, - conversion: ConversionFlag, - spec: Option<&ruff::InterpolatedStringFormatSpec>, + val: &ast::Expr, + debug_text: Option<&ast::DebugText>, + conversion: ast::ConversionFlag, + spec: Option<&ast::InterpolatedStringFormatSpec>, ) -> fmt::Result { let buffered = to_string_fmt(|f| { Unparser::new(f, self.source).unparse_expr(val, precedence::TEST + 1) }); - if let Some(ruff::DebugText { leading, trailing }) = debug_text { + if let Some(ast::DebugText { leading, trailing }) = debug_text { self.p(leading)?; self.p(self.source.slice(val.range()))?; self.p(trailing)?; @@ -573,7 +567,7 @@ impl<'a, 'b, 'c> Unparser<'a, 'b, 'c> { self.p(&buffered)?; drop(buffered); - if conversion != ConversionFlag::None { + if conversion != ast::ConversionFlag::None { self.p("!")?; let buf = &[conversion as u8]; let c = core::str::from_utf8(buf).unwrap(); @@ -590,9 +584,9 @@ impl<'a, 'b, 'c> Unparser<'a, 'b, 'c> { Ok(()) } - fn unparse_fstring_elem(&mut self, elem: &ruff::InterpolatedStringElement) -> fmt::Result { + fn unparse_fstring_elem(&mut self, elem: &ast::InterpolatedStringElement) -> fmt::Result { match elem { - ruff::InterpolatedStringElement::Interpolation(ruff::InterpolatedElement { + ast::InterpolatedStringElement::Interpolation(ast::InterpolatedElement { expression, debug_text, conversion, @@ -604,7 +598,7 @@ impl<'a, 'b, 'c> Unparser<'a, 'b, 'c> { *conversion, format_spec.as_deref(), ), - ruff::InterpolatedStringElement::Literal(ruff::InterpolatedStringLiteralElement { + ast::InterpolatedStringElement::Literal(ast::InterpolatedStringLiteralElement { value, .. }) => self.unparse_fstring_str(value), @@ -616,12 +610,12 @@ impl<'a, 'b, 'c> Unparser<'a, 'b, 'c> { self.p(&s) } - fn unparse_fstring(&mut self, value: &ruff::FStringValue) -> fmt::Result { + fn unparse_fstring(&mut self, value: &ast::FStringValue) -> fmt::Result { self.p("f")?; let body = to_string_fmt(|f| { value.iter().try_for_each(|part| match part { - ruff::FStringPart::Literal(lit) => f.write_str(lit), - ruff::FStringPart::FString(ruff::FString { elements, .. }) => { + ast::FStringPart::Literal(lit) => f.write_str(lit), + ast::FStringPart::FString(ast::FString { elements, .. }) => { Unparser::new(f, self.source).unparse_fstring_body(elements) } }) @@ -634,12 +628,12 @@ impl<'a, 'b, 'c> Unparser<'a, 'b, 'c> { } pub struct UnparseExpr<'a> { - expr: &'a Expr, + expr: &'a ast::Expr, source: &'a SourceFile, } impl<'a> UnparseExpr<'a> { - pub const fn new(expr: &'a Expr, source: &'a SourceFile) -> Self { + pub const fn new(expr: &'a ast::Expr, source: &'a SourceFile) -> Self { Self { expr, source } } } diff --git a/crates/compiler-core/src/bytecode/instruction.rs b/crates/compiler-core/src/bytecode/instruction.rs index fe448f4574b..4faf67d273b 100644 --- a/crates/compiler-core/src/bytecode/instruction.rs +++ b/crates/compiler-core/src/bytecode/instruction.rs @@ -95,9 +95,7 @@ pub enum Instruction { Call { nargs: Arg, } = 53, - CallFunctionEx { - has_kwargs: Arg, - } = 54, + CallFunctionEx = 54, CallIntrinsic1 { func: Arg, } = 55, @@ -575,8 +573,8 @@ impl InstructionMetadata for Instruction { Self::Call { nargs } => -(nargs.get(arg) as i32) - 2 + 1, // CallKw: pops kw_names_tuple + nargs + self_or_null + callable, pushes result Self::CallKw { nargs } => -1 - (nargs.get(arg) as i32) - 2 + 1, - // CallFunctionEx: pops kwargs(if any) + args_tuple + self_or_null + callable, pushes result - Self::CallFunctionEx { has_kwargs } => -1 - (has_kwargs.get(arg) as i32) - 2 + 1, + // CallFunctionEx: always pops kwargs_or_null + args_tuple + self_or_null + callable, pushes result + Self::CallFunctionEx => -4 + 1, Self::CheckEgMatch => 0, // pops 2 (exc, type), pushes 2 (rest, match) Self::ConvertValue { .. } => 0, Self::FormatSimple => 0, @@ -880,7 +878,7 @@ impl InstructionMetadata for Instruction { Self::BuildTupleFromIter => w!(BUILD_TUPLE_FROM_ITER), Self::BuildTupleFromTuples { size } => w!(BUILD_TUPLE_FROM_TUPLES, size), Self::Call { nargs } => w!(CALL, nargs), - Self::CallFunctionEx { has_kwargs } => w!(CALL_FUNCTION_EX, has_kwargs), + Self::CallFunctionEx => w!(CALL_FUNCTION_EX), Self::CallKw { nargs } => w!(CALL_KW, nargs), Self::CallIntrinsic1 { func } => w!(CALL_INTRINSIC_1, ?func), Self::CallIntrinsic2 { func } => w!(CALL_INTRINSIC_2, ?func), diff --git a/crates/derive-impl/src/pyclass.rs b/crates/derive-impl/src/pyclass.rs index 57cbf67de5a..e54c1a867fd 100644 --- a/crates/derive-impl/src/pyclass.rs +++ b/crates/derive-impl/src/pyclass.rs @@ -574,51 +574,80 @@ pub(crate) fn impl_pyclass(attr: PunctuatedNestedMeta, item: Item) -> Result) { + #try_clear_body } - assert_eq!(s, "manual"); - quote! {} - } else { - quote! {#[derive(Traverse)]} - }; - (maybe_trace_code, derive_trace) - } else { - ( - // a dummy impl, which do nothing - // #attrs - quote! { - impl ::rustpython_vm::object::MaybeTraverse for #ident { - fn try_traverse(&self, tracer_fn: &mut ::rustpython_vm::object::TraverseFn) { - // do nothing - } - } - }, - quote! {}, - ) + } } }; @@ -675,7 +704,7 @@ pub(crate) fn impl_pyclass(attr: PunctuatedNestedMeta, item: Item) -> Result) { self.0.try_traverse(traverse_fn) } + + fn try_clear(&mut self, out: &mut ::std::vec::Vec<::rustpython_vm::PyObjectRef>) { + self.0.try_clear(out) + } } // PySubclass for proper inheritance diff --git a/crates/derive-impl/src/util.rs b/crates/derive-impl/src/util.rs index 6be1fcdf7ad..b09ad9c93fe 100644 --- a/crates/derive-impl/src/util.rs +++ b/crates/derive-impl/src/util.rs @@ -372,6 +372,7 @@ impl ItemMeta for ClassItemMeta { "ctx", "impl", "traverse", + "clear", // tp_clear ]; fn from_inner(inner: ItemMetaInner) -> Self { diff --git a/crates/jit/tests/common.rs b/crates/jit/tests/common.rs index ef5a25f0843..49cc7168dd4 100644 --- a/crates/jit/tests/common.rs +++ b/crates/jit/tests/common.rs @@ -102,46 +102,33 @@ fn extract_annotations_from_annotate_code(code: &CodeObject) -> HashMap { - Some(value.as_str().map(|s| s.to_owned()).unwrap_or_else( - |_| value.to_string_lossy().into_owned(), - )) - } - Some(other) => { - eprintln!( - "Warning: Malformed annotation for '{:?}': expected string constant at index {}, got {:?}", - param_name, val_idx, other - ); - None - } - None => { - eprintln!( - "Warning: Malformed annotation for '{:?}': constant index {} out of bounds (len={})", - param_name, - val_idx, - code.constants.len() - ); - None - } + Some(ConstantData::Str { value }) => value + .as_str() + .map(|s| s.to_owned()) + .unwrap_or_else(|_| value.to_string_lossy().into_owned()), + Some(other) => panic!( + "Unsupported annotation const for '{:?}' at idx {}: {:?}", + param_name, val_idx, other + ), + None => panic!( + "Annotation const idx out of bounds for '{:?}': {} (len={})", + param_name, + val_idx, + code.constants.len() + ), } } else { match code.names.get(val_idx) { - Some(name) => Some(name.clone()), - None => { - eprintln!( - "Warning: Malformed annotation for '{}': name index {} out of bounds (len={})", - param_name, - val_idx, - code.names.len() - ); - None - } + Some(name) => name.clone(), + None => panic!( + "Annotation name idx out of bounds for '{:?}': {} (len={})", + param_name, + val_idx, + code.names.len() + ), } }; - if let Some(type_name) = type_name { - annotations - .insert(param_name.clone(), StackValue::String(type_name)); - } + annotations.insert(param_name.clone(), StackValue::String(type_name)); } } } diff --git a/crates/stdlib/src/sqlite.rs b/crates/stdlib/src/sqlite.rs index 7e0392b1f30..9b7b810b25c 100644 --- a/crates/stdlib/src/sqlite.rs +++ b/crates/stdlib/src/sqlite.rs @@ -425,6 +425,12 @@ mod _sqlite { name: PyStrRef, } + #[derive(FromArgs)] + struct CursorArgs { + #[pyarg(any, default)] + factory: OptionalArg, + } + struct CallbackData { obj: NonNull, vm: *const VirtualMachine, @@ -1023,22 +1029,29 @@ mod _sqlite { #[pymethod] fn cursor( zelf: PyRef, - factory: OptionalArg, + args: CursorArgs, vm: &VirtualMachine, - ) -> PyResult> { + ) -> PyResult { zelf.db_lock(vm).map(drop)?; - let cursor = if let OptionalArg::Present(factory) = factory { - let cursor = factory.invoke((zelf.clone(),), vm)?; - let cursor = cursor.downcast::().map_err(|x| { - vm.new_type_error(format!("factory must return a cursor, not {}", x.class())) - })?; - let _ = unsafe { cursor.row_factory.swap(zelf.row_factory.to_owned()) }; - cursor - } else { - let row_factory = zelf.row_factory.to_owned(); - Cursor::new(zelf, row_factory, vm).into_ref(&vm.ctx) + let factory = match args.factory { + OptionalArg::Present(f) => f, + OptionalArg::Missing => Cursor::class(&vm.ctx).to_owned().into(), }; + + let cursor = factory.call((zelf.clone(),), vm)?; + + if !cursor.class().fast_issubclass(Cursor::class(&vm.ctx)) { + return Err(vm.new_type_error(format!( + "factory must return a cursor, not {}", + cursor.class() + ))); + } + + if let Some(cursor_ref) = cursor.downcast_ref::() { + let _ = unsafe { cursor_ref.row_factory.swap(zelf.row_factory.to_owned()) }; + } + Ok(cursor) } diff --git a/crates/stdlib/src/ssl.rs b/crates/stdlib/src/ssl.rs index 31973a84f89..6c62a24397b 100644 --- a/crates/stdlib/src/ssl.rs +++ b/crates/stdlib/src/ssl.rs @@ -50,7 +50,7 @@ mod _ssl { // Import error types used in this module (others are exposed via pymodule(with(...))) use super::error::{ - PySSLEOFError, PySSLError, create_ssl_want_read_error, create_ssl_want_write_error, + PySSLError, create_ssl_eof_error, create_ssl_want_read_error, create_ssl_want_write_error, }; use alloc::sync::Arc; use core::{ @@ -1903,6 +1903,7 @@ mod _ssl { client_hello_buffer: PyMutex::new(None), shutdown_state: PyMutex::new(ShutdownState::NotStarted), pending_tls_output: PyMutex::new(Vec::new()), + write_buffered_len: PyMutex::new(0), deferred_cert_error: Arc::new(ParkingRwLock::new(None)), }; @@ -1974,6 +1975,7 @@ mod _ssl { client_hello_buffer: PyMutex::new(None), shutdown_state: PyMutex::new(ShutdownState::NotStarted), pending_tls_output: PyMutex::new(Vec::new()), + write_buffered_len: PyMutex::new(0), deferred_cert_error: Arc::new(ParkingRwLock::new(None)), }; @@ -2345,6 +2347,10 @@ mod _ssl { // but the socket cannot accept all the data immediately #[pytraverse(skip)] pub(crate) pending_tls_output: PyMutex>, + // Tracks bytes already buffered in rustls for the current write operation + // Prevents duplicate writes when retrying after WantWrite/WantRead + #[pytraverse(skip)] + pub(crate) write_buffered_len: PyMutex, // Deferred client certificate verification error (for TLS 1.3) // Stores error message if client cert verification failed during handshake // Error is raised on first I/O operation after handshake @@ -2604,6 +2610,36 @@ mod _ssl { Ok(timed_out) } + // Internal implementation with explicit timeout override + pub(crate) fn sock_wait_for_io_with_timeout( + &self, + kind: SelectKind, + timeout: Option, + vm: &VirtualMachine, + ) -> PyResult { + if self.is_bio_mode() { + // BIO mode doesn't use select + return Ok(false); + } + + if let Some(t) = timeout + && t.is_zero() + { + // Non-blocking mode - don't use select + return Ok(false); + } + + let py_socket: PyRef = self.sock.clone().try_into_value(vm)?; + let socket = py_socket + .sock() + .map_err(|e| vm.new_os_error(format!("Failed to get socket: {e}")))?; + + let timed_out = sock_select(&socket, kind, timeout) + .map_err(|e| vm.new_os_error(format!("select failed: {e}")))?; + + Ok(timed_out) + } + // SNI (Server Name Indication) Helper Methods: // These methods support the server-side handshake SNI callback mechanism @@ -2783,6 +2819,7 @@ mod _ssl { let is_non_blocking = socket_timeout.map(|t| t.is_zero()).unwrap_or(false); let mut sent_total = 0; + while sent_total < pending.len() { // Calculate timeout: use deadline if provided, otherwise use socket timeout let timeout_to_use = if let Some(dl) = deadline { @@ -2810,6 +2847,9 @@ mod _ssl { if timed_out { // Keep unsent data in pending buffer *pending = pending[sent_total..].to_vec(); + if is_non_blocking { + return Err(create_ssl_want_write_error(vm).upcast()); + } return Err( timeout_error_msg(vm, "The write operation timed out".to_string()).upcast(), ); @@ -2824,6 +2864,7 @@ mod _ssl { *pending = pending[sent_total..].to_vec(); return Err(create_ssl_want_write_error(vm).upcast()); } + // Socket said ready but sent 0 bytes - retry continue; } sent_total += sent; @@ -2916,6 +2957,9 @@ mod _ssl { pub(crate) fn blocking_flush_all_pending(&self, vm: &VirtualMachine) -> PyResult<()> { // Get socket timeout to respect during flush let timeout = self.get_socket_timeout(vm)?; + if timeout.map(|t| t.is_zero()).unwrap_or(false) { + return self.flush_pending_tls_output(vm, None); + } loop { let pending_data = { @@ -2948,8 +2992,7 @@ mod _ssl { let mut pending = self.pending_tls_output.lock(); pending.drain(..sent); } - // If sent == 0, socket wasn't ready despite select() saying so - // Continue loop to retry - this avoids infinite loops + // If sent == 0, loop will retry with sock_select } Err(e) => { if is_blocking_io_error(&e, vm) { @@ -3515,16 +3558,60 @@ mod _ssl { return_data(buf, &buffer, vm) } Err(crate::ssl::compat::SslError::Eof) => { + // If plaintext is still buffered, return it before EOF. + let pending = { + let mut conn_guard = self.connection.lock(); + let conn = match conn_guard.as_mut() { + Some(conn) => conn, + None => return Err(create_ssl_eof_error(vm).upcast()), + }; + use std::io::BufRead; + let mut reader = conn.reader(); + reader.fill_buf().map(|buf| buf.len()).unwrap_or(0) + }; + if pending > 0 { + let mut buf = vec![0u8; pending.min(len)]; + let read_retry = { + let mut conn_guard = self.connection.lock(); + let conn = conn_guard + .as_mut() + .ok_or_else(|| vm.new_value_error("Connection not established"))?; + crate::ssl::compat::ssl_read(conn, &mut buf, self, vm) + }; + if let Ok(n) = read_retry { + buf.truncate(n); + return return_data(buf, &buffer, vm); + } + } // EOF occurred in violation of protocol (unexpected closure) - Err(vm - .new_os_subtype_error( - PySSLEOFError::class(&vm.ctx).to_owned(), - None, - "EOF occurred in violation of protocol", - ) - .upcast()) + Err(create_ssl_eof_error(vm).upcast()) } Err(crate::ssl::compat::SslError::ZeroReturn) => { + // If plaintext is still buffered, return it before clean EOF. + let pending = { + let mut conn_guard = self.connection.lock(); + let conn = match conn_guard.as_mut() { + Some(conn) => conn, + None => return return_data(vec![], &buffer, vm), + }; + use std::io::BufRead; + let mut reader = conn.reader(); + reader.fill_buf().map(|buf| buf.len()).unwrap_or(0) + }; + if pending > 0 { + let mut buf = vec![0u8; pending.min(len)]; + let read_retry = { + let mut conn_guard = self.connection.lock(); + let conn = conn_guard + .as_mut() + .ok_or_else(|| vm.new_value_error("Connection not established"))?; + crate::ssl::compat::ssl_read(conn, &mut buf, self, vm) + }; + if let Ok(n) = read_retry { + buf.truncate(n); + return return_data(buf, &buffer, vm); + } + } // Clean closure with close_notify - return empty data return_data(vec![], &buffer, vm) } @@ -3580,21 +3667,17 @@ mod _ssl { let data_bytes = data.borrow_buf(); let data_len = data_bytes.len(); - // return 0 immediately for empty write if data_len == 0 { return Ok(0); } - // Ensure handshake is done - if not, complete it first - // This matches OpenSSL behavior where SSL_write() auto-completes handshake + // Ensure handshake is done (SSL_write auto-completes handshake) if !*self.handshake_done.lock() { self.do_handshake(vm)?; } - // Check if connection has been shut down - // After unwrap()/shutdown(), write operations should fail with SSLError - let shutdown_state = *self.shutdown_state.lock(); - if shutdown_state != ShutdownState::NotStarted { + // Check shutdown state + if *self.shutdown_state.lock() != ShutdownState::NotStarted { return Err(vm .new_os_subtype_error( PySSLError::class(&vm.ctx).to_owned(), @@ -3604,76 +3687,32 @@ mod _ssl { .upcast()); } - { + // Call ssl_write (matches CPython's SSL_write_ex loop) + let result = { let mut conn_guard = self.connection.lock(); let conn = conn_guard .as_mut() .ok_or_else(|| vm.new_value_error("Connection not established"))?; - let is_bio = self.is_bio_mode(); - let data: &[u8] = data_bytes.as_ref(); + crate::ssl::compat::ssl_write(conn, data_bytes.as_ref(), self, vm) + }; - // CRITICAL: Flush any pending TLS data before writing new data - // This ensures TLS 1.3 Finished message reaches server before application data - // Without this, server may not be ready to process our data - if !is_bio { - self.flush_pending_tls_output(vm, None)?; + match result { + Ok(n) => { + self.check_deferred_cert_error(vm)?; + Ok(n) } - - // Write data in chunks to avoid filling the internal TLS buffer - // rustls has a limited internal buffer, so we need to flush periodically - const CHUNK_SIZE: usize = 16384; // 16KB chunks (typical TLS record size) - let mut written = 0; - - while written < data.len() { - let chunk_end = core::cmp::min(written + CHUNK_SIZE, data.len()); - let chunk = &data[written..chunk_end]; - - // Write chunk to TLS layer - { - let mut writer = conn.writer(); - use std::io::Write; - writer - .write_all(chunk) - .map_err(|e| vm.new_os_error(format!("Write failed: {e}")))?; - // Flush to ensure data is converted to TLS records - writer - .flush() - .map_err(|e| vm.new_os_error(format!("Flush failed: {e}")))?; - } - - written = chunk_end; - - // Flush TLS data to socket after each chunk - if conn.wants_write() { - if is_bio { - self.write_pending_tls(conn, vm)?; - } else { - // Socket mode: flush all pending TLS data - // First, try to send any previously pending data - self.flush_pending_tls_output(vm, None)?; - - while conn.wants_write() { - let mut buf = Vec::new(); - conn.write_tls(&mut buf).map_err(|e| { - vm.new_os_error(format!("TLS write failed: {e}")) - })?; - - if !buf.is_empty() { - // Try to send TLS data, saving unsent bytes to pending buffer - self.send_tls_output(buf, vm)?; - } - } - } - } + Err(crate::ssl::compat::SslError::WantRead) => { + Err(create_ssl_want_read_error(vm).upcast()) + } + Err(crate::ssl::compat::SslError::WantWrite) => { + Err(create_ssl_want_write_error(vm).upcast()) + } + Err(crate::ssl::compat::SslError::Timeout(msg)) => { + Err(timeout_error_msg(vm, msg).upcast()) } + Err(e) => Err(e.into_py_err(vm)), } - - // Check for deferred certificate verification errors (TLS 1.3) - // Must be checked AFTER write completes, as the error may be set during I/O - self.check_deferred_cert_error(vm)?; - - Ok(data_len) } #[pymethod] @@ -4013,6 +4052,10 @@ mod _ssl { // Write close_notify to outgoing buffer/BIO self.write_pending_tls(conn, vm)?; + // Ensure close_notify and any pending TLS data are flushed + if !is_bio { + self.flush_pending_tls_output(vm, None)?; + } // Update state *self.shutdown_state.lock() = ShutdownState::SentCloseNotify; diff --git a/crates/stdlib/src/ssl/compat.rs b/crates/stdlib/src/ssl/compat.rs index 3c72ccf4e21..322fdde5b9a 100644 --- a/crates/stdlib/src/ssl/compat.rs +++ b/crates/stdlib/src/ssl/compat.rs @@ -36,8 +36,8 @@ use super::_ssl::PySSLSocket; // Import error types and helper functions from error module use super::error::{ - PySSLCertVerificationError, PySSLError, create_ssl_eof_error, create_ssl_want_read_error, - create_ssl_want_write_error, create_ssl_zero_return_error, + PySSLCertVerificationError, PySSLError, create_ssl_eof_error, create_ssl_syscall_error, + create_ssl_want_read_error, create_ssl_want_write_error, create_ssl_zero_return_error, }; // SSL Verification Flags @@ -553,8 +553,8 @@ impl SslError { SslError::WantWrite => create_ssl_want_write_error(vm).upcast(), SslError::Timeout(msg) => timeout_error_msg(vm, msg).upcast(), SslError::Syscall(msg) => { - // Create SSLError with library=None for syscall errors during SSL operations - Self::create_ssl_error_with_reason(vm, None, &msg, msg.clone()) + // SSLSyscallError with errno=SSL_ERROR_SYSCALL (5) + create_ssl_syscall_error(vm, msg).upcast() } SslError::Ssl(msg) => vm .new_os_subtype_error( @@ -1039,6 +1039,36 @@ fn send_all_bytes( return Err(SslError::Timeout("The operation timed out".to_string())); } + // Wait for socket to be writable before sending + let timed_out = if let Some(dl) = deadline { + let now = std::time::Instant::now(); + if now >= dl { + socket + .pending_tls_output + .lock() + .extend_from_slice(&buf[sent_total..]); + return Err(SslError::Timeout( + "The write operation timed out".to_string(), + )); + } + socket + .sock_wait_for_io_with_timeout(SelectKind::Write, Some(dl - now), vm) + .map_err(SslError::Py)? + } else { + socket + .sock_wait_for_io_impl(SelectKind::Write, vm) + .map_err(SslError::Py)? + }; + if timed_out { + socket + .pending_tls_output + .lock() + .extend_from_slice(&buf[sent_total..]); + return Err(SslError::Timeout( + "The write operation timed out".to_string(), + )); + } + match socket.sock_send(&buf[sent_total..], vm) { Ok(result) => { let sent: usize = result @@ -1443,9 +1473,17 @@ pub(super) fn ssl_do_handshake( } } - // If we exit the loop without completing handshake, return error - // Check rustls state to provide better error message + // If we exit the loop without completing handshake, return appropriate error if conn.is_handshaking() { + // For non-blocking sockets, return WantRead/WantWrite to signal caller + // should retry when socket is ready. This matches OpenSSL behavior. + if conn.wants_write() { + return Err(SslError::WantWrite); + } + if conn.wants_read() { + return Err(SslError::WantRead); + } + // Neither wants_read nor wants_write - this is a real error Err(SslError::Syscall(format!( "SSL handshake failed: incomplete after {iteration_count} iterations", ))) @@ -1581,6 +1619,14 @@ pub(super) fn ssl_read( if let Some(t) = timeout && t.is_zero() { + // Non-blocking socket: check if peer has closed before returning WantRead + // If close_notify was received, we should return ZeroReturn (EOF), not WantRead + // This is critical for asyncore-based applications that rely on recv() returning + // 0 or raising SSL_ERROR_ZERO_RETURN to detect connection close. + let io_state = conn.process_new_packets().map_err(SslError::from_rustls)?; + if io_state.peer_has_closed() { + return Err(SslError::ZeroReturn); + } // Non-blocking socket: return immediately return Err(SslError::WantRead); } @@ -1605,7 +1651,13 @@ pub(super) fn ssl_read( .unwrap_or(0); if bytes_read == 0 { - // No more data available - connection might be closed + // No more data available - check if this is clean shutdown or unexpected EOF + // If close_notify was already received, return ZeroReturn (clean closure) + // Otherwise, return Eof (unexpected EOF) + let io_state = conn.process_new_packets().map_err(SslError::from_rustls)?; + if io_state.peer_has_closed() { + return Err(SslError::ZeroReturn); + } return Err(SslError::Eof); } @@ -1648,6 +1700,138 @@ pub(super) fn ssl_read( } } +/// Equivalent to OpenSSL's SSL_write() +/// +/// Writes application data to TLS connection. +/// Automatically handles TLS record I/O as needed. +/// +/// = SSL_write_ex() +pub(super) fn ssl_write( + conn: &mut TlsConnection, + data: &[u8], + socket: &PySSLSocket, + vm: &VirtualMachine, +) -> SslResult { + if data.is_empty() { + return Ok(0); + } + + let is_bio = socket.is_bio_mode(); + + // Get socket timeout and calculate deadline (= _PyDeadline_Init) + let deadline = if !is_bio { + match socket.get_socket_timeout(vm).map_err(SslError::Py)? { + Some(timeout) if !timeout.is_zero() => Some(std::time::Instant::now() + timeout), + _ => None, + } + } else { + None + }; + + // Flush any pending TLS output before writing new data + if !is_bio { + socket + .flush_pending_tls_output(vm, deadline) + .map_err(SslError::Py)?; + } + + // Check if we already have data buffered from a previous retry + // (prevents duplicate writes when retrying after WantWrite/WantRead) + let already_buffered = *socket.write_buffered_len.lock(); + + // Only write plaintext if not already buffered + if already_buffered == 0 { + // Write plaintext to rustls (= SSL_write_ex internal buffer write) + { + let mut writer = conn.writer(); + use std::io::Write; + writer + .write_all(data) + .map_err(|e| SslError::Syscall(format!("Write failed: {e}")))?; + } + // Mark data as buffered + *socket.write_buffered_len.lock() = data.len(); + } else if already_buffered != data.len() { + // Caller is retrying with different data - this is a protocol error + // Clear the buffer state and return an SSL error (bad write retry) + *socket.write_buffered_len.lock() = 0; + return Err(SslError::Ssl("bad write retry".to_string())); + } + // else: already_buffered == data.len(), this is a valid retry + + // Loop to send TLS records, handling WANT_READ/WANT_WRITE + // Matches CPython's do-while loop on SSL_ERROR_WANT_READ/WANT_WRITE + loop { + // Check deadline + if let Some(dl) = deadline + && std::time::Instant::now() >= dl + { + return Err(SslError::Timeout( + "The write operation timed out".to_string(), + )); + } + + // Check if rustls has TLS data to send + if !conn.wants_write() { + // All TLS data sent successfully + break; + } + + // Get TLS records from rustls + let tls_data = ssl_write_tls_records(conn)?; + if tls_data.is_empty() { + break; + } + + // Send TLS data to socket + match send_all_bytes(socket, tls_data, vm, deadline) { + Ok(()) => { + // Successfully sent, continue loop to check for more data + } + Err(SslError::WantWrite) => { + // Non-blocking socket would block - return WANT_WRITE + // Keep write_buffered_len set so we don't re-buffer on retry + return Err(SslError::WantWrite); + } + Err(SslError::WantRead) => { + // Need to read before write can complete (e.g., renegotiation) + // This matches CPython's handling of SSL_ERROR_WANT_READ in write + if is_bio { + // Keep write_buffered_len set so we don't re-buffer on retry + return Err(SslError::WantRead); + } + // For socket mode, try to read TLS data + let recv_result = socket.sock_recv(4096, vm).map_err(SslError::Py)?; + ssl_read_tls_records(conn, recv_result, false, vm)?; + conn.process_new_packets().map_err(SslError::from_rustls)?; + // Continue loop + } + Err(e @ SslError::Timeout(_)) => { + // Preserve buffered state so retry doesn't duplicate data + // (send_all_bytes saved unsent TLS bytes to pending_tls_output) + return Err(e); + } + Err(e) => { + // Clear buffer state on error + *socket.write_buffered_len.lock() = 0; + return Err(e); + } + } + } + + // Final flush to ensure all data is sent + if !is_bio { + socket + .flush_pending_tls_output(vm, deadline) + .map_err(SslError::Py)?; + } + + // Write completed successfully - clear buffer state + *socket.write_buffered_len.lock() = 0; + + Ok(data.len()) +} + // Helper functions (private-ish, used by public SSL functions) /// Write TLS records from rustls to socket @@ -1684,26 +1868,24 @@ fn ssl_read_tls_records( // 1. Clean shutdown: received TLS close_notify → return ZeroReturn (0 bytes) // 2. Unexpected EOF: no close_notify → return Eof (SSLEOFError) // - // SSL_ERROR_ZERO_RETURN vs SSL_ERROR_SYSCALL(errno=0) logic + // SSL_ERROR_ZERO_RETURN vs SSL_ERROR_EOF logic // CPython checks SSL_get_shutdown() & SSL_RECEIVED_SHUTDOWN // // Process any buffered TLS records (may contain close_notify) - let _ = conn.process_new_packets(); - - // IMPORTANT: CPython's default behavior (suppress_ragged_eofs=True) - // treats empty recv() as clean shutdown, returning 0 bytes instead of raising SSLEOFError. - // - // This is necessary for HTTP/1.0 servers that: - // 1. Send response without Content-Length header - // 2. Signal end-of-response by closing connection (TCP FIN) - // 3. Don't send TLS close_notify before TCP close - // - // While this could theoretically allow truncation attacks, - // it's the standard behavior for compatibility with real-world servers. - // Python only raises SSLEOFError when suppress_ragged_eofs=False is explicitly set. - // - // TODO: Implement suppress_ragged_eofs parameter if needed for strict security mode. - return Err(SslError::ZeroReturn); + match conn.process_new_packets() { + Ok(io_state) => { + if io_state.peer_has_closed() { + // Received close_notify - normal SSL closure (SSL_ERROR_ZERO_RETURN) + return Err(SslError::ZeroReturn); + } else { + // No close_notify - ragged EOF (SSL_ERROR_EOF → SSLEOFError) + // CPython raises SSLEOFError here, which SSLSocket.read() handles + // based on suppress_ragged_eofs setting + return Err(SslError::Eof); + } + } + Err(e) => return Err(SslError::from_rustls(e)), + } } } @@ -1816,6 +1998,9 @@ fn ssl_ensure_data_available( let data = match socket.sock_recv(2048, vm) { Ok(data) => data, Err(e) => { + if is_blocking_io_error(&e, vm) { + return Err(SslError::WantRead); + } // Before returning socket error, check if rustls already has a queued TLS alert // This mirrors CPython/OpenSSL behavior: SSL errors take precedence over socket errors // On Windows, TCP RST may arrive before we read the alert, but rustls may have diff --git a/crates/stdlib/src/ssl/error.rs b/crates/stdlib/src/ssl/error.rs index 6219eff41b5..cbc59e0e8f6 100644 --- a/crates/stdlib/src/ssl/error.rs +++ b/crates/stdlib/src/ssl/error.rs @@ -132,4 +132,15 @@ pub(crate) mod ssl_error { "TLS/SSL connection has been closed (EOF)", ) } + + pub fn create_ssl_syscall_error( + vm: &VirtualMachine, + msg: impl Into, + ) -> PyRef { + vm.new_os_subtype_error( + PySSLSyscallError::class(&vm.ctx).to_owned(), + Some(SSL_ERROR_SYSCALL), + msg.into(), + ) + } } diff --git a/crates/vm/Cargo.toml b/crates/vm/Cargo.toml index b74aba41145..da01eff65b9 100644 --- a/crates/vm/Cargo.toml +++ b/crates/vm/Cargo.toml @@ -77,6 +77,7 @@ memchr = { workspace = true } caseless = "0.2.2" flamer = { version = "0.5", optional = true } half = "2" +psm = "0.1" optional = { workspace = true } result-like = "0.5.0" timsort = "0.1.2" diff --git a/crates/vm/src/builtins/asyncgenerator.rs b/crates/vm/src/builtins/asyncgenerator.rs index 8f7518f0eae..8b7c107d4b8 100644 --- a/crates/vm/src/builtins/asyncgenerator.rs +++ b/crates/vm/src/builtins/asyncgenerator.rs @@ -87,9 +87,14 @@ impl PyAsyncGen { if let Some(finalizer) = finalizer && !zelf.inner.closed.load() { - // Ignore any errors (PyErr_WriteUnraisable) + // Create a strong reference for the finalizer call. + // This keeps the object alive during the finalizer execution. let obj: PyObjectRef = zelf.to_owned().into(); - let _ = finalizer.call((obj,), vm); + + // Call the finalizer. Any exceptions are handled as unraisable. + if let Err(e) = finalizer.call((obj,), vm) { + vm.run_unraisable(e, Some("async generator finalizer".to_owned()), finalizer); + } } } @@ -496,11 +501,13 @@ impl PyAsyncGenAThrow { } fn yield_close(&self, vm: &VirtualMachine) -> PyBaseExceptionRef { self.ag.running_async.store(false); + self.ag.inner.closed.store(true); self.state.store(AwaitableState::Closed); vm.new_runtime_error("async generator ignored GeneratorExit") } fn check_error(&self, exc: PyBaseExceptionRef, vm: &VirtualMachine) -> PyBaseExceptionRef { self.ag.running_async.store(false); + self.ag.inner.closed.store(true); self.state.store(AwaitableState::Closed); if self.aclose && (exc.fast_isinstance(vm.ctx.exceptions.stop_async_iteration) @@ -687,12 +694,14 @@ impl IterNext for PyAnextAwaitable { /// _PyGen_Finalize for async generators impl Destructor for PyAsyncGen { fn del(zelf: &Py, vm: &VirtualMachine) -> PyResult<()> { - // Generator isn't paused, so no need to close + // Generator is already closed, nothing to do if zelf.inner.closed.load() { return Ok(()); } + // Call the async generator finalizer hook if set. Self::call_finalizer(zelf, vm); + Ok(()) } } diff --git a/crates/vm/src/builtins/dict.rs b/crates/vm/src/builtins/dict.rs index d1adb8a066d..fcb51c2ca0e 100644 --- a/crates/vm/src/builtins/dict.rs +++ b/crates/vm/src/builtins/dict.rs @@ -2,6 +2,7 @@ use super::{ IterStatus, PositionIterInternal, PyBaseExceptionRef, PyGenericAlias, PyMappingProxy, PySet, PyStr, PyStrRef, PyTupleRef, PyType, PyTypeRef, set::PySetInner, }; +use crate::object::{Traverse, TraverseFn}; use crate::{ AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyRefExact, PyResult, TryFromObject, atomic_func, @@ -29,13 +30,28 @@ use std::sync::LazyLock; pub type DictContentType = dict_inner::Dict; -#[pyclass(module = false, name = "dict", unhashable = true, traverse)] +#[pyclass(module = false, name = "dict", unhashable = true, traverse = "manual")] #[derive(Default)] pub struct PyDict { entries: DictContentType, } pub type PyDictRef = PyRef; +// SAFETY: Traverse properly visits all owned PyObjectRefs +unsafe impl Traverse for PyDict { + fn traverse(&self, traverse_fn: &mut TraverseFn<'_>) { + self.entries.traverse(traverse_fn); + } + + fn clear(&mut self, out: &mut Vec) { + // Pop all entries and collect both keys and values + for (key, value) in self.entries.drain_entries() { + out.push(key); + out.push(value); + } + } +} + impl fmt::Debug for PyDict { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { // TODO: implement more detailed, non-recursive Debug formatter diff --git a/crates/vm/src/builtins/function.rs b/crates/vm/src/builtins/function.rs index 9297cf07201..632fd867d2e 100644 --- a/crates/vm/src/builtins/function.rs +++ b/crates/vm/src/builtins/function.rs @@ -51,6 +51,70 @@ unsafe impl Traverse for PyFunction { closure.as_untyped().traverse(tracer_fn); } self.defaults_and_kwdefaults.traverse(tracer_fn); + // Traverse additional fields that may contain references + self.type_params.lock().traverse(tracer_fn); + self.annotations.lock().traverse(tracer_fn); + self.module.lock().traverse(tracer_fn); + self.doc.lock().traverse(tracer_fn); + } + + fn clear(&mut self, out: &mut Vec) { + // Pop closure if present (equivalent to Py_CLEAR(func_closure)) + if let Some(closure) = self.closure.take() { + out.push(closure.into()); + } + + // Pop defaults and kwdefaults + if let Some(mut guard) = self.defaults_and_kwdefaults.try_lock() { + if let Some(defaults) = guard.0.take() { + out.push(defaults.into()); + } + if let Some(kwdefaults) = guard.1.take() { + out.push(kwdefaults.into()); + } + } + + // Clear annotations and annotate (Py_CLEAR) + if let Some(mut guard) = self.annotations.try_lock() + && let Some(annotations) = guard.take() + { + out.push(annotations.into()); + } + if let Some(mut guard) = self.annotate.try_lock() + && let Some(annotate) = guard.take() + { + out.push(annotate); + } + + // Clear module, doc, and type_params (Py_CLEAR) + if let Some(mut guard) = self.module.try_lock() { + let old_module = + std::mem::replace(&mut *guard, Context::genesis().none.to_owned().into()); + out.push(old_module); + } + if let Some(mut guard) = self.doc.try_lock() { + let old_doc = std::mem::replace(&mut *guard, Context::genesis().none.to_owned().into()); + out.push(old_doc); + } + if let Some(mut guard) = self.type_params.try_lock() { + let old_type_params = + std::mem::replace(&mut *guard, Context::genesis().empty_tuple.to_owned()); + out.push(old_type_params.into()); + } + + // Replace name and qualname with empty string to break potential str subclass cycles + // name and qualname could be str subclasses, so they could have reference cycles + if let Some(mut guard) = self.name.try_lock() { + let old_name = std::mem::replace(&mut *guard, Context::genesis().empty_str.to_owned()); + out.push(old_name.into()); + } + if let Some(mut guard) = self.qualname.try_lock() { + let old_qualname = + std::mem::replace(&mut *guard, Context::genesis().empty_str.to_owned()); + out.push(old_qualname.into()); + } + + // Note: globals, builtins, code are NOT cleared (required to be non-NULL) } } diff --git a/crates/vm/src/builtins/list.rs b/crates/vm/src/builtins/list.rs index 02475ee12b6..84825de7d3d 100644 --- a/crates/vm/src/builtins/list.rs +++ b/crates/vm/src/builtins/list.rs @@ -3,6 +3,7 @@ use crate::atomic_func; use crate::common::lock::{ PyMappedRwLockReadGuard, PyMutex, PyRwLock, PyRwLockReadGuard, PyRwLockWriteGuard, }; +use crate::object::{Traverse, TraverseFn}; use crate::{ AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, class::PyClassImpl, @@ -23,7 +24,7 @@ use crate::{ use alloc::fmt; use core::ops::DerefMut; -#[pyclass(module = false, name = "list", unhashable = true, traverse)] +#[pyclass(module = false, name = "list", unhashable = true, traverse = "manual")] #[derive(Default)] pub struct PyList { elements: PyRwLock>, @@ -50,6 +51,22 @@ impl FromIterator for PyList { } } +// SAFETY: Traverse properly visits all owned PyObjectRefs +unsafe impl Traverse for PyList { + fn traverse(&self, traverse_fn: &mut TraverseFn<'_>) { + self.elements.traverse(traverse_fn); + } + + fn clear(&mut self, out: &mut Vec) { + // During GC, we use interior mutability to access elements. + // This is safe because during GC collection, the object is unreachable + // and no other code should be accessing it. + if let Some(mut guard) = self.elements.try_write() { + out.extend(guard.drain(..)); + } + } +} + impl PyPayload for PyList { #[inline] fn class(ctx: &Context) -> &'static Py { diff --git a/crates/vm/src/builtins/str.rs b/crates/vm/src/builtins/str.rs index 640778c8cb9..d765847c1ab 100644 --- a/crates/vm/src/builtins/str.rs +++ b/crates/vm/src/builtins/str.rs @@ -1924,9 +1924,16 @@ impl fmt::Display for PyUtf8Str { } impl MaybeTraverse for PyUtf8Str { + const HAS_TRAVERSE: bool = true; + const HAS_CLEAR: bool = false; + fn try_traverse(&self, traverse_fn: &mut TraverseFn<'_>) { self.0.try_traverse(traverse_fn); } + + fn try_clear(&mut self, _out: &mut Vec) { + // No clear needed for PyUtf8Str + } } impl PyPayload for PyUtf8Str { diff --git a/crates/vm/src/builtins/tuple.rs b/crates/vm/src/builtins/tuple.rs index f6eff5b91e5..ba296686c73 100644 --- a/crates/vm/src/builtins/tuple.rs +++ b/crates/vm/src/builtins/tuple.rs @@ -3,6 +3,7 @@ use crate::common::{ hash::{PyHash, PyUHash}, lock::PyMutex, }; +use crate::object::{Traverse, TraverseFn}; use crate::{ AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject, atomic_func, @@ -24,7 +25,7 @@ use crate::{ use alloc::fmt; use std::sync::LazyLock; -#[pyclass(module = false, name = "tuple", traverse)] +#[pyclass(module = false, name = "tuple", traverse = "manual")] pub struct PyTuple { elements: Box<[R]>, } @@ -36,6 +37,19 @@ impl fmt::Debug for PyTuple { } } +// SAFETY: Traverse properly visits all owned PyObjectRefs +// Note: Only impl for PyTuple (the default) +unsafe impl Traverse for PyTuple { + fn traverse(&self, traverse_fn: &mut TraverseFn<'_>) { + self.elements.traverse(traverse_fn); + } + + fn clear(&mut self, out: &mut Vec) { + let elements = std::mem::take(&mut self.elements); + out.extend(elements.into_vec()); + } +} + impl PyPayload for PyTuple { #[inline] fn class(ctx: &Context) -> &'static Py { diff --git a/crates/vm/src/dict_inner.rs b/crates/vm/src/dict_inner.rs index 1d9fe8403ab..f2a379d99a5 100644 --- a/crates/vm/src/dict_inner.rs +++ b/crates/vm/src/dict_inner.rs @@ -724,6 +724,17 @@ impl Dict { + inner.indices.len() * size_of::() + inner.entries.len() * size_of::>() } + + /// Pop all entries from the dict, returning (key, value) pairs. + /// This is used for circular reference resolution in GC. + /// Requires &mut self to avoid lock contention. + pub fn drain_entries(&mut self) -> impl Iterator + '_ { + let inner = self.inner.get_mut(); + inner.used = 0; + inner.filled = 0; + inner.indices.iter_mut().for_each(|i| *i = IndexEntry::FREE); + inner.entries.drain(..).flatten().map(|e| (e.key, e.value)) + } } type LookupResult = (IndexEntry, IndexIndex); diff --git a/crates/vm/src/frame.rs b/crates/vm/src/frame.rs index d83be841275..55788ee3e4a 100644 --- a/crates/vm/src/frame.rs +++ b/crates/vm/src/frame.rs @@ -13,6 +13,7 @@ use crate::{ coroutine::Coro, exceptions::ExceptionCtor, function::{ArgMapping, Either, FuncArgs}, + object::{Traverse, TraverseFn}, protocol::{PyIter, PyIterReturn}, scope::Scope, stdlib::{builtins, typing}, @@ -66,7 +67,7 @@ type Lasti = atomic::AtomicU32; #[cfg(not(feature = "threading"))] type Lasti = core::cell::Cell; -#[pyclass(module = false, name = "frame")] +#[pyclass(module = false, name = "frame", traverse = "manual")] pub struct Frame { pub code: PyRef, pub func_obj: Option, @@ -97,6 +98,27 @@ impl PyPayload for Frame { } } +unsafe impl Traverse for FrameState { + fn traverse(&self, tracer_fn: &mut TraverseFn<'_>) { + self.stack.traverse(tracer_fn); + } +} + +unsafe impl Traverse for Frame { + fn traverse(&self, tracer_fn: &mut TraverseFn<'_>) { + self.code.traverse(tracer_fn); + self.func_obj.traverse(tracer_fn); + self.fastlocals.traverse(tracer_fn); + self.cells_frees.traverse(tracer_fn); + self.locals.traverse(tracer_fn); + self.globals.traverse(tracer_fn); + self.builtins.traverse(tracer_fn); + self.trace.traverse(tracer_fn); + self.state.traverse(tracer_fn); + self.temporary_refs.traverse(tracer_fn); + } +} + // Running a frame can result in one of the below: pub enum ExecutionResult { Return(PyObjectRef), @@ -763,9 +785,9 @@ impl ExecutingFrame<'_> { let args = self.collect_keyword_args(nargs.get(arg)); self.execute_call(args, vm) } - Instruction::CallFunctionEx { has_kwargs } => { - // Stack: [callable, self_or_null, args_tuple, (kwargs_dict)?] - let args = self.collect_ex_args(vm, has_kwargs.get(arg))?; + Instruction::CallFunctionEx => { + // Stack: [callable, self_or_null, args_tuple, kwargs_or_null] + let args = self.collect_ex_args(vm)?; self.execute_call(args, vm) } Instruction::CallIntrinsic1 { func } => { @@ -2128,9 +2150,9 @@ impl ExecutingFrame<'_> { FuncArgs::with_kwargs_names(args, kwarg_names) } - fn collect_ex_args(&mut self, vm: &VirtualMachine, has_kwargs: bool) -> PyResult { - let kwargs = if has_kwargs { - let kw_obj = self.pop_value(); + fn collect_ex_args(&mut self, vm: &VirtualMachine) -> PyResult { + let kwargs_or_null = self.pop_value_opt(); + let kwargs = if let Some(kw_obj) = kwargs_or_null { let mut kwargs = IndexMap::new(); // Use keys() method for all mapping objects to preserve order diff --git a/crates/vm/src/object/core.rs b/crates/vm/src/object/core.rs index 4e51e296462..c949cae9053 100644 --- a/crates/vm/src/object/core.rs +++ b/crates/vm/src/object/core.rs @@ -93,7 +93,7 @@ pub(super) unsafe fn debug_obj( } /// Call `try_trace` on payload -pub(super) unsafe fn try_trace_obj(x: &PyObject, tracer_fn: &mut TraverseFn<'_>) { +pub(super) unsafe fn try_traverse_obj(x: &PyObject, tracer_fn: &mut TraverseFn<'_>) { let x = unsafe { &*(x as *const PyObject as *const PyInner) }; let payload = &x.payload; payload.try_traverse(tracer_fn) diff --git a/crates/vm/src/object/traverse.rs b/crates/vm/src/object/traverse.rs index 2ce0db41a5e..367076b78e3 100644 --- a/crates/vm/src/object/traverse.rs +++ b/crates/vm/src/object/traverse.rs @@ -12,9 +12,13 @@ pub type TraverseFn<'a> = dyn FnMut(&PyObject) + 'a; /// Every PyObjectPayload impl `MaybeTrace`, which may or may not be traceable pub trait MaybeTraverse { /// if is traceable, will be used by vtable to determine - const IS_TRACE: bool = false; + const HAS_TRAVERSE: bool = false; + /// if has clear implementation for circular reference resolution (tp_clear) + const HAS_CLEAR: bool = false; // if this type is traceable, then call with tracer_fn, default to do nothing fn try_traverse(&self, traverse_fn: &mut TraverseFn<'_>); + // if this type has clear, extract child refs for circular reference resolution (tp_clear) + fn try_clear(&mut self, _out: &mut Vec) {} } /// Type that need traverse it's children should impl [`Traverse`] (not [`MaybeTraverse`]) @@ -28,6 +32,11 @@ pub unsafe trait Traverse { /// /// - _**DO NOT**_ clone a [`PyObjectRef`] or [`PyRef`] in [`Traverse::traverse()`] fn traverse(&self, traverse_fn: &mut TraverseFn<'_>); + + /// Extract all owned child PyObjectRefs for circular reference resolution (tp_clear). + /// Called just before object deallocation to break circular references. + /// Default implementation does nothing. + fn clear(&mut self, _out: &mut Vec) {} } unsafe impl Traverse for PyObjectRef { diff --git a/crates/vm/src/object/traverse_object.rs b/crates/vm/src/object/traverse_object.rs index 7a66f0b35f0..840bbd42b39 100644 --- a/crates/vm/src/object/traverse_object.rs +++ b/crates/vm/src/object/traverse_object.rs @@ -5,7 +5,7 @@ use crate::{ PyObject, object::{ Erased, InstanceDict, MaybeTraverse, PyInner, PyObjectPayload, debug_obj, drop_dealloc_obj, - try_trace_obj, + try_traverse_obj, }, }; @@ -25,8 +25,8 @@ impl PyObjVTable { drop_dealloc: drop_dealloc_obj::, debug: debug_obj::, trace: const { - if T::IS_TRACE { - Some(try_trace_obj::) + if T::HAS_TRAVERSE { + Some(try_traverse_obj::) } else { None } diff --git a/crates/vm/src/stdlib/ast.rs b/crates/vm/src/stdlib/ast.rs index 31aad306f96..116d239c033 100644 --- a/crates/vm/src/stdlib/ast.rs +++ b/crates/vm/src/stdlib/ast.rs @@ -19,7 +19,7 @@ use crate::{ convert::ToPyObject, }; use node::Node; -use ruff_python_ast as ruff; +use ruff_python_ast as ast; use ruff_text_size::{Ranged, TextRange, TextSize}; use rustpython_compiler_core::{ LineIndex, OneIndexed, PositionEncoding, SourceFile, SourceFileBuilder, SourceLocation, @@ -283,8 +283,8 @@ pub(crate) fn parse( })? .into_syntax(); let top = match top { - ruff::Mod::Module(m) => Mod::Module(m), - ruff::Mod::Expression(e) => Mod::Expression(e), + ast::Mod::Module(m) => Mod::Module(m), + ast::Mod::Expression(e) => Mod::Expression(e), }; Ok(top.ast_to_object(vm, &source_file)) } @@ -305,13 +305,13 @@ pub(crate) fn compile( let source_file = SourceFileBuilder::new(filename.to_owned(), "".to_owned()).finish(); let ast: Mod = Node::ast_from_object(vm, &source_file, object)?; let ast = match ast { - Mod::Module(m) => ruff::Mod::Module(m), - Mod::Interactive(ModInteractive { range, body }) => ruff::Mod::Module(ruff::ModModule { + Mod::Module(m) => ast::Mod::Module(m), + Mod::Interactive(ModInteractive { range, body }) => ast::Mod::Module(ast::ModModule { node_index: Default::default(), range, body, }), - Mod::Expression(e) => ruff::Mod::Expression(e), + Mod::Expression(e) => ast::Mod::Expression(e), Mod::FunctionType(_) => todo!(), }; // TODO: create a textual representation of the ast diff --git a/crates/vm/src/stdlib/ast/argument.rs b/crates/vm/src/stdlib/ast/argument.rs index a13200e6502..626024f5bd6 100644 --- a/crates/vm/src/stdlib/ast/argument.rs +++ b/crates/vm/src/stdlib/ast/argument.rs @@ -3,7 +3,7 @@ use rustpython_compiler_core::SourceFile; pub(super) struct PositionalArguments { pub range: TextRange, - pub args: Box<[ruff::Expr]>, + pub args: Box<[ast::Expr]>, } impl Node for PositionalArguments { @@ -27,7 +27,7 @@ impl Node for PositionalArguments { pub(super) struct KeywordArguments { pub range: TextRange, - pub keywords: Box<[ruff::Keyword]>, + pub keywords: Box<[ast::Keyword]>, } impl Node for KeywordArguments { @@ -53,10 +53,10 @@ impl Node for KeywordArguments { pub(super) fn merge_function_call_arguments( pos_args: PositionalArguments, key_args: KeywordArguments, -) -> ruff::Arguments { +) -> ast::Arguments { let range = pos_args.range.cover(key_args.range); - ruff::Arguments { + ast::Arguments { node_index: Default::default(), range, args: pos_args.args, @@ -65,9 +65,9 @@ pub(super) fn merge_function_call_arguments( } pub(super) fn split_function_call_arguments( - args: ruff::Arguments, + args: ast::Arguments, ) -> (PositionalArguments, KeywordArguments) { - let ruff::Arguments { + let ast::Arguments { node_index: _, range: _, args, @@ -100,13 +100,13 @@ pub(super) fn split_function_call_arguments( } pub(super) fn split_class_def_args( - args: Option>, + args: Option>, ) -> (Option, Option) { let args = match args { None => return (None, None), Some(args) => *args, }; - let ruff::Arguments { + let ast::Arguments { node_index: _, range: _, args, @@ -141,7 +141,7 @@ pub(super) fn split_class_def_args( pub(super) fn merge_class_def_args( positional_arguments: Option, keyword_arguments: Option, -) -> Option> { +) -> Option> { if positional_arguments.is_none() && keyword_arguments.is_none() { return None; } @@ -157,7 +157,7 @@ pub(super) fn merge_class_def_args( vec![].into_boxed_slice() }; - Some(Box::new(ruff::Arguments { + Some(Box::new(ast::Arguments { node_index: Default::default(), range: Default::default(), // TODO args, diff --git a/crates/vm/src/stdlib/ast/basic.rs b/crates/vm/src/stdlib/ast/basic.rs index d8565029d6c..612b6144eea 100644 --- a/crates/vm/src/stdlib/ast/basic.rs +++ b/crates/vm/src/stdlib/ast/basic.rs @@ -2,7 +2,7 @@ use super::*; use rustpython_codegen::compile::ruff_int_to_bigint; use rustpython_compiler_core::SourceFile; -impl Node for ruff::Identifier { +impl Node for ast::Identifier { fn ast_to_object(self, vm: &VirtualMachine, _source_file: &SourceFile) -> PyObjectRef { let id = self.as_str(); vm.ctx.new_str(id).into() @@ -18,7 +18,7 @@ impl Node for ruff::Identifier { } } -impl Node for ruff::Int { +impl Node for ast::Int { fn ast_to_object(self, vm: &VirtualMachine, _source_file: &SourceFile) -> PyObjectRef { vm.ctx.new_int(ruff_int_to_bigint(&self).unwrap()).into() } diff --git a/crates/vm/src/stdlib/ast/constant.rs b/crates/vm/src/stdlib/ast/constant.rs index 83b2a7f7015..a6aac224585 100644 --- a/crates/vm/src/stdlib/ast/constant.rs +++ b/crates/vm/src/stdlib/ast/constant.rs @@ -1,6 +1,6 @@ use super::*; use crate::builtins::{PyComplex, PyFrozenSet, PyTuple}; -use ruff::str_prefix::StringLiteralPrefix; +use ast::str_prefix::StringLiteralPrefix; use rustpython_compiler_core::SourceFile; #[derive(Debug)] @@ -22,7 +22,7 @@ impl Constant { } } - pub(super) const fn new_int(value: ruff::Int, range: TextRange) -> Self { + pub(super) const fn new_int(value: ast::Int, range: TextRange) -> Self { Self { range, value: ConstantLiteral::Int(value), @@ -71,7 +71,7 @@ impl Constant { } } - pub(crate) fn into_expr(self) -> ruff::Expr { + pub(crate) fn into_expr(self) -> ast::Expr { constant_to_ruff_expr(self) } } @@ -85,7 +85,7 @@ pub(crate) enum ConstantLiteral { prefix: StringLiteralPrefix, }, Bytes(Box<[u8]>), - Int(ruff::Int), + Int(ast::Int), Tuple(Vec), FrozenSet(Vec), Float(f64), @@ -244,48 +244,48 @@ impl Node for ConstantLiteral { } } -fn constant_to_ruff_expr(value: Constant) -> ruff::Expr { +fn constant_to_ruff_expr(value: Constant) -> ast::Expr { let Constant { value, range } = value; match value { - ConstantLiteral::None => ruff::Expr::NoneLiteral(ruff::ExprNoneLiteral { + ConstantLiteral::None => ast::Expr::NoneLiteral(ast::ExprNoneLiteral { node_index: Default::default(), range, }), - ConstantLiteral::Bool(value) => ruff::Expr::BooleanLiteral(ruff::ExprBooleanLiteral { + ConstantLiteral::Bool(value) => ast::Expr::BooleanLiteral(ast::ExprBooleanLiteral { node_index: Default::default(), range, value, }), ConstantLiteral::Str { value, prefix } => { - ruff::Expr::StringLiteral(ruff::ExprStringLiteral { + ast::Expr::StringLiteral(ast::ExprStringLiteral { node_index: Default::default(), range, - value: ruff::StringLiteralValue::single(ruff::StringLiteral { + value: ast::StringLiteralValue::single(ast::StringLiteral { node_index: Default::default(), range, value, - flags: ruff::StringLiteralFlags::empty().with_prefix(prefix), + flags: ast::StringLiteralFlags::empty().with_prefix(prefix), }), }) } ConstantLiteral::Bytes(value) => { - ruff::Expr::BytesLiteral(ruff::ExprBytesLiteral { + ast::Expr::BytesLiteral(ast::ExprBytesLiteral { node_index: Default::default(), range, - value: ruff::BytesLiteralValue::single(ruff::BytesLiteral { + value: ast::BytesLiteralValue::single(ast::BytesLiteral { node_index: Default::default(), range, value, - flags: ruff::BytesLiteralFlags::empty(), // TODO + flags: ast::BytesLiteralFlags::empty(), // TODO }), }) } - ConstantLiteral::Int(value) => ruff::Expr::NumberLiteral(ruff::ExprNumberLiteral { + ConstantLiteral::Int(value) => ast::Expr::NumberLiteral(ast::ExprNumberLiteral { node_index: Default::default(), range, - value: ruff::Number::Int(value), + value: ast::Number::Int(value), }), - ConstantLiteral::Tuple(value) => ruff::Expr::Tuple(ruff::ExprTuple { + ConstantLiteral::Tuple(value) => ast::Expr::Tuple(ast::ExprTuple { node_index: Default::default(), range, elts: value @@ -297,21 +297,21 @@ fn constant_to_ruff_expr(value: Constant) -> ruff::Expr { }) }) .collect(), - ctx: ruff::ExprContext::Load, + ctx: ast::ExprContext::Load, // TODO: Does this matter? parenthesized: true, }), - ConstantLiteral::FrozenSet(value) => ruff::Expr::Call(ruff::ExprCall { + ConstantLiteral::FrozenSet(value) => ast::Expr::Call(ast::ExprCall { node_index: Default::default(), range, // idk lol - func: Box::new(ruff::Expr::Name(ruff::ExprName { + func: Box::new(ast::Expr::Name(ast::ExprName { node_index: Default::default(), range: TextRange::default(), - id: ruff::name::Name::new_static("frozenset"), - ctx: ruff::ExprContext::Load, + id: ast::name::Name::new_static("frozenset"), + ctx: ast::ExprContext::Load, })), - arguments: ruff::Arguments { + arguments: ast::Arguments { node_index: Default::default(), range, args: value @@ -326,19 +326,19 @@ fn constant_to_ruff_expr(value: Constant) -> ruff::Expr { keywords: Box::default(), }, }), - ConstantLiteral::Float(value) => ruff::Expr::NumberLiteral(ruff::ExprNumberLiteral { + ConstantLiteral::Float(value) => ast::Expr::NumberLiteral(ast::ExprNumberLiteral { node_index: Default::default(), range, - value: ruff::Number::Float(value), + value: ast::Number::Float(value), }), ConstantLiteral::Complex { real, imag } => { - ruff::Expr::NumberLiteral(ruff::ExprNumberLiteral { + ast::Expr::NumberLiteral(ast::ExprNumberLiteral { node_index: Default::default(), range, - value: ruff::Number::Complex { real, imag }, + value: ast::Number::Complex { real, imag }, }) } - ConstantLiteral::Ellipsis => ruff::Expr::EllipsisLiteral(ruff::ExprEllipsisLiteral { + ConstantLiteral::Ellipsis => ast::Expr::EllipsisLiteral(ast::ExprEllipsisLiteral { node_index: Default::default(), range, }), @@ -348,17 +348,17 @@ fn constant_to_ruff_expr(value: Constant) -> ruff::Expr { pub(super) fn number_literal_to_object( vm: &VirtualMachine, source_file: &SourceFile, - constant: ruff::ExprNumberLiteral, + constant: ast::ExprNumberLiteral, ) -> PyObjectRef { - let ruff::ExprNumberLiteral { + let ast::ExprNumberLiteral { node_index: _, range, value, } = constant; let c = match value { - ruff::Number::Int(n) => Constant::new_int(n, range), - ruff::Number::Float(n) => Constant::new_float(n, range), - ruff::Number::Complex { real, imag } => Constant::new_complex(real, imag, range), + ast::Number::Int(n) => Constant::new_int(n, range), + ast::Number::Float(n) => Constant::new_float(n, range), + ast::Number::Complex { real, imag } => Constant::new_complex(real, imag, range), }; c.ast_to_object(vm, source_file) } @@ -366,9 +366,9 @@ pub(super) fn number_literal_to_object( pub(super) fn string_literal_to_object( vm: &VirtualMachine, source_file: &SourceFile, - constant: ruff::ExprStringLiteral, + constant: ast::ExprStringLiteral, ) -> PyObjectRef { - let ruff::ExprStringLiteral { + let ast::ExprStringLiteral { node_index: _, range, value, @@ -384,9 +384,9 @@ pub(super) fn string_literal_to_object( pub(super) fn bytes_literal_to_object( vm: &VirtualMachine, source_file: &SourceFile, - constant: ruff::ExprBytesLiteral, + constant: ast::ExprBytesLiteral, ) -> PyObjectRef { - let ruff::ExprBytesLiteral { + let ast::ExprBytesLiteral { node_index: _, range, value, @@ -399,9 +399,9 @@ pub(super) fn bytes_literal_to_object( pub(super) fn boolean_literal_to_object( vm: &VirtualMachine, source_file: &SourceFile, - constant: ruff::ExprBooleanLiteral, + constant: ast::ExprBooleanLiteral, ) -> PyObjectRef { - let ruff::ExprBooleanLiteral { + let ast::ExprBooleanLiteral { node_index: _, range, value, @@ -413,9 +413,9 @@ pub(super) fn boolean_literal_to_object( pub(super) fn none_literal_to_object( vm: &VirtualMachine, source_file: &SourceFile, - constant: ruff::ExprNoneLiteral, + constant: ast::ExprNoneLiteral, ) -> PyObjectRef { - let ruff::ExprNoneLiteral { + let ast::ExprNoneLiteral { node_index: _, range, } = constant; @@ -426,9 +426,9 @@ pub(super) fn none_literal_to_object( pub(super) fn ellipsis_literal_to_object( vm: &VirtualMachine, source_file: &SourceFile, - constant: ruff::ExprEllipsisLiteral, + constant: ast::ExprEllipsisLiteral, ) -> PyObjectRef { - let ruff::ExprEllipsisLiteral { + let ast::ExprEllipsisLiteral { node_index: _, range, } = constant; diff --git a/crates/vm/src/stdlib/ast/elif_else_clause.rs b/crates/vm/src/stdlib/ast/elif_else_clause.rs index e2a8789dd08..b27e956077e 100644 --- a/crates/vm/src/stdlib/ast/elif_else_clause.rs +++ b/crates/vm/src/stdlib/ast/elif_else_clause.rs @@ -2,12 +2,12 @@ use super::*; use rustpython_compiler_core::SourceFile; pub(super) fn ast_to_object( - clause: ruff::ElifElseClause, - mut rest: alloc::vec::IntoIter, + clause: ast::ElifElseClause, + mut rest: alloc::vec::IntoIter, vm: &VirtualMachine, source_file: &SourceFile, ) -> PyObjectRef { - let ruff::ElifElseClause { + let ast::ElifElseClause { node_index: _, range, test, @@ -48,18 +48,18 @@ pub(super) fn ast_from_object( vm: &VirtualMachine, source_file: &SourceFile, object: PyObjectRef, -) -> PyResult { +) -> PyResult { let test = Node::ast_from_object(vm, source_file, get_node_field(vm, &object, "test", "If")?)?; let body = Node::ast_from_object(vm, source_file, get_node_field(vm, &object, "body", "If")?)?; - let orelse: Vec = Node::ast_from_object( + let orelse: Vec = Node::ast_from_object( vm, source_file, get_node_field(vm, &object, "orelse", "If")?, )?; let range = range_from_object(vm, source_file, object, "If")?; - let elif_else_clauses = if let [ruff::Stmt::If(_)] = &*orelse { - let Some(ruff::Stmt::If(ruff::StmtIf { + let elif_else_clauses = if let [ast::Stmt::If(_)] = &*orelse { + let Some(ast::Stmt::If(ast::StmtIf { node_index: _, range, test, @@ -71,7 +71,7 @@ pub(super) fn ast_from_object( }; elif_else_clauses.insert( 0, - ruff::ElifElseClause { + ast::ElifElseClause { node_index: Default::default(), range, test: Some(*test), @@ -80,7 +80,7 @@ pub(super) fn ast_from_object( ); elif_else_clauses } else { - vec![ruff::ElifElseClause { + vec![ast::ElifElseClause { node_index: Default::default(), range, test: None, @@ -88,7 +88,7 @@ pub(super) fn ast_from_object( }] }; - Ok(ruff::StmtIf { + Ok(ast::StmtIf { node_index: Default::default(), test, body, diff --git a/crates/vm/src/stdlib/ast/exception.rs b/crates/vm/src/stdlib/ast/exception.rs index b5b3ca2709a..bdb8b7ad9ac 100644 --- a/crates/vm/src/stdlib/ast/exception.rs +++ b/crates/vm/src/stdlib/ast/exception.rs @@ -2,7 +2,7 @@ use super::*; use rustpython_compiler_core::SourceFile; // sum -impl Node for ruff::ExceptHandler { +impl Node for ast::ExceptHandler { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { match self { Self::ExceptHandler(cons) => cons.ast_to_object(vm, source_file), @@ -16,7 +16,7 @@ impl Node for ruff::ExceptHandler { let _cls = _object.class(); Ok( if _cls.is(pyast::NodeExceptHandlerExceptHandler::static_type()) { - Self::ExceptHandler(ruff::ExceptHandlerExceptHandler::ast_from_object( + Self::ExceptHandler(ast::ExceptHandlerExceptHandler::ast_from_object( _vm, source_file, _object, @@ -32,7 +32,7 @@ impl Node for ruff::ExceptHandler { } // constructor -impl Node for ruff::ExceptHandlerExceptHandler { +impl Node for ast::ExceptHandlerExceptHandler { fn ast_to_object(self, _vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { node_index: _, diff --git a/crates/vm/src/stdlib/ast/expression.rs b/crates/vm/src/stdlib/ast/expression.rs index c63d0e0df68..3bf1470795d 100644 --- a/crates/vm/src/stdlib/ast/expression.rs +++ b/crates/vm/src/stdlib/ast/expression.rs @@ -7,7 +7,7 @@ use crate::stdlib::ast::{ use rustpython_compiler_core::SourceFile; // sum -impl Node for ruff::Expr { +impl Node for ast::Expr { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { match self { Self::BoolOp(cons) => cons.ast_to_object(vm, source_file), @@ -59,77 +59,69 @@ impl Node for ruff::Expr { ) -> PyResult { let cls = object.class(); Ok(if cls.is(pyast::NodeExprBoolOp::static_type()) { - Self::BoolOp(ruff::ExprBoolOp::ast_from_object(vm, source_file, object)?) + Self::BoolOp(ast::ExprBoolOp::ast_from_object(vm, source_file, object)?) } else if cls.is(pyast::NodeExprNamedExpr::static_type()) { - Self::Named(ruff::ExprNamed::ast_from_object(vm, source_file, object)?) + Self::Named(ast::ExprNamed::ast_from_object(vm, source_file, object)?) } else if cls.is(pyast::NodeExprBinOp::static_type()) { - Self::BinOp(ruff::ExprBinOp::ast_from_object(vm, source_file, object)?) + Self::BinOp(ast::ExprBinOp::ast_from_object(vm, source_file, object)?) } else if cls.is(pyast::NodeExprUnaryOp::static_type()) { - Self::UnaryOp(ruff::ExprUnaryOp::ast_from_object(vm, source_file, object)?) + Self::UnaryOp(ast::ExprUnaryOp::ast_from_object(vm, source_file, object)?) } else if cls.is(pyast::NodeExprLambda::static_type()) { - Self::Lambda(ruff::ExprLambda::ast_from_object(vm, source_file, object)?) + Self::Lambda(ast::ExprLambda::ast_from_object(vm, source_file, object)?) } else if cls.is(pyast::NodeExprIfExp::static_type()) { - Self::If(ruff::ExprIf::ast_from_object(vm, source_file, object)?) + Self::If(ast::ExprIf::ast_from_object(vm, source_file, object)?) } else if cls.is(pyast::NodeExprDict::static_type()) { - Self::Dict(ruff::ExprDict::ast_from_object(vm, source_file, object)?) + Self::Dict(ast::ExprDict::ast_from_object(vm, source_file, object)?) } else if cls.is(pyast::NodeExprSet::static_type()) { - Self::Set(ruff::ExprSet::ast_from_object(vm, source_file, object)?) + Self::Set(ast::ExprSet::ast_from_object(vm, source_file, object)?) } else if cls.is(pyast::NodeExprListComp::static_type()) { - Self::ListComp(ruff::ExprListComp::ast_from_object( - vm, - source_file, - object, - )?) + Self::ListComp(ast::ExprListComp::ast_from_object(vm, source_file, object)?) } else if cls.is(pyast::NodeExprSetComp::static_type()) { - Self::SetComp(ruff::ExprSetComp::ast_from_object(vm, source_file, object)?) + Self::SetComp(ast::ExprSetComp::ast_from_object(vm, source_file, object)?) } else if cls.is(pyast::NodeExprDictComp::static_type()) { - Self::DictComp(ruff::ExprDictComp::ast_from_object( - vm, - source_file, - object, - )?) + Self::DictComp(ast::ExprDictComp::ast_from_object(vm, source_file, object)?) } else if cls.is(pyast::NodeExprGeneratorExp::static_type()) { - Self::Generator(ruff::ExprGenerator::ast_from_object( + Self::Generator(ast::ExprGenerator::ast_from_object( vm, source_file, object, )?) } else if cls.is(pyast::NodeExprAwait::static_type()) { - Self::Await(ruff::ExprAwait::ast_from_object(vm, source_file, object)?) + Self::Await(ast::ExprAwait::ast_from_object(vm, source_file, object)?) } else if cls.is(pyast::NodeExprYield::static_type()) { - Self::Yield(ruff::ExprYield::ast_from_object(vm, source_file, object)?) + Self::Yield(ast::ExprYield::ast_from_object(vm, source_file, object)?) } else if cls.is(pyast::NodeExprYieldFrom::static_type()) { - Self::YieldFrom(ruff::ExprYieldFrom::ast_from_object( + Self::YieldFrom(ast::ExprYieldFrom::ast_from_object( vm, source_file, object, )?) } else if cls.is(pyast::NodeExprCompare::static_type()) { - Self::Compare(ruff::ExprCompare::ast_from_object(vm, source_file, object)?) + Self::Compare(ast::ExprCompare::ast_from_object(vm, source_file, object)?) } else if cls.is(pyast::NodeExprCall::static_type()) { - Self::Call(ruff::ExprCall::ast_from_object(vm, source_file, object)?) + Self::Call(ast::ExprCall::ast_from_object(vm, source_file, object)?) } else if cls.is(pyast::NodeExprAttribute::static_type()) { - Self::Attribute(ruff::ExprAttribute::ast_from_object( + Self::Attribute(ast::ExprAttribute::ast_from_object( vm, source_file, object, )?) } else if cls.is(pyast::NodeExprSubscript::static_type()) { - Self::Subscript(ruff::ExprSubscript::ast_from_object( + Self::Subscript(ast::ExprSubscript::ast_from_object( vm, source_file, object, )?) } else if cls.is(pyast::NodeExprStarred::static_type()) { - Self::Starred(ruff::ExprStarred::ast_from_object(vm, source_file, object)?) + Self::Starred(ast::ExprStarred::ast_from_object(vm, source_file, object)?) } else if cls.is(pyast::NodeExprName::static_type()) { - Self::Name(ruff::ExprName::ast_from_object(vm, source_file, object)?) + Self::Name(ast::ExprName::ast_from_object(vm, source_file, object)?) } else if cls.is(pyast::NodeExprList::static_type()) { - Self::List(ruff::ExprList::ast_from_object(vm, source_file, object)?) + Self::List(ast::ExprList::ast_from_object(vm, source_file, object)?) } else if cls.is(pyast::NodeExprTuple::static_type()) { - Self::Tuple(ruff::ExprTuple::ast_from_object(vm, source_file, object)?) + Self::Tuple(ast::ExprTuple::ast_from_object(vm, source_file, object)?) } else if cls.is(pyast::NodeExprSlice::static_type()) { - Self::Slice(ruff::ExprSlice::ast_from_object(vm, source_file, object)?) + Self::Slice(ast::ExprSlice::ast_from_object(vm, source_file, object)?) } else if cls.is(pyast::NodeExprConstant::static_type()) { Constant::ast_from_object(vm, source_file, object)?.into_expr() } else if cls.is(pyast::NodeExprJoinedStr::static_type()) { @@ -144,7 +136,7 @@ impl Node for ruff::Expr { } // constructor -impl Node for ruff::ExprBoolOp { +impl Node for ast::ExprBoolOp { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { node_index: _, @@ -187,7 +179,7 @@ impl Node for ruff::ExprBoolOp { } // constructor -impl Node for ruff::ExprNamed { +impl Node for ast::ExprNamed { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { node_index: _, @@ -230,7 +222,7 @@ impl Node for ruff::ExprNamed { } // constructor -impl Node for ruff::ExprBinOp { +impl Node for ast::ExprBinOp { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { node_index: _, @@ -281,7 +273,7 @@ impl Node for ruff::ExprBinOp { } // constructor -impl Node for ruff::ExprUnaryOp { +impl Node for ast::ExprUnaryOp { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { node_index: _, @@ -323,7 +315,7 @@ impl Node for ruff::ExprUnaryOp { } // constructor -impl Node for ruff::ExprLambda { +impl Node for ast::ExprLambda { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { node_index: _, @@ -366,7 +358,7 @@ impl Node for ruff::ExprLambda { } // constructor -impl Node for ruff::ExprIf { +impl Node for ast::ExprIf { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { node_index: _, @@ -417,7 +409,7 @@ impl Node for ruff::ExprIf { } // constructor -impl Node for ruff::ExprDict { +impl Node for ast::ExprDict { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { node_index: _, @@ -449,7 +441,7 @@ impl Node for ruff::ExprDict { source_file: &SourceFile, object: PyObjectRef, ) -> PyResult { - let keys: Vec> = Node::ast_from_object( + let keys: Vec> = Node::ast_from_object( vm, source_file, get_node_field(vm, &object, "keys", "Dict")?, @@ -462,7 +454,7 @@ impl Node for ruff::ExprDict { let items = keys .into_iter() .zip(values) - .map(|(key, value)| ruff::DictItem { key, value }) + .map(|(key, value)| ast::DictItem { key, value }) .collect(); Ok(Self { node_index: Default::default(), @@ -473,7 +465,7 @@ impl Node for ruff::ExprDict { } // constructor -impl Node for ruff::ExprSet { +impl Node for ast::ExprSet { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { node_index: _, @@ -507,7 +499,7 @@ impl Node for ruff::ExprSet { } // constructor -impl Node for ruff::ExprListComp { +impl Node for ast::ExprListComp { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { node_index: _, @@ -550,7 +542,7 @@ impl Node for ruff::ExprListComp { } // constructor -impl Node for ruff::ExprSetComp { +impl Node for ast::ExprSetComp { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { node_index: _, @@ -593,7 +585,7 @@ impl Node for ruff::ExprSetComp { } // constructor -impl Node for ruff::ExprDictComp { +impl Node for ast::ExprDictComp { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { node_index: _, @@ -644,7 +636,7 @@ impl Node for ruff::ExprDictComp { } // constructor -impl Node for ruff::ExprGenerator { +impl Node for ast::ExprGenerator { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { node_index: _, @@ -690,7 +682,7 @@ impl Node for ruff::ExprGenerator { } // constructor -impl Node for ruff::ExprAwait { +impl Node for ast::ExprAwait { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { node_index: _, @@ -724,7 +716,7 @@ impl Node for ruff::ExprAwait { } // constructor -impl Node for ruff::ExprYield { +impl Node for ast::ExprYield { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { node_index: _, @@ -757,7 +749,7 @@ impl Node for ruff::ExprYield { } // constructor -impl Node for ruff::ExprYieldFrom { +impl Node for ast::ExprYieldFrom { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { node_index: _, @@ -792,7 +784,7 @@ impl Node for ruff::ExprYieldFrom { } // constructor -impl Node for ruff::ExprCompare { +impl Node for ast::ExprCompare { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { node_index: _, @@ -853,7 +845,7 @@ impl Node for ruff::ExprCompare { } // constructor -impl Node for ruff::ExprCall { +impl Node for ast::ExprCall { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { node_index: _, @@ -914,7 +906,7 @@ impl Node for ruff::ExprCall { } // constructor -impl Node for ruff::ExprAttribute { +impl Node for ast::ExprAttribute { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { node_index: _, @@ -965,7 +957,7 @@ impl Node for ruff::ExprAttribute { } // constructor -impl Node for ruff::ExprSubscript { +impl Node for ast::ExprSubscript { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { node_index: _, @@ -1015,7 +1007,7 @@ impl Node for ruff::ExprSubscript { } // constructor -impl Node for ruff::ExprStarred { +impl Node for ast::ExprStarred { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { node_index: _, @@ -1057,7 +1049,7 @@ impl Node for ruff::ExprStarred { } // constructor -impl Node for ruff::ExprName { +impl Node for ast::ExprName { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { node_index: _, @@ -1095,7 +1087,7 @@ impl Node for ruff::ExprName { } // constructor -impl Node for ruff::ExprList { +impl Node for ast::ExprList { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { node_index: _, @@ -1138,7 +1130,7 @@ impl Node for ruff::ExprList { } // constructor -impl Node for ruff::ExprTuple { +impl Node for ast::ExprTuple { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { node_index: _, @@ -1183,7 +1175,7 @@ impl Node for ruff::ExprTuple { } // constructor -impl Node for ruff::ExprSlice { +impl Node for ast::ExprSlice { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { node_index: _, @@ -1228,7 +1220,7 @@ impl Node for ruff::ExprSlice { } // sum -impl Node for ruff::ExprContext { +impl Node for ast::ExprContext { fn ast_to_object(self, vm: &VirtualMachine, _source_file: &SourceFile) -> PyObjectRef { let node_type = match self { Self::Load => pyast::NodeExprContextLoad::static_type(), @@ -1266,7 +1258,7 @@ impl Node for ruff::ExprContext { } // product -impl Node for ruff::Comprehension { +impl Node for ast::Comprehension { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { node_index: _, diff --git a/crates/vm/src/stdlib/ast/module.rs b/crates/vm/src/stdlib/ast/module.rs index 6fae8f10a33..78f897b8930 100644 --- a/crates/vm/src/stdlib/ast/module.rs +++ b/crates/vm/src/stdlib/ast/module.rs @@ -18,9 +18,9 @@ use rustpython_compiler_core::SourceFile; /// - `FunctionType`: A function signature with argument and return type /// annotations, representing the type hints of a function (e.g., `def add(x: int, y: int) -> int`). pub(super) enum Mod { - Module(ruff::ModModule), + Module(ast::ModModule), Interactive(ModInteractive), - Expression(ruff::ModExpression), + Expression(ast::ModExpression), FunctionType(ModFunctionType), } @@ -42,11 +42,11 @@ impl Node for Mod { ) -> PyResult { let cls = object.class(); Ok(if cls.is(pyast::NodeModModule::static_type()) { - Self::Module(ruff::ModModule::ast_from_object(vm, source_file, object)?) + Self::Module(ast::ModModule::ast_from_object(vm, source_file, object)?) } else if cls.is(pyast::NodeModInteractive::static_type()) { Self::Interactive(ModInteractive::ast_from_object(vm, source_file, object)?) } else if cls.is(pyast::NodeModExpression::static_type()) { - Self::Expression(ruff::ModExpression::ast_from_object( + Self::Expression(ast::ModExpression::ast_from_object( vm, source_file, object, @@ -63,7 +63,7 @@ impl Node for Mod { } // constructor -impl Node for ruff::ModModule { +impl Node for ast::ModModule { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { node_index: _, @@ -113,7 +113,7 @@ impl Node for ruff::ModModule { pub(super) struct ModInteractive { pub(crate) range: TextRange, - pub(crate) body: Vec, + pub(crate) body: Vec, } // constructor @@ -147,7 +147,7 @@ impl Node for ModInteractive { } // constructor -impl Node for ruff::ModExpression { +impl Node for ast::ModExpression { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { node_index: _, @@ -182,8 +182,8 @@ impl Node for ruff::ModExpression { } pub(super) struct ModFunctionType { - pub(crate) argtypes: Box<[ruff::Expr]>, - pub(crate) returns: ruff::Expr, + pub(crate) argtypes: Box<[ast::Expr]>, + pub(crate) returns: ast::Expr, pub(crate) range: TextRange, } diff --git a/crates/vm/src/stdlib/ast/operator.rs b/crates/vm/src/stdlib/ast/operator.rs index c394152da2c..23aa63c7031 100644 --- a/crates/vm/src/stdlib/ast/operator.rs +++ b/crates/vm/src/stdlib/ast/operator.rs @@ -2,7 +2,7 @@ use super::*; use rustpython_compiler_core::SourceFile; // sum -impl Node for ruff::BoolOp { +impl Node for ast::BoolOp { fn ast_to_object(self, vm: &VirtualMachine, _source_file: &SourceFile) -> PyObjectRef { let node_type = match self { Self::And => pyast::NodeBoolOpAnd::static_type(), @@ -34,7 +34,7 @@ impl Node for ruff::BoolOp { } // sum -impl Node for ruff::Operator { +impl Node for ast::Operator { fn ast_to_object(self, vm: &VirtualMachine, _source_file: &SourceFile) -> PyObjectRef { let node_type = match self { Self::Add => pyast::NodeOperatorAdd::static_type(), @@ -99,7 +99,7 @@ impl Node for ruff::Operator { } // sum -impl Node for ruff::UnaryOp { +impl Node for ast::UnaryOp { fn ast_to_object(self, vm: &VirtualMachine, _source_file: &SourceFile) -> PyObjectRef { let node_type = match self { Self::Invert => pyast::NodeUnaryOpInvert::static_type(), @@ -137,7 +137,7 @@ impl Node for ruff::UnaryOp { } // sum -impl Node for ruff::CmpOp { +impl Node for ast::CmpOp { fn ast_to_object(self, vm: &VirtualMachine, _source_file: &SourceFile) -> PyObjectRef { let node_type = match self { Self::Eq => pyast::NodeCmpOpEq::static_type(), diff --git a/crates/vm/src/stdlib/ast/other.rs b/crates/vm/src/stdlib/ast/other.rs index ce7d5fe4807..8a89a740682 100644 --- a/crates/vm/src/stdlib/ast/other.rs +++ b/crates/vm/src/stdlib/ast/other.rs @@ -1,7 +1,7 @@ use super::*; use rustpython_compiler_core::SourceFile; -impl Node for ruff::ConversionFlag { +impl Node for ast::ConversionFlag { fn ast_to_object(self, vm: &VirtualMachine, _source_file: &SourceFile) -> PyObjectRef { vm.ctx.new_int(self as u8).into() } @@ -24,7 +24,7 @@ impl Node for ruff::ConversionFlag { } // /// This is just a string, not strictly an AST node. But it makes AST conversions easier. -impl Node for ruff::name::Name { +impl Node for ast::name::Name { fn ast_to_object(self, vm: &VirtualMachine, _source_file: &SourceFile) -> PyObjectRef { vm.ctx.new_str(self.as_str()).to_pyobject(vm) } @@ -41,9 +41,9 @@ impl Node for ruff::name::Name { } } -impl Node for ruff::Decorator { +impl Node for ast::Decorator { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { - ruff::Expr::ast_to_object(self.expression, vm, source_file) + ast::Expr::ast_to_object(self.expression, vm, source_file) } fn ast_from_object( @@ -51,7 +51,7 @@ impl Node for ruff::Decorator { source_file: &SourceFile, object: PyObjectRef, ) -> PyResult { - let expression = ruff::Expr::ast_from_object(vm, source_file, object)?; + let expression = ast::Expr::ast_from_object(vm, source_file, object)?; let range = expression.range(); Ok(Self { node_index: Default::default(), @@ -62,7 +62,7 @@ impl Node for ruff::Decorator { } // product -impl Node for ruff::Alias { +impl Node for ast::Alias { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { node_index: _, @@ -103,7 +103,7 @@ impl Node for ruff::Alias { } // product -impl Node for ruff::WithItem { +impl Node for ast::WithItem { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { node_index: _, diff --git a/crates/vm/src/stdlib/ast/parameter.rs b/crates/vm/src/stdlib/ast/parameter.rs index 44fcbb2b464..1e411d41ab6 100644 --- a/crates/vm/src/stdlib/ast/parameter.rs +++ b/crates/vm/src/stdlib/ast/parameter.rs @@ -2,7 +2,7 @@ use super::*; use rustpython_compiler_core::SourceFile; // product -impl Node for ruff::Parameters { +impl Node for ast::Parameters { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { node_index: _, @@ -101,7 +101,7 @@ impl Node for ruff::Parameters { } // product -impl Node for ruff::Parameter { +impl Node for ast::Parameter { fn ast_to_object(self, _vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { node_index: _, @@ -156,7 +156,7 @@ impl Node for ruff::Parameter { } // product -impl Node for ruff::Keyword { +impl Node for ast::Keyword { fn ast_to_object(self, _vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { node_index: _, @@ -197,7 +197,7 @@ impl Node for ruff::Keyword { struct PositionalParameters { pub _range: TextRange, // TODO: Use this - pub args: Box<[ruff::Parameter]>, + pub args: Box<[ast::Parameter]>, } impl Node for PositionalParameters { @@ -220,7 +220,7 @@ impl Node for PositionalParameters { struct KeywordParameters { pub _range: TextRange, // TODO: Use this - pub keywords: Box<[ruff::Parameter]>, + pub keywords: Box<[ast::Parameter]>, } impl Node for KeywordParameters { @@ -243,7 +243,7 @@ impl Node for KeywordParameters { struct ParameterDefaults { pub _range: TextRange, // TODO: Use this - defaults: Box<[Option>]>, + defaults: Box<[Option>]>, } impl Node for ParameterDefaults { @@ -265,8 +265,8 @@ impl Node for ParameterDefaults { } fn extract_positional_parameter_defaults( - pos_only_args: Vec, - args: Vec, + pos_only_args: Vec, + args: Vec, ) -> ( PositionalParameters, PositionalParameters, @@ -325,15 +325,15 @@ fn merge_positional_parameter_defaults( args: PositionalParameters, defaults: ParameterDefaults, ) -> ( - Vec, - Vec, + Vec, + Vec, ) { let posonlyargs = posonlyargs.args; let args = args.args; let defaults = defaults.defaults; let mut posonlyargs: Vec<_> = as IntoIterator>::into_iter(posonlyargs) - .map(|parameter| ruff::ParameterWithDefault { + .map(|parameter| ast::ParameterWithDefault { node_index: Default::default(), range: Default::default(), parameter, @@ -341,7 +341,7 @@ fn merge_positional_parameter_defaults( }) .collect(); let mut args: Vec<_> = as IntoIterator>::into_iter(args) - .map(|parameter| ruff::ParameterWithDefault { + .map(|parameter| ast::ParameterWithDefault { node_index: Default::default(), range: Default::default(), parameter, @@ -366,7 +366,7 @@ fn merge_positional_parameter_defaults( } fn extract_keyword_parameter_defaults( - kw_only_args: Vec, + kw_only_args: Vec, ) -> (KeywordParameters, ParameterDefaults) { let mut defaults = vec![]; defaults.extend(kw_only_args.iter().map(|item| item.default.clone())); @@ -402,9 +402,9 @@ fn extract_keyword_parameter_defaults( fn merge_keyword_parameter_defaults( kw_only_args: KeywordParameters, defaults: ParameterDefaults, -) -> Vec { +) -> Vec { core::iter::zip(kw_only_args.keywords, defaults.defaults) - .map(|(parameter, default)| ruff::ParameterWithDefault { + .map(|(parameter, default)| ast::ParameterWithDefault { node_index: Default::default(), parameter, default, diff --git a/crates/vm/src/stdlib/ast/pattern.rs b/crates/vm/src/stdlib/ast/pattern.rs index d8128cb0622..4531a989cb3 100644 --- a/crates/vm/src/stdlib/ast/pattern.rs +++ b/crates/vm/src/stdlib/ast/pattern.rs @@ -2,7 +2,7 @@ use super::*; use rustpython_compiler_core::SourceFile; // product -impl Node for ruff::MatchCase { +impl Node for ast::MatchCase { fn ast_to_object(self, _vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { node_index: _, @@ -50,7 +50,7 @@ impl Node for ruff::MatchCase { } // sum -impl Node for ruff::Pattern { +impl Node for ast::Pattern { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { match self { Self::MatchValue(cons) => cons.ast_to_object(vm, source_file), @@ -70,49 +70,49 @@ impl Node for ruff::Pattern { ) -> PyResult { let _cls = _object.class(); Ok(if _cls.is(pyast::NodePatternMatchValue::static_type()) { - Self::MatchValue(ruff::PatternMatchValue::ast_from_object( + Self::MatchValue(ast::PatternMatchValue::ast_from_object( _vm, source_file, _object, )?) } else if _cls.is(pyast::NodePatternMatchSingleton::static_type()) { - Self::MatchSingleton(ruff::PatternMatchSingleton::ast_from_object( + Self::MatchSingleton(ast::PatternMatchSingleton::ast_from_object( _vm, source_file, _object, )?) } else if _cls.is(pyast::NodePatternMatchSequence::static_type()) { - Self::MatchSequence(ruff::PatternMatchSequence::ast_from_object( + Self::MatchSequence(ast::PatternMatchSequence::ast_from_object( _vm, source_file, _object, )?) } else if _cls.is(pyast::NodePatternMatchMapping::static_type()) { - Self::MatchMapping(ruff::PatternMatchMapping::ast_from_object( + Self::MatchMapping(ast::PatternMatchMapping::ast_from_object( _vm, source_file, _object, )?) } else if _cls.is(pyast::NodePatternMatchClass::static_type()) { - Self::MatchClass(ruff::PatternMatchClass::ast_from_object( + Self::MatchClass(ast::PatternMatchClass::ast_from_object( _vm, source_file, _object, )?) } else if _cls.is(pyast::NodePatternMatchStar::static_type()) { - Self::MatchStar(ruff::PatternMatchStar::ast_from_object( + Self::MatchStar(ast::PatternMatchStar::ast_from_object( _vm, source_file, _object, )?) } else if _cls.is(pyast::NodePatternMatchAs::static_type()) { - Self::MatchAs(ruff::PatternMatchAs::ast_from_object( + Self::MatchAs(ast::PatternMatchAs::ast_from_object( _vm, source_file, _object, )?) } else if _cls.is(pyast::NodePatternMatchOr::static_type()) { - Self::MatchOr(ruff::PatternMatchOr::ast_from_object( + Self::MatchOr(ast::PatternMatchOr::ast_from_object( _vm, source_file, _object, @@ -126,7 +126,7 @@ impl Node for ruff::Pattern { } } // constructor -impl Node for ruff::PatternMatchValue { +impl Node for ast::PatternMatchValue { fn ast_to_object(self, _vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { node_index: _, @@ -161,7 +161,7 @@ impl Node for ruff::PatternMatchValue { } // constructor -impl Node for ruff::PatternMatchSingleton { +impl Node for ast::PatternMatchSingleton { fn ast_to_object(self, _vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { node_index: _, @@ -198,12 +198,12 @@ impl Node for ruff::PatternMatchSingleton { } } -impl Node for ruff::Singleton { +impl Node for ast::Singleton { fn ast_to_object(self, vm: &VirtualMachine, _source_file: &SourceFile) -> PyObjectRef { match self { - ruff::Singleton::None => vm.ctx.none(), - ruff::Singleton::True => vm.ctx.new_bool(true).into(), - ruff::Singleton::False => vm.ctx.new_bool(false).into(), + ast::Singleton::None => vm.ctx.none(), + ast::Singleton::True => vm.ctx.new_bool(true).into(), + ast::Singleton::False => vm.ctx.new_bool(false).into(), } } @@ -213,11 +213,11 @@ impl Node for ruff::Singleton { object: PyObjectRef, ) -> PyResult { if vm.is_none(&object) { - Ok(ruff::Singleton::None) + Ok(ast::Singleton::None) } else if object.is(&vm.ctx.true_value) { - Ok(ruff::Singleton::True) + Ok(ast::Singleton::True) } else if object.is(&vm.ctx.false_value) { - Ok(ruff::Singleton::False) + Ok(ast::Singleton::False) } else { Err(vm.new_value_error(format!( "Expected None, True, or False, got {:?}", @@ -228,7 +228,7 @@ impl Node for ruff::Singleton { } // constructor -impl Node for ruff::PatternMatchSequence { +impl Node for ast::PatternMatchSequence { fn ast_to_object(self, _vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { node_index: _, @@ -266,7 +266,7 @@ impl Node for ruff::PatternMatchSequence { } // constructor -impl Node for ruff::PatternMatchMapping { +impl Node for ast::PatternMatchMapping { fn ast_to_object(self, _vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { node_index: _, @@ -318,7 +318,7 @@ impl Node for ruff::PatternMatchMapping { } // constructor -impl Node for ruff::PatternMatchClass { +impl Node for ast::PatternMatchClass { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { node_index: _, @@ -377,7 +377,7 @@ impl Node for ruff::PatternMatchClass { get_node_field(vm, &object, "cls", "MatchClass")?, )?, range: range_from_object(vm, source_file, object, "MatchClass")?, - arguments: ruff::PatternArguments { + arguments: ast::PatternArguments { node_index: Default::default(), range: Default::default(), patterns, @@ -387,7 +387,7 @@ impl Node for ruff::PatternMatchClass { } } -struct PatternMatchClassPatterns(Vec); +struct PatternMatchClassPatterns(Vec); impl Node for PatternMatchClassPatterns { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { @@ -403,7 +403,7 @@ impl Node for PatternMatchClassPatterns { } } -struct PatternMatchClassKeywordAttributes(Vec); +struct PatternMatchClassKeywordAttributes(Vec); impl Node for PatternMatchClassKeywordAttributes { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { @@ -419,7 +419,7 @@ impl Node for PatternMatchClassKeywordAttributes { } } -struct PatternMatchClassKeywordPatterns(Vec); +struct PatternMatchClassKeywordPatterns(Vec); impl Node for PatternMatchClassKeywordPatterns { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { @@ -435,7 +435,7 @@ impl Node for PatternMatchClassKeywordPatterns { } } // constructor -impl Node for ruff::PatternMatchStar { +impl Node for ast::PatternMatchStar { fn ast_to_object(self, _vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { node_index: _, @@ -468,7 +468,7 @@ impl Node for ruff::PatternMatchStar { } // constructor -impl Node for ruff::PatternMatchAs { +impl Node for ast::PatternMatchAs { fn ast_to_object(self, _vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { node_index: _, @@ -507,7 +507,7 @@ impl Node for ruff::PatternMatchAs { } // constructor -impl Node for ruff::PatternMatchOr { +impl Node for ast::PatternMatchOr { fn ast_to_object(self, _vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { node_index: _, @@ -541,7 +541,7 @@ impl Node for ruff::PatternMatchOr { } fn split_pattern_match_class( - arguments: ruff::PatternArguments, + arguments: ast::PatternArguments, ) -> ( PatternMatchClassPatterns, PatternMatchClassKeywordAttributes, @@ -562,12 +562,12 @@ fn merge_pattern_match_class( patterns: PatternMatchClassPatterns, kwd_attrs: PatternMatchClassKeywordAttributes, kwd_patterns: PatternMatchClassKeywordPatterns, -) -> (Vec, Vec) { +) -> (Vec, Vec) { let keywords = kwd_attrs .0 .into_iter() .zip(kwd_patterns.0) - .map(|(attr, pattern)| ruff::PatternKeyword { + .map(|(attr, pattern)| ast::PatternKeyword { range: Default::default(), node_index: Default::default(), attr, diff --git a/crates/vm/src/stdlib/ast/statement.rs b/crates/vm/src/stdlib/ast/statement.rs index f1d36c52e2e..7716181fc5b 100644 --- a/crates/vm/src/stdlib/ast/statement.rs +++ b/crates/vm/src/stdlib/ast/statement.rs @@ -3,7 +3,7 @@ use crate::stdlib::ast::argument::{merge_class_def_args, split_class_def_args}; use rustpython_compiler_core::SourceFile; // sum -impl Node for ruff::Stmt { +impl Node for ast::Stmt { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { match self { Self::FunctionDef(cons) => cons.ast_to_object(vm, source_file), @@ -44,117 +44,93 @@ impl Node for ruff::Stmt { ) -> PyResult { let _cls = _object.class(); Ok(if _cls.is(pyast::NodeStmtFunctionDef::static_type()) { - Self::FunctionDef(ruff::StmtFunctionDef::ast_from_object( + Self::FunctionDef(ast::StmtFunctionDef::ast_from_object( _vm, source_file, _object, )?) } else if _cls.is(pyast::NodeStmtAsyncFunctionDef::static_type()) { - Self::FunctionDef(ruff::StmtFunctionDef::ast_from_object( + Self::FunctionDef(ast::StmtFunctionDef::ast_from_object( _vm, source_file, _object, )?) } else if _cls.is(pyast::NodeStmtClassDef::static_type()) { - Self::ClassDef(ruff::StmtClassDef::ast_from_object( + Self::ClassDef(ast::StmtClassDef::ast_from_object( _vm, source_file, _object, )?) } else if _cls.is(pyast::NodeStmtReturn::static_type()) { - Self::Return(ruff::StmtReturn::ast_from_object( - _vm, - source_file, - _object, - )?) + Self::Return(ast::StmtReturn::ast_from_object(_vm, source_file, _object)?) } else if _cls.is(pyast::NodeStmtDelete::static_type()) { - Self::Delete(ruff::StmtDelete::ast_from_object( - _vm, - source_file, - _object, - )?) + Self::Delete(ast::StmtDelete::ast_from_object(_vm, source_file, _object)?) } else if _cls.is(pyast::NodeStmtAssign::static_type()) { - Self::Assign(ruff::StmtAssign::ast_from_object( - _vm, - source_file, - _object, - )?) + Self::Assign(ast::StmtAssign::ast_from_object(_vm, source_file, _object)?) } else if _cls.is(pyast::NodeStmtTypeAlias::static_type()) { - Self::TypeAlias(ruff::StmtTypeAlias::ast_from_object( + Self::TypeAlias(ast::StmtTypeAlias::ast_from_object( _vm, source_file, _object, )?) } else if _cls.is(pyast::NodeStmtAugAssign::static_type()) { - Self::AugAssign(ruff::StmtAugAssign::ast_from_object( + Self::AugAssign(ast::StmtAugAssign::ast_from_object( _vm, source_file, _object, )?) } else if _cls.is(pyast::NodeStmtAnnAssign::static_type()) { - Self::AnnAssign(ruff::StmtAnnAssign::ast_from_object( + Self::AnnAssign(ast::StmtAnnAssign::ast_from_object( _vm, source_file, _object, )?) } else if _cls.is(pyast::NodeStmtFor::static_type()) { - Self::For(ruff::StmtFor::ast_from_object(_vm, source_file, _object)?) + Self::For(ast::StmtFor::ast_from_object(_vm, source_file, _object)?) } else if _cls.is(pyast::NodeStmtAsyncFor::static_type()) { - Self::For(ruff::StmtFor::ast_from_object(_vm, source_file, _object)?) + Self::For(ast::StmtFor::ast_from_object(_vm, source_file, _object)?) } else if _cls.is(pyast::NodeStmtWhile::static_type()) { - Self::While(ruff::StmtWhile::ast_from_object(_vm, source_file, _object)?) + Self::While(ast::StmtWhile::ast_from_object(_vm, source_file, _object)?) } else if _cls.is(pyast::NodeStmtIf::static_type()) { - Self::If(ruff::StmtIf::ast_from_object(_vm, source_file, _object)?) + Self::If(ast::StmtIf::ast_from_object(_vm, source_file, _object)?) } else if _cls.is(pyast::NodeStmtWith::static_type()) { - Self::With(ruff::StmtWith::ast_from_object(_vm, source_file, _object)?) + Self::With(ast::StmtWith::ast_from_object(_vm, source_file, _object)?) } else if _cls.is(pyast::NodeStmtAsyncWith::static_type()) { - Self::With(ruff::StmtWith::ast_from_object(_vm, source_file, _object)?) + Self::With(ast::StmtWith::ast_from_object(_vm, source_file, _object)?) } else if _cls.is(pyast::NodeStmtMatch::static_type()) { - Self::Match(ruff::StmtMatch::ast_from_object(_vm, source_file, _object)?) + Self::Match(ast::StmtMatch::ast_from_object(_vm, source_file, _object)?) } else if _cls.is(pyast::NodeStmtRaise::static_type()) { - Self::Raise(ruff::StmtRaise::ast_from_object(_vm, source_file, _object)?) + Self::Raise(ast::StmtRaise::ast_from_object(_vm, source_file, _object)?) } else if _cls.is(pyast::NodeStmtTry::static_type()) { - Self::Try(ruff::StmtTry::ast_from_object(_vm, source_file, _object)?) + Self::Try(ast::StmtTry::ast_from_object(_vm, source_file, _object)?) } else if _cls.is(pyast::NodeStmtTryStar::static_type()) { - Self::Try(ruff::StmtTry::ast_from_object(_vm, source_file, _object)?) + Self::Try(ast::StmtTry::ast_from_object(_vm, source_file, _object)?) } else if _cls.is(pyast::NodeStmtAssert::static_type()) { - Self::Assert(ruff::StmtAssert::ast_from_object( - _vm, - source_file, - _object, - )?) + Self::Assert(ast::StmtAssert::ast_from_object(_vm, source_file, _object)?) } else if _cls.is(pyast::NodeStmtImport::static_type()) { - Self::Import(ruff::StmtImport::ast_from_object( - _vm, - source_file, - _object, - )?) + Self::Import(ast::StmtImport::ast_from_object(_vm, source_file, _object)?) } else if _cls.is(pyast::NodeStmtImportFrom::static_type()) { - Self::ImportFrom(ruff::StmtImportFrom::ast_from_object( + Self::ImportFrom(ast::StmtImportFrom::ast_from_object( _vm, source_file, _object, )?) } else if _cls.is(pyast::NodeStmtGlobal::static_type()) { - Self::Global(ruff::StmtGlobal::ast_from_object( - _vm, - source_file, - _object, - )?) + Self::Global(ast::StmtGlobal::ast_from_object(_vm, source_file, _object)?) } else if _cls.is(pyast::NodeStmtNonlocal::static_type()) { - Self::Nonlocal(ruff::StmtNonlocal::ast_from_object( + Self::Nonlocal(ast::StmtNonlocal::ast_from_object( _vm, source_file, _object, )?) } else if _cls.is(pyast::NodeStmtExpr::static_type()) { - Self::Expr(ruff::StmtExpr::ast_from_object(_vm, source_file, _object)?) + Self::Expr(ast::StmtExpr::ast_from_object(_vm, source_file, _object)?) } else if _cls.is(pyast::NodeStmtPass::static_type()) { - Self::Pass(ruff::StmtPass::ast_from_object(_vm, source_file, _object)?) + Self::Pass(ast::StmtPass::ast_from_object(_vm, source_file, _object)?) } else if _cls.is(pyast::NodeStmtBreak::static_type()) { - Self::Break(ruff::StmtBreak::ast_from_object(_vm, source_file, _object)?) + Self::Break(ast::StmtBreak::ast_from_object(_vm, source_file, _object)?) } else if _cls.is(pyast::NodeStmtContinue::static_type()) { - Self::Continue(ruff::StmtContinue::ast_from_object( + Self::Continue(ast::StmtContinue::ast_from_object( _vm, source_file, _object, @@ -169,7 +145,7 @@ impl Node for ruff::Stmt { } // constructor -impl Node for ruff::StmtFunctionDef { +impl Node for ast::StmtFunctionDef { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { node_index: _, @@ -266,7 +242,7 @@ impl Node for ruff::StmtFunctionDef { } // constructor -impl Node for ruff::StmtClassDef { +impl Node for ast::StmtClassDef { fn ast_to_object(self, _vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { node_index: _, @@ -348,7 +324,7 @@ impl Node for ruff::StmtClassDef { } } // constructor -impl Node for ruff::StmtReturn { +impl Node for ast::StmtReturn { fn ast_to_object(self, _vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { node_index: _, @@ -379,7 +355,7 @@ impl Node for ruff::StmtReturn { } } // constructor -impl Node for ruff::StmtDelete { +impl Node for ast::StmtDelete { fn ast_to_object(self, _vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { node_index: _, @@ -413,7 +389,7 @@ impl Node for ruff::StmtDelete { } // constructor -impl Node for ruff::StmtAssign { +impl Node for ast::StmtAssign { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { node_index: _, @@ -461,7 +437,7 @@ impl Node for ruff::StmtAssign { } // constructor -impl Node for ruff::StmtTypeAlias { +impl Node for ast::StmtTypeAlias { fn ast_to_object(self, _vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { node_index: _, @@ -516,7 +492,7 @@ impl Node for ruff::StmtTypeAlias { } // constructor -impl Node for ruff::StmtAugAssign { +impl Node for ast::StmtAugAssign { fn ast_to_object(self, _vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { node_index: _, @@ -566,7 +542,7 @@ impl Node for ruff::StmtAugAssign { } // constructor -impl Node for ruff::StmtAnnAssign { +impl Node for ast::StmtAnnAssign { fn ast_to_object(self, _vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { node_index: _, @@ -626,7 +602,7 @@ impl Node for ruff::StmtAnnAssign { } // constructor -impl Node for ruff::StmtFor { +impl Node for ast::StmtFor { fn ast_to_object(self, _vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { node_index: _, @@ -704,7 +680,7 @@ impl Node for ruff::StmtFor { } // constructor -impl Node for ruff::StmtWhile { +impl Node for ast::StmtWhile { fn ast_to_object(self, _vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { node_index: _, @@ -754,7 +730,7 @@ impl Node for ruff::StmtWhile { } } // constructor -impl Node for ruff::StmtIf { +impl Node for ast::StmtIf { fn ast_to_object(self, _vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { node_index: _, @@ -764,7 +740,7 @@ impl Node for ruff::StmtIf { elif_else_clauses, } = self; elif_else_clause::ast_to_object( - ruff::ElifElseClause { + ast::ElifElseClause { node_index: Default::default(), range, test: Some(*test), @@ -784,7 +760,7 @@ impl Node for ruff::StmtIf { } } // constructor -impl Node for ruff::StmtWith { +impl Node for ast::StmtWith { fn ast_to_object(self, _vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { node_index: _, @@ -844,7 +820,7 @@ impl Node for ruff::StmtWith { } } // constructor -impl Node for ruff::StmtMatch { +impl Node for ast::StmtMatch { fn ast_to_object(self, _vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { node_index: _, @@ -885,7 +861,7 @@ impl Node for ruff::StmtMatch { } } // constructor -impl Node for ruff::StmtRaise { +impl Node for ast::StmtRaise { fn ast_to_object(self, _vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { node_index: _, @@ -922,7 +898,7 @@ impl Node for ruff::StmtRaise { } } // constructor -impl Node for ruff::StmtTry { +impl Node for ast::StmtTry { fn ast_to_object(self, _vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { node_index: _, @@ -996,7 +972,7 @@ impl Node for ruff::StmtTry { } } // constructor -impl Node for ruff::StmtAssert { +impl Node for ast::StmtAssert { fn ast_to_object(self, _vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { node_index: _, @@ -1035,7 +1011,7 @@ impl Node for ruff::StmtAssert { } } // constructor -impl Node for ruff::StmtImport { +impl Node for ast::StmtImport { fn ast_to_object(self, _vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { node_index: _, @@ -1068,7 +1044,7 @@ impl Node for ruff::StmtImport { } } // constructor -impl Node for ruff::StmtImportFrom { +impl Node for ast::StmtImportFrom { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { node_index: _, @@ -1117,7 +1093,7 @@ impl Node for ruff::StmtImportFrom { } } // constructor -impl Node for ruff::StmtGlobal { +impl Node for ast::StmtGlobal { fn ast_to_object(self, _vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { node_index: _, @@ -1150,7 +1126,7 @@ impl Node for ruff::StmtGlobal { } } // constructor -impl Node for ruff::StmtNonlocal { +impl Node for ast::StmtNonlocal { fn ast_to_object(self, _vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { node_index: _, @@ -1183,7 +1159,7 @@ impl Node for ruff::StmtNonlocal { } } // constructor -impl Node for ruff::StmtExpr { +impl Node for ast::StmtExpr { fn ast_to_object(self, _vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { node_index: _, @@ -1216,7 +1192,7 @@ impl Node for ruff::StmtExpr { } } // constructor -impl Node for ruff::StmtPass { +impl Node for ast::StmtPass { fn ast_to_object(self, _vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { node_index: _, @@ -1241,7 +1217,7 @@ impl Node for ruff::StmtPass { } } // constructor -impl Node for ruff::StmtBreak { +impl Node for ast::StmtBreak { fn ast_to_object(self, _vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { node_index: _, @@ -1268,7 +1244,7 @@ impl Node for ruff::StmtBreak { } // constructor -impl Node for ruff::StmtContinue { +impl Node for ast::StmtContinue { fn ast_to_object(self, _vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { node_index: _, diff --git a/crates/vm/src/stdlib/ast/string.rs b/crates/vm/src/stdlib/ast/string.rs index c68e378bc74..2533fb8c6b9 100644 --- a/crates/vm/src/stdlib/ast/string.rs +++ b/crates/vm/src/stdlib/ast/string.rs @@ -2,14 +2,13 @@ use super::constant::{Constant, ConstantLiteral}; use super::*; fn ruff_fstring_element_into_iter( - mut fstring_element: ruff::InterpolatedStringElements, -) -> impl Iterator { - let default = - ruff::InterpolatedStringElement::Literal(ruff::InterpolatedStringLiteralElement { - node_index: Default::default(), - range: Default::default(), - value: Default::default(), - }); + mut fstring_element: ast::InterpolatedStringElements, +) -> impl Iterator { + let default = ast::InterpolatedStringElement::Literal(ast::InterpolatedStringLiteralElement { + node_index: Default::default(), + range: Default::default(), + value: Default::default(), + }); fstring_element .iter_mut() .map(move |elem| core::mem::replace(elem, default.clone())) @@ -18,19 +17,19 @@ fn ruff_fstring_element_into_iter( } fn ruff_fstring_element_to_joined_str_part( - element: ruff::InterpolatedStringElement, + element: ast::InterpolatedStringElement, ) -> JoinedStrPart { match element { - ruff::InterpolatedStringElement::Literal(ruff::InterpolatedStringLiteralElement { + ast::InterpolatedStringElement::Literal(ast::InterpolatedStringLiteralElement { range, value, node_index: _, }) => JoinedStrPart::Constant(Constant::new_str( value, - ruff::str_prefix::StringLiteralPrefix::Empty, + ast::str_prefix::StringLiteralPrefix::Empty, range, )), - ruff::InterpolatedStringElement::Interpolation(ruff::InterpolatedElement { + ast::InterpolatedStringElement::Interpolation(ast::InterpolatedElement { range, expression, debug_text: _, // TODO: What is this? @@ -47,12 +46,12 @@ fn ruff_fstring_element_to_joined_str_part( } fn ruff_format_spec_to_joined_str( - format_spec: Option>, + format_spec: Option>, ) -> Option> { match format_spec { None => None, Some(format_spec) => { - let ruff::InterpolatedStringFormatSpec { + let ast::InterpolatedStringFormatSpec { range, elements, node_index: _, @@ -67,36 +66,36 @@ fn ruff_format_spec_to_joined_str( } fn ruff_fstring_element_to_ruff_fstring_part( - element: ruff::InterpolatedStringElement, -) -> ruff::FStringPart { + element: ast::InterpolatedStringElement, +) -> ast::FStringPart { match element { - ruff::InterpolatedStringElement::Literal(value) => { - let ruff::InterpolatedStringLiteralElement { + ast::InterpolatedStringElement::Literal(value) => { + let ast::InterpolatedStringLiteralElement { node_index, range, value, } = value; - ruff::FStringPart::Literal(ruff::StringLiteral { + ast::FStringPart::Literal(ast::StringLiteral { node_index, range, value, - flags: ruff::StringLiteralFlags::empty(), + flags: ast::StringLiteralFlags::empty(), }) } - ruff::InterpolatedStringElement::Interpolation(ruff::InterpolatedElement { + ast::InterpolatedStringElement::Interpolation(ast::InterpolatedElement { range, .. - }) => ruff::FStringPart::FString(ruff::FString { + }) => ast::FStringPart::FString(ast::FString { node_index: Default::default(), range, elements: vec![element].into(), - flags: ruff::FStringFlags::empty(), + flags: ast::FStringFlags::empty(), }), } } fn joined_str_to_ruff_format_spec( joined_str: Option>, -) -> Option> { +) -> Option> { match joined_str { None => None, Some(joined_str) => { @@ -104,7 +103,7 @@ fn joined_str_to_ruff_format_spec( let elements: Vec<_> = Box::into_iter(values) .map(joined_str_part_to_ruff_fstring_element) .collect(); - let format_spec = ruff::InterpolatedStringFormatSpec { + let format_spec = ast::InterpolatedStringFormatSpec { node_index: Default::default(), range, elements: elements.into(), @@ -121,32 +120,32 @@ pub(super) struct JoinedStr { } impl JoinedStr { - pub(super) fn into_expr(self) -> ruff::Expr { + pub(super) fn into_expr(self) -> ast::Expr { let Self { range, values } = self; - ruff::Expr::FString(ruff::ExprFString { + ast::Expr::FString(ast::ExprFString { node_index: Default::default(), range: Default::default(), value: match values.len() { // ruff represents an empty fstring like this: - 0 => ruff::FStringValue::single(ruff::FString { + 0 => ast::FStringValue::single(ast::FString { node_index: Default::default(), range, elements: vec![].into(), - flags: ruff::FStringFlags::empty(), + flags: ast::FStringFlags::empty(), }), - 1 => ruff::FStringValue::single( + 1 => ast::FStringValue::single( Box::<[_]>::into_iter(values) .map(joined_str_part_to_ruff_fstring_element) - .map(|element| ruff::FString { + .map(|element| ast::FString { node_index: Default::default(), range, elements: vec![element].into(), - flags: ruff::FStringFlags::empty(), + flags: ast::FStringFlags::empty(), }) .next() .expect("FString has exactly one part"), ), - _ => ruff::FStringValue::concatenated( + _ => ast::FStringValue::concatenated( Box::<[_]>::into_iter(values) .map(joined_str_part_to_ruff_fstring_element) .map(ruff_fstring_element_to_ruff_fstring_part) @@ -157,10 +156,10 @@ impl JoinedStr { } } -fn joined_str_part_to_ruff_fstring_element(part: JoinedStrPart) -> ruff::InterpolatedStringElement { +fn joined_str_part_to_ruff_fstring_element(part: JoinedStrPart) -> ast::InterpolatedStringElement { match part { JoinedStrPart::FormattedValue(value) => { - ruff::InterpolatedStringElement::Interpolation(ruff::InterpolatedElement { + ast::InterpolatedStringElement::Interpolation(ast::InterpolatedElement { node_index: Default::default(), range: value.range, expression: value.value.clone(), @@ -170,7 +169,7 @@ fn joined_str_part_to_ruff_fstring_element(part: JoinedStrPart) -> ruff::Interpo }) } JoinedStrPart::Constant(value) => { - ruff::InterpolatedStringElement::Literal(ruff::InterpolatedStringLiteralElement { + ast::InterpolatedStringElement::Literal(ast::InterpolatedStringLiteralElement { node_index: Default::default(), range: value.range, value: match value.value { @@ -254,8 +253,8 @@ impl Node for JoinedStrPart { #[derive(Debug)] pub(super) struct FormattedValue { - value: Box, - conversion: ruff::ConversionFlag, + value: Box, + conversion: ast::ConversionFlag, format_spec: Option>, range: TextRange, } @@ -313,24 +312,24 @@ impl Node for FormattedValue { pub(super) fn fstring_to_object( vm: &VirtualMachine, source_file: &SourceFile, - expression: ruff::ExprFString, + expression: ast::ExprFString, ) -> PyObjectRef { - let ruff::ExprFString { + let ast::ExprFString { range, mut value, node_index: _, } = expression; - let default_part = ruff::FStringPart::FString(ruff::FString { + let default_part = ast::FStringPart::FString(ast::FString { node_index: Default::default(), range: Default::default(), elements: Default::default(), - flags: ruff::FStringFlags::empty(), + flags: ast::FStringFlags::empty(), }); let mut values = Vec::new(); for i in 0..value.as_slice().len() { let part = core::mem::replace(value.iter_mut().nth(i).unwrap(), default_part.clone()); match part { - ruff::FStringPart::Literal(ruff::StringLiteral { + ast::FStringPart::Literal(ast::StringLiteral { range, value, flags, @@ -342,7 +341,7 @@ pub(super) fn fstring_to_object( range, ))); } - ruff::FStringPart::FString(ruff::FString { + ast::FStringPart::FString(ast::FString { range: _, elements, flags: _, @@ -364,20 +363,20 @@ pub(super) fn fstring_to_object( // ===== TString (Template String) Support ===== fn ruff_tstring_element_to_template_str_part( - element: ruff::InterpolatedStringElement, + element: ast::InterpolatedStringElement, source_file: &SourceFile, ) -> TemplateStrPart { match element { - ruff::InterpolatedStringElement::Literal(ruff::InterpolatedStringLiteralElement { + ast::InterpolatedStringElement::Literal(ast::InterpolatedStringLiteralElement { range, value, node_index: _, }) => TemplateStrPart::Constant(Constant::new_str( value, - ruff::str_prefix::StringLiteralPrefix::Empty, + ast::str_prefix::StringLiteralPrefix::Empty, range, )), - ruff::InterpolatedStringElement::Interpolation(ruff::InterpolatedElement { + ast::InterpolatedStringElement::Interpolation(ast::InterpolatedElement { range, expression, debug_text, @@ -401,13 +400,13 @@ fn ruff_tstring_element_to_template_str_part( } fn ruff_format_spec_to_template_str( - format_spec: Option>, + format_spec: Option>, source_file: &SourceFile, ) -> Option> { match format_spec { None => None, Some(format_spec) => { - let ruff::InterpolatedStringFormatSpec { + let ast::InterpolatedStringFormatSpec { range, elements, node_index: _, @@ -499,9 +498,9 @@ impl Node for TemplateStrPart { #[derive(Debug)] pub(super) struct TStringInterpolation { - value: Box, + value: Box, str: String, - conversion: ruff::ConversionFlag, + conversion: ast::ConversionFlag, format_spec: Option>, range: TextRange, } @@ -565,18 +564,18 @@ impl Node for TStringInterpolation { pub(super) fn tstring_to_object( vm: &VirtualMachine, source_file: &SourceFile, - expression: ruff::ExprTString, + expression: ast::ExprTString, ) -> PyObjectRef { - let ruff::ExprTString { + let ast::ExprTString { range, mut value, node_index: _, } = expression; - let default_tstring = ruff::TString { + let default_tstring = ast::TString { node_index: Default::default(), range: Default::default(), elements: Default::default(), - flags: ruff::TStringFlags::empty(), + flags: ast::TStringFlags::empty(), }; let mut values = Vec::new(); for i in 0..value.as_slice().len() { diff --git a/crates/vm/src/stdlib/ast/type_parameters.rs b/crates/vm/src/stdlib/ast/type_parameters.rs index 017470f7e64..4801a9a4b28 100644 --- a/crates/vm/src/stdlib/ast/type_parameters.rs +++ b/crates/vm/src/stdlib/ast/type_parameters.rs @@ -1,7 +1,7 @@ use super::*; use rustpython_compiler_core::SourceFile; -impl Node for ruff::TypeParams { +impl Node for ast::TypeParams { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { self.type_params.ast_to_object(vm, source_file) } @@ -11,7 +11,7 @@ impl Node for ruff::TypeParams { _source_file: &SourceFile, _object: PyObjectRef, ) -> PyResult { - let type_params: Vec = Node::ast_from_object(_vm, _source_file, _object)?; + let type_params: Vec = Node::ast_from_object(_vm, _source_file, _object)?; let range = Option::zip(type_params.first(), type_params.last()) .map(|(first, last)| first.range().cover(last.range())) .unwrap_or_default(); @@ -28,7 +28,7 @@ impl Node for ruff::TypeParams { } // sum -impl Node for ruff::TypeParam { +impl Node for ast::TypeParam { fn ast_to_object(self, vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { match self { Self::TypeVar(cons) => cons.ast_to_object(vm, source_file), @@ -44,19 +44,19 @@ impl Node for ruff::TypeParam { ) -> PyResult { let _cls = _object.class(); Ok(if _cls.is(pyast::NodeTypeParamTypeVar::static_type()) { - Self::TypeVar(ruff::TypeParamTypeVar::ast_from_object( + Self::TypeVar(ast::TypeParamTypeVar::ast_from_object( _vm, source_file, _object, )?) } else if _cls.is(pyast::NodeTypeParamParamSpec::static_type()) { - Self::ParamSpec(ruff::TypeParamParamSpec::ast_from_object( + Self::ParamSpec(ast::TypeParamParamSpec::ast_from_object( _vm, source_file, _object, )?) } else if _cls.is(pyast::NodeTypeParamTypeVarTuple::static_type()) { - Self::TypeVarTuple(ruff::TypeParamTypeVarTuple::ast_from_object( + Self::TypeVarTuple(ast::TypeParamTypeVarTuple::ast_from_object( _vm, source_file, _object, @@ -71,7 +71,7 @@ impl Node for ruff::TypeParam { } // constructor -impl Node for ruff::TypeParamTypeVar { +impl Node for ast::TypeParamTypeVar { fn ast_to_object(self, _vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { node_index: _, @@ -118,7 +118,7 @@ impl Node for ruff::TypeParamTypeVar { } // constructor -impl Node for ruff::TypeParamParamSpec { +impl Node for ast::TypeParamParamSpec { fn ast_to_object(self, _vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { node_index: _, @@ -165,7 +165,7 @@ impl Node for ruff::TypeParamParamSpec { } // constructor -impl Node for ruff::TypeParamTypeVarTuple { +impl Node for ast::TypeParamTypeVarTuple { fn ast_to_object(self, _vm: &VirtualMachine, source_file: &SourceFile) -> PyObjectRef { let Self { node_index: _, diff --git a/crates/vm/src/stdlib/sys.rs b/crates/vm/src/stdlib/sys.rs index 66a43427f83..305e1cfa61d 100644 --- a/crates/vm/src/stdlib/sys.rs +++ b/crates/vm/src/stdlib/sys.rs @@ -4,6 +4,30 @@ pub(crate) use sys::{ __module_def, DOC, MAXSIZE, RUST_MULTIARCH, UnraisableHookArgsData, multiarch, }; +#[pymodule(name = "_jit")] +mod sys_jit { + /// Return True if the current Python executable supports JIT compilation, + /// and False otherwise. + #[pyfunction] + const fn is_available() -> bool { + false // RustPython has no JIT + } + + /// Return True if JIT compilation is enabled for the current Python process, + /// and False otherwise. + #[pyfunction] + const fn is_enabled() -> bool { + false // RustPython has no JIT + } + + /// Return True if the topmost Python frame is currently executing JIT code, + /// and False otherwise. + #[pyfunction] + const fn is_active() -> bool { + false // RustPython has no JIT + } +} + #[pymodule] mod sys { use crate::{ @@ -52,8 +76,11 @@ mod sys { #[pyattr(name = "_rustpython_debugbuild")] const RUSTPYTHON_DEBUGBUILD: bool = cfg!(debug_assertions); + #[cfg(not(windows))] #[pyattr(name = "abiflags")] - pub(crate) const ABIFLAGS: &str = "t"; // 't' for free-threaded (no GIL) + const ABIFLAGS_ATTR: &str = "t"; // 't' for free-threaded (no GIL) + // Internal constant used for sysconfigdata_name + pub(crate) const ABIFLAGS: &str = "t"; #[pyattr(name = "api_version")] const API_VERSION: u32 = 0x0; // what C api? #[pyattr(name = "copyright")] @@ -503,6 +530,7 @@ mod sys { "_multiarch" => ctx.new_str(multiarch()), "version" => version_info(vm), "hexversion" => ctx.new_int(version::VERSION_HEX), + "supports_isolated_interpreters" => ctx.new_bool(false), }) } @@ -604,6 +632,12 @@ mod sys { false // RustPython has no GIL (like free-threaded Python) } + /// Return True if remote debugging is enabled, False otherwise. + #[pyfunction] + const fn is_remote_debug_enabled() -> bool { + false // RustPython does not support remote debugging + } + #[pyfunction] fn exit(code: OptionalArg, vm: &VirtualMachine) -> PyResult { let code = code.unwrap_or_none(vm); @@ -1303,6 +1337,8 @@ mod sys { warn_default_encoding: u8, /// -X thread_inherit_context, whether new threads inherit context from parent thread_inherit_context: bool, + /// -X context_aware_warnings, whether warnings are context aware + context_aware_warnings: bool, } impl FlagsData { @@ -1327,6 +1363,7 @@ mod sys { safe_path: settings.safe_path, warn_default_encoding: settings.warn_default_encoding as u8, thread_inherit_context: settings.thread_inherit_context, + context_aware_warnings: settings.context_aware_warnings, } } } @@ -1551,9 +1588,14 @@ pub(crate) fn init_module(vm: &VirtualMachine, module: &Py, builtins: modules .set_item("builtins", builtins.to_owned().into(), vm) .unwrap(); + + // Create sys._jit submodule + let jit_module = sys_jit::make_module(vm); + extend_module!(vm, module, { "__doc__" => sys::DOC.to_owned().to_pyobject(vm), "modules" => modules, + "_jit" => jit_module, }); } diff --git a/crates/vm/src/stdlib/thread.rs b/crates/vm/src/stdlib/thread.rs index 9f0c0535d71..f7e47b15deb 100644 --- a/crates/vm/src/stdlib/thread.rs +++ b/crates/vm/src/stdlib/thread.rs @@ -421,14 +421,14 @@ pub(crate) mod _thread { vm.new_thread() .make_spawn_func(move |vm| run_thread(func, args, vm)), ) - .map(|handle| { - vm.state.thread_count.fetch_add(1); - thread_to_id(&handle) - }) + .map(|handle| thread_to_id(&handle)) .map_err(|err| vm.new_runtime_error(format!("can't start new thread: {err}"))) } fn run_thread(func: ArgCallable, args: FuncArgs, vm: &VirtualMachine) { + // Increment thread count when thread actually starts executing + vm.state.thread_count.fetch_add(1); + match func.invoke(args, vm) { Ok(_obj) => {} Err(e) if e.fast_isinstance(vm.ctx.exceptions.system_exit) => {} @@ -1168,13 +1168,6 @@ pub(crate) mod _thread { // Mark as done inner_for_cleanup.lock().state = ThreadHandleState::Done; - // Signal waiting threads that this thread is done - { - let (lock, cvar) = &*done_event_for_cleanup; - *lock.lock() = true; - cvar.notify_all(); - } - // Handle sentinels for lock in SENTINELS.take() { if lock.mu.is_locked() { @@ -1189,8 +1182,19 @@ pub(crate) mod _thread { crate::vm::thread::cleanup_current_thread_frames(vm); vm_state.thread_count.fetch_sub(1); + + // Signal waiting threads that this thread is done + // This must be LAST to ensure all cleanup is complete before join() returns + { + let (lock, cvar) = &*done_event_for_cleanup; + *lock.lock() = true; + cvar.notify_all(); + } } + // Increment thread count when thread actually starts executing + vm_state.thread_count.fetch_add(1); + // Run the function match func.invoke((), vm) { Ok(_) => {} @@ -1206,8 +1210,6 @@ pub(crate) mod _thread { })) .map_err(|err| vm.new_runtime_error(format!("can't start new thread: {err}")))?; - vm.state.thread_count.fetch_add(1); - // Store the join handle handle.inner.lock().join_handle = Some(join_handle); diff --git a/crates/vm/src/vm/mod.rs b/crates/vm/src/vm/mod.rs index 0c29d34ef88..5bd2b9e8297 100644 --- a/crates/vm/src/vm/mod.rs +++ b/crates/vm/src/vm/mod.rs @@ -86,6 +86,8 @@ pub struct VirtualMachine { pub state: PyRc, pub initialized: bool, recursion_depth: Cell, + /// C stack soft limit for detecting stack overflow (like c_stack_soft_limit) + c_stack_soft_limit: Cell, /// Async generator firstiter hook (per-thread, set via sys.set_asyncgen_hooks) pub async_gen_firstiter: RefCell>, /// Async generator finalizer hook (per-thread, set via sys.set_asyncgen_hooks) @@ -228,6 +230,7 @@ impl VirtualMachine { }), initialized: false, recursion_depth: Cell::new(0), + c_stack_soft_limit: Cell::new(Self::calculate_c_stack_soft_limit()), async_gen_firstiter: RefCell::new(None), async_gen_finalizer: RefCell::new(None), }; @@ -689,11 +692,127 @@ impl VirtualMachine { self.recursion_depth.get() } + /// Stack margin bytes (like _PyOS_STACK_MARGIN_BYTES). + /// 2048 * sizeof(void*) = 16KB for 64-bit. + const STACK_MARGIN_BYTES: usize = 2048 * std::mem::size_of::(); + + /// Get the stack boundaries using platform-specific APIs. + /// Returns (base, top) where base is the lowest address and top is the highest. + #[cfg(all(not(miri), windows))] + fn get_stack_bounds() -> (usize, usize) { + use windows_sys::Win32::System::Threading::{ + GetCurrentThreadStackLimits, SetThreadStackGuarantee, + }; + let mut low: usize = 0; + let mut high: usize = 0; + unsafe { + GetCurrentThreadStackLimits(&mut low as *mut usize, &mut high as *mut usize); + // Add the guaranteed stack space (reserved for exception handling) + let mut guarantee: u32 = 0; + SetThreadStackGuarantee(&mut guarantee); + low += guarantee as usize; + } + (low, high) + } + + /// Get stack boundaries on non-Windows platforms. + /// Falls back to estimating based on current stack pointer. + #[cfg(all(not(miri), not(windows)))] + fn get_stack_bounds() -> (usize, usize) { + // Use pthread_attr_getstack on platforms that support it + #[cfg(any(target_os = "linux", target_os = "android"))] + { + use libc::{ + pthread_attr_destroy, pthread_attr_getstack, pthread_attr_t, pthread_getattr_np, + pthread_self, + }; + let mut attr: pthread_attr_t = unsafe { std::mem::zeroed() }; + unsafe { + if pthread_getattr_np(pthread_self(), &mut attr) == 0 { + let mut stack_addr: *mut libc::c_void = std::ptr::null_mut(); + let mut stack_size: libc::size_t = 0; + if pthread_attr_getstack(&attr, &mut stack_addr, &mut stack_size) == 0 { + pthread_attr_destroy(&mut attr); + let base = stack_addr as usize; + let top = base + stack_size; + return (base, top); + } + pthread_attr_destroy(&mut attr); + } + } + } + + #[cfg(target_os = "macos")] + { + use libc::{pthread_get_stackaddr_np, pthread_get_stacksize_np, pthread_self}; + unsafe { + let thread = pthread_self(); + let stack_top = pthread_get_stackaddr_np(thread) as usize; + let stack_size = pthread_get_stacksize_np(thread); + let stack_base = stack_top - stack_size; + return (stack_base, stack_top); + } + } + + // Fallback: estimate based on current SP and a default stack size + #[allow(unreachable_code)] + { + let current_sp = psm::stack_pointer() as usize; + // Assume 8MB stack, estimate base + let estimated_size = 8 * 1024 * 1024; + let base = current_sp.saturating_sub(estimated_size); + let top = current_sp + 1024 * 1024; // Assume we're not at the very top + (base, top) + } + } + + /// Calculate the C stack soft limit based on actual stack boundaries. + /// soft_limit = base + 2 * margin (for downward-growing stacks) + #[cfg(not(miri))] + fn calculate_c_stack_soft_limit() -> usize { + let (base, _top) = Self::get_stack_bounds(); + // Soft limit is 2 margins above the base + base + Self::STACK_MARGIN_BYTES * 2 + } + + /// Miri doesn't support inline assembly, so disable C stack checking. + #[cfg(miri)] + fn calculate_c_stack_soft_limit() -> usize { + 0 + } + + /// Check if we're near the C stack limit (like _Py_MakeRecCheck). + /// Returns true only when stack pointer is in the "danger zone" between + /// soft_limit and hard_limit (soft_limit - 2*margin). + #[cfg(not(miri))] + #[inline(always)] + fn check_c_stack_overflow(&self) -> bool { + let current_sp = psm::stack_pointer() as usize; + let soft_limit = self.c_stack_soft_limit.get(); + // Stack grows downward: check if we're below soft limit but above hard limit + // This matches CPython's _Py_MakeRecCheck behavior + current_sp < soft_limit + && current_sp >= soft_limit.saturating_sub(Self::STACK_MARGIN_BYTES * 2) + } + + /// Miri doesn't support inline assembly, so always return false. + #[cfg(miri)] + #[inline(always)] + fn check_c_stack_overflow(&self) -> bool { + false + } + /// Used to run the body of a (possibly) recursive function. It will raise a /// RecursionError if recursive functions are nested far too many times, /// preventing a stack overflow. pub fn with_recursion PyResult>(&self, _where: &str, f: F) -> PyResult { self.check_recursive_call(_where)?; + + // Native stack guard: check C stack like _Py_MakeRecCheck + if self.check_c_stack_overflow() { + return Err(self.new_recursion_error(_where.to_string())); + } + self.recursion_depth.set(self.recursion_depth.get() + 1); let result = f(); self.recursion_depth.set(self.recursion_depth.get() - 1); diff --git a/crates/vm/src/vm/setting.rs b/crates/vm/src/vm/setting.rs index a7779156f8c..06cc35e933f 100644 --- a/crates/vm/src/vm/setting.rs +++ b/crates/vm/src/vm/setting.rs @@ -91,6 +91,9 @@ pub struct Settings { /// -X thread_inherit_context, whether new threads inherit context from parent pub thread_inherit_context: bool, + /// -X context_aware_warnings, whether warnings are context aware + pub context_aware_warnings: bool, + /// -i pub inspect: bool, @@ -194,6 +197,7 @@ impl Default for Settings { dev_mode: false, warn_default_encoding: false, thread_inherit_context: false, + context_aware_warnings: false, warnoptions: vec![], path_list: vec![], argv: vec![], diff --git a/crates/vm/src/vm/thread.rs b/crates/vm/src/vm/thread.rs index 7188aa6d270..bfbfc4f6a04 100644 --- a/crates/vm/src/vm/thread.rs +++ b/crates/vm/src/vm/thread.rs @@ -234,6 +234,7 @@ impl VirtualMachine { state: self.state.clone(), initialized: self.initialized, recursion_depth: Cell::new(0), + c_stack_soft_limit: Cell::new(VirtualMachine::calculate_c_stack_soft_limit()), async_gen_firstiter: RefCell::new(None), async_gen_finalizer: RefCell::new(None), }; diff --git a/scripts/auto_mark_test.py b/scripts/auto_mark_test.py index 2e3aef52599..59c7d8a05b0 100644 --- a/scripts/auto_mark_test.py +++ b/scripts/auto_mark_test.py @@ -83,7 +83,7 @@ def parse_results(result): test_results.stdout = result.stdout in_test_results = False for line in lines: - if re.match(r"Run tests? sequentially", line): + if re.search(r"Run \d+ tests? sequentially", line): in_test_results = True elif line.startswith("-----------"): in_test_results = False @@ -161,6 +161,66 @@ def is_super_call_only(func_node: ast.FunctionDef | ast.AsyncFunctionDef) -> boo return True +def build_inheritance_info(tree: ast.Module) -> tuple[dict, dict]: + """ + Build inheritance information from AST. + + Returns: + class_bases: dict[str, list[str]] - parent classes for each class (only those defined in the file) + class_methods: dict[str, set[str]] - methods directly defined in each class + """ + all_classes = { + node.name for node in ast.walk(tree) if isinstance(node, ast.ClassDef) + } + class_bases = {} + class_methods = {} + + for node in ast.walk(tree): + if isinstance(node, ast.ClassDef): + # Collect only parent classes defined in this file + bases = [ + base.id + for base in node.bases + if isinstance(base, ast.Name) and base.id in all_classes + ] + class_bases[node.name] = bases + + # Collect directly defined methods + methods = { + item.name + for item in node.body + if isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)) + } + class_methods[node.name] = methods + + return class_bases, class_methods + + +def find_method_definition( + class_name: str, method_name: str, class_bases: dict, class_methods: dict +) -> str | None: + """Find the class where a method is actually defined. Traverses inheritance chain (BFS).""" + # Check current class first + if method_name in class_methods.get(class_name, set()): + return class_name + + # Search parent classes + visited = set() + queue = list(class_bases.get(class_name, [])) + + while queue: + current = queue.pop(0) + if current in visited: + continue + visited.add(current) + + if method_name in class_methods.get(current, set()): + return current + queue.extend(class_bases.get(current, [])) + + return None + + def remove_expected_failures( contents: str, tests_to_remove: set[tuple[str, str]] ) -> str: @@ -172,6 +232,18 @@ def remove_expected_failures( lines = contents.splitlines() lines_to_remove = set() + # Build inheritance information + class_bases, class_methods = build_inheritance_info(tree) + + # Resolve to actual defining classes + resolved_tests = set() + for class_name, method_name in tests_to_remove: + defining_class = find_method_definition( + class_name, method_name, class_bases, class_methods + ) + if defining_class: + resolved_tests.add((defining_class, method_name)) + for node in ast.walk(tree): if not isinstance(node, ast.ClassDef): continue @@ -180,7 +252,7 @@ def remove_expected_failures( if not isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)): continue method_name = item.name - if (class_name, method_name) not in tests_to_remove: + if (class_name, method_name) not in resolved_tests: continue # Check if we should remove the entire method (super() call only) diff --git a/scripts/lib_updater.py b/scripts/lib_updater.py index 6176bd54431..712ea7bc838 100755 --- a/scripts/lib_updater.py +++ b/scripts/lib_updater.py @@ -236,11 +236,14 @@ def build_patch_dict(it: "Iterator[PatchEntry]") -> Patches: def iter_patch_lines(tree: ast.Module, patches: Patches) -> "Iterator[tuple[int, str]]": - cache = {} # Used in phase 2. Stores the end line location of a class name. + # Build cache of all classes (for Phase 2 to find classes without methods) + cache = {} + for node in tree.body: + if isinstance(node, ast.ClassDef): + cache[node.name] = node.end_lineno # Phase 1: Iterate and mark existing tests for cls_node, fn_node in iter_tests(tree): - cache[cls_node.name] = cls_node.end_lineno specs = patches.get(cls_node.name, {}).pop(fn_node.name, None) if not specs: continue diff --git a/src/lib.rs b/src/lib.rs index 11f6dc01c35..cf5227a70dd 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -208,6 +208,12 @@ fn run_rustpython(vm: &VirtualMachine, run_mode: RunMode) -> PyResult<()> { let scope = vm.new_scope_with_main()?; + // Initialize warnings module to process sys.warnoptions + // _PyWarnings_Init() + if vm.import("warnings", 0).is_err() { + warn!("Failed to import warnings module"); + } + // Import site first, before setting sys.path[0] // This matches CPython's behavior where site.removeduppaths() runs // before sys.path[0] is set, preventing '' from being converted to cwd @@ -219,12 +225,6 @@ fn run_rustpython(vm: &VirtualMachine, run_mode: RunMode) -> PyResult<()> { ); } - // Initialize warnings module to process sys.warnoptions - // _PyWarnings_Init() - if vm.import("warnings", 0).is_err() { - warn!("Failed to import warnings module"); - } - // _PyPathConfig_ComputeSysPath0 - set sys.path[0] after site import if !vm.state.config.settings.safe_path { let path0: Option = match &run_mode {