Update code formatting - line length

This commit is contained in:
thatmattlove 2021-09-12 15:09:24 -07:00
parent a62785227e
commit 52ebf4663c
50 changed files with 151 additions and 464 deletions

View file

@ -25,10 +25,7 @@ class HyperglassError(Exception):
"""hyperglass base exception.""" """hyperglass base exception."""
def __init__( def __init__(
self, self, message: str = "", level: str = "warning", keywords: Optional[List[str]] = None,
message: str = "",
level: str = "warning",
keywords: Optional[List[str]] = None,
) -> None: ) -> None:
"""Initialize the hyperglass base exception class.""" """Initialize the hyperglass base exception class."""
self._message = message self._message = message
@ -87,16 +84,12 @@ class _UnformattedHyperglassError(HyperglassError):
_level = "warning" _level = "warning"
def __init__( def __init__(self, unformatted_msg: str = "", level: Optional[str] = None, **kwargs) -> None:
self, unformatted_msg: str = "", level: Optional[str] = None, **kwargs
) -> None:
"""Format error message with keyword arguments.""" """Format error message with keyword arguments."""
self._message = unformatted_msg.format(**kwargs) self._message = unformatted_msg.format(**kwargs)
self._level = level or self._level self._level = level or self._level
self._keywords = list(kwargs.values()) self._keywords = list(kwargs.values())
super().__init__( super().__init__(message=self._message, level=self._level, keywords=self._keywords)
message=self._message, level=self._level, keywords=self._keywords
)
class _PredefinedHyperglassError(HyperglassError): class _PredefinedHyperglassError(HyperglassError):
@ -107,9 +100,7 @@ class _PredefinedHyperglassError(HyperglassError):
self._fmt_msg = self._message.format(**kwargs) self._fmt_msg = self._message.format(**kwargs)
self._level = level or self._level self._level = level or self._level
self._keywords = list(kwargs.values()) self._keywords = list(kwargs.values())
super().__init__( super().__init__(message=self._fmt_msg, level=self._level, keywords=self._keywords)
message=self._fmt_msg, level=self._level, keywords=self._keywords
)
class ConfigInvalid(HyperglassError): class ConfigInvalid(HyperglassError):

View file

@ -113,9 +113,7 @@ def _custom_openapi():
description=params.docs.description, description=params.docs.description,
routes=app.routes, routes=app.routes,
) )
openapi_schema["info"]["x-logo"] = { openapi_schema["info"]["x-logo"] = {"url": "/images/light" + params.web.logo.light.suffix}
"url": "/images/light" + params.web.logo.light.suffix
}
query_samples = [] query_samples = []
queries_samples = [] queries_samples = []
@ -123,38 +121,26 @@ def _custom_openapi():
with EXAMPLE_QUERY_CURL.open("r") as e: with EXAMPLE_QUERY_CURL.open("r") as e:
example = e.read() example = e.read()
query_samples.append( query_samples.append({"lang": "cURL", "source": example % str(params.docs.base_url)})
{"lang": "cURL", "source": example % str(params.docs.base_url)}
)
with EXAMPLE_QUERY_PY.open("r") as e: with EXAMPLE_QUERY_PY.open("r") as e:
example = e.read() example = e.read()
query_samples.append( query_samples.append({"lang": "Python", "source": example % str(params.docs.base_url)})
{"lang": "Python", "source": example % str(params.docs.base_url)}
)
with EXAMPLE_DEVICES_CURL.open("r") as e: with EXAMPLE_DEVICES_CURL.open("r") as e:
example = e.read() example = e.read()
queries_samples.append( queries_samples.append({"lang": "cURL", "source": example % str(params.docs.base_url)})
{"lang": "cURL", "source": example % str(params.docs.base_url)}
)
with EXAMPLE_DEVICES_PY.open("r") as e: with EXAMPLE_DEVICES_PY.open("r") as e:
example = e.read() example = e.read()
queries_samples.append( queries_samples.append({"lang": "Python", "source": example % str(params.docs.base_url)})
{"lang": "Python", "source": example % str(params.docs.base_url)}
)
with EXAMPLE_QUERIES_CURL.open("r") as e: with EXAMPLE_QUERIES_CURL.open("r") as e:
example = e.read() example = e.read()
devices_samples.append( devices_samples.append({"lang": "cURL", "source": example % str(params.docs.base_url)})
{"lang": "cURL", "source": example % str(params.docs.base_url)}
)
with EXAMPLE_QUERIES_PY.open("r") as e: with EXAMPLE_QUERIES_PY.open("r") as e:
example = e.read() example = e.read()
devices_samples.append( devices_samples.append({"lang": "Python", "source": example % str(params.docs.base_url)})
{"lang": "Python", "source": example % str(params.docs.base_url)}
)
openapi_schema["paths"]["/api/query/"]["post"]["x-code-samples"] = query_samples openapi_schema["paths"]["/api/query/"]["post"]["x-code-samples"] = query_samples
openapi_schema["paths"]["/api/devices"]["get"]["x-code-samples"] = devices_samples openapi_schema["paths"]["/api/devices"]["get"]["x-code-samples"] = devices_samples

View file

@ -10,16 +10,14 @@ from hyperglass.configuration import params
async def default_handler(request, exc): async def default_handler(request, exc):
"""Handle uncaught errors.""" """Handle uncaught errors."""
return JSONResponse( return JSONResponse(
{"output": params.messages.general, "level": "danger", "keywords": []}, {"output": params.messages.general, "level": "danger", "keywords": []}, status_code=500,
status_code=500,
) )
async def http_handler(request, exc): async def http_handler(request, exc):
"""Handle web server errors.""" """Handle web server errors."""
return JSONResponse( return JSONResponse(
{"output": exc.detail, "level": "danger", "keywords": []}, {"output": exc.detail, "level": "danger", "keywords": []}, status_code=exc.status_code,
status_code=exc.status_code,
) )
@ -35,6 +33,5 @@ async def validation_handler(request, exc):
"""Handle Pydantic validation errors raised by FastAPI.""" """Handle Pydantic validation errors raised by FastAPI."""
error = exc.errors()[0] error = exc.errors()[0]
return JSONResponse( return JSONResponse(
{"output": error["msg"], "level": "error", "keywords": error["loc"]}, {"output": error["msg"], "level": "error", "keywords": error["loc"]}, status_code=422,
status_code=422,
) )

View file

