Coverage for arguably/_util.py: 88%

285 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-10 01:01 +0000

1from __future__ import annotations 

2 

3import ast 

4import asyncio 

5import enum 

6import functools 

7import importlib.util 

8import inspect 

9import logging 

10import math 

11import multiprocessing 

12import re 

13import sys 

14import time 

15import warnings 

16from dataclasses import dataclass 

17from io import StringIO 

18from pathlib import Path 

19from typing import Callable, cast, Any, Union, Optional, List, Dict, Type, Tuple, TextIO 

20 

21from docstring_parser import parse as docparse 

22 

23from typing import Annotated, get_type_hints, get_args, get_origin # noqa 

24 

25# UnionType is the new type for the `A | B` type syntax, which is 3.10 and up 

26if sys.version_info >= (3, 10): 

27 from types import UnionType # noqa 

28else: # pragma: no cover 

29 

30 class UnionType: 

31 """Stub this out, we only use it for issubclass() checks""" 

32 

33 

34logger = logging.getLogger("arguably") 

35 

36 

37def is_async_callable(obj: Any) -> bool: 

38 """Checks if an object is an async callable - https://stackoverflow.com/a/72682939""" 

39 while isinstance(obj, functools.partial): 

40 obj = obj.func 

41 return asyncio.iscoroutinefunction(obj) or (callable(obj) and asyncio.iscoroutinefunction(obj.__call__)) 

42 

43 

44def camel_case_to_kebab_case(name: str) -> str: 

45 return re.sub(r"([a-z])([A-Z])", r"\1-\2", name).lower() 

46 

47 

48def split_unquoted(unsplit: str, delimeter: str, limit: Union[int, float] = math.inf) -> List[str]: 

49 """Splits text at a delimiter, as long as that delimiter is not quoted (either single ' or double quotes ").""" 

50 assert len(delimeter) == 1 

51 assert limit > 0 

52 result = list() 

53 quote_char = None 

54 accumulator: List[str] = list() 

55 for char in unsplit: 

56 if char == delimeter and quote_char is None and limit > 0: 

57 result.append("".join(accumulator)) 

58 accumulator.clear() 

59 limit -= 1 

60 continue 

61 elif char == "'": 

62 if quote_char is None: 

63 quote_char = "'" 

64 elif quote_char == "'": 

65 quote_char = None 

66 elif char == '"': 

67 if quote_char is None: 

68 quote_char = '"' 

69 elif quote_char == '"': 

70 quote_char = None 

71 accumulator.append(char) 

72 result.append("".join(accumulator)) 

73 accumulator.clear() 

74 return result 

75 

76 

77class NoDefault: 

78 """Indicator that there is no default value for a parameter. Necessary because None can be the default value.""" 

79 

80 

81def unwrap_quotes(qs: str) -> str: 

82 """Removes quotes wrapping a string - they must be matching, and also be the first and last character""" 

83 if (qs.startswith('"') and qs.endswith('"')) or (qs.startswith("'") and qs.endswith("'")): 

84 return qs[1:-1] 

85 return qs 

86 

87 

88def get_ancestors(command_name: str) -> List[str]: 

89 """ 

90 List all ancestors for a given command. For example, `foo bar bat` yeilds a list with: 

91 * `__root__` 

92 * `__root__ foo` 

93 * `__root__ foo bar` 

94 

95 Note that `__root__` is always an implicit ancestor. 

96 """ 

97 if command_name == "__root__": 

98 return [] 

99 tokens = command_name.split(" ") 

100 return ["__root__"] + [" ".join(tokens[: i + 1]) for i in range(len(tokens))][:-1] 

101 

102 

103def normalize_name(name: str, spaces: bool = True) -> str: 

104 """ 

105 Normalizes a name. It's converted to lowercase, leading and trailing `_` are stripped, and `_` is converted to `-`. 

106 If `spaces` is true, it also converts `__` to a single space ` `. 

107 """ 

108 result = name.lower().strip("_") 

109 if spaces: 

110 result = result.replace("__", " ") 

111 result = result.replace("_", "-") 

112 if len(result.strip("- ")) == 0: 

113 raise ArguablyException(f"Cannot normalize name `{name}` - cannot just be underscores and dashes.") 

114 return result 

115 

116 

117@dataclass 

118class EnumFlagInfo: 

119 """Used similarly to _CommandArg, but for entries in an `enum.Flag`.""" 

120 

121 option: Union[Tuple[str], Tuple[str, str]] 

122 cli_arg_name: str 

123 value: Any 

124 description: str 

125 

126 

127def get_enum_member_docs(enum_class: Type[enum.Enum]) -> Dict[str, str]: 

128 """Extracts docstrings for enum members similar to PEP-224, which has become a pseudo-standard supported by a lot of 

129 tooling""" 

130 parsed = ast.parse(inspect.getsource(enum_class)) 

131 assert len(parsed.body) == 1 

132 

133 classdef = parsed.body[0] 

134 assert isinstance(classdef, ast.ClassDef) 

135 

136 # Search for a string expression following an assignment 

137 prev = None 

138 result: Dict[str, str] = dict() 

139 for item in classdef.body: 

140 if isinstance(item, ast.Expr) and isinstance(item.value, ast.Constant) and isinstance(item.value.value, str): 

141 if isinstance(prev, ast.Assign) and len(prev.targets) == 1 and isinstance(prev.targets[0], ast.Name): 

142 result[cast(ast.Name, prev.targets[0]).id] = item.value.value 

143 prev = item 

144 

145 return result 

146 

147 

148def info_for_flags(cli_arg_name: str, flag_class: Type[enum.Flag]) -> List[EnumFlagInfo]: 

149 """Generates a list of `_EnumFlagInfo` corresponding to all flags in an `enum.Flag`.""" 

150 result = list() 

151 docs = docparse(flag_class.__doc__ or "") 

152 enum_member_docs = get_enum_member_docs(flag_class) 

153 for item in flag_class: 

154 arg_description = "" 

155 

156 # `docstring_parser` does not specially parse out attibutes declared in the docstring - we have to do that 

157 # ourselves here. 

158 found = False 

159 for doc_item in docs.meta: 

160 assert len(doc_item.args) >= 2 

161 doc_item_type, doc_item_name = doc_item.args[0], doc_item.args[-1] 

162 if item.name != doc_item_name: 

163 continue 

164 if doc_item_type not in ["var", "cvar", "attribute", "attr"]: 

165 continue 

166 arg_description = doc_item.description or "" 

167 

168 found = True 

169 break 

170 

171 if not found and item.name in enum_member_docs: 

172 # noinspection PyTypeChecker 

173 arg_description = enum_member_docs[item.name] 

174 

175 # Extract the alias from the docstring for the flag item 

176 options = list() 

177 long_name: Optional[str] = normalize_name(cast(str, item.name)) 

178 arg_description, short_name, long_name = parse_short_and_long_name(long_name, arg_description, flag_class) 

179 if short_name is not None: 

180 options.append(f"-{short_name}") 

181 if long_name is not None: 

182 options.append(f"--{long_name}") 

183 

184 result.append( 

185 EnumFlagInfo(cast(Union[Tuple[str], Tuple[str, str]], tuple(options)), cli_arg_name, item, arg_description) 

186 ) 

187 return result 

188 

189 

190def parse_short_and_long_name( 

191 long_name: Optional[str], arg_description: str, func_or_class: Callable 

192) -> Tuple[str, Optional[str], Optional[str]]: 

193 """Extracts the short and long name for an option""" 

194 short_name = None 

195 if alias_match := re.match(r"^\[-([a-zA-Z0-9])(/--([a-zA-Z0-9][a-zA-Z0-9\-]*))?]", arg_description): 

196 short_name, maybe_long_name = alias_match.group(1), alias_match.group(3) 

197 arg_description = arg_description[len(alias_match.group(0)) :].lstrip(" ") 

198 if maybe_long_name is not None: 

199 long_name = maybe_long_name 

200 elif alias_match := re.match(r"^\[--([a-zA-Z0-9][a-zA-Z0-9\-]*)(/-([a-zA-Z0-9]))?]", arg_description): 

201 long_name, short_name = alias_match.group(1), alias_match.group(3) 

202 arg_description = arg_description[len(alias_match.group(0)) :].lstrip(" ") 

203 elif alias_match := re.match(r"^\[-([a-zA-Z0-9])/]", arg_description): 

204 short_name = alias_match.group(1) 

205 long_name = None 

206 arg_description = arg_description[len(alias_match.group(0)) :].lstrip(" ") 

207 elif arg_description.startswith("["): 

208 # TODO: Should this be an exception? 

209 warn(f"Description for {long_name} starts with `[`, but doesn't match any known option formats.", func_or_class) 

210 return arg_description, short_name, long_name 

211 

212 

213######################################################################################################################## 

214# For __main__ 

215 

216 

217class RedirectedIO(StringIO): 

218 def __init__(self, pipe: Any) -> None: 

219 super().__init__() 

220 self.pipe = pipe 

221 

222 def write(self, s: str) -> int: 

223 self.pipe.send(s) 

224 return len(s) 

225 

226 

227def capture_stdout_stderr(stdout_writer: Any, stderr_writer: Any, target: Callable, args: Tuple[Any, ...]) -> None: 

228 sys.stdout = RedirectedIO(stdout_writer) 

229 sys.stderr = RedirectedIO(stderr_writer) 

230 

231 target(*args) 

232 

233 sys.stdout = sys.__stdout__ 

234 sys.stderr = sys.__stderr__ 

235 

236 

237def io_redirector(proc: multiprocessing.Process, pipe: Any, file: TextIO) -> None: 

238 while True: 

239 try: 

240 recv = pipe.recv().strip() 

241 if recv: 

242 print(recv, file=file) 

243 except OSError: 

244 if not proc.is_alive(): 

245 break 

246 time.sleep(0.05) 

247 except EOFError: 

248 break 

249 

250 

251def run_redirected_io(mp_ctx: multiprocessing.context.SpawnContext, target: Callable, args: Tuple[Any, ...]) -> None: 

252 """Redirects the subprocess's stdout/stderr back to THIS process's stdout/stderr""" 

253 from threading import Thread 

254 

255 # Set up multiprocessing so we can launch a new process to load the file 

256 # We redirect stdout and stderr back to us and print in threads 

257 stdout_reader, stdout_writer = mp_ctx.Pipe() 

258 stderr_reader, stderr_writer = mp_ctx.Pipe() 

259 

260 proc = mp_ctx.Process(target=capture_stdout_stderr, args=(stdout_writer, stderr_writer, target, args)) 

261 

262 # Run the external process 

263 proc.start() 

264 Thread(target=io_redirector, args=(proc, stdout_reader, sys.stdout)).start() 

265 Thread(target=io_redirector, args=(proc, stderr_reader, sys.stderr)).start() 

266 proc.join() 

267 

268 stderr_reader.close() 

269 stderr_writer.close() 

270 

271 

272@dataclass 

273class LoadAndRunResult: 

274 """Result from load_and_run""" 

275 

276 error: Optional[str] = None 

277 exception: Optional[BaseException] = None 

278 

279 

280@dataclass 

281class ArgSpec: 

282 args: Tuple[Any, ...] 

283 kwargs: Dict[str, Any] 

284 

285 

286def get_parser_name(prog_name: str) -> str: 

287 nice_name = prog_name.partition(" ")[2] 

288 if nice_name == "": 

289 return "__root__" 

290 return nice_name 

291 

292 

293def log_args(logger_fn: Callable, msg: str, fn_name: str, *args: Any, **kwargs: Any) -> ArgSpec: 

294 args_str = ", ".join(repr(a) for a in args) 

295 kwargs_str = ", ".join(f"{k}={repr(v)}" for k, v in kwargs.items()) 

296 if len(args_str) == 0 or len(kwargs_str) == 0: 

297 full_arg_string = f"{args_str}{kwargs_str}" 

298 else: 

299 full_arg_string = f"{args_str}, {kwargs_str}" 

300 logger_fn(f"{msg}{fn_name}({full_arg_string})") 

301 return ArgSpec(args, kwargs) 

302 

303 

304def func_or_class_info(func_or_class: Callable) -> Optional[Tuple[str, int]]: 

305 source_file = inspect.getsourcefile(func_or_class) 

306 if source_file is None: 

307 return None 

308 

309 # Skip lines before the `def`. Should be cleaned up in the future. 

310 source_lines, line_number = inspect.getsourcelines(func_or_class) 

311 for line in source_lines: 

312 if "def " in line or "class " in line: 

313 break 

314 line_number += 1 

315 

316 return source_file, line_number 

317 

318 

319def warn(message: str, func_or_class: Optional[Callable] = None) -> None: 

320 """Provide a warning. We avoid using logging, since we're just a library, so we issue through `warnings`.""" 

321 

322 if func_or_class is not None: 

323 info = func_or_class_info(func_or_class) 

324 if info is not None: 

325 source_file, source_file_line = info 

326 warnings.warn_explicit( 

327 message, 

328 ArguablyWarning, 

329 source_file, 

330 source_file_line, 

331 ) 

332 return 

333 

334 warnings.warn(message, ArguablyWarning) 

335 return 

336 

337 

338def get_callable_methods(cls: type) -> List[Callable]: 

339 """ 

340 Gets all the callable methods from a function - __init__, classmethods, and staticmethods. Skips abstractmethods. 

341 """ 

342 callable_methods = [] 

343 

344 for name, method in vars(cls).items(): 

345 if (name.startswith("__") and name.endswith("__")) or inspect.isabstract(method): 

346 continue 

347 if isinstance(method, staticmethod) or isinstance(method, classmethod): 

348 callable_methods.append(getattr(cls, name)) 

349 

350 return callable_methods 

351 

352 

353def load_and_run_inner(file: Path, *args: str, debug: bool, no_warn: bool) -> LoadAndRunResult: 

354 import arguably 

355 

356 if debug: 

357 logging.basicConfig(level=logging.DEBUG, format="%(pathname)s:%(lineno)d: %(message)s") 

358 

359 if no_warn: 

360 warnings.filterwarnings(action="ignore", category=arguably.ArguablyWarning) 

361 

362 # Load the specified file 

363 spec = importlib.util.spec_from_file_location("_arguably_imported", str(file)) 

364 assert spec is not None 

365 assert spec.loader is not None 

366 module = importlib.util.module_from_spec(spec) 

367 sys.modules["_arguably_imported"] = module 

368 spec.loader.exec_module(module) 

369 

370 # Collect all callables (classes and functions) 

371 callables = [item for item in vars(module).values() if callable(item)] 

372 

373 functions = list() 

374 classes = list() 

375 for callable_ in callables: 

376 if isinstance(callable_, type): 

377 classes.append(callable_) 

378 else: 

379 functions.append(callable_) 

380 

381 # For classmethods and staticmethods, we prepend the class name when we call `arguably.command`. Keep track of the 

382 # real name here so we can revert it after 

383 real_names: Dict[Any, str] = dict() 

384 

385 # Add classmethods and staticmethods 

386 for cls in classes: 

387 # Add the initializer for the class itself, if it's not abstract 

388 if not inspect.isabstract(cls): 

389 functions.append(cls) 

390 

391 # Add classmethods and staticmethods 

392 for callable_method in get_callable_methods(cls): 

393 cls_name = camel_case_to_kebab_case(cls.__name__) 

394 if inspect.ismethod(callable_method): 

395 # We have to set through .__func__ for the bound @classmethod 

396 callable_method = cast(classmethod, callable_method) # type: ignore[assignment] 

397 real_names[callable_method] = callable_method.__name__ 

398 callable_method.__func__.__name__ = f"{cls_name}.{callable_method.__name__}" 

399 else: 

400 real_names[callable_method] = callable_method.__name__ 

401 callable_method.__name__ = f"{cls_name}.{callable_method.__name__}" 

402 functions.append(callable_method) 

403 

404 arguably._context.context.reset() 

405 

406 # Add all functions to arguably 

407 for function in functions: 

408 try: 

409 # Heuristic for determining what is close enough to a class or function 

410 inspect.signature(function) 

411 get_type_hints(function, include_extras=True) 

412 except TypeError: 

413 continue 

414 

415 try: 

416 arguably.command(function) 

417 except Exception as e: 

418 warn(f"Unable to add function {function.__name__}: {str(e)}", function) 

419 continue 

420 

421 # If it's a classmethod or staticmethod, revert the name 

422 if function in real_names: 

423 if inspect.ismethod(function): 

424 function.__func__.__name__ = real_names[function] 

425 else: 

426 function.__name__ = real_names[function] 

427 

428 sys.argv.extend(args) 

429 

430 import __main__ 

431 

432 __main__.__doc__ = module.__doc__ or "" 

433 version = False 

434 if hasattr(module, "__version__"): 

435 __main__.__version__ = module.__version__ 

436 version = True 

437 

438 # Run and return success 

439 arguably.run(name=file.stem, always_subcommand=True, strict=False, version_flag=version) 

440 return LoadAndRunResult() 

441 

442 

443def load_and_run(results: multiprocessing.Queue, file: Path, argv: List[str], debug: bool, no_warn: bool) -> None: 

444 """Load the specified file, register all callable top-level functions, classmethods, and staticmethods, then run""" 

445 try: 

446 results.put(load_and_run_inner(file, *argv, debug=debug, no_warn=no_warn)) 

447 except BaseException as e: 

448 results.put(LoadAndRunResult(exception=e)) 

449 

450 

451######################################################################################################################## 

452# Exposed for API 

453 

454 

455class ArguablyException(Exception): 

456 """ 

457 Raised when a decorated function is incorrectly set up in some way. Will *not* be raised when a user provides 

458 incorrect input to the CLI. 

459 

460 Examples: 

461 ```python 

462 #!/usr/bin/env python3 

463 import arguably 

464 

465 @arguably.command 

466 def example(collision_, _collision): 

467 print("You should never see this") 

468 

469 if __name__ == "__main__": 

470 arguably.run() 

471 ``` 

472 

473 ```console 

474 user@machine:~$ python3 arguably-exception.py 

475 Traceback (most recent call last): 

476 File ".../arguably/etc/scripts/api-examples/arguably-exception.py", line 9, in <module> 

477 arguably.run() 

478 File ".../arguably/arguably/_context.py", line 706, in run 

479 cmd = self._process_decorator_info(command_decorator_info) 

480 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 

481 File ".../arguably/arguably/_context.py", line 284, in _process_decorator_info 

482 return Command( 

483 ^^^^^^^^ 

484 File "<string>", line 9, in __init__ 

485 File ".../arguably/arguably/_commands.py", line 214, in __post_init__ 

486 raise util.ArguablyException( 

487 arguably._util.ArguablyException: Function argument `_collision` in `example` conflicts with `collision_`, both 

488 names simplify to `collision` 

489 ``` 

490 """ 

491 

492 

493class ArguablyWarning(UserWarning): 

494 """ 

495 If strict checks are disabled through `arguably.run(strict=False)` this is emitted when a decorated function is 

496 incorrectly set up in some way, but arguably can continue. Will *not* be raised when a user provides incorrect input 

497 to the CLI. 

498 

499 When `arguably` is directly invoked through `python3 -m arguably ...`, `strict=False` is always set. 

500 

501 Note that this is a warning - it is used with `warnings.warn`. 

502 

503 Examples: 

504 ```python 

505 def example_failed(collision_, _collision): 

506 print("You should never see this") 

507 

508 def example_ok(): 

509 print("All good") 

510 ``` 

511 

512 ```console 

513 user@machine:~$ python3 -m arguably arguably-warn.py -h 

514 .../arguably/etc/scripts/api-examples/arguably-warn.py:1: ArguablyWarning: Unable to add function 

515 example_failed: Function argument `_collision` in `example-failed` conflicts with `collision_`, both names 

516 simplify to `collision` 

517 def example_failed(collision_, _collision): 

518 usage: arguably-warn [-h] command ... 

519 

520 positional arguments: 

521 command 

522 example-ok 

523 

524 options: 

525 -h, --help show this help message and exit 

526 ``` 

527 """