Coverage for arguably/_argparse_extensions.py: 89%

188 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-10 01:01 +0000

1from __future__ import annotations 

2 

3import argparse 

4import enum 

5import sys 

6from gettext import gettext 

7from typing import ( 

8 Callable, 

9 cast, 

10 Any, 

11 TextIO, 

12 IO, 

13 Sequence, 

14 Union, 

15 Optional, 

16 List, 

17 Dict, 

18 Tuple, 

19) 

20 

21import arguably._context as ctx 

22import arguably._modifiers as mods 

23import arguably._util as util 

24 

25 

26def normalize_action_input(values: Union[str, Sequence[Any], None]) -> List[str]: 

27 """Normalize `values` input to be a list""" 

28 if values is None: 

29 return list() 

30 elif isinstance(values, str): 

31 # "-" means empty 

32 return list() if values == "-" else [values] 

33 else: 

34 return list(values) 

35 

36 

37class HelpFormatter(argparse.HelpFormatter): 

38 """HelpFormatter modified for arguably""" 

39 

40 def add_argument(self, action: argparse.Action) -> None: 

41 """ 

42 Corrects _max_action_length for the indenting of subparsers, see https://stackoverflow.com/questions/32888815/ 

43 """ 

44 if action.help is not argparse.SUPPRESS: 

45 # find all invocations 

46 get_invocation = self._format_action_invocation 

47 invocations = [get_invocation(action)] 

48 current_indent = self._current_indent 

49 for subaction in self._iter_indented_subactions(action): 

50 # compensate for the indent that will be added 

51 indent_chg = self._current_indent - current_indent 

52 added_indent = "x" * indent_chg 

53 invocations.append(added_indent + get_invocation(subaction)) 

54 

55 # update the maximum item length 

56 invocation_length = max([len(s) for s in invocations]) 

57 action_length = invocation_length + self._current_indent 

58 self._action_max_length = max(self._action_max_length, action_length) 

59 

60 # add the item to the list 

61 self._add_item(self._format_action, [action]) 

62 

63 def _format_action_invocation(self, action: argparse.Action) -> str: 

64 """Changes metavar printing for parameters, only displays it once""" 

65 if not action.option_strings or action.nargs == 0: 

66 # noinspection PyProtectedMember 

67 return super()._format_action_invocation(action) 

68 default_metavar = self._get_default_metavar_for_optional(action) 

69 args_string = self._format_args(action, default_metavar) 

70 return ", ".join(action.option_strings) + " " + args_string 

71 

72 def _metavar_formatter( 

73 self, 

74 action: argparse.Action, 

75 default_metavar: str, 

76 ) -> Callable[[int], Tuple[str, ...]]: 

77 """Mostly copied from the original _metavar_formatter, but special-cases enum member printing""" 

78 if action.metavar is not None: 

79 result = action.metavar 

80 elif action.choices is not None: 

81 if isinstance(next(iter(action.choices)), enum.Enum): 

82 choice_strs = [choice.name for choice in action.choices] 

83 else: 

84 choice_strs = [str(choice) for choice in action.choices] 

85 result = "{%s}" % ",".join(choice_strs) 

86 else: 

87 result = default_metavar 

88 

89 def _format(tuple_size: int) -> Tuple[str, ...]: 

90 if isinstance(result, tuple): 

91 return result 

92 else: 

93 return (result,) * tuple_size 

94 

95 return _format 

96 

97 def _split_lines(self, text: str, width: int) -> List[str]: 

98 """Copied from the original _split_lines, but we don't replace multiple spaces with only one""" 

99 # text = self._whitespace_matcher.sub(' ', text).strip() 

100 # The textwrap module is used only for formatting help. 

101 # Delay its import for speeding up the common usage of argparse. 

102 import textwrap 

103 

104 return textwrap.wrap(text, width) 

105 

106 def _format_args(self, action: argparse.Action, default_metavar: str) -> str: 

107 """Same as stock, but backport ZERO_OR_MORE behavior for 3.8""" 

108 get_metavar = self._metavar_formatter(action, default_metavar) 

109 if action.nargs is None: 

110 result = "%s" % get_metavar(1) 

111 elif action.nargs == argparse.OPTIONAL: 

112 result = "[%s]" % get_metavar(1) 

113 elif action.nargs == argparse.ZERO_OR_MORE: 

114 metavar = get_metavar(1) 

115 if len(metavar) == 2: 

116 result = "[%s [%s ...]]" % metavar 

117 else: 

118 result = "[%s ...]" % metavar 

119 elif action.nargs == argparse.ONE_OR_MORE: 

120 result = "%s [%s ...]" % get_metavar(2) 

121 elif action.nargs == argparse.REMAINDER: 

122 result = "..." 

123 elif action.nargs == argparse.PARSER: 

124 result = "%s ..." % get_metavar(1) 

125 elif action.nargs == argparse.SUPPRESS: 

