mirror of
https://github.com/thatmattlove/hyperglass.git
synced 2026-01-17 08:48:05 +00:00
#142: Start multiple query target implementation
This commit is contained in:
parent
433b5cecee
commit
9300105d9f
2 changed files with 46 additions and 16 deletions
|
|
@ -22,13 +22,15 @@ from ..config.devices import Device
|
|||
|
||||
(TEXT := use_state("params").web.text)
|
||||
|
||||
QueryTarget = constr(strip_whitespace=True, min_length=1)
|
||||
|
||||
|
||||
class Query(BaseModel):
|
||||
"""Validation model for input query parameters."""
|
||||
|
||||
query_location: StrictStr # Device `name` field
|
||||
query_type: StrictStr # Directive `id` field
|
||||
query_target: constr(strip_whitespace=True, min_length=1)
|
||||
query_target: t.Union[t.List[QueryTarget], QueryTarget]
|
||||
|
||||
class Config:
|
||||
"""Pydantic model configuration."""
|
||||
|
|
|
|||
|
|
@ -18,6 +18,10 @@ from hyperglass.exceptions.private import InputValidationError
|
|||
from .main import MultiModel, HyperglassModel, HyperglassUniqueModel
|
||||
from .fields import Action
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
# Project
|
||||
from hyperglass.models.api.query import QueryTarget
|
||||
|
||||
IPv4PrefixLength = conint(ge=0, le=32)
|
||||
IPv6PrefixLength = conint(ge=0, le=128)
|
||||
IPNetwork = t.Union[IPv4Network, IPv6Network]
|
||||
|
|
@ -82,7 +86,7 @@ class Rule(HyperglassModel, allow_population_by_field_name=True):
|
|||
return [value]
|
||||
return value
|
||||
|
||||
def validate_target(self, target: str) -> bool:
|
||||
def validate_target(self, target: str, *, multiple: bool) -> bool:
|
||||
"""Validate a query target (Placeholder signature)."""
|
||||
raise NotImplementedError(
|
||||
f"{self._validation} rule does not implement a 'validate_target()' method"
|
||||
|
|
@ -119,8 +123,11 @@ class RuleWithIP(Rule):
|
|||
|
||||
return False
|
||||
|
||||
def validate_target(self, target: str) -> bool:
|
||||
def validate_target(self, target: "QueryTarget", *, multiple: bool) -> bool:
|
||||
"""Validate an IP address target against this rule's conditions."""
|
||||
if isinstance(target, t.List):
|
||||
self._passed = False
|
||||
raise InputValidationError("Target must be a single value")
|
||||
try:
|
||||
# Attempt to use IP object factory to create an IP address object
|
||||
valid_target = ip_network(target)
|
||||
|
|
@ -181,23 +188,43 @@ class RuleWithPattern(Rule):
|
|||
_validation: RuleValidation = PrivateAttr("pattern")
|
||||
condition: StrictStr
|
||||
|
||||
def validate_target(self, target: str) -> str:
|
||||
def validate_target(self, target: "QueryTarget", *, multiple: bool) -> str: # noqa: C901
|
||||
"""Validate a string target against configured regex patterns."""
|
||||
|
||||
if self.condition == "*":
|
||||
pattern = re.compile(".+", re.IGNORECASE)
|
||||
else:
|
||||
pattern = re.compile(self.condition, re.IGNORECASE)
|
||||
def validate_single_value(value: str) -> t.Union[bool, BaseException]:
|
||||
if self.condition == "*":
|
||||
pattern = re.compile(".+", re.IGNORECASE)
|
||||
else:
|
||||
pattern = re.compile(self.condition, re.IGNORECASE)
|
||||
is_match = pattern.match(value)
|
||||
|
||||
is_match = pattern.match(target)
|
||||
if is_match and self.action == "permit":
|
||||
if is_match and self.action == "permit":
|
||||
return True
|
||||
elif is_match and self.action == "deny":
|
||||
return InputValidationError(target=value, error="Denied")
|
||||
return False
|
||||
|
||||
if isinstance(target, t.List) and multiple:
|
||||
for result in (validate_single_value(v) for v in target):
|
||||
if isinstance(result, BaseException):
|
||||
self._passed = False
|
||||
raise result
|
||||
elif result is False:
|
||||
self._passed = False
|
||||
return result
|
||||
self._passed = True
|
||||
return True
|
||||
elif is_match and self.action == "deny":
|
||||
self._passed = False
|
||||
raise InputValidationError(target=target, error="Denied")
|
||||
|
||||
return False
|
||||
elif isinstance(target, t.List) and not multiple:
|
||||
raise InputValidationError("Target must be a single value")
|
||||
|
||||
result = validate_single_value(target)
|
||||
|
||||
if isinstance(result, BaseException):
|
||||
self._passed = False
|
||||
raise result
|
||||
self._passed = result
|
||||
return result
|
||||
|
||||
|
||||
class RuleWithoutValidation(Rule):
|
||||
|
|
@ -206,7 +233,7 @@ class RuleWithoutValidation(Rule):
|
|||
_validation: RuleValidation = PrivateAttr(None)
|
||||
condition: None
|
||||
|
||||
def validate_target(self, target: str) -> t.Literal[True]:
|
||||
def validate_target(self, target: str, *, multiple: bool) -> t.Literal[True]:
|
||||
"""Don't validate a target. Always returns `True`."""
|
||||
self._passed = True
|
||||
return True
|
||||
|
|
@ -229,11 +256,12 @@ class Directive(HyperglassUniqueModel, unique_by=("id", "table_output")):
|
|||
disable_builtins: StrictBool = False
|
||||
table_output: t.Optional[StrictStr]
|
||||
groups: t.List[StrictStr] = []
|
||||
multiple: StrictBool = False
|
||||
|
||||
def validate_target(self, target: str) -> bool:
|
||||
"""Validate a target against all configured rules."""
|
||||
for rule in self.rules:
|
||||
valid = rule.validate_target(target)
|
||||
valid = rule.validate_target(target, multiple=self.multiple)
|
||||
if valid is True:
|
||||
return True
|
||||
continue
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue