mirror of
https://github.com/thatmattlove/hyperglass.git
synced 2026-01-17 08:48:05 +00:00
Overhaul configuration initialization process, add missing device to directive association
This commit is contained in:
parent
af0d5345bf
commit
cd87c254e4
12 changed files with 319 additions and 142 deletions
|
|
@ -1,11 +1,32 @@
|
|||
"""hyperglass Configuration."""
|
||||
|
||||
# Local
|
||||
from .main import params, devices, ui_params, directives
|
||||
# Project
|
||||
from hyperglass.log import log
|
||||
from hyperglass.state import use_state
|
||||
from hyperglass.defaults.directives import init_builtin_directives
|
||||
|
||||
__all__ = (
|
||||
"params",
|
||||
"devices",
|
||||
"directives",
|
||||
"ui_params",
|
||||
)
|
||||
# Local
|
||||
from .main import init_params, init_devices, init_ui_params, init_directives
|
||||
|
||||
__all__ = ("init_user_config",)
|
||||
|
||||
|
||||
def init_user_config() -> None:
|
||||
"""Initialize all user configurations and add them to global state."""
|
||||
state = use_state()
|
||||
|
||||
params = init_params()
|
||||
builtins = init_builtin_directives()
|
||||
custom = init_directives()
|
||||
directives = builtins + custom
|
||||
with state.cache.pipeline() as pipeline:
|
||||
# Write params and directives to the cache first to avoid a race condition where ui_params
|
||||
# or devices try to access params or directives before they're available.
|
||||
pipeline.set("params", params)
|
||||
pipeline.set("directives", directives)
|
||||
|
||||
devices = init_devices()
|
||||
ui_params = init_ui_params(params=params, devices=devices)
|
||||
with state.cache.pipeline() as pipeline:
|
||||
pipeline.set("devices", devices)
|
||||
pipeline.set("ui_params", ui_params)
|
||||
|
|
|
|||
|
|
@ -10,9 +10,7 @@ from pydantic import ValidationError
|
|||
|
||||
# Project
|
||||
from hyperglass.log import log, enable_file_logging, enable_syslog_logging
|
||||
from hyperglass.defaults import CREDIT
|
||||
from hyperglass.settings import Settings
|
||||
from hyperglass.constants import PARSED_RESPONSE_FIELDS, __version__
|
||||
from hyperglass.models.ui import UIParameters
|
||||
from hyperglass.util.files import check_path
|
||||
from hyperglass.models.directive import Directive, Directives
|
||||
|
|
@ -24,11 +22,15 @@ from hyperglass.models.config.devices import Devices
|
|||
from .markdown import get_markdown
|
||||
from .validation import validate_config
|
||||
|
||||
CONFIG_PATH = Settings.app_path
|
||||
log.info("Configuration directory: {d!s}", d=CONFIG_PATH)
|
||||
__all__ = (
|
||||
"init_params",
|
||||
"init_directives",
|
||||
"init_devices",
|
||||
"init_ui_params",
|
||||
)
|
||||
|
||||
# Project Directories
|
||||
WORKING_DIR = Path(__file__).resolve().parent
|
||||
CONFIG_PATH = Settings.app_path
|
||||
CONFIG_FILES = (
|
||||
("hyperglass.yaml", False),
|
||||
("devices.yaml", True),
|
||||
|
|
@ -108,81 +110,88 @@ def _get_directives(data: t.Dict[str, t.Any]) -> "Directives":
|
|||
return Directives(*directives)
|
||||
|
||||
|
||||
def _get_devices(data: t.List[t.Dict[str, t.Any]], directives: "Directives") -> Devices:
|
||||
for device in data:
|
||||
device["directives"] = directives.filter_by_ids(*device.get("directives", ()))
|
||||
return Devices(data)
|
||||
def init_params() -> "Params":
|
||||
"""Validate & initialize configuration parameters."""
|
||||
user_config = _config_optional(CONFIG_MAIN)
|
||||
# Map imported user configuration to expected schema.
|
||||
log.debug("Unvalidated configuration from {}: {}", CONFIG_MAIN, user_config)
|
||||
params = validate_config(config=user_config, importer=Params)
|
||||
|
||||
|
||||
user_config = _config_optional(CONFIG_MAIN)
|
||||
|
||||
# Read raw debug value from config to enable debugging quickly.
|
||||
|
||||
# Map imported user configuration to expected schema.
|
||||
log.debug("Unvalidated configuration from {}: {}", CONFIG_MAIN, user_config)
|
||||
params = validate_config(config=user_config, importer=Params)
|
||||
|
||||
# Map imported user directives to expected schema.
|
||||
_user_directives = _config_optional(CONFIG_DIRECTIVES)
|
||||
log.debug("Unvalidated directives from {!s}: {}", CONFIG_DIRECTIVES, _user_directives)
|
||||
directives = _get_directives(_user_directives)
|
||||
|
||||
# Map imported user devices to expected schema.
|
||||
_user_devices = _config_required(CONFIG_DEVICES)
|
||||
log.debug("Unvalidated devices from {}: {}", CONFIG_DEVICES, _user_devices)
|
||||
devices = _get_devices(_user_devices.get("devices", _user_devices.get("routers", [])), directives)
|
||||
|
||||
|
||||
# Set up file logging once configuration parameters are initialized.
|
||||
enable_file_logging(
|
||||
logger=log,
|
||||
log_directory=params.logging.directory,
|
||||
log_format=params.logging.format,
|
||||
log_max_size=params.logging.max_size,
|
||||
)
|
||||
|
||||
# Set up syslog logging if enabled.
|
||||
if params.logging.syslog is not None and params.logging.syslog.enable:
|
||||
enable_syslog_logging(
|
||||
logger=log, syslog_host=params.logging.syslog.host, syslog_port=params.logging.syslog.port,
|
||||
# Set up file logging once configuration parameters are initialized.
|
||||
enable_file_logging(
|
||||
logger=log,
|
||||
log_directory=params.logging.directory,
|
||||
log_format=params.logging.format,
|
||||
log_max_size=params.logging.max_size,
|
||||
)
|
||||
|
||||
if params.logging.http is not None and params.logging.http.enable:
|
||||
log.debug("HTTP logging is enabled")
|
||||
# Set up syslog logging if enabled.
|
||||
if params.logging.syslog is not None and params.logging.syslog.enable:
|
||||
enable_syslog_logging(
|
||||
logger=log,
|
||||
syslog_host=params.logging.syslog.host,
|
||||
syslog_port=params.logging.syslog.port,
|
||||
)
|
||||
|
||||
# Perform post-config initialization string formatting or other
|
||||
# functions that require access to other config levels. E.g.,
|
||||
# something in 'params.web.text' needs to be formatted with a value
|
||||
# from params.
|
||||
try:
|
||||
params.web.text.subtitle = params.web.text.subtitle.format(
|
||||
**params.dict(exclude={"web", "queries", "messages"})
|
||||
if params.logging.http is not None and params.logging.http.enable:
|
||||
log.debug("HTTP logging is enabled")
|
||||
|
||||
# Perform post-config initialization string formatting or other
|
||||
# functions that require access to other config levels. E.g.,
|
||||
# something in 'params.web.text' needs to be formatted with a value
|
||||
# from params.
|
||||
try:
|
||||
params.web.text.subtitle = params.web.text.subtitle.format(
|
||||
**params.dict(exclude={"web", "queries", "messages"})
|
||||
)
|
||||
|
||||
# If keywords are unmodified (default), add the org name &
|
||||
# site_title.
|
||||
if Params().site_keywords == params.site_keywords:
|
||||
params.site_keywords = sorted(
|
||||
{*params.site_keywords, params.org_name, params.site_title}
|
||||
)
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
return params
|
||||
|
||||
|
||||
def init_directives() -> "Directives":
|
||||
"""Validate & initialize directives."""
|
||||
# Map imported user directives to expected schema.
|
||||
_user_directives = _config_optional(CONFIG_DIRECTIVES)
|
||||
log.debug("Unvalidated directives from {!s}: {}", CONFIG_DIRECTIVES, _user_directives)
|
||||
return _get_directives(_user_directives)
|
||||
|
||||
|
||||
def init_devices() -> "Devices":
|
||||
"""Validate & initialize devices."""
|
||||
_user_devices = _config_required(CONFIG_DEVICES)
|
||||
log.debug("Unvalidated devices from {}: {}", CONFIG_DEVICES, _user_devices)
|
||||
return Devices(_user_devices.get("devices", _user_devices.get("routers", [])))
|
||||
|
||||
|
||||
def init_ui_params(*, params: "Params", devices: "Devices") -> "UIParameters":
|
||||
"""Validate & initialize UI parameters."""
|
||||
|
||||
# Project
|
||||
from hyperglass.defaults import CREDIT
|
||||
from hyperglass.constants import PARSED_RESPONSE_FIELDS, __version__
|
||||
|
||||
content_greeting = get_markdown(
|
||||
config_path=params.web.greeting, default="", params={"title": params.web.greeting.title},
|
||||
)
|
||||
content_credit = CREDIT.format(version=__version__)
|
||||
|
||||
# If keywords are unmodified (default), add the org name &
|
||||
# site_title.
|
||||
if Params().site_keywords == params.site_keywords:
|
||||
params.site_keywords = sorted({*params.site_keywords, params.org_name, params.site_title})
|
||||
_ui_params = params.frontend()
|
||||
_ui_params["web"]["logo"]["light_format"] = params.web.logo.light.suffix
|
||||
_ui_params["web"]["logo"]["dark_format"] = params.web.logo.dark.suffix
|
||||
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
|
||||
content_greeting = get_markdown(
|
||||
config_path=params.web.greeting, default="", params={"title": params.web.greeting.title},
|
||||
)
|
||||
|
||||
|
||||
content_credit = CREDIT.format(version=__version__)
|
||||
|
||||
_ui_params = params.frontend()
|
||||
_ui_params["web"]["logo"]["light_format"] = params.web.logo.light.suffix
|
||||
_ui_params["web"]["logo"]["dark_format"] = params.web.logo.dark.suffix
|
||||
|
||||
ui_params = UIParameters(
|
||||
**_ui_params,
|
||||
version=__version__,
|
||||
networks=devices.networks(params),
|
||||
parsed_data_fields=PARSED_RESPONSE_FIELDS,
|
||||
content={"credit": content_credit, "greeting": content_greeting},
|
||||
)
|
||||
return UIParameters(
|
||||
**_ui_params,
|
||||
version=__version__,
|
||||
networks=devices.networks(params),
|
||||
parsed_data_fields=PARSED_RESPONSE_FIELDS,
|
||||
content={"credit": content_credit, "greeting": content_greeting},
|
||||
)
|
||||
|
|
|
|||
|
|
@ -7,13 +7,13 @@ from pathlib import Path
|
|||
|
||||
# Project
|
||||
from hyperglass.log import log
|
||||
from hyperglass.state import use_state
|
||||
from hyperglass.models.directive import Directives
|
||||
|
||||
|
||||
def register_builtin_directives() -> None:
|
||||
def init_builtin_directives() -> "Directives":
|
||||
"""Find all directives and register them with global state manager."""
|
||||
directives_dir = Path(__file__).parent
|
||||
state = use_state()
|
||||
directives = ()
|
||||
for _, name, __ in pkgutil.iter_modules([directives_dir]):
|
||||
module = importlib.import_module(f"hyperglass.defaults.directives.{name}")
|
||||
|
||||
|
|
@ -22,4 +22,5 @@ def register_builtin_directives() -> None:
|
|||
log.warning("Module '{!s}' is missing an '__all__' export", module)
|
||||
|
||||
exports = (getattr(module, p) for p in module.__all__ if hasattr(module, p))
|
||||
state.add_directive(*exports)
|
||||
directives += (*exports,)
|
||||
return Directives(*directives)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
"""Default Juniper Directives."""
|
||||
|
||||
# Project
|
||||
from hyperglass.models.directive import Rule, Text, NativeDirective
|
||||
from hyperglass.models.directive import Rule, Text, BuiltinDirective
|
||||
|
||||
__all__ = (
|
||||
"JuniperBGPRoute",
|
||||
|
|
@ -14,7 +14,7 @@ __all__ = (
|
|||
"JuniperBGPCommunityTable",
|
||||
)
|
||||
|
||||
JuniperBGPRoute = NativeDirective(
|
||||
JuniperBGPRoute = BuiltinDirective(
|
||||
id="__hyperglass_juniper_bgp_route__",
|
||||
name="BGP Route",
|
||||
rules=[
|
||||
|
|
@ -33,7 +33,7 @@ JuniperBGPRoute = NativeDirective(
|
|||
platforms=["juniper"],
|
||||
)
|
||||
|
||||
JuniperBGPASPath = NativeDirective(
|
||||
JuniperBGPASPath = BuiltinDirective(
|
||||
id="__hyperglass_juniper_bgp_aspath__",
|
||||
name="BGP AS Path",
|
||||
rules=[
|
||||
|
|
@ -50,7 +50,7 @@ JuniperBGPASPath = NativeDirective(
|
|||
platforms=["juniper"],
|
||||
)
|
||||
|
||||
JuniperBGPCommunity = NativeDirective(
|
||||
JuniperBGPCommunity = BuiltinDirective(
|
||||
id="__hyperglass_juniper_bgp_community__",
|
||||
name="BGP Community",
|
||||
rules=[
|
||||
|
|
@ -68,7 +68,7 @@ JuniperBGPCommunity = NativeDirective(
|
|||
)
|
||||
|
||||
|
||||
JuniperPing = NativeDirective(
|
||||
JuniperPing = BuiltinDirective(
|
||||
id="__hyperglass_juniper_ping__",
|
||||
name="Ping",
|
||||
rules=[
|
||||
|
|
@ -87,7 +87,7 @@ JuniperPing = NativeDirective(
|
|||
platforms=["juniper"],
|
||||
)
|
||||
|
||||
JuniperTraceroute = NativeDirective(
|
||||
JuniperTraceroute = BuiltinDirective(
|
||||
id="__hyperglass_juniper_traceroute__",
|
||||
name="Traceroute",
|
||||
rules=[
|
||||
|
|
@ -108,7 +108,7 @@ JuniperTraceroute = NativeDirective(
|
|||
|
||||
# Table Output Directives
|
||||
|
||||
JuniperBGPRouteTable = NativeDirective(
|
||||
JuniperBGPRouteTable = BuiltinDirective(
|
||||
id="__hyperglass_juniper_bgp_route_table__",
|
||||
name="BGP Route",
|
||||
rules=[
|
||||
|
|
@ -128,7 +128,7 @@ JuniperBGPRouteTable = NativeDirective(
|
|||
platforms=["juniper"],
|
||||
)
|
||||
|
||||
JuniperBGPASPathTable = NativeDirective(
|
||||
JuniperBGPASPathTable = BuiltinDirective(
|
||||
id="__hyperglass_juniper_bgp_aspath_table__",
|
||||
name="BGP AS Path",
|
||||
rules=[
|
||||
|
|
@ -146,7 +146,7 @@ JuniperBGPASPathTable = NativeDirective(
|
|||
platforms=["juniper"],
|
||||
)
|
||||
|
||||
JuniperBGPCommunityTable = NativeDirective(
|
||||
JuniperBGPCommunityTable = BuiltinDirective(
|
||||
id="__hyperglass_juniper_bgp_community_table__",
|
||||
name="BGP Community",
|
||||
rules=[
|
||||
|
|
|
|||
|
|
@ -21,7 +21,6 @@ from .plugins import (
|
|||
)
|
||||
from .constants import MIN_NODE_VERSION, MIN_PYTHON_VERSION, __version__
|
||||
from .util.frontend import get_node_version
|
||||
from .defaults.directives import register_builtin_directives
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
# Local
|
||||
|
|
@ -43,6 +42,7 @@ if node_major != MIN_NODE_VERSION:
|
|||
from .util import cpu_count
|
||||
from .state import use_state
|
||||
from .settings import Settings
|
||||
from .configuration import init_user_config
|
||||
from .util.frontend import build_frontend
|
||||
|
||||
|
||||
|
|
@ -89,8 +89,6 @@ def on_starting(server: "Arbiter") -> None:
|
|||
|
||||
state = use_state()
|
||||
|
||||
register_builtin_directives()
|
||||
|
||||
register_all_plugins(state.devices)
|
||||
|
||||
asyncio.run(build_ui())
|
||||
|
|
@ -103,7 +101,7 @@ def on_starting(server: "Arbiter") -> None:
|
|||
)
|
||||
|
||||
|
||||
def on_exit(*_: t.Any) -> None:
|
||||
def on_exit(_: t.Any) -> None:
|
||||
"""Gunicorn shutdown tasks."""
|
||||
|
||||
log.critical("Stopping hyperglass {}", __version__)
|
||||
|
|
@ -165,6 +163,7 @@ def start(*, log_level: str, workers: int, **kwargs) -> None:
|
|||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
init_user_config()
|
||||
set_log_level(log, Settings.debug)
|
||||
|
||||
log.debug("System settings: {!r}", Settings)
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ from hyperglass.util import (
|
|||
resolve_hostname,
|
||||
validate_platform,
|
||||
)
|
||||
from hyperglass.state import use_state
|
||||
from hyperglass.settings import Settings
|
||||
from hyperglass.constants import SCRAPE_HELPERS, SUPPORTED_STRUCTURED_OUTPUT
|
||||
from hyperglass.exceptions.private import ConfigError, UnsupportedDevice
|
||||
|
|
@ -33,6 +34,12 @@ from ..directive import Directives
|
|||
from .credential import Credential
|
||||
|
||||
|
||||
class DirectiveOptions(HyperglassModel, extra="ignore"):
|
||||
"""Per-device directive options."""
|
||||
|
||||
builtins: Union[StrictBool, List[StrictStr]] = True
|
||||
|
||||
|
||||
class Device(HyperglassModelWithId, extra="allow"):
|
||||
"""Validation model for per-router config in devices.yaml."""
|
||||
|
||||
|
|
@ -95,10 +102,6 @@ class Device(HyperglassModelWithId, extra="allow"):
|
|||
"network": self.network.display_name,
|
||||
}
|
||||
|
||||
@property
|
||||
def directive_builtins(self) -> List[str]:
|
||||
...
|
||||
|
||||
@property
|
||||
def directive_commands(self) -> List[str]:
|
||||
"""Get all commands associated with the device."""
|
||||
|
|
@ -189,7 +192,7 @@ class Device(HyperglassModelWithId, extra="allow"):
|
|||
|
||||
@root_validator(pre=True)
|
||||
def validate_device(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Validate & rewrite device platform, set default directives."""
|
||||
"""Validate & rewrite device platform, set default `directives`."""
|
||||
|
||||
platform = values.get("platform")
|
||||
if platform is None:
|
||||
|
|
@ -210,21 +213,39 @@ class Device(HyperglassModelWithId, extra="allow"):
|
|||
|
||||
values["platform"] = platform
|
||||
|
||||
directives = values.get("directives")
|
||||
directives = use_state("directives")
|
||||
|
||||
if directives is None:
|
||||
# TODO: This should be different now, and could be removed after there's a way to associate default directives
|
||||
# If no directive are defined, set directive to the NOS.
|
||||
inferred = values["platform"]
|
||||
directive_ids = values.get("directives", [])
|
||||
|
||||
# If the _telnet prefix is added, remove it from the command
|
||||
# profile so the commands are the same regardless of
|
||||
# protocol.
|
||||
if "_telnet" in inferred:
|
||||
inferred = inferred.replace("_telnet", "")
|
||||
# Directive options
|
||||
directive_options = DirectiveOptions(
|
||||
**{
|
||||
k: v
|
||||
for statement in directive_ids
|
||||
if isinstance(statement, Dict)
|
||||
for k, v in statement.items()
|
||||
}
|
||||
)
|
||||
|
||||
values["directives"] = [inferred]
|
||||
# String directive IDs, excluding builtins and options.
|
||||
directive_ids = [
|
||||
statement
|
||||
for statement in directive_ids
|
||||
if isinstance(statement, str) and not statement.startswith("__")
|
||||
]
|
||||
# Directives matching provided IDs.
|
||||
device_directives = directives.filter_by_ids(*directive_ids)
|
||||
# Matching built-in directives for this device's platform.
|
||||
builtins = directives.device_builtins(platform=platform)
|
||||
|
||||
if directive_options.builtins is True:
|
||||
# Add all builtins.
|
||||
device_directives += builtins
|
||||
elif isinstance(directive_options.builtins, List):
|
||||
# If the user provides a list of builtin directives to include, add only those.
|
||||
device_directives += builtins.matching(*directive_options.builtins)
|
||||
|
||||
values["directives"] = device_directives
|
||||
return values
|
||||
|
||||
@validator("driver")
|
||||
|
|
|
|||
|
|
@ -305,14 +305,14 @@ class Directive(HyperglassModelWithId):
|
|||
return value
|
||||
|
||||
|
||||
class NativeDirective(Directive):
|
||||
class BuiltinDirective(Directive):
|
||||
"""Natively-supported directive."""
|
||||
|
||||
__hyperglass_builtin__: t.ClassVar[bool] = True
|
||||
platforms: Series[str] = []
|
||||
|
||||
|
||||
DirectiveT = t.Union[NativeDirective, Directive]
|
||||
DirectiveT = t.Union[BuiltinDirective, Directive]
|
||||
|
||||
|
||||
class Directives(HyperglassMultiModel[Directive]):
|
||||
|
|
@ -322,10 +322,35 @@ class Directives(HyperglassMultiModel[Directive]):
|
|||
"""Initialize base class and validate objects."""
|
||||
super().__init__(*items, model=Directive, accessor="id")
|
||||
|
||||
def __add__(self, other: "Directives") -> "Directives":
|
||||
"""Create a new `Directives` instance by merging this instance with another."""
|
||||
valid = all(
|
||||
(
|
||||
isinstance(other, self.__class__),
|
||||
hasattr(other, "model"),
|
||||
getattr(other, "model", None) == self.model,
|
||||
),
|
||||
)
|
||||
if not valid:
|
||||
raise TypeError(f"Cannot add {other!r} to {self.__class__.__name__}")
|
||||
merged = self._merge_with(*other, unique_by=self.accessor)
|
||||
return Directives(*merged)
|
||||
|
||||
def ids(self) -> t.Tuple[str]:
|
||||
"""Get all directive IDs."""
|
||||
return tuple(directive.id for directive in self)
|
||||
return tuple(sorted(directive.id for directive in self))
|
||||
|
||||
def filter_by_ids(self, *ids) -> "Directives":
|
||||
def filter_by_ids(self, *ids: str) -> "Directives":
|
||||
"""Filter directives by directive IDs."""
|
||||
return Directives(*[directive for directive in self if directive.id in ids])
|
||||
return Directives(*(directive for directive in self if directive.id in ids))
|
||||
|
||||
def device_builtins(self, *, platform: str):
|
||||
"""Get builtin directives for a device."""
|
||||
return Directives(
|
||||
*(
|
||||
directive
|
||||
for directive in self
|
||||
if directive.__hyperglass_builtin__ is True
|
||||
and platform in getattr(directive, "platforms", ())
|
||||
)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
"""Data models used throughout hyperglass."""
|
||||
|
||||
# Standard Library
|
||||
|
||||
# Standard Library
|
||||
import re
|
||||
import typing as t
|
||||
|
|
@ -11,7 +13,7 @@ from pydantic.generics import GenericModel
|
|||
|
||||
# Project
|
||||
from hyperglass.log import log
|
||||
from hyperglass.util import snake_to_camel, repr_from_attrs
|
||||
from hyperglass.util import compare_init, snake_to_camel, repr_from_attrs
|
||||
from hyperglass.types import Series
|
||||
|
||||
MultiModelT = t.TypeVar("MultiModelT", bound=BaseModel)
|
||||
|
|
@ -138,6 +140,10 @@ class HyperglassMultiModel(GenericModel, t.Generic[MultiModelT]):
|
|||
super().__init__(__root__=valid)
|
||||
self._count = len(self.__root__)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Represent model."""
|
||||
return repr_from_attrs(self, ["_count", "_accessor"], strip="_")
|
||||
|
||||
def __iter__(self) -> t.Iterator[MultiModelT]:
|
||||
"""Iterate items."""
|
||||
return iter(self.__root__)
|
||||
|
|
@ -162,9 +168,29 @@ class HyperglassMultiModel(GenericModel, t.Generic[MultiModelT]):
|
|||
),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Represent model."""
|
||||
return repr_from_attrs(self, ["_count", "_accessor"], strip="_")
|
||||
def __add__(self, other: MultiModelT) -> MultiModelT:
|
||||
"""Merge another MultiModel with this one.
|
||||
|
||||
Note: If you're subclassing `HyperglassMultiModel` and overriding `__init__`, you need to
|
||||
override this too.
|
||||
"""
|
||||
valid = all(
|
||||
(
|
||||
isinstance(other, self.__class__),
|
||||
hasattr(other, "model"),
|
||||
getattr(other, "model", None) == self.model,
|
||||
),
|
||||
)
|
||||
if not valid:
|
||||
raise TypeError(f"Cannot add {other!r} to {self.__class__.__name__}")
|
||||
merged = self._merge_with(*other, unique_by=self.accessor)
|
||||
|
||||
if compare_init(self.__class__, other.__class__):
|
||||
return self.__class__(*merged, model=self.model, accessor=self.accessor)
|
||||
raise TypeError(
|
||||
f"{self.__class__.__name__} and {other.__class__.__name__} have different `__init__` "
|
||||
"signatures. You probably need to override `HyperglassMultiModel.__add__`"
|
||||
)
|
||||
|
||||
@property
|
||||
def accessor(self) -> str:
|
||||
|
|
@ -199,8 +225,24 @@ class HyperglassMultiModel(GenericModel, t.Generic[MultiModelT]):
|
|||
items[index] = self.model(**item)
|
||||
return items
|
||||
|
||||
def add(self, *items, unique_by: t.Optional[str] = None) -> None:
|
||||
"""Add an item to the model."""
|
||||
def matching(self, *accessors: str) -> MultiModelT:
|
||||
"""Get a new instance containing partial matches from `accessors`."""
|
||||
|
||||
def matches(*searches: str) -> t.Generator[MultiModelT, None, None]:
|
||||
"""Get any matching items by accessor value.
|
||||
|
||||
For example, if `accessors` is `('one', 'two')`, and `Model.<accessor>` is `'one'`,
|
||||
`Model` is yielded.
|
||||
"""
|
||||
for search in searches:
|
||||
pattern = re.compile(fr".*{search}.*", re.IGNORECASE)
|
||||
for item in self:
|
||||
if pattern.match(getattr(item, self.accessor)):
|
||||
yield item
|
||||
|
||||
return self.__class__(*matches(*accessors))
|
||||
|
||||
def _merge_with(self, *items, unique_by: t.Optional[str] = None) -> Series[MultiModelT]:
|
||||
to_add = self._valid_items(*items)
|
||||
if unique_by is not None:
|
||||
unique_by_values = {
|
||||
|
|
@ -212,10 +254,12 @@ class HyperglassMultiModel(GenericModel, t.Generic[MultiModelT]):
|
|||
for o in (*self, *to_add)
|
||||
if getattr(o, unique_by) == v
|
||||
}
|
||||
new: t.List[MultiModelT] = list(unique_by_objects.values())
|
||||
return tuple(unique_by_objects.values())
|
||||
return (*self.__root__, *to_add)
|
||||
|
||||
else:
|
||||
new: t.List[MultiModelT] = [*self.__root__, *to_add]
|
||||
def add(self, *items, unique_by: t.Optional[str] = None) -> None:
|
||||
"""Add an item to the model."""
|
||||
new = self._merge_with(*items, unique_by=unique_by)
|
||||
self.__root__ = new
|
||||
self._count = len(self.__root__)
|
||||
for item in new:
|
||||
|
|
|
|||
|
|
@ -35,7 +35,8 @@ def _tester(sample: str):
|
|||
credential={"username": "", "password": ""},
|
||||
platform="juniper",
|
||||
structured_output=True,
|
||||
directives=[{"id": "test", "name": "Test", "rules": []}],
|
||||
directives=[],
|
||||
attrs={"source4": "192.0.2.1", "source6": "2001:db8::1"},
|
||||
)
|
||||
|
||||
# Override has_directives method for testing.
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ def _use_state(attr: t.Optional[str] = None) -> "HyperglassState":
|
|||
if attr is None:
|
||||
return HyperglassState(settings=Settings)
|
||||
if attr in ("cache", "redis"):
|
||||
return HyperglassState(settings=Settings).redis
|
||||
return HyperglassState(settings=Settings).cache
|
||||
if attr in HyperglassState.properties():
|
||||
return getattr(HyperglassState(settings=Settings), attr)
|
||||
raise StateError("'{attr}' does not exist on HyperglassState", attr=attr)
|
||||
|
|
|
|||
|
|
@ -3,15 +3,18 @@
|
|||
# Standard Library
|
||||
import pickle
|
||||
import typing as t
|
||||
from types import TracebackType
|
||||
from typing import overload
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
# Project
|
||||
from hyperglass.log import log
|
||||
from hyperglass.exceptions.private import StateError
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
# Third Party
|
||||
from redis import Redis
|
||||
from redis.client import Pipeline
|
||||
|
||||
|
||||
class RedisManager:
|
||||
|
|
@ -125,3 +128,57 @@ class RedisManager:
|
|||
"""Add a value to a hash map (dict)."""
|
||||
name = self.key(key)
|
||||
self.instance.hset(name, item, pickle.dumps(value))
|
||||
|
||||
def pipeline(self):
|
||||
"""Enter a Redis Pipeline, but expose all the custom interaction methods."""
|
||||
# Copy the base RedisManager and remove the pipeline method (this method).
|
||||
ctx = type(
|
||||
"RedisManagerExcludePipeline",
|
||||
(RedisManager,),
|
||||
{k: v for k, v in self.__dict__.items() if k != "pipeline"},
|
||||
)
|
||||
|
||||
def nested_pipeline(*_, **__) -> None:
|
||||
"""Ensure pipeline is never called from within pipeline."""
|
||||
raise AttributeError("Cannot access pipeline from pipeline")
|
||||
|
||||
class RedisManagerPipeline(ctx):
|
||||
"""Copy of RedisManager, but uses `Redis.pipeline` as the `instance`."""
|
||||
|
||||
parent: "Redis"
|
||||
instance: "Pipeline"
|
||||
pipeline: t.Any = nested_pipeline
|
||||
|
||||
def __init__(
|
||||
pipeline_self, # noqa: N805 Avoid `self` namespace conflict
|
||||
*,
|
||||
parent: "Redis",
|
||||
instance: "Pipeline",
|
||||
namespace: str,
|
||||
) -> None:
|
||||
pipeline_self.parent = parent
|
||||
super().__init__(instance=instance, namespace=namespace)
|
||||
|
||||
def __enter__(
|
||||
pipeline_self: "RedisManagerPipeline", # noqa: N805 Avoid `self` namespace conflict
|
||||
) -> "RedisManagerPipeline":
|
||||
return pipeline_self
|
||||
|
||||
def __exit__(
|
||||
pipeline_self: "RedisManagerPipeline", # noqa: N805 Avoid `self` namespace conflict
|
||||
exc_type: t.Optional[t.Type[BaseException]] = None,
|
||||
exc_value: t.Optional[BaseException] = None,
|
||||
_: t.Optional[TracebackType] = None,
|
||||
) -> None:
|
||||
pipeline_self.instance.execute()
|
||||
if exc_type is not None:
|
||||
log.error(
|
||||
"Error in pipeline {!r} from parent instance {!r}:\n{!s}",
|
||||
pipeline_self,
|
||||
pipeline_self.parent,
|
||||
exc_value,
|
||||
)
|
||||
|
||||
return RedisManagerPipeline(
|
||||
parent=self.instance, instance=self.instance.pipeline(), namespace=self.namespace,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -5,9 +5,6 @@ import codecs
|
|||
import pickle
|
||||
import typing as t
|
||||
|
||||
# Project
|
||||
from hyperglass.configuration import params, devices, ui_params, directives
|
||||
|
||||
# Local
|
||||
from .manager import StateManager
|
||||
|
||||
|
|
@ -20,6 +17,9 @@ if t.TYPE_CHECKING:
|
|||
from hyperglass.models.config.params import Params
|
||||
from hyperglass.models.config.devices import Devices
|
||||
|
||||
# Local
|
||||
from .manager import RedisManager
|
||||
|
||||
|
||||
PluginT = t.TypeVar("PluginT", bound="HyperglassPlugin")
|
||||
|
||||
|
|
@ -31,12 +31,6 @@ class HyperglassState(StateManager):
|
|||
"""Initialize state store and reset plugins."""
|
||||
super().__init__(settings=settings)
|
||||
|
||||
# Add configuration objects.
|
||||
self.redis.set("params", params)
|
||||
self.redis.set("devices", devices)
|
||||
self.redis.set("ui_params", ui_params)
|
||||
self.redis.set("directives", directives)
|
||||
|
||||
# Ensure plugins are empty.
|
||||
self.reset_plugins("output")
|
||||
self.reset_plugins("input")
|
||||
|
|
@ -78,6 +72,11 @@ class HyperglassState(StateManager):
|
|||
"""Delete all cache keys."""
|
||||
self.redis.instance.flushdb(asynchronous=True)
|
||||
|
||||
@property
|
||||
def cache(self) -> "RedisManager":
|
||||
"""Get the redis manager instance."""
|
||||
return self.redis
|
||||
|
||||
@property
|
||||
def params(self) -> "Params":
|
||||
"""Get hyperglass configuration parameters (`hyperglass.yaml`)."""
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue