Coverage for arguably/_commands.py: 89%

218 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-10 01:01 +0000

1from __future__ import annotations 

2 

3import asyncio 

4import enum 

5import inspect 

6import re 

7from dataclasses import dataclass, field 

8from typing import Callable, Any, Union, Optional, List, Dict, Tuple, cast, Set 

9 

10from docstring_parser import parse as docparse 

11 

12import arguably._modifiers as mods 

13import arguably._util as util 

14 

15 

16class InputMethod(enum.Enum): 

17 """Specifies how a given argument is passed in""" 

18 

19 REQUIRED_POSITIONAL = 0 # usage: foo BAR 

20 OPTIONAL_POSITIONAL = 1 # usage: foo [BAR] 

21 OPTION = 2 # Examples: -F, --test_scripts, --filename foo.txt 

22 

23 @property 

24 def is_positional(self) -> bool: 

25 return self in [InputMethod.REQUIRED_POSITIONAL, InputMethod.OPTIONAL_POSITIONAL] 

26 

27 @property 

28 def is_optional(self) -> bool: 

29 return self in [InputMethod.OPTIONAL_POSITIONAL, InputMethod.OPTION] 

30 

31 

32@dataclass 

33class CommandDecoratorInfo: 

34 """Used for keeping a reference to everything marked with @arguably.command""" 

35 

36 function: Callable 

37 alias: Optional[str] = None 

38 help: bool = True 

39 name: str = field(init=False) 

40 command: Command = field(init=False) 

41 

42 def __post_init__(self) -> None: 

43 if self.function.__name__ == "__root__": 

44 self.name = "__root__" 

45 elif isinstance(self.function, type): 

46 self.name = util.normalize_name(util.camel_case_to_kebab_case(self.function.__name__)) 

47 else: 

48 self.name = util.normalize_name(self.function.__name__) 

49 

50 self.command = self._process() 

51 

52 def _process(self) -> Command: 

53 """Takes the decorator info and return a processed command""" 

54 

55 processed_name = self.name 

56 func = self.function.__init__ if isinstance(self.function, type) else self.function # type: ignore[misc] 

57 

58 # Get the description from the docstring 

59 if func.__doc__ is None: 

60 docs = None 

61 processed_description = "" 

62 else: 

63 docs = docparse(func.__doc__) 

64 processed_description = "" if docs.short_description is None else docs.short_description 

65 

66 try: 

67 hints = util.get_type_hints(func, include_extras=True) 

68 except NameError as e: 

69 hints = {} 

70 util.warn(f"Unable to resolve type hints for function {processed_name}: {str(e)}", func) 

71 

72 # Will be filled in as we loop over all parameters 

73 processed_args: List[CommandArg] = list() 

74 

75 # Iterate over all parameters 

76 for func_arg_name, param in inspect.signature(self.function).parameters.items(): 

77 cli_arg_name = util.normalize_name(func_arg_name, spaces=False) 

78 arg_default = util.NoDefault if param.default is param.empty else param.default 

79 

80 # Handle variadic arguments 

81 is_variadic = False 

82 if param.kind is param.VAR_KEYWORD: 

83 raise util.ArguablyException(f"`{processed_name}` is using **kwargs, which is not supported") 

84 if param.kind is param.VAR_POSITIONAL: 

85 is_variadic = True 

86 

87 # Get the type and normalize it 

88 arg_value_type, modifiers = CommandArg.normalize_type(processed_name, param, hints) 

89 tuple_modifiers = [m for m in modifiers if isinstance(m, mods.TupleModifier)] 

90 expected_metavars = 1 

91 if len(tuple_modifiers) > 0: 

92 assert len(tuple_modifiers) == 1 

93 expected_metavars = len(tuple_modifiers[0].tuple_arg) 

94 

95 # What kind of argument is this? Is it required-positional, optional-positional, or an option? 

96 if param.kind == param.KEYWORD_ONLY: 

97 input_method = InputMethod.OPTION 

98 elif arg_default is util.NoDefault: 

99 input_method = InputMethod.REQUIRED_POSITIONAL 

100 else: 

101 input_method = InputMethod.OPTIONAL_POSITIONAL 

102 

103 # Get the description 

104 arg_description = "" 

105 if docs is not None and docs.params is not None: 

106 ds_matches = [ds_p for ds_p in docs.params if ds_p.arg_name.lstrip("*") == param.name] 

107 if len(ds_matches) > 1: 

108 raise util.ArguablyException( 

109 f"Function argument `{param.name}` in " f"`{processed_name}` has multiple docstring entries." 

110 ) 

111 if len(ds_matches) == 1: 

112 ds_info = ds_matches[0] 

113 arg_description = "" if ds_info.description is None else ds_info.description 

114 

115 # Extract the alias 

116 arg_alias = None 

117 has_long_name = True 

118 if input_method == InputMethod.OPTION: 

119 arg_description, arg_alias, long_name = util.parse_short_and_long_name( 

120 cli_arg_name, arg_description, func 

121 ) 

122 if long_name is None: 

123 has_long_name = False 

124 else: 

125 cli_arg_name = long_name 

126 else: 

127 if arg_description.startswith("["): 

128 util.warn( 

129 f"Function argument `{param.name}` in `{processed_name}` is a positional argument, but starts " 

130 f"with a `[`, which is used to specify --option names. To make this argument an --option, make " 

131 f"it into be a keyword-only argument.", 

132 func, 

133 ) 

134 

135 # Extract the metavars 

136 metavars = None 

137 if metavar_split := re.split(r"\{((?:[a-zA-Z0-9_-]+(?:, *)*)+)}", arg_description): 

138 if len(metavar_split) == 3: 

139 # format would be: ['pre-metavar', 'METAVAR', 'post-metavar'] 

140 match_items = [i.strip() for i in metavar_split[1].split(",")] 

141 if is_variadic: 

142 if len(match_items) != 1: 

143 raise util.ArguablyException( 

144 f"Function argument `{param.name}` in `{processed_name}` should only have one item in " 

145 f"its metavar descriptor, but found {len(match_items)}: {','.join(match_items)}." 

146 ) 

147 elif len(match_items) != expected_metavars: 

148 if len(match_items) == 1: 

149 match_items *= expected_metavars 

150 else: 

151 raise util.ArguablyException( 

152 f"Function argument `{param.name}` in `{processed_name}` takes {expected_metavars} " 

153 f"items, but metavar descriptor has {len(match_items)}: {','.join(match_items)}." 

154 ) 

155 metavars = [i.upper() for i in match_items] 

156 arg_description = "".join(metavar_split) # Strips { and } from metavars for description 

157 if len(metavar_split) > 3: 

158 raise util.ArguablyException( 

159 f"Function argument `{param.name}` in `{processed_name}` has multiple metavar sequences - " 

160 f"these are denoted like { A, B, C} . There should be only one." 

161 ) 

162 

163 # Check modifiers 

164 for modifier in modifiers: 

165 modifier.check_valid(arg_value_type, param, processed_name) 

166 

167 # Finished processing this arg 

168 processed_args.append( 

169 CommandArg( 

170 func_arg_name, 

171 cli_arg_name, 

172 input_method, 

173 is_variadic, 

174 arg_value_type, 

175 arg_description, 

176 arg_alias, 

177 has_long_name, 

178 metavars, 

179 arg_default, 

180 modifiers, 

181 ) 

182 ) 

183 

184 # Return the processed command 

185 return Command( 

186 self.function, 

187 processed_name, 

188 processed_args, 

189 processed_description, 

190 self.alias, 

191 self.help, 

192 ) 

193 

194 

195@dataclass 

196class SubtypeDecoratorInfo: 

197 """Used for keeping a reference to everything marked with @arguably.subtype""" 

198 

199 type_: type 

200 alias: Optional[str] = None 

201 ignore: bool = False 

202 factory: Optional[Callable] = None 

203 

204 

205@dataclass 

206class CommandArg: 

207 """A single argument to a given command""" 

208 

209 func_arg_name: str 

210 cli_arg_name: str 

211 

212 input_method: InputMethod 

213 is_variadic: bool 

214 arg_value_type: type 

215 

216 description: str 

217 alias: Optional[str] = None 

218 has_long_name: bool = True 

219 metavars: Optional[List[str]] = None 

220 

221 default: Any = util.NoDefault 

222 

223 modifiers: List[mods.CommandArgModifier] = field(default_factory=list) 

224 

225 def __post_init__(self) -> None: 

226 if not self.has_long_name and self.alias is None: 

227 raise ValueError("CommandArg has no short or long name") 

