Coverage for arguably/_commands.py: 89%
218 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-10 01:01 +0000
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-10 01:01 +0000
1from __future__ import annotations
3import asyncio
4import enum
5import inspect
6import re
7from dataclasses import dataclass, field
8from typing import Callable, Any, Union, Optional, List, Dict, Tuple, cast, Set
10from docstring_parser import parse as docparse
12import arguably._modifiers as mods
13import arguably._util as util
16class InputMethod(enum.Enum):
17 """Specifies how a given argument is passed in"""
19 REQUIRED_POSITIONAL = 0 # usage: foo BAR
20 OPTIONAL_POSITIONAL = 1 # usage: foo [BAR]
21 OPTION = 2 # Examples: -F, --test_scripts, --filename foo.txt
23 @property
24 def is_positional(self) -> bool:
25 return self in [InputMethod.REQUIRED_POSITIONAL, InputMethod.OPTIONAL_POSITIONAL]
27 @property
28 def is_optional(self) -> bool:
29 return self in [InputMethod.OPTIONAL_POSITIONAL, InputMethod.OPTION]
32@dataclass
33class CommandDecoratorInfo:
34 """Used for keeping a reference to everything marked with @arguably.command"""
36 function: Callable
37 alias: Optional[str] = None
38 help: bool = True
39 name: str = field(init=False)
40 command: Command = field(init=False)
42 def __post_init__(self) -> None:
43 if self.function.__name__ == "__root__":
44 self.name = "__root__"
45 elif isinstance(self.function, type):
46 self.name = util.normalize_name(util.camel_case_to_kebab_case(self.function.__name__))
47 else:
48 self.name = util.normalize_name(self.function.__name__)
50 self.command = self._process()
52 def _process(self) -> Command:
53 """Takes the decorator info and return a processed command"""
55 processed_name = self.name
56 func = self.function.__init__ if isinstance(self.function, type) else self.function # type: ignore[misc]
58 # Get the description from the docstring
59 if func.__doc__ is None:
60 docs = None
61 processed_description = ""
62 else:
63 docs = docparse(func.__doc__)
64 processed_description = "" if docs.short_description is None else docs.short_description
66 try:
67 hints = util.get_type_hints(func, include_extras=True)
68 except NameError as e:
69 hints = {}
70 util.warn(f"Unable to resolve type hints for function {processed_name}: {str(e)}", func)
72 # Will be filled in as we loop over all parameters
73 processed_args: List[CommandArg] = list()
75 # Iterate over all parameters
76 for func_arg_name, param in inspect.signature(self.function).parameters.items():
77 cli_arg_name = util.normalize_name(func_arg_name, spaces=False)
78 arg_default = util.NoDefault if param.default is param.empty else param.default
80 # Handle variadic arguments
81 is_variadic = False
82 if param.kind is param.VAR_KEYWORD:
83 raise util.ArguablyException(f"`{processed_name}` is using **kwargs, which is not supported")
84 if param.kind is param.VAR_POSITIONAL:
85 is_variadic = True
87 # Get the type and normalize it
88 arg_value_type, modifiers = CommandArg.normalize_type(processed_name, param, hints)
89 tuple_modifiers = [m for m in modifiers if isinstance(m, mods.TupleModifier)]
90 expected_metavars = 1
91 if len(tuple_modifiers) > 0:
92 assert len(tuple_modifiers) == 1
93 expected_metavars = len(tuple_modifiers[0].tuple_arg)
95 # What kind of argument is this? Is it required-positional, optional-positional, or an option?
96 if param.kind == param.KEYWORD_ONLY:
97 input_method = InputMethod.OPTION
98 elif arg_default is util.NoDefault:
99 input_method = InputMethod.REQUIRED_POSITIONAL
100 else:
101 input_method = InputMethod.OPTIONAL_POSITIONAL
103 # Get the description
104 arg_description = ""
105 if docs is not None and docs.params is not None:
106 ds_matches = [ds_p for ds_p in docs.params if ds_p.arg_name.lstrip("*") == param.name]
107 if len(ds_matches) > 1:
108 raise util.ArguablyException(
109 f"Function argument `{param.name}` in " f"`{processed_name}` has multiple docstring entries."
110 )
111 if len(ds_matches) == 1:
112 ds_info = ds_matches[0]
113 arg_description = "" if ds_info.description is None else ds_info.description
115 # Extract the alias
116 arg_alias = None
117 has_long_name = True
118 if input_method == InputMethod.OPTION:
119 arg_description, arg_alias, long_name = util.parse_short_and_long_name(
120 cli_arg_name, arg_description, func
121 )
122 if long_name is None:
123 has_long_name = False
124 else:
125 cli_arg_name = long_name
126 else:
127 if arg_description.startswith("["):
128 util.warn(
129 f"Function argument `{param.name}` in `{processed_name}` is a positional argument, but starts "
130 f"with a `[`, which is used to specify --option names. To make this argument an --option, make "
131 f"it into be a keyword-only argument.",
132 func,
133 )
135 # Extract the metavars
136 metavars = None
137 if metavar_split := re.split(r"\{((?:[a-zA-Z0-9_-]+(?:, *)*)+)}", arg_description):
138 if len(metavar_split) == 3:
139 # format would be: ['pre-metavar', 'METAVAR', 'post-metavar']
140 match_items = [i.strip() for i in metavar_split[1].split(",")]
141 if is_variadic:
142 if len(match_items) != 1:
143 raise util.ArguablyException(
144 f"Function argument `{param.name}` in `{processed_name}` should only have one item in "
145 f"its metavar descriptor, but found {len(match_items)}: {','.join(match_items)}."
146 )
147 elif len(match_items) != expected_metavars:
148 if len(match_items) == 1:
149 match_items *= expected_metavars
150 else:
151 raise util.ArguablyException(
152 f"Function argument `{param.name}` in `{processed_name}` takes {expected_metavars} "
153 f"items, but metavar descriptor has {len(match_items)}: {','.join(match_items)}."
154 )
155 metavars = [i.upper() for i in match_items]
156 arg_description = "".join(metavar_split) # Strips { and } from metavars for description
157 if len(metavar_split) > 3:
158 raise util.ArguablyException(
159 f"Function argument `{param.name}` in `{processed_name}` has multiple metavar sequences - "
160 f"these are denoted like { A, B, C} . There should be only one."
161 )
163 # Check modifiers
164 for modifier in modifiers:
165 modifier.check_valid(arg_value_type, param, processed_name)
167 # Finished processing this arg
168 processed_args.append(
169 CommandArg(
170 func_arg_name,
171 cli_arg_name,
172 input_method,
173 is_variadic,
174 arg_value_type,
175 arg_description,
176 arg_alias,
177 has_long_name,
178 metavars,
179 arg_default,
180 modifiers,
181 )
182 )
184 # Return the processed command
185 return Command(
186 self.function,
187 processed_name,
188 processed_args,
189 processed_description,
190 self.alias,
191 self.help,
192 )
195@dataclass
196class SubtypeDecoratorInfo:
197 """Used for keeping a reference to everything marked with @arguably.subtype"""
199 type_: type
200 alias: Optional[str] = None
201 ignore: bool = False
202 factory: Optional[Callable] = None
205@dataclass
206class CommandArg:
207 """A single argument to a given command"""
209 func_arg_name: str
210 cli_arg_name: str
212 input_method: InputMethod
213 is_variadic: bool
214 arg_value_type: type
216 description: str
217 alias: Optional[str] = None
218 has_long_name: bool = True
219 metavars: Optional[List[str]] = None
221 default: Any = util.NoDefault
223 modifiers: List[mods.CommandArgModifier] = field(default_factory=list)
225 def __post_init__(self) -> None:
226 if not self.has_long_name and self.alias is None:
227 raise ValueError("CommandArg has no short or long name")
229 def get_options(self) -> Union[Tuple[()], Tuple[str], Tuple[str, str]]:
230 options = list()
231 if self.alias is not None:
232 options.append(f"-{self.alias}")
233 if self.has_long_name:
234 options.append(f"--{self.cli_arg_name}")
235 return cast(Union[Tuple[()], Tuple[str], Tuple[str, str]], tuple(options))
237 @staticmethod
238 def _normalize_type_union(
239 function_name: str,
240 param: inspect.Parameter,
241 value_type: type,
242 ) -> type:
243 """
244 We break this out because Python 3.10 seems to want to wrap `Annotated[Optional[...` in another `Optional`, so
245 we call this twice.
246 """
247 if isinstance(value_type, util.UnionType) or util.get_origin(value_type) is Union:
248 filtered_types = [x for x in util.get_args(value_type) if x is not type(None)]
249 if len(filtered_types) != 1:
250 raise util.ArguablyException(
251 f"Function argument `{param.name}` in `{function_name}` is an unsupported type. It must be either "
252 f"a single, non-generic type or a Union with None."
253 )
254 value_type = filtered_types[0]
255 return value_type
257 @staticmethod
258 def normalize_type(
259 function_name: str,
260 param: inspect.Parameter,
261 hints: Dict[str, Any],
262 ) -> Tuple[type, List[mods.CommandArgModifier]]:
263 """
264 Normalizes the function argument type. Most of the logic here is validation. Explanation of what's returned for
265 a given function argument type:
266 * SomeType -> value_type=SomeType, modifiers=[]
267 * int | None -> value_type=int, modifiers=[]
268 * Tuple[float, float] -> value_type=type(None), modifiers=[_TupleModifier([float, float])]
269 * List[str] -> value_type=str, modifiers=[_ListModifier()]
270 * Annotated[int, arg.count()] -> value_type=int, modifiers=[_CountedModifier()]
272 Things that will cause an exception:
273 * Parameterized type other than a Optional[] or Tuple[]
274 * Flexible-length Tuple[SomeType, ...]
275 * Parameter lacking an annotation
276 """
278 modifiers: List[mods.CommandArgModifier] = list()
280 if param.name in hints:
281 value_type = hints[param.name]
282 else:
283 # No type hint. Guess type from default value, if any other than None. Otherwise, default to string.
284 value_type = type(param.default) if param.default not in [param.empty, None] else str
286 # Extra call to normalize a union here, see note in `_normalize_type_union`
287 value_type = CommandArg._normalize_type_union(function_name, param, value_type)
289 # Handle annotated types
290 if util.get_origin(value_type) == util.Annotated:
291 type_args = util.get_args(value_type)
292 if len(type_args) == 0:
293 raise util.ArguablyException(f"Function argument `{param.name}` is Annotated, but no type is specified")
294 else:
295 value_type = type_args[0]
296 for type_arg in type_args[1:]:
297 if not isinstance(type_arg, mods.CommandArgModifier):
298 raise util.ArguablyException(
299 f"Function argument `{param.name}` has an invalid annotation value: {type_arg}"
300 )
301 modifiers.append(type_arg)
303 # Normalize Union with None
304 value_type = CommandArg._normalize_type_union(function_name, param, value_type)
306 # Validate list/tuple and error on other parameterized types
307 origin = util.get_origin(value_type)
308 if (isinstance(value_type, type) and issubclass(value_type, list)) or (
309 isinstance(origin, type) and issubclass(origin, list)
310 ):
311 type_args = util.get_args(value_type)
312 if len(type_args) == 0:
313 value_type = str
314 elif len(type_args) > 1:
315 raise util.ArguablyException(
316 f"Function argument `{param.name}` in `{function_name}` has too many items passed to List[...]."
317 f"There should be exactly one item between the square brackets."
318 )
319 else:
320 value_type = type_args[0]
321 modifiers.append(mods.ListModifier())
322 elif (isinstance(value_type, type) and issubclass(value_type, tuple)) or (
323 isinstance(origin, type) and issubclass(origin, tuple)
324 ):
325 if param.kind in [param.VAR_KEYWORD, param.VAR_POSITIONAL]:
326 raise util.ArguablyException(
327 f"Function argument `{param.name}` in `{function_name}` is an *args or **kwargs, which should "
328 f"be annotated with what only one of its items should be."
329 )
330 type_args = util.get_args(value_type)
331 if len(type_args) == 0:
332 raise util.ArguablyException(
333 f"Function argument `{param.name}` in `{function_name}` is a tuple but doesn't specify the "
334 f"type of its items, which arguably requires."
335 )
336 if type_args[-1] is Ellipsis:
337 raise util.ArguablyException(
338 f"Function argument `{param.name}` in `{function_name}` is a variable-length tuple, which is "
339 f"not supported."
340 )
341 value_type = type(None)
342 modifiers.append(mods.TupleModifier(list(type_args)))
343 elif origin is not None:
344 if param.kind in [param.VAR_KEYWORD, param.VAR_POSITIONAL]:
345 raise util.ArguablyException(
346 f"Function argument `{param.name}` in `{function_name}` is an *args or **kwargs, which should "
347 f"be annotated with what only one of its items should be."
348 )
349 raise util.ArguablyException(
350 f"Function argument `{param.name}` in `{function_name}` is a generic type "
351 f"(`{util.get_origin(value_type)}`), which is not supported."
352 )
354 return value_type, modifiers
357@dataclass
358class Command:
359 """A fully processed command"""
361 function: Callable
362 name: str
363 args: List[CommandArg]
364 description: str = ""
365 alias: Optional[str] = None
366 add_help: bool = True
368 func_arg_names: Set[str] = field(default_factory=set)
369 cli_arg_map: Dict[str, CommandArg] = field(default_factory=dict)
371 def __post_init__(self) -> None:
372 self.cli_arg_map = dict()
373 for arg in self.args:
374 assert arg.func_arg_name not in self.func_arg_names
375 self.func_arg_names.add(arg.func_arg_name)
377 if arg.cli_arg_name in self.cli_arg_map:
378 raise util.ArguablyException(
379 f"Function argument `{arg.func_arg_name}` in `{self.name}` conflicts with "
380 f"`{self.cli_arg_map[arg.cli_arg_name].func_arg_name}`, both have the CLI name `{arg.cli_arg_name}`"
381 )
382 self.cli_arg_map[arg.cli_arg_name] = arg
384 def call(self, parsed_args: Dict[str, Any]) -> Any:
385 """Filters arguments from argparse to only include the ones used by this command, then calls it"""
387 args = list()
388 kwargs = dict()
390 filtered_args = {k: v for k, v in parsed_args.items() if k in self.func_arg_names}
392 # Add to either args or kwargs
393 for arg in self.args:
394 if arg.input_method.is_positional and not arg.is_variadic:
395 args.append(filtered_args[arg.func_arg_name])
396 elif arg.input_method.is_positional and arg.is_variadic:
397 args.extend(filtered_args[arg.func_arg_name])
398 else:
399 kwargs[arg.func_arg_name] = filtered_args[arg.func_arg_name]
401 # Call the function
402 if util.is_async_callable(self.function):
403 util.log_args(
404 util.logger.debug, f"Calling {self.name} function async: ", self.function.__name__, *args, **kwargs
405 )
406 return asyncio.get_event_loop().run_until_complete(self.function(*args, **kwargs))
407 else:
408 util.log_args(util.logger.debug, f"Calling {self.name} function: ", self.function.__name__, *args, **kwargs)
409 return self.function(*args, **kwargs)
411 def get_subcommand_metavar(self, command_metavar: str) -> str:
412 """If this command has a subparser (for subcommands of its own), this can be called to generate a unique name
413 for the subparser's command metavar"""
414 if self.name == "__root__":
415 return command_metavar
416 return f"{self.name.replace(' ', '_')}{'_' if len(self.name) > 0 else ''}{command_metavar}"