flatten configuration model

This commit is contained in:
checktheroads 2020-01-28 08:59:27 -07:00
parent 6ad69ae6bb
commit 9b9fd95061
10 changed files with 143 additions and 156 deletions

View file

@ -33,21 +33,21 @@ UI_DIR = STATIC_DIR / "ui"
IMAGES_DIR = STATIC_DIR / "images"
ASGI_PARAMS = {
"host": str(params.general.listen_address),
"port": params.general.listen_port,
"debug": params.general.debug,
"host": str(params.listen_address),
"port": params.listen_port,
"debug": params.debug,
}
# Main App Definition
app = FastAPI(
debug=params.general.debug,
title=params.general.site_title,
description=params.general.site_description,
debug=params.debug,
title=params.site_title,
description=params.site_description,
version=__version__,
default_response_class=UJSONResponse,
docs_url=None,
redoc_url=None,
openapi_url=params.general.docs.openapi_url,
openapi_url=params.docs.openapi_url,
)
# Add Event Handlers
@ -73,9 +73,9 @@ app.add_exception_handler(Exception, default_handler)
def _custom_openapi():
"""Generate custom OpenAPI config."""
openapi_schema = get_openapi(
title=params.general.site_title,
title=params.site_title,
version=__version__,
description=params.general.site_description,
description=params.site_description,
routes=app.routes,
)
app.openapi_schema = openapi_schema
@ -84,11 +84,11 @@ def _custom_openapi():
app.openapi = _custom_openapi
if params.general.docs.enable:
if params.docs.enable:
log.debug(f"API Docs config: {app.openapi()}")
CORS_ORIGINS = params.general.cors_origins.copy()
if params.general.developer_mode:
CORS_ORIGINS = params.cors_origins.copy()
if params.developer_mode:
CORS_ORIGINS.append(URL_DEV)
# CORS Configuration
@ -103,10 +103,10 @@ app.add_api_route(
path="/api/query/",
endpoint=query,
methods=["POST"],
summary=params.general.docs.endpoint_summary,
description=params.general.docs.endpoint_description,
summary=params.docs.endpoint_summary,
description=params.docs.endpoint_description,
response_model=QueryResponse,
tags=[params.general.docs.group_title],
tags=[params.docs.group_title],
response_class=UJSONResponse,
)
app.add_api_route(path="api/docs", endpoint=docs, include_in_schema=False)

View file

@ -58,7 +58,7 @@ async def build_ui():
"""
try:
await build_frontend(
dev_mode=params.general.developer_mode,
dev_mode=params.developer_mode,
dev_url=URL_DEV,
prod_url=URL_PROD,
params=frontend_params,

View file

@ -64,12 +64,11 @@ async def query(query_data: Query, request: Request):
async def docs():
"""Serve custom docs."""
if params.general.docs.enable:
if params.docs.enable:
docs_func_map = {"swagger": get_swagger_ui_html, "redoc": get_redoc_html}
docs_func = docs_func_map[params.general.docs.mode]
docs_func = docs_func_map[params.docs.mode]
return docs_func(
openapi_url=params.general.docs.openapi_url,
title=params.general.site_title + " - API Docs",
openapi_url=params.docs.openapi_url, title=params.site_title + " - API Docs"
)
else:
raise HTTPException(detail="Not found", status_code=404)

View file

@ -145,7 +145,7 @@ try:
params = _params.Params()
try:
params.branding.text.subtitle = params.branding.text.subtitle.format(
**params.general.dict()
**params.dict(exclude={"branding", "features", "messages"})
)
except KeyError:
pass
@ -167,7 +167,7 @@ except ValidationError as validation_errors:
)
# Re-evaluate debug state after config is validated
_set_log_level(params.general.debug, params.general.log_file)
_set_log_level(params.debug, params.log_file)
def _build_frontend_networks():
@ -328,9 +328,7 @@ def _build_queries():
content_params = json.loads(
params.general.json(
include={"primary_asn", "org_name", "site_title", "site_description"}
)
params.json(include={"primary_asn", "org_name", "site_title", "site_description"})
)
@ -436,11 +434,11 @@ _frontend_params.update(
)
frontend_params = _frontend_params
URL_DEV = f"http://localhost:{str(params.general.listen_port)}/api/"
URL_DEV = f"http://localhost:{str(params.listen_port)}/api/"
URL_PROD = "/api/"
REDIS_CONFIG = {
"host": str(params.general.redis_host),
"port": params.general.redis_port,
"host": str(params.redis_host),
"port": params.redis_port,
"decode_responses": True,
}

View file

@ -10,7 +10,7 @@ from hyperglass.configuration.models._utils import HyperglassModel
class Docs(HyperglassModel):
"""Validation model for params.general.docs."""
"""Validation model for params.docs."""
enable: StrictBool = True
mode: constr(regex=r"(swagger|redoc)") = "swagger"

View file

@ -1,117 +0,0 @@
"""Validate general configuration variables."""
# Standard Library Imports
from datetime import datetime
from ipaddress import ip_address
from pathlib import Path
from typing import List
from typing import Optional
from typing import Union
# Third Party Imports
from pydantic import FilePath
from pydantic import IPvAnyAddress
from pydantic import StrictBool
from pydantic import StrictInt
from pydantic import StrictStr
from pydantic import validator
# Project Imports
from hyperglass.configuration.models._utils import HyperglassModel
from hyperglass.configuration.models.docs import Docs
from hyperglass.configuration.models.opengraph import OpenGraph
class General(HyperglassModel):
"""Validation model for params.general."""
debug: StrictBool = False
developer_mode: StrictBool = False
primary_asn: StrictStr = "65001"
org_name: StrictStr = "The Company"
site_title: StrictStr = "hyperglass"
site_description: StrictStr = "{org_name} Network Looking Glass"
site_keywords: List[StrictStr] = [
"hyperglass",
"looking glass",
"lg",
"peer",
"peering",
"ipv4",
"ipv6",
"transit",
"community",
"communities",
"bgp",
"routing",
"network",
"isp",
]
opengraph: OpenGraph = OpenGraph()
docs: Docs = Docs()
google_analytics: StrictStr = ""
redis_host: StrictStr = "localhost"
redis_port: StrictInt = 6379
requires_ipv6_cidr: List[StrictStr] = ["cisco_ios", "cisco_nxos"]
request_timeout: StrictInt = 30
listen_address: Optional[Union[IPvAnyAddress, StrictStr]]
listen_port: StrictInt = 8001
log_file: Optional[FilePath]
cors_origins: List[StrictStr] = []
@validator("listen_address", pre=True, always=True)
def validate_listen_address(cls, value, values):
"""Set default listen_address based on debug mode.
Arguments:
value {str|IPvAnyAddress|None} -- listen_address
values {dict} -- already-validated entries before listen_address
Returns:
{str} -- Validated listen_address
"""
if value is None and not values["debug"]:
listen_address = "localhost"
elif value is None and values["debug"]:
listen_address = ip_address("0.0.0.0") # noqa: S104
elif isinstance(value, str) and value != "localhost":
try:
listen_address = ip_address(value)
except ValueError:
raise ValueError(str(value))
elif isinstance(value, str) and value == "localhost":
listen_address = value
else:
raise ValueError(str(value))
return listen_address
@validator("site_description")
def validate_site_description(cls, value, values):
"""Format the site descripion with the org_name field.
Arguments:
value {str} -- site_description
values {str} -- Values before site_description
Returns:
{str} -- Formatted description
"""
return value.format(org_name=values["org_name"])
@validator("log_file")
def validate_log_file(cls, value):
"""Set default logfile location if none is configured.
Arguments:
value {FilePath} -- Path to log file
Returns:
{Path} -- Logfile path object
"""
if value is None:
now = datetime.now()
now.isoformat
value = Path(
f'/tmp/hyperglass_{now.strftime(r"%Y%M%d-%H%M%S")}.log' # noqa: S108
)
return value

View file

@ -13,7 +13,7 @@ from hyperglass.configuration.models._utils import HyperglassModel
class OpenGraph(HyperglassModel):
"""Validation model for params.general.opengraph."""
"""Validation model for params.opengraph."""
width: Optional[StrictInt]
height: Optional[StrictInt]

View file

@ -1,17 +1,124 @@
"""Configuration validation entry point."""
# Standard Library Imports
from datetime import datetime
from ipaddress import ip_address
from pathlib import Path
from typing import List
from typing import Optional
from typing import Union
# Third Party Imports
from pydantic import FilePath
from pydantic import IPvAnyAddress
from pydantic import StrictBool
from pydantic import StrictInt
from pydantic import StrictStr
from pydantic import validator
# Project Imports
from hyperglass.configuration.models._utils import HyperglassModel
from hyperglass.configuration.models.branding import Branding
from hyperglass.configuration.models.docs import Docs
from hyperglass.configuration.models.features import Features
from hyperglass.configuration.models.general import General
from hyperglass.configuration.models.messages import Messages
from hyperglass.configuration.models.opengraph import OpenGraph
class Params(HyperglassModel):
"""Validation model for all configuration variables."""
general: General = General()
debug: StrictBool = False
developer_mode: StrictBool = False
primary_asn: StrictStr = "65001"
org_name: StrictStr = "The Company"
site_title: StrictStr = "hyperglass"
site_description: StrictStr = "{org_name} Network Looking Glass"
site_keywords: List[StrictStr] = [
"hyperglass",
"looking glass",
"lg",
"peer",
"peering",
"ipv4",
"ipv6",
"transit",
"community",
"communities",
"bgp",
"routing",
"network",
"isp",
]
opengraph: OpenGraph = OpenGraph()
docs: Docs = Docs()
google_analytics: StrictStr = ""
redis_host: StrictStr = "localhost"
redis_port: StrictInt = 6379
requires_ipv6_cidr: List[StrictStr] = ["cisco_ios", "cisco_nxos"]
request_timeout: StrictInt = 30
listen_address: Optional[Union[IPvAnyAddress, StrictStr]]
listen_port: StrictInt = 8001
log_file: Optional[FilePath]
cors_origins: List[StrictStr] = []
@validator("listen_address", pre=True, always=True)
def validate_listen_address(cls, value, values):
"""Set default listen_address based on debug mode.
Arguments:
value {str|IPvAnyAddress|None} -- listen_address
values {dict} -- already-validated entries before listen_address
Returns:
{str} -- Validated listen_address
"""
if value is None and not values["debug"]:
listen_address = "localhost"
elif value is None and values["debug"]:
listen_address = ip_address("0.0.0.0") # noqa: S104
elif isinstance(value, str) and value != "localhost":
try:
listen_address = ip_address(value)
except ValueError:
raise ValueError(str(value))
elif isinstance(value, str) and value == "localhost":
listen_address = value
else:
raise ValueError(str(value))
return listen_address
@validator("site_description")
def validate_site_description(cls, value, values):
"""Format the site descripion with the org_name field.
Arguments:
value {str} -- site_description
values {str} -- Values before site_description
Returns:
{str} -- Formatted description
"""
return value.format(org_name=values["org_name"])
@validator("log_file")
def validate_log_file(cls, value):
"""Set default logfile location if none is configured.
Arguments:
value {FilePath} -- Path to log file
Returns:
{Path} -- Logfile path object
"""
if value is None:
now = datetime.now()
now.isoformat
value = Path(
f'/tmp/hyperglass_{now.strftime(r"%Y%M%d-%H%M%S")}.log' # noqa: S108
)
return value
features: Features = Features()
branding: Branding = Branding()
messages: Messages = Messages()

View file

@ -104,7 +104,7 @@ class Connect:
)
signal.signal(signal.SIGALRM, handle_timeout)
signal.alarm(params.general.request_timeout - 1)
signal.alarm(params.request_timeout - 1)
with tunnel:
log.debug(
@ -119,7 +119,7 @@ class Connect:
"username": self.device.credential.username,
"password": self.device.credential.password.get_secret_value(),
"global_delay_factor": 0.2,
"timeout": params.general.request_timeout - 1,
"timeout": params.request_timeout - 1,
}
try:
@ -194,7 +194,7 @@ class Connect:
"username": self.device.credential.username,
"password": self.device.credential.password.get_secret_value(),
"global_delay_factor": 0.2,
"timeout": params.general.request_timeout,
"timeout": params.request_timeout,
}
try:
@ -210,7 +210,7 @@ class Connect:
)
signal.signal(signal.SIGALRM, handle_timeout)
signal.alarm(params.general.request_timeout - 1)
signal.alarm(params.request_timeout - 1)
responses = []
@ -259,7 +259,7 @@ class Connect:
client_params = {
"headers": {"Content-Type": "application/json"},
"timeout": params.general.request_timeout,
"timeout": params.request_timeout,
}
if self.device.ssl is not None and self.device.ssl.enable:
http_protocol = "https"
@ -286,7 +286,7 @@ class Connect:
encoded_query = await jwt_encode(
payload=query,
secret=self.device.credential.password.get_secret_value(),
duration=params.general.request_timeout,
duration=params.request_timeout,
)
log.debug(f"Encoded JWT: {encoded_query}")

View file

@ -232,7 +232,7 @@ def ip_type_check(query_type, target, device):
if (
query_type == "bgp_route"
and prefix_attr["version"] == 6
and device.nos in params.general.requires_ipv6_cidr
and device.nos in params.requires_ipv6_cidr
and IPType().is_host(target)
):
log.debug("Failed requires IPv6 CIDR check")