forked from mirrors/thatmattlove-hyperglass
117 lines
4.1 KiB
Python
117 lines
4.1 KiB
Python
"""Interact with an http-based device."""
|
|
|
|
# Standard Library
|
|
import typing as t
|
|
|
|
# Third Party
|
|
import httpx
|
|
|
|
# Project
|
|
from hyperglass.util import get_fmt_keys
|
|
from hyperglass.exceptions.public import AuthError, RestError, DeviceTimeout, ResponseEmpty
|
|
|
|
# Local
|
|
from ._common import Connection
|
|
|
|
if t.TYPE_CHECKING:
|
|
# Project
|
|
from hyperglass.models.api import Query
|
|
from hyperglass.models.config.devices import Device
|
|
from hyperglass.models.config.http_client import HttpConfiguration
|
|
|
|
|
|
class HttpClient(Connection):
|
|
"""Interact with an http-based device."""
|
|
|
|
config: "HttpConfiguration"
|
|
client: httpx.AsyncClient
|
|
|
|
def __init__(self, device: "Device", query_data: "Query") -> None:
|
|
"""Initialize base connection and set http config & client."""
|
|
super().__init__(device, query_data)
|
|
self.config = device.http
|
|
self.client = self.config.create_client(device=device)
|
|
|
|
def setup_proxy(self: "Connection"):
|
|
"""HTTP Client does not support SSH proxies."""
|
|
raise NotImplementedError("HTTP Client does not support SSH proxies.")
|
|
|
|
def _query_params(self) -> t.Dict[str, str]:
|
|
if self.config.query is None:
|
|
return {
|
|
self.config._attribute_map.query_target: self.query_data.query_target,
|
|
self.config._attribute_map.query_location: self.query_data.query_location,
|
|
self.config._attribute_map.query_type: self.query_data.query_type,
|
|
}
|
|
if isinstance(self.config.query, t.Dict):
|
|
return {
|
|
key: value.format(
|
|
**{
|
|
str(v): str(getattr(self.query_data, k, None))
|
|
for k, v in self.config.attribute_map.model_dump().items()
|
|
if v in get_fmt_keys(value)
|
|
}
|
|
)
|
|
for key, value in self.config.query.items()
|
|
}
|
|
return {}
|
|
|
|
def _body(self) -> t.Dict[str, t.Union[t.Dict[str, t.Any], str]]:
|
|
data = {
|
|
self.config._attribute_map.query_target: self.query_data.query_target,
|
|
self.config._attribute_map.query_location: self.query_data.query_location,
|
|
self.config._attribute_map.query_type: self.query_data.query_type,
|
|
}
|
|
if self.config.body_format == "json":
|
|
return {"json": data}
|
|
|
|
if self.config.body_format == "yaml":
|
|
# Third Party
|
|
import yaml
|
|
|
|
return {"content": yaml.dump(data), "headers": {"content-type": "text/yaml"}}
|
|
|
|
if self.config.body_format == "xml":
|
|
# Third Party
|
|
import xmltodict # type: ignore
|
|
|
|
return {
|
|
"content": xmltodict.unparse({"query": data}),
|
|
"headers": {"content-type": "application/xml"},
|
|
}
|
|
if self.config.body_format == "text":
|
|
return {"data": data}
|
|
|
|
return {}
|
|
|
|
async def collect(self, *args: t.Any, **kwargs: t.Any) -> t.Iterable:
|
|
"""Collect response data from an HTTP endpoint."""
|
|
|
|
query = self._query_params()
|
|
responses = ()
|
|
|
|
async with self.client as client:
|
|
body = {}
|
|
if self.config.method in ("POST", "PATCH", "PUT"):
|
|
body = self._body()
|
|
|
|
try:
|
|
response: httpx.Response = await client.request(
|
|
method=self.config.method, url=self.config.path, params=query, **body
|
|
)
|
|
response.raise_for_status()
|
|
data = response.text.strip()
|
|
|
|
if len(data) == 0:
|
|
raise ResponseEmpty(query=self.query_data)
|
|
|
|
responses += (data,)
|
|
|
|
except httpx.TimeoutException as error:
|
|
raise DeviceTimeout(error=error, device=self.device) from error
|
|
|
|
except httpx.HTTPStatusError as error:
|
|
if error.response.status_code == 401:
|
|
raise AuthError(error=error, device=self.device) from error
|
|
raise RestError(error=error, device=self.device) from error
|
|
return responses
|