mirror of
https://github.com/thatmattlove/hyperglass.git
synced 2026-01-17 08:48:05 +00:00
Implement HyperglassMultiModel to manage multiple objects
This commit is contained in:
parent
6bc6cf0e1c
commit
99565da0f9
7 changed files with 206 additions and 26 deletions
|
|
@ -1,7 +1,7 @@
|
|||
"""hyperglass Configuration."""
|
||||
|
||||
# Local
|
||||
from .main import params, devices, directives, ui_params
|
||||
from .main import params, devices, ui_params, directives
|
||||
|
||||
__all__ = (
|
||||
"params",
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
"""Import configuration files and returns default values if undefined."""
|
||||
|
||||
# Standard Library
|
||||
from typing import Dict, List, Generator
|
||||
import typing as t
|
||||
from pathlib import Path
|
||||
|
||||
# Third Party
|
||||
|
|
@ -15,7 +15,7 @@ from hyperglass.settings import Settings
|
|||
from hyperglass.constants import PARSED_RESPONSE_FIELDS, __version__
|
||||
from hyperglass.models.ui import UIParameters
|
||||
from hyperglass.util.files import check_path
|
||||
from hyperglass.models.directive import Directive
|
||||
from hyperglass.models.directive import Directive, Directives
|
||||
from hyperglass.exceptions.private import ConfigError, ConfigMissing
|
||||
from hyperglass.models.config.params import Params
|
||||
from hyperglass.models.config.devices import Devices
|
||||
|
|
@ -64,7 +64,7 @@ def _check_config_files(directory: Path):
|
|||
CONFIG_MAIN, CONFIG_DEVICES, CONFIG_DIRECTIVES = _check_config_files(CONFIG_PATH)
|
||||
|
||||
|
||||
def _config_required(config_path: Path) -> Dict:
|
||||
def _config_required(config_path: Path) -> t.Dict[str, t.Any]:
|
||||
try:
|
||||
with config_path.open("r") as cf:
|
||||
config = yaml.safe_load(cf)
|
||||
|
|
@ -78,7 +78,7 @@ def _config_required(config_path: Path) -> Dict:
|
|||
return config
|
||||
|
||||
|
||||
def _config_optional(config_path: Path) -> Dict:
|
||||
def _config_optional(config_path: Path) -> t.Dict[str, t.Any]:
|
||||
|
||||
config = {}
|
||||
|
||||
|
|
@ -96,30 +96,21 @@ def _config_optional(config_path: Path) -> Dict:
|
|||
return config
|
||||
|
||||
|
||||
def _get_directives(data: Dict) -> List[Directive]:
|
||||
directives = []
|
||||
def _get_directives(data: t.Dict[str, t.Any]) -> "Directives":
|
||||
directives = ()
|
||||
for name, directive in data.items():
|
||||
try:
|
||||
directives.append(Directive(id=name, **directive))
|
||||
directives += (Directive(id=name, **directive),)
|
||||
except ValidationError as err:
|
||||
raise ConfigError(
|
||||
message="Validation error in directive '{d}': '{e}'", d=name, e=err
|
||||
) from err
|
||||
return directives
|
||||
return Directives(*directives)
|
||||
|
||||
|
||||
def _device_directives(
|
||||
device: Dict, directives: List[Directive]
|
||||
) -> Generator[Directive, None, None]:
|
||||
for directive in directives:
|
||||
if directive.id in device.get("directives", []):
|
||||
yield directive
|
||||
|
||||
|
||||
def _get_devices(data: List[Dict], directives: List[Directive]) -> Devices:
|
||||
def _get_devices(data: t.List[t.Dict[str, t.Any]], directives: "Directives") -> Devices:
|
||||
for device in data:
|
||||
directives = list(_device_directives(device, directives))
|
||||
device["directives"] = directives
|
||||
device["directives"] = directives.filter_by_ids(*device.get("directives", ()))
|
||||
return Devices(data)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ from .proxy import Proxy
|
|||
from .params import Params
|
||||
from ..fields import SupportedDriver
|
||||
from .network import Network
|
||||
from ..directive import Directive
|
||||
from ..directive import Directives
|
||||
from .credential import Credential
|
||||
|
||||
|
||||
|
|
@ -46,7 +46,7 @@ class Device(HyperglassModelWithId, extra="allow"):
|
|||
port: StrictInt = 22
|
||||
ssl: Optional[Ssl]
|
||||
platform: StrictStr
|
||||
directives: List[Directive]
|
||||
directives: Directives
|
||||
structured_output: Optional[StrictBool]
|
||||
driver: Optional[SupportedDriver]
|
||||
attrs: Dict[str, str] = {}
|
||||
|
|
|
|||
|
|
@ -18,11 +18,12 @@ from pydantic import (
|
|||
|
||||
# Project
|
||||
from hyperglass.log import log
|
||||
from hyperglass.types import Series
|
||||
from hyperglass.settings import Settings
|
||||
from hyperglass.exceptions.private import InputValidationError
|
||||
|
||||
# Local
|
||||
from .main import HyperglassModel, HyperglassModelWithId
|
||||
from .main import HyperglassModel, HyperglassMultiModel, HyperglassModelWithId
|
||||
from .fields import Action
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
|
|
@ -229,6 +230,8 @@ RuleType = t.Union[RuleWithIPv4, RuleWithIPv6, RuleWithPattern, RuleWithoutValid
|
|||
class Directive(HyperglassModelWithId):
|
||||
"""A directive contains commands that can be run on a device, as long as defined rules are met."""
|
||||
|
||||
__hyperglass_builtin__: t.ClassVar[bool] = False
|
||||
|
||||
id: StrictStr
|
||||
name: StrictStr
|
||||
rules: t.List[RuleType]
|
||||
|
|
@ -236,6 +239,7 @@ class Directive(HyperglassModelWithId):
|
|||
info: t.Optional[FilePath]
|
||||
plugins: t.List[StrictStr] = []
|
||||
disable_builtins: StrictBool = False
|
||||
table_output: StrictBool = False
|
||||
groups: t.List[
|
||||
StrictStr
|
||||
] = [] # TODO: Flesh this out. Replace VRFs, but use same logic in React to filter available commands for multi-device queries.
|
||||
|
|
@ -299,3 +303,29 @@ class Directive(HyperglassModelWithId):
|
|||
value["options"] = [o.export_dict() for o in self.field.options if o is not None]
|
||||
|
||||
return value
|
||||
|
||||
|
||||
class NativeDirective(Directive):
|
||||
"""Natively-supported directive."""
|
||||
|
||||
__hyperglass_builtin__: t.ClassVar[bool] = True
|
||||
platforms: Series[str] = []
|
||||
|
||||
|
||||
DirectiveT = t.Union[NativeDirective, Directive]
|
||||
|
||||
|
||||
class Directives(HyperglassMultiModel[DirectiveT]):
|
||||
"""Collection of directives."""
|
||||
|
||||
def __init__(self, *items: t.Dict[str, t.Any]) -> None:
|
||||
"""Initialize base class and validate objects."""
|
||||
super().__init__(*items, model=Directive, accessor="id")
|
||||
|
||||
def ids(self) -> t.Tuple[str]:
|
||||
"""Get all directive IDs."""
|
||||
return tuple(directive.id for directive in self)
|
||||
|
||||
def filter_by_ids(self, *ids) -> "Directives":
|
||||
"""Filter directives by directive IDs."""
|
||||
return Directives(*[directive for directive in self if directive.id in ids])
|
||||
|
|
|
|||
|
|
@ -2,16 +2,20 @@
|
|||
|
||||
# Standard Library
|
||||
import re
|
||||
import typing as t
|
||||
from pathlib import Path
|
||||
|
||||
# Third Party
|
||||
from pydantic import HttpUrl, BaseModel, BaseConfig
|
||||
from pydantic import HttpUrl, BaseModel, BaseConfig, PrivateAttr
|
||||
from pydantic.generics import GenericModel
|
||||
|
||||
# Project
|
||||
from hyperglass.log import log
|
||||
from hyperglass.util import snake_to_camel, repr_from_attrs
|
||||
from hyperglass.types import Series
|
||||
|
||||
MultiModelT = t.TypeVar("MultiModelT", bound=BaseModel)
|
||||
|
||||
|
||||
class HyperglassModel(BaseModel):
|
||||
"""Base model for all hyperglass configuration models."""
|
||||
|
|
@ -107,3 +111,110 @@ class HyperglassModelWithId(HyperglassModel):
|
|||
def __hash__(self: "HyperglassModelWithId") -> int:
|
||||
"""Create a hashed representation of this model's name."""
|
||||
return hash(self.id)
|
||||
|
||||
|
||||
class HyperglassMultiModel(GenericModel, t.Generic[MultiModelT]):
|
||||
"""Extension of HyperglassModel for managing multiple models as a list."""
|
||||
|
||||
__root__: t.List[MultiModelT] = []
|
||||
_accessor: str = PrivateAttr()
|
||||
_model: MultiModelT = PrivateAttr()
|
||||
_count: int = PrivateAttr()
|
||||
|
||||
class Config(BaseConfig):
|
||||
"""Pydantic model configuration."""
|
||||
|
||||
validate_all = True
|
||||
extra = "forbid"
|
||||
validate_assignment = True
|
||||
|
||||
def __init__(self, *items: t.Dict[str, t.Any], model: MultiModelT, accessor: str) -> None:
|
||||
"""Validate items."""
|
||||
items = self._valid_items(*items, model=model, accessor=accessor)
|
||||
super().__init__(__root__=items)
|
||||
self._count = len(items)
|
||||
self._accessor = accessor
|
||||
self._model = model
|
||||
|
||||
def __iter__(self) -> t.Iterator[MultiModelT]:
|
||||
"""Iterate items."""
|
||||
return iter(self.__root__)
|
||||
|
||||
def __getitem__(self, value: t.Union[int, str]) -> MultiModelT:
|
||||
"""Get an item by accessor value."""
|
||||
if not isinstance(value, (str, int)):
|
||||
raise TypeError(
|
||||
"Value of {}.{!s} should be a string or integer. Got {!r} ({!s})".format(
|
||||
self.__class__.__name__, self.accessor, value, type(value)
|
||||
)
|
||||
)
|
||||
if isinstance(value, int):
|
||||
return self.__root__[value]
|
||||
|
||||
for item in self:
|
||||
if hasattr(item, self.accessor) and getattr(item, self.accessor) == value:
|
||||
return item
|
||||
raise IndexError(
|
||||
"No match found for {!s}.{!s}={!r}".format(
|
||||
self.model.__class__.__name__, self.accessor, value
|
||||
),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Represent model."""
|
||||
return repr_from_attrs(self, ["_count", "_accessor"], strip="_")
|
||||
|
||||
@property
|
||||
def accessor(self) -> str:
|
||||
"""Access item accessor."""
|
||||
return self._accessor
|
||||
|
||||
@property
|
||||
def model(self) -> MultiModelT:
|
||||
"""Access item model class."""
|
||||
return self._model
|
||||
|
||||
@property
|
||||
def count(self) -> int:
|
||||
"""Access item count."""
|
||||
return self._count
|
||||
|
||||
@staticmethod
|
||||
def _valid_items(
|
||||
*to_validate: t.List[t.Union[MultiModelT, t.Dict[str, t.Any]]],
|
||||
model: MultiModelT,
|
||||
accessor: str
|
||||
) -> t.List[MultiModelT]:
|
||||
items = [
|
||||
item
|
||||
for item in to_validate
|
||||
if any(
|
||||
(
|
||||
(isinstance(item, dict) and accessor in item),
|
||||
(isinstance(item, model) and hasattr(item, accessor)),
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
for index, item in enumerate(items):
|
||||
if isinstance(item, dict):
|
||||
items[index] = model(**item)
|
||||
return items
|
||||
|
||||
def add(self, *items, unique_by: t.Optional[str] = None) -> None:
|
||||
"""Add an item to the model."""
|
||||
to_add = self._valid_items(*items, model=self.model, accessor=self.accessor)
|
||||
if unique_by is not None:
|
||||
unique_by_values = {
|
||||
getattr(obj, unique_by) for obj in (*self, *to_add) if hasattr(obj, unique_by)
|
||||
}
|
||||
unique_by_objects = {
|
||||
v: o
|
||||
for v in unique_by_values
|
||||
for o in (*self, *to_add)
|
||||
if getattr(o, unique_by) == v
|
||||
}
|
||||
self.__root__ = list(unique_by_objects.values())
|
||||
else:
|
||||
self.__root__ = [*self.__root__, *to_add]
|
||||
self._count = len(self.__root__)
|
||||
|
|
|
|||
44
hyperglass/models/tests/test_multi_model.py
Normal file
44
hyperglass/models/tests/test_multi_model.py
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
"""Test HyperglassMultiModel."""
|
||||
|
||||
# Third Party
|
||||
from pydantic import BaseModel
|
||||
|
||||
# Local
|
||||
from ..main import HyperglassMultiModel
|
||||
|
||||
|
||||
class Item(BaseModel):
|
||||
"""Test item."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
|
||||
|
||||
ITEMS_1 = [
|
||||
{"id": "item1", "name": "Item One"},
|
||||
Item(id="item2", name="Item Two"),
|
||||
{"id": "item3", "name": "Item Three"},
|
||||
]
|
||||
|
||||
ITEMS_2 = [
|
||||
Item(id="item4", name="Item Four"),
|
||||
{"id": "item5", "name": "Item Five"},
|
||||
]
|
||||
|
||||
ITEMS_3 = [
|
||||
{"id": "item1", "name": "Item New One"},
|
||||
{"id": "item6", "name": "Item Six"},
|
||||
]
|
||||
|
||||
|
||||
def test_multi_model():
|
||||
model = HyperglassMultiModel(*ITEMS_1, model=Item, accessor="id")
|
||||
assert model.count == 3
|
||||
assert len([o for o in model]) == model.count # noqa: C416 (Iteration testing)
|
||||
assert model["item1"].name == "Item One"
|
||||
model.add(*ITEMS_2)
|
||||
assert model.count == 5
|
||||
assert model[3].name == "Item Four"
|
||||
model.add(*ITEMS_3, unique_by="id")
|
||||
assert model.count == 6
|
||||
assert model["item1"].name == "Item New One"
|
||||
|
|
@ -206,7 +206,7 @@ def make_repr(_class):
|
|||
return f'{_class.__name__}({", ".join(_process_attrs(dir(_class)))})'
|
||||
|
||||
|
||||
def repr_from_attrs(obj: object, attrs: Series[str]) -> str:
|
||||
def repr_from_attrs(obj: object, attrs: Series[str], strip: t.Optional[str] = None) -> str:
|
||||
"""Generate a `__repr__()` value from a specific set of attribute names.
|
||||
|
||||
Useful for complex models/objects where `__repr__()` should only display specific fields.
|
||||
|
|
@ -215,7 +215,11 @@ def repr_from_attrs(obj: object, attrs: Series[str]) -> str:
|
|||
attr_names = {a for a in attrs if hasattr(obj, a)}
|
||||
# Dict representation of attr name to obj value (e.g. `obj.attr`), if the value has a
|
||||
# `__repr__` method.
|
||||
attr_values = {f: v for f in attr_names if hasattr((v := getattr(obj, f)), "__repr__")}
|
||||
attr_values = {
|
||||
f if strip is None else f.strip(strip): v # noqa: IF100
|
||||
for f in attr_names
|
||||
if hasattr((v := getattr(obj, f)), "__repr__")
|
||||
}
|
||||
pairs = (f"{k}={v!r}" for k, v in attr_values.items())
|
||||
return f"{obj.__class__.__name__}({','.join(pairs)})"
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue