diff --git a/hyperglass/configuration/main.py b/hyperglass/configuration/main.py index 5b4bb80..804c8db 100644 --- a/hyperglass/configuration/main.py +++ b/hyperglass/configuration/main.py @@ -167,9 +167,22 @@ def init_directives() -> "Directives": def init_devices() -> "Devices": """Validate & initialize devices.""" - _user_devices = _config_required(CONFIG_DEVICES) - log.debug("Unvalidated devices from {}: {}", CONFIG_DEVICES, _user_devices) - return Devices(_user_devices.get("devices", _user_devices.get("routers", []))) + devices_config = _config_required(CONFIG_DEVICES) + log.debug("Unvalidated devices from {!s}: {!r}", CONFIG_DEVICES, devices_config) + items = [] + + for key in ("main", "devices", "routers"): + if key in devices_config: + items = devices_config[key] + break + + if len(items) < 1: + raise ConfigError("No devices are defined in devices.yaml") + + devices = Devices(*items) + log.info("Initialized devices {!r}", devices) + + return devices def init_ui_params(*, params: "Params", devices: "Devices") -> "UIParameters": diff --git a/hyperglass/models/config/devices.py b/hyperglass/models/config/devices.py index 0dce116..ab9cf43 100644 --- a/hyperglass/models/config/devices.py +++ b/hyperglass/models/config/devices.py @@ -24,7 +24,7 @@ from hyperglass.exceptions.private import ConfigError, UnsupportedDevice # Local from .ssl import Ssl -from ..main import HyperglassModel, HyperglassModelWithId +from ..main import MultiModel, HyperglassModel, HyperglassModelWithId from ..util import check_legacy_fields from .proxy import Proxy from .params import Params @@ -58,11 +58,12 @@ class Device(HyperglassModelWithId, extra="allow"): driver: Optional[SupportedDriver] attrs: Dict[str, str] = {} - def __init__(self, **kwargs) -> None: - """Set the device ID.""" - kwargs = check_legacy_fields("Device", **kwargs) - _id, values = self._generate_id(kwargs) - super().__init__(id=_id, **values) + def __init__(self, **kw) -> None: + """Check legacy fields and ensure an `id` is set.""" + kw = check_legacy_fields("Device", **kw) + if "id" not in kw: + kw = self._with_id(kw) + super().__init__(**kw) self._validate_directive_attrs() @property @@ -70,7 +71,7 @@ class Device(HyperglassModelWithId, extra="allow"): return str(self.address) @staticmethod - def _generate_id(values: Dict) -> Tuple[str, Dict]: + def _with_id(values: Dict) -> str: """Generate device id & handle legacy display_name field.""" def generate_id(name: str) -> str: @@ -92,7 +93,7 @@ class Device(HyperglassModelWithId, extra="allow"): device_id = generate_id(name) display_name = name - return device_id, {"name": display_name, "display_name": None, **values} + return {"id": device_id, "name": display_name, "display_name": None, **values} def export_api(self) -> Dict[str, Any]: """Export API-facing device fields.""" @@ -256,85 +257,23 @@ class Device(HyperglassModelWithId, extra="allow"): return get_driver(values["platform"], value) -class Devices(HyperglassModel, extra="allow"): - """Validation model for device configurations.""" +class Devices(MultiModel, model=Device, unique_by="id"): + """Container for all devices.""" - ids: List[StrictStr] = [] - hostnames: List[StrictStr] = [] - objects: List[Device] = [] - - def __init__(self, input_params: List[Dict]) -> None: - """Import loaded YAML, initialize per-network definitions. - - Remove unsupported characters from device names, dynamically - set attributes for the devices class. Builds lists of common - attributes for easy access in other modules. - """ - objects = set() - hostnames = set() - ids = set() - - init_kwargs = {} - - for definition in input_params: - # Validate each router config against Router() model/schema - device = Device(**definition) - - # Add router-level attributes (assumed to be unique) to - # class lists, e.g. so all hostnames can be accessed as a - # list with `devices.hostnames`, same for all router - # classes, for when iteration over all routers is required. - hostnames.add(device.name) - ids.add(device.id) - objects.add(device) - - # Convert the de-duplicated sets to a standard list, add lists - # as class attributes. Sort router list by router name attribute - init_kwargs["ids"] = list(ids) - init_kwargs["hostnames"] = list(hostnames) - init_kwargs["objects"] = sorted(objects, key=lambda x: x.name) - - super().__init__(**init_kwargs) - - def __getitem__(self, accessor: str) -> Device: - """Get a device by its name.""" - for device in self.objects: - if device.id == accessor: - return device - elif device.name == accessor: - return device - - raise AttributeError(f"No device named '{accessor}'") + def __init__(self, *items: Dict[str, Any]) -> None: + """Generate IDs prior to validation.""" + with_id = (Device._with_id(item) for item in items) + super().__init__(*with_id) def export_api(self) -> List[Dict[str, Any]]: """Export API-facing device fields.""" - return [d.export_api() for d in self.objects] - - def networks(self, params: Params) -> List[Dict[str, Any]]: - """Group devices by network.""" - names = {device.network.display_name for device in self.objects} - return [ - { - "display_name": name, - "locations": [ - { - "id": device.id, - "name": device.name, - "network": device.network.display_name, - "directives": [d.frontend(params) for d in device.directives], - } - for device in self.objects - if device.network.display_name == name - ], - } - for name in names - ] + return [d.export_api() for d in self] def directive_plugins(self) -> Dict[Path, Tuple[StrictStr]]: """Get a mapping of plugin paths to associated directive IDs.""" result: Dict[Path, Set[StrictStr]] = {} # Unique set of all directives. - directives = {directive for device in self.objects for directive in device.directives} + directives = {directive for device in self for directive in device.directives} # Unique set of all plugin file names. plugin_names = {plugin for directive in directives for plugin in directive.plugins} @@ -346,3 +285,23 @@ class Devices(HyperglassModel, extra="allow"): result[plugin].add(directive.id) # Convert the directive set to a tuple. return {k: tuple(v) for k, v in result.items()} + + def networks(self, params: Params) -> List[Dict[str, Any]]: + """Group devices by network.""" + names = {device.network.display_name for device in self} + return [ + { + "display_name": name, + "locations": [ + { + "id": device.id, + "name": device.name, + "network": device.network.display_name, + "directives": [d.frontend(params) for d in device.directives], + } + for device in self + if device.network.display_name == name + ], + } + for name in names + ]