126 result = "" 

127 else: 

128 try: 

129 formats = ["%s" for _ in range(action.nargs)] # type: ignore[arg-type] 

130 except TypeError: 

131 raise ValueError("invalid nargs value") from None 

132 result = " ".join(formats) % get_metavar(action.nargs) # type: ignore[arg-type] 

133 return result 

134 

135 

136class ArgumentParser(argparse.ArgumentParser): 

137 """ArgumentParser modified for arguably""" 

138 

139 def __init__(self, *args: Any, output: Optional[TextIO] = None, **kwargs: Any): 

140 """Adds output redirection capabilites""" 

141 super().__init__(*args, **kwargs) 

142 self._output = output 

143 

144 def _print_message(self, message: str, file: Optional[IO[str]] = None) -> None: 

145 """Allows redirecting all prints""" 

146 if message: 

147 # argparse.ArgumentParser defaults to sys.stderr in this function, though most seems to go to stdout 

148 file = self._output or file or sys.stderr 

149 file.write(message) 

150 

151 def _get_value(self, action: argparse.Action, arg_string: str) -> Any: 

152 """Mostly copied from the original _get_value, but prints choices on failure""" 

153 type_func = self._registry_get("type", action.type, action.type) 

154 if not callable(type_func): 

155 msg = gettext("%r is not callable") 

156 raise argparse.ArgumentError(action, msg % type_func) 

157 

158 try: 

159 if isinstance(type_func, type) and issubclass(type_func, enum.Enum): 

160 mapping = ctx.context.get_enum_mapping(type_func) 

161 if arg_string not in mapping: 

162 raise ValueError(arg_string) 

163 result = mapping[arg_string] 

164 else: 

165 result = cast(Callable, type_func)(arg_string) 

166 except argparse.ArgumentTypeError as err: 

167 msg = str(err) 

168 raise argparse.ArgumentError(action, msg) 

169 except (TypeError, ValueError): 

170 name = getattr(action.type, "__name__", repr(action.type)) 

171 # Added code is here 

172 if action.choices is not None: 

173 choice_strs = [ 

174 util.normalize_name(c.name, spaces=False) if isinstance(c, enum.Enum) else str(c) 

175 for c in action.choices 

176 ] 

177 args = {"type": name, "value": arg_string, "choices": ", ".join(repr(c) for c in choice_strs)} 

178 msg = gettext("invalid choice: %(value)r (choose from %(choices)s)") 

179 else: 

180 args = {"type": name, "value": arg_string} 

181 msg = gettext("invalid %(type)s value: %(value)r") 

182 raise argparse.ArgumentError(action, msg % args) 

183 

184 return result 

185 

186 def _check_value(self, action: argparse.Action, value: Any) -> None: 

187 """'Just trust me' for enums, otherwise default behavior""" 

188 type_func = self._registry_get("type", action.type, action.type) 

189 if isinstance(type_func, type) and issubclass(type_func, enum.Enum): 

190 return 

191 

192 # converted value must be one of the choices (if specified) 

193 if action.choices is not None and value not in action.choices: 

194 args = { 

195 "value": value, 

196 "choices": ", ".join( 

197 [ 

198 util.normalize_name(c.name, spaces=False) if isinstance(c, enum.Enum) else repr(c) 

199 for c in action.choices 

200 ] 

201 ), 

202 } 

203 msg = gettext("invalid choice: %(value)r (choose from %(choices)s)") 

204 raise argparse.ArgumentError(action, msg % args) 

205 

206 

207class ListTupleBuilderAction(argparse.Action): 

208 """ 

209 Special action for arguably - handles lists, tuples, and builders. Designed to handle: 

210 * lists - List[int], List[str] 

211 * tuples - Tuple[int, int, int], Tuple[str, float, int] 

212 * builders - Annotated[FooClass, arguably.arg.builder()] 

213 * list of tuples - List[Tuple[int, int]] 

214 * list of builders - Annotated[List[FooClass], arguably.arg.builder()] 

215 """ 

216 

217 def __init__(self, *args: Any, **kwargs: Any) -> None: 

218 self._command_arg = kwargs["command_arg"] 

219 del kwargs["command_arg"] 

220 

221 super().__init__(*args, **kwargs) 

222 

223 # Check if we're handling a list 

224 self._is_list = any(isinstance(m, mods.ListModifier) for m in self._command_arg.modifiers) 

225 

226 # Check if we're handling a tuple (or a list of tuples) 

227 self._is_tuple = any(isinstance(m, mods.TupleModifier) for m in self._command_arg.modifiers) 

228 

229 # Check if we're handling a builder (or a list of builders) 

230 self._is_builder = any(isinstance(m, mods.BuilderModifier) for m in self._command_arg.modifiers) 

231 

232 if self._is_tuple and self._is_builder: 

233 raise util.ArguablyException(f"{'/'.join(self.option_strings)} cannot use both tuple and builder") 

234 

235 # Validate that type is callable 

236 check_type_list = self.type if isinstance(self.type, list) else [self.type] 

237 for type_ in check_type_list: 

238 if not callable(type_): 

239 type_name = f"{self.type}" if not isinstance(self.type, list) else f"{type_} in {self.type}" 

240 raise util.ArguablyException(f"{'/'.join(self.option_strings)} type {type_name} is not callable") 

241 

242 # Keep track of the real type and real nargs, lie to argparse to take in a single (comma-separated) string 

243 assert isinstance(self.type, type) or isinstance(self.type, list) 

244 self._real_type: Union[type, List[type]] = self.type 

245 self.type = str 

246 

247 # Make metavar comma-separated as well 

248 if isinstance(self.metavar, tuple): 

249 self.metavar = ",".join(self.metavar) 

250 

251 def __call__( 

252 self, 

253 parser: argparse.ArgumentParser, 

254 namespace: argparse.Namespace, 

255 value_strs: Union[str, Sequence[Any], None], 

256 option_string: Optional[str] = None, 

257 ) -> None: 

258 value_strs = normalize_action_input(value_strs) 

259 

260 # Split values and convert to self._real_type 

261 values = list() 

262 for value_str in value_strs: 

263 split_value_str = [util.unwrap_quotes(v) for v in util.split_unquoted(value_str, delimeter=",")] 

264 if self._is_tuple: 

265 assert isinstance(self._real_type, list) 

266 if len(split_value_str) != len(self._real_type): 

267 raise argparse.ArgumentError(self, f"expected {len(self._real_type)} values") 

268 value = tuple(type_(str_) for str_, type_ in zip(split_value_str, self._real_type)) 

269 values.append(value) 

270 elif self._is_builder: 

271 values.append(self._build_from_str_values(parser, option_string, split_value_str)) 

272 else: 

273 assert self._is_list 

274 assert isinstance(self._real_type, type) 

275 values.extend(self._real_type(str_) for str_ in split_value_str) 

276 

277 # Set namespace variable 

278 if self._is_list: 

279 items = getattr(namespace, self.dest, list()) 

280 items = argparse._copy_items(items) # type: ignore[attr-defined] 

281 items.extend(values) 

282 setattr(namespace, self.dest, items) 

283 else: 

284 assert len(values) == 1 

285 setattr(namespace, self.dest, values[0]) 

286 

287 def _build_from_str_values( 

288 self, 

289 parser: argparse.ArgumentParser, 

290 option_string: Optional[str], 

291 split_value_str: List[str], 

292 ) -> Any: 

293 """ 

294 Builds a class from the passed-in strings. Example: 

295 split_value_str=['foo', 'bar=123', 'bat=asdf'] -> FooClass(bar=123, bat='asdf') 

296 """ 

297 

298 # Separate out subtype and kwargs 

299 kwargs: Dict[str, Any] = dict() 

300 subtype_ = None 

301 if len(split_value_str) > 0 and "=" not in split_value_str[0]: 

302 subtype_ = split_value_str[0] 

303 kwarg_strs = split_value_str[1:] 

304 else: 

305 kwarg_strs = split_value_str 

306 

307 # Build kwargs dict 

308 for kwarg_str in kwarg_strs: 

309 key, eq, value = kwarg_str.partition("=") 

310 if len(eq) == 0: 

311 raise argparse.ArgumentError( 

312 self, f"type arguments should be of form key=value, {value} does not match" 

313 ) 

314 if key in kwargs: 

315 if not isinstance(kwargs[key], list): 

316 kwargs[key] = [kwargs[key], value] 

317 else: 

318 kwargs[key].append(value) 

319 else: 

320 kwargs[key] = value 

321 

322 # Set the value in the namespace to be a `_BuildTypeSpec`, which will be consumed later to build the class 

323 option_name = "" if option_string is None else option_string.lstrip("-") 

324 with ctx.context.current_parser(parser): 

325 assert isinstance(self._real_type, type) 

326 return ctx.context.resolve_subtype(option_name, self._real_type, subtype_, kwargs) 

327 

328 

329class FlagAction(argparse.Action): 

330 """Special action for arguably - handles `enum.Flag`. Clears default value and ORs together flag values.""" 

331 

332 def __call__( 

333 self, 

334 parser: argparse.ArgumentParser, 

335 namespace: argparse.Namespace, 

336 values: Union[str, Sequence[Any], None], 

337 option_string: Optional[str] = None, 

338 ) -> None: 

339 flag_info = cast(util.EnumFlagInfo, self.const) 

340 value = flag_info.value 

341 

342 if ctx.context.check_and_set_enum_flag_default_status(parser, flag_info.cli_arg_name): 

343 value |= getattr(namespace, flag_info.cli_arg_name) 

344 setattr(namespace, flag_info.cli_arg_name, value)