228 

229 def get_options(self) -> Union[Tuple[()], Tuple[str], Tuple[str, str]]: 

230 options = list() 

231 if self.alias is not None: 

232 options.append(f"-{self.alias}") 

233 if self.has_long_name: 

234 options.append(f"--{self.cli_arg_name}") 

235 return cast(Union[Tuple[()], Tuple[str], Tuple[str, str]], tuple(options)) 

236 

237 @staticmethod 

238 def _normalize_type_union( 

239 function_name: str, 

240 param: inspect.Parameter, 

241 value_type: type, 

242 ) -> type: 

243 """ 

244 We break this out because Python 3.10 seems to want to wrap `Annotated[Optional[...` in another `Optional`, so 

245 we call this twice. 

246 """ 

247 if isinstance(value_type, util.UnionType) or util.get_origin(value_type) is Union: 

248 filtered_types = [x for x in util.get_args(value_type) if x is not type(None)] 

249 if len(filtered_types) != 1: 

250 raise util.ArguablyException( 

251 f"Function argument `{param.name}` in `{function_name}` is an unsupported type. It must be either " 

252 f"a single, non-generic type or a Union with None." 

253 ) 

254 value_type = filtered_types[0] 

255 return value_type 

256 

257 @staticmethod 

258 def normalize_type( 

259 function_name: str, 

260 param: inspect.Parameter, 

261 hints: Dict[str, Any], 

262 ) -> Tuple[type, List[mods.CommandArgModifier]]: 

263 """ 

264 Normalizes the function argument type. Most of the logic here is validation. Explanation of what's returned for 

265 a given function argument type: 

266 * SomeType -> value_type=SomeType, modifiers=[] 

267 * int | None -> value_type=int, modifiers=[] 

268 * Tuple[float, float] -> value_type=type(None), modifiers=[_TupleModifier([float, float])] 

269 * List[str] -> value_type=str, modifiers=[_ListModifier()] 

270 * Annotated[int, arg.count()] -> value_type=int, modifiers=[_CountedModifier()] 

271 

272 Things that will cause an exception: 

273 * Parameterized type other than a Optional[] or Tuple[] 

274 * Flexible-length Tuple[SomeType, ...] 

275 * Parameter lacking an annotation 

276 """ 

277 

278 modifiers: List[mods.CommandArgModifier] = list() 

279 

280 if param.name in hints: 

281 value_type = hints[param.name] 

282 else: 

283 # No type hint. Guess type from default value, if any other than None. Otherwise, default to string. 

284 value_type = type(param.default) if param.default not in [param.empty, None] else str 

285 

286 # Extra call to normalize a union here, see note in `_normalize_type_union` 

287 value_type = CommandArg._normalize_type_union(function_name, param, value_type) 

288 

289 # Handle annotated types 

290 if util.get_origin(value_type) == util.Annotated: 

291 type_args = util.get_args(value_type) 

292 if len(type_args) == 0: 

293 raise util.ArguablyException(f"Function argument `{param.name}` is Annotated, but no type is specified") 

294 else: 

295 value_type = type_args[0] 

296 for type_arg in type_args[1:]: 

297 if not isinstance(type_arg, mods.CommandArgModifier): 

298 raise util.ArguablyException( 

299 f"Function argument `{param.name}` has an invalid annotation value: {type_arg}" 

300 ) 

301 modifiers.append(type_arg) 

302 

303 # Normalize Union with None 

304 value_type = CommandArg._normalize_type_union(function_name, param, value_type) 

305 

306 # Validate list/tuple and error on other parameterized types 

307 origin = util.get_origin(value_type) 

308 if (isinstance(value_type, type) and issubclass(value_type, list)) or ( 

309 isinstance(origin, type) and issubclass(origin, list) 

310 ): 

311 type_args = util.get_args(value_type) 

312 if len(type_args) == 0: 

313 value_type = str 

314 elif len(type_args) > 1: 

315 raise util.ArguablyException( 

316 f"Function argument `{param.name}` in `{function_name}` has too many items passed to List[...]." 

317 f"There should be exactly one item between the square brackets." 

318 ) 

319 else: 

320 value_type = type_args[0] 

321 modifiers.append(mods.ListModifier()) 

322 elif (isinstance(value_type, type) and issubclass(value_type, tuple)) or ( 

323 isinstance(origin, type) and issubclass(origin, tuple) 

324 ): 

