mirror of
https://github.com/thatmattlove/hyperglass.git
synced 2026-01-17 00:38:06 +00:00
Continue output plugin implementation
This commit is contained in:
parent
560663601d
commit
74fcb5dba4
16 changed files with 124 additions and 129 deletions
|
|
@ -58,7 +58,6 @@ async def send_webhook(query_data: Query, request: Request, timestamp: datetime)
|
|||
log.error("Error sending webhook to {}: {}", params.logging.http.provider, str(err))
|
||||
|
||||
|
||||
@log.catch
|
||||
async def query(query_data: Query, request: Request, background_tasks: BackgroundTasks):
|
||||
"""Ingest request data pass it to the backend application to perform the query."""
|
||||
|
||||
|
|
|
|||
|
|
@ -120,7 +120,7 @@ class PublicHyperglassError(HyperglassError):
|
|||
"""Format error message with keyword arguments."""
|
||||
if "error" in kwargs:
|
||||
error = kwargs.pop("error")
|
||||
error = self._safe_format(error, **kwargs)
|
||||
error = self._safe_format(str(error), **kwargs)
|
||||
kwargs["error"] = error
|
||||
self._message = self._safe_format(self._message_template, **kwargs)
|
||||
self._keywords = list(kwargs.values())
|
||||
|
|
@ -150,7 +150,7 @@ class PrivateHyperglassError(HyperglassError):
|
|||
"""Format error message with keyword arguments."""
|
||||
if "error" in kwargs:
|
||||
error = kwargs.pop("error")
|
||||
error = self._safe_format(error, **kwargs)
|
||||
error = self._safe_format(str(error), **kwargs)
|
||||
kwargs["error"] = error
|
||||
self._message = self._safe_format(message, **kwargs)
|
||||
self._keywords = list(kwargs.values())
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
"""User-facing/Public exceptions."""
|
||||
|
||||
# Standard Library
|
||||
from typing import Any, Dict, Optional, ForwardRef
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
|
||||
# Project
|
||||
from hyperglass.configuration import params
|
||||
|
|
@ -9,8 +9,10 @@ from hyperglass.configuration import params
|
|||
# Local
|
||||
from ._common import PublicHyperglassError
|
||||
|
||||
Query = ForwardRef("Query")
|
||||
Device = ForwardRef("Device")
|
||||
if TYPE_CHECKING:
|
||||
# Project
|
||||
from hyperglass.models.api.query import Query
|
||||
from hyperglass.models.config.devices import Device
|
||||
|
||||
|
||||
class ScrapeError(
|
||||
|
|
@ -18,7 +20,7 @@ class ScrapeError(
|
|||
):
|
||||
"""Raised when an SSH driver error occurs."""
|
||||
|
||||
def __init__(self, error: BaseException, *, device: Device):
|
||||
def __init__(self, *, error: BaseException, device: "Device"):
|
||||
"""Initialize parent error."""
|
||||
super().__init__(error=str(error), device=device.name, proxy=device.proxy)
|
||||
|
||||
|
|
@ -28,7 +30,7 @@ class AuthError(
|
|||
):
|
||||
"""Raised when authentication to a device fails."""
|
||||
|
||||
def __init__(self, error: BaseException, *, device: Device):
|
||||
def __init__(self, *, error: BaseException, device: "Device"):
|
||||
"""Initialize parent error."""
|
||||
super().__init__(error=str(error), device=device.name, proxy=device.proxy)
|
||||
|
||||
|
|
@ -36,7 +38,7 @@ class AuthError(
|
|||
class RestError(PublicHyperglassError, template=params.messages.connection_error, level="danger"):
|
||||
"""Raised upon a rest API client error."""
|
||||
|
||||
def __init__(self, error: BaseException, *, device: Device):
|
||||
def __init__(self, *, error: BaseException, device: "Device"):
|
||||
"""Initialize parent error."""
|
||||
super().__init__(error=str(error), device=device.name)
|
||||
|
||||
|
|
@ -46,7 +48,7 @@ class DeviceTimeout(
|
|||
):
|
||||
"""Raised when the connection to a device times out."""
|
||||
|
||||
def __init__(self, error: BaseException, *, device: Device):
|
||||
def __init__(self, *, error: BaseException, device: "Device"):
|
||||
"""Initialize parent error."""
|
||||
super().__init__(error=str(error), device=device.name, proxy=device.proxy)
|
||||
|
||||
|
|
@ -55,7 +57,7 @@ class InvalidQuery(PublicHyperglassError, template=params.messages.invalid_query
|
|||
"""Raised when input validation fails."""
|
||||
|
||||
def __init__(
|
||||
self, error: Optional[str] = None, *, query: "Query", **kwargs: Dict[str, Any]
|
||||
self, *, error: Optional[str] = None, query: "Query", **kwargs: Dict[str, Any]
|
||||
) -> None:
|
||||
"""Initialize parent error."""
|
||||
|
||||
|
|
@ -107,7 +109,7 @@ class InputInvalid(PublicHyperglassError, template=params.messages.invalid_input
|
|||
"""Raised when input validation fails."""
|
||||
|
||||
def __init__(
|
||||
self, error: Optional[Any] = None, *, target: str, **kwargs: Dict[str, Any]
|
||||
self, *, error: Optional[Any] = None, target: str, **kwargs: Dict[str, Any]
|
||||
) -> None:
|
||||
"""Initialize parent error."""
|
||||
|
||||
|
|
@ -123,7 +125,7 @@ class InputNotAllowed(PublicHyperglassError, template=params.messages.acl_not_al
|
|||
"""Raised when input validation fails due to a configured check."""
|
||||
|
||||
def __init__(
|
||||
self, error: Optional[str] = None, *, query: Query, **kwargs: Dict[str, Any]
|
||||
self, *, error: Optional[str] = None, query: "Query", **kwargs: Dict[str, Any]
|
||||
) -> None:
|
||||
"""Initialize parent error."""
|
||||
|
||||
|
|
@ -143,7 +145,7 @@ class ResponseEmpty(PublicHyperglassError, template=params.messages.no_output):
|
|||
"""Raised when hyperglass can connect to the device but the response is empty."""
|
||||
|
||||
def __init__(
|
||||
self, error: Optional[str] = None, *, query: Query, **kwargs: Dict[str, Any]
|
||||
self, *, error: Optional[str] = None, query: "Query", **kwargs: Dict[str, Any]
|
||||
) -> None:
|
||||
"""Initialize parent error."""
|
||||
|
||||
|
|
|
|||
|
|
@ -7,23 +7,21 @@ from typing import TYPE_CHECKING, Dict, Union, Sequence
|
|||
# Project
|
||||
from hyperglass.log import log
|
||||
from hyperglass.plugins import OutputPluginManager
|
||||
from hyperglass.models.api import Query
|
||||
from hyperglass.parsing.nos import scrape_parsers, structured_parsers
|
||||
from hyperglass.parsing.common import parsers
|
||||
from hyperglass.models.config.devices import Device
|
||||
|
||||
# Local
|
||||
from ._construct import Construct
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# Project
|
||||
from hyperglass.models.api import Query
|
||||
from hyperglass.compat._sshtunnel import SSHTunnelForwarder
|
||||
from hyperglass.models.config.devices import Device
|
||||
|
||||
|
||||
class Connection(ABC):
|
||||
"""Base transport driver class."""
|
||||
|
||||
def __init__(self, device: Device, query_data: Query) -> None:
|
||||
def __init__(self, device: "Device", query_data: "Query") -> None:
|
||||
"""Initialize connection to device."""
|
||||
self.device = device
|
||||
self.query_data = query_data
|
||||
|
|
@ -38,53 +36,14 @@ class Connection(ABC):
|
|||
"""Return a preconfigured sshtunnel.SSHTunnelForwarder instance."""
|
||||
pass
|
||||
|
||||
async def parsed_response( # noqa: C901 ("too complex")
|
||||
self, output: Sequence[str]
|
||||
) -> Union[str, Sequence[Dict]]:
|
||||
async def parsed_response(self, output: Sequence[str]) -> Union[str, Sequence[Dict]]:
|
||||
"""Send output through common parsers."""
|
||||
|
||||
log.debug("Pre-parsed responses:\n{}", output)
|
||||
parsed = ()
|
||||
response = None
|
||||
|
||||
structured_nos = structured_parsers.keys()
|
||||
structured_query_types = structured_parsers.get(self.device.nos, {}).keys()
|
||||
|
||||
scrape_nos = scrape_parsers.keys()
|
||||
scrape_query_types = scrape_parsers.get(self.device.nos, {}).keys()
|
||||
|
||||
if not self.device.structured_output:
|
||||
_parsed = ()
|
||||
for func in parsers:
|
||||
for response in output:
|
||||
_output = func(commands=self.query, output=response)
|
||||
_parsed += (_output,)
|
||||
if self.device.nos in scrape_nos and self.query_type in scrape_query_types:
|
||||
func = scrape_parsers[self.device.nos][self.query_type]
|
||||
for response in _parsed:
|
||||
_output = func(response)
|
||||
parsed += (_output,)
|
||||
else:
|
||||
parsed += _parsed
|
||||
|
||||
response = "\n\n".join(parsed)
|
||||
elif (
|
||||
self.device.structured_output
|
||||
and self.device.nos in structured_nos
|
||||
and self.query_type not in structured_query_types
|
||||
):
|
||||
for func in parsers:
|
||||
for response in output:
|
||||
_output = func(commands=self.query, output=response)
|
||||
parsed += (_output,)
|
||||
response = "\n\n".join(parsed)
|
||||
elif (
|
||||
self.device.structured_output
|
||||
and self.device.nos in structured_nos
|
||||
and self.query_type in structured_query_types
|
||||
):
|
||||
func = structured_parsers[self.device.nos][self.query_type]
|
||||
response = func(output)
|
||||
response = self.plugin_manager.execute(
|
||||
directive=self.query_data.directive, output=output, device=self.device
|
||||
)
|
||||
|
||||
if response is None:
|
||||
response = "\n\n".join(output)
|
||||
|
|
|
|||
|
|
@ -55,7 +55,10 @@ async def execute(query: "Query") -> Union[str, Sequence[Dict]]:
|
|||
mapped_driver = map_driver(query.device.driver)
|
||||
driver: "Connection" = mapped_driver(query.device, query)
|
||||
|
||||
signal.signal(signal.SIGALRM, handle_timeout(error=TimeoutError(), device=query.device))
|
||||
signal.signal(
|
||||
signal.SIGALRM,
|
||||
handle_timeout(error=TimeoutError("Connection timed out"), device=query.device),
|
||||
)
|
||||
signal.alarm(params.request_timeout - 1)
|
||||
|
||||
if query.device.proxy:
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@ import shutil
|
|||
import logging
|
||||
import platform
|
||||
from typing import TYPE_CHECKING
|
||||
from pathlib import Path
|
||||
|
||||
# Third Party
|
||||
from gunicorn.app.base import BaseApplication # type: ignore
|
||||
|
|
@ -124,10 +123,8 @@ def cache_config() -> bool:
|
|||
def register_all_plugins(devices: "Devices") -> None:
|
||||
"""Validate and register configured plugins."""
|
||||
|
||||
for plugin_file in {
|
||||
Path(p) for p in (p for d in devices.objects for c in d.commands for p in c.plugins)
|
||||
}:
|
||||
failures = register_plugin(plugin_file)
|
||||
for plugin_file, directives in devices.directive_plugins().items():
|
||||
failures = register_plugin(plugin_file, directives=directives)
|
||||
for failure in failures:
|
||||
log.warning(
|
||||
"Plugin '{}' is not a valid hyperglass plugin, and was not registered", failure,
|
||||
|
|
@ -203,11 +200,9 @@ class HyperglassWSGI(BaseApplication):
|
|||
|
||||
def start(**kwargs):
|
||||
"""Start hyperglass via gunicorn."""
|
||||
# Project
|
||||
from hyperglass.api import app
|
||||
|
||||
HyperglassWSGI(
|
||||
app=app,
|
||||
app="hyperglass.api:app",
|
||||
options={
|
||||
"worker_class": "uvicorn.workers.UvicornWorker",
|
||||
"preload": True,
|
||||
|
|
|
|||
|
|
@ -1,6 +1,9 @@
|
|||
"""All Data Models used by hyperglass."""
|
||||
|
||||
# Local
|
||||
from .main import HyperglassModel
|
||||
from .main import HyperglassModel, HyperglassModelWithId
|
||||
|
||||
__all__ = ("HyperglassModel",)
|
||||
__all__ = (
|
||||
"HyperglassModel",
|
||||
"HyperglassModelWithId",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -168,7 +168,7 @@ class Query(BaseModel):
|
|||
def validate_query_location(cls, value):
|
||||
"""Ensure query_location is defined."""
|
||||
|
||||
valid_id = value in devices._ids
|
||||
valid_id = value in devices.ids
|
||||
valid_hostname = value in devices.hostnames
|
||||
|
||||
if not any((valid_id, valid_hostname)):
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ from hyperglass.log import log
|
|||
from hyperglass.exceptions.private import InputValidationError
|
||||
|
||||
# Local
|
||||
from ..main import HyperglassModel
|
||||
from ..main import HyperglassModel, HyperglassModelWithId
|
||||
from ..fields import Action
|
||||
from ..config.params import Params
|
||||
|
||||
|
|
@ -224,7 +224,7 @@ class RuleWithoutValidation(Rule):
|
|||
Rules = Union[RuleWithIPv4, RuleWithIPv6, RuleWithPattern, RuleWithoutValidation]
|
||||
|
||||
|
||||
class Directive(HyperglassModel):
|
||||
class Directive(HyperglassModelWithId):
|
||||
"""A directive contains commands that can be run on a device, as long as defined rules are met."""
|
||||
|
||||
id: StrictStr
|
||||
|
|
|
|||
|
|
@ -3,19 +3,12 @@
|
|||
# Standard Library
|
||||
import os
|
||||
import re
|
||||
from typing import Any, Dict, List, Tuple, Union, Optional
|
||||
from typing import Any, Set, Dict, List, Tuple, Union, Optional
|
||||
from pathlib import Path
|
||||
from ipaddress import IPv4Address, IPv6Address
|
||||
|
||||
# Third Party
|
||||
from pydantic import (
|
||||
StrictInt,
|
||||
StrictStr,
|
||||
StrictBool,
|
||||
PrivateAttr,
|
||||
validator,
|
||||
root_validator,
|
||||
)
|
||||
from pydantic import StrictInt, StrictStr, StrictBool, validator, root_validator
|
||||
|
||||
# Project
|
||||
from hyperglass.log import log
|
||||
|
|
@ -26,7 +19,7 @@ from hyperglass.models.commands.generic import Directive
|
|||
|
||||
# Local
|
||||
from .ssl import Ssl
|
||||
from ..main import HyperglassModel
|
||||
from ..main import HyperglassModel, HyperglassModelWithId
|
||||
from .proxy import Proxy
|
||||
from .params import Params
|
||||
from ..fields import SupportedDriver
|
||||
|
|
@ -34,10 +27,10 @@ from .network import Network
|
|||
from .credential import Credential
|
||||
|
||||
|
||||
class Device(HyperglassModel, extra="allow"):
|
||||
class Device(HyperglassModelWithId, extra="allow"):
|
||||
"""Validation model for per-router config in devices.yaml."""
|
||||
|
||||
_id: StrictStr = PrivateAttr()
|
||||
id: StrictStr
|
||||
name: StrictStr
|
||||
address: Union[IPv4Address, IPv6Address, StrictStr]
|
||||
network: Network
|
||||
|
|
@ -55,23 +48,9 @@ class Device(HyperglassModel, extra="allow"):
|
|||
def __init__(self, **kwargs) -> None:
|
||||
"""Set the device ID."""
|
||||
_id, values = self._generate_id(kwargs)
|
||||
super().__init__(**values)
|
||||
self._id = _id
|
||||
super().__init__(id=_id, **values)
|
||||
self._validate_directive_attrs()
|
||||
|
||||
def __hash__(self) -> int:
|
||||
"""Make device object hashable so the object can be deduplicated with set()."""
|
||||
return hash((self.name,))
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
"""Make device object comparable so the object can be deduplicated with set()."""
|
||||
result = False
|
||||
|
||||
if isinstance(other, HyperglassModel):
|
||||
result = self.name == other.name
|
||||
|
||||
return result
|
||||
|
||||
@property
|
||||
def _target(self):
|
||||
return str(self.address)
|
||||
|
|
@ -104,7 +83,7 @@ class Device(HyperglassModel, extra="allow"):
|
|||
def export_api(self) -> Dict[str, Any]:
|
||||
"""Export API-facing device fields."""
|
||||
return {
|
||||
"id": self._id,
|
||||
"id": self.id,
|
||||
"name": self.name,
|
||||
"network": self.network.display_name,
|
||||
}
|
||||
|
|
@ -233,7 +212,7 @@ class Device(HyperglassModel, extra="allow"):
|
|||
class Devices(HyperglassModel, extra="allow"):
|
||||
"""Validation model for device configurations."""
|
||||
|
||||
_ids: List[StrictStr] = []
|
||||
ids: List[StrictStr] = []
|
||||
hostnames: List[StrictStr] = []
|
||||
objects: List[Device] = []
|
||||
all_nos: List[StrictStr] = []
|
||||
|
|
@ -248,7 +227,7 @@ class Devices(HyperglassModel, extra="allow"):
|
|||
all_nos = set()
|
||||
objects = set()
|
||||
hostnames = set()
|
||||
_ids = set()
|
||||
ids = set()
|
||||
|
||||
init_kwargs = {}
|
||||
|
||||
|
|
@ -261,13 +240,13 @@ class Devices(HyperglassModel, extra="allow"):
|
|||
# list with `devices.hostnames`, same for all router
|
||||
# classes, for when iteration over all routers is required.
|
||||
hostnames.add(device.name)
|
||||
_ids.add(device._id)
|
||||
ids.add(device.id)
|
||||
objects.add(device)
|
||||
all_nos.add(device.nos)
|
||||
|
||||
# Convert the de-duplicated sets to a standard list, add lists
|
||||
# as class attributes. Sort router list by router name attribute
|
||||
init_kwargs["_ids"] = list(_ids)
|
||||
init_kwargs["ids"] = list(ids)
|
||||
init_kwargs["hostnames"] = list(hostnames)
|
||||
init_kwargs["all_nos"] = list(all_nos)
|
||||
init_kwargs["objects"] = sorted(objects, key=lambda x: x.name)
|
||||
|
|
@ -277,7 +256,7 @@ class Devices(HyperglassModel, extra="allow"):
|
|||
def __getitem__(self, accessor: str) -> Device:
|
||||
"""Get a device by its name."""
|
||||
for device in self.objects:
|
||||
if device._id == accessor:
|
||||
if device.id == accessor:
|
||||
return device
|
||||
elif device.name == accessor:
|
||||
return device
|
||||
|
|
@ -296,7 +275,7 @@ class Devices(HyperglassModel, extra="allow"):
|
|||
"display_name": name,
|
||||
"locations": [
|
||||
{
|
||||
"id": device._id,
|
||||
"id": device.id,
|
||||
"name": device.name,
|
||||
"network": device.network.display_name,
|
||||
"directives": [c.frontend(params) for c in device.commands],
|
||||
|
|
@ -307,3 +286,20 @@ class Devices(HyperglassModel, extra="allow"):
|
|||
}
|
||||
for name in names
|
||||
]
|
||||
|
||||
def directive_plugins(self) -> Dict[Path, Tuple[StrictStr]]:
|
||||
"""Get a mapping of plugin paths to associated directive IDs."""
|
||||
result: Dict[Path, Set[StrictStr]] = {}
|
||||
# Unique set of all directives.
|
||||
directives = {directive for device in self.objects for directive in device.commands}
|
||||
# Unique set of all plugin file names.
|
||||
plugin_names = {plugin for directive in directives for plugin in directive.plugins}
|
||||
|
||||
for directive in directives:
|
||||
# Convert each plugin file name to a `Path` object.
|
||||
for plugin in (Path(p) for p in directive.plugins if p in plugin_names):
|
||||
if plugin not in result:
|
||||
result[plugin] = set()
|
||||
result[plugin].add(directive.id)
|
||||
# Convert the directive set to a tuple.
|
||||
return {k: tuple(v) for k, v in result.items()}
|
||||
|
|
|
|||
|
|
@ -80,3 +80,25 @@ class HyperglassModel(BaseModel):
|
|||
}
|
||||
|
||||
return yaml.safe_dump(json.loads(self.export_json(**export_kwargs)), *args, **kwargs)
|
||||
|
||||
|
||||
class HyperglassModelWithId(HyperglassModel):
|
||||
"""hyperglass model that is unique by its `id` field."""
|
||||
|
||||
id: str
|
||||
|
||||
def __eq__(self: "HyperglassModelWithId", other: "HyperglassModelWithId") -> bool:
|
||||
"""Other model is equal to this model."""
|
||||
if not isinstance(other, self.__class__):
|
||||
return False
|
||||
if hasattr(other, "id"):
|
||||
return other and self.id == other.id
|
||||
return False
|
||||
|
||||
def __ne__(self: "HyperglassModelWithId", other: "HyperglassModelWithId") -> bool:
|
||||
"""Other model is not equal to this model."""
|
||||
return not self.__eq__(other)
|
||||
|
||||
def __hash__(self: "HyperglassModelWithId") -> int:
|
||||
"""Create a hashed representation of this model's name."""
|
||||
return hash(self.id)
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
# Standard Library
|
||||
from abc import ABC
|
||||
from typing import Any, Union, Literal, TypeVar
|
||||
from typing import Any, Union, Literal, TypeVar, Sequence
|
||||
from inspect import Signature
|
||||
|
||||
# Third Party
|
||||
|
|
@ -52,3 +52,9 @@ class HyperglassPlugin(BaseModel, ABC):
|
|||
"""Initialize plugin instance."""
|
||||
name = kwargs.pop("name", None) or self.__class__.__name__
|
||||
super().__init__(name=name, **kwargs)
|
||||
|
||||
|
||||
class DirectivePlugin(HyperglassPlugin):
|
||||
"""Plugin associated with directives."""
|
||||
|
||||
directives: Sequence[str] = ()
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
# Local
|
||||
from ._base import HyperglassPlugin
|
||||
from ._base import DirectivePlugin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# Project
|
||||
|
|
@ -13,7 +13,7 @@ if TYPE_CHECKING:
|
|||
InputPluginReturn = Union[None, bool]
|
||||
|
||||
|
||||
class InputPlugin(HyperglassPlugin):
|
||||
class InputPlugin(DirectivePlugin):
|
||||
"""Plugin to validate user input prior to running commands."""
|
||||
|
||||
def validate(self, query: "Query") -> InputPluginReturn:
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
import json
|
||||
import codecs
|
||||
import pickle
|
||||
from typing import TYPE_CHECKING, List, Generic, TypeVar, Callable, Generator
|
||||
from typing import TYPE_CHECKING, Any, List, Generic, TypeVar, Callable, Generator
|
||||
from inspect import isclass
|
||||
|
||||
# Project
|
||||
|
|
@ -22,6 +22,7 @@ if TYPE_CHECKING:
|
|||
# Project
|
||||
from hyperglass.models.api.query import Query
|
||||
from hyperglass.models.config.devices import Device
|
||||
from hyperglass.models.commands.generic import Directive
|
||||
|
||||
PluginT = TypeVar("PluginT")
|
||||
|
||||
|
|
@ -73,7 +74,7 @@ class PluginManager(Generic[PluginT]):
|
|||
def plugins(self: "PluginManager") -> List[PluginT]:
|
||||
"""Get all plugins, with built-in plugins last."""
|
||||
return sorted(
|
||||
self.plugins,
|
||||
self._get_plugins(),
|
||||
key=lambda p: -1 if p.__hyperglass_builtin__ else 1, # flake8: noqa IF100
|
||||
reverse=True,
|
||||
)
|
||||
|
|
@ -117,12 +118,12 @@ class PluginManager(Generic[PluginT]):
|
|||
return
|
||||
raise PluginError("Plugin '{}' is not a valid hyperglass plugin", repr(plugin))
|
||||
|
||||
def register(self: "PluginManager", plugin: PluginT) -> None:
|
||||
def register(self: "PluginManager", plugin: PluginT, *args: Any, **kwargs: Any) -> None:
|
||||
"""Add a plugin to currently active plugins."""
|
||||
# Create a set of plugins so duplicate plugins are not mistakenly added.
|
||||
try:
|
||||
if issubclass(plugin, HyperglassPlugin):
|
||||
instance = plugin()
|
||||
instance = plugin(*args, **kwargs)
|
||||
plugins = {
|
||||
# Create a base64 representation of a picked plugin.
|
||||
codecs.encode(pickle.dumps(p), "base64").decode()
|
||||
|
|
@ -131,7 +132,10 @@ class PluginManager(Generic[PluginT]):
|
|||
}
|
||||
# Add plugins from cache.
|
||||
self._cache.set(f"hyperglass.plugins.{self._type}", json.dumps(list(plugins)))
|
||||
log.success("Registered plugin '{}'", instance.name)
|
||||
if instance.__hyperglass_builtin__ is True:
|
||||
log.debug("Registered built-in plugin '{}'", instance.name)
|
||||
else:
|
||||
log.success("Registered plugin '{}'", instance.name)
|
||||
return
|
||||
except TypeError:
|
||||
raise PluginError(
|
||||
|
|
@ -145,13 +149,15 @@ class PluginManager(Generic[PluginT]):
|
|||
class InputPluginManager(PluginManager[InputPlugin], type="input"):
|
||||
"""Manage Input Validation Plugins."""
|
||||
|
||||
def execute(self: "InputPluginManager", query: "Query") -> InputPluginReturn:
|
||||
def execute(
|
||||
self: "InputPluginManager", *, directive: "Directive", query: "Query"
|
||||
) -> InputPluginReturn:
|
||||
"""Execute all input validation plugins.
|
||||
|
||||
If any plugin returns `False`, execution is halted.
|
||||
"""
|
||||
result = None
|
||||
for plugin in self.plugins:
|
||||
for plugin in (plugin for plugin in self.plugins if directive.id in plugin.directives):
|
||||
if result is False:
|
||||
return result
|
||||
result = plugin.validate(query)
|
||||
|
|
@ -161,13 +167,15 @@ class InputPluginManager(PluginManager[InputPlugin], type="input"):
|
|||
class OutputPluginManager(PluginManager[OutputPlugin], type="output"):
|
||||
"""Manage Output Processing Plugins."""
|
||||
|
||||
def execute(self: "OutputPluginManager", output: str, device: "Device") -> OutputPluginReturn:
|
||||
def execute(
|
||||
self: "OutputPluginManager", *, directive: "Directive", output: str, device: "Device"
|
||||
) -> OutputPluginReturn:
|
||||
"""Execute all output parsing plugins.
|
||||
|
||||
The result of each plugin is passed to the next plugin.
|
||||
"""
|
||||
result = output
|
||||
for plugin in self.plugins:
|
||||
for plugin in (plugin for plugin in self.plugins if directive.id in plugin.directives):
|
||||
if result is False:
|
||||
return result
|
||||
# Pass the result of each plugin to the next plugin.
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
"""Device output plugins."""
|
||||
|
||||
# Standard Library
|
||||
from typing import TYPE_CHECKING, Union
|
||||
from typing import TYPE_CHECKING, Union, Sequence
|
||||
|
||||
# Local
|
||||
from ._base import HyperglassPlugin
|
||||
from ._base import DirectivePlugin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# Project
|
||||
|
|
@ -14,9 +14,11 @@ if TYPE_CHECKING:
|
|||
OutputPluginReturn = Union[None, "ParsedRoutes", str]
|
||||
|
||||
|
||||
class OutputPlugin(HyperglassPlugin):
|
||||
class OutputPlugin(DirectivePlugin):
|
||||
"""Plugin to interact with device command output."""
|
||||
|
||||
directive_ids: Sequence[str] = ()
|
||||
|
||||
def process(self, output: Union["ParsedRoutes", str], device: "Device") -> OutputPluginReturn:
|
||||
"""Process or manipulate output from a device."""
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ def _is_class(module: Any, obj: object) -> bool:
|
|||
return isclass(obj) and obj.__module__ == module.__name__
|
||||
|
||||
|
||||
def _register_from_module(module: Any) -> Tuple[str, ...]:
|
||||
def _register_from_module(module: Any, **kwargs: Any) -> Tuple[str, ...]:
|
||||
"""Register defined classes from the module."""
|
||||
failures = ()
|
||||
defs = getmembers(module, lambda o: _is_class(module, o))
|
||||
|
|
@ -35,7 +35,7 @@ def _register_from_module(module: Any) -> Tuple[str, ...]:
|
|||
else:
|
||||
failures += (name,)
|
||||
continue
|
||||
manager.register(plugin)
|
||||
manager.register(plugin, **kwargs)
|
||||
return failures
|
||||
return failures
|
||||
|
||||
|
|
@ -57,10 +57,10 @@ def init_plugins() -> None:
|
|||
_register_from_module(_builtin)
|
||||
|
||||
|
||||
def register_plugin(plugin_file: Path) -> Tuple[str, ...]:
|
||||
def register_plugin(plugin_file: Path, **kwargs) -> Tuple[str, ...]:
|
||||
"""Register an external plugin by file path."""
|
||||
if plugin_file.exists():
|
||||
module = _module_from_file(plugin_file)
|
||||
results = _register_from_module(module)
|
||||
results = _register_from_module(module, **kwargs)
|
||||
return results
|
||||
raise FileNotFoundError(str(plugin_file))
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue