Coverage for arguably/_argparse_extensions.py: 89%
188 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-10 01:01 +0000
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-10 01:01 +0000
1from __future__ import annotations
3import argparse
4import enum
5import sys
6from gettext import gettext
7from typing import (
8 Callable,
9 cast,
10 Any,
11 TextIO,
12 IO,
13 Sequence,
14 Union,
15 Optional,
16 List,
17 Dict,
18 Tuple,
19)
21import arguably._context as ctx
22import arguably._modifiers as mods
23import arguably._util as util
26def normalize_action_input(values: Union[str, Sequence[Any], None]) -> List[str]:
27 """Normalize `values` input to be a list"""
28 if values is None:
29 return list()
30 elif isinstance(values, str):
31 # "-" means empty
32 return list() if values == "-" else [values]
33 else:
34 return list(values)
37class HelpFormatter(argparse.HelpFormatter):
38 """HelpFormatter modified for arguably"""
40 def add_argument(self, action: argparse.Action) -> None:
41 """
42 Corrects _max_action_length for the indenting of subparsers, see https://stackoverflow.com/questions/32888815/
43 """
44 if action.help is not argparse.SUPPRESS:
45 # find all invocations
46 get_invocation = self._format_action_invocation
47 invocations = [get_invocation(action)]
48 current_indent = self._current_indent
49 for subaction in self._iter_indented_subactions(action):
50 # compensate for the indent that will be added
51 indent_chg = self._current_indent - current_indent
52 added_indent = "x" * indent_chg
53 invocations.append(added_indent + get_invocation(subaction))
55 # update the maximum item length
56 invocation_length = max([len(s) for s in invocations])
57 action_length = invocation_length + self._current_indent
58 self._action_max_length = max(self._action_max_length, action_length)
60 # add the item to the list
61 self._add_item(self._format_action, [action])
63 def _format_action_invocation(self, action: argparse.Action) -> str:
64 """Changes metavar printing for parameters, only displays it once"""
65 if not action.option_strings or action.nargs == 0:
66 # noinspection PyProtectedMember
67 return super()._format_action_invocation(action)
68 default_metavar = self._get_default_metavar_for_optional(action)
69 args_string = self._format_args(action, default_metavar)
70 return ", ".join(action.option_strings) + " " + args_string
72 def _metavar_formatter(
73 self,
74 action: argparse.Action,
75 default_metavar: str,
76 ) -> Callable[[int], Tuple[str, ...]]:
77 """Mostly copied from the original _metavar_formatter, but special-cases enum member printing"""
78 if action.metavar is not None:
79 result = action.metavar
80 elif action.choices is not None:
81 if isinstance(next(iter(action.choices)), enum.Enum):
82 choice_strs = [choice.name for choice in action.choices]
83 else:
84 choice_strs = [str(choice) for choice in action.choices]
85 result = "{%s}" % ",".join(choice_strs)
86 else:
87 result = default_metavar
89 def _format(tuple_size: int) -> Tuple[str, ...]:
90 if isinstance(result, tuple):
91 return result
92 else:
93 return (result,) * tuple_size
95 return _format
97 def _split_lines(self, text: str, width: int) -> List[str]:
98 """Copied from the original _split_lines, but we don't replace multiple spaces with only one"""
99 # text = self._whitespace_matcher.sub(' ', text).strip()
100 # The textwrap module is used only for formatting help.
101 # Delay its import for speeding up the common usage of argparse.
102 import textwrap
104 return textwrap.wrap(text, width)
106 def _format_args(self, action: argparse.Action, default_metavar: str) -> str:
107 """Same as stock, but backport ZERO_OR_MORE behavior for 3.8"""
108 get_metavar = self._metavar_formatter(action, default_metavar)
109 if action.nargs is None:
110 result = "%s" % get_metavar(1)
111 elif action.nargs == argparse.OPTIONAL:
112 result = "[%s]" % get_metavar(1)
113 elif action.nargs == argparse.ZERO_OR_MORE:
114 metavar = get_metavar(1)
115 if len(metavar) == 2:
116 result = "[%s [%s ...]]" % metavar
117 else:
118 result = "[%s ...]" % metavar
119 elif action.nargs == argparse.ONE_OR_MORE:
120 result = "%s [%s ...]" % get_metavar(2)
121 elif action.nargs == argparse.REMAINDER:
122 result = "..."
123 elif action.nargs == argparse.PARSER:
124 result = "%s ..." % get_metavar(1)
125 elif action.nargs == argparse.SUPPRESS:
126 result = ""
127 else:
128 try:
129 formats = ["%s" for _ in range(action.nargs)] # type: ignore[arg-type]
130 except TypeError:
131 raise ValueError("invalid nargs value") from None
132 result = " ".join(formats) % get_metavar(action.nargs) # type: ignore[arg-type]
133 return result
136class ArgumentParser(argparse.ArgumentParser):
137 """ArgumentParser modified for arguably"""
139 def __init__(self, *args: Any, output: Optional[TextIO] = None, **kwargs: Any):
140 """Adds output redirection capabilites"""
141 super().__init__(*args, **kwargs)
142 self._output = output
144 def _print_message(self, message: str, file: Optional[IO[str]] = None) -> None:
145 """Allows redirecting all prints"""
146 if message:
147 # argparse.ArgumentParser defaults to sys.stderr in this function, though most seems to go to stdout
148 file = self._output or file or sys.stderr
149 file.write(message)
151 def _get_value(self, action: argparse.Action, arg_string: str) -> Any:
152 """Mostly copied from the original _get_value, but prints choices on failure"""
153 type_func = self._registry_get("type", action.type, action.type)
154 if not callable(type_func):
155 msg = gettext("%r is not callable")
156 raise argparse.ArgumentError(action, msg % type_func)
158 try:
159 if isinstance(type_func, type) and issubclass(type_func, enum.Enum):
160 mapping = ctx.context.get_enum_mapping(type_func)
161 if arg_string not in mapping:
162 raise ValueError(arg_string)
163 result = mapping[arg_string]
164 else:
165 result = cast(Callable, type_func)(arg_string)
166 except argparse.ArgumentTypeError as err:
167 msg = str(err)
168 raise argparse.ArgumentError(action, msg)
169 except (TypeError, ValueError):
170 name = getattr(action.type, "__name__", repr(action.type))
171 # Added code is here
172 if action.choices is not None:
173 choice_strs = [
174 util.normalize_name(c.name, spaces=False) if isinstance(c, enum.Enum) else str(c)
175 for c in action.choices
176 ]
177 args = {"type": name, "value": arg_string, "choices": ", ".join(repr(c) for c in choice_strs)}
178 msg = gettext("invalid choice: %(value)r (choose from %(choices)s)")
179 else:
180 args = {"type": name, "value": arg_string}
181 msg = gettext("invalid %(type)s value: %(value)r")
182 raise argparse.ArgumentError(action, msg % args)
184 return result
186 def _check_value(self, action: argparse.Action, value: Any) -> None:
187 """'Just trust me' for enums, otherwise default behavior"""
188 type_func = self._registry_get("type", action.type, action.type)
189 if isinstance(type_func, type) and issubclass(type_func, enum.Enum):
190 return
192 # converted value must be one of the choices (if specified)
193 if action.choices is not None and value not in action.choices:
194 args = {
195 "value": value,
196 "choices": ", ".join(
197 [
198 util.normalize_name(c.name, spaces=False) if isinstance(c, enum.Enum) else repr(c)
199 for c in action.choices
200 ]
201 ),
202 }
203 msg = gettext("invalid choice: %(value)r (choose from %(choices)s)")
204 raise argparse.ArgumentError(action, msg % args)
207class ListTupleBuilderAction(argparse.Action):
208 """
209 Special action for arguably - handles lists, tuples, and builders. Designed to handle:
210 * lists - List[int], List[str]
211 * tuples - Tuple[int, int, int], Tuple[str, float, int]
212 * builders - Annotated[FooClass, arguably.arg.builder()]
213 * list of tuples - List[Tuple[int, int]]
214 * list of builders - Annotated[List[FooClass], arguably.arg.builder()]
215 """
217 def __init__(self, *args: Any, **kwargs: Any) -> None:
218 self._command_arg = kwargs["command_arg"]
219 del kwargs["command_arg"]
221 super().__init__(*args, **kwargs)
223 # Check if we're handling a list
224 self._is_list = any(isinstance(m, mods.ListModifier) for m in self._command_arg.modifiers)
226 # Check if we're handling a tuple (or a list of tuples)
227 self._is_tuple = any(isinstance(m, mods.TupleModifier) for m in self._command_arg.modifiers)
229 # Check if we're handling a builder (or a list of builders)
230 self._is_builder = any(isinstance(m, mods.BuilderModifier) for m in self._command_arg.modifiers)
232 if self._is_tuple and self._is_builder:
233 raise util.ArguablyException(f"{'/'.join(self.option_strings)} cannot use both tuple and builder")
235 # Validate that type is callable
236 check_type_list = self.type if isinstance(self.type, list) else [self.type]
237 for type_ in check_type_list:
238 if not callable(type_):
239 type_name = f"{self.type}" if not isinstance(self.type, list) else f"{type_} in {self.type}"
240 raise util.ArguablyException(f"{'/'.join(self.option_strings)} type {type_name} is not callable")
242 # Keep track of the real type and real nargs, lie to argparse to take in a single (comma-separated) string
243 assert isinstance(self.type, type) or isinstance(self.type, list)
244 self._real_type: Union[type, List[type]] = self.type
245 self.type = str
247 # Make metavar comma-separated as well
248 if isinstance(self.metavar, tuple):
249 self.metavar = ",".join(self.metavar)
251 def __call__(
252 self,
253 parser: argparse.ArgumentParser,
254 namespace: argparse.Namespace,
255 value_strs: Union[str, Sequence[Any], None],
256 option_string: Optional[str] = None,
257 ) -> None:
258 value_strs = normalize_action_input(value_strs)
260 # Split values and convert to self._real_type
261 values = list()
262 for value_str in value_strs:
263 split_value_str = [util.unwrap_quotes(v) for v in util.split_unquoted(value_str, delimeter=",")]
264 if self._is_tuple:
265 assert isinstance(self._real_type, list)
266 if len(split_value_str) != len(self._real_type):
267 raise argparse.ArgumentError(self, f"expected {len(self._real_type)} values")
268 value = tuple(type_(str_) for str_, type_ in zip(split_value_str, self._real_type))
269 values.append(value)
270 elif self._is_builder:
271 values.append(self._build_from_str_values(parser, option_string, split_value_str))
272 else:
273 assert self._is_list
274 assert isinstance(self._real_type, type)
275 values.extend(self._real_type(str_) for str_ in split_value_str)
277 # Set namespace variable
278 if self._is_list:
279 items = getattr(namespace, self.dest, list())
280 items = argparse._copy_items(items) # type: ignore[attr-defined]
281 items.extend(values)
282 setattr(namespace, self.dest, items)
283 else:
284 assert len(values) == 1
285 setattr(namespace, self.dest, values[0])
287 def _build_from_str_values(
288 self,
289 parser: argparse.ArgumentParser,
290 option_string: Optional[str],
291 split_value_str: List[str],
292 ) -> Any:
293 """
294 Builds a class from the passed-in strings. Example:
295 split_value_str=['foo', 'bar=123', 'bat=asdf'] -> FooClass(bar=123, bat='asdf')
296 """
298 # Separate out subtype and kwargs
299 kwargs: Dict[str, Any] = dict()
300 subtype_ = None
301 if len(split_value_str) > 0 and "=" not in split_value_str[0]:
302 subtype_ = split_value_str[0]
303 kwarg_strs = split_value_str[1:]
304 else:
305 kwarg_strs = split_value_str
307 # Build kwargs dict
308 for kwarg_str in kwarg_strs:
309 key, eq, value = kwarg_str.partition("=")
310 if len(eq) == 0:
311 raise argparse.ArgumentError(
312 self, f"type arguments should be of form key=value, {value} does not match"
313 )
314 if key in kwargs:
315 if not isinstance(kwargs[key], list):
316 kwargs[key] = [kwargs[key], value]
317 else:
318 kwargs[key].append(value)
319 else:
320 kwargs[key] = value
322 # Set the value in the namespace to be a `_BuildTypeSpec`, which will be consumed later to build the class
323 option_name = "" if option_string is None else option_string.lstrip("-")
324 with ctx.context.current_parser(parser):
325 assert isinstance(self._real_type, type)
326 return ctx.context.resolve_subtype(option_name, self._real_type, subtype_, kwargs)
329class FlagAction(argparse.Action):
330 """Special action for arguably - handles `enum.Flag`. Clears default value and ORs together flag values."""
332 def __call__(
333 self,
334 parser: argparse.ArgumentParser,
335 namespace: argparse.Namespace,
336 values: Union[str, Sequence[Any], None],
337 option_string: Optional[str] = None,
338 ) -> None:
339 flag_info = cast(util.EnumFlagInfo, self.const)
340 value = flag_info.value
342 if ctx.context.check_and_set_enum_flag_default_status(parser, flag_info.cli_arg_name):
343 value |= getattr(namespace, flag_info.cli_arg_name)
344 setattr(namespace, flag_info.cli_arg_name, value)