1
0
Fork 1
mirror of https://github.com/thatmattlove/hyperglass.git synced 2026-01-17 08:48:05 +00:00

Add failure_reason to InputPlugin, add BGP Community builtin plugin and tests

This commit is contained in:
thatmattlove 2021-09-16 18:32:59 -07:00
parent 37a8e40bfe
commit cb5459a72a
3 changed files with 147 additions and 3 deletions

View file

@ -0,0 +1,109 @@
"""Remove anything before the command if found in output."""
# Standard Library
import typing as t
from ipaddress import ip_address
# Third Party
from pydantic import PrivateAttr
# Project
from hyperglass.state.hooks import use_state
# Local
from .._input import InputPlugin, InputPluginReturn
if t.TYPE_CHECKING:
# Project
from hyperglass.models.api.query import Query
_32BIT = 0xFFFFFFFF
_16BIT = 0xFFFF
EXTENDED_TYPES = ("target", "origin")
def check_decimal(value: str, size: int) -> bool:
"""Verify the value is a 32 bit number."""
try:
return abs(int(value)) <= size
except Exception:
return False
def check_string(value: str) -> bool:
"""Verify part of a community is an IPv4 address, per RFC4360."""
try:
addr = ip_address(value)
return addr.version == 4
except ValueError:
return False
def validate_decimal(value: str) -> bool:
"""Verify a community is a 32 bit decimal number."""
return check_decimal(value, _32BIT)
def validate_new_format(value: str) -> bool:
"""Verify a community matches "new" format, standard or extended."""
if ":" in value:
parts = [p for p in value.split(":") if p]
if len(parts) == 3:
if parts[0].lower() not in EXTENDED_TYPES:
# Handle extended community format with `target:` or `origin:` prefix.
return False
# Remove type from parts list after it's been validated.
parts = parts[1:]
if len(parts) != 2:
# Only allow two sections in new format, e.g. 65000:1
return False
one, two = parts
if all((check_decimal(one, _16BIT), check_decimal(two, _16BIT))):
# Handle standard format, e.g. `65000:1`
return True
elif all((check_decimal(one, _16BIT), check_decimal(two, _32BIT))):
# Handle extended format, e.g. `65000:4294967295`
return True
elif all((check_string(one), check_decimal(two, _16BIT))):
# Handle IP address format, e.g. `192.0.2.1:65000`
return True
return False
def validate_large_community(value: str) -> bool:
"""Verify a community matches "large" format. E.g., `65000:65001:65002`."""
if ":" in value:
parts = [p for p in value.split(":") if p]
if len(parts) != 3:
return False
for part in parts:
if not check_decimal(part, _32BIT):
# Each member must be a 32 bit number.
return False
return True
return False
class ValidateBGPCommunity(InputPlugin):
"""Validate a BGP community string."""
__hyperglass_builtin__: bool = PrivateAttr(True)
def validate(self, query: "Query") -> InputPluginReturn:
"""Ensure an input query target is a valid BGP community."""
params = use_state("params")
if not isinstance(query.query_target, str):
return None
for validator in (validate_decimal, validate_new_format, validate_large_community):
result = validator(query.query_target)
if result is True:
return True
self.failure_reason = params.messages.invalid_input
return False

View file

@ -1,21 +1,23 @@
"""Input validation plugins."""
# Standard Library
from typing import TYPE_CHECKING, Union
import typing as t
# Local
from ._base import DirectivePlugin
if TYPE_CHECKING:
if t.TYPE_CHECKING:
# Project
from hyperglass.models.api.query import Query
InputPluginReturn = Union[None, bool]
InputPluginReturn = t.Union[None, bool]
class InputPlugin(DirectivePlugin):
"""Plugin to validate user input prior to running commands."""
failure_reason: t.Optional[str] = None
def validate(self, query: "Query") -> InputPluginReturn:
"""Validate input from hyperglass UI/API."""
return None

View file

@ -0,0 +1,33 @@
"""Test BGP Community validation."""
# Local
from .._builtin.bgp_community import ValidateBGPCommunity
CHECKS = (
("32768", True),
("65000:1", True),
("65000:4294967296", False),
("4294967295:65000", False),
("192.0.2.1:65000", True),
("65000:192.0.2.1", False),
("target:65000:1", True),
("origin:65001:1", True),
("wrong:65000:1", False),
("65000:65001:65002", True),
("4294967295:4294967294:4294967293", True),
("65000:4294967295:1", True),
("65000:192.0.2.1:1", False),
("gibberish", False),
("192.0.2.1", False),
(True, None),
(type("FakeClass", (), {}), None),
)
def test_bgp_community():
plugin = ValidateBGPCommunity()
for value, expected in CHECKS:
query = type("Query", (), {"query_target": value})
result = plugin.validate(query)
assert result == expected, f"Invalid value {value!r}"