diff --git a/docs/docs/adding-devices.mdx b/docs/docs/adding-devices.mdx
index e7cb51b..bc8be79 100644
--- a/docs/docs/adding-devices.mdx
+++ b/docs/docs/adding-devices.mdx
@@ -49,6 +49,7 @@ routers:
| `port` | Integer | TCP port used to connect to the device. `22` by default. |
| `nos` | String | Network Operating System. Must be a supported platform. |
| `structured_output` | Boolean | Disabled output parsing to structured data. |
+| `driver` | String | Override the device driver. Must be 'scrapli' or 'netmiko'. |
| `credential` | | [Device Credential Configuration](#credential) |
| `vrfs` | | [Device VRF Configuration](#vrfs) |
| `proxy` | | [SSH Proxy Configuration](#proxy) |
@@ -131,11 +132,11 @@ May be set to `null` to disable IPv4 for this VRF, on the parent device.
May be set to `null` to disable IPv6 for this VRF, on the parent device.
-| Parameter | Type | Default | Description |
-| :-------------------- | :------ | :------ | :------------------------------------------------------------------------------------------------------------------------------------- |
-| `source_address` | String | | Device's source IPv6 address for directed queries (ping, traceroute). This address must be surrounded by quotes. Ex. "0000:0000:0000::"|
-| `force_cidr` | Boolean | `true` | Convert IP host queries to actual advertised containing prefix length |
-| `access_list` | | | [IPv6 Access List Configuration](#access_list) |
+| Parameter | Type | Default | Description |
+| :-------------------- | :------ | :------ | :-------------------------------------------------------------------------------------------------------------------------------------- |
+| `source_address` | String | | Device's source IPv6 address for directed queries (ping, traceroute). This address must be surrounded by quotes. Ex. "0000:0000:0000::" |
+| `force_cidr` | Boolean | `true` | Convert IP host queries to actual advertised containing prefix length |
+| `access_list` | | | [IPv6 Access List Configuration](#access_list) |
:::note
The `force_cidr` option will ensure that a **BGP Route** query for an IP host (/32 IPv4, /128 IPv6) is converted to its containing prefix. For example, a query for `1.1.1.1` would be converted to a query for `1.1.1.0/24`. This is because not all platforms support a BGP lookup for a host (this is primary a problem with IPv6, but the option applies to both address families).
diff --git a/hyperglass/execution/drivers/__init__.py b/hyperglass/execution/drivers/__init__.py
index d83c8cc..73a6ada 100644
--- a/hyperglass/execution/drivers/__init__.py
+++ b/hyperglass/execution/drivers/__init__.py
@@ -2,5 +2,6 @@
# Local
from .agent import AgentConnection
+from ._common import Connection
from .ssh_netmiko import NetmikoConnection
from .ssh_scrapli import ScrapliConnection
diff --git a/hyperglass/execution/main.py b/hyperglass/execution/main.py
index 553c126..3ccddd9 100644
--- a/hyperglass/execution/main.py
+++ b/hyperglass/execution/main.py
@@ -12,19 +12,24 @@ from typing import Any, Dict, Union, Callable, Sequence
# Project
from hyperglass.log import log
-from hyperglass.util import validate_nos
from hyperglass.exceptions import DeviceTimeout, ResponseEmpty
from hyperglass.models.api import Query
from hyperglass.configuration import params
# Local
-from .drivers import AgentConnection, NetmikoConnection, ScrapliConnection
+from .drivers import Connection, AgentConnection, NetmikoConnection, ScrapliConnection
-DRIVER_MAP = {
- "scrapli": ScrapliConnection,
- "netmiko": NetmikoConnection,
- "hyperglass_agent": AgentConnection,
-}
+
+def map_driver(driver_name: str) -> Connection:
+ """Get the correct driver class based on the driver name."""
+
+ if driver_name == "scrapli":
+ return ScrapliConnection
+
+ elif driver_name == "hyperglass_agent":
+ return AgentConnection
+
+ return NetmikoConnection
def handle_timeout(**exc_args: Any) -> Callable:
@@ -44,9 +49,7 @@ async def execute(query: Query) -> Union[str, Sequence[Dict]]:
log.debug("Received query for {}", query.json())
log.debug("Matched device config: {}", query.device)
- supported, driver_name = validate_nos(query.device.nos)
-
- mapped_driver = DRIVER_MAP.get(driver_name, NetmikoConnection)
+ mapped_driver = map_driver(query.device.driver)
driver = mapped_driver(query.device, query)
timeout_args = {
diff --git a/hyperglass/models/config/devices.py b/hyperglass/models/config/devices.py
index e57f90e..a353610 100644
--- a/hyperglass/models/config/devices.py
+++ b/hyperglass/models/config/devices.py
@@ -19,7 +19,7 @@ from pydantic import (
# Project
from hyperglass.log import log
-from hyperglass.util import validate_nos, resolve_hostname
+from hyperglass.util import get_driver, validate_nos, resolve_hostname
from hyperglass.constants import SCRAPE_HELPERS, SUPPORTED_STRUCTURED_OUTPUT
from hyperglass.exceptions import ConfigError, UnsupportedDevice
@@ -28,6 +28,7 @@ from .ssl import Ssl
from .vrf import Vrf, Info
from ..main import HyperglassModel, HyperglassModelExtra
from .proxy import Proxy
+from ..fields import SupportedDriver
from .network import Network
from .credential import Credential
@@ -93,6 +94,7 @@ class Device(HyperglassModel):
display_vrfs: List[StrictStr] = []
vrf_names: List[StrictStr] = []
structured_output: Optional[StrictBool]
+ driver: Optional[SupportedDriver]
def __init__(self, **kwargs) -> None:
"""Set the device ID."""
@@ -130,15 +132,9 @@ class Device(HyperglassModel):
return value
@validator("structured_output", pre=True, always=True)
- def validate_structured_output(cls, value, values):
- """Validate structured output is supported on the device & set a default.
+ def validate_structured_output(cls, value: bool, values: Dict) -> bool:
+ """Validate structured output is supported on the device & set a default."""
- Raises:
- ConfigError: Raised if true on a device that doesn't support structured output.
-
- Returns:
- {bool} -- True if hyperglass should return structured output for this device.
- """
if value is True and values["nos"] not in SUPPORTED_STRUCTURED_OUTPUT:
raise ConfigError(
"The 'structured_output' field is set to 'true' on device '{d}' with "
@@ -280,6 +276,11 @@ class Device(HyperglassModel):
vrfs.append(vrf)
return vrfs
+ @validator("driver")
+ def validate_driver(cls, value: Optional[str], values: Dict) -> Dict:
+ """Set the correct driver and override if supported."""
+ return get_driver(values["nos"], value)
+
class Devices(HyperglassModelExtra):
"""Validation model for device configurations."""
diff --git a/hyperglass/models/fields.py b/hyperglass/models/fields.py
index 1787714..bd4d734 100644
--- a/hyperglass/models/fields.py
+++ b/hyperglass/models/fields.py
@@ -5,10 +5,12 @@ import re
from typing import TypeVar
# Third Party
-from pydantic import StrictInt, StrictFloat
+from pydantic import StrictInt, StrictFloat, constr
IntFloat = TypeVar("IntFloat", StrictInt, StrictFloat)
+SupportedDriver = constr(regex=r"(scrapli|netmiko|hyperglass_agent)")
+
class StrictBytes(bytes):
"""Custom data type for a strict byte string.
diff --git a/hyperglass/util/__init__.py b/hyperglass/util/__init__.py
index 82924e8..cde1100 100644
--- a/hyperglass/util/__init__.py
+++ b/hyperglass/util/__init__.py
@@ -6,16 +6,21 @@ import sys
import json
import platform
from queue import Queue
-from typing import Dict, Union, Generator
+from typing import Dict, Union, Optional, Generator
from asyncio import iscoroutine
from pathlib import Path
from ipaddress import IPv4Address, IPv6Address, ip_address
# Third Party
from loguru._logger import Logger as LoguruLogger
+from netmiko.ssh_dispatcher import CLASS_MAPPER
# Project
from hyperglass.log import log
+from hyperglass.constants import DRIVER_MAP
+
+ALL_NOS = {*DRIVER_MAP.keys(), *CLASS_MAPPER.keys()}
+ALL_DRIVERS = {*DRIVER_MAP.values(), "netmiko"}
def cpu_count(multiplier: int = 0) -> int:
@@ -250,22 +255,30 @@ def make_repr(_class):
def validate_nos(nos):
"""Validate device NOS is supported."""
- # Third Party
- from netmiko.ssh_dispatcher import CLASS_MAPPER
-
- # Project
- from hyperglass.constants import DRIVER_MAP
result = (False, None)
- all_nos = {*DRIVER_MAP.keys(), *CLASS_MAPPER.keys()}
-
- if nos in all_nos:
+ if nos in ALL_NOS:
result = (True, DRIVER_MAP.get(nos, "netmiko"))
return result
+def get_driver(nos: str, driver: Optional[str]) -> str:
+ """Determine the appropriate driver for a device."""
+
+ if driver is None:
+ # If no driver is set, use the driver map with netmiko as
+ # fallback.
+ return DRIVER_MAP.get(nos, "netmiko")
+ elif driver in ALL_DRIVERS:
+ # If a driver is set and it is valid, allow it.
+ return driver
+ else:
+ # Otherwise, fail validation.
+ raise ValueError("{} is not a supported driver.".format(driver))
+
+
def current_log_level(logger: LoguruLogger) -> str:
"""Get the current log level of a logger instance."""