mirror of
https://github.com/thatmattlove/hyperglass.git
synced 2026-01-17 08:48:05 +00:00
Redis caching improvements
This commit is contained in:
parent
69416b4dda
commit
b5a67e7c0e
8 changed files with 311 additions and 152 deletions
|
|
@ -6,9 +6,9 @@ from ipaddress import ip_network
|
|||
|
||||
# Project
|
||||
from hyperglass.log import log
|
||||
from hyperglass.external import bgptools
|
||||
from hyperglass.exceptions import InputInvalid, InputNotAllowed
|
||||
from hyperglass.configuration import params
|
||||
from hyperglass.external.bgptools import network_info_sync
|
||||
|
||||
|
||||
def _member_of(target, network):
|
||||
|
|
@ -144,9 +144,17 @@ def validate_ip(value, query_type, query_vrf): # noqa: C901
|
|||
):
|
||||
log.debug("Getting containing prefix for {q}", q=str(valid_ip))
|
||||
|
||||
containing_prefix = bgptools.network_info_sync(
|
||||
valid_ip.network_address
|
||||
).get("prefix")
|
||||
ip_str = str(valid_ip.network_address)
|
||||
network_info = network_info_sync(ip_str)
|
||||
containing_prefix = network_info.get(ip_str, {}).get("prefix")
|
||||
|
||||
if containing_prefix is None:
|
||||
log.error(
|
||||
"Unable to find containing prefix for {}. Got: {}",
|
||||
str(valid_ip),
|
||||
network_info,
|
||||
)
|
||||
raise InputInvalid("{q} does not have a containing prefix", q=ip_str)
|
||||
|
||||
try:
|
||||
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
|
|||
# Project
|
||||
from hyperglass.log import log
|
||||
from hyperglass.util import clean_name, process_headers, import_public_key
|
||||
from hyperglass.cache import Cache
|
||||
from hyperglass.cache import AsyncCache
|
||||
from hyperglass.encode import jwt_decode
|
||||
from hyperglass.external import Webhook, bgptools
|
||||
from hyperglass.exceptions import HyperglassError
|
||||
|
|
@ -56,7 +56,7 @@ async def send_webhook(query_data: Query, request: Request, timestamp: datetime)
|
|||
**query_data.export_dict(pretty=True),
|
||||
"headers": headers,
|
||||
"source": host,
|
||||
"network": network_info,
|
||||
"network": network_info.get(host, {}),
|
||||
"timestamp": timestamp,
|
||||
}
|
||||
)
|
||||
|
|
@ -73,7 +73,7 @@ async def query(query_data: Query, request: Request, background_tasks: Backgroun
|
|||
background_tasks.add_task(send_webhook, query_data, request, timestamp)
|
||||
|
||||
# Initialize cache
|
||||
cache = Cache(db=params.cache.database, **REDIS_CONFIG)
|
||||
cache = AsyncCache(db=params.cache.database, **REDIS_CONFIG)
|
||||
log.debug("Initialized cache {}", repr(cache))
|
||||
|
||||
# Use hashed query_data string as key for for k/v cache store so
|
||||
|
|
@ -131,12 +131,10 @@ async def query(query_data: Query, request: Request, background_tasks: Backgroun
|
|||
|
||||
# If it does, return the cached entry
|
||||
cache_response = await cache.get_dict(cache_key, "output")
|
||||
response_format = "text/plain"
|
||||
|
||||
if query_data.device.structured_output:
|
||||
response_format = "application/json"
|
||||
cache_response = json.loads(cache_response)
|
||||
else:
|
||||
response_format = "text/plain"
|
||||
|
||||
log.debug(f"Cache match for {cache_key}:\n {cache_response}")
|
||||
log.success(f"Completed query execution for {query_data.summary}")
|
||||
|
|
|
|||
|
|
@ -1,139 +0,0 @@
|
|||
"""Redis cache handler."""
|
||||
|
||||
# Standard Library
|
||||
import time
|
||||
import asyncio
|
||||
|
||||
# Third Party
|
||||
from aredis import StrictRedis
|
||||
|
||||
|
||||
class Cache:
|
||||
"""Redis cache handler."""
|
||||
|
||||
def __init__(
|
||||
self, db, host="localhost", port=6379, decode_responses=True, **kwargs
|
||||
):
|
||||
"""Initialize Redis connection."""
|
||||
self.db: int = db
|
||||
self.host: str = host
|
||||
self.port: int = port
|
||||
self.instance: StrictRedis = StrictRedis(
|
||||
db=self.db,
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
decode_responses=decode_responses,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
"""Represent class state."""
|
||||
return f"ConfigCache(db={self.db}, host={self.host}, port={self.port})"
|
||||
|
||||
def __getitem__(self, item):
|
||||
"""Enable subscriptable syntax."""
|
||||
return self.get(item)
|
||||
|
||||
@staticmethod
|
||||
async def _parse_types(value):
|
||||
"""Parse a string to standard python types."""
|
||||
import re
|
||||
|
||||
async def _parse_string(str_value):
|
||||
|
||||
is_float = (re.compile(r"^(\d+\.\d+)$"), float)
|
||||
is_int = (re.compile(r"^(\d+)$"), int)
|
||||
is_bool = (re.compile(r"^(True|true|False|false)$"), bool)
|
||||
is_none = (re.compile(r"^(None|none|null|nil|\(nil\))$"), lambda v: None)
|
||||
|
||||
for pattern, factory in (is_float, is_int, is_bool, is_none):
|
||||
if isinstance(str_value, str) and bool(re.match(pattern, str_value)):
|
||||
str_value = factory(str_value)
|
||||
break
|
||||
return str_value
|
||||
|
||||
if isinstance(value, str):
|
||||
value = await _parse_string(value)
|
||||
elif isinstance(value, bytes):
|
||||
value = await _parse_string(value.decode("utf-8"))
|
||||
elif isinstance(value, list):
|
||||
value = [await _parse_string(i) for i in value]
|
||||
elif isinstance(value, tuple):
|
||||
value = tuple(await _parse_string(i) for i in value)
|
||||
|
||||
return value
|
||||
|
||||
async def get(self, *args):
|
||||
"""Get item(s) from cache."""
|
||||
if len(args) == 1:
|
||||
raw = await self.instance.get(args[0])
|
||||
else:
|
||||
raw = await self.instance.mget(args)
|
||||
return await self._parse_types(raw)
|
||||
|
||||
async def get_dict(self, key, field=None):
|
||||
"""Get hash map (dict) item(s)."""
|
||||
if field is None:
|
||||
raw = await self.instance.hgetall(key)
|
||||
else:
|
||||
raw = await self.instance.hget(key, field)
|
||||
return await self._parse_types(raw)
|
||||
|
||||
async def set(self, key, value):
|
||||
"""Set cache values."""
|
||||
return await self.instance.set(key, value)
|
||||
|
||||
async def set_dict(self, key, field, value):
|
||||
"""Set hash map (dict) values."""
|
||||
return await self.instance.hset(key, field, value)
|
||||
|
||||
async def wait(self, pubsub, timeout=30, **kwargs):
|
||||
"""Wait for pub/sub messages & return posted message."""
|
||||
now = time.time()
|
||||
timeout = now + timeout
|
||||
while now < timeout:
|
||||
message = await pubsub.get_message(ignore_subscribe_messages=True, **kwargs)
|
||||
if message is not None and message["type"] == "message":
|
||||
data = message["data"]
|
||||
return await self._parse_types(data)
|
||||
|
||||
await asyncio.sleep(0.01)
|
||||
now = time.time()
|
||||
return None
|
||||
|
||||
async def pubsub(self):
|
||||
"""Provide an aredis.pubsub.Pubsub instance."""
|
||||
return self.instance.pubsub()
|
||||
|
||||
async def pub(self, key, value):
|
||||
"""Publish a value."""
|
||||
await asyncio.sleep(1)
|
||||
await self.instance.publish(key, value)
|
||||
|
||||
async def clear(self):
|
||||
"""Clear the cache."""
|
||||
await self.instance.flushdb()
|
||||
|
||||
async def delete(self, *keys):
|
||||
"""Delete a cache key."""
|
||||
await self.instance.delete(*keys)
|
||||
|
||||
async def expire(self, *keys, seconds):
|
||||
"""Set timeout of key in seconds."""
|
||||
for key in keys:
|
||||
await self.instance.expire(key, seconds)
|
||||
|
||||
async def aget_config(self):
|
||||
"""Get picked config object from cache."""
|
||||
import pickle
|
||||
|
||||
pickled = await self.instance.get("HYPERGLASS_CONFIG")
|
||||
return pickle.loads(pickled)
|
||||
|
||||
def get_config(self):
|
||||
"""Get picked config object from cache."""
|
||||
import pickle
|
||||
from hyperglass.compat._asyncio import aiorun
|
||||
|
||||
pickled = aiorun(self.instance.get("HYPERGLASS_CONFIG"))
|
||||
return pickle.loads(pickled)
|
||||
7
hyperglass/cache/__init__.py
vendored
Normal file
7
hyperglass/cache/__init__.py
vendored
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
"""Redis cache handlers."""
|
||||
|
||||
# Project
|
||||
from hyperglass.cache.aio import AsyncCache
|
||||
from hyperglass.cache.sync import SyncCache
|
||||
|
||||
__all__ = ("AsyncCache", "SyncCache")
|
||||
113
hyperglass/cache/aio.py
vendored
Normal file
113
hyperglass/cache/aio.py
vendored
Normal file
|
|
@ -0,0 +1,113 @@
|
|||
"""Asyncio Redis cache handler."""
|
||||
|
||||
# Standard Library
|
||||
import json
|
||||
import time
|
||||
import pickle
|
||||
import asyncio
|
||||
from typing import Any, Dict
|
||||
|
||||
# Third Party
|
||||
from aredis import StrictRedis as AsyncRedis
|
||||
from aredis.pubsub import PubSub as AsyncPubSub
|
||||
|
||||
# Project
|
||||
from hyperglass.cache.base import BaseCache
|
||||
|
||||
|
||||
class AsyncCache(BaseCache):
|
||||
"""Asynchronous Redis cache handler."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""Initialize Redis connection."""
|
||||
super().__init__(*args, **kwargs)
|
||||
self.instance: AsyncRedis = AsyncRedis(
|
||||
db=self.db,
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
decode_responses=self.decode_responses,
|
||||
**self.redis_args,
|
||||
)
|
||||
|
||||
async def get(self, *args: str) -> Any:
|
||||
"""Get item(s) from cache."""
|
||||
if len(args) == 1:
|
||||
raw = await self.instance.get(args[0])
|
||||
else:
|
||||
raw = await self.instance.mget(args)
|
||||
return self.parse_types(raw)
|
||||
|
||||
async def get_dict(self, key: str, field: str = "") -> Any:
|
||||
"""Get hash map (dict) item(s)."""
|
||||
if not field:
|
||||
raw = await self.instance.hgetall(key)
|
||||
else:
|
||||
raw = await self.instance.hget(key, field)
|
||||
|
||||
return self.parse_types(raw)
|
||||
|
||||
async def set(self, key: str, value: str) -> bool:
|
||||
"""Set cache values."""
|
||||
return await self.instance.set(key, value)
|
||||
|
||||
async def set_dict(self, key: str, field: str, value: str) -> bool:
|
||||
"""Set hash map (dict) values."""
|
||||
success = False
|
||||
|
||||
if isinstance(value, Dict):
|
||||
value = json.dumps(value)
|
||||
else:
|
||||
value = str(value)
|
||||
|
||||
response = await self.instance.hset(key, field, value)
|
||||
|
||||
if response in (0, 1):
|
||||
success = True
|
||||
|
||||
return success
|
||||
|
||||
async def wait(self, pubsub: AsyncPubSub, timeout: int = 30, **kwargs) -> Any:
|
||||
"""Wait for pub/sub messages & return posted message."""
|
||||
now = time.time()
|
||||
timeout = now + timeout
|
||||
|
||||
while now < timeout:
|
||||
|
||||
message = await pubsub.get_message(ignore_subscribe_messages=True, **kwargs)
|
||||
|
||||
if message is not None and message["type"] == "message":
|
||||
data = message["data"]
|
||||
return self.parse_types(data)
|
||||
|
||||
await asyncio.sleep(0.01)
|
||||
now = time.time()
|
||||
|
||||
return None
|
||||
|
||||
async def pubsub(self) -> AsyncPubSub:
|
||||
"""Provide an aredis.pubsub.Pubsub instance."""
|
||||
return self.instance.pubsub()
|
||||
|
||||
async def pub(self, key: str, value: str) -> None:
|
||||
"""Publish a value."""
|
||||
await asyncio.sleep(1)
|
||||
await self.instance.publish(key, value)
|
||||
|
||||
async def clear(self) -> None:
|
||||
"""Clear the cache."""
|
||||
await self.instance.flushdb()
|
||||
|
||||
async def delete(self, *keys: str) -> None:
|
||||
"""Delete a cache key."""
|
||||
await self.instance.delete(*keys)
|
||||
|
||||
async def expire(self, *keys: str, seconds: int) -> None:
|
||||
"""Set timeout of key in seconds."""
|
||||
for key in keys:
|
||||
await self.instance.expire(key, seconds)
|
||||
|
||||
async def get_config(self) -> Dict:
|
||||
"""Get picked config object from cache."""
|
||||
|
||||
pickled = await self.instance.get("HYPERGLASS_CONFIG")
|
||||
return pickle.loads(pickled)
|
||||
59
hyperglass/cache/base.py
vendored
Normal file
59
hyperglass/cache/base.py
vendored
Normal file
|
|
@ -0,0 +1,59 @@
|
|||
"""Base Redis cache handler."""
|
||||
|
||||
# Standard Library
|
||||
import re
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
|
||||
class BaseCache:
|
||||
"""Redis cache handler."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db: int,
|
||||
host: str = "localhost",
|
||||
port: int = 6379,
|
||||
decode_responses: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize Redis connection."""
|
||||
self.db: int = db
|
||||
self.host: str = host
|
||||
self.port: int = port
|
||||
self.decode_responses: bool = decode_responses
|
||||
self.redis_args: dict = kwargs
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Represent class state."""
|
||||
return f"HyperglassCache(db={self.db}, host={self.host}, port={self.port})"
|
||||
|
||||
def parse_types(self, value: str) -> Any:
|
||||
"""Parse a string to standard python types."""
|
||||
|
||||
def parse_string(str_value: str):
|
||||
|
||||
is_float = (re.compile(r"^(\d+\.\d+)$"), float)
|
||||
is_int = (re.compile(r"^(\d+)$"), int)
|
||||
is_bool = (re.compile(r"^(True|true|False|false)$"), bool)
|
||||
is_none = (re.compile(r"^(None|none|null|nil|\(nil\))$"), lambda v: None)
|
||||
is_jsonable = (re.compile(r"^[\{\[].*[\}\]]$"), json.loads)
|
||||
|
||||
for pattern, factory in (is_float, is_int, is_bool, is_none, is_jsonable):
|
||||
if isinstance(str_value, str) and bool(re.match(pattern, str_value)):
|
||||
str_value = factory(str_value)
|
||||
break
|
||||
return str_value
|
||||
|
||||
if isinstance(value, str):
|
||||
value = parse_string(value)
|
||||
elif isinstance(value, bytes):
|
||||
value = parse_string(value.decode("utf-8"))
|
||||
elif isinstance(value, list):
|
||||
value = [parse_string(i) for i in value]
|
||||
elif isinstance(value, tuple):
|
||||
value = tuple(parse_string(i) for i in value)
|
||||
elif isinstance(value, dict):
|
||||
value = {k: self.parse_types(v) for k, v in value.items()}
|
||||
|
||||
return value
|
||||
112
hyperglass/cache/sync.py
vendored
Normal file
112
hyperglass/cache/sync.py
vendored
Normal file
|
|
@ -0,0 +1,112 @@
|
|||
"""Non-asyncio Redis cache handler."""
|
||||
|
||||
# Standard Library
|
||||
import json
|
||||
import time
|
||||
import pickle
|
||||
from typing import Any, Dict
|
||||
|
||||
# Third Party
|
||||
from redis import Redis as SyncRedis
|
||||
from redis.client import PubSub as SyncPubsSub
|
||||
|
||||
# Project
|
||||
from hyperglass.cache.base import BaseCache
|
||||
|
||||
|
||||
class SyncCache(BaseCache):
|
||||
"""Synchronous Redis cache handler."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""Initialize Redis connection."""
|
||||
super().__init__(*args, **kwargs)
|
||||
self.instance: SyncRedis = SyncRedis(
|
||||
db=self.db,
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
decode_responses=self.decode_responses,
|
||||
**self.redis_args,
|
||||
)
|
||||
|
||||
def get(self, *args: str) -> Any:
|
||||
"""Get item(s) from cache."""
|
||||
if len(args) == 1:
|
||||
raw = self.instance.get(args[0])
|
||||
else:
|
||||
raw = self.instance.mget(args)
|
||||
return self.parse_types(raw)
|
||||
|
||||
def get_dict(self, key: str, field: str = "") -> Any:
|
||||
"""Get hash map (dict) item(s)."""
|
||||
if not field:
|
||||
raw = self.instance.hgetall(key)
|
||||
else:
|
||||
raw = self.instance.hget(key, str(field))
|
||||
|
||||
return self.parse_types(raw)
|
||||
|
||||
def set(self, key: str, value: str) -> bool:
|
||||
"""Set cache values."""
|
||||
return self.instance.set(key, str(value))
|
||||
|
||||
def set_dict(self, key: str, field: str, value: str) -> bool:
|
||||
"""Set hash map (dict) values."""
|
||||
success = False
|
||||
|
||||
if isinstance(value, Dict):
|
||||
value = json.dumps(value)
|
||||
else:
|
||||
value = str(value)
|
||||
|
||||
response = self.instance.hset(key, str(field), value)
|
||||
|
||||
if response in (0, 1):
|
||||
success = True
|
||||
|
||||
return success
|
||||
|
||||
def wait(self, pubsub: SyncPubsSub, timeout: int = 30, **kwargs) -> Any:
|
||||
"""Wait for pub/sub messages & return posted message."""
|
||||
now = time.time()
|
||||
timeout = now + timeout
|
||||
|
||||
while now < timeout:
|
||||
|
||||
message = pubsub.get_message(ignore_subscribe_messages=True, **kwargs)
|
||||
|
||||
if message is not None and message["type"] == "message":
|
||||
data = message["data"]
|
||||
return self.parse_types(data)
|
||||
|
||||
time.sleep(0.01)
|
||||
now = time.time()
|
||||
|
||||
return None
|
||||
|
||||
def pubsub(self) -> SyncPubsSub:
|
||||
"""Provide a redis.client.Pubsub instance."""
|
||||
return self.instance.pubsub()
|
||||
|
||||
def pub(self, key: str, value: str) -> None:
|
||||
"""Publish a value."""
|
||||
time.sleep(1)
|
||||
self.instance.publish(key, value)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear the cache."""
|
||||
self.instance.flushdb()
|
||||
|
||||
def delete(self, *keys: str) -> None:
|
||||
"""Delete a cache key."""
|
||||
self.instance.delete(*keys)
|
||||
|
||||
def expire(self, *keys: str, seconds: int) -> None:
|
||||
"""Set timeout of key in seconds."""
|
||||
for key in keys:
|
||||
self.instance.expire(key, seconds)
|
||||
|
||||
def get_config(self) -> Dict:
|
||||
"""Get picked config object from cache."""
|
||||
|
||||
pickled = self.instance.get("HYPERGLASS_CONFIG")
|
||||
return pickle.loads(pickled)
|
||||
|
|
@ -12,7 +12,7 @@ from gunicorn.app.base import BaseApplication
|
|||
|
||||
# Project
|
||||
from hyperglass.log import log
|
||||
from hyperglass.cache import Cache
|
||||
from hyperglass.cache import AsyncCache
|
||||
from hyperglass.constants import MIN_PYTHON_VERSION, __version__
|
||||
|
||||
pretty_version = ".".join(tuple(str(v) for v in MIN_PYTHON_VERSION))
|
||||
|
|
@ -85,7 +85,7 @@ async def cache_config():
|
|||
"""Add configuration to Redis cache as a pickled object."""
|
||||
import pickle
|
||||
|
||||
cache = Cache(
|
||||
cache = AsyncCache(
|
||||
db=params.cache.database, host=params.cache.host, port=params.cache.port
|
||||
)
|
||||
await cache.set("HYPERGLASS_CONFIG", pickle.dumps(params))
|
||||
|
|
@ -123,7 +123,8 @@ def on_exit(server: Arbiter):
|
|||
log.critical("Stopping hyperglass {}", __version__)
|
||||
|
||||
async def runner():
|
||||
await clear_cache()
|
||||
if not params.developer_mode:
|
||||
await clear_cache()
|
||||
|
||||
aiorun(runner())
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue