Coverage for arguably/_util.py: 88%
285 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-10 01:01 +0000
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-10 01:01 +0000
1from __future__ import annotations
3import ast
4import asyncio
5import enum
6import functools
7import importlib.util
8import inspect
9import logging
10import math
11import multiprocessing
12import re
13import sys
14import time
15import warnings
16from dataclasses import dataclass
17from io import StringIO
18from pathlib import Path
19from typing import Callable, cast, Any, Union, Optional, List, Dict, Type, Tuple, TextIO
21from docstring_parser import parse as docparse
23from typing import Annotated, get_type_hints, get_args, get_origin # noqa
25# UnionType is the new type for the `A | B` type syntax, which is 3.10 and up
26if sys.version_info >= (3, 10):
27 from types import UnionType # noqa
28else: # pragma: no cover
30 class UnionType:
31 """Stub this out, we only use it for issubclass() checks"""
34logger = logging.getLogger("arguably")
37def is_async_callable(obj: Any) -> bool:
38 """Checks if an object is an async callable - https://stackoverflow.com/a/72682939"""
39 while isinstance(obj, functools.partial):
40 obj = obj.func
41 return asyncio.iscoroutinefunction(obj) or (callable(obj) and asyncio.iscoroutinefunction(obj.__call__))
44def camel_case_to_kebab_case(name: str) -> str:
45 return re.sub(r"([a-z])([A-Z])", r"\1-\2", name).lower()
48def split_unquoted(unsplit: str, delimeter: str, limit: Union[int, float] = math.inf) -> List[str]:
49 """Splits text at a delimiter, as long as that delimiter is not quoted (either single ' or double quotes ")."""
50 assert len(delimeter) == 1
51 assert limit > 0
52 result = list()
53 quote_char = None
54 accumulator: List[str] = list()
55 for char in unsplit:
56 if char == delimeter and quote_char is None and limit > 0:
57 result.append("".join(accumulator))
58 accumulator.clear()
59 limit -= 1
60 continue
61 elif char == "'":
62 if quote_char is None:
63 quote_char = "'"
64 elif quote_char == "'":
65 quote_char = None
66 elif char == '"':
67 if quote_char is None:
68 quote_char = '"'
69 elif quote_char == '"':
70 quote_char = None
71 accumulator.append(char)
72 result.append("".join(accumulator))
73 accumulator.clear()
74 return result
77class NoDefault:
78 """Indicator that there is no default value for a parameter. Necessary because None can be the default value."""
81def unwrap_quotes(qs: str) -> str:
82 """Removes quotes wrapping a string - they must be matching, and also be the first and last character"""
83 if (qs.startswith('"') and qs.endswith('"')) or (qs.startswith("'") and qs.endswith("'")):
84 return qs[1:-1]
85 return qs
88def get_ancestors(command_name: str) -> List[str]:
89 """
90 List all ancestors for a given command. For example, `foo bar bat` yeilds a list with:
91 * `__root__`
92 * `__root__ foo`
93 * `__root__ foo bar`
95 Note that `__root__` is always an implicit ancestor.
96 """
97 if command_name == "__root__":
98 return []
99 tokens = command_name.split(" ")
100 return ["__root__"] + [" ".join(tokens[: i + 1]) for i in range(len(tokens))][:-1]
103def normalize_name(name: str, spaces: bool = True) -> str:
104 """
105 Normalizes a name. It's converted to lowercase, leading and trailing `_` are stripped, and `_` is converted to `-`.
106 If `spaces` is true, it also converts `__` to a single space ` `.
107 """
108 result = name.lower().strip("_")
109 if spaces:
110 result = result.replace("__", " ")
111 result = result.replace("_", "-")
112 if len(result.strip("- ")) == 0:
113 raise ArguablyException(f"Cannot normalize name `{name}` - cannot just be underscores and dashes.")
114 return result
117@dataclass
118class EnumFlagInfo:
119 """Used similarly to _CommandArg, but for entries in an `enum.Flag`."""
121 option: Union[Tuple[str], Tuple[str, str]]
122 cli_arg_name: str
123 value: Any
124 description: str
127def get_enum_member_docs(enum_class: Type[enum.Enum]) -> Dict[str, str]:
128 """Extracts docstrings for enum members similar to PEP-224, which has become a pseudo-standard supported by a lot of
129 tooling"""
130 parsed = ast.parse(inspect.getsource(enum_class))
131 assert len(parsed.body) == 1
133 classdef = parsed.body[0]
134 assert isinstance(classdef, ast.ClassDef)
136 # Search for a string expression following an assignment
137 prev = None
138 result: Dict[str, str] = dict()
139 for item in classdef.body:
140 if isinstance(item, ast.Expr) and isinstance(item.value, ast.Constant) and isinstance(item.value.value, str):
141 if isinstance(prev, ast.Assign) and len(prev.targets) == 1 and isinstance(prev.targets[0], ast.Name):
142 result[cast(ast.Name, prev.targets[0]).id] = item.value.value
143 prev = item
145 return result
148def info_for_flags(cli_arg_name: str, flag_class: Type[enum.Flag]) -> List[EnumFlagInfo]:
149 """Generates a list of `_EnumFlagInfo` corresponding to all flags in an `enum.Flag`."""
150 result = list()
151 docs = docparse(flag_class.__doc__ or "")
152 enum_member_docs = get_enum_member_docs(flag_class)
153 for item in flag_class:
154 arg_description = ""
156 # `docstring_parser` does not specially parse out attibutes declared in the docstring - we have to do that
157 # ourselves here.
158 found = False
159 for doc_item in docs.meta:
160 assert len(doc_item.args) >= 2
161 doc_item_type, doc_item_name = doc_item.args[0], doc_item.args[-1]
162 if item.name != doc_item_name:
163 continue
164 if doc_item_type not in ["var", "cvar", "attribute", "attr"]:
165 continue
166 arg_description = doc_item.description or ""
168 found = True
169 break
171 if not found and item.name in enum_member_docs:
172 # noinspection PyTypeChecker
173 arg_description = enum_member_docs[item.name]
175 # Extract the alias from the docstring for the flag item
176 options = list()
177 long_name: Optional[str] = normalize_name(cast(str, item.name))
178 arg_description, short_name, long_name = parse_short_and_long_name(long_name, arg_description, flag_class)
179 if short_name is not None:
180 options.append(f"-{short_name}")
181 if long_name is not None:
182 options.append(f"--{long_name}")
184 result.append(
185 EnumFlagInfo(cast(Union[Tuple[str], Tuple[str, str]], tuple(options)), cli_arg_name, item, arg_description)
186 )
187 return result
190def parse_short_and_long_name(
191 long_name: Optional[str], arg_description: str, func_or_class: Callable
192) -> Tuple[str, Optional[str], Optional[str]]:
193 """Extracts the short and long name for an option"""
194 short_name = None
195 if alias_match := re.match(r"^\[-([a-zA-Z0-9])(/--([a-zA-Z0-9][a-zA-Z0-9\-]*))?]", arg_description):
196 short_name, maybe_long_name = alias_match.group(1), alias_match.group(3)
197 arg_description = arg_description[len(alias_match.group(0)) :].lstrip(" ")
198 if maybe_long_name is not None:
199 long_name = maybe_long_name
200 elif alias_match := re.match(r"^\[--([a-zA-Z0-9][a-zA-Z0-9\-]*)(/-([a-zA-Z0-9]))?]", arg_description):
201 long_name, short_name = alias_match.group(1), alias_match.group(3)
202 arg_description = arg_description[len(alias_match.group(0)) :].lstrip(" ")
203 elif alias_match := re.match(r"^\[-([a-zA-Z0-9])/]", arg_description):
204 short_name = alias_match.group(1)
205 long_name = None
206 arg_description = arg_description[len(alias_match.group(0)) :].lstrip(" ")
207 elif arg_description.startswith("["):
208 # TODO: Should this be an exception?
209 warn(f"Description for {long_name} starts with `[`, but doesn't match any known option formats.", func_or_class)
210 return arg_description, short_name, long_name
213########################################################################################################################
214# For __main__
217class RedirectedIO(StringIO):
218 def __init__(self, pipe: Any) -> None:
219 super().__init__()
220 self.pipe = pipe
222 def write(self, s: str) -> int:
223 self.pipe.send(s)
224 return len(s)
227def capture_stdout_stderr(stdout_writer: Any, stderr_writer: Any, target: Callable, args: Tuple[Any, ...]) -> None:
228 sys.stdout = RedirectedIO(stdout_writer)
229 sys.stderr = RedirectedIO(stderr_writer)
231 target(*args)
233 sys.stdout = sys.__stdout__
234 sys.stderr = sys.__stderr__
237def io_redirector(proc: multiprocessing.Process, pipe: Any, file: TextIO) -> None:
238 while True:
239 try:
240 recv = pipe.recv().strip()
241 if recv:
242 print(recv, file=file)
243 except OSError:
244 if not proc.is_alive():
245 break
246 time.sleep(0.05)
247 except EOFError:
248 break
251def run_redirected_io(mp_ctx: multiprocessing.context.SpawnContext, target: Callable, args: Tuple[Any, ...]) -> None:
252 """Redirects the subprocess's stdout/stderr back to THIS process's stdout/stderr"""
253 from threading import Thread
255 # Set up multiprocessing so we can launch a new process to load the file
256 # We redirect stdout and stderr back to us and print in threads
257 stdout_reader, stdout_writer = mp_ctx.Pipe()
258 stderr_reader, stderr_writer = mp_ctx.Pipe()
260 proc = mp_ctx.Process(target=capture_stdout_stderr, args=(stdout_writer, stderr_writer, target, args))
262 # Run the external process
263 proc.start()
264 Thread(target=io_redirector, args=(proc, stdout_reader, sys.stdout)).start()
265 Thread(target=io_redirector, args=(proc, stderr_reader, sys.stderr)).start()
266 proc.join()
268 stderr_reader.close()
269 stderr_writer.close()
272@dataclass
273class LoadAndRunResult:
274 """Result from load_and_run"""
276 error: Optional[str] = None
277 exception: Optional[BaseException] = None
280@dataclass
281class ArgSpec:
282 args: Tuple[Any, ...]
283 kwargs: Dict[str, Any]
286def get_parser_name(prog_name: str) -> str:
287 nice_name = prog_name.partition(" ")[2]
288 if nice_name == "":
289 return "__root__"
290 return nice_name
293def log_args(logger_fn: Callable, msg: str, fn_name: str, *args: Any, **kwargs: Any) -> ArgSpec:
294 args_str = ", ".join(repr(a) for a in args)
295 kwargs_str = ", ".join(f"{k}={repr(v)}" for k, v in kwargs.items())
296 if len(args_str) == 0 or len(kwargs_str) == 0:
297 full_arg_string = f"{args_str}{kwargs_str}"
298 else:
299 full_arg_string = f"{args_str}, {kwargs_str}"
300 logger_fn(f"{msg}{fn_name}({full_arg_string})")
301 return ArgSpec(args, kwargs)
304def func_or_class_info(func_or_class: Callable) -> Optional[Tuple[str, int]]:
305 source_file = inspect.getsourcefile(func_or_class)
306 if source_file is None:
307 return None
309 # Skip lines before the `def`. Should be cleaned up in the future.
310 source_lines, line_number = inspect.getsourcelines(func_or_class)
311 for line in source_lines:
312 if "def " in line or "class " in line:
313 break
314 line_number += 1
316 return source_file, line_number
319def warn(message: str, func_or_class: Optional[Callable] = None) -> None:
320 """Provide a warning. We avoid using logging, since we're just a library, so we issue through `warnings`."""
322 if func_or_class is not None:
323 info = func_or_class_info(func_or_class)
324 if info is not None:
325 source_file, source_file_line = info
326 warnings.warn_explicit(
327 message,
328 ArguablyWarning,
329 source_file,
330 source_file_line,
331 )
332 return
334 warnings.warn(message, ArguablyWarning)
335 return
338def get_callable_methods(cls: type) -> List[Callable]:
339 """
340 Gets all the callable methods from a function - __init__, classmethods, and staticmethods. Skips abstractmethods.
341 """
342 callable_methods = []
344 for name, method in vars(cls).items():
345 if (name.startswith("__") and name.endswith("__")) or inspect.isabstract(method):
346 continue
347 if isinstance(method, staticmethod) or isinstance(method, classmethod):
348 callable_methods.append(getattr(cls, name))
350 return callable_methods
353def load_and_run_inner(file: Path, *args: str, debug: bool, no_warn: bool) -> LoadAndRunResult:
354 import arguably
356 if debug:
357 logging.basicConfig(level=logging.DEBUG, format="%(pathname)s:%(lineno)d: %(message)s")
359 if no_warn:
360 warnings.filterwarnings(action="ignore", category=arguably.ArguablyWarning)
362 # Load the specified file
363 spec = importlib.util.spec_from_file_location("_arguably_imported", str(file))
364 assert spec is not None
365 assert spec.loader is not None
366 module = importlib.util.module_from_spec(spec)
367 sys.modules["_arguably_imported"] = module
368 spec.loader.exec_module(module)
370 # Collect all callables (classes and functions)
371 callables = [item for item in vars(module).values() if callable(item)]
373 functions = list()
374 classes = list()
375 for callable_ in callables:
376 if isinstance(callable_, type):
377 classes.append(callable_)
378 else:
379 functions.append(callable_)
381 # For classmethods and staticmethods, we prepend the class name when we call `arguably.command`. Keep track of the
382 # real name here so we can revert it after
383 real_names: Dict[Any, str] = dict()
385 # Add classmethods and staticmethods
386 for cls in classes:
387 # Add the initializer for the class itself, if it's not abstract
388 if not inspect.isabstract(cls):
389 functions.append(cls)
391 # Add classmethods and staticmethods
392 for callable_method in get_callable_methods(cls):
393 cls_name = camel_case_to_kebab_case(cls.__name__)
394 if inspect.ismethod(callable_method):
395 # We have to set through .__func__ for the bound @classmethod
396 callable_method = cast(classmethod, callable_method) # type: ignore[assignment]
397 real_names[callable_method] = callable_method.__name__
398 callable_method.__func__.__name__ = f"{cls_name}.{callable_method.__name__}"
399 else:
400 real_names[callable_method] = callable_method.__name__
401 callable_method.__name__ = f"{cls_name}.{callable_method.__name__}"
402 functions.append(callable_method)
404 arguably._context.context.reset()
406 # Add all functions to arguably
407 for function in functions:
408 try:
409 # Heuristic for determining what is close enough to a class or function
410 inspect.signature(function)
411 get_type_hints(function, include_extras=True)
412 except TypeError:
413 continue
415 try:
416 arguably.command(function)
417 except Exception as e:
418 warn(f"Unable to add function {function.__name__}: {str(e)}", function)
419 continue
421 # If it's a classmethod or staticmethod, revert the name
422 if function in real_names:
423 if inspect.ismethod(function):
424 function.__func__.__name__ = real_names[function]
425 else:
426 function.__name__ = real_names[function]
428 sys.argv.extend(args)
430 import __main__
432 __main__.__doc__ = module.__doc__ or ""
433 version = False
434 if hasattr(module, "__version__"):
435 __main__.__version__ = module.__version__
436 version = True
438 # Run and return success
439 arguably.run(name=file.stem, always_subcommand=True, strict=False, version_flag=version)
440 return LoadAndRunResult()
443def load_and_run(results: multiprocessing.Queue, file: Path, argv: List[str], debug: bool, no_warn: bool) -> None:
444 """Load the specified file, register all callable top-level functions, classmethods, and staticmethods, then run"""
445 try:
446 results.put(load_and_run_inner(file, *argv, debug=debug, no_warn=no_warn))
447 except BaseException as e:
448 results.put(LoadAndRunResult(exception=e))
451########################################################################################################################
452# Exposed for API
455class ArguablyException(Exception):
456 """
457 Raised when a decorated function is incorrectly set up in some way. Will *not* be raised when a user provides
458 incorrect input to the CLI.
460 Examples:
461 ```python
462 #!/usr/bin/env python3
463 import arguably
465 @arguably.command
466 def example(collision_, _collision):
467 print("You should never see this")
469 if __name__ == "__main__":
470 arguably.run()
471 ```
473 ```console
474 user@machine:~$ python3 arguably-exception.py
475 Traceback (most recent call last):
476 File ".../arguably/etc/scripts/api-examples/arguably-exception.py", line 9, in <module>
477 arguably.run()
478 File ".../arguably/arguably/_context.py", line 706, in run
479 cmd = self._process_decorator_info(command_decorator_info)
480 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
481 File ".../arguably/arguably/_context.py", line 284, in _process_decorator_info
482 return Command(
483 ^^^^^^^^
484 File "<string>", line 9, in __init__
485 File ".../arguably/arguably/_commands.py", line 214, in __post_init__
486 raise util.ArguablyException(
487 arguably._util.ArguablyException: Function argument `_collision` in `example` conflicts with `collision_`, both
488 names simplify to `collision`
489 ```
490 """
493class ArguablyWarning(UserWarning):
494 """
495 If strict checks are disabled through `arguably.run(strict=False)` this is emitted when a decorated function is
496 incorrectly set up in some way, but arguably can continue. Will *not* be raised when a user provides incorrect input
497 to the CLI.
499 When `arguably` is directly invoked through `python3 -m arguably ...`, `strict=False` is always set.
501 Note that this is a warning - it is used with `warnings.warn`.
503 Examples:
504 ```python
505 def example_failed(collision_, _collision):
506 print("You should never see this")
508 def example_ok():
509 print("All good")
510 ```
512 ```console
513 user@machine:~$ python3 -m arguably arguably-warn.py -h
514 .../arguably/etc/scripts/api-examples/arguably-warn.py:1: ArguablyWarning: Unable to add function
515 example_failed: Function argument `_collision` in `example-failed` conflicts with `collision_`, both names
516 simplify to `collision`
517 def example_failed(collision_, _collision):
518 usage: arguably-warn [-h] command ...
520 positional arguments:
521 command
522 example-ok
524 options:
525 -h, --help show this help message and exit
526 ```
527 """