325 if param.kind in [param.VAR_KEYWORD, param.VAR_POSITIONAL]: 

326 raise util.ArguablyException( 

327 f"Function argument `{param.name}` in `{function_name}` is an *args or **kwargs, which should " 

328 f"be annotated with what only one of its items should be." 

329 ) 

330 type_args = util.get_args(value_type) 

331 if len(type_args) == 0: 

332 raise util.ArguablyException( 

333 f"Function argument `{param.name}` in `{function_name}` is a tuple but doesn't specify the " 

334 f"type of its items, which arguably requires." 

335 ) 

336 if type_args[-1] is Ellipsis: 

337 raise util.ArguablyException( 

338 f"Function argument `{param.name}` in `{function_name}` is a variable-length tuple, which is " 

339 f"not supported." 

340 ) 

341 value_type = type(None) 

342 modifiers.append(mods.TupleModifier(list(type_args))) 

343 elif origin is not None: 

344 if param.kind in [param.VAR_KEYWORD, param.VAR_POSITIONAL]: 

345 raise util.ArguablyException( 

346 f"Function argument `{param.name}` in `{function_name}` is an *args or **kwargs, which should " 

347 f"be annotated with what only one of its items should be." 

348 ) 

349 raise util.ArguablyException( 

350 f"Function argument `{param.name}` in `{function_name}` is a generic type " 

351 f"(`{util.get_origin(value_type)}`), which is not supported." 

352 ) 

353 

354 return value_type, modifiers 

355 

356 

357@dataclass 

358class Command: 

359 """A fully processed command""" 

360 

361 function: Callable 

362 name: str 

363 args: List[CommandArg] 

364 description: str = "" 

365 alias: Optional[str] = None 

366 add_help: bool = True 

367 

368 func_arg_names: Set[str] = field(default_factory=set) 

369 cli_arg_map: Dict[str, CommandArg] = field(default_factory=dict) 

370 

371 def __post_init__(self) -> None: 

372 self.cli_arg_map = dict() 

373 for arg in self.args: 

374 assert arg.func_arg_name not in self.func_arg_names 

375 self.func_arg_names.add(arg.func_arg_name) 

376 

377 if arg.cli_arg_name in self.cli_arg_map: 

378 raise util.ArguablyException( 

379 f"Function argument `{arg.func_arg_name}` in `{self.name}` conflicts with " 

380 f"`{self.cli_arg_map[arg.cli_arg_name].func_arg_name}`, both have the CLI name `{arg.cli_arg_name}`" 

381 ) 

382 self.cli_arg_map[arg.cli_arg_name] = arg 

383 

384 def call(self, parsed_args: Dict[str, Any]) -> Any: 

385 """Filters arguments from argparse to only include the ones used by this command, then calls it""" 

386 

387 args = list() 

388 kwargs = dict() 

389 

390 filtered_args = {k: v for k, v in parsed_args.items() if k in self.func_arg_names} 

391 

392 # Add to either args or kwargs 

393 for arg in self.args: 

394 if arg.input_method.is_positional and not arg.is_variadic: 

395 args.append(filtered_args[arg.func_arg_name]) 

396 elif arg.input_method.is_positional and arg.is_variadic: 

397 args.extend(filtered_args[arg.func_arg_name]) 

398 else: 

399 kwargs[arg.func_arg_name] = filtered_args[arg.func_arg_name] 

400 

401 # Call the function 

402 if util.is_async_callable(self.function): 

403 util.log_args( 

404 util.logger.debug, f"Calling {self.name} function async: ", self.function.__name__, *args, **kwargs 

405 ) 

406 return asyncio.get_event_loop().run_until_complete(self.function(*args, **kwargs)) 

407 else: 

408 util.log_args(util.logger.debug, f"Calling {self.name} function: ", self.function.__name__, *args, **kwargs) 

409 return self.function(*args, **kwargs) 

410 

411 def get_subcommand_metavar(self, command_metavar: str) -> str: 

412 """If this command has a subparser (for subcommands of its own), this can be called to generate a unique name 

413 for the subparser's command metavar""" 

414 if self.name == "__root__": 

415 return command_metavar 

416 return f"{self.name.replace(' ', '_')}{'_' if len(self.name) > 0 else ''}{command_metavar}"