From 99565da0f9ae0c85b9a09518b77a82f0b11236f1 Mon Sep 17 00:00:00 2001 From: thatmattlove Date: Fri, 17 Sep 2021 01:12:33 -0700 Subject: [PATCH] Implement `HyperglassMultiModel` to manage multiple objects --- hyperglass/configuration/__init__.py | 2 +- hyperglass/configuration/main.py | 29 ++--- hyperglass/models/config/devices.py | 4 +- hyperglass/models/directive.py | 32 +++++- hyperglass/models/main.py | 113 +++++++++++++++++++- hyperglass/models/tests/test_multi_model.py | 44 ++++++++ hyperglass/util/__init__.py | 8 +- 7 files changed, 206 insertions(+), 26 deletions(-) create mode 100644 hyperglass/models/tests/test_multi_model.py diff --git a/hyperglass/configuration/__init__.py b/hyperglass/configuration/__init__.py index b9ce906..fb841e0 100644 --- a/hyperglass/configuration/__init__.py +++ b/hyperglass/configuration/__init__.py @@ -1,7 +1,7 @@ """hyperglass Configuration.""" # Local -from .main import params, devices, directives, ui_params +from .main import params, devices, ui_params, directives __all__ = ( "params", diff --git a/hyperglass/configuration/main.py b/hyperglass/configuration/main.py index 3e46e84..456f1f9 100644 --- a/hyperglass/configuration/main.py +++ b/hyperglass/configuration/main.py @@ -1,7 +1,7 @@ """Import configuration files and returns default values if undefined.""" # Standard Library -from typing import Dict, List, Generator +import typing as t from pathlib import Path # Third Party @@ -15,7 +15,7 @@ from hyperglass.settings import Settings from hyperglass.constants import PARSED_RESPONSE_FIELDS, __version__ from hyperglass.models.ui import UIParameters from hyperglass.util.files import check_path -from hyperglass.models.directive import Directive +from hyperglass.models.directive import Directive, Directives from hyperglass.exceptions.private import ConfigError, ConfigMissing from hyperglass.models.config.params import Params from hyperglass.models.config.devices import Devices @@ -64,7 +64,7 @@ def _check_config_files(directory: Path): CONFIG_MAIN, CONFIG_DEVICES, CONFIG_DIRECTIVES = _check_config_files(CONFIG_PATH) -def _config_required(config_path: Path) -> Dict: +def _config_required(config_path: Path) -> t.Dict[str, t.Any]: try: with config_path.open("r") as cf: config = yaml.safe_load(cf) @@ -78,7 +78,7 @@ def _config_required(config_path: Path) -> Dict: return config -def _config_optional(config_path: Path) -> Dict: +def _config_optional(config_path: Path) -> t.Dict[str, t.Any]: config = {} @@ -96,30 +96,21 @@ def _config_optional(config_path: Path) -> Dict: return config -def _get_directives(data: Dict) -> List[Directive]: - directives = [] +def _get_directives(data: t.Dict[str, t.Any]) -> "Directives": + directives = () for name, directive in data.items(): try: - directives.append(Directive(id=name, **directive)) + directives += (Directive(id=name, **directive),) except ValidationError as err: raise ConfigError( message="Validation error in directive '{d}': '{e}'", d=name, e=err ) from err - return directives + return Directives(*directives) -def _device_directives( - device: Dict, directives: List[Directive] -) -> Generator[Directive, None, None]: - for directive in directives: - if directive.id in device.get("directives", []): - yield directive - - -def _get_devices(data: List[Dict], directives: List[Directive]) -> Devices: +def _get_devices(data: t.List[t.Dict[str, t.Any]], directives: "Directives") -> Devices: for device in data: - directives = list(_device_directives(device, directives)) - device["directives"] = directives + device["directives"] = directives.filter_by_ids(*device.get("directives", ())) return Devices(data) diff --git a/hyperglass/models/config/devices.py b/hyperglass/models/config/devices.py index 0210b75..02ea88f 100644 --- a/hyperglass/models/config/devices.py +++ b/hyperglass/models/config/devices.py @@ -29,7 +29,7 @@ from .proxy import Proxy from .params import Params from ..fields import SupportedDriver from .network import Network -from ..directive import Directive +from ..directive import Directives from .credential import Credential @@ -46,7 +46,7 @@ class Device(HyperglassModelWithId, extra="allow"): port: StrictInt = 22 ssl: Optional[Ssl] platform: StrictStr - directives: List[Directive] + directives: Directives structured_output: Optional[StrictBool] driver: Optional[SupportedDriver] attrs: Dict[str, str] = {} diff --git a/hyperglass/models/directive.py b/hyperglass/models/directive.py index a295348..b5f0049 100644 --- a/hyperglass/models/directive.py +++ b/hyperglass/models/directive.py @@ -18,11 +18,12 @@ from pydantic import ( # Project from hyperglass.log import log +from hyperglass.types import Series from hyperglass.settings import Settings from hyperglass.exceptions.private import InputValidationError # Local -from .main import HyperglassModel, HyperglassModelWithId +from .main import HyperglassModel, HyperglassMultiModel, HyperglassModelWithId from .fields import Action if t.TYPE_CHECKING: @@ -229,6 +230,8 @@ RuleType = t.Union[RuleWithIPv4, RuleWithIPv6, RuleWithPattern, RuleWithoutValid class Directive(HyperglassModelWithId): """A directive contains commands that can be run on a device, as long as defined rules are met.""" + __hyperglass_builtin__: t.ClassVar[bool] = False + id: StrictStr name: StrictStr rules: t.List[RuleType] @@ -236,6 +239,7 @@ class Directive(HyperglassModelWithId): info: t.Optional[FilePath] plugins: t.List[StrictStr] = [] disable_builtins: StrictBool = False + table_output: StrictBool = False groups: t.List[ StrictStr ] = [] # TODO: Flesh this out. Replace VRFs, but use same logic in React to filter available commands for multi-device queries. @@ -299,3 +303,29 @@ class Directive(HyperglassModelWithId): value["options"] = [o.export_dict() for o in self.field.options if o is not None] return value + + +class NativeDirective(Directive): + """Natively-supported directive.""" + + __hyperglass_builtin__: t.ClassVar[bool] = True + platforms: Series[str] = [] + + +DirectiveT = t.Union[NativeDirective, Directive] + + +class Directives(HyperglassMultiModel[DirectiveT]): + """Collection of directives.""" + + def __init__(self, *items: t.Dict[str, t.Any]) -> None: + """Initialize base class and validate objects.""" + super().__init__(*items, model=Directive, accessor="id") + + def ids(self) -> t.Tuple[str]: + """Get all directive IDs.""" + return tuple(directive.id for directive in self) + + def filter_by_ids(self, *ids) -> "Directives": + """Filter directives by directive IDs.""" + return Directives(*[directive for directive in self if directive.id in ids]) diff --git a/hyperglass/models/main.py b/hyperglass/models/main.py index 32174e2..a80a585 100644 --- a/hyperglass/models/main.py +++ b/hyperglass/models/main.py @@ -2,16 +2,20 @@ # Standard Library import re +import typing as t from pathlib import Path # Third Party -from pydantic import HttpUrl, BaseModel, BaseConfig +from pydantic import HttpUrl, BaseModel, BaseConfig, PrivateAttr +from pydantic.generics import GenericModel # Project from hyperglass.log import log from hyperglass.util import snake_to_camel, repr_from_attrs from hyperglass.types import Series +MultiModelT = t.TypeVar("MultiModelT", bound=BaseModel) + class HyperglassModel(BaseModel): """Base model for all hyperglass configuration models.""" @@ -107,3 +111,110 @@ class HyperglassModelWithId(HyperglassModel): def __hash__(self: "HyperglassModelWithId") -> int: """Create a hashed representation of this model's name.""" return hash(self.id) + + +class HyperglassMultiModel(GenericModel, t.Generic[MultiModelT]): + """Extension of HyperglassModel for managing multiple models as a list.""" + + __root__: t.List[MultiModelT] = [] + _accessor: str = PrivateAttr() + _model: MultiModelT = PrivateAttr() + _count: int = PrivateAttr() + + class Config(BaseConfig): + """Pydantic model configuration.""" + + validate_all = True + extra = "forbid" + validate_assignment = True + + def __init__(self, *items: t.Dict[str, t.Any], model: MultiModelT, accessor: str) -> None: + """Validate items.""" + items = self._valid_items(*items, model=model, accessor=accessor) + super().__init__(__root__=items) + self._count = len(items) + self._accessor = accessor + self._model = model + + def __iter__(self) -> t.Iterator[MultiModelT]: + """Iterate items.""" + return iter(self.__root__) + + def __getitem__(self, value: t.Union[int, str]) -> MultiModelT: + """Get an item by accessor value.""" + if not isinstance(value, (str, int)): + raise TypeError( + "Value of {}.{!s} should be a string or integer. Got {!r} ({!s})".format( + self.__class__.__name__, self.accessor, value, type(value) + ) + ) + if isinstance(value, int): + return self.__root__[value] + + for item in self: + if hasattr(item, self.accessor) and getattr(item, self.accessor) == value: + return item + raise IndexError( + "No match found for {!s}.{!s}={!r}".format( + self.model.__class__.__name__, self.accessor, value + ), + ) + + def __repr__(self) -> str: + """Represent model.""" + return repr_from_attrs(self, ["_count", "_accessor"], strip="_") + + @property + def accessor(self) -> str: + """Access item accessor.""" + return self._accessor + + @property + def model(self) -> MultiModelT: + """Access item model class.""" + return self._model + + @property + def count(self) -> int: + """Access item count.""" + return self._count + + @staticmethod + def _valid_items( + *to_validate: t.List[t.Union[MultiModelT, t.Dict[str, t.Any]]], + model: MultiModelT, + accessor: str + ) -> t.List[MultiModelT]: + items = [ + item + for item in to_validate + if any( + ( + (isinstance(item, dict) and accessor in item), + (isinstance(item, model) and hasattr(item, accessor)), + ), + ) + ] + + for index, item in enumerate(items): + if isinstance(item, dict): + items[index] = model(**item) + return items + + def add(self, *items, unique_by: t.Optional[str] = None) -> None: + """Add an item to the model.""" + to_add = self._valid_items(*items, model=self.model, accessor=self.accessor) + if unique_by is not None: + unique_by_values = { + getattr(obj, unique_by) for obj in (*self, *to_add) if hasattr(obj, unique_by) + } + unique_by_objects = { + v: o + for v in unique_by_values + for o in (*self, *to_add) + if getattr(o, unique_by) == v + } + self.__root__ = list(unique_by_objects.values()) + else: + self.__root__ = [*self.__root__, *to_add] + self._count = len(self.__root__) diff --git a/hyperglass/models/tests/test_multi_model.py b/hyperglass/models/tests/test_multi_model.py new file mode 100644 index 0000000..17a3941 --- /dev/null +++ b/hyperglass/models/tests/test_multi_model.py @@ -0,0 +1,44 @@ +"""Test HyperglassMultiModel.""" + +# Third Party +from pydantic import BaseModel + +# Local +from ..main import HyperglassMultiModel + + +class Item(BaseModel): + """Test item.""" + + id: str + name: str + + +ITEMS_1 = [ + {"id": "item1", "name": "Item One"}, + Item(id="item2", name="Item Two"), + {"id": "item3", "name": "Item Three"}, +] + +ITEMS_2 = [ + Item(id="item4", name="Item Four"), + {"id": "item5", "name": "Item Five"}, +] + +ITEMS_3 = [ + {"id": "item1", "name": "Item New One"}, + {"id": "item6", "name": "Item Six"}, +] + + +def test_multi_model(): + model = HyperglassMultiModel(*ITEMS_1, model=Item, accessor="id") + assert model.count == 3 + assert len([o for o in model]) == model.count # noqa: C416 (Iteration testing) + assert model["item1"].name == "Item One" + model.add(*ITEMS_2) + assert model.count == 5 + assert model[3].name == "Item Four" + model.add(*ITEMS_3, unique_by="id") + assert model.count == 6 + assert model["item1"].name == "Item New One" diff --git a/hyperglass/util/__init__.py b/hyperglass/util/__init__.py index 5a4aa1a..12ed801 100644 --- a/hyperglass/util/__init__.py +++ b/hyperglass/util/__init__.py @@ -206,7 +206,7 @@ def make_repr(_class): return f'{_class.__name__}({", ".join(_process_attrs(dir(_class)))})' -def repr_from_attrs(obj: object, attrs: Series[str]) -> str: +def repr_from_attrs(obj: object, attrs: Series[str], strip: t.Optional[str] = None) -> str: """Generate a `__repr__()` value from a specific set of attribute names. Useful for complex models/objects where `__repr__()` should only display specific fields. @@ -215,7 +215,11 @@ def repr_from_attrs(obj: object, attrs: Series[str]) -> str: attr_names = {a for a in attrs if hasattr(obj, a)} # Dict representation of attr name to obj value (e.g. `obj.attr`), if the value has a # `__repr__` method. - attr_values = {f: v for f in attr_names if hasattr((v := getattr(obj, f)), "__repr__")} + attr_values = { + f if strip is None else f.strip(strip): v # noqa: IF100 + for f in attr_names + if hasattr((v := getattr(obj, f)), "__repr__") + } pairs = (f"{k}={v!r}" for k, v in attr_values.items()) return f"{obj.__class__.__name__}({','.join(pairs)})"