1
0
Fork 1
mirror of https://github.com/thatmattlove/hyperglass.git synced 2026-01-17 08:48:05 +00:00

Implement HyperglassMultiModel to manage multiple objects

This commit is contained in:
thatmattlove 2021-09-17 01:12:33 -07:00
parent 6bc6cf0e1c
commit 99565da0f9
7 changed files with 206 additions and 26 deletions

View file

@ -1,7 +1,7 @@
"""hyperglass Configuration."""
# Local
from .main import params, devices, directives, ui_params
from .main import params, devices, ui_params, directives
__all__ = (
"params",

View file

@ -1,7 +1,7 @@
"""Import configuration files and returns default values if undefined."""
# Standard Library
from typing import Dict, List, Generator
import typing as t
from pathlib import Path
# Third Party
@ -15,7 +15,7 @@ from hyperglass.settings import Settings
from hyperglass.constants import PARSED_RESPONSE_FIELDS, __version__
from hyperglass.models.ui import UIParameters
from hyperglass.util.files import check_path
from hyperglass.models.directive import Directive
from hyperglass.models.directive import Directive, Directives
from hyperglass.exceptions.private import ConfigError, ConfigMissing
from hyperglass.models.config.params import Params
from hyperglass.models.config.devices import Devices
@ -64,7 +64,7 @@ def _check_config_files(directory: Path):
CONFIG_MAIN, CONFIG_DEVICES, CONFIG_DIRECTIVES = _check_config_files(CONFIG_PATH)
def _config_required(config_path: Path) -> Dict:
def _config_required(config_path: Path) -> t.Dict[str, t.Any]:
try:
with config_path.open("r") as cf:
config = yaml.safe_load(cf)
@ -78,7 +78,7 @@ def _config_required(config_path: Path) -> Dict:
return config
def _config_optional(config_path: Path) -> Dict:
def _config_optional(config_path: Path) -> t.Dict[str, t.Any]:
config = {}
@ -96,30 +96,21 @@ def _config_optional(config_path: Path) -> Dict:
return config
def _get_directives(data: Dict) -> List[Directive]:
directives = []
def _get_directives(data: t.Dict[str, t.Any]) -> "Directives":
directives = ()
for name, directive in data.items():
try:
directives.append(Directive(id=name, **directive))
directives += (Directive(id=name, **directive),)
except ValidationError as err:
raise ConfigError(
message="Validation error in directive '{d}': '{e}'", d=name, e=err
) from err
return directives
return Directives(*directives)
def _device_directives(
device: Dict, directives: List[Directive]
) -> Generator[Directive, None, None]:
for directive in directives:
if directive.id in device.get("directives", []):
yield directive
def _get_devices(data: List[Dict], directives: List[Directive]) -> Devices:
def _get_devices(data: t.List[t.Dict[str, t.Any]], directives: "Directives") -> Devices:
for device in data:
directives = list(_device_directives(device, directives))
device["directives"] = directives
device["directives"] = directives.filter_by_ids(*device.get("directives", ()))
return Devices(data)

View file

@ -29,7 +29,7 @@ from .proxy import Proxy
from .params import Params
from ..fields import SupportedDriver
from .network import Network
from ..directive import Directive
from ..directive import Directives
from .credential import Credential
@ -46,7 +46,7 @@ class Device(HyperglassModelWithId, extra="allow"):
port: StrictInt = 22
ssl: Optional[Ssl]
platform: StrictStr
directives: List[Directive]
directives: Directives
structured_output: Optional[StrictBool]
driver: Optional[SupportedDriver]
attrs: Dict[str, str] = {}

View file

@ -18,11 +18,12 @@ from pydantic import (
# Project
from hyperglass.log import log
from hyperglass.types import Series
from hyperglass.settings import Settings
from hyperglass.exceptions.private import InputValidationError
# Local
from .main import HyperglassModel, HyperglassModelWithId
from .main import HyperglassModel, HyperglassMultiModel, HyperglassModelWithId
from .fields import Action
if t.TYPE_CHECKING:
@ -229,6 +230,8 @@ RuleType = t.Union[RuleWithIPv4, RuleWithIPv6, RuleWithPattern, RuleWithoutValid
class Directive(HyperglassModelWithId):
"""A directive contains commands that can be run on a device, as long as defined rules are met."""
__hyperglass_builtin__: t.ClassVar[bool] = False
id: StrictStr
name: StrictStr
rules: t.List[RuleType]
@ -236,6 +239,7 @@ class Directive(HyperglassModelWithId):
info: t.Optional[FilePath]
plugins: t.List[StrictStr] = []
disable_builtins: StrictBool = False
table_output: StrictBool = False
groups: t.List[
StrictStr
] = [] # TODO: Flesh this out. Replace VRFs, but use same logic in React to filter available commands for multi-device queries.
@ -299,3 +303,29 @@ class Directive(HyperglassModelWithId):
value["options"] = [o.export_dict() for o in self.field.options if o is not None]
return value
class NativeDirective(Directive):
"""Natively-supported directive."""
__hyperglass_builtin__: t.ClassVar[bool] = True
platforms: Series[str] = []
DirectiveT = t.Union[NativeDirective, Directive]
class Directives(HyperglassMultiModel[DirectiveT]):
"""Collection of directives."""
def __init__(self, *items: t.Dict[str, t.Any]) -> None:
"""Initialize base class and validate objects."""
super().__init__(*items, model=Directive, accessor="id")
def ids(self) -> t.Tuple[str]:
"""Get all directive IDs."""
return tuple(directive.id for directive in self)
def filter_by_ids(self, *ids) -> "Directives":
"""Filter directives by directive IDs."""
return Directives(*[directive for directive in self if directive.id in ids])

View file

@ -2,16 +2,20 @@
# Standard Library
import re
import typing as t
from pathlib import Path
# Third Party
from pydantic import HttpUrl, BaseModel, BaseConfig
from pydantic import HttpUrl, BaseModel, BaseConfig, PrivateAttr
from pydantic.generics import GenericModel
# Project
from hyperglass.log import log
from hyperglass.util import snake_to_camel, repr_from_attrs
from hyperglass.types import Series
MultiModelT = t.TypeVar("MultiModelT", bound=BaseModel)
class HyperglassModel(BaseModel):
"""Base model for all hyperglass configuration models."""
@ -107,3 +111,110 @@ class HyperglassModelWithId(HyperglassModel):
def __hash__(self: "HyperglassModelWithId") -> int:
"""Create a hashed representation of this model's name."""
return hash(self.id)
class HyperglassMultiModel(GenericModel, t.Generic[MultiModelT]):
"""Extension of HyperglassModel for managing multiple models as a list."""
__root__: t.List[MultiModelT] = []
_accessor: str = PrivateAttr()
_model: MultiModelT = PrivateAttr()
_count: int = PrivateAttr()
class Config(BaseConfig):
"""Pydantic model configuration."""
validate_all = True
extra = "forbid"
validate_assignment = True
def __init__(self, *items: t.Dict[str, t.Any], model: MultiModelT, accessor: str) -> None:
"""Validate items."""
items = self._valid_items(*items, model=model, accessor=accessor)
super().__init__(__root__=items)
self._count = len(items)
self._accessor = accessor
self._model = model
def __iter__(self) -> t.Iterator[MultiModelT]:
"""Iterate items."""
return iter(self.__root__)
def __getitem__(self, value: t.Union[int, str]) -> MultiModelT:
"""Get an item by accessor value."""
if not isinstance(value, (str, int)):
raise TypeError(
"Value of {}.{!s} should be a string or integer. Got {!r} ({!s})".format(
self.__class__.__name__, self.accessor, value, type(value)
)
)
if isinstance(value, int):
return self.__root__[value]
for item in self:
if hasattr(item, self.accessor) and getattr(item, self.accessor) == value:
return item
raise IndexError(
"No match found for {!s}.{!s}={!r}".format(
self.model.__class__.__name__, self.accessor, value
),
)
def __repr__(self) -> str:
"""Represent model."""
return repr_from_attrs(self, ["_count", "_accessor"], strip="_")
@property
def accessor(self) -> str:
"""Access item accessor."""
return self._accessor
@property
def model(self) -> MultiModelT:
"""Access item model class."""
return self._model
@property
def count(self) -> int:
"""Access item count."""
return self._count
@staticmethod
def _valid_items(
*to_validate: t.List[t.Union[MultiModelT, t.Dict[str, t.Any]]],
model: MultiModelT,
accessor: str
) -> t.List[MultiModelT]:
items = [
item
for item in to_validate
if any(
(
(isinstance(item, dict) and accessor in item),
(isinstance(item, model) and hasattr(item, accessor)),
),
)
]
for index, item in enumerate(items):
if isinstance(item, dict):
items[index] = model(**item)
return items
def add(self, *items, unique_by: t.Optional[str] = None) -> None:
"""Add an item to the model."""
to_add = self._valid_items(*items, model=self.model, accessor=self.accessor)
if unique_by is not None:
unique_by_values = {
getattr(obj, unique_by) for obj in (*self, *to_add) if hasattr(obj, unique_by)
}
unique_by_objects = {
v: o
for v in unique_by_values
for o in (*self, *to_add)
if getattr(o, unique_by) == v
}
self.__root__ = list(unique_by_objects.values())
else:
self.__root__ = [*self.__root__, *to_add]
self._count = len(self.__root__)

View file

@ -0,0 +1,44 @@
"""Test HyperglassMultiModel."""
# Third Party
from pydantic import BaseModel
# Local
from ..main import HyperglassMultiModel
class Item(BaseModel):
"""Test item."""
id: str
name: str
ITEMS_1 = [
{"id": "item1", "name": "Item One"},
Item(id="item2", name="Item Two"),
{"id": "item3", "name": "Item Three"},
]
ITEMS_2 = [
Item(id="item4", name="Item Four"),
{"id": "item5", "name": "Item Five"},
]
ITEMS_3 = [
{"id": "item1", "name": "Item New One"},
{"id": "item6", "name": "Item Six"},
]
def test_multi_model():
model = HyperglassMultiModel(*ITEMS_1, model=Item, accessor="id")
assert model.count == 3
assert len([o for o in model]) == model.count # noqa: C416 (Iteration testing)
assert model["item1"].name == "Item One"
model.add(*ITEMS_2)
assert model.count == 5
assert model[3].name == "Item Four"
model.add(*ITEMS_3, unique_by="id")
assert model.count == 6
assert model["item1"].name == "Item New One"

View file

@ -206,7 +206,7 @@ def make_repr(_class):
return f'{_class.__name__}({", ".join(_process_attrs(dir(_class)))})'
def repr_from_attrs(obj: object, attrs: Series[str]) -> str:
def repr_from_attrs(obj: object, attrs: Series[str], strip: t.Optional[str] = None) -> str:
"""Generate a `__repr__()` value from a specific set of attribute names.
Useful for complex models/objects where `__repr__()` should only display specific fields.
@ -215,7 +215,11 @@ def repr_from_attrs(obj: object, attrs: Series[str]) -> str:
attr_names = {a for a in attrs if hasattr(obj, a)}
# Dict representation of attr name to obj value (e.g. `obj.attr`), if the value has a
# `__repr__` method.
attr_values = {f: v for f in attr_names if hasattr((v := getattr(obj, f)), "__repr__")}
attr_values = {
f if strip is None else f.strip(strip): v # noqa: IF100
for f in attr_names
if hasattr((v := getattr(obj, f)), "__repr__")
}
pairs = (f"{k}={v!r}" for k, v in attr_values.items())
return f"{obj.__class__.__name__}({','.join(pairs)})"