Source code for shmpipeline.cli

"""Command-line entry points for shmpipeline."""

from __future__ import annotations

import argparse
import json
import logging
import time
from typing import Sequence

from shmpipeline.config import PipelineConfig
from shmpipeline.graph import PipelineGraph, validate_pipeline_config
from shmpipeline.logging_utils import configure_colored_logging
from shmpipeline.manager import PipelineManager


[docs] def build_parser() -> argparse.ArgumentParser: """Build the top-level CLI parser.""" parser = argparse.ArgumentParser( prog="shmpipeline", description="Shared-memory pipeline tools built on top of pyshmem.", ) parser.add_argument( "--log-level", default="INFO", choices=("DEBUG", "INFO", "WARNING", "ERROR"), help="Set the CLI logging verbosity.", ) subparsers = parser.add_subparsers(dest="command", required=True) validate_parser = subparsers.add_parser( "validate", help="Validate a YAML pipeline configuration.", ) validate_parser.add_argument( "config", help="Path to the YAML pipeline file." ) describe_parser = subparsers.add_parser( "describe", help="Describe the pipeline graph without starting it.", ) describe_parser.add_argument( "config", help="Path to the YAML pipeline file." ) describe_parser.add_argument( "--json", action="store_true", help="Emit machine-readable JSON instead of human text.", ) run_parser = subparsers.add_parser( "run", help=( "Build and start a pipeline until interrupted or until duration " "elapses." ), ) run_parser.add_argument("config", help="Path to the YAML pipeline file.") run_parser.add_argument( "--duration", type=float, default=None, help=( "Optional run duration in seconds. Omit to run until interrupted." ), ) run_parser.add_argument( "--poll-interval", type=float, default=0.25, help="Runtime polling interval in seconds.", ) run_parser.add_argument( "--json-status", action="store_true", help="Print the final runtime snapshot as JSON before exiting.", ) serve_parser = subparsers.add_parser( "serve", help=( "Expose one pipeline manager over HTTP and Server-Sent Events " "for local or remote control." ), ) serve_parser.add_argument( "config", help="Path to the YAML pipeline file to control." ) serve_parser.add_argument( "--host", default="127.0.0.1", help="Host interface to bind. Defaults to loopback only.", ) serve_parser.add_argument( "--port", type=int, default=8765, help="Control-plane TCP port.", ) serve_parser.add_argument( "--token", default=None, help=( "Optional bearer token for API access. Required when binding to " "a non-local interface." ), ) serve_parser.add_argument( "--poll-interval", type=float, default=0.1, help="Manager event polling interval in seconds.", ) return parser
[docs] def main(argv: Sequence[str] | None = None) -> int: """Run the shmpipeline CLI.""" parser = build_parser() args = parser.parse_args(argv) configure_colored_logging(level=getattr(logging, args.log_level)) if args.command == "validate": return _run_validate(args.config) if args.command == "describe": return _run_describe(args.config, as_json=args.json) if args.command == "run": return _run_pipeline( args.config, duration=args.duration, poll_interval=args.poll_interval, emit_json_status=args.json_status, ) if args.command == "serve": return _run_serve( args.config, host=args.host, port=args.port, token=args.token, poll_interval=args.poll_interval, log_level=args.log_level, ) parser.error(f"unsupported command: {args.command}") return 2
def _run_validate(config_path: str) -> int: config = PipelineConfig.from_yaml(config_path) errors = validate_pipeline_config(config) if errors: print("Validation failed:") for error in errors: print(f"- {error}") return 1 print(f"Validation passed: {config_path}") return 0 def _run_describe(config_path: str, *, as_json: bool) -> int: config = PipelineConfig.from_yaml(config_path) errors = validate_pipeline_config(config) if errors: print("Validation failed:") for error in errors: print(f"- {error}") return 1 graph = PipelineGraph.from_config(config) if as_json: print(json.dumps(graph.to_dict(), indent=2, sort_keys=True)) else: print(graph.describe()) return 0 def _run_pipeline( config_path: str, *, duration: float | None, poll_interval: float, emit_json_status: bool, ) -> int: config = PipelineConfig.from_yaml(config_path) errors = validate_pipeline_config(config) if errors: print("Validation failed:") for error in errors: print(f"- {error}") return 1 manager = PipelineManager(config) started_at = time.monotonic() exit_code = 0 try: manager.build() manager.start() while True: snapshot = manager.runtime_snapshot() if snapshot["state"] == "failed": manager.raise_if_failed() if ( duration is not None and (time.monotonic() - started_at) >= duration ): break time.sleep(poll_interval) except KeyboardInterrupt: exit_code = 130 except Exception as exc: print(f"Pipeline run failed: {exc}") exit_code = 1 finally: try: manager.shutdown(force=True) except Exception as exc: print(f"Pipeline shutdown failed: {exc}") exit_code = max(exit_code, 1) if emit_json_status: print(json.dumps(manager.runtime_snapshot(), indent=2, sort_keys=True)) return exit_code def _run_serve( config_path: str, *, host: str, port: int, token: str | None, poll_interval: float, log_level: str, ) -> int: try: from shmpipeline.control.api import run_control_server except ImportError: print( "Control server support requires the control extra: " 'pip install "shmpipeline[control]"' ) return 1 try: run_control_server( config_path, host=host, port=port, token=token, poll_interval=poll_interval, log_level=log_level.lower(), ) except KeyboardInterrupt: return 130 except Exception as exc: print(f"Control server failed: {exc}") return 1 return 0 if __name__ == "__main__": # pragma: no cover - module entry point raise SystemExit(main())