From b7b681f3d38e12248cdc3c2ce97580cb9ee8efa0 Mon Sep 17 00:00:00 2001 From: thatmattlove Date: Wed, 15 Dec 2021 00:50:20 -0700 Subject: [PATCH] Restructure utilities, add tests --- hyperglass/external/_base.py | 4 +- hyperglass/external/msteams.py | 2 + hyperglass/util/__init__.py | 355 +++--------------------- hyperglass/util/frontend.py | 13 - hyperglass/util/system_info.py | 50 +++- hyperglass/util/tests/test_tools.py | 191 +++++++++++++ hyperglass/util/tests/test_utilities.py | 69 ----- hyperglass/util/tools.py | 185 ++++++++++++ hyperglass/util/validation.py | 73 +++++ poetry.lock | 6 +- pyproject.toml | 2 +- 11 files changed, 548 insertions(+), 402 deletions(-) create mode 100644 hyperglass/util/tests/test_tools.py delete mode 100644 hyperglass/util/tests/test_utilities.py create mode 100644 hyperglass/util/tools.py create mode 100644 hyperglass/util/validation.py diff --git a/hyperglass/external/_base.py b/hyperglass/external/_base.py index 9de9097..cad9dbb 100644 --- a/hyperglass/external/_base.py +++ b/hyperglass/external/_base.py @@ -13,7 +13,7 @@ import httpx # Project from hyperglass.log import log -from hyperglass.util import make_repr, parse_exception +from hyperglass.util import parse_exception, repr_from_attrs from hyperglass.constants import __version__ from hyperglass.models.fields import JsonValue, HttpMethod, Primitives from hyperglass.exceptions.private import ExternalError @@ -124,7 +124,7 @@ class BaseExternal: def __repr__(self: "BaseExternal") -> str: """Return user friendly representation of instance.""" - return make_repr(self) + return repr_from_attrs(self, ("name", "base_url", "config", "parse")) def _exception( self: "BaseExternal", diff --git a/hyperglass/external/msteams.py b/hyperglass/external/msteams.py index 63667b2..b91a52a 100644 --- a/hyperglass/external/msteams.py +++ b/hyperglass/external/msteams.py @@ -1,5 +1,6 @@ """Session handler for Microsoft Teams API.""" +# Standard Library import typing as t # Project @@ -8,6 +9,7 @@ from hyperglass.external._base import BaseExternal from hyperglass.models.webhook import Webhook if t.TYPE_CHECKING: + # Project from hyperglass.models.config.logging import Http diff --git a/hyperglass/util/__init__.py b/hyperglass/util/__init__.py index 16ba04c..c1d0ab3 100644 --- a/hyperglass/util/__init__.py +++ b/hyperglass/util/__init__.py @@ -1,311 +1,48 @@ """Utility functions.""" -# Standard Library -import os -import sys -import json -import string -import typing as t -import platform -from asyncio import iscoroutine -from pathlib import Path -from ipaddress import IPv4Address, IPv6Address, ip_address - -# Third Party -from loguru._logger import Logger as LoguruLogger -from netmiko.ssh_dispatcher import CLASS_MAPPER # type: ignore - -# Project -from hyperglass.types import Series -from hyperglass.constants import DRIVER_MAP - -ALL_DEVICE_TYPES = {*DRIVER_MAP.keys(), *CLASS_MAPPER.keys()} -ALL_DRIVERS = {*DRIVER_MAP.values(), "netmiko"} - -DeepConvert = t.TypeVar("DeepConvert", bound=t.Dict[str, t.Any]) - - -def cpu_count(multiplier: int = 0) -> int: - """Get server's CPU core count. - - Used to determine the number of web server workers. - """ - # Standard Library - import multiprocessing - - return multiprocessing.cpu_count() * multiplier - - -def check_python() -> str: - """Verify Python Version.""" - # Project - from hyperglass.constants import MIN_PYTHON_VERSION - - pretty_version = ".".join(tuple(str(v) for v in MIN_PYTHON_VERSION)) - if sys.version_info < MIN_PYTHON_VERSION: - raise RuntimeError(f"Python {pretty_version}+ is required.") - return platform.python_version() - - -def split_on_uppercase(s: str) -> t.List[str]: - """Split characters by uppercase letters. - - From: https://stackoverflow.com/a/40382663 - """ - string_length = len(s) - is_lower_around = lambda: s[i - 1].islower() or string_length > (i + 1) and s[i + 1].islower() - - start = 0 - parts = [] - for i in range(1, string_length): - if s[i].isupper() and is_lower_around(): - parts.append(s[start:i]) - start = i - parts.append(s[start:]) - - return parts - - -def parse_exception(exc: BaseException) -> str: - """Parse an exception and its direct cause.""" - - if not isinstance(exc, BaseException): - raise TypeError(f"'{repr(exc)}' is not an exception.") - - def get_exc_name(exc): - return " ".join(split_on_uppercase(exc.__class__.__name__)) - - def get_doc_summary(doc): - return doc.strip().split("\n")[0].strip(".") - - name = get_exc_name(exc) - parsed = [] - if exc.__doc__: - detail = get_doc_summary(exc.__doc__) - parsed.append(f"{name} ({detail})") - else: - parsed.append(name) - - if exc.__cause__: - cause = get_exc_name(exc.__cause__) - if exc.__cause__.__doc__: - cause_detail = get_doc_summary(exc.__cause__.__doc__) - parsed.append(f"{cause} ({cause_detail})") - else: - parsed.append(cause) - return ", caused by ".join(parsed) - - -def make_repr(_class): - """Create a user-friendly represention of an object.""" - - def _process_attrs(_dir): - for attr in _dir: - if not attr.startswith("_"): - attr_val = getattr(_class, attr) - - if callable(attr_val): - yield f'{attr}=' - - elif iscoroutine(attr_val): - yield f'{attr}=' - - elif isinstance(attr_val, str): - yield f'{attr}="{attr_val}"' - - else: - yield f"{attr}={str(attr_val)}" - - return f'{_class.__name__}({", ".join(_process_attrs(dir(_class)))})' - - -def repr_from_attrs(obj: object, attrs: Series[str], strip: t.Optional[str] = None) -> str: - """Generate a `__repr__()` value from a specific set of attribute names. - - Useful for complex models/objects where `__repr__()` should only display specific fields. - """ - # Check the object to ensure each attribute actually exists, and deduplicate - attr_names = {a for a in attrs if hasattr(obj, a)} - # Dict representation of attr name to obj value (e.g. `obj.attr`), if the value has a - # `__repr__` method. - attr_values = { - f if strip is None else f.strip(strip): v # noqa: IF100 - for f in attr_names - if hasattr((v := getattr(obj, f)), "__repr__") - } - pairs = (f"{k}={v!r}" for k, v in attr_values.items()) - return f"{obj.__class__.__name__}({', '.join(pairs)})" - - -def validate_platform(_type: str) -> t.Tuple[bool, t.Union[None, str]]: - """Validate device type is supported.""" - - result = (False, None) - - if _type in ALL_DEVICE_TYPES: - result = (True, DRIVER_MAP.get(_type, "netmiko")) - - return result - - -def get_driver(_type: str, driver: t.Optional[str]) -> str: - """Determine the appropriate driver for a device.""" - - if driver is None: - # If no driver is set, use the driver map with netmiko as - # fallback. - return DRIVER_MAP.get(_type, "netmiko") - elif driver in ALL_DRIVERS: - # If a driver is set and it is valid, allow it. - return driver - else: - # Otherwise, fail validation. - raise ValueError("{} is not a supported driver.".format(driver)) - - -def current_log_level(logger: LoguruLogger) -> str: - """Get the current log level of a logger instance.""" - - try: - handler = list(logger._core.handlers.values())[0] - levels = {v.no: k for k, v in logger._core.levels.items()} - current_level = levels[handler.levelno].lower() - - except Exception as err: - logger.error(err) - current_level = "info" - - return current_level - - -def resolve_hostname(hostname: str) -> t.Generator[t.Union[IPv4Address, IPv6Address], None, None]: - """Resolve a hostname via DNS/hostfile.""" - # Standard Library - from socket import gaierror, getaddrinfo - - # Project - from hyperglass.log import log - - log.debug("Ensuring '{}' is resolvable...", hostname) - - ip4 = None - ip6 = None - try: - res = getaddrinfo(hostname, None) - for sock in res: - if sock[0].value == 2 and ip4 is None: - ip4 = ip_address(sock[4][0]) - elif sock[0].value in (10, 30) and ip6 is None: - ip6 = ip_address(sock[4][0]) - except (gaierror, ValueError, IndexError) as err: - log.debug(str(err)) - pass - - yield ip4 - yield ip6 - - -def snake_to_camel(value: str) -> str: - """Convert a string from snake_case to camelCase.""" - parts = value.split("_") - humps = (hump.capitalize() for hump in parts[1:]) - return "".join((parts[0], *humps)) - - -def get_fmt_keys(template: str) -> t.List[str]: - """Get a list of str.format keys. - - For example, string `"The value of {key} is {value}"` returns - `["key", "value"]`. - """ - keys = [] - for block in (b for b in string.Formatter.parse("", template) if isinstance(template, str)): - key = block[1] - if key: - keys.append(key) - return keys - - -def deep_convert_keys(_dict: t.Type[DeepConvert], predicate: t.Callable[[str], str]) -> DeepConvert: - """Convert all dictionary keys and nested dictionary keys.""" - converted = {} - - def get_value(value: t.Any): - if isinstance(value, t.Dict): - return {predicate(k): get_value(v) for k, v in value.items()} - elif isinstance(value, t.List): - return [get_value(v) for v in value] - elif isinstance(value, t.Tuple): - return tuple(get_value(v) for v in value) - return value - - for key, value in _dict.items(): - converted[predicate(key)] = get_value(value) - - return converted - - -def at_least( - minimum: int, - value: int, -) -> int: - """Get a number value that is at least a specified minimum.""" - if value < minimum: - return minimum - return value - - -def compare_dicts(dict_a: t.Dict[t.Any, t.Any], dict_b: t.Dict[t.Any, t.Any]) -> bool: - """Determine if two dictationaries are (mostly) equal.""" - if isinstance(dict_a, t.Dict) and isinstance(dict_b, t.Dict): - dict_a_keys, dict_a_values = set(dict_a.keys()), set(dict_a.values()) - dict_b_keys, dict_b_values = set(dict_b.keys()), set(dict_b.values()) - return all((dict_a_keys == dict_b_keys, dict_a_values == dict_b_values)) - return False - - -def compare_init(obj_a: object, obj_b: object) -> bool: - """Compare the `__init__` annoations of two objects.""" - - def _check_obj(obj: object): - """Ensure `__annotations__` exists on the `__init__` method.""" - if hasattr(obj, "__init__") and isinstance(getattr(obj, "__init__", None), t.Callable): - if hasattr(obj.__init__, "__annotations__") and isinstance( - getattr(obj.__init__, "__annotations__", None), t.Dict - ): - return True - return False - - if all((_check_obj(obj_a), _check_obj(obj_b))): - obj_a.__init__.__annotations__.pop("self", None) - obj_b.__init__.__annotations__.pop("self", None) - return compare_dicts(obj_a.__init__.__annotations__, obj_b.__init__.__annotations__) - return False - - -def run_coroutine_in_new_thread(coroutine: t.Coroutine) -> t.Any: - """Run an async function in a separate thread and get the result.""" - # Standard Library - import asyncio - import threading - - class Resolver(threading.Thread): - def __init__(self, coro: t.Coroutine) -> None: - self.result: t.Any = None - self.coro: t.Coroutine = coro - super().__init__() - - def run(self): - self.result = asyncio.run(self.coro()) - - thread = Resolver(coroutine) - thread.start() - thread.join() - return thread.result - - -def compare_lists(left: t.List[t.Any], right: t.List[t.Any], *, ignore: Series[t.Any] = ()) -> bool: - """Determine if all items in left list exist in right list.""" - left_ignored = [i for i in left if i not in ignore] - diff_ignored = [i for i in left if i in right and i not in ignore] - return len(left_ignored) == len(diff_ignored) +# Local +from .files import copyfiles, check_path, move_files, dotenv_to_dict +from .tools import ( + at_least, + compare_init, + get_fmt_keys, + compare_dicts, + compare_lists, + snake_to_camel, + parse_exception, + repr_from_attrs, + deep_convert_keys, + split_on_uppercase, + run_coroutine_in_new_thread, +) +from .typing import is_type, is_series +from .frontend import build_ui, build_frontend +from .validation import get_driver, resolve_hostname, validate_platform +from .system_info import cpu_count, check_python + +__all__ = ( + "at_least", + "build_frontend", + "build_ui", + "check_path", + "check_python", + "compare_dicts", + "compare_init", + "compare_lists", + "copyfiles", + "cpu_count", + "deep_convert_keys", + "dotenv_to_dict", + "get_driver", + "get_fmt_keys", + "is_series", + "is_type", + "move_files", + "parse_exception", + "repr_from_attrs", + "resolve_hostname", + "run_coroutine_in_new_thread", + "snake_to_camel", + "split_on_uppercase", + "validate_platform", +) diff --git a/hyperglass/util/frontend.py b/hyperglass/util/frontend.py index 9f5c386..7c32408 100644 --- a/hyperglass/util/frontend.py +++ b/hyperglass/util/frontend.py @@ -7,7 +7,6 @@ import math import shutil import typing as t import asyncio -import subprocess from pathlib import Path # Project @@ -21,18 +20,6 @@ if t.TYPE_CHECKING: from hyperglass.models.ui import UIParameters -def get_node_version() -> t.Tuple[int, int, int]: - """Get the system's NodeJS version.""" - node_path = shutil.which("node") - - raw_version = subprocess.check_output([node_path, "--version"]).decode() # noqa: S603 - - # Node returns the version as 'v14.5.0', for example. Remove the v. - version = raw_version.replace("v", "") - # Parse the version parts. - return tuple((int(v) for v in version.split("."))) - - def get_ui_build_timeout() -> t.Optional[int]: """Read the UI build timeout from environment variables or set a default.""" timeout = None diff --git a/hyperglass/util/system_info.py b/hyperglass/util/system_info.py index bb4edb2..753099c 100644 --- a/hyperglass/util/system_info.py +++ b/hyperglass/util/system_info.py @@ -2,8 +2,9 @@ # Standard Library import os +import sys +import typing as t import platform -from typing import Dict, Tuple, Union # Third Party import psutil as _psutil @@ -12,10 +13,7 @@ from cpuinfo import get_cpu_info as _get_cpu_info # type: ignore # Project from hyperglass.constants import __version__ -# Local -from .frontend import get_node_version - -SystemData = Dict[str, Tuple[Union[str, int], str]] +SystemData = t.Dict[str, t.Tuple[t.Union[str, int], str]] def _cpu() -> SystemData: @@ -44,6 +42,48 @@ def _disk() -> SystemData: return (total_gb, usage_percent) +def get_node_version() -> t.Tuple[int, int, int]: + """Get the system's NodeJS version.""" + + # Standard Library + import shutil + import subprocess + + node_path = shutil.which("node") + + raw_version = subprocess.check_output([node_path, "--version"]).decode() # noqa: S603 + + # Node returns the version as 'v14.5.0', for example. Remove the v. + version = raw_version.replace("v", "") + # Parse the version parts. + return tuple((int(v) for v in version.split("."))) + + +def cpu_count(multiplier: int = 0) -> int: + """Get server's CPU core count. + + Used to determine the number of web server workers. + """ + # Standard Library + import multiprocessing + + return multiprocessing.cpu_count() * multiplier + + +def check_python() -> str: + """Verify Python Version.""" + # Project + from hyperglass.constants import MIN_PYTHON_VERSION + + pretty_version = ".".join(tuple(str(v) for v in MIN_PYTHON_VERSION)) + running_version = ".".join( + str(v) for v in (sys.version_info.major, sys.version_info.minor, sys.version_info.micro) + ) + if sys.version_info < MIN_PYTHON_VERSION: + raise RuntimeError(f"Python {pretty_version}+ is required (Running {running_version})") + return running_version + + def get_system_info() -> SystemData: """Get system info.""" diff --git a/hyperglass/util/tests/test_tools.py b/hyperglass/util/tests/test_tools.py new file mode 100644 index 0000000..47f00b3 --- /dev/null +++ b/hyperglass/util/tests/test_tools.py @@ -0,0 +1,191 @@ +"""Test generic utilities.""" + +# Standard Library +import asyncio + +# Third Party +import pytest + +# Local +from ..tools import ( + at_least, + compare_init, + get_fmt_keys, + compare_dicts, + compare_lists, + snake_to_camel, + parse_exception, + repr_from_attrs, + deep_convert_keys, + split_on_uppercase, + run_coroutine_in_new_thread, +) + + +def test_split_on_uppercase(): + strings = ( + ("TestOne", ["Test", "One"]), + ("testTwo", ["test", "Two"]), + ("TestingOneTwoThree", ["Testing", "One", "Two", "Three"]), + ) + for str_in, list_out in strings: + result = split_on_uppercase(str_in) + assert result == list_out + + +def test_parse_exception(): + with pytest.raises(TypeError): + parse_exception(1) + + exc1 = RuntimeError("Test1") + exc1_expected = f"Runtime Error ({(RuntimeError.__doc__ or '').strip('.')})" + exc2 = RuntimeError("Test2") + exc2_cause = f"Connection Error ({(ConnectionError.__doc__ or '').strip('.')})" + exc2_expected = f"{exc1_expected}, caused by {exc2_cause}" + try: + raise exc1 + except Exception as err: + result = parse_exception(err) + assert result == exc1_expected + try: + raise exc2 from ConnectionError + except Exception as err: + result = parse_exception(err) + assert result == exc2_expected + + +def test_repr_from_attrs(): + # Third Party + from pydantic import create_model + + model = create_model("TestModel", one=(str, ...), two=(int, ...), three=(bool, ...)) + implementation = model(one="one", two=2, three=True) + result = repr_from_attrs(implementation, ("one", "two", "three")) + assert result == "TestModel(one='one', three=True, two=2)" + + +@pytest.mark.dependency() +def test_snake_to_camel(): + keys = ( + ("test_one", "testOne"), + ("test_two_three", "testTwoThree"), + ("Test_four_five_six", "testFourFiveSix"), + ) + for key_in, key_out in keys: + result = snake_to_camel(key_in) + assert result == key_out + + +def test_get_fmt_keys(): + template = "This is a {template} for a {test}" + result = get_fmt_keys(template) + assert len(result) == 2 and "template" in result and "test" in result + + +@pytest.mark.dependency( + depends=["hyperglass/util/tests/test_tools.py::test_snake_to_camel"], scope="session" +) +def test_deep_convert_keys(): + dict_in = { + "key_one": 1, + "key_two": 2, + "key_dict": { + "key_one": "one", + "key_two": "two", + }, + "key_list_dicts": [{"key_one": 101, "key_two": 102}, {"key_three": 103, "key_four": 104}], + } + + result = deep_convert_keys(dict_in, snake_to_camel) + assert result.get("keyOne") is not None + assert result.get("keyTwo") is not None + assert result.get("keyDict") is not None + assert result["keyDict"].get("keyOne") is not None + assert result["keyDict"].get("keyTwo") is not None + assert isinstance(result.get("keyListDicts"), list) + assert result["keyListDicts"][0].get("keyOne") is not None + assert result["keyListDicts"][0].get("keyTwo") is not None + assert result["keyListDicts"][1].get("keyThree") is not None + assert result["keyListDicts"][1].get("keyFour") is not None + + +def test_at_least(): + assert at_least(8, 10) == 10 + assert at_least(8, 6) == 8 + + +def test_compare_dicts(): + + d1 = {"one": 1, "two": 2} + d2 = {"one": 1, "two": 2} + d3 = {"one": 1, "three": 3} + d4 = {"one": 1, "two": 3} + d5 = {} + d6 = {} + checks = ( + (d1, d2, True), + (d1, d3, False), + (d1, d4, False), + (d1, d1, True), + (d5, d6, True), + (d1, [], False), + ) + for a, b, expected in checks: + assert compare_dicts(a, b) is expected + + +def test_compare_init(): + class Compare1: + def __init__(self, item: str) -> None: + pass + + class Compare2: + def __init__(self: "Compare2", item: str) -> None: + pass + + class Compare3: + def __init__(self: "Compare3", item: str, other_item: int) -> None: + pass + + class Compare4: + def __init__(self: "Compare4", item: bool) -> None: + pass + + class Compare5: + pass + + checks = ( + (Compare1, Compare2, True), + (Compare1, Compare3, False), + (Compare1, Compare4, False), + (Compare1, Compare5, False), + (Compare1, Compare1, True), + ) + for a, b, expected in checks: + assert compare_init(a, b) is expected + + +def test_run_coroutine_in_new_thread(): + async def sleeper(): + await asyncio.sleep(5) + + async def test(): + return True + + asyncio.run(sleeper()) + result = run_coroutine_in_new_thread(test) + assert result is True + + +def test_compare_lists(): + # Standard Library + import random + + list1 = ["one", 2, "3"] + list2 = [4, "5", "six"] + list3 = ["one", 11, False] + list4 = [*list1, *list2] + random.shuffle(list4) + assert compare_lists(list1, list2) is False + assert compare_lists(list1, list3) is False + assert compare_lists(list1, list4) is True diff --git a/hyperglass/util/tests/test_utilities.py b/hyperglass/util/tests/test_utilities.py deleted file mode 100644 index 93ba3ca..0000000 --- a/hyperglass/util/tests/test_utilities.py +++ /dev/null @@ -1,69 +0,0 @@ -"""Test generic utilities.""" -# Standard Library -import asyncio - -# Local -from .. import compare_init, compare_dicts, run_coroutine_in_new_thread - - -def test_compare_dicts(): - - d1 = {"one": 1, "two": 2} - d2 = {"one": 1, "two": 2} - d3 = {"one": 1, "three": 3} - d4 = {"one": 1, "two": 3} - d5 = {} - d6 = {} - checks = ( - (d1, d2, True), - (d1, d3, False), - (d1, d4, False), - (d1, d1, True), - (d5, d6, True), - (d1, [], False), - ) - for a, b, expected in checks: - assert compare_dicts(a, b) is expected - - -def test_compare_init(): - class Compare1: - def __init__(self, item: str) -> None: - pass - - class Compare2: - def __init__(self: "Compare2", item: str) -> None: - pass - - class Compare3: - def __init__(self: "Compare3", item: str, other_item: int) -> None: - pass - - class Compare4: - def __init__(self: "Compare4", item: bool) -> None: - pass - - class Compare5: - pass - - checks = ( - (Compare1, Compare2, True), - (Compare1, Compare3, False), - (Compare1, Compare4, False), - (Compare1, Compare5, False), - (Compare1, Compare1, True), - ) - for a, b, expected in checks: - assert compare_init(a, b) is expected - - -def test_run_coroutine_in_new_thread(): - async def sleeper(): - await asyncio.sleep(5) - - async def test(): - return True - - asyncio.run(sleeper()) - result = run_coroutine_in_new_thread(test) - assert result is True diff --git a/hyperglass/util/tools.py b/hyperglass/util/tools.py new file mode 100644 index 0000000..3fd745b --- /dev/null +++ b/hyperglass/util/tools.py @@ -0,0 +1,185 @@ +"""Collection of generalized functional tools.""" + +# Standard Library +import typing as t + +# Project +from hyperglass.types import Series + +DeepConvert = t.TypeVar("DeepConvert", bound=t.Dict[str, t.Any]) + + +def run_coroutine_in_new_thread(coroutine: t.Coroutine) -> t.Any: + """Run an async function in a separate thread and get the result.""" + # Standard Library + import asyncio + import threading + + class Resolver(threading.Thread): + def __init__(self, coro: t.Coroutine) -> None: + self.result: t.Any = None + self.coro: t.Coroutine = coro + super().__init__() + + def run(self): + self.result = asyncio.run(self.coro()) + + thread = Resolver(coroutine) + thread.start() + thread.join() + return thread.result + + +def split_on_uppercase(s: str) -> t.List[str]: + """Split characters by uppercase letters. + + From: https://stackoverflow.com/a/40382663 + """ + string_length = len(s) + is_lower_around = lambda: s[i - 1].islower() or string_length > (i + 1) and s[i + 1].islower() + + start = 0 + parts = [] + for i in range(1, string_length): + if s[i].isupper() and is_lower_around(): + parts.append(s[start:i]) + start = i + parts.append(s[start:]) + + return parts + + +def parse_exception(exc: BaseException) -> str: + """Parse an exception and its direct cause.""" + + if not isinstance(exc, BaseException): + raise TypeError(f"'{repr(exc)}' is not an exception.") + + def get_exc_name(exc): + return " ".join(split_on_uppercase(exc.__class__.__name__)) + + def get_doc_summary(doc): + return doc.strip().split("\n")[0].strip(".") + + name = get_exc_name(exc) + parsed = [] + if exc.__doc__: + detail = get_doc_summary(exc.__doc__) + parsed.append(f"{name} ({detail})") + else: + parsed.append(name) + + if exc.__cause__: + cause = get_exc_name(exc.__cause__) + if exc.__cause__.__doc__: + cause_detail = get_doc_summary(exc.__cause__.__doc__) + parsed.append(f"{cause} ({cause_detail})") + else: + parsed.append(cause) + return ", caused by ".join(parsed) + + +def repr_from_attrs(obj: object, attrs: Series[str], strip: t.Optional[str] = None) -> str: + """Generate a `__repr__()` value from a specific set of attribute names. + + Useful for complex models/objects where `__repr__()` should only display specific fields. + """ + # Check the object to ensure each attribute actually exists, and deduplicate + attr_names = {a for a in attrs if hasattr(obj, a)} + # Dict representation of attr name to obj value (e.g. `obj.attr`), if the value has a + # `__repr__` method. + attr_values = { + f if strip is None else f.strip(strip): v # noqa: IF100 + for f in attr_names + if hasattr((v := getattr(obj, f)), "__repr__") + } + pairs = (f"{k}={v!r}" for k, v in sorted(attr_values.items())) + return f"{obj.__class__.__name__}({', '.join(pairs)})" + + +def snake_to_camel(value: str) -> str: + """Convert a string from snake_case to camelCase.""" + head, *body = value.split("_") + humps = (hump.capitalize() for hump in body) + return "".join((head.lower(), *humps)) + + +def get_fmt_keys(template: str) -> t.List[str]: + """Get a list of str.format keys. + + For example, string `"The value of {key} is {value}"` returns + `["key", "value"]`. + """ + # Standard Library + import string + + keys = [] + for block in (b for b in string.Formatter.parse("", template) if isinstance(template, str)): + key = block[1] + if key: + keys.append(key) + return keys + + +def deep_convert_keys(_dict: t.Type[DeepConvert], predicate: t.Callable[[str], str]) -> DeepConvert: + """Convert all dictionary keys and nested dictionary keys.""" + converted = {} + + def get_value(value: t.Any): + if isinstance(value, t.Dict): + return {predicate(k): get_value(v) for k, v in value.items()} + elif isinstance(value, t.List): + return [get_value(v) for v in value] + elif isinstance(value, t.Tuple): + return tuple(get_value(v) for v in value) + return value + + for key, value in _dict.items(): + converted[predicate(key)] = get_value(value) + + return converted + + +def at_least( + minimum: int, + value: int, +) -> int: + """Get a number value that is at least a specified minimum.""" + if value < minimum: + return minimum + return value + + +def compare_dicts(dict_a: t.Dict[t.Any, t.Any], dict_b: t.Dict[t.Any, t.Any]) -> bool: + """Determine if two dictationaries are (mostly) equal.""" + if isinstance(dict_a, t.Dict) and isinstance(dict_b, t.Dict): + dict_a_keys, dict_a_values = set(dict_a.keys()), set(dict_a.values()) + dict_b_keys, dict_b_values = set(dict_b.keys()), set(dict_b.values()) + return all((dict_a_keys == dict_b_keys, dict_a_values == dict_b_values)) + return False + + +def compare_lists(left: t.List[t.Any], right: t.List[t.Any], *, ignore: Series[t.Any] = ()) -> bool: + """Determine if all items in left list exist in right list.""" + left_ignored = [i for i in left if i not in ignore] + diff_ignored = [i for i in left if i in right and i not in ignore] + return len(left_ignored) == len(diff_ignored) + + +def compare_init(obj_a: object, obj_b: object) -> bool: + """Compare the `__init__` annoations of two objects.""" + + def _check_obj(obj: object): + """Ensure `__annotations__` exists on the `__init__` method.""" + if hasattr(obj, "__init__") and isinstance(getattr(obj, "__init__", None), t.Callable): + if hasattr(obj.__init__, "__annotations__") and isinstance( + getattr(obj.__init__, "__annotations__", None), t.Dict + ): + return True + return False + + if all((_check_obj(obj_a), _check_obj(obj_b))): + obj_a.__init__.__annotations__.pop("self", None) + obj_b.__init__.__annotations__.pop("self", None) + return compare_dicts(obj_a.__init__.__annotations__, obj_b.__init__.__annotations__) + return False diff --git a/hyperglass/util/validation.py b/hyperglass/util/validation.py new file mode 100644 index 0000000..e0de95d --- /dev/null +++ b/hyperglass/util/validation.py @@ -0,0 +1,73 @@ +"""Validation Utilities.""" + +# Standard Library +import typing as t + +# Third Party +from netmiko.ssh_dispatcher import CLASS_MAPPER # type: ignore + +# Project +from hyperglass.constants import DRIVER_MAP + +ALL_DEVICE_TYPES = {*DRIVER_MAP.keys(), *CLASS_MAPPER.keys()} +ALL_DRIVERS = {*DRIVER_MAP.values(), "netmiko"} + +if t.TYPE_CHECKING: + # Standard Library + from ipaddress import IPv4Address, IPv6Address + + +def validate_platform(_type: str) -> t.Tuple[bool, t.Union[None, str]]: + """Validate device type is supported.""" + + result = (False, None) + + if _type in ALL_DEVICE_TYPES: + result = (True, DRIVER_MAP.get(_type, "netmiko")) + + return result + + +def get_driver(_type: str, driver: t.Optional[str]) -> str: + """Determine the appropriate driver for a device.""" + + if driver is None: + # If no driver is set, use the driver map with netmiko as + # fallback. + return DRIVER_MAP.get(_type, "netmiko") + elif driver in ALL_DRIVERS: + # If a driver is set and it is valid, allow it. + return driver + else: + # Otherwise, fail validation. + raise ValueError("{} is not a supported driver.".format(driver)) + + +def resolve_hostname( + hostname: str, +) -> t.Generator[t.Union["IPv4Address", "IPv6Address"], None, None]: + """Resolve a hostname via DNS/hostfile.""" + # Standard Library + from socket import gaierror, getaddrinfo + from ipaddress import ip_address + + # Project + from hyperglass.log import log + + log.debug("Ensuring {!r} is resolvable...", hostname) + + ip4 = None + ip6 = None + try: + res = getaddrinfo(hostname, None) + for sock in res: + if sock[0].value == 2 and ip4 is None: + ip4 = ip_address(sock[4][0]) + elif sock[0].value in (10, 30) and ip6 is None: + ip6 = ip_address(sock[4][0]) + except (gaierror, ValueError, IndexError) as err: + log.debug(str(err)) + pass + + yield ip4 + yield ip6 diff --git a/poetry.lock b/poetry.lock index cb18bff..5f7566b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -820,7 +820,7 @@ python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" [[package]] name = "py-cpuinfo" -version = "7.0.0" +version = "8.0.0" description = "Get CPU info with pure Python 2 & 3" category = "main" optional = false @@ -1340,7 +1340,7 @@ python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" [metadata] lock-version = "1.1" python-versions = ">=3.8.1,<4.0" -content-hash = "a13c78fca92dfe22206f8b27362c779021ee420756bb7c6d7f39ec72bc9d7148" +content-hash = "59c7bf05d11ded8cd759701bc8e9c8962ae44ecfb8b1fa0b222119b08284c273" [metadata.files] aiofiles = [ @@ -1819,7 +1819,7 @@ py = [ {file = "py-1.11.0.tar.gz", hash = "sha256:51c75c4126074b472f746a24399ad32f6053d1b34b68d2fa41e558e6f4a98719"}, ] py-cpuinfo = [ - {file = "py-cpuinfo-7.0.0.tar.gz", hash = "sha256:9aa2e49675114959697d25cf57fec41c29b55887bff3bc4809b44ac6f5730097"}, + {file = "py-cpuinfo-8.0.0.tar.gz", hash = "sha256:5f269be0e08e33fd959de96b34cd4aeeeacac014dd8305f70eb28d06de2345c5"}, ] pycodestyle = [ {file = "pycodestyle-2.7.0-py2.py3-none-any.whl", hash = "sha256:514f76d918fcc0b55c6680472f0a37970994e07bbb80725808c17089be302068"}, diff --git a/pyproject.toml b/pyproject.toml index 60d0476..ffb1002 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,7 +43,7 @@ loguru = "^0.5.3" netmiko = "^3.4.0" paramiko = "^2.7.2" psutil = "^5.7.2" -py-cpuinfo = "^7.0.0" +py-cpuinfo = "^8.0.0" pydantic = {extras = ["dotenv"], version = "^1.8.2"} python = ">=3.8.1,<4.0" redis = "^3.5.3"