Skip to content

Commit

Permalink
Merge pull request #811 from circulon/fix/redis_cache_issues
Browse files Browse the repository at this point in the history
Fix/redis cache issues
  • Loading branch information
josephmancuso authored Aug 13, 2024
2 parents 9d26aec + 0d4709d commit 77a5499
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 43 deletions.
117 changes: 77 additions & 40 deletions src/masonite/cache/drivers/RedisDriver.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
from typing import Any, TYPE_CHECKING

import pendulum as pdlm
if TYPE_CHECKING:
from redis import Redis

Expand All @@ -9,59 +10,61 @@ class RedisDriver:
def __init__(self, application):
self.application = application
self.connection = None
self._internal_cache: "dict|None" = None
self.options = {}
self._internal_cache: dict = None

def set_options(self, options: dict) -> "RedisDriver":
self.options = options
return self

def get_connection(self) -> "Redis":
if self.connection:
return self.connection

try:
from redis import Redis
except ImportError:
raise ModuleNotFoundError(
"Could not find the 'redis' library. Run 'pip install redis' to fix this."
)

if not self.connection:
self.connection = Redis(
**self.options.get("options", {}),
host=self.options.get("host"),
port=self.options.get("port"),
password=self.options.get("password"),
decode_responses=True,
)

# populate the internal cache the first time
# the connection is established
if self._internal_cache is None and self.connection:
self._load_from_store(self.connection)
self.connection = Redis(
**self.options.get("options", {}),
host=self.options.get("host"),
port=self.options.get("port"),
password=self.options.get("password"),
decode_responses=True,
)

return self.connection

def _load_from_store(self, connection: "Redis" = None) -> None:
def _load_from_store(self) -> None:
"""
copy all the "cache" key value pairs for faster access
"""
if not connection:
if self._internal_cache is not None:
return

if self._internal_cache is None:
self._internal_cache = {}
self._internal_cache = {}

cursor = "0"
prefix = self.get_cache_namespace()
while cursor != 0:
cursor, keys = connection.scan(
cursor, keys = self.get_connection().scan(
cursor=cursor, match=prefix + "*", count=100000
)
if keys:
values = connection.mget(*keys)
values = self.get_connection().mget(*keys)
store_data = dict(zip(keys, values))
for key, value in store_data.items():
key = key.replace(prefix, "")
value = self.unpack_value(value)
self._internal_cache.setdefault(key, value)
# we dont load the ttl (expiry)
# because there is an O(N) performance hit
self._internal_cache[key] = {
"value": value,
"expires": None,
}

def get_cache_namespace(self) -> str:
"""
Expand All @@ -72,37 +75,66 @@ def get_cache_namespace(self) -> str:
return f"{namespace}cache:"

def add(self, key: str, value: Any = None) -> Any:
if not value:
if value is None:
return None

self.put(key, value)
return value

def get(self, key: str, default: Any = None, **options) -> Any:
if default and not self.has(key):
self.put(key, default, **options)
return default

return self._internal_cache.get(key)
self._load_from_store()
if not self.has(key):
return default or None

key_expiry = self._internal_cache[key].get("expires", None)
if key_expiry is None:
# the ttl value can also provide info on the
# existence of the key in the store
ttl = self.get_connection().ttl(key)
if ttl == -1:
# key exists but has no set ttl
ttl = self.get_default_timeout()
elif ttl == -2:
# key not found in store
self._internal_cache.pop(key)
return default or None

key_expiry = self._expires_from_ttl(ttl)
self._internal_cache[key]["expires"] = key_expiry

if pdlm.now() > key_expiry:
# the key has expired so remove it from the cache
self._internal_cache.pop(key)
return default or None

# the key has not yet expired
return self._internal_cache.get(key)["value"]

def put(self, key: str, value: Any = None, seconds: int = None, **options) -> Any:
if not key or value is None:
return None

time = self.get_expiration_time(seconds)

store_value = value
if isinstance(value, (dict, list, tuple)):
store_value = json.dumps(value)
elif isinstance(value, int):
store_value = str(value)

self._load_from_store()
key_ttl = seconds or self.get_default_timeout()
self.get_connection().set(
f"{self.get_cache_namespace()}{key}", store_value, ex=time
f"{self.get_cache_namespace()}{key}", store_value, ex=key_ttl
)

if not self.has(key):
self._internal_cache.update({key: value})
expires = self._expires_from_ttl(key_ttl)
self._internal_cache.update({
key: {
"value": value,
"expires": expires,
}
})

def has(self, key: str) -> bool:
self._load_from_store()
return key in self._internal_cache

def increment(self, key: str, amount: int = 1) -> int:
Expand All @@ -126,23 +158,28 @@ def remember(self, key: str, callable):
return self.get(key)

def forget(self, key: str) -> None:
if not self.has(key):
return
self.get_connection().delete(f"{self.get_cache_namespace()}{key}")
self._internal_cache.pop(key)

def flush(self) -> None:
return self.get_connection().flushall()
self.get_connection().flushall()
self._internal_cache = None

def get_expiration_time(self, seconds: int) -> int:
if seconds is None:
seconds = 31557600 * 10

return seconds
def get_default_timeout(self) -> int:
# if unset default timeout of cache vars is 1 month
return int(self.options.get("timeout", 60 * 60 * 24 * 30))

def unpack_value(self, value: Any) -> Any:
value = str(value)
if value.isdigit():
return str(value)
return int(value)

try:
return json.loads(value)
except json.decoder.JSONDecodeError:
return value

def _expires_from_ttl(self, ttl: int) -> pdlm.DateTime:
return pdlm.now().add(seconds=ttl)
10 changes: 7 additions & 3 deletions tests/features/cache/test_redis_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def setUp(self):
self.application.make("cache")
self.driver = self.application.make("cache").store("redis")

def test_can_add_file_driver(self):
def test_can_add_redis_driver(self):
self.assertEqual(self.driver.add("add_key", "value"), "value")

def test_can_get_driver(self):
Expand All @@ -23,9 +23,13 @@ def test_can_increment(self):
self.driver.put("count", "1")
self.assertEqual(self.driver.get("count"), "1")
self.driver.increment("count")
self.assertEqual(self.driver.get("count"), "2")
self.assertEqual(self.driver.get("count"), 2)
self.driver.increment("count", 3)
self.assertEqual(self.driver.get("count"), 5)
self.driver.decrement("count")
self.assertEqual(self.driver.get("count"), "1")
self.assertEqual(self.driver.get("count"), 4)
self.driver.decrement("count", 2)
self.assertEqual(self.driver.get("count"), 2)

def test_will_not_get_expired(self):
self.driver.put("expire", "1", 1)
Expand Down

0 comments on commit 77a5499

Please sign in to comment.