From 74fcb5dba41ba619a818e502f72769005dc80036 Mon Sep 17 00:00:00 2001 From: thatmattlove Date: Sun, 12 Sep 2021 18:27:33 -0700 Subject: [PATCH] Continue output plugin implementation --- hyperglass/api/routes.py | 1 - hyperglass/exceptions/_common.py | 4 +- hyperglass/exceptions/public.py | 24 +++++----- hyperglass/execution/drivers/_common.py | 55 +++------------------ hyperglass/execution/main.py | 5 +- hyperglass/main.py | 11 ++--- hyperglass/models/__init__.py | 7 ++- hyperglass/models/api/query.py | 2 +- hyperglass/models/commands/generic.py | 4 +- hyperglass/models/config/devices.py | 64 ++++++++++++------------- hyperglass/models/main.py | 22 +++++++++ hyperglass/plugins/_base.py | 8 +++- hyperglass/plugins/_input.py | 4 +- hyperglass/plugins/_manager.py | 26 ++++++---- hyperglass/plugins/_output.py | 8 ++-- hyperglass/plugins/main.py | 8 ++-- 16 files changed, 124 insertions(+), 129 deletions(-) diff --git a/hyperglass/api/routes.py b/hyperglass/api/routes.py index 8367dd0..94c2da4 100644 --- a/hyperglass/api/routes.py +++ b/hyperglass/api/routes.py @@ -58,7 +58,6 @@ async def send_webhook(query_data: Query, request: Request, timestamp: datetime) log.error("Error sending webhook to {}: {}", params.logging.http.provider, str(err)) -@log.catch async def query(query_data: Query, request: Request, background_tasks: BackgroundTasks): """Ingest request data pass it to the backend application to perform the query.""" diff --git a/hyperglass/exceptions/_common.py b/hyperglass/exceptions/_common.py index 117db00..3ee210a 100644 --- a/hyperglass/exceptions/_common.py +++ b/hyperglass/exceptions/_common.py @@ -120,7 +120,7 @@ class PublicHyperglassError(HyperglassError): """Format error message with keyword arguments.""" if "error" in kwargs: error = kwargs.pop("error") - error = self._safe_format(error, **kwargs) + error = self._safe_format(str(error), **kwargs) kwargs["error"] = error self._message = self._safe_format(self._message_template, **kwargs) self._keywords = list(kwargs.values()) @@ -150,7 +150,7 @@ class PrivateHyperglassError(HyperglassError): """Format error message with keyword arguments.""" if "error" in kwargs: error = kwargs.pop("error") - error = self._safe_format(error, **kwargs) + error = self._safe_format(str(error), **kwargs) kwargs["error"] = error self._message = self._safe_format(message, **kwargs) self._keywords = list(kwargs.values()) diff --git a/hyperglass/exceptions/public.py b/hyperglass/exceptions/public.py index 1ca8062..3866481 100644 --- a/hyperglass/exceptions/public.py +++ b/hyperglass/exceptions/public.py @@ -1,7 +1,7 @@ """User-facing/Public exceptions.""" # Standard Library -from typing import Any, Dict, Optional, ForwardRef +from typing import TYPE_CHECKING, Any, Dict, Optional # Project from hyperglass.configuration import params @@ -9,8 +9,10 @@ from hyperglass.configuration import params # Local from ._common import PublicHyperglassError -Query = ForwardRef("Query") -Device = ForwardRef("Device") +if TYPE_CHECKING: + # Project + from hyperglass.models.api.query import Query + from hyperglass.models.config.devices import Device class ScrapeError( @@ -18,7 +20,7 @@ class ScrapeError( ): """Raised when an SSH driver error occurs.""" - def __init__(self, error: BaseException, *, device: Device): + def __init__(self, *, error: BaseException, device: "Device"): """Initialize parent error.""" super().__init__(error=str(error), device=device.name, proxy=device.proxy) @@ -28,7 +30,7 @@ class AuthError( ): """Raised when authentication to a device fails.""" - def __init__(self, error: BaseException, *, device: Device): + def __init__(self, *, error: BaseException, device: "Device"): """Initialize parent error.""" super().__init__(error=str(error), device=device.name, proxy=device.proxy) @@ -36,7 +38,7 @@ class AuthError( class RestError(PublicHyperglassError, template=params.messages.connection_error, level="danger"): """Raised upon a rest API client error.""" - def __init__(self, error: BaseException, *, device: Device): + def __init__(self, *, error: BaseException, device: "Device"): """Initialize parent error.""" super().__init__(error=str(error), device=device.name) @@ -46,7 +48,7 @@ class DeviceTimeout( ): """Raised when the connection to a device times out.""" - def __init__(self, error: BaseException, *, device: Device): + def __init__(self, *, error: BaseException, device: "Device"): """Initialize parent error.""" super().__init__(error=str(error), device=device.name, proxy=device.proxy) @@ -55,7 +57,7 @@ class InvalidQuery(PublicHyperglassError, template=params.messages.invalid_query """Raised when input validation fails.""" def __init__( - self, error: Optional[str] = None, *, query: "Query", **kwargs: Dict[str, Any] + self, *, error: Optional[str] = None, query: "Query", **kwargs: Dict[str, Any] ) -> None: """Initialize parent error.""" @@ -107,7 +109,7 @@ class InputInvalid(PublicHyperglassError, template=params.messages.invalid_input """Raised when input validation fails.""" def __init__( - self, error: Optional[Any] = None, *, target: str, **kwargs: Dict[str, Any] + self, *, error: Optional[Any] = None, target: str, **kwargs: Dict[str, Any] ) -> None: """Initialize parent error.""" @@ -123,7 +125,7 @@ class InputNotAllowed(PublicHyperglassError, template=params.messages.acl_not_al """Raised when input validation fails due to a configured check.""" def __init__( - self, error: Optional[str] = None, *, query: Query, **kwargs: Dict[str, Any] + self, *, error: Optional[str] = None, query: "Query", **kwargs: Dict[str, Any] ) -> None: """Initialize parent error.""" @@ -143,7 +145,7 @@ class ResponseEmpty(PublicHyperglassError, template=params.messages.no_output): """Raised when hyperglass can connect to the device but the response is empty.""" def __init__( - self, error: Optional[str] = None, *, query: Query, **kwargs: Dict[str, Any] + self, *, error: Optional[str] = None, query: "Query", **kwargs: Dict[str, Any] ) -> None: """Initialize parent error.""" diff --git a/hyperglass/execution/drivers/_common.py b/hyperglass/execution/drivers/_common.py index 67e6c1c..e37aa1d 100644 --- a/hyperglass/execution/drivers/_common.py +++ b/hyperglass/execution/drivers/_common.py @@ -7,23 +7,21 @@ from typing import TYPE_CHECKING, Dict, Union, Sequence # Project from hyperglass.log import log from hyperglass.plugins import OutputPluginManager -from hyperglass.models.api import Query -from hyperglass.parsing.nos import scrape_parsers, structured_parsers -from hyperglass.parsing.common import parsers -from hyperglass.models.config.devices import Device # Local from ._construct import Construct if TYPE_CHECKING: # Project + from hyperglass.models.api import Query from hyperglass.compat._sshtunnel import SSHTunnelForwarder + from hyperglass.models.config.devices import Device class Connection(ABC): """Base transport driver class.""" - def __init__(self, device: Device, query_data: Query) -> None: + def __init__(self, device: "Device", query_data: "Query") -> None: """Initialize connection to device.""" self.device = device self.query_data = query_data @@ -38,53 +36,14 @@ class Connection(ABC): """Return a preconfigured sshtunnel.SSHTunnelForwarder instance.""" pass - async def parsed_response( # noqa: C901 ("too complex") - self, output: Sequence[str] - ) -> Union[str, Sequence[Dict]]: + async def parsed_response(self, output: Sequence[str]) -> Union[str, Sequence[Dict]]: """Send output through common parsers.""" log.debug("Pre-parsed responses:\n{}", output) - parsed = () - response = None - structured_nos = structured_parsers.keys() - structured_query_types = structured_parsers.get(self.device.nos, {}).keys() - - scrape_nos = scrape_parsers.keys() - scrape_query_types = scrape_parsers.get(self.device.nos, {}).keys() - - if not self.device.structured_output: - _parsed = () - for func in parsers: - for response in output: - _output = func(commands=self.query, output=response) - _parsed += (_output,) - if self.device.nos in scrape_nos and self.query_type in scrape_query_types: - func = scrape_parsers[self.device.nos][self.query_type] - for response in _parsed: - _output = func(response) - parsed += (_output,) - else: - parsed += _parsed - - response = "\n\n".join(parsed) - elif ( - self.device.structured_output - and self.device.nos in structured_nos - and self.query_type not in structured_query_types - ): - for func in parsers: - for response in output: - _output = func(commands=self.query, output=response) - parsed += (_output,) - response = "\n\n".join(parsed) - elif ( - self.device.structured_output - and self.device.nos in structured_nos - and self.query_type in structured_query_types - ): - func = structured_parsers[self.device.nos][self.query_type] - response = func(output) + response = self.plugin_manager.execute( + directive=self.query_data.directive, output=output, device=self.device + ) if response is None: response = "\n\n".join(output) diff --git a/hyperglass/execution/main.py b/hyperglass/execution/main.py index d0970ad..b8aa93d 100644 --- a/hyperglass/execution/main.py +++ b/hyperglass/execution/main.py @@ -55,7 +55,10 @@ async def execute(query: "Query") -> Union[str, Sequence[Dict]]: mapped_driver = map_driver(query.device.driver) driver: "Connection" = mapped_driver(query.device, query) - signal.signal(signal.SIGALRM, handle_timeout(error=TimeoutError(), device=query.device)) + signal.signal( + signal.SIGALRM, + handle_timeout(error=TimeoutError("Connection timed out"), device=query.device), + ) signal.alarm(params.request_timeout - 1) if query.device.proxy: diff --git a/hyperglass/main.py b/hyperglass/main.py index 4997134..8d753ca 100644 --- a/hyperglass/main.py +++ b/hyperglass/main.py @@ -7,7 +7,6 @@ import shutil import logging import platform from typing import TYPE_CHECKING -from pathlib import Path # Third Party from gunicorn.app.base import BaseApplication # type: ignore @@ -124,10 +123,8 @@ def cache_config() -> bool: def register_all_plugins(devices: "Devices") -> None: """Validate and register configured plugins.""" - for plugin_file in { - Path(p) for p in (p for d in devices.objects for c in d.commands for p in c.plugins) - }: - failures = register_plugin(plugin_file) + for plugin_file, directives in devices.directive_plugins().items(): + failures = register_plugin(plugin_file, directives=directives) for failure in failures: log.warning( "Plugin '{}' is not a valid hyperglass plugin, and was not registered", failure, @@ -203,11 +200,9 @@ class HyperglassWSGI(BaseApplication): def start(**kwargs): """Start hyperglass via gunicorn.""" - # Project - from hyperglass.api import app HyperglassWSGI( - app=app, + app="hyperglass.api:app", options={ "worker_class": "uvicorn.workers.UvicornWorker", "preload": True, diff --git a/hyperglass/models/__init__.py b/hyperglass/models/__init__.py index 3a064f0..c8186f9 100644 --- a/hyperglass/models/__init__.py +++ b/hyperglass/models/__init__.py @@ -1,6 +1,9 @@ """All Data Models used by hyperglass.""" # Local -from .main import HyperglassModel +from .main import HyperglassModel, HyperglassModelWithId -__all__ = ("HyperglassModel",) +__all__ = ( + "HyperglassModel", + "HyperglassModelWithId", +) diff --git a/hyperglass/models/api/query.py b/hyperglass/models/api/query.py index 47752ca..b251163 100644 --- a/hyperglass/models/api/query.py +++ b/hyperglass/models/api/query.py @@ -168,7 +168,7 @@ class Query(BaseModel): def validate_query_location(cls, value): """Ensure query_location is defined.""" - valid_id = value in devices._ids + valid_id = value in devices.ids valid_hostname = value in devices.hostnames if not any((valid_id, valid_hostname)): diff --git a/hyperglass/models/commands/generic.py b/hyperglass/models/commands/generic.py index 7f42158..9c4789a 100644 --- a/hyperglass/models/commands/generic.py +++ b/hyperglass/models/commands/generic.py @@ -23,7 +23,7 @@ from hyperglass.log import log from hyperglass.exceptions.private import InputValidationError # Local -from ..main import HyperglassModel +from ..main import HyperglassModel, HyperglassModelWithId from ..fields import Action from ..config.params import Params @@ -224,7 +224,7 @@ class RuleWithoutValidation(Rule): Rules = Union[RuleWithIPv4, RuleWithIPv6, RuleWithPattern, RuleWithoutValidation] -class Directive(HyperglassModel): +class Directive(HyperglassModelWithId): """A directive contains commands that can be run on a device, as long as defined rules are met.""" id: StrictStr diff --git a/hyperglass/models/config/devices.py b/hyperglass/models/config/devices.py index a3cf2d1..1731f70 100644 --- a/hyperglass/models/config/devices.py +++ b/hyperglass/models/config/devices.py @@ -3,19 +3,12 @@ # Standard Library import os import re -from typing import Any, Dict, List, Tuple, Union, Optional +from typing import Any, Set, Dict, List, Tuple, Union, Optional from pathlib import Path from ipaddress import IPv4Address, IPv6Address # Third Party -from pydantic import ( - StrictInt, - StrictStr, - StrictBool, - PrivateAttr, - validator, - root_validator, -) +from pydantic import StrictInt, StrictStr, StrictBool, validator, root_validator # Project from hyperglass.log import log @@ -26,7 +19,7 @@ from hyperglass.models.commands.generic import Directive # Local from .ssl import Ssl -from ..main import HyperglassModel +from ..main import HyperglassModel, HyperglassModelWithId from .proxy import Proxy from .params import Params from ..fields import SupportedDriver @@ -34,10 +27,10 @@ from .network import Network from .credential import Credential -class Device(HyperglassModel, extra="allow"): +class Device(HyperglassModelWithId, extra="allow"): """Validation model for per-router config in devices.yaml.""" - _id: StrictStr = PrivateAttr() + id: StrictStr name: StrictStr address: Union[IPv4Address, IPv6Address, StrictStr] network: Network @@ -55,23 +48,9 @@ class Device(HyperglassModel, extra="allow"): def __init__(self, **kwargs) -> None: """Set the device ID.""" _id, values = self._generate_id(kwargs) - super().__init__(**values) - self._id = _id + super().__init__(id=_id, **values) self._validate_directive_attrs() - def __hash__(self) -> int: - """Make device object hashable so the object can be deduplicated with set().""" - return hash((self.name,)) - - def __eq__(self, other: Any) -> bool: - """Make device object comparable so the object can be deduplicated with set().""" - result = False - - if isinstance(other, HyperglassModel): - result = self.name == other.name - - return result - @property def _target(self): return str(self.address) @@ -104,7 +83,7 @@ class Device(HyperglassModel, extra="allow"): def export_api(self) -> Dict[str, Any]: """Export API-facing device fields.""" return { - "id": self._id, + "id": self.id, "name": self.name, "network": self.network.display_name, } @@ -233,7 +212,7 @@ class Device(HyperglassModel, extra="allow"): class Devices(HyperglassModel, extra="allow"): """Validation model for device configurations.""" - _ids: List[StrictStr] = [] + ids: List[StrictStr] = [] hostnames: List[StrictStr] = [] objects: List[Device] = [] all_nos: List[StrictStr] = [] @@ -248,7 +227,7 @@ class Devices(HyperglassModel, extra="allow"): all_nos = set() objects = set() hostnames = set() - _ids = set() + ids = set() init_kwargs = {} @@ -261,13 +240,13 @@ class Devices(HyperglassModel, extra="allow"): # list with `devices.hostnames`, same for all router # classes, for when iteration over all routers is required. hostnames.add(device.name) - _ids.add(device._id) + ids.add(device.id) objects.add(device) all_nos.add(device.nos) # Convert the de-duplicated sets to a standard list, add lists # as class attributes. Sort router list by router name attribute - init_kwargs["_ids"] = list(_ids) + init_kwargs["ids"] = list(ids) init_kwargs["hostnames"] = list(hostnames) init_kwargs["all_nos"] = list(all_nos) init_kwargs["objects"] = sorted(objects, key=lambda x: x.name) @@ -277,7 +256,7 @@ class Devices(HyperglassModel, extra="allow"): def __getitem__(self, accessor: str) -> Device: """Get a device by its name.""" for device in self.objects: - if device._id == accessor: + if device.id == accessor: return device elif device.name == accessor: return device @@ -296,7 +275,7 @@ class Devices(HyperglassModel, extra="allow"): "display_name": name, "locations": [ { - "id": device._id, + "id": device.id, "name": device.name, "network": device.network.display_name, "directives": [c.frontend(params) for c in device.commands], @@ -307,3 +286,20 @@ class Devices(HyperglassModel, extra="allow"): } for name in names ] + + def directive_plugins(self) -> Dict[Path, Tuple[StrictStr]]: + """Get a mapping of plugin paths to associated directive IDs.""" + result: Dict[Path, Set[StrictStr]] = {} + # Unique set of all directives. + directives = {directive for device in self.objects for directive in device.commands} + # Unique set of all plugin file names. + plugin_names = {plugin for directive in directives for plugin in directive.plugins} + + for directive in directives: + # Convert each plugin file name to a `Path` object. + for plugin in (Path(p) for p in directive.plugins if p in plugin_names): + if plugin not in result: + result[plugin] = set() + result[plugin].add(directive.id) + # Convert the directive set to a tuple. + return {k: tuple(v) for k, v in result.items()} diff --git a/hyperglass/models/main.py b/hyperglass/models/main.py index 3a4b95e..5cee705 100644 --- a/hyperglass/models/main.py +++ b/hyperglass/models/main.py @@ -80,3 +80,25 @@ class HyperglassModel(BaseModel): } return yaml.safe_dump(json.loads(self.export_json(**export_kwargs)), *args, **kwargs) + + +class HyperglassModelWithId(HyperglassModel): + """hyperglass model that is unique by its `id` field.""" + + id: str + + def __eq__(self: "HyperglassModelWithId", other: "HyperglassModelWithId") -> bool: + """Other model is equal to this model.""" + if not isinstance(other, self.__class__): + return False + if hasattr(other, "id"): + return other and self.id == other.id + return False + + def __ne__(self: "HyperglassModelWithId", other: "HyperglassModelWithId") -> bool: + """Other model is not equal to this model.""" + return not self.__eq__(other) + + def __hash__(self: "HyperglassModelWithId") -> int: + """Create a hashed representation of this model's name.""" + return hash(self.id) diff --git a/hyperglass/plugins/_base.py b/hyperglass/plugins/_base.py index 5b79833..977bf2f 100644 --- a/hyperglass/plugins/_base.py +++ b/hyperglass/plugins/_base.py @@ -2,7 +2,7 @@ # Standard Library from abc import ABC -from typing import Any, Union, Literal, TypeVar +from typing import Any, Union, Literal, TypeVar, Sequence from inspect import Signature # Third Party @@ -52,3 +52,9 @@ class HyperglassPlugin(BaseModel, ABC): """Initialize plugin instance.""" name = kwargs.pop("name", None) or self.__class__.__name__ super().__init__(name=name, **kwargs) + + +class DirectivePlugin(HyperglassPlugin): + """Plugin associated with directives.""" + + directives: Sequence[str] = () diff --git a/hyperglass/plugins/_input.py b/hyperglass/plugins/_input.py index 513cd2a..532aaf1 100644 --- a/hyperglass/plugins/_input.py +++ b/hyperglass/plugins/_input.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Union # Local -from ._base import HyperglassPlugin +from ._base import DirectivePlugin if TYPE_CHECKING: # Project @@ -13,7 +13,7 @@ if TYPE_CHECKING: InputPluginReturn = Union[None, bool] -class InputPlugin(HyperglassPlugin): +class InputPlugin(DirectivePlugin): """Plugin to validate user input prior to running commands.""" def validate(self, query: "Query") -> InputPluginReturn: diff --git a/hyperglass/plugins/_manager.py b/hyperglass/plugins/_manager.py index c6fb6fa..f2a7c5d 100644 --- a/hyperglass/plugins/_manager.py +++ b/hyperglass/plugins/_manager.py @@ -4,7 +4,7 @@ import json import codecs import pickle -from typing import TYPE_CHECKING, List, Generic, TypeVar, Callable, Generator +from typing import TYPE_CHECKING, Any, List, Generic, TypeVar, Callable, Generator from inspect import isclass # Project @@ -22,6 +22,7 @@ if TYPE_CHECKING: # Project from hyperglass.models.api.query import Query from hyperglass.models.config.devices import Device + from hyperglass.models.commands.generic import Directive PluginT = TypeVar("PluginT") @@ -73,7 +74,7 @@ class PluginManager(Generic[PluginT]): def plugins(self: "PluginManager") -> List[PluginT]: """Get all plugins, with built-in plugins last.""" return sorted( - self.plugins, + self._get_plugins(), key=lambda p: -1 if p.__hyperglass_builtin__ else 1, # flake8: noqa IF100 reverse=True, ) @@ -117,12 +118,12 @@ class PluginManager(Generic[PluginT]): return raise PluginError("Plugin '{}' is not a valid hyperglass plugin", repr(plugin)) - def register(self: "PluginManager", plugin: PluginT) -> None: + def register(self: "PluginManager", plugin: PluginT, *args: Any, **kwargs: Any) -> None: """Add a plugin to currently active plugins.""" # Create a set of plugins so duplicate plugins are not mistakenly added. try: if issubclass(plugin, HyperglassPlugin): - instance = plugin() + instance = plugin(*args, **kwargs) plugins = { # Create a base64 representation of a picked plugin. codecs.encode(pickle.dumps(p), "base64").decode() @@ -131,7 +132,10 @@ class PluginManager(Generic[PluginT]): } # Add plugins from cache. self._cache.set(f"hyperglass.plugins.{self._type}", json.dumps(list(plugins))) - log.success("Registered plugin '{}'", instance.name) + if instance.__hyperglass_builtin__ is True: + log.debug("Registered built-in plugin '{}'", instance.name) + else: + log.success("Registered plugin '{}'", instance.name) return except TypeError: raise PluginError( @@ -145,13 +149,15 @@ class PluginManager(Generic[PluginT]): class InputPluginManager(PluginManager[InputPlugin], type="input"): """Manage Input Validation Plugins.""" - def execute(self: "InputPluginManager", query: "Query") -> InputPluginReturn: + def execute( + self: "InputPluginManager", *, directive: "Directive", query: "Query" + ) -> InputPluginReturn: """Execute all input validation plugins. If any plugin returns `False`, execution is halted. """ result = None - for plugin in self.plugins: + for plugin in (plugin for plugin in self.plugins if directive.id in plugin.directives): if result is False: return result result = plugin.validate(query) @@ -161,13 +167,15 @@ class InputPluginManager(PluginManager[InputPlugin], type="input"): class OutputPluginManager(PluginManager[OutputPlugin], type="output"): """Manage Output Processing Plugins.""" - def execute(self: "OutputPluginManager", output: str, device: "Device") -> OutputPluginReturn: + def execute( + self: "OutputPluginManager", *, directive: "Directive", output: str, device: "Device" + ) -> OutputPluginReturn: """Execute all output parsing plugins. The result of each plugin is passed to the next plugin. """ result = output - for plugin in self.plugins: + for plugin in (plugin for plugin in self.plugins if directive.id in plugin.directives): if result is False: return result # Pass the result of each plugin to the next plugin. diff --git a/hyperglass/plugins/_output.py b/hyperglass/plugins/_output.py index b7e3cf9..7e0f219 100644 --- a/hyperglass/plugins/_output.py +++ b/hyperglass/plugins/_output.py @@ -1,10 +1,10 @@ """Device output plugins.""" # Standard Library -from typing import TYPE_CHECKING, Union +from typing import TYPE_CHECKING, Union, Sequence # Local -from ._base import HyperglassPlugin +from ._base import DirectivePlugin if TYPE_CHECKING: # Project @@ -14,9 +14,11 @@ if TYPE_CHECKING: OutputPluginReturn = Union[None, "ParsedRoutes", str] -class OutputPlugin(HyperglassPlugin): +class OutputPlugin(DirectivePlugin): """Plugin to interact with device command output.""" + directive_ids: Sequence[str] = () + def process(self, output: Union["ParsedRoutes", str], device: "Device") -> OutputPluginReturn: """Process or manipulate output from a device.""" return None diff --git a/hyperglass/plugins/main.py b/hyperglass/plugins/main.py index 48a273e..aef1645 100644 --- a/hyperglass/plugins/main.py +++ b/hyperglass/plugins/main.py @@ -23,7 +23,7 @@ def _is_class(module: Any, obj: object) -> bool: return isclass(obj) and obj.__module__ == module.__name__ -def _register_from_module(module: Any) -> Tuple[str, ...]: +def _register_from_module(module: Any, **kwargs: Any) -> Tuple[str, ...]: """Register defined classes from the module.""" failures = () defs = getmembers(module, lambda o: _is_class(module, o)) @@ -35,7 +35,7 @@ def _register_from_module(module: Any) -> Tuple[str, ...]: else: failures += (name,) continue - manager.register(plugin) + manager.register(plugin, **kwargs) return failures return failures @@ -57,10 +57,10 @@ def init_plugins() -> None: _register_from_module(_builtin) -def register_plugin(plugin_file: Path) -> Tuple[str, ...]: +def register_plugin(plugin_file: Path, **kwargs) -> Tuple[str, ...]: """Register an external plugin by file path.""" if plugin_file.exists(): module = _module_from_file(plugin_file) - results = _register_from_module(module) + results = _register_from_module(module, **kwargs) return results raise FileNotFoundError(str(plugin_file))