76838887e6
Move 6 packages (core, config, data_processing, analysis, visualization, cli) into src/ to reduce root clutter. Merge tools/data.py into data_processing/transforms.py. Move docs to docs/. Path resolution via .pth file (setup_dev.py), pytest pythonpath config, and sys.path bootstrap in rxconfig.py and CLI entry points. Clean up pyproject.toml deps (remove stale pins, add snowflake-connector-python). Fix tomllib import for Python 3.10 compatibility. All 113 tests pass.
554 lines
17 KiB
Python
554 lines
17 KiB
Python
"""
|
|
Query result caching module for NHS Patient Pathway Analysis.
|
|
|
|
Provides file-based caching for Snowflake query results with TTL-based invalidation.
|
|
Supports different TTLs for historical data vs data including the current date.
|
|
|
|
Cache keys are generated from query hashes. Results are stored as compressed JSON.
|
|
|
|
Usage:
|
|
from data_processing.cache import QueryCache, get_cache
|
|
|
|
cache = get_cache()
|
|
|
|
# Check for cached result
|
|
result = cache.get(query, params)
|
|
if result is None:
|
|
# Execute query and cache result
|
|
result = execute_query(query, params)
|
|
cache.set(query, params, result, includes_current_data=False)
|
|
"""
|
|
|
|
from dataclasses import dataclass
|
|
from datetime import datetime, date
|
|
from pathlib import Path
|
|
from typing import Any, Optional
|
|
import gzip
|
|
import hashlib
|
|
import json
|
|
import os
|
|
import time
|
|
|
|
from config import get_snowflake_config, CacheConfig
|
|
from core.logging_config import get_logger
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class CacheEntry:
|
|
"""Metadata for a cached query result."""
|
|
cache_key: str
|
|
query_hash: str
|
|
created_at: datetime
|
|
expires_at: datetime
|
|
includes_current_data: bool
|
|
row_count: int
|
|
file_size_bytes: int
|
|
file_path: Path
|
|
|
|
|
|
@dataclass
|
|
class CacheStats:
|
|
"""Statistics about the cache."""
|
|
enabled: bool
|
|
cache_dir: Path
|
|
total_entries: int
|
|
total_size_mb: float
|
|
max_size_mb: int
|
|
oldest_entry: Optional[datetime]
|
|
newest_entry: Optional[datetime]
|
|
hit_count: int
|
|
miss_count: int
|
|
|
|
|
|
class QueryCache:
|
|
"""
|
|
File-based cache for Snowflake query results.
|
|
|
|
Results are stored as gzipped JSON files with TTL-based expiration.
|
|
Supports different TTLs for historical vs current data.
|
|
|
|
Attributes:
|
|
config: CacheConfig with cache settings
|
|
cache_dir: Path to cache directory
|
|
"""
|
|
|
|
def __init__(self, config: Optional[CacheConfig] = None, base_path: Optional[Path] = None):
|
|
"""
|
|
Initialize the query cache.
|
|
|
|
Args:
|
|
config: Optional CacheConfig. If not provided, loads from snowflake.toml
|
|
base_path: Base path for relative cache directory. Defaults to cwd.
|
|
"""
|
|
if config is None:
|
|
sf_config = get_snowflake_config()
|
|
config = sf_config.cache
|
|
|
|
self._config = config
|
|
self._base_path = base_path or Path.cwd()
|
|
|
|
# Resolve cache directory
|
|
cache_dir = Path(config.directory)
|
|
if not cache_dir.is_absolute():
|
|
cache_dir = self._base_path / cache_dir
|
|
self._cache_dir = cache_dir
|
|
|
|
# Stats tracking (in-memory only, reset on restart)
|
|
self._hit_count = 0
|
|
self._miss_count = 0
|
|
|
|
# Ensure cache directory exists if enabled
|
|
if self._config.enabled:
|
|
self._cache_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
@property
|
|
def config(self) -> CacheConfig:
|
|
"""Return the cache configuration."""
|
|
return self._config
|
|
|
|
@property
|
|
def cache_dir(self) -> Path:
|
|
"""Return the cache directory path."""
|
|
return self._cache_dir
|
|
|
|
@property
|
|
def is_enabled(self) -> bool:
|
|
"""Return True if caching is enabled."""
|
|
return self._config.enabled
|
|
|
|
def _generate_cache_key(self, query: str, params: Optional[tuple] = None) -> str:
|
|
"""
|
|
Generate a cache key from query and parameters.
|
|
|
|
Uses SHA256 hash of query + params to create unique key.
|
|
"""
|
|
# Normalize query (strip whitespace, lowercase)
|
|
normalized_query = " ".join(query.lower().split())
|
|
|
|
# Combine query and params
|
|
key_content = normalized_query
|
|
if params:
|
|
key_content += "|" + "|".join(str(p) for p in params)
|
|
|
|
# Hash to create key
|
|
hash_obj = hashlib.sha256(key_content.encode("utf-8"))
|
|
return hash_obj.hexdigest()[:32] # Use first 32 chars for readability
|
|
|
|
def _get_cache_file_path(self, cache_key: str) -> Path:
|
|
"""Get the file path for a cache entry."""
|
|
return self._cache_dir / f"{cache_key}.json.gz"
|
|
|
|
def _get_meta_file_path(self, cache_key: str) -> Path:
|
|
"""Get the metadata file path for a cache entry."""
|
|
return self._cache_dir / f"{cache_key}.meta.json"
|
|
|
|
def _is_expired(self, meta: dict) -> bool:
|
|
"""Check if a cache entry is expired based on its metadata."""
|
|
expires_at = datetime.fromisoformat(meta["expires_at"])
|
|
return datetime.now() > expires_at
|
|
|
|
def get(
|
|
self,
|
|
query: str,
|
|
params: Optional[tuple] = None,
|
|
check_expiry: bool = True
|
|
) -> Optional[list[dict]]:
|
|
"""
|
|
Get a cached query result.
|
|
|
|
Args:
|
|
query: SQL query string
|
|
params: Optional query parameters
|
|
check_expiry: If True, returns None for expired entries
|
|
|
|
Returns:
|
|
Cached result as list of dicts, or None if not cached/expired
|
|
"""
|
|
if not self.is_enabled:
|
|
self._miss_count += 1
|
|
return None
|
|
|
|
cache_key = self._generate_cache_key(query, params)
|
|
cache_file = self._get_cache_file_path(cache_key)
|
|
meta_file = self._get_meta_file_path(cache_key)
|
|
|
|
# Check if files exist
|
|
if not cache_file.exists() or not meta_file.exists():
|
|
self._miss_count += 1
|
|
logger.debug(f"Cache miss (not found): {cache_key}")
|
|
return None
|
|
|
|
# Load and check metadata
|
|
try:
|
|
with open(meta_file, "r", encoding="utf-8") as f:
|
|
meta = json.load(f)
|
|
|
|
if check_expiry and self._is_expired(meta):
|
|
self._miss_count += 1
|
|
logger.debug(f"Cache miss (expired): {cache_key}")
|
|
return None
|
|
|
|
# Load cached data
|
|
with gzip.open(cache_file, "rt", encoding="utf-8") as f:
|
|
data = json.load(f)
|
|
|
|
self._hit_count += 1
|
|
logger.info(f"Cache hit: {cache_key} ({meta['row_count']} rows)")
|
|
return data
|
|
|
|
except (json.JSONDecodeError, KeyError, OSError) as e:
|
|
logger.warning(f"Cache read error for {cache_key}: {e}")
|
|
self._miss_count += 1
|
|
# Clean up corrupted entry
|
|
self._delete_entry(cache_key)
|
|
return None
|
|
|
|
def set(
|
|
self,
|
|
query: str,
|
|
params: Optional[tuple],
|
|
data: list[dict],
|
|
includes_current_data: bool = False,
|
|
custom_ttl_seconds: Optional[int] = None
|
|
) -> Optional[CacheEntry]:
|
|
"""
|
|
Cache a query result.
|
|
|
|
Args:
|
|
query: SQL query string
|
|
params: Optional query parameters
|
|
data: Query result as list of dicts
|
|
includes_current_data: If True, uses shorter TTL for current data
|
|
custom_ttl_seconds: Optional custom TTL (overrides config)
|
|
|
|
Returns:
|
|
CacheEntry with metadata, or None if caching disabled/failed
|
|
"""
|
|
if not self.is_enabled:
|
|
return None
|
|
|
|
cache_key = self._generate_cache_key(query, params)
|
|
cache_file = self._get_cache_file_path(cache_key)
|
|
meta_file = self._get_meta_file_path(cache_key)
|
|
|
|
# Determine TTL
|
|
if custom_ttl_seconds is not None:
|
|
ttl = custom_ttl_seconds
|
|
elif includes_current_data:
|
|
ttl = self._config.ttl_current_data_seconds
|
|
else:
|
|
ttl = self._config.ttl_seconds
|
|
|
|
now = datetime.now()
|
|
expires_at = datetime.fromtimestamp(now.timestamp() + ttl)
|
|
|
|
try:
|
|
# Write compressed data
|
|
with gzip.open(cache_file, "wt", encoding="utf-8", compresslevel=6) as f:
|
|
json.dump(data, f, default=str)
|
|
|
|
file_size = cache_file.stat().st_size
|
|
|
|
# Write metadata
|
|
meta = {
|
|
"cache_key": cache_key,
|
|
"query_hash": hashlib.sha256(query.encode()).hexdigest()[:16],
|
|
"created_at": now.isoformat(),
|
|
"expires_at": expires_at.isoformat(),
|
|
"includes_current_data": includes_current_data,
|
|
"row_count": len(data),
|
|
"file_size_bytes": file_size,
|
|
"ttl_seconds": ttl,
|
|
}
|
|
|
|
with open(meta_file, "w", encoding="utf-8") as f:
|
|
json.dump(meta, f, indent=2)
|
|
|
|
logger.info(f"Cached {len(data)} rows as {cache_key} (expires in {ttl}s)")
|
|
|
|
# Check if we need to enforce size limit
|
|
self._enforce_size_limit()
|
|
|
|
return CacheEntry(
|
|
cache_key=cache_key,
|
|
query_hash=str(meta["query_hash"]),
|
|
created_at=now,
|
|
expires_at=expires_at,
|
|
includes_current_data=includes_current_data,
|
|
row_count=len(data),
|
|
file_size_bytes=file_size,
|
|
file_path=cache_file,
|
|
)
|
|
|
|
except (OSError, TypeError) as e:
|
|
logger.error(f"Failed to cache result: {e}")
|
|
return None
|
|
|
|
def invalidate(self, query: str, params: Optional[tuple] = None) -> bool:
|
|
"""
|
|
Invalidate a specific cache entry.
|
|
|
|
Args:
|
|
query: SQL query string
|
|
params: Optional query parameters
|
|
|
|
Returns:
|
|
True if entry was deleted, False if not found
|
|
"""
|
|
cache_key = self._generate_cache_key(query, params)
|
|
return self._delete_entry(cache_key)
|
|
|
|
def _delete_entry(self, cache_key: str) -> bool:
|
|
"""Delete a cache entry by key."""
|
|
cache_file = self._get_cache_file_path(cache_key)
|
|
meta_file = self._get_meta_file_path(cache_key)
|
|
|
|
deleted = False
|
|
|
|
if cache_file.exists():
|
|
cache_file.unlink()
|
|
deleted = True
|
|
|
|
if meta_file.exists():
|
|
meta_file.unlink()
|
|
deleted = True
|
|
|
|
if deleted:
|
|
logger.debug(f"Deleted cache entry: {cache_key}")
|
|
|
|
return deleted
|
|
|
|
def clear(self) -> int:
|
|
"""
|
|
Clear all cache entries.
|
|
|
|
Returns:
|
|
Number of entries deleted
|
|
"""
|
|
if not self._cache_dir.exists():
|
|
return 0
|
|
|
|
count = 0
|
|
for file in self._cache_dir.glob("*.json*"):
|
|
try:
|
|
file.unlink()
|
|
count += 1
|
|
except OSError as e:
|
|
logger.warning(f"Failed to delete {file}: {e}")
|
|
|
|
# Reset stats
|
|
self._hit_count = 0
|
|
self._miss_count = 0
|
|
|
|
logger.info(f"Cleared {count} cache files")
|
|
return count // 2 # Divide by 2 since we have .json.gz and .meta.json
|
|
|
|
def clear_expired(self) -> int:
|
|
"""
|
|
Remove expired cache entries.
|
|
|
|
Returns:
|
|
Number of expired entries deleted
|
|
"""
|
|
if not self._cache_dir.exists():
|
|
return 0
|
|
|
|
count = 0
|
|
for meta_file in self._cache_dir.glob("*.meta.json"):
|
|
try:
|
|
with open(meta_file, "r", encoding="utf-8") as f:
|
|
meta = json.load(f)
|
|
|
|
if self._is_expired(meta):
|
|
cache_key = meta_file.stem.replace(".meta", "")
|
|
self._delete_entry(cache_key)
|
|
count += 1
|
|
except (OSError, json.JSONDecodeError):
|
|
# Delete corrupted metadata files
|
|
cache_key = meta_file.stem.replace(".meta", "")
|
|
self._delete_entry(cache_key)
|
|
count += 1
|
|
|
|
logger.info(f"Cleared {count} expired cache entries")
|
|
return count
|
|
|
|
def _get_total_size_mb(self) -> float:
|
|
"""Calculate total cache size in MB."""
|
|
if not self._cache_dir.exists():
|
|
return 0.0
|
|
|
|
total_bytes = sum(
|
|
f.stat().st_size
|
|
for f in self._cache_dir.glob("*")
|
|
if f.is_file()
|
|
)
|
|
return total_bytes / (1024 * 1024)
|
|
|
|
def _enforce_size_limit(self) -> int:
|
|
"""
|
|
Enforce cache size limit by removing oldest entries.
|
|
|
|
Returns:
|
|
Number of entries removed
|
|
"""
|
|
max_size_mb = self._config.max_size_mb
|
|
current_size_mb = self._get_total_size_mb()
|
|
|
|
if current_size_mb <= max_size_mb:
|
|
return 0
|
|
|
|
# Get all entries sorted by creation time
|
|
entries = []
|
|
for meta_file in self._cache_dir.glob("*.meta.json"):
|
|
try:
|
|
with open(meta_file, "r", encoding="utf-8") as f:
|
|
meta = json.load(f)
|
|
entries.append((
|
|
meta_file.stem.replace(".meta", ""),
|
|
datetime.fromisoformat(meta["created_at"]),
|
|
meta.get("file_size_bytes", 0)
|
|
))
|
|
except (OSError, json.JSONDecodeError, KeyError):
|
|
# Clean up corrupted entry
|
|
cache_key = meta_file.stem.replace(".meta", "")
|
|
self._delete_entry(cache_key)
|
|
|
|
# Sort by creation time (oldest first)
|
|
entries.sort(key=lambda x: x[1])
|
|
|
|
# Remove oldest entries until under limit
|
|
removed = 0
|
|
size_to_remove_bytes = (current_size_mb - max_size_mb * 0.9) * 1024 * 1024 # Target 90% of limit
|
|
removed_bytes = 0
|
|
|
|
for cache_key, created_at, file_size in entries:
|
|
if removed_bytes >= size_to_remove_bytes:
|
|
break
|
|
|
|
self._delete_entry(cache_key)
|
|
removed_bytes += file_size
|
|
removed += 1
|
|
|
|
logger.info(f"Removed {removed} cache entries to enforce size limit")
|
|
return removed
|
|
|
|
def get_stats(self) -> CacheStats:
|
|
"""Get cache statistics."""
|
|
if not self._cache_dir.exists():
|
|
return CacheStats(
|
|
enabled=self.is_enabled,
|
|
cache_dir=self._cache_dir,
|
|
total_entries=0,
|
|
total_size_mb=0.0,
|
|
max_size_mb=self._config.max_size_mb,
|
|
oldest_entry=None,
|
|
newest_entry=None,
|
|
hit_count=self._hit_count,
|
|
miss_count=self._miss_count,
|
|
)
|
|
|
|
entries = []
|
|
for meta_file in self._cache_dir.glob("*.meta.json"):
|
|
try:
|
|
with open(meta_file, "r", encoding="utf-8") as f:
|
|
meta = json.load(f)
|
|
entries.append(datetime.fromisoformat(meta["created_at"]))
|
|
except (OSError, json.JSONDecodeError, KeyError):
|
|
pass
|
|
|
|
oldest = min(entries) if entries else None
|
|
newest = max(entries) if entries else None
|
|
|
|
return CacheStats(
|
|
enabled=self.is_enabled,
|
|
cache_dir=self._cache_dir,
|
|
total_entries=len(entries),
|
|
total_size_mb=self._get_total_size_mb(),
|
|
max_size_mb=self._config.max_size_mb,
|
|
oldest_entry=oldest,
|
|
newest_entry=newest,
|
|
hit_count=self._hit_count,
|
|
miss_count=self._miss_count,
|
|
)
|
|
|
|
def list_entries(self) -> list[CacheEntry]:
|
|
"""List all cache entries with metadata."""
|
|
if not self._cache_dir.exists():
|
|
return []
|
|
|
|
entries = []
|
|
for meta_file in self._cache_dir.glob("*.meta.json"):
|
|
try:
|
|
with open(meta_file, "r", encoding="utf-8") as f:
|
|
meta = json.load(f)
|
|
|
|
cache_key = meta["cache_key"]
|
|
entries.append(CacheEntry(
|
|
cache_key=cache_key,
|
|
query_hash=meta.get("query_hash", ""),
|
|
created_at=datetime.fromisoformat(meta["created_at"]),
|
|
expires_at=datetime.fromisoformat(meta["expires_at"]),
|
|
includes_current_data=meta.get("includes_current_data", False),
|
|
row_count=meta.get("row_count", 0),
|
|
file_size_bytes=meta.get("file_size_bytes", 0),
|
|
file_path=self._get_cache_file_path(cache_key),
|
|
))
|
|
except (OSError, json.JSONDecodeError, KeyError):
|
|
pass
|
|
|
|
# Sort by creation time (newest first)
|
|
entries.sort(key=lambda x: x.created_at, reverse=True)
|
|
return entries
|
|
|
|
|
|
# Module-level singleton
|
|
_default_cache: Optional[QueryCache] = None
|
|
|
|
|
|
def get_cache(config: Optional[CacheConfig] = None) -> QueryCache:
|
|
"""
|
|
Get a QueryCache instance (creates singleton on first call).
|
|
|
|
Args:
|
|
config: Optional CacheConfig. If provided, creates new cache with
|
|
this config. If None, uses/creates default cache.
|
|
|
|
Returns:
|
|
QueryCache instance
|
|
"""
|
|
global _default_cache
|
|
|
|
if config is not None:
|
|
# Custom config requested, create new cache
|
|
return QueryCache(config)
|
|
|
|
if _default_cache is None:
|
|
_default_cache = QueryCache()
|
|
|
|
return _default_cache
|
|
|
|
|
|
def reset_cache() -> None:
|
|
"""Reset the default cache singleton."""
|
|
global _default_cache
|
|
_default_cache = None
|
|
|
|
|
|
def is_cache_enabled() -> bool:
|
|
"""Return True if caching is enabled in configuration."""
|
|
config = get_snowflake_config()
|
|
return config.cache.enabled
|
|
|
|
|
|
# Export public API
|
|
__all__ = [
|
|
"QueryCache",
|
|
"CacheEntry",
|
|
"CacheStats",
|
|
"get_cache",
|
|
"reset_cache",
|
|
"is_cache_enabled",
|
|
]
|