Files
HighCostDrugsDemo/data_processing/cache.py
T

554 lines
17 KiB
Python

"""
Query result caching module for NHS Patient Pathway Analysis.
Provides file-based caching for Snowflake query results with TTL-based invalidation.
Supports different TTLs for historical data vs data including the current date.
Cache keys are generated from query hashes. Results are stored as compressed JSON.
Usage:
from data_processing.cache import QueryCache, get_cache
cache = get_cache()
# Check for cached result
result = cache.get(query, params)
if result is None:
# Execute query and cache result
result = execute_query(query, params)
cache.set(query, params, result, includes_current_data=False)
"""
from dataclasses import dataclass
from datetime import datetime, date
from pathlib import Path
from typing import Any, Optional
import gzip
import hashlib
import json
import os
import time
from config import get_snowflake_config, CacheConfig
from core.logging_config import get_logger
logger = get_logger(__name__)
@dataclass
class CacheEntry:
"""Metadata for a cached query result."""
cache_key: str
query_hash: str
created_at: datetime
expires_at: datetime
includes_current_data: bool
row_count: int
file_size_bytes: int
file_path: Path
@dataclass
class CacheStats:
"""Statistics about the cache."""
enabled: bool
cache_dir: Path
total_entries: int
total_size_mb: float
max_size_mb: int
oldest_entry: Optional[datetime]
newest_entry: Optional[datetime]
hit_count: int
miss_count: int
class QueryCache:
"""
File-based cache for Snowflake query results.
Results are stored as gzipped JSON files with TTL-based expiration.
Supports different TTLs for historical vs current data.
Attributes:
config: CacheConfig with cache settings
cache_dir: Path to cache directory
"""
def __init__(self, config: Optional[CacheConfig] = None, base_path: Optional[Path] = None):
"""
Initialize the query cache.
Args:
config: Optional CacheConfig. If not provided, loads from snowflake.toml
base_path: Base path for relative cache directory. Defaults to cwd.
"""
if config is None:
sf_config = get_snowflake_config()
config = sf_config.cache
self._config = config
self._base_path = base_path or Path.cwd()
# Resolve cache directory
cache_dir = Path(config.directory)
if not cache_dir.is_absolute():
cache_dir = self._base_path / cache_dir
self._cache_dir = cache_dir
# Stats tracking (in-memory only, reset on restart)
self._hit_count = 0
self._miss_count = 0
# Ensure cache directory exists if enabled
if self._config.enabled:
self._cache_dir.mkdir(parents=True, exist_ok=True)
@property
def config(self) -> CacheConfig:
"""Return the cache configuration."""
return self._config
@property
def cache_dir(self) -> Path:
"""Return the cache directory path."""
return self._cache_dir
@property
def is_enabled(self) -> bool:
"""Return True if caching is enabled."""
return self._config.enabled
def _generate_cache_key(self, query: str, params: Optional[tuple] = None) -> str:
"""
Generate a cache key from query and parameters.
Uses SHA256 hash of query + params to create unique key.
"""
# Normalize query (strip whitespace, lowercase)
normalized_query = " ".join(query.lower().split())
# Combine query and params
key_content = normalized_query
if params:
key_content += "|" + "|".join(str(p) for p in params)
# Hash to create key
hash_obj = hashlib.sha256(key_content.encode("utf-8"))
return hash_obj.hexdigest()[:32] # Use first 32 chars for readability
def _get_cache_file_path(self, cache_key: str) -> Path:
"""Get the file path for a cache entry."""
return self._cache_dir / f"{cache_key}.json.gz"
def _get_meta_file_path(self, cache_key: str) -> Path:
"""Get the metadata file path for a cache entry."""
return self._cache_dir / f"{cache_key}.meta.json"
def _is_expired(self, meta: dict) -> bool:
"""Check if a cache entry is expired based on its metadata."""
expires_at = datetime.fromisoformat(meta["expires_at"])
return datetime.now() > expires_at
def get(
self,
query: str,
params: Optional[tuple] = None,
check_expiry: bool = True
) -> Optional[list[dict]]:
"""
Get a cached query result.
Args:
query: SQL query string
params: Optional query parameters
check_expiry: If True, returns None for expired entries
Returns:
Cached result as list of dicts, or None if not cached/expired
"""
if not self.is_enabled:
self._miss_count += 1
return None
cache_key = self._generate_cache_key(query, params)
cache_file = self._get_cache_file_path(cache_key)
meta_file = self._get_meta_file_path(cache_key)
# Check if files exist
if not cache_file.exists() or not meta_file.exists():
self._miss_count += 1
logger.debug(f"Cache miss (not found): {cache_key}")
return None
# Load and check metadata
try:
with open(meta_file, "r", encoding="utf-8") as f:
meta = json.load(f)
if check_expiry and self._is_expired(meta):
self._miss_count += 1
logger.debug(f"Cache miss (expired): {cache_key}")
return None
# Load cached data
with gzip.open(cache_file, "rt", encoding="utf-8") as f:
data = json.load(f)
self._hit_count += 1
logger.info(f"Cache hit: {cache_key} ({meta['row_count']} rows)")
return data
except (json.JSONDecodeError, KeyError, OSError) as e:
logger.warning(f"Cache read error for {cache_key}: {e}")
self._miss_count += 1
# Clean up corrupted entry
self._delete_entry(cache_key)
return None
def set(
self,
query: str,
params: Optional[tuple],
data: list[dict],
includes_current_data: bool = False,
custom_ttl_seconds: Optional[int] = None
) -> Optional[CacheEntry]:
"""
Cache a query result.
Args:
query: SQL query string
params: Optional query parameters
data: Query result as list of dicts
includes_current_data: If True, uses shorter TTL for current data
custom_ttl_seconds: Optional custom TTL (overrides config)
Returns:
CacheEntry with metadata, or None if caching disabled/failed
"""
if not self.is_enabled:
return None
cache_key = self._generate_cache_key(query, params)
cache_file = self._get_cache_file_path(cache_key)
meta_file = self._get_meta_file_path(cache_key)
# Determine TTL
if custom_ttl_seconds is not None:
ttl = custom_ttl_seconds
elif includes_current_data:
ttl = self._config.ttl_current_data_seconds
else:
ttl = self._config.ttl_seconds
now = datetime.now()
expires_at = datetime.fromtimestamp(now.timestamp() + ttl)
try:
# Write compressed data
with gzip.open(cache_file, "wt", encoding="utf-8", compresslevel=6) as f:
json.dump(data, f, default=str)
file_size = cache_file.stat().st_size
# Write metadata
meta = {
"cache_key": cache_key,
"query_hash": hashlib.sha256(query.encode()).hexdigest()[:16],
"created_at": now.isoformat(),
"expires_at": expires_at.isoformat(),
"includes_current_data": includes_current_data,
"row_count": len(data),
"file_size_bytes": file_size,
"ttl_seconds": ttl,
}
with open(meta_file, "w", encoding="utf-8") as f:
json.dump(meta, f, indent=2)
logger.info(f"Cached {len(data)} rows as {cache_key} (expires in {ttl}s)")
# Check if we need to enforce size limit
self._enforce_size_limit()
return CacheEntry(
cache_key=cache_key,
query_hash=str(meta["query_hash"]),
created_at=now,
expires_at=expires_at,
includes_current_data=includes_current_data,
row_count=len(data),
file_size_bytes=file_size,
file_path=cache_file,
)
except (OSError, TypeError) as e:
logger.error(f"Failed to cache result: {e}")
return None
def invalidate(self, query: str, params: Optional[tuple] = None) -> bool:
"""
Invalidate a specific cache entry.
Args:
query: SQL query string
params: Optional query parameters
Returns:
True if entry was deleted, False if not found
"""
cache_key = self._generate_cache_key(query, params)
return self._delete_entry(cache_key)
def _delete_entry(self, cache_key: str) -> bool:
"""Delete a cache entry by key."""
cache_file = self._get_cache_file_path(cache_key)
meta_file = self._get_meta_file_path(cache_key)
deleted = False
if cache_file.exists():
cache_file.unlink()
deleted = True
if meta_file.exists():
meta_file.unlink()
deleted = True
if deleted:
logger.debug(f"Deleted cache entry: {cache_key}")
return deleted
def clear(self) -> int:
"""
Clear all cache entries.
Returns:
Number of entries deleted
"""
if not self._cache_dir.exists():
return 0
count = 0
for file in self._cache_dir.glob("*.json*"):
try:
file.unlink()
count += 1
except OSError as e:
logger.warning(f"Failed to delete {file}: {e}")
# Reset stats
self._hit_count = 0
self._miss_count = 0
logger.info(f"Cleared {count} cache files")
return count // 2 # Divide by 2 since we have .json.gz and .meta.json
def clear_expired(self) -> int:
"""
Remove expired cache entries.
Returns:
Number of expired entries deleted
"""
if not self._cache_dir.exists():
return 0
count = 0
for meta_file in self._cache_dir.glob("*.meta.json"):
try:
with open(meta_file, "r", encoding="utf-8") as f:
meta = json.load(f)
if self._is_expired(meta):
cache_key = meta_file.stem.replace(".meta", "")
self._delete_entry(cache_key)
count += 1
except (OSError, json.JSONDecodeError):
# Delete corrupted metadata files
cache_key = meta_file.stem.replace(".meta", "")
self._delete_entry(cache_key)
count += 1
logger.info(f"Cleared {count} expired cache entries")
return count
def _get_total_size_mb(self) -> float:
"""Calculate total cache size in MB."""
if not self._cache_dir.exists():
return 0.0
total_bytes = sum(
f.stat().st_size
for f in self._cache_dir.glob("*")
if f.is_file()
)
return total_bytes / (1024 * 1024)
def _enforce_size_limit(self) -> int:
"""
Enforce cache size limit by removing oldest entries.
Returns:
Number of entries removed
"""
max_size_mb = self._config.max_size_mb
current_size_mb = self._get_total_size_mb()
if current_size_mb <= max_size_mb:
return 0
# Get all entries sorted by creation time
entries = []
for meta_file in self._cache_dir.glob("*.meta.json"):
try:
with open(meta_file, "r", encoding="utf-8") as f:
meta = json.load(f)
entries.append((
meta_file.stem.replace(".meta", ""),
datetime.fromisoformat(meta["created_at"]),
meta.get("file_size_bytes", 0)
))
except (OSError, json.JSONDecodeError, KeyError):
# Clean up corrupted entry
cache_key = meta_file.stem.replace(".meta", "")
self._delete_entry(cache_key)
# Sort by creation time (oldest first)
entries.sort(key=lambda x: x[1])
# Remove oldest entries until under limit
removed = 0
size_to_remove_bytes = (current_size_mb - max_size_mb * 0.9) * 1024 * 1024 # Target 90% of limit
removed_bytes = 0
for cache_key, created_at, file_size in entries:
if removed_bytes >= size_to_remove_bytes:
break
self._delete_entry(cache_key)
removed_bytes += file_size
removed += 1
logger.info(f"Removed {removed} cache entries to enforce size limit")
return removed
def get_stats(self) -> CacheStats:
"""Get cache statistics."""
if not self._cache_dir.exists():
return CacheStats(
enabled=self.is_enabled,
cache_dir=self._cache_dir,
total_entries=0,
total_size_mb=0.0,
max_size_mb=self._config.max_size_mb,
oldest_entry=None,
newest_entry=None,
hit_count=self._hit_count,
miss_count=self._miss_count,
)
entries = []
for meta_file in self._cache_dir.glob("*.meta.json"):
try:
with open(meta_file, "r", encoding="utf-8") as f:
meta = json.load(f)
entries.append(datetime.fromisoformat(meta["created_at"]))
except (OSError, json.JSONDecodeError, KeyError):
pass
oldest = min(entries) if entries else None
newest = max(entries) if entries else None
return CacheStats(
enabled=self.is_enabled,
cache_dir=self._cache_dir,
total_entries=len(entries),
total_size_mb=self._get_total_size_mb(),
max_size_mb=self._config.max_size_mb,
oldest_entry=oldest,
newest_entry=newest,
hit_count=self._hit_count,
miss_count=self._miss_count,
)
def list_entries(self) -> list[CacheEntry]:
"""List all cache entries with metadata."""
if not self._cache_dir.exists():
return []
entries = []
for meta_file in self._cache_dir.glob("*.meta.json"):
try:
with open(meta_file, "r", encoding="utf-8") as f:
meta = json.load(f)
cache_key = meta["cache_key"]
entries.append(CacheEntry(
cache_key=cache_key,
query_hash=meta.get("query_hash", ""),
created_at=datetime.fromisoformat(meta["created_at"]),
expires_at=datetime.fromisoformat(meta["expires_at"]),
includes_current_data=meta.get("includes_current_data", False),
row_count=meta.get("row_count", 0),
file_size_bytes=meta.get("file_size_bytes", 0),
file_path=self._get_cache_file_path(cache_key),
))
except (OSError, json.JSONDecodeError, KeyError):
pass
# Sort by creation time (newest first)
entries.sort(key=lambda x: x.created_at, reverse=True)
return entries
# Module-level singleton
_default_cache: Optional[QueryCache] = None
def get_cache(config: Optional[CacheConfig] = None) -> QueryCache:
"""
Get a QueryCache instance (creates singleton on first call).
Args:
config: Optional CacheConfig. If provided, creates new cache with
this config. If None, uses/creates default cache.
Returns:
QueryCache instance
"""
global _default_cache
if config is not None:
# Custom config requested, create new cache
return QueryCache(config)
if _default_cache is None:
_default_cache = QueryCache()
return _default_cache
def reset_cache() -> None:
"""Reset the default cache singleton."""
global _default_cache
_default_cache = None
def is_cache_enabled() -> bool:
"""Return True if caching is enabled in configuration."""
config = get_snowflake_config()
return config.cache.enabled
# Export public API
__all__ = [
"QueryCache",
"CacheEntry",
"CacheStats",
"get_cache",
"reset_cache",
"is_cache_enabled",
]