@ -55,9 +55,7 @@ async def send_webhook(query_data: Query, request: Request, timestamp: datetime)
} }
) )
except Exception as err: except Exception as err:
log.error( log.error("Error sending webhook to {}: {}", params.logging.http.provider, str(err))
"Error sending webhook to {}: {}", params.logging.http.provider, str(err)
)
@log.catch @log.catch
@ -106,9 +104,7 @@ async def query(query_data: Query, request: Request, background_tasks: Backgroun
elif not cache_response: elif not cache_response:
log.debug("No existing cache entry for query {}", cache_key) log.debug("No existing cache entry for query {}", cache_key)
log.debug( log.debug("Created new cache key {} entry for query {}", cache_key, query_data.summary)
"Created new cache key {} entry for query {}", cache_key, query_data.summary
)
timestamp = query_data.timestamp timestamp = query_data.timestamp

View file

@ -8,9 +8,7 @@ from pathlib import Path
from httpx import Headers from httpx import Headers
def import_public_key( def import_public_key(app_path: Union[Path, str], device_name: str, keystring: str) -> bool:
app_path: Union[Path, str], device_name: str, keystring: str
) -> bool:
"""Import a public key for hyperglass-agent.""" """Import a public key for hyperglass-agent."""
if not isinstance(app_path, Path): if not isinstance(app_path, Path):
app_path = Path(app_path) app_path = Path(app_path)

View file

@ -52,18 +52,14 @@ def _print_version(ctx, param, value):
help=cmd_help(E.NUMBERS, "hyperglass version", supports_color), help=cmd_help(E.NUMBERS, "hyperglass version", supports_color),
) )
@help_option( @help_option(
"-h", "-h", "--help", help=cmd_help(E.FOLDED_HANDS, "Show this help message", supports_color),
"--help",
help=cmd_help(E.FOLDED_HANDS, "Show this help message", supports_color),
) )
def hg(): def hg():
"""Initialize Click Command Group.""" """Initialize Click Command Group."""
pass pass
@hg.command( @hg.command("build-ui", help=cmd_help(E.BUTTERFLY, "Create a new UI build", supports_color))
"build-ui", help=cmd_help(E.BUTTERFLY, "Create a new UI build", supports_color)
)
@option("-t", "--timeout", required=False, default=180, help="Timeout in seconds") @option("-t", "--timeout", required=False, default=180, help="Timeout in seconds")
def build_frontend(timeout): def build_frontend(timeout):
"""Create a new UI build.""" """Create a new UI build."""
@ -131,9 +127,7 @@ def start(build, direct, workers): # noqa: C901
cls=HelpColorsCommand, cls=HelpColorsCommand,
help_options_custom_colors=random_colors("-l"), help_options_custom_colors=random_colors("-l"),
) )
@option( @option("-l", "--length", "length", default=32, help="Number of characters [default: 32]")
"-l", "--length", "length", default=32, help="Number of characters [default: 32]"
)
def generate_secret(length): def generate_secret(length):
"""Generate secret for hyperglass-agent. """Generate secret for hyperglass-agent.
@ -177,9 +171,7 @@ After adding your {devices} file, you should run the {build_cmd} command.""", #
@hg.command( @hg.command(
"system-info", "system-info",
help=cmd_help( help=cmd_help(E.THERMOMETER, " Get system information for a bug report", supports_color),
E.THERMOMETER, " Get system information for a bug report", supports_color
),
cls=HelpColorsCommand, cls=HelpColorsCommand,
) )
def get_system_info(): def get_system_info():

View file

@ -96,9 +96,7 @@ def success(text, *args, **kwargs):
Returns: Returns:
{str} -- Success output {str} -- Success output
""" """
return _base_formatter( return _base_formatter(_state="success", _text=text, _callback=echo, *args, **kwargs)
_state="success", _text=text, _callback=echo, *args, **kwargs
)
def warning(text, *args, **kwargs): def warning(text, *args, **kwargs):
@ -111,9 +109,7 @@ def warning(text, *args, **kwargs):
Returns: Returns:
{str} -- Warning output {str} -- Warning output
""" """
return _base_formatter( return _base_formatter(_state="warning", _text=text, _callback=echo, *args, **kwargs)
_state="warning", _text=text, _callback=echo, *args, **kwargs
)
def label(text, *args, **kwargs): def label(text, *args, **kwargs):

View file

@ -56,12 +56,7 @@ class HelpColorsFormatter(click.HelpFormatter):
""" """
def __init__( def __init__(
self, self, headers_color=None, options_color=None, options_custom_colors=None, *args, **kwargs
headers_color=None,
options_color=None,
options_custom_colors=None,
*args,
**kwargs
): ):
"""Initialize help formatter. """Initialize help formatter.
@ -98,9 +93,7 @@ class HelpColorsFormatter(click.HelpFormatter):
def write_dl(self, rows, **kwargs): def write_dl(self, rows, **kwargs):
"""Write Options section.""" """Write Options section."""
colorized_rows = [ colorized_rows = [(click.style(row[0], **self._pick_color(row[0])), row[1]) for row in rows]
(click.style(row[0], **self._pick_color(row[0])), row[1]) for row in rows
]
super().write_dl(colorized_rows, **kwargs) super().write_dl(colorized_rows, **kwargs)

View file

@ -20,9 +20,7 @@ IGNORED_FILES = [".DS_Store"]
INSTALL_PATHS = [ INSTALL_PATHS = [
inquirer.List( inquirer.List(
"install_path", "install_path", message="Choose a directory for hyperglass", choices=[USER_PATH, ROOT_PATH],
message="Choose a directory for hyperglass",
choices=[USER_PATH, ROOT_PATH],
) )
] ]
@ -104,9 +102,7 @@ class Installer:
if not compare_post.left_list == compare_post.right_list: if not compare_post.left_list == compare_post.right_list:
error( error(
"Files in {a} do not match files in {b}", "Files in {a} do not match files in {b}", a=str(ASSET_DIR), b=str(target_dir),
a=str(ASSET_DIR),
b=str(target_dir),
) )
return False return False

View file

@ -122,6 +122,4 @@ class Message:
def __repr__(self): def __repr__(self):
"""Stringify the instance character for representation.""" """Stringify the instance character for representation."""
return "Message(msg={m}, kw={k}, emoji={e})".format( return "Message(msg={m}, kw={k}, emoji={e})".format(m=self.msg, k=self.kw, e=self.emoji)
m=self.msg, k=self.kw, e=self.emoji
)

View file

@ -62,9 +62,7 @@ def _cancel_all_tasks(loop, tasks):
for task in to_cancel: for task in to_cancel:
task.cancel() task.cancel()
loop.run_until_complete( loop.run_until_complete(asyncio.gather(*to_cancel, loop=loop, return_exceptions=True))
asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
)
for task in to_cancel: for task in to_cancel:
if task.cancelled(): if task.cancelled():

View file

@ -117,12 +117,8 @@ def check_address(address):
elif isinstance(address, str): elif isinstance(address, str):
if os.name != "posix": if os.name != "posix":
raise ValueError("Platform does not support UNIX domain sockets") raise ValueError("Platform does not support UNIX domain sockets")
if not ( if not (os.path.exists(address) or os.access(os.path.dirname(address), os.W_OK)):
os.path.exists(address) or os.access(os.path.dirname(address), os.W_OK) raise ValueError("ADDRESS not a valid socket domain socket ({0})".format(address))
):
raise ValueError(
"ADDRESS not a valid socket domain socket ({0})".format(address)
)
else: else:
raise ValueError( raise ValueError(
"ADDRESS is not a tuple, string, or character buffer " "ADDRESS is not a tuple, string, or character buffer "
@ -238,16 +234,12 @@ class _ForwardHandler(socketserver.BaseRequestHandler):
if not chan.recv_ready(): if not chan.recv_ready():
break break
data = chan.recv(1024) data = chan.recv(1024)
self.logger.trace( self.logger.trace("<<< IN {0} recv: {1} <<<".format(self.info, hexlify(data)),)
"<<< IN {0} recv: {1} <<<".format(self.info, hexlify(data)),
)
self.request.sendall(data) self.request.sendall(data)
def handle(self): def handle(self):
uid = get_connection_id() uid = get_connection_id()
self.info = "#{0} <-- {1}".format( self.info = "#{0} <-- {1}".format(uid, self.client_address or self.server.local_address)
uid, self.client_address or self.server.local_address
)
src_address = self.request.getpeername() src_address = self.request.getpeername()
if not isinstance(src_address, tuple): if not isinstance(src_address, tuple):
src_address = ("dummy", 12345) src_address = ("dummy", 12345)
@ -261,9 +253,7 @@ class _ForwardHandler(socketserver.BaseRequestHandler):
except paramiko.SSHException: except paramiko.SSHException:
chan = None chan = None
if chan is None: if chan is None:
msg = "{0} to {1} was rejected by the SSH server".format( msg = "{0} to {1} was rejected by the SSH server".format(self.info, self.remote_address)
self.info, self.remote_address
)
self.logger.trace(msg) self.logger.trace(msg)
raise HandlerSSHTunnelForwarderError(msg) raise HandlerSSHTunnelForwarderError(msg)
@ -373,9 +363,7 @@ class _UnixStreamForwardServer(UnixStreamServer):
return self.RequestHandlerClass.remote_address[1] return self.RequestHandlerClass.remote_address[1]
class _ThreadingUnixStreamForwardServer( class _ThreadingUnixStreamForwardServer(socketserver.ThreadingMixIn, _UnixStreamForwardServer):
socketserver.ThreadingMixIn, _UnixStreamForwardServer
):
""" """
Allow concurrent connections to each tunnel Allow concurrent connections to each tunnel
""" """
@ -693,11 +681,7 @@ class SSHTunnelForwarder:
return _ThreadingForwardServer if self._threaded else _ForwardServer return _ThreadingForwardServer if self._threaded else _ForwardServer
def _make_unix_ssh_forward_server_class(self, remote_address_): def _make_unix_ssh_forward_server_class(self, remote_address_):
return ( return _ThreadingUnixStreamForwardServer if self._threaded else _UnixStreamForwardServer
_ThreadingUnixStreamForwardServer
if self._threaded
else _UnixStreamForwardServer
)
def _make_ssh_forward_server(self, remote_address, local_bind_address): def _make_ssh_forward_server(self, remote_address, local_bind_address):
""" """
@ -710,9 +694,7 @@ class SSHTunnelForwarder:
else: else:
forward_maker_class = self._make_ssh_forward_server_class forward_maker_class = self._make_ssh_forward_server_class
_Server = forward_maker_class(remote_address) _Server = forward_maker_class(remote_address)
ssh_forward_server = _Server( ssh_forward_server = _Server(local_bind_address, _Handler, logger=self.logger,)
local_bind_address, _Handler, logger=self.logger,
)
if ssh_forward_server: if ssh_forward_server:
ssh_forward_server.daemon_threads = self.daemon_forward_servers ssh_forward_server.daemon_threads = self.daemon_forward_servers
@ -724,8 +706,7 @@ class SSHTunnelForwarder:
"Problem setting up ssh {0} <> {1} forwarder. You can " "Problem setting up ssh {0} <> {1} forwarder. You can "
"suppress this exception by using the `mute_exceptions`" "suppress this exception by using the `mute_exceptions`"
"argument".format( "argument".format(
address_to_str(local_bind_address), address_to_str(local_bind_address), address_to_str(remote_address),
address_to_str(remote_address),
), ),
) )
except IOError: except IOError:
@ -802,9 +783,7 @@ class SSHTunnelForwarder:
) )
# local binds # local binds
self._local_binds = self._get_binds(local_bind_address, local_bind_addresses) self._local_binds = self._get_binds(local_bind_address, local_bind_addresses)
self._local_binds = self._consolidate_binds( self._local_binds = self._consolidate_binds(self._local_binds, self._remote_binds)
self._local_binds, self._remote_binds
)
( (
self.ssh_host, self.ssh_host,
@ -882,16 +861,12 @@ class SSHTunnelForwarder:
ssh_port = ssh_port or hostname_info.get("port") ssh_port = ssh_port or hostname_info.get("port")
proxycommand = hostname_info.get("proxycommand") proxycommand = hostname_info.get("proxycommand")
ssh_proxy = ssh_proxy or ( ssh_proxy = ssh_proxy or (paramiko.ProxyCommand(proxycommand) if proxycommand else None)
paramiko.ProxyCommand(proxycommand) if proxycommand else None
)
if compression is None: if compression is None:
compression = hostname_info.get("compression", "") compression = hostname_info.get("compression", "")
compression = True if compression.upper() == "YES" else False compression = True if compression.upper() == "YES" else False
except IOError: except IOError:
logger.warning( logger.warning("Could not read SSH configuration file: {f}", f=ssh_config_file)
"Could not read SSH configuration file: {f}", f=ssh_config_file
)
except (AttributeError, TypeError): # ssh_config_file is None except (AttributeError, TypeError): # ssh_config_file is None
logger.info("Skipping loading of ssh configuration file") logger.info("Skipping loading of ssh configuration file")
finally: finally:
@ -979,8 +954,7 @@ class SSHTunnelForwarder:
count = len(remote_binds) - len(local_binds) count = len(remote_binds) - len(local_binds)
if count < 0: if count < 0:
raise ValueError( raise ValueError(
"Too many local bind addresses " "Too many local bind addresses " "(local_bind_addresses > remote_bind_addresses)"
"(local_bind_addresses > remote_bind_addresses)"
) )
local_binds.extend([("0.0.0.0", 0) for x in range(count)]) local_binds.extend([("0.0.0.0", 0) for x in range(count)])
return local_binds return local_binds
@ -1002,9 +976,7 @@ class SSHTunnelForwarder:
- ``paramiko.Pkey`` - it will be transparently added to loaded keys - ``paramiko.Pkey`` - it will be transparently added to loaded keys
""" """
ssh_loaded_pkeys = SSHTunnelForwarder.get_keys( ssh_loaded_pkeys = SSHTunnelForwarder.get_keys(
logger=logger, logger=logger, host_pkey_directories=host_pkey_directories, allow_agent=allow_agent,
host_pkey_directories=host_pkey_directories,
allow_agent=allow_agent,
) )
if isinstance(ssh_pkey, str): if isinstance(ssh_pkey, str):
@ -1058,9 +1030,7 @@ class SSHTunnelForwarder:
try: try:
self._connect_to_gateway() self._connect_to_gateway()
except socket.gaierror: # raised by paramiko.Transport except socket.gaierror: # raised by paramiko.Transport
msg = "Could not resolve IP address for {0}, aborting!".format( msg = "Could not resolve IP address for {0}, aborting!".format(self.ssh_host)
self.ssh_host
)
self.logger.error(msg) self.logger.error(msg)
return return
except (paramiko.SSHException, socket.error) as e: except (paramiko.SSHException, socket.error) as e:
@ -1109,9 +1079,7 @@ class SSHTunnelForwarder:
"""Processes optional deprecate arguments.""" """Processes optional deprecate arguments."""
if deprecated_attrib not in DEPRECATIONS: if deprecated_attrib not in DEPRECATIONS:
raise ValueError( raise ValueError("{0} not included in deprecations list".format(deprecated_attrib))
"{0} not included in deprecations list".format(deprecated_attrib)
)
if deprecated_attrib in kwargs: if deprecated_attrib in kwargs:
warnings.warn( warnings.warn(
"'{0}' is DEPRECATED use '{1}' instead".format( "'{0}' is DEPRECATED use '{1}' instead".format(
@ -1148,17 +1116,10 @@ class SSHTunnelForwarder:
for pkey_class in ( for pkey_class in (
(key_type,) (key_type,)
if key_type if key_type
else ( else (paramiko.RSAKey, paramiko.DSSKey, paramiko.ECDSAKey, paramiko.Ed25519Key,)
paramiko.RSAKey,
paramiko.DSSKey,
paramiko.ECDSAKey,
paramiko.Ed25519Key,
)
): ):
try: try:
ssh_pkey = pkey_class.from_private_key_file( ssh_pkey = pkey_class.from_private_key_file(pkey_file, password=pkey_password)
pkey_file, password=pkey_password
)
logger.debug( logger.debug(
"Private key file ({k0}, {k1}) successfully loaded", "Private key file ({k0}, {k1}) successfully loaded",
@ -1202,9 +1163,7 @@ class SSHTunnelForwarder:
else _srv.local_address else _srv.local_address
) )
s.connect(connect_to) s.connect(connect_to)
self.tunnel_is_up[_srv.local_address] = _srv.tunnel_ok.get( self.tunnel_is_up[_srv.local_address] = _srv.tunnel_ok.get(timeout=TUNNEL_TIMEOUT * 1.1)
timeout=TUNNEL_TIMEOUT * 1.1
)
self.logger.debug("Tunnel to {0} is DOWN".format(_srv.remote_address)) self.logger.debug("Tunnel to {0} is DOWN".format(_srv.remote_address))
except socket.error: except socket.error:
self.logger.debug("Tunnel to {0} is DOWN".format(_srv.remote_address)) self.logger.debug("Tunnel to {0} is DOWN".format(_srv.remote_address))
@ -1232,8 +1191,7 @@ class SSHTunnelForwarder:
self._create_tunnels() self._create_tunnels()
if not self.is_active: if not self.is_active:
self._raise( self._raise(
BaseSSHTunnelForwarderError, BaseSSHTunnelForwarderError, reason="Could not establish session to SSH gateway",
reason="Could not establish session to SSH gateway",
) )
for _srv in self._server_list: for _srv in self._server_list:
thread = threading.Thread( thread = threading.Thread(
@ -1247,8 +1205,7 @@ class SSHTunnelForwarder:
self.is_alive = any(self.tunnel_is_up.values()) self.is_alive = any(self.tunnel_is_up.values())
if not self.is_alive: if not self.is_alive:
self._raise( self._raise(
HandlerSSHTunnelForwarderError, HandlerSSHTunnelForwarderError, "An error occurred while opening tunnels.",
"An error occurred while opening tunnels.",
) )
def stop(self) -> None: def stop(self) -> None:
@ -1270,8 +1227,7 @@ class SSHTunnelForwarder:
""" """
self.logger.info("Closing all open connections...") self.logger.info("Closing all open connections...")
opened_address_text = ( opened_address_text = (
", ".join((address_to_str(k.local_address) for k in self._server_list)) ", ".join((address_to_str(k.local_address) for k in self._server_list)) or "None"
or "None"
) )
self.logger.debug("Listening tunnels: " + opened_address_text) self.logger.debug("Listening tunnels: " + opened_address_text)
self._stop_transport() self._stop_transport()
@ -1311,9 +1267,7 @@ class SSHTunnelForwarder:
if self.ssh_password: # avoid conflict using both pass and pkey if self.ssh_password: # avoid conflict using both pass and pkey
self.logger.debug( self.logger.debug(
"Trying to log in with password: {0}".format( "Trying to log in with password: {0}".format("*" * len(self.ssh_password))
"*" * len(self.ssh_password)
)
) )
try: try:
self._transport = self._get_transport() self._transport = self._get_transport()
@ -1364,9 +1318,7 @@ class SSHTunnelForwarder:
os.unlink(_srv.local_address) os.unlink(_srv.local_address)
except Exception as e: except Exception as e:
self.logger.error( self.logger.error(
"Unable to unlink socket {0}: {1}".format( "Unable to unlink socket {0}: {1}".format(self.local_address, repr(e))
self.local_address, repr(e)
)
) )
self.is_alive = False self.is_alive = False
if self.is_active: if self.is_active:
@ -1413,9 +1365,7 @@ class SSHTunnelForwarder:
self._check_is_started() self._check_is_started()
return [ return [
_server.local_port _server.local_port for _server in self._server_list if _server.local_port is not None
for _server in self._server_list
if _server.local_port is not None
] ]
@property @property
@ -1423,9 +1373,7 @@ class SSHTunnelForwarder:
"""Return a list containing the IP addresses listening for the tunnels.""" """Return a list containing the IP addresses listening for the tunnels."""
self._check_is_started() self._check_is_started()
return [ return [
_server.local_host _server.local_host for _server in self._server_list if _server.local_host is not None
for _server in self._server_list
if _server.local_host is not None
] ]
@property @property
@ -1461,10 +1409,7 @@ class SSHTunnelForwarder:
def __str__(self) -> str: def __str__(self) -> str:
credentials = { credentials = {
"password": self.ssh_password, "password": self.ssh_password,
"pkeys": [ "pkeys": [(key.get_name(), hexlify(key.get_fingerprint())) for key in self.ssh_pkeys]
(key.get_name(), hexlify(key.get_fingerprint()))
for key in self.ssh_pkeys
]
if any(self.ssh_pkeys) if any(self.ssh_pkeys)
else None, else None,
} }
@ -1496,9 +1441,7 @@ class SSHTunnelForwarder:
credentials, credentials,
self.ssh_host_key if self.ssh_host_key else "not checked", self.ssh_host_key if self.ssh_host_key else "not checked",
"" if self.is_alive else "not ", "" if self.is_alive else "not ",
"disabled" "disabled" if not self.set_keepalive else "every {0} sec".format(self.set_keepalive),
if not self.set_keepalive
else "every {0} sec".format(self.set_keepalive),
"disabled" if self.skip_tunnel_checkup else "enabled", "disabled" if self.skip_tunnel_checkup else "enabled",
"" if self._threaded else "not ", "" if self._threaded else "not ",
"" if self.compression else "not ", "" if self.compression else "not ",
@ -1612,8 +1555,6 @@ def _bindlist(input_str):
_port = "22" # default port if not given _port = "22" # default port if not given
return _ip, int(_port) return _ip, int(_port)
except ValueError: except ValueError:
raise argparse.ArgumentTypeError( raise argparse.ArgumentTypeError("Address tuple must be of type IP_ADDRESS:PORT")
"Address tuple must be of type IP_ADDRESS:PORT"
)
except AssertionError: except AssertionError:
raise argparse.ArgumentTypeError("Both IP:PORT can't be missing!") raise argparse.ArgumentTypeError("Both IP:PORT can't be missing!")

View file

@ -118,9 +118,7 @@ def _get_commands(data: Dict) -> List[Directive]:
return commands return commands
def _device_commands( def _device_commands(device: Dict, directives: List[Directive]) -> Generator[Directive, None, None]:
device: Dict, directives: List[Directive]
) -> Generator[Directive, None, None]:
device_commands = device.get("commands", []) device_commands = device.get("commands", [])
for directive in directives: for directive in directives:
if directive.id in device_commands: if directive.id in device_commands:
@ -176,9 +174,7 @@ enable_file_logging(
# Set up syslog logging if enabled. # Set up syslog logging if enabled.
if params.logging.syslog is not None and params.logging.syslog.enable: if params.logging.syslog is not None and params.logging.syslog.enable:
enable_syslog_logging( enable_syslog_logging(
logger=log, logger=log, syslog_host=params.logging.syslog.host, syslog_port=params.logging.syslog.port,
syslog_host=params.logging.syslog.host,
syslog_port=params.logging.syslog.port,
) )
if params.logging.http is not None and params.logging.http.enable: if params.logging.http is not None and params.logging.http.enable:
@ -196,18 +192,14 @@ try:
# If keywords are unmodified (default), add the org name & # If keywords are unmodified (default), add the org name &
# site_title. # site_title.
if Params().site_keywords == params.site_keywords: if Params().site_keywords == params.site_keywords:
params.site_keywords = sorted( params.site_keywords = sorted({*params.site_keywords, params.org_name, params.site_title})
{*params.site_keywords, params.org_name, params.site_title}
)
except KeyError: except KeyError:
pass pass
content_greeting = get_markdown( content_greeting = get_markdown(
config_path=params.web.greeting, config_path=params.web.greeting, default="", params={"title": params.web.greeting.title},
default="",
params={"title": params.web.greeting.title},
) )

View file

@ -14,9 +14,7 @@ from hyperglass.exceptions.private import ConfigInvalid
Importer = TypeVar("Importer") Importer = TypeVar("Importer")
def validate_config( def validate_config(config: Union[Dict[str, Any], List[Any]], importer: Importer) -> Importer:
config: Union[Dict[str, Any], List[Any]], importer: Importer
) -> Importer:
"""Validate a config dict against a model.""" """Validate a config dict against a model."""
validated = None validated = None
try: try:

View file

@ -124,9 +124,7 @@ class PublicHyperglassError(HyperglassError):
kwargs["error"] = error kwargs["error"] = error
self._message = self._safe_format(self._message_template, **kwargs) self._message = self._safe_format(self._message_template, **kwargs)
self._keywords = list(kwargs.values()) self._keywords = list(kwargs.values())
super().__init__( super().__init__(message=self._message, level=self._level, keywords=self._keywords)
message=self._message, level=self._level, keywords=self._keywords
)
def handle_error(self, error: Any) -> None: def handle_error(self, error: Any) -> None:
"""Add details to the error template, if provided.""" """Add details to the error template, if provided."""
@ -156,6 +154,4 @@ class PrivateHyperglassError(HyperglassError):
kwargs["error"] = error kwargs["error"] = error
self._message = self._safe_format(message, **kwargs) self._message = self._safe_format(message, **kwargs)
self._keywords = list(kwargs.values()) self._keywords = list(kwargs.values())
super().__init__( super().__init__(message=self._message, level=self._level, keywords=self._keywords)
message=self._message, level=self._level, keywords=self._keywords
)

View file

@ -10,9 +10,7 @@ from ._common import ErrorLevel, PrivateHyperglassError
class ExternalError(PrivateHyperglassError): class ExternalError(PrivateHyperglassError):
"""Raised when an error during a connection to an external service occurs.""" """Raised when an error during a connection to an external service occurs."""
def __init__( def __init__(self, message: str, level: ErrorLevel, **kwargs: Dict[str, Any]) -> None:
self, message: str, level: ErrorLevel, **kwargs: Dict[str, Any]
) -> None:
"""Set level according to level argument.""" """Set level according to level argument."""
self._level = level self._level = level
super().__init__(message, **kwargs) super().__init__(message, **kwargs)
@ -31,9 +29,7 @@ class UnsupportedDevice(PrivateHyperglassError):
drivers = ("", *[*DRIVER_MAP.keys(), *CLASS_MAPPER.keys()].sort()) drivers = ("", *[*DRIVER_MAP.keys(), *CLASS_MAPPER.keys()].sort())
driver_list = "\n - ".join(drivers) driver_list = "\n - ".join(drivers)
super().__init__( super().__init__(message=f"'{nos}' is not supported. Must be one of:{driver_list}")
message=f"'{nos}' is not supported. Must be one of:{driver_list}"
)
class InputValidationError(PrivateHyperglassError): class InputValidationError(PrivateHyperglassError):

View file

@ -33,9 +33,7 @@ class AuthError(
super().__init__(error=str(error), device=device.name, proxy=device.proxy) super().__init__(error=str(error), device=device.name, proxy=device.proxy)
class RestError( class RestError(PublicHyperglassError, template=params.messages.connection_error, level="danger"):
PublicHyperglassError, template=params.messages.connection_error, level="danger"
):
"""Raised upon a rest API client error.""" """Raised upon a rest API client error."""
def __init__(self, error: BaseException, *, device: Device): def __init__(self, error: BaseException, *, device: Device):
@ -86,9 +84,7 @@ class QueryLocationNotFound(NotFound):
def __init__(self, location: Any, **kwargs: Dict[str, Any]) -> None: def __init__(self, location: Any, **kwargs: Dict[str, Any]) -> None:
"""Initialize a NotFound error for a query location.""" """Initialize a NotFound error for a query location."""
super().__init__( super().__init__(type=params.web.text.query_location, name=str(location), **kwargs)
type=params.web.text.query_location, name=str(location), **kwargs
)
class QueryTypeNotFound(NotFound): class QueryTypeNotFound(NotFound):
@ -96,9 +92,7 @@ class QueryTypeNotFound(NotFound):
def __init__(self, query_type: Any, **kwargs: Dict[str, Any]) -> None: def __init__(self, query_type: Any, **kwargs: Dict[str, Any]) -> None:
"""Initialize a NotFound error for a query type.""" """Initialize a NotFound error for a query type."""
super().__init__( super().__init__(type=params.web.text.query_type, name=str(query_type), **kwargs)
type=params.web.text.query_type, name=str(query_type), **kwargs
)
class QueryGroupNotFound(NotFound): class QueryGroupNotFound(NotFound):

View file

@ -32,9 +32,7 @@ class Construct:
def __init__(self, device, query): def __init__(self, device, query):
"""Initialize command construction.""" """Initialize command construction."""
log.debug( log.debug(
"Constructing '{}' query for '{}'", "Constructing '{}' query for '{}'", query.query_type, str(query.query_target),
query.query_type,
str(query.query_target),
) )
self.query = query self.query = query
self.device = device self.device = device
@ -73,10 +71,7 @@ class Construct:
for key in [k for k in keys if k != "target"]: for key in [k for k in keys if k != "target"]:
if key not in attrs: if key not in attrs:
raise ConfigError( raise ConfigError(
( ("Command '{c}' has attribute '{k}', " "which is missing from device '{d}'"),
"Command '{c}' has attribute '{k}', "
"which is missing from device '{d}'"
),
level="danger", level="danger",
c=self.directive.name, c=self.directive.name,
k=key, k=key,

View file

@ -80,9 +80,7 @@ class AgentConnection(Connection):
) )
log.debug("Encoded JWT: {}", encoded_query) log.debug("Encoded JWT: {}", encoded_query)
raw_response = await http_client.post( raw_response = await http_client.post(endpoint, json={"encoded": encoded_query})
endpoint, json={"encoded": encoded_query}
)
log.debug("HTTP status code: {}", raw_response.status_code) log.debug("HTTP status code: {}", raw_response.status_code)
raw = raw_response.text raw = raw_response.text

View file

@ -36,9 +36,7 @@ class SSHConnection(Connection):
} }
if proxy.credential._method == "password": if proxy.credential._method == "password":
# Use password auth if no key is defined. # Use password auth if no key is defined.
tunnel_kwargs[ tunnel_kwargs["ssh_password"] = proxy.credential.password.get_secret_value()
"ssh_password"
] = proxy.credential.password.get_secret_value()
else: else:
# Otherwise, use key auth. # Otherwise, use key auth.
tunnel_kwargs["ssh_pkey"] = proxy.credential.key.as_posix() tunnel_kwargs["ssh_pkey"] = proxy.credential.key.as_posix()
@ -53,8 +51,7 @@ class SSHConnection(Connection):
except BaseSSHTunnelForwarderError as scrape_proxy_error: except BaseSSHTunnelForwarderError as scrape_proxy_error:
log.error( log.error(
f"Error connecting to device {self.device.name} via " f"Error connecting to device {self.device.name} via " f"proxy {proxy.name}"
f"proxy {proxy.name}"
) )
raise ScrapeError(error=scrape_proxy_error, device=self.device) raise ScrapeError(error=scrape_proxy_error, device=self.device)

View file

@ -78,9 +78,7 @@ class NetmikoConnection(SSHConnection):
if self.device.credential._method == "password": if self.device.credential._method == "password":
# Use password auth if no key is defined. # Use password auth if no key is defined.
driver_kwargs[ driver_kwargs["password"] = self.device.credential.password.get_secret_value()
"password"
] = self.device.credential.password.get_secret_value()
else: else:
# Otherwise, use key auth. # Otherwise, use key auth.
driver_kwargs["use_keys"] = True driver_kwargs["use_keys"] = True
@ -88,9 +86,7 @@ class NetmikoConnection(SSHConnection):
if self.device.credential._method == "encrypted_key": if self.device.credential._method == "encrypted_key":
# If the key is encrypted, use the password field as the # If the key is encrypted, use the password field as the
# private key password. # private key password.
driver_kwargs[ driver_kwargs["passphrase"] = self.device.credential.password.get_secret_value()
"passphrase"
] = self.device.credential.password.get_secret_value()
try: try:
nm_connect_direct = ConnectHandler(**driver_kwargs) nm_connect_direct = ConnectHandler(**driver_kwargs)

View file

@ -98,9 +98,7 @@ class ScrapliConnection(SSHConnection):
if self.device.credential._method == "password": if self.device.credential._method == "password":
# Use password auth if no key is defined. # Use password auth if no key is defined.
driver_kwargs[ driver_kwargs["auth_password"] = self.device.credential.password.get_secret_value()
"auth_password"
] = self.device.credential.password.get_secret_value()
else: else:
# Otherwise, use key auth. # Otherwise, use key auth.
driver_kwargs["auth_private_key"] = self.device.credential.key.as_posix() driver_kwargs["auth_private_key"] = self.device.credential.key.as_posix()
@ -112,9 +110,7 @@ class ScrapliConnection(SSHConnection):
] = self.device.credential.password.get_secret_value() ] = self.device.credential.password.get_secret_value()
driver = driver(**driver_kwargs) driver = driver(**driver_kwargs)
driver.logger = log.bind( driver.logger = log.bind(logger_name=f"scrapli.{driver.host}:{driver.port}-driver")
logger_name=f"scrapli.{driver.host}:{driver.port}-driver"
)
try: try:
responses = () responses = ()
async with driver as connection: async with driver as connection:

View file

@ -8,7 +8,7 @@ hyperglass-frr API calls, returns the output back to the front end.
# Standard Library # Standard Library
import signal import signal
from typing import Any, Dict, Union, Callable, Sequence, TYPE_CHECKING from typing import TYPE_CHECKING, Any, Dict, Union, Callable, Sequence
# Project # Project
from hyperglass.log import log from hyperglass.log import log

View file

@ -140,9 +140,7 @@ class BaseExternal:
except gaierror as err: except gaierror as err:
# Raised if the target isn't listening on the port # Raised if the target isn't listening on the port
raise self._exception( raise self._exception(f"{self.name} appears to be unreachable", err) from None
f"{self.name} appears to be unreachable", err
) from None
return True return True
@ -157,21 +155,13 @@ class BaseExternal:
supported_methods = ("GET", "POST", "PUT", "DELETE", "HEAD", "PATCH") supported_methods = ("GET", "POST", "PUT", "DELETE", "HEAD", "PATCH")
( (method, endpoint, item, headers, params, data, timeout, response_required,) = itemgetter(
method, *kwargs.keys()
endpoint, )(kwargs)
item,
headers,
params,
data,
timeout,
response_required,
) = itemgetter(*kwargs.keys())(kwargs)
if method.upper() not in supported_methods: if method.upper() not in supported_methods:
raise self._exception( raise self._exception(
f'Method must be one of {", ".join(supported_methods)}. ' f'Method must be one of {", ".join(supported_methods)}. ' f"Got: {str(method)}"
f"Got: {str(method)}"
) )
endpoint = "/".join( endpoint = "/".join(
@ -209,9 +199,7 @@ class BaseExternal:
try: try:
timeout = int(timeout) timeout = int(timeout)
except TypeError: except TypeError:
raise self._exception( raise self._exception(f"Timeout must be an int, got: {str(timeout)}")
f"Timeout must be an int, got: {str(timeout)}"
)
request["timeout"] = timeout request["timeout"] = timeout
log.debug("Constructed request parameters {}", request) log.debug("Constructed request parameters {}", request)

View file

@ -31,9 +31,7 @@ def parse_whois(output: str, targets: List[str]) -> Dict[str, str]:
def lines(raw): def lines(raw):
"""Generate clean string values for each column.""" """Generate clean string values for each column."""
for r in (r for r in raw.split("\n") if r): for r in (r for r in raw.split("\n") if r):
fields = ( fields = (re.sub(r"(\n|\r)", "", field).strip(" ") for field in r.split("|"))
re.sub(r"(\n|\r)", "", field).strip(" ") for field in r.split("|")
)
yield fields yield fields
data = {} data = {}

View file

@ -12,9 +12,7 @@ class GenericHook(BaseExternal, name="Generic"):
def __init__(self, config): def __init__(self, config):
"""Initialize external base class with http connection details.""" """Initialize external base class with http connection details."""
super().__init__( super().__init__(base_url=f"{config.host.scheme}://{config.host.host}", config=config)
base_url=f"{config.host.scheme}://{config.host.host}", config=config
)
async def send(self, query): async def send(self, query):
"""Send an incoming webhook to http endpoint.""" """Send an incoming webhook to http endpoint."""

View file

@ -12,9 +12,7 @@ class MSTeams(BaseExternal, name="MSTeams"):
def __init__(self, config): def __init__(self, config):
"""Initialize external base class with Microsoft Teams connection details.""" """Initialize external base class with Microsoft Teams connection details."""
super().__init__( super().__init__(base_url="https://outlook.office.com", config=config, parse=False)
base_url="https://outlook.office.com", config=config, parse=False
)
async def send(self, query): async def send(self, query):
"""Send an incoming webhook to Microsoft Teams.""" """Send an incoming webhook to Microsoft Teams."""

View file

@ -41,9 +41,7 @@ def rpki_state(prefix, asn):
log.error(str(err)) log.error(str(err))
state = 3 state = 3
msg = "RPKI Validation State for {} via AS{} is {}".format( msg = "RPKI Validation State for {} via AS{} is {}".format(prefix, asn, RPKI_NAME_MAP[state])
prefix, asn, RPKI_NAME_MAP[state]
)
if cached is not None: if cached is not None:
msg += " [CACHED]" msg += " [CACHED]"

View file

@ -24,6 +24,5 @@ class Webhook(BaseExternal):
return provider_class(config) return provider_class(config)
except KeyError: except KeyError:
raise UnsupportedError( raise UnsupportedError(
message="{p} is not yet supported as a webhook target.", message="{p} is not yet supported as a webhook target.", p=config.provider.title(),
p=config.provider.title(),
) )

View file

@ -108,11 +108,7 @@ def enable_file_logging(logger, log_directory, log_format, log_max_size):
lf.write(f'\n\n{"".join(log_break)}\n\n') lf.write(f'\n\n{"".join(log_break)}\n\n')
logger.add( logger.add(
log_file, log_file, format=_FMT, rotation=log_max_size, serialize=structured, enqueue=True,
format=_FMT,
rotation=log_max_size,
serialize=structured,
enqueue=True,
) )
logger.debug("Logging to {} enabled", str(log_file)) logger.debug("Logging to {} enabled", str(log_file))
@ -127,9 +123,7 @@ def enable_syslog_logging(logger, syslog_host, syslog_port):
from logging.handlers import SysLogHandler from logging.handlers import SysLogHandler
logger.add( logger.add(
SysLogHandler(address=(str(syslog_host), syslog_port)), SysLogHandler(address=(str(syslog_host), syslog_port)), format=_FMT_BASIC, enqueue=True,
format=_FMT_BASIC,
enqueue=True,
) )
logger.debug( logger.debug(
"Logging to syslog target {}:{} enabled", str(syslog_host), str(syslog_port), "Logging to syslog target {}:{} enabled", str(syslog_host), str(syslog_port),

View file

@ -125,14 +125,12 @@ def register_all_plugins(devices: "Devices") -> None:
"""Validate and register configured plugins.""" """Validate and register configured plugins."""
for plugin_file in { for plugin_file in {
Path(p) Path(p) for p in (p for d in devices.objects for c in d.commands for p in c.plugins)
for p in (p for d in devices.objects for c in d.commands for p in c.plugins)
}: }:
failures = register_plugin(plugin_file) failures = register_plugin(plugin_file)
for failure in failures: for failure in failures:
log.warning( log.warning(
"Plugin '{}' is not a valid hyperglass plugin, and was not registered", "Plugin '{}' is not a valid hyperglass plugin, and was not registered", failure,
failure,
) )

View file

@ -26,9 +26,7 @@ from hyperglass.exceptions.private import InputValidationError
from ..config.devices import Device from ..config.devices import Device
from ..commands.generic import Directive from ..commands.generic import Directive
DIRECTIVE_IDS = [ DIRECTIVE_IDS = [directive.id for device in devices.objects for directive in device.commands]
directive.id for device in devices.objects for directive in device.commands
]
DIRECTIVE_GROUPS = { DIRECTIVE_GROUPS = {
group group
@ -76,9 +74,7 @@ class Query(BaseModel):
"example": "1.1.1.0/24", "example": "1.1.1.0/24",
}, },
} }
schema_extra = { schema_extra = {"x-code-samples": [{"lang": "Python", "source": "print('stuff')"}]}
"x-code-samples": [{"lang": "Python", "source": "print('stuff')"}]
}
def __init__(self, **kwargs): def __init__(self, **kwargs):
"""Initialize the query with a UTC timestamp at initialization time.""" """Initialize the query with a UTC timestamp at initialization time."""

View file

@ -26,20 +26,14 @@ class QueryError(BaseModel):
"""Pydantic model configuration.""" """Pydantic model configuration."""
title = "Query Error" title = "Query Error"
description = ( description = "Response received when there is an error executing the requested query."
"Response received when there is an error executing the requested query."
)
fields = { fields = {
"output": { "output": {
"title": "Output", "title": "Output",
"description": "Error Details", "description": "Error Details",
"example": "192.0.2.1/32 is not allowed.", "example": "192.0.2.1/32 is not allowed.",
}, },
"level": { "level": {"title": "Level", "description": "Error Severity", "example": "danger"},
"title": "Level",
"description": "Error Severity",
"example": "danger",
},
"keywords": { "keywords": {
"title": "Keywords", "title": "Keywords",
"description": "Relevant keyword values contained in the `output` field, which can be used for formatting.", "description": "Relevant keyword values contained in the `output` field, which can be used for formatting.",
@ -189,11 +183,7 @@ class RoutersResponse(BaseModel):
description = "Device attributes" description = "Device attributes"
schema_extra = { schema_extra = {
"examples": [ "examples": [
{ {"id": "nyc_router_1", "name": "NYC Router 1", "network": "New York City, NY"}
"id": "nyc_router_1",
"name": "NYC Router 1",
"network": "New York City, NY",
}
] ]
} }
@ -217,11 +207,11 @@ class SupportedQueryResponse(BaseModel):
"""Pydantic model configuration.""" """Pydantic model configuration."""
title = "Query Type" title = "Query Type"
description = "If enabled is `true`, the `name` field may be used to specify the query type." description = (
"If enabled is `true`, the `name` field may be used to specify the query type."
)
schema_extra = { schema_extra = {
"examples": [ "examples": [{"name": "bgp_route", "display_name": "BGP Route", "enable": True}]
{"name": "bgp_route", "display_name": "BGP Route", "enable": True}
]
} }

View file

@ -71,9 +71,7 @@ def validate_ip(value, query_type, query_vrf): # noqa: C901
except ValueError: except ValueError:
raise InputInvalid( raise InputInvalid(
params.messages.invalid_input, params.messages.invalid_input, target=value, query_type=query_type_params.display_name,
target=value,
query_type=query_type_params.display_name,
) )
# Test the valid IP address to determine if it is: # Test the valid IP address to determine if it is:
@ -83,9 +81,7 @@ def validate_ip(value, query_type, query_vrf): # noqa: C901
# ...and returns an error if so. # ...and returns an error if so.
if valid_ip.is_reserved or valid_ip.is_unspecified or valid_ip.is_loopback: if valid_ip.is_reserved or valid_ip.is_unspecified or valid_ip.is_loopback:
raise InputInvalid( raise InputInvalid(
params.messages.invalid_input, params.messages.invalid_input, target=value, query_type=query_type_params.display_name,
target=value,
query_type=query_type_params.display_name,
) )
ip_version = valid_ip.version ip_version = valid_ip.version
@ -105,9 +101,7 @@ def validate_ip(value, query_type, query_vrf): # noqa: C901
pass pass
if ace.action == "permit": if ace.action == "permit":
log.debug( log.debug("{t} is allowed by access-list {a}", t=str(valid_ip), a=repr(ace))
"{t} is allowed by access-list {a}", t=str(valid_ip), a=repr(ace)
)
break break
elif ace.action == "deny": elif ace.action == "deny":
raise InputNotAllowed( raise InputNotAllowed(
@ -125,10 +119,7 @@ def validate_ip(value, query_type, query_vrf): # noqa: C901
new_ip = valid_ip.network_address new_ip = valid_ip.network_address
log.debug( log.debug(
"Converted '{o}' to '{n}' for '{q}' query", "Converted '{o}' to '{n}' for '{q}' query", o=valid_ip, n=new_ip, q=query_type,
o=valid_ip,
n=new_ip,
q=query_type,
) )
valid_ip = new_ip valid_ip = new_ip
@ -137,11 +128,7 @@ def validate_ip(value, query_type, query_vrf): # noqa: C901
# - Query type is bgp_route # - Query type is bgp_route
# - force_cidr option is enabled # - force_cidr option is enabled
# - Query target is not a private address/network # - Query target is not a private address/network
elif ( elif query_type in ("bgp_route",) and vrf_afi.force_cidr and not valid_ip.is_private:
query_type in ("bgp_route",)
and vrf_afi.force_cidr
and not valid_ip.is_private
):
log.debug("Getting containing prefix for {q}", q=str(valid_ip)) log.debug("Getting containing prefix for {q}", q=str(valid_ip))
ip_str = str(valid_ip.network_address) ip_str = str(valid_ip.network_address)
@ -150,9 +137,7 @@ def validate_ip(value, query_type, query_vrf): # noqa: C901
if containing_prefix is None: if containing_prefix is None:
log.error( log.error(
"Unable to find containing prefix for {}. Got: {}", "Unable to find containing prefix for {}. Got: {}", str(valid_ip), network_info,
str(valid_ip),
network_info,
) )
raise InputInvalid("{q} does not have a containing prefix", q=ip_str) raise InputInvalid("{q} does not have a containing prefix", q=ip_str)
@ -163,13 +148,9 @@ def validate_ip(value, query_type, query_vrf): # noqa: C901
except ValueError as err: except ValueError as err:
log.error( log.error(
"Unable to find containing prefix for {q}. Error: {e}", "Unable to find containing prefix for {q}. Error: {e}", q=str(valid_ip), e=err,
q=str(valid_ip),
e=err,
)
raise InputInvalid(
"{q} does does not have a containing prefix", q=valid_ip
) )
raise InputInvalid("{q} does does not have a containing prefix", q=valid_ip)
# For a host query with bgp_route query type and force_cidr # For a host query with bgp_route query type and force_cidr
# disabled, convert the host query to a single IP address. # disabled, convert the host query to a single IP address.

View file

@ -24,9 +24,7 @@ class _IPv6(CommandSet):
bgp_aspath: StrictStr = 'show bgp ipv6 unicast quote-regexp "{target}"' bgp_aspath: StrictStr = 'show bgp ipv6 unicast quote-regexp "{target}"'
bgp_route: StrictStr = "show bgp ipv6 unicast {target} | exclude pathid:|Epoch" bgp_route: StrictStr = "show bgp ipv6 unicast {target} | exclude pathid:|Epoch"
ping: StrictStr = "ping ipv6 {target} repeat 5 source {source}" ping: StrictStr = "ping ipv6 {target} repeat 5 source {source}"
traceroute: StrictStr = ( traceroute: StrictStr = ("traceroute ipv6 {target} timeout 1 probe 2 source {source}")
"traceroute ipv6 {target} timeout 1 probe 2 source {source}"
)
class _VPNIPv4(CommandSet): class _VPNIPv4(CommandSet):
@ -36,9 +34,7 @@ class _VPNIPv4(CommandSet):
bgp_aspath: StrictStr = 'show bgp vpnv4 unicast vrf {vrf} quote-regexp "{target}"' bgp_aspath: StrictStr = 'show bgp vpnv4 unicast vrf {vrf} quote-regexp "{target}"'
bgp_route: StrictStr = "show bgp vpnv4 unicast vrf {vrf} {target}" bgp_route: StrictStr = "show bgp vpnv4 unicast vrf {vrf} {target}"
ping: StrictStr = "ping vrf {vrf} {target} repeat 5 source {source}" ping: StrictStr = "ping vrf {vrf} {target} repeat 5 source {source}"
traceroute: StrictStr = ( traceroute: StrictStr = ("traceroute vrf {vrf} {target} timeout 1 probe 2 source {source}")
"traceroute vrf {vrf} {target} timeout 1 probe 2 source {source}"
)
class _VPNIPv6(CommandSet): class _VPNIPv6(CommandSet):
@ -48,9 +44,7 @@ class _VPNIPv6(CommandSet):
bgp_aspath: StrictStr = 'show bgp vpnv6 unicast vrf {vrf} quote-regexp "{target}"' bgp_aspath: StrictStr = 'show bgp vpnv6 unicast vrf {vrf} quote-regexp "{target}"'
bgp_route: StrictStr = "show bgp vpnv6 unicast vrf {vrf} {target}" bgp_route: StrictStr = "show bgp vpnv6 unicast vrf {vrf} {target}"
ping: StrictStr = "ping vrf {vrf} {target} repeat 5 source {source}" ping: StrictStr = "ping vrf {vrf} {target} repeat 5 source {source}"
traceroute: StrictStr = ( traceroute: StrictStr = ("traceroute vrf {vrf} {target} timeout 1 probe 2 source {source}")
"traceroute vrf {vrf} {target} timeout 1 probe 2 source {source}"
)
class CiscoIOSCommands(CommandGroup): class CiscoIOSCommands(CommandGroup):

View file

@ -292,8 +292,6 @@ class Directive(HyperglassModel):
} }
if self.field.is_select: if self.field.is_select:
value["options"] = [ value["options"] = [o.export_dict() for o in self.field.options if o is not None]
o.export_dict() for o in self.field.options if o is not None
]
return value return value

View file

@ -92,9 +92,7 @@ class Device(HyperglassModel, extra="allow"):
legacy_display_name = values.pop("display_name", None) legacy_display_name = values.pop("display_name", None)
if legacy_display_name is not None: if legacy_display_name is not None:
log.warning( log.warning("The 'display_name' field is deprecated. Use the 'name' field instead.")
"The 'display_name' field is deprecated. Use the 'name' field instead."
)
device_id = generate_id(legacy_display_name) device_id = generate_id(legacy_display_name)
display_name = legacy_display_name display_name = legacy_display_name
else: else:

View file

@ -23,9 +23,7 @@ class EndpointConfig(HyperglassModel):
description="Displayed inside each API endpoint section.", description="Displayed inside each API endpoint section.",
) )
summary: StrictStr = Field( summary: StrictStr = Field(
..., ..., title="Endpoint Summary", description="Displayed beside the API endpoint URI.",
title="Endpoint Summary",
description="Displayed beside the API endpoint URI.",
) )
@ -41,9 +39,7 @@ class Docs(HyperglassModel):
description="OpenAPI UI library to use for the hyperglass API docs. Currently, the options are [Swagger UI](/fixme) and [Redoc](/fixme).", description="OpenAPI UI library to use for the hyperglass API docs. Currently, the options are [Swagger UI](/fixme) and [Redoc](/fixme).",
) )
base_url: HttpUrl = Field( base_url: HttpUrl = Field(
"https://lg.example.net", "https://lg.example.net", title="Base URL", description="Base URL used in request samples.",
title="Base URL",
description="Base URL used in request samples.",
) )
uri: AnyUri = Field( uri: AnyUri = Field(
"/api/docs", "/api/docs",

View file

@ -11,9 +11,7 @@ class Network(HyperglassModel):
"""Validation Model for per-network/asn config in devices.yaml.""" """Validation Model for per-network/asn config in devices.yaml."""
name: StrictStr = Field( name: StrictStr = Field(
..., ..., title="Network Name", description="Internal name of the device's primary network.",
title="Network Name",
description="Internal name of the device's primary network.",
) )
display_name: StrictStr = Field( display_name: StrictStr = Field(
..., ...,

View file

@ -32,9 +32,7 @@ class OpenGraph(HyperglassModel):
supported_extensions = (".jpg", ".jpeg", ".png") supported_extensions = (".jpg", ".jpeg", ".png")
if value is not None and value.suffix not in supported_extensions: if value is not None and value.suffix not in supported_extensions:
raise ValueError( raise ValueError(
"OpenGraph image must be one of {e}".format( "OpenGraph image must be one of {e}".format(e=", ".join(supported_extensions))
e=", ".join(supported_extensions)
)
) )
return value return value

View file

@ -110,9 +110,7 @@ class Params(ParamsPublic, HyperglassModel):
description="Allowed CORS hosts. By default, no CORS hosts are allowed.", description="Allowed CORS hosts. By default, no CORS hosts are allowed.",
) )
netmiko_delay_factor: IntFloat = Field( netmiko_delay_factor: IntFloat = Field(
0.1, 0.1, title="Netmiko Delay Factor", description="Override the netmiko global delay factor.",
title="Netmiko Delay Factor",
description="Override the netmiko global delay factor.",
) )
# Sub Level Params # Sub Level Params
@ -184,9 +182,7 @@ class Params(ParamsPublic, HyperglassModel):
def content_params(self) -> Dict[str, Any]: def content_params(self) -> Dict[str, Any]:
"""Export content-specific parameters.""" """Export content-specific parameters."""
return self.dict( return self.dict(include={"primary_asn", "org_name", "site_title", "site_description"})
include={"primary_asn", "org_name", "site_title", "site_description"}
)
def frontend(self) -> Dict[str, Any]: def frontend(self) -> Dict[str, Any]:
"""Export UI-specific parameters.""" """Export UI-specific parameters."""

View file

@ -39,9 +39,7 @@ class BgpCommunityPattern(HyperglassModel):
"""Pydantic model configuration.""" """Pydantic model configuration."""
title = "Pattern" title = "Pattern"
description = ( description = "Regular expression patterns used to validate BGP Community queries."
"Regular expression patterns used to validate BGP Community queries."
)
class BgpAsPathPattern(HyperglassModel): class BgpAsPathPattern(HyperglassModel):
@ -67,9 +65,7 @@ class BgpAsPathPattern(HyperglassModel):
"""Pydantic model configuration.""" """Pydantic model configuration."""
title = "Pattern" title = "Pattern"
description = ( description = "Regular expression patterns used to validate BGP AS Path queries."
"Regular expression patterns used to validate BGP AS Path queries."
)
class Community(HyperglassModel): class Community(HyperglassModel):
@ -84,9 +80,7 @@ class BgpCommunity(HyperglassModel):
"""Validation model for bgp_community configuration.""" """Validation model for bgp_community configuration."""
enable: StrictBool = Field( enable: StrictBool = Field(
True, True, title="Enable", description="Enable or disable the BGP Community query type.",
title="Enable",
description="Enable or disable the BGP Community query type.",
) )
display_name: StrictStr = Field( display_name: StrictStr = Field(
"BGP Community", "BGP Community",
@ -115,9 +109,7 @@ class BgpAsPath(HyperglassModel):
"""Validation model for bgp_aspath configuration.""" """Validation model for bgp_aspath configuration."""
enable: StrictBool = Field( enable: StrictBool = Field(
True, True, title="Enable", description="Enable or disable the BGP AS Path query type.",
title="Enable",
description="Enable or disable the BGP AS Path query type.",
) )
display_name: StrictStr = Field( display_name: StrictStr = Field(
"BGP AS Path", "BGP AS Path",
@ -168,9 +160,7 @@ class Queries(HyperglassModel):
query_obj = getattr(self, query) query_obj = getattr(self, query)
_map[query] = { _map[query] = {
"name": query, "name": query,
**query_obj.export_dict( **query_obj.export_dict(include={"display_name", "enable", "mode", "communities"}),
include={"display_name", "enable", "mode", "communities"}
),
} }
return _map return _map
@ -185,11 +175,7 @@ class Queries(HyperglassModel):
for query in SUPPORTED_QUERY_TYPES: for query in SUPPORTED_QUERY_TYPES:
query_obj = getattr(self, query) query_obj = getattr(self, query)
_list.append( _list.append(
{ {"name": query, "display_name": query_obj.display_name, "enable": query_obj.enable}
"name": query,
"display_name": query_obj.display_name,
"enable": query_obj.enable,
}
) )
return _list return _list

View file

@ -79,6 +79,4 @@ class HyperglassModel(BaseModel):
"exclude_unset": kwargs.pop("exclude_unset", False), "exclude_unset": kwargs.pop("exclude_unset", False),
} }
return yaml.safe_dump( return yaml.safe_dump(json.loads(self.export_json(**export_kwargs)), *args, **kwargs)
json.loads(self.export_json(**export_kwargs)), *args, **kwargs
)

View file

@ -110,9 +110,7 @@ class FRRRoute(_FRRBase):
} }
) )
serialized = ParsedRoutes( serialized = ParsedRoutes(vrf=vrf, count=len(routes), routes=routes, winning_weight="high",)
vrf=vrf, count=len(routes), routes=routes, winning_weight="high",
)
log.info("Serialized FRR response: {}", serialized) log.info("Serialized FRR response: {}", serialized)
return serialized return serialized

View file

@ -83,9 +83,7 @@ class JuniperRouteTableEntry(_JuniperBase):
_path_attr = values.get("bgp-path-attributes", {}) _path_attr = values.get("bgp-path-attributes", {})
_path_attr_agg = _path_attr.get("attr-aggregator", {}).get("attr-value", {}) _path_attr_agg = _path_attr.get("attr-aggregator", {}).get("attr-value", {})
values["as-path"] = _path_attr.get("attr-as-path-effective", {}).get( values["as-path"] = _path_attr.get("attr-as-path-effective", {}).get("attr-value", "")
"attr-value", ""
)
values["source-as"] = _path_attr_agg.get("aggr-as-number", 0) values["source-as"] = _path_attr_agg.get("aggr-as-number", 0)
values["source-rid"] = _path_attr_agg.get("aggr-router-id", "") values["source-rid"] = _path_attr_agg.get("aggr-router-id", "")
values["peer-rid"] = values["peer-id"] values["peer-rid"] = values["peer-id"]
@ -171,9 +169,7 @@ class JuniperRoute(_JuniperBase):
count = 0 count = 0
for table in self.rt: for table in self.rt:
count += table.rt_entry_count count += table.rt_entry_count
prefix = "/".join( prefix = "/".join(str(i) for i in (table.rt_destination, table.rt_prefix_length))
str(i) for i in (table.rt_destination, table.rt_prefix_length)
)
for route in table.rt_entry: for route in table.rt_entry:
routes.append( routes.append(
{ {
@ -193,9 +189,7 @@ class JuniperRoute(_JuniperBase):
} }
) )
serialized = ParsedRoutes( serialized = ParsedRoutes(vrf=vrf, count=count, routes=routes, winning_weight="low",)
vrf=vrf, count=count, routes=routes, winning_weight="low",
)
log.debug("Serialized Juniper response: {}", serialized) log.debug("Serialized Juniper response: {}", serialized)
return serialized return serialized

View file

@ -75,8 +75,7 @@ class Webhook(HyperglassModel):
return f"`{str(value)}`" return f"`{str(value)}`"
header_data = [ header_data = [
{"name": k, "value": code(v)} {"name": k, "value": code(v)} for k, v in self.headers.dict(by_alias=True).items()
for k, v in self.headers.dict(by_alias=True).items()
] ]
time_fmt = self.timestamp.strftime("%Y %m %d %H:%M:%S") time_fmt = self.timestamp.strftime("%Y %m %d %H:%M:%S")
payload = { payload = {
@ -131,39 +130,21 @@ class Webhook(HyperglassModel):
header_data.append(field) header_data.append(field)
query_data = [ query_data = [
{ {"type": "mrkdwn", "text": make_field("Query Location", self.query_location)},
"type": "mrkdwn", {"type": "mrkdwn", "text": make_field("Query Target", self.query_target, code=True)},
"text": make_field("Query Location", self.query_location),
},
{
"type": "mrkdwn",
"text": make_field("Query Target", self.query_target, code=True),
},
{"type": "mrkdwn", "text": make_field("Query Type", self.query_type)}, {"type": "mrkdwn", "text": make_field("Query Type", self.query_type)},
{"type": "mrkdwn", "text": make_field("Query VRF", self.query_vrf)}, {"type": "mrkdwn", "text": make_field("Query VRF", self.query_vrf)},
] ]
source_data = [ source_data = [
{ {"type": "mrkdwn", "text": make_field("Source IP", self.source, code=True)},
"type": "mrkdwn",
"text": make_field("Source IP", self.source, code=True),
},
{ {
"type": "mrkdwn", "type": "mrkdwn",
"text": make_field("Source Prefix", self.network.prefix, code=True), "text": make_field("Source Prefix", self.network.prefix, code=True),
}, },
{ {"type": "mrkdwn", "text": make_field("Source ASN", self.network.asn, code=True)},
"type": "mrkdwn", {"type": "mrkdwn", "text": make_field("Source Country", self.network.country)},
"text": make_field("Source ASN", self.network.asn, code=True), {"type": "mrkdwn", "text": make_field("Source Organization", self.network.org)},
},
{
"type": "mrkdwn",
"text": make_field("Source Country", self.network.country),
},
{
"type": "mrkdwn",
"text": make_field("Source Organization", self.network.org),
},
] ]
time_fmt = self.timestamp.strftime("%Y %m %d %H:%M:%S") time_fmt = self.timestamp.strftime("%Y %m %d %H:%M:%S")
@ -171,20 +152,14 @@ class Webhook(HyperglassModel):
payload = { payload = {
"text": _WEBHOOK_TITLE, "text": _WEBHOOK_TITLE,
"blocks": [ "blocks": [
{ {"type": "section", "text": {"type": "mrkdwn", "text": f"*{time_fmt} UTC*"}},
"type": "section",
"text": {"type": "mrkdwn", "text": f"*{time_fmt} UTC*"},
},
{"type": "section", "fields": query_data}, {"type": "section", "fields": query_data},
{"type": "divider"}, {"type": "divider"},
{"type": "section", "fields": source_data}, {"type": "section", "fields": source_data},
{"type": "divider"}, {"type": "divider"},
{ {
"type": "section", "type": "section",
"text": { "text": {"type": "mrkdwn", "text": "*Headers*\n" + "\n".join(header_data)},
"type": "mrkdwn",
"text": "*Headers*\n" + "\n".join(header_data),
},
}, },
], ],
} }

View file

@ -58,9 +58,7 @@ def parse_juniper(output: Sequence) -> Dict: # noqa: C901
cleaned = clean_xml_output(response) cleaned = clean_xml_output(response)
try: try:
parsed = xmltodict.parse( parsed = xmltodict.parse(cleaned, force_list=("rt", "rt-entry", "community"))
cleaned, force_list=("rt", "rt-entry", "community")
)
log.debug("Initially Parsed Response: \n{}", parsed) log.debug("Initially Parsed Response: \n{}", parsed)

View file

@ -50,9 +50,7 @@ def parse_linux_ping(output):
_bytes, seq, ttl, rtt = _process_numbers(bytes_seq_ttl_rtt) _bytes, seq, ttl, rtt = _process_numbers(bytes_seq_ttl_rtt)
reply_stats.append( reply_stats.append({"bytes": _bytes, "sequence": seq, "ttl": ttl, "rtt": rtt})
{"bytes": _bytes, "sequence": seq, "ttl": ttl, "rtt": rtt}
)
stats = [l for l in _stats.splitlines() if l] stats = [l for l in _stats.splitlines() if l]

View file

@ -11,9 +11,7 @@ from threading import Thread
from hyperglass.log import log from hyperglass.log import log
async def move_files( # noqa: C901 async def move_files(src: Path, dst: Path, files: Iterable[Path]) -> Tuple[str]: # noqa: C901
src: Path, dst: Path, files: Iterable[Path]
) -> Tuple[str]:
"""Move iterable of files from source to destination. """Move iterable of files from source to destination.
Arguments: Arguments:
@ -133,9 +131,7 @@ def copyfiles(src_files: Iterable[Path], dst_files: Iterable[Path]):
return True return True
def check_path( def check_path(path: Union[Path, str], mode: str = "r", create: bool = False) -> Optional[Path]:
path: Union[Path, str], mode: str = "r", create: bool = False
) -> Optional[Path]:
"""Verify if a path exists and is accessible.""" """Verify if a path exists and is accessible."""
result = None result = None

View file

@ -22,9 +22,7 @@ def get_node_version() -> Tuple[int, int, int]:
"""Get the system's NodeJS version.""" """Get the system's NodeJS version."""
node_path = shutil.which("node") node_path = shutil.which("node")
raw_version = subprocess.check_output( # noqa: S603 raw_version = subprocess.check_output([node_path, "--version"]).decode() # noqa: S603
[node_path, "--version"]
).decode()
# Node returns the version as 'v14.5.0', for example. Remove the v. # Node returns the version as 'v14.5.0', for example. Remove the v.
version = raw_version.replace("v", "") version = raw_version.replace("v", "")
@ -162,11 +160,7 @@ async def build_ui(app_path):
def generate_opengraph( def generate_opengraph(
image_path: Path, image_path: Path, max_width: int, max_height: int, target_path: Path, background_color: str,
max_width: int,
max_height: int,
target_path: Path,
background_color: str,
): ):
"""Generate an OpenGraph compliant image.""" """Generate an OpenGraph compliant image."""
# Third Party # Third Party
@ -340,9 +334,7 @@ async def build_frontend( # noqa: C901
log.debug("Previous Build ID: {}", ef_id) log.debug("Previous Build ID: {}", ef_id)
if ef_id == build_id: if ef_id == build_id:
log.debug( log.debug("UI parameters unchanged since last build, skipping UI build...")
"UI parameters unchanged since last build, skipping UI build..."
)
return True return True
env_vars["buildId"] = build_id env_vars["buildId"] = build_id
@ -368,11 +360,7 @@ async def build_frontend( # noqa: C901
migrate_images(app_path, params) migrate_images(app_path, params)
generate_opengraph( generate_opengraph(
params.web.opengraph.image, params.web.opengraph.image, 1200, 630, images_dir, params.web.theme.colors.black,
1200,
630,
images_dir,
params.web.theme.colors.black,
) )
except Exception as err: except Exception as err: