"""Pipeline manager responsible for building and supervising workers."""
from __future__ import annotations
import multiprocessing as mp
import os
import time
from collections import deque
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Any
import pyshmem
from shmpipeline.config import PipelineConfig
from shmpipeline.errors import (
ConfigValidationError,
StateTransitionError,
WorkerProcessError,
)
from shmpipeline.graph import PipelineGraph
from shmpipeline.logging_utils import get_logger
from shmpipeline.registry import KernelRegistry, get_default_registry
from shmpipeline.runtime import drain_events, run_kernel_process
from shmpipeline.scheduling import (
WorkerPlacementPolicy,
normalize_placement_policy,
)
from shmpipeline.shm_cleanup import close_stream, unlink_stream_name
from shmpipeline.state import PipelineState
if TYPE_CHECKING:
from shmpipeline.synthetic import SyntheticInputConfig
@dataclass
class _WorkerHandle:
"""Runtime handle for one worker process."""
name: str
process: Any
stop_event: Any
event_reader: Any | None
cpu_slot: int | None
[docs]
class PipelineManager:
"""Create shared memory, spawn workers, and supervise pipeline state.
The manager owns the pipeline lifecycle from validated config through
runtime monitoring and teardown. It is the primary user-facing runtime API
for Python callers.
"""
def __init__(
self,
config: PipelineConfig | str | Path,
*,
spawn_method: str = "spawn",
placement_policy: WorkerPlacementPolicy | None = None,
registry: KernelRegistry | None = None,
) -> None:
"""Initialize a manager from a config object or YAML path."""
if isinstance(config, (str, Path)):
config = PipelineConfig.from_yaml(config)
self.config = config
self.registry = registry or get_default_registry()
self._runtime_registry = registry
self.context = mp.get_context(spawn_method)
self.placement_policy = normalize_placement_policy(placement_policy)
self._logger = get_logger("manager")
self.state = PipelineState.INITIALIZED
self._streams: dict[str, Any] = {}
self._pause_event: Any | None = None
self._workers: dict[str, _WorkerHandle] = {}
self._failures: list[dict[str, Any]] = []
self._events: deque[dict[str, Any]] = deque(maxlen=256)
self._worker_metrics: dict[str, dict[str, Any]] = {}
self._worker_runtime: dict[str, dict[str, Any]] = {}
self._synthetic_sources: dict[str, Any] = {}
self._kernel_configs = {
kernel.name: kernel for kernel in self.config.kernels
}
self._logger.info(
"manager initialized: spawn_method=%s placement_policy=%s "
"kernels=%d streams=%d",
spawn_method,
self.placement_policy.describe(),
len(self.config.kernels),
len(self.config.shared_memory),
)
@property
def graph(self) -> PipelineGraph:
"""Return a derived graph view of the current configuration."""
return PipelineGraph.from_config(self.config)
def _transition_state(
self, new_state: PipelineState, *, reason: str
) -> None:
"""Update manager state and emit a state-transition log."""
old_state = self.state
self.state = new_state
self._logger.info(
"state transition: %s -> %s (%s)",
old_state.value,
new_state.value,
reason,
)
[docs]
def build(self) -> None:
"""Validate configuration and create shared-memory resources.
This step prepares the pipeline graph and opens or creates the named
streams without starting any worker processes.
"""
if self.state not in {
PipelineState.INITIALIZED,
PipelineState.STOPPED,
}:
raise StateTransitionError(
f"cannot build pipeline while state is {self.state.value!r}"
)
self._logger.info("build started")
self._close_event_readers()
self._events.clear()
self._failures.clear()
self._worker_metrics.clear()
self._worker_runtime.clear()
shared_by_name = self.config.shared_memory_by_name
self._logger.info("validating pipeline config")
graph_errors = self.graph.validation_errors()
if graph_errors:
raise ConfigValidationError(graph_errors[0])
for kernel_config in self.config.kernels:
self.registry.validate(kernel_config, shared_by_name)
self._logger.info(
"validated kernel: name=%s kind=%s input=%s output=%s "
"auxiliary=%s",
kernel_config.name,
kernel_config.kind,
kernel_config.input,
kernel_config.output,
kernel_config.auxiliary_by_alias,
)
for spec in self.config.shared_memory:
self._streams[spec.name] = self._build_stream(spec)
self._transition_state(PipelineState.BUILT, reason="build complete")
def _build_stream(self, spec) -> Any:
"""Create, reuse, or replace one shared-memory stream."""
existing_stream = self._open_existing_stream(spec)
if existing_stream is not None:
reused_stream = self._reuse_or_replace_existing_stream(
spec,
existing_stream,
)
if reused_stream is not None:
return reused_stream
create_kwargs = {
"shape": spec.shape,
"dtype": spec.dtype,
}
if spec.storage == "gpu":
create_kwargs["gpu_device"] = spec.gpu_device
create_kwargs["cpu_mirror"] = spec.cpu_mirror
try:
stream = pyshmem.create(spec.name, **create_kwargs)
except FileExistsError:
self._logger.info(
"shared memory already exists during create; retrying attach: name=%s",
spec.name,
)
existing_stream = self._open_existing_stream(spec)
if existing_stream is not None:
reused_stream = self._reuse_or_replace_existing_stream(
spec,
existing_stream,
)
if reused_stream is not None:
return reused_stream
else:
self._logger.warning(
"shared memory exists but is not attachable; recreating stale stream: name=%s",
spec.name,
)
unlink_stream_name(spec.name)
stream = pyshmem.create(spec.name, **create_kwargs)
self._logger.info(
"created shared memory: name=%s storage=%s shape=%s dtype=%s",
spec.name,
spec.storage,
spec.shape,
spec.dtype,
)
return stream
def _reuse_or_replace_existing_stream(
self, spec, existing_stream
) -> Any | None:
"""Return a reusable existing stream or replace it when incompatible."""
if self._stream_matches_spec(existing_stream, spec):
self._logger.info(
"reusing shared memory: name=%s storage=%s shape=%s dtype=%s",
spec.name,
spec.storage,
spec.shape,
spec.dtype,
)
return existing_stream
self._logger.info(
"replacing incompatible shared memory: name=%s expected_storage=%s "
"expected_shape=%s expected_dtype=%s existing_storage=%s "
"existing_shape=%s existing_dtype=%s",
spec.name,
spec.storage,
spec.shape,
spec.dtype,
"gpu" if existing_stream.gpu_enabled else "cpu",
existing_stream.shape,
existing_stream.dtype,
)
close_stream(existing_stream, unlink=True)
return None
def _open_existing_stream(self, spec) -> Any | None:
"""Open an existing stream if present.
GPU streams are probed without a CUDA attachment so build can inspect
and replace stale streams without reopening their IPC handles.
"""
try:
return pyshmem.open(spec.name)
except FileNotFoundError:
return None
except ValueError:
return None
def _stream_matches_spec(self, stream: Any, spec) -> bool:
"""Return whether an existing stream matches the requested config."""
if spec.storage == "gpu":
return False
if tuple(stream.shape) != tuple(spec.shape):
return False
if stream.dtype != spec.dtype:
return False
if stream.gpu_enabled != (spec.storage == "gpu"):
return False
return True
[docs]
def start(self) -> None:
"""Start worker processes for every configured kernel.
The manager must already be in the built state before workers can be
spawned.
"""
self.poll_events()
if self.state != PipelineState.BUILT:
raise StateTransitionError(
f"cannot start pipeline while state is {self.state.value!r}"
)
self._logger.info("start requested")
self._pause_event = self.context.Event()
self._failures.clear()
self._worker_metrics.clear()
self._worker_runtime.clear()
self._close_event_readers()
cpu_count = max(1, os.cpu_count() or 1)
for index, kernel_config in enumerate(self.config.kernels):
stop_event = self.context.Event()
event_reader, event_writer = self.context.Pipe(duplex=False)
cpu_slot = self.placement_policy.cpu_slot_for(
kernel=kernel_config,
index=index,
cpu_count=cpu_count,
)
process = self.context.Process(
target=run_kernel_process,
args=(
kernel_config,
self.config.shared_memory,
self._pause_event,
stop_event,
event_writer,
cpu_slot,
self._runtime_registry,
),
name=f"shmpipeline-{kernel_config.name}",
)
try:
process.start()
except BaseException:
event_reader.close()
event_writer.close()
raise
event_writer.close()
self._logger.info(
"spawned worker: kernel=%s pid=%s cpu_slot=%s",
kernel_config.name,
process.pid,
cpu_slot,
)
self._workers[kernel_config.name] = _WorkerHandle(
name=kernel_config.name,
process=process,
stop_event=stop_event,
event_reader=event_reader,
cpu_slot=cpu_slot,
)
self._wait_for_workers_started()
self._transition_state(
PipelineState.RUNNING, reason="all workers started"
)
def _wait_for_workers_started(self, *, timeout: float = 5.0) -> None:
"""Block until all worker processes report a started event."""
expected = set(self._workers)
end_time = time.monotonic() + timeout
while time.monotonic() < end_time:
self.poll_events()
started = {
event["kernel"]
for event in self._events
if event.get("type") == "worker_started"
}
if started >= expected:
return
for worker in self._workers.values():
if worker.process.exitcode not in (None, 0):
self.raise_if_failed()
time.sleep(0.01)
missing = sorted(
expected
- {
event["kernel"]
for event in self._events
if event.get("type") == "worker_started"
}
)
raise WorkerProcessError(
"timed out waiting for workers to start: " + ", ".join(missing)
)
[docs]
def pause(self) -> None:
"""Pause all workers without tearing down the built pipeline."""
self.poll_events()
if self.state != PipelineState.RUNNING:
raise StateTransitionError(
f"cannot pause pipeline while state is {self.state.value!r}"
)
assert self._pause_event is not None
self._pause_event.set()
self._transition_state(PipelineState.PAUSED, reason="pause requested")
[docs]
def resume(self) -> None:
"""Resume work after a pause."""
self.poll_events()
if self.state != PipelineState.PAUSED:
raise StateTransitionError(
f"cannot resume pipeline while state is {self.state.value!r}"
)
assert self._pause_event is not None
self._pause_event.clear()
self._transition_state(
PipelineState.RUNNING, reason="resume requested"
)
[docs]
def stop(self, *, timeout: float = 5.0, force: bool = False) -> None:
"""Stop worker processes but keep shared-memory resources built.
After a stop, callers may inspect stream contents, restart workers, or
proceed to a final shutdown.
"""
self.poll_events()
if self.state not in {
PipelineState.RUNNING,
PipelineState.PAUSED,
PipelineState.FAILED,
}:
if self.state == PipelineState.BUILT:
return
raise StateTransitionError(
f"cannot stop pipeline while state is {self.state.value!r}"
)
self._logger.info(
"stop requested: timeout=%s force=%s", timeout, force
)
for worker in self._workers.values():
worker.stop_event.set()
for worker in self._workers.values():
worker.process.join(timeout)
if worker.process.is_alive():
if force and hasattr(worker.process, "kill"):
worker.process.kill()
else:
worker.process.terminate()
worker.process.join(timeout)
self._logger.info(
"worker stopped: kernel=%s pid=%s exitcode=%s",
worker.name,
worker.process.pid,
worker.process.exitcode,
)
self.poll_events()
self._close_event_readers()
self._workers.clear()
self._pause_event = None
self._transition_state(PipelineState.BUILT, reason="workers stopped")
[docs]
def shutdown(self, *, unlink: bool = True, force: bool = False) -> None:
"""Stop workers, close local handles, and optionally unlink streams.
This is the terminal cleanup step for a manager instance.
"""
self._logger.info(
"shutdown requested: unlink=%s force=%s", unlink, force
)
self.stop_all_synthetic_inputs()
if self.state in {
PipelineState.RUNNING,
PipelineState.PAUSED,
PipelineState.FAILED,
}:
self.stop(force=force)
for name, stream in list(self._streams.items()):
try:
close_stream(stream, unlink=unlink)
finally:
self._logger.info(
"released shared memory: name=%s unlink=%s",
name,
unlink,
)
self._streams.clear()
self._transition_state(
PipelineState.STOPPED, reason="shutdown complete"
)
[docs]
def poll_events(self) -> list[dict[str, Any]]:
"""Drain worker events and update manager failure state."""
events = drain_events(
[
worker.event_reader
for worker in self._workers.values()
if worker.event_reader is not None
]
)
if events:
for event in events:
self._events.append(event)
self._log_event(event)
kernel_name = event.get("kernel")
if event.get("type") == "worker_metrics":
observed_at = time.time()
self._worker_metrics[event["kernel"]] = {
"pid": event.get("pid"),
"cpu_slot": event.get("cpu_slot"),
"frames_processed": event.get("frames_processed", 0),
"last_exec_ms": event.get("last_exec_ms"),
"last_exec_us": event.get("last_exec_us"),
"avg_exec_ms": event.get("avg_exec_ms"),
"avg_exec_us": event.get("avg_exec_us"),
"jitter_us_rms": event.get("jitter_us_rms"),
"throughput_hz": event.get("throughput_hz"),
"runtime_s": event.get("runtime_s"),
"last_output_count": event.get("last_output_count"),
"metrics_window": event.get("metrics_window"),
}
runtime = self._worker_runtime.setdefault(
event["kernel"],
{},
)
runtime.setdefault(
"started_at",
observed_at - float(event.get("runtime_s") or 0.0),
)
runtime["last_metric_at"] = observed_at
previous_frames = int(runtime.get("frames_processed") or 0)
previous_output_count = runtime.get("last_output_count")
current_frames = int(event.get("frames_processed") or 0)
current_output_count = event.get("last_output_count")
if (
current_frames > previous_frames
or previous_output_count is None
or current_output_count is None
or current_output_count > previous_output_count
):
runtime["last_progress_at"] = observed_at
runtime["frames_processed"] = current_frames
runtime["last_output_count"] = current_output_count
elif event.get("type") == "worker_started" and kernel_name:
runtime = self._worker_runtime.setdefault(kernel_name, {})
runtime["started_at"] = time.time()
elif event.get("type") == "worker_failed":
self._record_worker_failure(event)
for worker in self._workers.values():
if worker.process.exitcode not in (None, 0):
already_recorded = any(
failure.get("kernel") == worker.name
for failure in self._failures
)
if not already_recorded:
self._record_worker_failure(
{
"type": "worker_failed",
"kernel": worker.name,
"pid": worker.process.pid,
"error": (
"worker exited with code "
f"{worker.process.exitcode}"
),
}
)
if self._failures and self.state in {
PipelineState.RUNNING,
PipelineState.PAUSED,
}:
self._transition_state(
PipelineState.FAILED,
reason="worker failure recorded",
)
return events
def _close_event_readers(self) -> None:
"""Close any open manager-side worker event readers."""
for worker in self._workers.values():
event_reader = worker.event_reader
if event_reader is None:
continue
try:
event_reader.close()
except OSError:
pass
worker.event_reader = None
def _log_event(self, event: dict[str, Any]) -> None:
"""Write manager-readable logs for worker lifecycle events."""
event_type = event.get("type")
if event_type == "worker_started":
self._logger.info(
"worker started: kernel=%s pid=%s cpu_slot=%s",
event.get("kernel"),
event.get("pid"),
event.get("cpu_slot"),
)
return
if event_type == "worker_stopped":
self._logger.info(
"worker stopped event: kernel=%s pid=%s cpu_slot=%s",
event.get("kernel"),
event.get("pid"),
event.get("cpu_slot"),
)
return
if event_type == "worker_failed":
self._logger.error(
"worker failed: kernel=%s pid=%s error=%s",
event.get("kernel"),
event.get("pid"),
event.get("error"),
)
return
if event_type == "worker_metrics":
self._logger.debug(
"worker metrics: kernel=%s frames=%s avg_exec_us=%s "
"jitter_us_rms=%s throughput_hz=%s",
event.get("kernel"),
event.get("frames_processed"),
event.get("avg_exec_us"),
event.get("jitter_us_rms"),
event.get("throughput_hz"),
)
return
self._logger.debug("worker event: %s", event)
def _record_worker_failure(self, failure: dict[str, Any]) -> None:
"""Record one worker failure, preferring detailed failures."""
kernel_name = failure.get("kernel")
if kernel_name is None:
self._failures.append(failure)
return
existing_index: int | None = None
for index, existing_failure in enumerate(self._failures):
if existing_failure.get("kernel") == kernel_name:
existing_index = index
break
if existing_index is None:
self._failures.append(failure)
return
existing_failure = self._failures[existing_index]
if self._failure_is_more_specific(failure, existing_failure):
self._failures[existing_index] = failure
@staticmethod
def _failure_is_more_specific(
candidate: dict[str, Any], existing: dict[str, Any]
) -> bool:
"""Return whether a candidate failure is more useful than existing."""
candidate_has_traceback = bool(candidate.get("traceback"))
existing_has_traceback = bool(existing.get("traceback"))
if candidate_has_traceback != existing_has_traceback:
return candidate_has_traceback
candidate_error = str(candidate.get("error", ""))
existing_error = str(existing.get("error", ""))
candidate_is_generic = candidate_error.startswith(
"worker exited with code"
)
existing_is_generic = existing_error.startswith(
"worker exited with code"
)
if candidate_is_generic != existing_is_generic:
return not candidate_is_generic
return False
[docs]
def status(self, *, poll: bool = True) -> dict[str, Any]:
"""Return a snapshot of manager state, workers, and failures.
The result is intentionally JSON-friendly so CLI, GUI, and external
tooling can consume the same structure.
"""
if poll:
self.poll_events()
workers_status = {
name: self._status_for_worker(name, worker)
for name, worker in self._workers.items()
}
return {
"state": self.state.value,
"workers": workers_status,
"failures": list(self._failures),
"metrics": {
name: dict(metrics)
for name, metrics in self._worker_metrics.items()
},
"synthetic_sources": self.synthetic_input_status(),
"placement_policy": self.placement_policy.describe(),
"summary": self._status_summary(workers_status),
}
def _status_for_worker(
self, name: str, worker: _WorkerHandle
) -> dict[str, Any]:
metrics = dict(self._worker_metrics.get(name, {}))
health, idle_s, last_metric_age_s = self._worker_health(name, worker)
status = {
"pid": worker.process.pid,
"alive": worker.process.is_alive(),
"exitcode": worker.process.exitcode,
"cpu_slot": worker.cpu_slot,
"health": health,
"idle_s": idle_s,
"last_metric_age_s": last_metric_age_s,
**metrics,
}
return status
def _worker_health(
self, name: str, worker: _WorkerHandle
) -> tuple[str, float | None, float | None]:
now = time.time()
runtime = self._worker_runtime.get(name, {})
started_at = runtime.get("started_at")
last_progress_at = runtime.get("last_progress_at")
last_metric_at = runtime.get("last_metric_at")
frames_processed = int(
self._worker_metrics.get(name, {}).get("frames_processed") or 0
)
if any(failure.get("kernel") == name for failure in self._failures):
idle_s = None
if last_progress_at is not None:
idle_s = max(0.0, now - float(last_progress_at))
last_metric_age_s = None
if last_metric_at is not None:
last_metric_age_s = max(0.0, now - float(last_metric_at))
return "failed", idle_s, last_metric_age_s
if not worker.process.is_alive():
return "stopped", None, None
if self.state == PipelineState.PAUSED:
idle_s = None
if last_progress_at is not None:
idle_s = max(0.0, now - float(last_progress_at))
last_metric_age_s = None
if last_metric_at is not None:
last_metric_age_s = max(0.0, now - float(last_metric_at))
return "paused", idle_s, last_metric_age_s
if started_at is None:
return "starting", None, None
reference_at = last_progress_at or started_at
idle_s = max(0.0, now - float(reference_at))
last_metric_age_s = None
if last_metric_at is not None:
last_metric_age_s = max(0.0, now - float(last_metric_at))
if frames_processed <= 0:
return "waiting-input", idle_s, last_metric_age_s
kernel_config = self._kernel_configs.get(name)
idle_threshold_s = 1.0
if kernel_config is not None:
idle_threshold_s = max(1.0, 4.0 * kernel_config.read_timeout)
if self.state == PipelineState.RUNNING and idle_s > idle_threshold_s:
return "idle", idle_s, last_metric_age_s
return "active", idle_s, last_metric_age_s
def _status_summary(
self, workers_status: dict[str, dict[str, Any]]
) -> dict[str, int]:
summary = {
"workers_total": len(workers_status),
"active_workers": 0,
"idle_workers": 0,
"waiting_workers": 0,
"paused_workers": 0,
"failed_workers": 0,
"stopped_workers": 0,
}
for worker in workers_status.values():
health = worker.get("health")
if health == "active":
summary["active_workers"] += 1
elif health == "idle":
summary["idle_workers"] += 1
elif health in {"waiting-input", "starting"}:
summary["waiting_workers"] += 1
elif health == "paused":
summary["paused_workers"] += 1
elif health == "failed":
summary["failed_workers"] += 1
elif health == "stopped":
summary["stopped_workers"] += 1
return summary
[docs]
def runtime_snapshot(self, *, poll: bool = True) -> dict[str, Any]:
"""Return a richer status snapshot for CLI and GUI consumers.
This extends :meth:`status` with a timestamp and the derived graph
description.
"""
status = self.status(poll=poll)
return {
"timestamp": time.time(),
**status,
"graph": self.graph.to_dict(),
}
[docs]
def raise_if_failed(self) -> None:
"""Raise the first worker failure, if any has been recorded."""
self.poll_events()
if not self._failures:
return
failure = self._failures[0]
message = f"worker {failure['kernel']!r} failed: {failure['error']}"
raise WorkerProcessError(message)
[docs]
def get_stream(self, name: str):
"""Return the manager-owned shared-memory handle for one stream."""
return self._streams[name]
@property
def failures(self) -> tuple[dict[str, Any], ...]:
"""Return recorded worker failures."""
self.poll_events()
return tuple(self._failures)
@property
def events(self) -> tuple[dict[str, Any], ...]:
"""Return all events received from worker processes so far."""
self.poll_events()
return tuple(self._events)
def __enter__(self) -> "PipelineManager":
"""Support context-managed pipeline lifecycles."""
return self
def __exit__(self, *_: object) -> None:
"""Ensure worker processes and streams are torn down on exit."""
self.shutdown()