"""Sphinx extension for automatic Cyclopts CLI documentation.""" from typing import TYPE_CHECKING, Any import attrs from cyclopts import __version__ from cyclopts.utils import import_app if TYPE_CHECKING: from sphinx.application import Sphinx from docutils import nodes from sphinx.application import Sphinx from sphinx.util import logging from sphinx.util.docutils import SphinxDirective logger = logging.getLogger(__name__) @attrs.define(kw_only=True) class DirectiveOptions: """Configuration for the Cyclopts directive.""" heading_level: int = 2 max_heading_level: int = 6 commands: list[str] | None = None exclude_commands: list[str] | None = None # All booleans must have ``False`` default. no_recursive: bool = False include_hidden: bool = False flatten_commands: bool = False code_block_title: bool = False skip_preamble: bool = False @classmethod def from_dict(cls, options: dict) -> "DirectiveOptions": """Create options from directive options dictionary.""" kwargs = {} for field in attrs.fields(cls): # Convert underscore to dash for looking up in options option_name = field.name.replace("_", "-") if field.type is bool: # For boolean fields using directives.flag, presence means True # The value is None when present, absent from dict when not specified if option_name in options: kwargs[field.name] = True # Use default value if not specified elif option_name in options: value = options[option_name] # Handle comma-separated lists for commands and exclude-commands if field.name in ("commands", "exclude_commands"): # Parse comma-separated list and strip whitespace if value: kwargs[field.name] = [cmd.strip() for cmd in value.split(",") if cmd.strip()] else: # Empty string means empty list kwargs[field.name] = [] else: kwargs[field.name] = value # If not specified, the dataclass default will be used return cls(**kwargs) @staticmethod def spec() -> dict[str, Any]: """Generate Sphinx option_spec from DirectiveOptions fields.""" from docutils.parsers.rst import directives type_mapping = { bool: directives.flag, int: directives.nonnegative_int, str: directives.unchanged, } option_spec = {} for field in attrs.fields(DirectiveOptions): option_name = field.name.replace("_", "-") # Handle List[str] fields (commands, exclude-commands) if field.name in ("commands", "exclude_commands"): validator = directives.unchanged # Will be parsed as comma-separated in from_dict else: validator = type_mapping.get(field.type, directives.unchanged) option_spec[option_name] = validator return option_spec def _should_include_command( command_name: str, command_path: list[str], commands_filter: list[str] | None, exclude_commands: list[str] | None, ) -> bool: """Check if a command should be included in documentation. Parameters ---------- command_name : str The name of the command. command_path : list[str] The full path to the command (including parent commands). commands_filter : list[str] | None If specified, only include commands in this list. exclude_commands : list[str] | None If specified, exclude commands in this list. Returns ------- bool True if the command should be included. """ # Build the full command path for nested commands full_path = ".".join(command_path + [command_name]) # Check exclusion list first if exclude_commands: # Check both the command name and full path if command_name in exclude_commands or full_path in exclude_commands: return False # Check if any parent path is excluded for i in range(len(command_path)): parent_path = ".".join(command_path[: i + 1]) if parent_path in exclude_commands: return False # Check inclusion list if commands_filter is not None: # If a filter is specified, only include if explicitly listed # Check if command name or full path is in the filter if command_name in commands_filter or full_path in commands_filter: return True # Check if any parent path is included (to include all subcommands) for i in range(len(command_path)): parent_path = ".".join(command_path[: i + 1]) if parent_path in commands_filter: return True # Also check if just the base command name matches for top-level commands if not command_path and command_name in commands_filter: return True return False # No filter specified, include by default return True def _filter_commands( commands: dict, commands_filter: list[str] | None, exclude_commands: list[str] | None, parent_path: list[str] | None = None, ) -> dict: """Filter commands based on inclusion/exclusion lists. Parameters ---------- commands : dict Dictionary mapping command names to App instances. commands_filter : Optional[List[str]] If specified, only include commands in this list. exclude_commands : Optional[List[str]] If specified, exclude commands in this list. parent_path : List[str] Path to the parent command for nested commands. Returns ------- dict Filtered commands dictionary. """ if parent_path is None: parent_path = [] filtered = {} for name, app in commands.items(): if _should_include_command(name, parent_path, commands_filter, exclude_commands): filtered[name] = app return filtered def _process_rst_content(content: str, skip_title: bool = False) -> list[str]: """Process RST content to remove problematic elements.""" lines = content.splitlines() processed = [] i = 0 while i < len(lines): line = lines[i] # Skip title and underline if requested if skip_title and i == 0 and line.strip() and i + 1 < len(lines): next_line = lines[i + 1].strip() if next_line and set(next_line) <= {"-", "=", "^", "~", '"'}: i += 2 continue # Skip .. contents:: directive if line.strip().startswith(".. contents::"): i += 1 while i < len(lines) and lines[i].strip() and lines[i][0] in " \t": i += 1 if i < len(lines) and not lines[i].strip(): i += 1 continue processed.append(line) i += 1 return processed def _create_section_nodes(lines: list[str], state: Any) -> list["nodes.Node"]: """Create section nodes from RST lines.""" from docutils.statemachine import StringList result = [] i = 0 while i < len(lines): line = lines[i] # Check for section header if i + 1 < len(lines): next_line = lines[i + 1].strip() if next_line and all(c == "-" for c in next_line): # Create section section = nodes.section() title_text = line.strip() section["ids"] = [title_text.lower().replace(" ", "-").replace("cyclopts-", "cli-cyclopts-")] section += nodes.title(text=title_text) # Collect section content content_lines = [] i += 2 # Skip title and underline while i < len(lines): next_line_stripped = lines[i + 1].strip() if i + 1 < len(lines) else "" if next_line_stripped and all(c == "-" for c in next_line_stripped): break content_lines.append(lines[i]) i += 1 if content_lines: state.nested_parse(StringList(content_lines), 0, section) result.append(section) continue # Check for literal block (::) if line.strip() == "::": # Skip the :: line i += 1 # Skip blank line after :: if i < len(lines) and not lines[i].strip(): i += 1 # Collect indented content for the literal block literal_content = [] while i < len(lines) and lines[i].startswith(" "): # Remove the 4-space indentation literal_content.append(lines[i][4:]) i += 1 # Create a literal block node directly if literal_content: literal_block = nodes.literal_block() literal_block.rawsource = "\n".join(literal_content) literal_block.append(nodes.Text("\n".join(literal_content))) result.append(literal_block) # Skip any trailing blank line if i < len(lines) and not lines[i].strip(): i += 1 continue # Regular content - accumulate consecutive lines if line.strip(): content_lines = [line] i += 1 # Collect consecutive non-empty lines that aren't section headers or literal blocks while i < len(lines): # Check if this is a section header next_line = lines[i + 1].strip() if i + 1 < len(lines) else "" if next_line and all(c == "-" for c in next_line): break # Check if this is a literal block if lines[i].strip() == "::": break # Check if this is a blank line if not lines[i].strip(): # Include the blank line and continue to see if there's more content content_lines.append(lines[i]) i += 1 # If the next line is also blank or we're at the end, stop if i >= len(lines) or not lines[i].strip(): break else: # Add non-empty line content_lines.append(lines[i]) i += 1 # Parse all accumulated lines together para = nodes.paragraph() state.nested_parse(StringList(content_lines), 0, para) if para.children: result.extend(para.children) else: i += 1 return result class CycloptsDirective(SphinxDirective): # type: ignore[misc,valid-type] """Sphinx directive for documenting Cyclopts CLI applications.""" has_content = False required_arguments = 1 optional_arguments = 0 final_argument_whitespace = False option_spec = DirectiveOptions.spec() def run(self) -> list["nodes.Node"]: """Generate documentation nodes for the Cyclopts app.""" module_path = self.arguments[0] opts = DirectiveOptions.from_dict(self.options) try: rst_content = self._generate_documentation(module_path, opts) return self._create_nodes(rst_content, opts) except Exception as e: return self._error_node(f"Error generating Cyclopts documentation: {e}") def _generate_documentation(self, module_path: str, opts: DirectiveOptions) -> str: """Generate RST documentation for the app.""" from cyclopts.docs.rst import generate_rst_docs app = import_app(module_path) # Call generate_rst_docs directly to access internal no_root_title parameter return generate_rst_docs( app, recursive=not opts.no_recursive, include_hidden=opts.include_hidden, heading_level=opts.heading_level, max_heading_level=opts.max_heading_level, flatten_commands=opts.flatten_commands, commands_filter=opts.commands, exclude_commands=opts.exclude_commands, no_root_title=True, # Always skip root title in Sphinx context code_block_title=opts.code_block_title, skip_preamble=opts.skip_preamble, ) def _create_nodes(self, rst_content: str, opts: DirectiveOptions) -> list["nodes.Node"]: """Create docutils nodes from RST content.""" lines = _process_rst_content(rst_content, skip_title=False) # Title already skipped in generate_docs # Always use section nodes for better Sphinx integration return _create_section_nodes(lines, self.state) def _error_node(self, message: str) -> list["nodes.Node"]: """Create an error node with the given message.""" logger.error(message) return [nodes.error("", nodes.paragraph(text=message))] def setup(app: "Sphinx") -> dict[str, Any]: """Setup function for the Sphinx extension.""" app.add_directive("cyclopts", CycloptsDirective) return { "version": __version__, "parallel_read_safe": True, "parallel_write_safe": True, }