class FlexibleArgumentParser(ArgumentParser):
"""ArgumentParser that allows both underscore and dash in names."""
_deprecated: set[Action] = set()
_json_tip: str = (
"When passing JSON CLI arguments, the following sets of arguments "
"are equivalent:\n"
' --json-arg \'{"key1": "value1", "key2": {"key3": "value2"}}\'\n'
" --json-arg.key1 value1 --json-arg.key2.key3 value2\n\n"
"Additionally, list elements can be passed individually using +:\n"
' --json-arg \'{"key4": ["value3", "value4", "value5"]}\'\n'
" --json-arg.key4+ value3 --json-arg.key4+='value4,value5'\n\n"
)
_search_keyword: str | None = None
def __init__(self, *args, **kwargs):
# Set the default "formatter_class" to SortedHelpFormatter
if "formatter_class" not in kwargs:
kwargs["formatter_class"] = SortedHelpFormatter
# Pop kwarg "add_json_tip" to control whether to add the JSON tip
self.add_json_tip = kwargs.pop("add_json_tip", True)
super().__init__(*args, **kwargs)
if sys.version_info < (3, 13):
# Enable the deprecated kwarg for Python 3.12 and below
def parse_known_args(self, args=None, namespace=None):
if args is not None and "--disable-log-requests" in args:
# Special case warning because the warning below won't trigger
# if –-disable-log-requests because its value is default.
logger.warning_once(
"argument '--disable-log-requests' is deprecated and "
"replaced with '--enable-log-requests'. This will be "
"removed in v0.12.0."
)
namespace, args = super().parse_known_args(args, namespace)
for action in FlexibleArgumentParser._deprecated:
if (
hasattr(namespace, dest := action.dest)
and getattr(namespace, dest) != action.default
):
logger.warning_once("argument '%s' is deprecated", dest)
return namespace, args
def add_argument(self, *args, **kwargs):
deprecated = kwargs.pop("deprecated", False)
action = super().add_argument(*args, **kwargs)
if deprecated:
FlexibleArgumentParser._deprecated.add(action)
return action
class _FlexibleArgumentGroup(_ArgumentGroup):
def add_argument(self, *args, **kwargs):
deprecated = kwargs.pop("deprecated", False)
action = super().add_argument(*args, **kwargs)
if deprecated:
FlexibleArgumentParser._deprecated.add(action)
return action
def add_argument_group(self, *args, **kwargs):
group = self._FlexibleArgumentGroup(self, *args, **kwargs)
self._action_groups.append(group)
return group
def format_help(self):
# Only use custom help formatting for bottom level parsers
if self._subparsers is not None:
return super().format_help()
formatter = self._get_formatter()
# Handle keyword search of the args
if (search_keyword := self._search_keyword) is not None:
# Normalise the search keyword
search_keyword = search_keyword.lower().replace("_", "-")
# Return full help if searching for 'all'
if search_keyword == "all":
self.epilog = self._json_tip
return super().format_help()
# Return group help if searching for a group title
for group in self._action_groups:
if group.title and group.title.lower() == search_keyword:
formatter.start_section(group.title)
formatter.add_text(group.description)
formatter.add_arguments(group._group_actions)
formatter.end_section()
formatter.add_text(self._json_tip)
return formatter.format_help()
# Return matched args if searching for an arg name
matched_actions = []
for group in self._action_groups:
for action in group._group_actions:
# search option name
if any(
search_keyword in opt.lower() for opt in action.option_strings
):
matched_actions.append(action)
if matched_actions:
formatter.start_section(f"Arguments matching '{search_keyword}'")
formatter.add_arguments(matched_actions)
formatter.end_section()
formatter.add_text(self._json_tip)
return formatter.format_help()
# No match found
formatter.add_text(
f"No group or arguments matching '{search_keyword}'.\n"
"Use '--help' to see available groups or "
"'--help=all' to see all available parameters."
)
return formatter.format_help()
# usage
formatter.add_usage(self.usage, self._actions, self._mutually_exclusive_groups)
# description
formatter.add_text(self.description)
# positionals, optionals and user-defined groups
formatter.start_section("Config Groups")
config_groups = ""
for group in self._action_groups:
if not group._group_actions:
continue
title = group.title
description = group.description or ""
config_groups += f"{title: <24}{description}\n"
formatter.add_text(config_groups)
formatter.end_section()
# epilog
formatter.add_text(self.epilog)
# determine help from format above
return formatter.format_help()
def parse_args( # type: ignore[override]
self,
args: list[str] | None = None,
namespace: Namespace | None = None,
):
if args is None:
args = sys.argv[1:]
# Check for --model in command line arguments first
if args and args[0] == "serve":
try:
model_idx = next(
i
for i, arg in enumerate(args)
if arg == "--model" or arg.startswith("--model=")
)
logger.warning(
"With `vllm serve`, you should provide the model as a "
"positional argument or in a config file instead of via "
"the `--model` option. "
"The `--model` option will be removed in v0.13."
)
if args[model_idx] == "--model":
model_tag = args[model_idx + 1]
rest_start_idx = model_idx + 2
else:
model_tag = args[model_idx].removeprefix("--model=")
rest_start_idx = model_idx + 1
# Move <model> to the front, e,g:
# [Before]
# vllm serve -tp 2 --model <model> --enforce-eager --port 8001
# [After]
# vllm serve <model> -tp 2 --enforce-eager --port 8001
args = [
"serve",
model_tag,
*args[1:model_idx],
*args[rest_start_idx:],
]
except StopIteration:
pass
if "--config" in args:
args = self._pull_args_from_config(args)
def repl(match: re.Match) -> str:
"""Replaces underscores with dashes in the matched string."""
return match.group(0).replace("_", "-")
# Everything between the first -- and the first .
pattern = re.compile(r"(?<=--)[^\.]*")
# Convert underscores to dashes and vice versa in argument names
processed_args = list[str]()
for i, arg in enumerate(args):
if arg.startswith("--help="):
FlexibleArgumentParser._search_keyword = arg.split("=", 1)[-1].lower()
processed_args.append("--help")
elif arg.startswith("--"):
if "=" in arg:
key, value = arg.split("=", 1)
key = pattern.sub(repl, key, count=1)
processed_args.append(f"{key}={value}")
else:
key = pattern.sub(repl, arg, count=1)
processed_args.append(key)
elif arg.startswith("-O") and arg != "-O" and arg[2] != ".":
# allow -O flag to be used without space, e.g. -O3 or -Odecode
# -O.<...> handled later
# also handle -O=<mode> here
mode = arg[3:] if arg[2] == "=" else arg[2:]
processed_args.append(f"-O.mode={mode}")
elif (
arg == "-O"
and i + 1 < len(args)
and args[i + 1] in {"0", "1", "2", "3"}
):
# Convert -O <n> to -O.mode <n>
processed_args.append("-O.mode")
else:
processed_args.append(arg)
def create_nested_dict(keys: list[str], value: str) -> dict[str, Any]:
"""Creates a nested dictionary from a list of keys and a value.
For example, `keys = ["a", "b", "c"]` and `value = 1` will create:
`{"a": {"b": {"c": 1}}}`
"""
nested_dict: Any = value
for key in reversed(keys):
nested_dict = {key: nested_dict}
return nested_dict
def recursive_dict_update(
original: dict[str, Any],
update: dict[str, Any],
) -> set[str]:
"""Recursively updates a dictionary with another dictionary.
Returns a set of duplicate keys that were overwritten.
"""
duplicates = set[str]()
for k, v in update.items():
if isinstance(v, dict) and isinstance(original.get(k), dict):
nested_duplicates = recursive_dict_update(original[k], v)
duplicates |= {f"{k}.{d}" for d in nested_duplicates}
elif isinstance(v, list) and isinstance(original.get(k), list):
original[k] += v
else:
if k in original:
duplicates.add(k)
original[k] = v
return duplicates
delete = set[int]()
dict_args = defaultdict[str, dict[str, Any]](dict)
duplicates = set[str]()
for i, processed_arg in enumerate(processed_args):
if i in delete: # skip if value from previous arg
continue
if processed_arg.startswith("-") and "." in processed_arg:
if "=" in processed_arg:
processed_arg, value_str = processed_arg.split("=", 1)
if "." not in processed_arg:
# False positive, '.' was only in the value
continue
else:
value_str = processed_args[i + 1]
delete.add(i + 1)
if processed_arg.endswith("+"):
processed_arg = processed_arg[:-1]
value_str = json.dumps(list(value_str.split(",")))
key, *keys = processed_arg.split(".")
try:
value = json.loads(value_str)
except json.decoder.JSONDecodeError:
value = value_str
# Merge all values with the same key into a single dict
arg_dict = create_nested_dict(keys, value)
arg_duplicates = recursive_dict_update(dict_args[key], arg_dict)
duplicates |= {f"{key}.{d}" for d in arg_duplicates}
delete.add(i)
# Filter out the dict args we set to None
processed_args = [a for i, a in enumerate(processed_args) if i not in delete]
if duplicates:
logger.warning("Found duplicate keys %s", ", ".join(duplicates))
# Add the dict args back as if they were originally passed as JSON
for dict_arg, dict_value in dict_args.items():
processed_args.append(dict_arg)
processed_args.append(json.dumps(dict_value))
return super().parse_args(processed_args, namespace)
def check_port(self, value):
try:
value = int(value)
except ValueError:
msg = "Port must be an integer"
raise ArgumentTypeError(msg) from None
if not (1024 <= value <= 65535):
raise ArgumentTypeError("Port must be between 1024 and 65535")
return value
def _pull_args_from_config(self, args: list[str]) -> list[str]:
"""Method to pull arguments specified in the config file
into the command-line args variable.
The arguments in config file will be inserted between
the argument list.
example:
```yaml
port: 12323
tensor-parallel-size: 4
```
```python
$: vllm {serve,chat,complete} "facebook/opt-12B" \
--config config.yaml -tp 2
$: args = [
"serve,chat,complete",
"facebook/opt-12B",
'--config', 'config.yaml',
'-tp', '2'
]
$: args = [
"serve,chat,complete",
"facebook/opt-12B",
'--port', '12323',
'--tensor-parallel-size', '4',
'-tp', '2'
]
```
Please note how the config args are inserted after the sub command.
this way the order of priorities is maintained when these are args
parsed by super().
"""
assert args.count("--config") <= 1, "More than one config file specified!"
index = args.index("--config")
if index == len(args) - 1:
raise ValueError(
"No config file specified! \
Please check your command-line arguments."
)
file_path = args[index + 1]
config_args = self.load_config_file(file_path)
# 0th index might be the sub command {serve,chat,complete,...}
# optionally followed by model_tag (only for serve)
# followed by config args
# followed by rest of cli args.
# maintaining this order will enforce the precedence
# of cli > config > defaults
if args[0].startswith("-"):
# No sub command (e.g., api_server entry point)
args = config_args + args[0:index] + args[index + 2 :]
elif args[0] == "serve":
model_in_cli = len(args) > 1 and not args[1].startswith("-")
model_in_config = any(arg == "--model" for arg in config_args)
if not model_in_cli and not model_in_config:
raise ValueError(
"No model specified! Please specify model either "
"as a positional argument or in a config file."
)
if model_in_cli:
# Model specified as positional arg, keep CLI version
args = (
[args[0]]
+ [args[1]]
+ config_args
+ args[2:index]
+ args[index + 2 :]
)
else:
# No model in CLI, use config if available
args = [args[0]] + config_args + args[1:index] + args[index + 2 :]
else:
args = [args[0]] + config_args + args[1:index] + args[index + 2 :]
return args
def load_config_file(self, file_path: str) -> list[str]:
"""Loads a yaml file and returns the key value pairs as a
flattened list with argparse like pattern
```yaml
port: 12323
tensor-parallel-size: 4
```
returns:
processed_args: list[str] = [
'--port': '12323',
'--tensor-parallel-size': '4'
]
"""
extension: str = file_path.split(".")[-1]
if extension not in ("yaml", "yml"):
raise ValueError(
f"Config file must be of a yaml/yml type. {extension} supplied"
)
# only expecting a flat dictionary of atomic types
processed_args: list[str] = []
config: dict[str, int | str] = {}
try:
with open(file_path) as config_file:
config = yaml.safe_load(config_file)
except Exception as ex:
logger.error(
"Unable to read the config file at %s. Check path correctness",
file_path,
)
raise ex
store_boolean_arguments = [
action.dest for action in self._actions if isinstance(action, StoreBoolean)
]
for key, value in config.items():
if isinstance(value, bool) and key not in store_boolean_arguments:
if value:
processed_args.append("--" + key)
elif isinstance(value, list):
if value:
processed_args.append("--" + key)
for item in value:
processed_args.append(str(item))
else:
processed_args.append("--" + key)
processed_args.append(str(value))
return processed_args