1
0
Fork 1
mirror of https://github.com/thatmattlove/hyperglass.git synced 2026-01-17 08:48:05 +00:00
thatmattlove-hyperglass/hyperglass/models/api/query.py
2024-05-27 16:25:47 -04:00

140 lines
4.7 KiB
Python

"""Input query validation model."""
# Standard Library
import typing as t
import hashlib
import secrets
from datetime import datetime
# Third Party
from pydantic import Field, BaseModel, ConfigDict, field_validator
# Project
from hyperglass.log import log
from hyperglass.util import snake_to_camel, repr_from_attrs
from hyperglass.state import use_state
from hyperglass.plugins import InputPluginManager
from hyperglass.exceptions.public import InputInvalid, QueryTypeNotFound, QueryLocationNotFound
from hyperglass.exceptions.private import InputValidationError
# Local
from ..config.devices import Device
class SimpleQuery(BaseModel):
"""A simple representation of a post-validated query."""
query_location: str
query_target: t.Union[t.List[str], str]
query_type: str
def __repr_name__(self) -> str:
"""Alias SimpleQuery to Query for clarity in logging."""
return "Query"
class Query(BaseModel):
"""Validation model for input query parameters."""
model_config = ConfigDict(extra="allow", alias_generator=snake_to_camel, populate_by_name=True)
# Device `name` field
query_location: str = Field(strict=True, min_length=1, strip_whitespace=True)
query_target: t.Union[t.List[str], str] = Field(min_length=1, strip_whitespace=True)
# Directive `id` field
query_type: str = Field(strict=True, min_length=1, strip_whitespace=True)
_kwargs: t.Dict[str, t.Any]
def __init__(self, **data) -> None:
"""Initialize the query with a UTC timestamp at initialization time."""
super().__init__(**data)
self._kwargs = data
self.timestamp = datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S")
state = use_state()
self._state = state
query_directives = self.device.directives.matching(self.query_type)
if len(query_directives) < 1:
raise QueryTypeNotFound(query_type=self.query_type)
self.directive = query_directives[0]
self._input_plugin_manager = InputPluginManager()
self.query_target = self.transform_query_target()
try:
self.validate_query_target()
except InputValidationError as err:
raise InputInvalid(**err.kwargs) from err
def summary(self) -> SimpleQuery:
"""Summarized and post-validated model of a Query."""
return SimpleQuery(
query_location=self.query_location,
query_target=self.query_target,
query_type=self.query_type,
)
def __repr__(self) -> str:
"""Represent only the query fields."""
return repr_from_attrs(self, ("query_location", "query_type", "query_target"))
def __str__(self) -> str:
"""Alias __str__ to __repr__."""
return repr(self)
def digest(self) -> str:
"""Create SHA256 hash digest of model representation."""
return hashlib.sha256(repr(self).encode()).hexdigest()
def random(self) -> str:
"""Create a random string to prevent client or proxy caching."""
return hashlib.sha256(
secrets.token_bytes(8) + repr(self).encode() + secrets.token_bytes(8)
).hexdigest()
def validate_query_target(self) -> None:
"""Validate a query target after all fields/relationships have been initialized."""
# Run config/rule-based validations.
self.directive.validate_target(self.query_target)
# Run plugin-based validations.
self._input_plugin_manager.validate(query=self)
log.bind(query=self.summary()).debug("Validation passed")
def transform_query_target(self) -> t.Union[t.List[str], str]:
"""Transform a query target based on defined plugins."""
return self._input_plugin_manager.transform(query=self)
def dict(self) -> t.Dict[str, t.Union[t.List[str], str]]:
"""Include only public fields."""
return super().model_dump(include={"query_location", "query_target", "query_type"})
@property
def device(self) -> Device:
"""Get this query's device object by query_location."""
return self._state.devices[self.query_location]
@field_validator("query_location")
def validate_query_location(cls, value):
"""Ensure query_location is defined."""
devices = use_state("devices")
if not devices.valid_id_or_name(value):
raise QueryLocationNotFound(location=value)
return value
@field_validator("query_type")
def validate_query_type(cls, value: t.Any):
"""Ensure a requested query type exists."""
devices = use_state("devices")
if any((device.has_directives(value) for device in devices)):
return value
raise QueryTypeNotFound(query_type=value)